【Datawhale组队学习】模型减肥秘籍:模型压缩技术6——项目实践
NNI (Neural Network Intelligence) 是由微软开发的一个开源自动化机器学习(AutoML)库,用于帮助研究人员和开发人员高效地进行机器学习实验。它提供了一套丰富的工具来进行模型调优、神经网络架构搜索、模型压缩以及自动化的超参数搜索。
1.模型剪枝
代码的主要目的是展示如何通过 NNI 进行神经网络模型剪枝,以减少模型大小和计算复杂度,之后通过微调来恢复模型的性能,并实现剪枝后的加速部署。这种方法有助于将复杂的神经网络压缩,使其更适合在资源受限的设备上运行,同时保持尽可能高的准确率。
核心代码:
config_list = [{
'op_types': ['Linear', 'Conv2d'],
'exclude_op_names': ['fc3'],
'sparse_ratio': 0.8
}]
这段代码定义了一个剪枝配置 config_list,用于指定如何对模型的特定层进行剪枝。下面是每个参数的解释:
1. op_types: [‘Linear’, ‘Conv2d’]
这个参数指定了需要进行剪枝的层类型。‘Linear’ 和 ‘Conv2d’ 表示对模型中的全连接层和卷积层(Conv2d)进行剪枝。
2. exclude_op_names: [‘fc3’]
这个参数指定了在剪枝过程中需要排除的层。名为 ‘fc3’ 的层将不会被剪枝,即该层不受剪枝影响。
3. sparse_ratio: 0.8
这个参数指定了剪枝的稀疏度比例。‘sparse_ratio’: 0.8 表示在选定的层中,要将 80% 的参数剪枝掉,只保留 20% 的权重。
2.模型量化
使用 NNI 框架对深度学习模型进行量化处理的实践。代码使用了一种训练后量化的技术,通过减少模型中参数的位数,来减小模型大小并加快推理速度。
# 量化配置,将卷积层和全连接层量化为int8类型
config_list = [{
'op_names': ['conv1', 'conv2', 'fc1', 'fc2'], # 需要量化的操作
'target_names': ['_input_', 'weight', '_output_'], # 量化输入、权重和输出
'quant_dtype': 'int8', # 使用int8类型进行量化
'quant_scheme': 'affine', # 量化方法,使用仿射变换
'granularity': 'default',
},{
'op_names': ['relu1', 'relu2'], # 需要量化的激活函数
'target_names': ['_output_'], # 量化输出
'quant_dtype': 'int8',
'quant_scheme': 'affine',
'granularity': 'default',
}]
# 创建QATQuantizer对象进行量化感知训练
quantizer = QATQuantizer(model, config_list, evaluator, len(train_loader))
通过量化感知训练减少模型的复杂度,使模型可以被压缩为 int8 类型。这样能够在保证模型性能的同时,显著降低模型大小并加快推理速度,尤其适合在资源受限的环境中使用。
3.NAS
代码使用 NNI 进行神经网络架构搜索(NAS)的实践,展示如何通过定义模型空间、选择搜索策略、训练评估模型并启动实验来寻找最优模型架构。
核心代码:
定义了一个名为 MyModelSpace 的模型空间,通过 LayerChoice 和 MutableXXX 使其包含多种可能的结构,用于搜索不同的架构组合。
class MyModelSpace(ModelSpace):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
# LayerChoice用于选择卷积层类型(标准卷积或深度可分离卷积)
self.conv2 = LayerChoice([
nn.Conv2d(32, 64, 3, 1),
DepthwiseSeparableConv(32, 64)
], label='conv2')
# MutableDropout用于从指定的概率中选择一个dropout率
self.dropout1 = MutableDropout(nni.choice('dropout', [0.25, 0.5, 0.75]))
self.dropout2 = nn.Dropout(0.5)
feature = nni.choice('feature', [64, 128, 256])
self.fc1 = MutableLinear(9216, feature)
self.fc2 = MutableLinear(feature, 10)
使用随机搜索策略来探索模型空间,以选择不同的层配置。
import nni.nas.strategy as strategy
search_strategy = strategy.Random() # 使用随机搜索策略
代码通过 NNI 的 NAS 功能自动化地探索不同的神经网络结构组合,减少了人工设计架构的复杂性。结合随机搜索策略,能够有效地在不同架构之间进行搜索,并通过实验来评估每个模型的性能。
4.使用NNI对模型进行剪枝、量化、蒸馏压缩
使用 NNI 框架对 ResNet18 模型进行融合压缩,涉及模型剪枝、量化和知识蒸馏等方法。
使用了 TaylorPruner 和 AGPPruner,分别对模型进行基于泰勒展开法的重要性评估以及渐进剪枝。
# 设置剪枝配置
bn_list = [module_name for module_name, module in model.named_modules() if isinstance(module, torch.nn.BatchNorm2d)]
p_config_list = [{
'op_types': ['Conv2d'],
'sparse_ratio': 0.5
}, *[{
'op_names': [name],
'target_names': ['_output_'],
'target_settings': {
'_output_': {
'align': {
'module_name': name.replace('bn', 'conv') if 'bn' in name else name.replace('downsample.1', 'downsample.0'),
'target_name': 'weight',
'dims': [0],
},
'granularity': 'per_channel'
}
}
} for name in bn_list]]
# 使用 TaylorPruner 和 AGPPruner 进行剪枝
sub_pruner = TaylorPruner(model, p_config_list, evaluator, training_steps=100)
scheduled_pruner = AGPPruner(sub_pruner, interval_steps=100, total_times=30)
使用了 QATQuantizer,以量化感知训练(QAT)的方式对模型进行量化。
q_config_list = [{
'op_types': ['Conv2d'],
'quant_dtype': 'int8',
'target_names': ['_input_'],
'granularity': 'per_channel'
}, {
'op_types': ['BatchNorm2d'],
'quant_dtype': 'int8',
'target_names': ['_output_'],
'granularity': 'per_channel'
}]
quantizer = QATQuantizer.from_compressor(scheduled_pruner, q_config_list, quant_start_step=100)
使用 DynamicLayerwiseDistiller 对模型进行蒸馏,将教师模型的知识传递给学生模型。
def teacher_predict(batch, teacher_model):
return teacher_model(batch[0])
d_config_list = [{
'op_types': ['Conv2d'],
'lambda': 0.1,
'apply_method': 'mse',
}]
distiller = DynamicLayerwiseDistiller.from_compressor(quantizer, d_config_list, teacher_model, teacher_predict, 0.1)
通过剪枝和量化得到的稀疏性,可以利用 ModelSpeedup 来加速模型推理。
masks = scheduled_pruner.get_masks()
speedup = ModelSpeedup(model, dummy_input, masks)
model = speedup.speedup_model()
展示了如何使用 NNI 对 ResNet18 模型进行融合压缩,在减少模型参数数量和加速推理的同时,最大限度地保留模型的性能。这种方式能够显著减小模型的存储需求和计算成本,非常适合在资源受限的设备上进行模型部署。
参考文献
- https://www.datawhale.cn/learn/content/68/966
- https://github.com/datawhalechina/awesome-compression/tree/main/docs/notebook/ch07
- https://nni.readthedocs.io/en/latest/