gptq 中W4A16 或者 W8A16 中具體是怎麼計算的呢?

沉淀fc發表於2024-08-19

在深入瞭解了 quantization 之後,對quant有所瞭解之後,不論是 dynamic quant還是static quant都有所瞭解,但是因為看了大佬的有關量化之後,理解了trt中的W8A8的運算,理解了為什麼量化之後會加速的原因,但是針對gptq的 W8A16或者W4A16 卻不明白到底屬於是 dynamic quant 還是 static quant,因此糾結了好久,後續透過看了gptq的原始碼理解到,整個過程其實是 將量化的 weight 先反量化為 fp16 然後再和 W*X再進行運算,具體原始碼可以參看gptq的原始碼。

但看完之後,又糾結了,就是覺得既然在相乘之前,有個反量化的過程,豈不是速度變慢了?為啥大家都說速度加快了呢?為什麼加速了呢?還是糾結


void vecquant8matmul_cuda(
  torch::Tensor vec,
  torch::Tensor mat,
  torch::Tensor mul,
  torch::Tensor scales,
  torch::Tensor zeros,
  torch::Tensor g_idx
) {
  int batch = vec.size(0);
  int vec_height = vec.size(1);
  int height = mat.size(0);
  int width = mat.size(1);
  int zero_width = zeros.size(1);

  dim3 blocks(
    (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
  );
  dim3 threads(BLOCKWIDTH); // 申請資源

  AT_DISPATCH_FLOATING_TYPES(
    vec.type(), "vecquant8matmul_cuda", ([&] {
      VecQuant8MatMulKernel<<<blocks, threads>>>( // 真正的cuda函式,包裝了thread之後重新呼叫
        vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
        scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(), 
        batch, vec_height, height, width, zero_width
      );
    })
  );
}

template <typename scalar_t> // 型別模版
__global__ void VecQuant8MatMulKernel(
    const  scalar_t* __restrict__ vec, // x
    const       int* __restrict__ mat, // w
           scalar_t* __restrict__ mul, // w*x 的結果
    const  scalar_t* __restrict__ scales, // w 量化過程中的 scale
    const       int* __restrict__ zeros, // w 量化過程中的 zero
    const   	int* __restrict__ g_idx,
    int batch,
    int vec_height,
    int height,
    int width,
	int zero_width
) {
  int h = BLOCKHEIGHT8 * blockIdx.x;
  int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  
  __shared__ scalar_t blockvec[BLOCKWIDTH];
  int i = width * h + w;
  int g_h = h * 4;
  int k;
  unsigned int g;
  scalar_t w_tmp;
  
  int z_w = w / 4; 
  int z_mod = (w % 4) * 8;
  
  float weight[BLOCKWIDTH];
  
  for (k = 0; k <  BLOCKWIDTH; ++k){	
	int k_w = (k / 4); 
	int k_bit = (k % 4) * 8;
	
    g = as_int(g_idx[g_h + k]);
    scalar_t scale = scales[g * width + w]; // 獲取 scale fp16型別
    scalar_t zero = scalar_t((((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1) & 0x0f);
	
    w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
    
	  weight[k] = scale * (w_tmp - zero); // 反量化
  }

  scalar_t res;
  for (int b = 0; b < batch; ++b){	
	res = 0;
	
    blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
    __syncthreads();
	for (k = 0; k <  BLOCKWIDTH; ++k){	
	  res += weight[k] * blockvec[k]; // 相乘
    }
    atomicAdd(&mul[b * width + w], res); // 賦值相乘結果
    __syncthreads();
  }
}

相關文章