模型压缩之剪枝
(1)通道选择
这里要先解释一下:
(1)通道剪枝
那我们实际做法不是上面直接对所有层都添加L1正则项,而是仅仅对BN层权重添加L1正则项。通道剪枝具体步骤如下:
1.BN层权重添加L1正则项,进行稀疏训练
2.对BN层权重的scale factor进行排序,对scale factor低于阈值的通道进行裁剪,得到剪枝模型
3.对剪枝模型进行finetune
注:进行finetune的目的是因为剪枝完整个网络结构发生了变化,之前的训练的模型无法再加载进入,必须要finetune(或者这里用重新训练更合适),否则会发现推理结果都是0.
在深度学习中,Batch Normalization(BN)层通常用于加速训练过程并提高模型的泛化能力。BN层的权重参数包括scale factor(缩放因子)和shift factor(偏移因子)。通过对BN层的scale factor添加L1正则化,我们可以实现通道剪枝。
下面是一个示例代码,展示了如何对BN层的scale factor添加L1正则化,并进行通道剪枝和微调(finetune)。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# L1正则化参数
lambda_l1 = 0.001
# 稀疏训练
for epoch in range(10):
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# 计算BN层scale factor的L1正则化项
l1_regularization = 0
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
l1_regularization += torch.norm(module.weight, p=1)
loss += lambda_l1 * l1_regularization
loss.backward()
optimizer.step()
print(f'Epoch {epoch + 1}, Loss: {loss.item()}')
print("稀疏训练完成")
# 通道剪枝
def prune_channels(model, sparsity_threshold):
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
weights = module.weight.data
mask = torch.abs(weights) > sparsity_threshold
module.weight.data = weights[mask]
module.bias.data = module.bias.data[mask]
module.num_features = int(torch.sum(mask))
# 更新卷积层的输入通道数
if hasattr(module, 'conv'):
conv_module = getattr(module, 'conv')
conv_module.out_channels = int(torch.sum(mask))
conv_module.weight.data = conv_module.weight.data[mask]
if conv_module.bias is not None:
conv_module.bias.data = conv_module.bias.data[mask]
# 设置稀疏性阈值
sparsity_threshold = 0.01
# 剪枝
prune_channels(model, sparsity_threshold)
print("通道剪枝完成")
# 微调
for epoch in range(10):
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(f'Finetune Epoch {epoch + 1}, Loss: {loss.item()}')
print("微调完成")
(2)卷积核剪枝
1.conv层权重添加L1正则项,进行稀疏训练
2.对conv层权重进行排序,对权重低于阈值的卷积核进行裁剪,得到剪枝模型
3.对剪枝模型进行finetune
下面我写了一个简单的示例代码,展示了如何在训练过程中计算权重的稀疏性,并根据稀疏性剪掉通道。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# L1正则化参数
lambda_l1 = 0.001
# 训练模型
for epoch in range(10):
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# 计算L1正则化项
l1_regularization = 0
for param in model.parameters():
l1_regularization += torch.norm(param, p=1)
loss += lambda_l1 * l1_regularization
loss.backward()
optimizer.step()
print(f'Epoch {epoch + 1}, Loss: {loss.item()}')
print("训练完成")
# 根据稀疏性剪掉通道
def prune_channels(model, sparsity_threshold):
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
weights = module.weight.data
abs_weights = torch.abs(weights)
channel_sums = torch.sum(abs_weights, dim=(1, 2, 3))
mask = channel_sums > sparsity_threshold
module.weight.data = weights[mask]
module.out_channels = int(torch.sum(mask))
if module.bias is not None:
module.bias.data = module.bias.data[mask]
# 设置稀疏性阈值
sparsity_threshold = 0.01
# 剪枝
prune_channels(model, sparsity_threshold)
print("剪枝完成")
在上面例子中,我们在训练完成后,通过 prune_channels
函数根据稀疏性剪掉通道。
具体步骤如下:
-
计算权重的稀疏性:对于每个卷积层的权重,我们计算每个通道的权重绝对值之和。
-
剪枝:根据设定的稀疏性阈值,我们创建一个掩码(mask),只保留那些权重绝对值之和大于阈值的通道,并更新卷积层的权重和偏置。
通过这种方式,我们可以根据权重的稀疏性剪掉不重要的通道,从而减少模型的复杂度和计算量。
通道剪枝和卷积核剪枝小结:
卷积核剪枝(Kernel Pruning)和通道剪枝(Channel Pruning)是两种不同的模型剪枝技术,它们在剪枝的对象和目标上有所区别。
卷积核剪枝(Kernel Pruning)
卷积核剪枝 是指从卷积层中移除整个卷积核(kernel)。一个卷积核通常由一组权重组成,这些权重在卷积操作中与输入特征图的局部区域进行卷积运算。卷积核剪枝的目标是移除那些对模型性能贡献较小的卷积核,从而减少模型的计算量和参数数量。
-
剪枝对象:卷积核(kernel)。
-
剪枝目标:移除整个卷积核。
-
影响:减少卷积层的输出通道数。
通道剪枝(Channel Pruning)
通道剪枝 是指从卷积层或全连接层中移除整个通道(channel)。一个通道通常由一组权重组成,这些权重在卷积操作中与输入特征图的所有位置进行卷积运算。通道剪枝的目标是移除那些对模型性能贡献较小的通道,从而减少模型的计算量和参数数量。
-
剪枝对象:通道(channel)。
-
剪枝目标:移除整个通道。
-
影响:减少卷积层的输入或输出通道数。
主要区别
-
剪枝对象:
-
卷积核剪枝针对的是卷积核,即卷积层中的单个权重组。
-
通道剪枝针对的是通道,即卷积层或全连接层中的整个权重集合。
-
-
剪枝目标:
-
卷积核剪枝的目标是移除整个卷积核。
-
通道剪枝的目标是移除整个通道。
-
-
影响:
-
卷积核剪枝主要影响卷积层的输出通道数。
-
通道剪枝既可以影响卷积层的输入通道数,也可以影响输出通道数。
-
卷积核剪枝代码:
def prune_kernels(model, sparsity_threshold):
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
weights = module.weight.data
abs_weights = torch.abs(weights)
kernel_sums = torch.sum(abs_weights, dim=(1, 2, 3))
mask = kernel_sums > sparsity_threshold
module.weight.data = weights[mask]
module.out_channels = int(torch.sum(mask))
if module.bias is not None:
module.bias.data = module.bias.data[mask]
通道剪枝代码:
def prune_channels(model, sparsity_threshold):
#遍历模型中的所有模块
for module in model.modules():
#检查模块是否为BN层
if isinstance(module, nn.BatchNorm2d):
#获取BN层的权重
weights = module.weight.data
#根据稀疏性阈值创建掩码
mask = torch.abs(weights) > sparsity_threshold
#应用掩码到BN层的权重和偏置
module.weight.data = weights[mask]
module.bias.data = module.bias.data[mask]
module.num_features = int(torch.sum(mask))
#检查BN层是否有与之关联的卷积层
if hasattr(module, 'conv'):
conv_module = getattr(module, 'conv')
#应用掩码到卷积层的权重和偏置
conv_module.out_channels = int(torch.sum(mask))
conv_module.weight.data = conv_module.weight.data[mask]
if conv_module.bias is not None:
conv_module.bias.data = conv_module.bias.data[mask]
在通道剪枝中,我们不仅需要剪枝Batch Normalization(BN)层的权重,还需要相应地剪枝与之关联的卷积层的权重。具体来说,BN层的权重(scale factor)决定了哪些通道是重要的,因此我们需要根据BN层的权重来剪枝卷积层的通道。
通过这种方式,我们确保了BN层的剪枝与卷积层的剪枝是一致的,即剪枝后的BN层和卷积层具有相同的通道数。这样可以保证模型在剪枝后的结构是有效的,并且能够正常工作。
总结来说,通道剪枝不仅涉及BN层的权重剪枝,还涉及与之关联的卷积层的权重剪枝,以确保剪枝后的模型结构的一致性和有效性。
(3)特征图重构
特征图重构是一种在通道剪枝中常用的方法,旨在最小化剪枝后特征图与原始特征图之间的差异。通过这种方式,我们可以更直接地控制剪枝的力度,并确保剪枝后的模型在性能上与原始模型尽可能接近。
下面是一个示例代码,展示了如何使用最小二乘法(linear least squares)来实现特征图重构,从而控制通道剪枝的力度。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(10):
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(f'Epoch {epoch + 1}, Loss: {loss.item()}')
print("训练完成")
# 特征图重构
def feature_map_reconstruction(model, train_loader, alpha=0.01):
model.eval()
original_features = []
pruned_features = []
# 收集原始特征图
with torch.no_grad():
for data, _ in train_loader:
output = model(data)
original_features.append(output)
# 剪枝
def prune_channels(model, sparsity_threshold):
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
weights = module.weight.data
mask = torch.abs(weights) > sparsity_threshold
module.weight.data = weights[mask]
module.bias.data = module.bias.data[mask]
module.num_features = int(torch.sum(mask))
# 更新卷积层的输入通道数
if hasattr(module, 'conv'):
conv_module = getattr(module, 'conv')
conv_module.out_channels = int(torch.sum(mask))
conv_module.weight.data = conv_module.weight.data[mask]
if conv_module.bias is not None:
conv_module.bias.data = conv_module.bias.data[mask]
# 设置稀疏性阈值
sparsity_threshold = 0.01
prune_channels(model, sparsity_threshold)
# 收集剪枝后的特征图
with torch.no_grad():
for data, _ in train_loader:
output = model(data)
pruned_features.append(output)
# 计算特征图差异
original_features = torch.cat(original_features, dim=0)
pruned_features = torch.cat(pruned_features, dim=0)
diff = original_features - pruned_features
loss = alpha * torch.norm(diff, p=2)
# 反向传播和优化
loss.backward()
optimizer.step()
print(f'Feature Map Reconstruction Loss: {loss.item()}')
# 特征图重构
feature_map_reconstruction(model, train_loader)
print("特征图重构完成")