, ,

, ,

,

,

,

,

gemm_int16(M, N, K, int16* A, int16* B, int32* C, int alpha)

```void weight_trans_4x4_3x3_int8(int16_t* dest, const int8_t* din, int ch_in,
int ch_out, void* workspace) {
const int32_t coeff[6][3] = {{6, 0, 0},
{-4, -4, -4},
{-4, 4, -4},
{1, 2, 4},
{1, -2, 4},
{0, 0, 24}};
int32_t* ptr_out = static_cast<int32_t*>(workspace);
for (int i = 0; i < ch_out; i++) {
for (int j = 0; j < ch_in; j++) {
const int8_t* kernel0 =
static_cast<const int8_t*>(din) + (i * ch_in + j) * 9;
const int8_t* k0 = kernel0;
const int8_t* k1 = kernel0 + 3;
const int8_t* k2 = kernel0 + 6;
int32_t tmp[6][3];
for (int i = 0; i < 6; i++) {
tmp[i][0] =
static_cast<int32_t>(k0[0]) * coeff[i][0] +
static_cast<int32_t>(k0[1]) * coeff[i][1] +
static_cast<int32_t>(k0[2]) * coeff[i][2];
tmp[i][1] =
static_cast<int32_t>(k1[0]) * coeff[i][0] +
static_cast<int32_t>(k1[1]) * coeff[i][1] +
static_cast<int32_t>(k1[2]) * coeff[i][2];
tmp[i][2] =
static_cast<int32_t>(k2[0]) * coeff[i][0] +
static_cast<int32_t>(k2[1]) * coeff[i][1] +
static_cast<int32_t>(k2[2]) * coeff[i][2];
}
for (int j = 0; j < 6; j++) {
int32_t* tmpp = &tmp[j][0];
for (int i = 0; i < 6; i++) {
ptr_channel[j * 6 + i] = tmpp[0] * coeff[i][0] +
tmpp[1] * coeff[i][1] +
tmpp[2] * coeff[i][2];
if (i == 5 || j == 5)
ptr_channel[j * 6 + i] /= 24;
}
}
}
}
... ...
}```

batched gemm接口调用示意代码：

```for (int gi = 0; gi < 36; ++gi) {
... ...
... ...
int col_idx = gi / 6;
int row_idx = gi % 6;
if (col_idx == 5 || row_idx == 5) {
gemm_int16_alpha(
M, N, K, A, B, C, 24);
} else {
gemm_int16_alpha(
M, N, K, A, B, C, 1);
}
}```

scale转换部分代码：

```... ...
for (auto& ws : w_scale_) {
ws /= 576;
}
... ...```