模型量化相关知识汇总
量化&反量化
量化操作可以将浮点数转换为低比特位数据表示,比如int8和 uint8.
Q(x_fp32, scale, zero_point) = round(x_fp32/scale) + zero_point,
量化后的数据可以经过反量化操作来获取浮点数
x_fp32 = (Q - zero_point)* scale
pytorch中 quantize_per_tensor的解释
pytorch可以使用quantize_per_tensor
函数来对一个浮点tensor做8bit量化.
ts_quant = torch.quantize_per_tensor(ts, scale = 0.1, zero_point = 10, dtype = torch.quint8)
print(f'fp32 ts:{ts}, quant ts:{ts_quant}, int_repr:{ts_quant.int_repr()}')
# 截断后的浮点数.
naive_quant = np.array([24.5, 1.0, 2.0]) / 0.1 + 10
print(f'naive_quant:{naive_quant}')
打印结果:
fp32 ts:tensor([100., 1., 2.]), quant ts:tensor([24.5000, 1.0000, 2.0000], size=(3,), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), int_repr:tensor([255, 20, 30], dtype=torch.uint8)
naive_quant:[255. 20. 30.]
有几点值得说明
- ts_quant直接打印出来的值并不是quint8数据,依然是个浮点,它表示的是映射到8bit后被截断剩下有效的浮点范围. scale=0.1的前提下,uint8只能表示出[0, 25.5]范围的浮点数,加上zero_point=10, 那么输入的浮点必须要在[0, 24.5]范围内才能被表示,超出部分会被截断,所以打印出来的是
24.5, 1, 2
. - 浮点表示的[24.5, 1.0, 2.0]可以通过int_repr方法打印出定点表示.
- int_repr等效与naive_quant的实现