在深入瞭解了 quantization 之後,對quant有所瞭解之後,不論是 dynamic quant還是static quant都有所瞭解,但是因為看了大佬的有關量化之後,理解了trt中的W8A8的運算,理解了為什麼量化之後會加速的原因,但是針對gptq的 W8A16或者W4A16 卻不明白到底屬於是 dynamic quant 還是 static quant,因此糾結了好久,後續透過看了gptq的原始碼理解到,整個過程其實是 將量化的 weight 先反量化為 fp16 然後再和 W*X再進行運算,具體原始碼可以參看gptq的原始碼。
但看完之後,又糾結了,就是覺得既然在相乘之前,有個反量化的過程,豈不是速度變慢了?為啥大家都說速度加快了呢?為什麼加速了呢?還是糾結
void vecquant8matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
torch::Tensor g_idx
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH); // 申請資源
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant8matmul_cuda", ([&] {
VecQuant8MatMulKernel<<<blocks, threads>>>( // 真正的cuda函式,包裝了thread之後重新呼叫
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width
);
})
);
}
template <typename scalar_t> // 型別模版
__global__ void VecQuant8MatMulKernel(
const scalar_t* __restrict__ vec, // x
const int* __restrict__ mat, // w
scalar_t* __restrict__ mul, // w*x 的結果
const scalar_t* __restrict__ scales, // w 量化過程中的 scale
const int* __restrict__ zeros, // w 量化過程中的 zero
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
) {
int h = BLOCKHEIGHT8 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w;
int g_h = h * 4;
int k;
unsigned int g;
scalar_t w_tmp;
int z_w = w / 4;
int z_mod = (w % 4) * 8;
float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 4);
int k_bit = (k % 4) * 8;
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w]; // 獲取 scale fp16型別
scalar_t zero = scalar_t((((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1) & 0x0f);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
weight[k] = scale * (w_tmp - zero); // 反量化
}
scalar_t res;
for (int b = 0; b < batch; ++b){
res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k]; // 相乘
}
atomicAdd(&mul[b * width + w], res); // 賦值相乘結果
__syncthreads();
}
}