大模型量化技术的原理和代码实现
本文最后更新于 2024年7月10日凌晨1点18分
大模型量化技术的原理
大模型量化,简单而言,就是对大模型中的参数(比如权重参数)转换数据类型,比如从16位浮点型转为8位整型,转换后数据只占用一般存储空间且推理加快,但模型性能损失不大。
目前主要有两个权重量化技术:
- PTQ (Post-Training Quantization)训练后量化:先训练好模型,再把模型权重转为较低精度,而无需任何重新训练。PTQ方法易于实施,但是会导致潜在性能下降。
- QAT (Quantization-Aware Training)量化感知训练:在预训练或者微调阶段结合了权重转换过程,提高模型性能。但是QAT的计算成本高,且需要有代表性的训练数据。
数据类型和存储
计算机上存储数据有这么几种类型,不同类型数据需要不同的存储空间。深度学习中主要用浮点型,下面看看浮点型的数据结构.
浮点型数据使用n位来存储数值,比如float32使用32bit存储一个数,float16使用16bit存储一个数。具体的,这n位分为三部分:
- 符号sign:符号位表示这个数是正数或者负数,占用1bit,0表示正数,1表示负数;
- 指数Exponent:指数位一般占用8bit,表示基数(二进制中通常是2)的幂,指数可以是正数或者负数,让数字很大或者很小;
- 有效数/尾数 Significand/Mantissa:剩余位存储有效数,也称为尾数。
浮点数的这种设计让其有不同精度能覆盖广泛的值,公式如下:
下面给出三个例子方便理解:
float32:1位表示符号,8bit表示指数,剩余23bit表示有效数;
float16: 1bit符号,5bit指数,10bit有效数;
bfloat16: 1bit符号,8bit指数,7bit有效数,和float16相比扩大了范围但降低精度;
量化方法
MinMax量化
MinMax量化属于线性量化,也称为均匀量化。MinMax量化分为对称量化和非对称量化两种。公式如下(这个公式参考的LLM-QAT论文,但是为什么round取整之后还要乘以缩放因子$\alpha$?这个公式1好像是把quantization和dequantization两个过程结合了):
其中,$X_Q$和$X_R$分别表示量化后变量和全精度变量。$i$表示张量中的第$i$个元素。$\alpha$表示放缩因子,$\beta$是零点值。
对于对称量化:
对于非对称量化:
具体的,我们以float32量化为int8为例,此时全精度变量$X_R$的每个元素是float32型,量化后的$X_{quant}$的每个元素是int8型,$X_{dequant}$表示反量化得到的值 它不等于全精度变量$X_R$有一定误差,$N=8$,$X^i$表示变量X的第i个元素,括号$\lfloor \rceil$表示round取整,可以是上取整也可以是下取整。
堆成量化和反量化公式如下:
下面是pytorch实现的$N=8$的对称量化:1
2
3
4
5
6
7
8
9
10
11
12
13import torch
def absmax_quantize(X):
# Calculate scale
scale = 127 / torch.max(torch.abs(X))
# Quantize
X_quant = (scale * X).round()
# Dequantize
X_dequant = X_quant / scale
return X_quant.to(torch.int8), X_dequant
$N=8$的非对称量化:
pytorch代码实现:(这个代码的公式和上述公式略有不同)1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18def zeropoint_quantize(X):
# Calculate value range (denominator)
x_range = torch.max(X) - torch.min(X)
x_range = 1 if x_range == 0 else x_range
# Calculate scale
scale = 255 / x_range
# Shift by zero-point
zeropoint = (-scale * torch.min(X) - 128).round()
# Scale and round the inputs
X_quant = torch.clip((X * scale + zeropoint).round(), -128, 127)
# Dequantize
X_dequant = (X_quant - zeropoint) / scale
return X_quant.to(torch.int8), X_dequant
待办:
机器学习中数据类型有哪些 除了最常见的float32/float16/int8之外。
量化方法MinMax的pytorch实现
量化除了MinMax还有哪些