Press "Enter" to skip to content

winograd int8实现技巧

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

随着深度学习的应用普及,推理框架在深度学习模型的落地产品化过程中扮演了越来越重要的角色。可用于加速卷积计算且历史悠久的的winograd算法重新粉墨登场,又一次开始了它的神奇表演。

 

winograd算法的原理网上有很多文章介绍这里不再赘述,笔者这里给出一张计算流程图和计算公式和转换矩阵:

 

winograd计算流程图

 

, ,

 

 

, ,

 

 

,

 

,

 

winograd在float32浮点运算下表现确实不错,在3×3 stride=1的卷积层时取得了很好的加速效果,但对于int8量化卷积来说,加速效果不甚理想,整型可没有浮点这幺大的动态范围,如何保证运算过程整型不会溢出将是个令人头疼的问题。

 

假设将网络量化成int8, int8的权重和int8的输入,先看 ,

 

 

 

 

为了保证计算不溢出, 需要额外的2个bit, 而 需要额外的1个bit,换言之,为了保证安全的int8*int8, 权重和输入分别需要量化到int6和int7.

 

由上分析推广到 需要额外的4个bit, 需要额外的2个bit,换言之,为了保证安全的int8*int8, 权重和输入实际需要分别量化到int4和int6才可以!

 

由上面 的分析方法推广到 , 为了保证安全的int8*int8, 需要额外的10个bit, 需要额外的7个bit……换言之,压根没法保证int8*int8的安全计算,这时候只能将int8扩展到int16,执行int16*int16的计算。

 

为了保证计算结果的安全性并且不损失输入量化数据的动态范围,实际当中int8 winograd执行的基本都是int16*int16

 

对于 来说,只有权值变换矩阵G含有小数,在int8计算时只需要把G’=2*G作为新的变换矩阵即可,在int8输出前requant(int8->int8)/dequant(int8->fp32)时修改w_ scale值w_ scale’ = w_scale * 0.25f即可,考虑int16表示的动态范围是-32768~32767,可以简单验证下这个变换的过程不会产生溢出。

 

但对于 来说就不同了:

 

,

 

,

 

 

 

 

注:ok表示不会溢出, overflow表示有可能溢出,实际上右下角数据= ,实际当中非常有可能溢出( 即可溢出), 实际int8量化后的取值大部分选择去除掉-128这个点,实际范围选择在-127~127.

 

这里我们采取一个trick来避免int8计算weight_transform带来的数据溢出问题,利用提取公共因子法,最后一列和最后一行转int16输出前除24,中间计算结果改用int32来存储以保持权值矩阵转换过程计算的正确性,

 

在后面计算batched gemm计算时,修改原来使用的int16*int16的计算kernel函数接口加入alpha系数:

 

gemm_int16(M, N, K, int16* A, int16* B, int32* C, int alpha)

 

在batched gemm kernel函数内部, 把int32的计算结果store出去之前,先mul alpha;

 

在kernel函数外部调用的地方对index进行判断alpha分别传1或者24, 把在weight_ transform过程中除掉的系数在gemm计算时再乘回来,下面给出weights_ transform示意代码:

 

void weight_trans_4x4_3x3_int8(int16_t* dest, const int8_t* din, int ch_in,
                           int ch_out, void* workspace) {
  const int32_t coeff[6][3] = {{6, 0, 0},
                                {-4, -4, -4},
                                {-4, 4, -4},
                                {1, 2, 4},
                                {1, -2, 4},
                                {0, 0, 24}};
  int32_t* ptr_out = static_cast<int32_t*>(workspace);
  for (int i = 0; i < ch_out; i++) {
    for (int j = 0; j < ch_in; j++) {
      const int8_t* kernel0 =
          static_cast<const int8_t*>(din) + (i * ch_in + j) * 9;
      const int8_t* k0 = kernel0;
      const int8_t* k1 = kernel0 + 3;
      const int8_t* k2 = kernel0 + 6;
      int32_t tmp[6][3];
      for (int i = 0; i < 6; i++) {
        tmp[i][0] =
            static_cast<int32_t>(k0[0]) * coeff[i][0] +
            static_cast<int32_t>(k0[1]) * coeff[i][1] +
            static_cast<int32_t>(k0[2]) * coeff[i][2];
        tmp[i][1] =
            static_cast<int32_t>(k1[0]) * coeff[i][0] +
            static_cast<int32_t>(k1[1]) * coeff[i][1] +
            static_cast<int32_t>(k1[2]) * coeff[i][2];
        tmp[i][2] =
            static_cast<int32_t>(k2[0]) * coeff[i][0] +
            static_cast<int32_t>(k2[1]) * coeff[i][1] +
            static_cast<int32_t>(k2[2]) * coeff[i][2];
      }
      for (int j = 0; j < 6; j++) {
        int32_t* tmpp = &tmp[j][0];
        for (int i = 0; i < 6; i++) {
          ptr_channel[j * 6 + i] = tmpp[0] * coeff[i][0] +
                                   tmpp[1] * coeff[i][1] +
                                   tmpp[2] * coeff[i][2];
          if (i == 5 || j == 5)
            ptr_channel[j * 6 + i] /= 24;
        }
      }
    }
  }
... ... 
}

 

batched gemm接口调用示意代码:

 

for (int gi = 0; gi < 36; ++gi) {
     ... ...
     ... ...
     int col_idx = gi / 6;
     int row_idx = gi % 6;
     if (col_idx == 5 || row_idx == 5) {
        gemm_int16_alpha(
            M, N, K, A, B, C, 24);
     } else {
        gemm_int16_alpha(
            M, N, K, A, B, C, 1);
     }
}

 

scale转换部分代码:

 

... ...
for (auto& ws : w_scale_) {
      ws /= 576;
}
... ...

 

这里给出了一个优化winograd int8 4x4_3x3实现的一个trick, 笔者在华为P30 ARM A76大核单线程上测试与之前的winograd int8 2x2_3x3实现相比,有明显的算法性能增益,这里就不给出具体的测试性能数据了.

 

容易掉进的坑:在计算weight_transform时采用float32形式,最后把计算的结果static_cast到int16,这样会丢弃float32计算结果的小数部分,直接导致最后计算结果的错误。

 

对于 ,情况要更复杂一些,除了权值转换矩阵G之外, 输入转换矩阵B和输出转换矩阵A也含有小数,感兴趣的同学可以推导一下能否采取相同的trick保证input_ transform和weight_ transform转换之后的结果落在int16有效的表示范围内(-32768~32767),在batched gemm时进行系数纠正.

 

天色已晚,晚安。

Be First to Comment

发表评论

电子邮件地址不会被公开。 必填项已用*标注