pytorch小记(八):pytorch中有关于.detach()的浅显见解
pytorch小记(八):pytorch中有关于 .detach() 使用的见解
- 理解神经网络中的局部梯度冻结:为什么部分节点不参与参数更新?
- 引言
- 1. 神经网络的计算图机制
- 关键特性
- 类比示例
- 2. 局部梯度冻结的原理
- 代码示例
- 结果分析
- 3. 可视化计算图
- 4. 为什么局部冻结不会影响全局参数?
- 5. 实际应用场景
- 场景1:自洽目标值训练
- 场景2:多任务学习
- 6. 代码实战:线性回归中的局部冻结
- 代码解释
- 总结
- 对比实验:使用与不使用梯度冻结的训练效果
- 1. 实验设置
- 2. 代码实现
- 3. 结果分析
- 4. 核心结论
- 附录:输出示例
理解神经网络中的局部梯度冻结:为什么部分节点不参与参数更新?
引言
在训练神经网络时,我们经常面临一个挑战:如何让网络的一部分输出保持稳定,同时允许其他部分继续学习?
例如,在复杂任务中,某些中间结果需要作为固定目标值参与损失计算,但这些目标值本身可能来自同一网络的预测。如果直接使用这些预测值,训练过程会因目标值的不断变化而难以收敛。
此时,一种常见的技术是 局部梯度冻结:仅冻结网络中特定节点的梯度回传,而非整个网络。本文将通过拆解计算图与梯度传播机制,解释这一技术的原理和实现方式。
1. 神经网络的计算图机制
神经网络的前向传播会动态构建一个 计算图,记录输入数据与参数之间的运算关系。反向传播时,梯度根据计算图从损失函数回传至参数,指导参数更新。
关键特性
- 参数共享:同一层的权重参数(如 W W W 和 b b b)在所有输入数据的前向传播中共享。
- 计算路径独立:不同输入数据生成独立的前向计算路径,但共享同一套参数。
类比示例
假设网络为线性函数 f ( x ) = W x + b f(x) = Wx + b f(x)=Wx+b:
- 输入 x 1 x_1 x1 时,计算路径为 f ( x 1 ) = W x 1 + b f(x_1) = Wx_1 + b f(x1)=Wx1+b。
- 输入 x 2 x_2 x2 时,计算路径为 f ( x 2 ) = W x 2 + b f(x_2) = Wx_2 + b f(x2)=Wx2+b。
- 两次计算共享参数 W W W 和 b b b,但生成独立的计算节点。
2. 局部梯度冻结的原理
在PyTorch中,.detach()
是一个关键操作,用于 切断某条计算路径的梯度回传,使该路径的输出被视为固定值。其作用可总结为:
- 不冻结参数:网络参数仍可通过其他计算路径更新。
- 仅冻结指定节点的梯度:避免该节点的输出影响参数更新。
代码示例
import torch
# 定义一个简单的线性网络
class SimpleNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.W = torch.nn.Parameter(torch.tensor(2.0))
self.b = torch.nn.Parameter(torch.tensor(0.5))
def forward(self, x):
return self.W * x + self.b
# 初始化网络
net = SimpleNet()
# 输入数据
x1 = torch.tensor(3.0)
x2 = torch.tensor(5.0)
# 前向传播:计算 x1 和 x2 的输出
y1 = net(x1) # 保留梯度
y2 = net(x2).detach() # 冻结梯度
# 计算损失(假设目标值为10.0)
loss = (y1 - 10.0)**2
# 反向传播并更新参数
loss.backward()
结果分析
y1
的梯度会更新参数:梯度从y1
回传,调整 W W W 和 b b b。y2
的梯度被切断:y2
的计算节点不参与参数更新。
3. 可视化计算图
↗ [y2] (detached) → 不参与梯度计算
参数(W, b)
↘ [y1] → 参与梯度计算 → 损失函数
- 参数共享:
y1
和y2
均使用 ( W ) 和 ( b ) 计算。 - 梯度路径独立:仅
y1
的梯度回传影响参数更新。
4. 为什么局部冻结不会影响全局参数?
-
反向传播的局部性
梯度回传时,只有参与损失计算且未被冻结的节点(如y1
)会贡献梯度。冻结节点(如y2
)的梯度在计算图中被标记为“无需计算”。 -
参数更新的独立性
参数的更新仅由未冻结的路径驱动。例如,若y1
的梯度指示 W W W 需要增大,优化器会根据这一信息调整 W W W,而冻结的y2
不影响这一过程。
5. 实际应用场景
场景1:自洽目标值训练
假设需要训练一个网络,使其输出值逐渐逼近自身的历史预测值(类似移动平均)。此时:
- 使用历史参数生成目标值,并冻结其梯度。
- 当前参数通过当前预测值与历史目标值的差异更新。
场景2:多任务学习
在网络的不同分支中,某些分支需要固定特征提取器(如预训练层),而其他分支需要继续训练。通过冻结特定分支的梯度,可以实现参数的局部更新。
6. 代码实战:线性回归中的局部冻结
import torch
import matplotlib.pyplot as plt
# 生成数据
X = torch.linspace(0, 10, 100).reshape(-1, 1)
Y = 2 * X + 3 + torch.randn(100).reshape(-1, 1) * 2 # Y = 2X + 3 + 噪声
# 定义网络
class LinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
model = LinearModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = torch.nn.MSELoss()
# 训练过程
losses = []
for epoch in range(100):
# 前向传播(冻结部分输入的梯度)
y_pred_frozen = model(X[:50]).detach() # 冻结前50个样本的预测
y_pred_train = model(X[50:]) # 后50个样本参与训练
# 计算损失
loss = loss_fn(y_pred_train, Y[50:])
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
# 绘制损失曲线
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss (Partial Freezing)")
plt.show()
代码解释
- 冻结部分输入的预测值:前50个样本的预测值被冻结,仅后50个样本参与训练。
- 参数仍全局更新:尽管前50个样本的梯度被切断,后50个样本的梯度仍会更新所有参数。
总结
- 局部梯度冻结的意义:通过冻结特定节点的梯度,可以在保持参数全局共享的前提下,控制梯度回传路径,避免目标值波动。
- 实现方式:使用
.detach()
切断指定计算节点的梯度,而非冻结整个网络。 - 核心原则:参数更新由未冻结的节点驱动,冻结节点仅提供稳定的中间值。
理解这一机制,能够帮助开发者在复杂模型训练中灵活控制梯度流,平衡模型的稳定性与学习能力。
对比实验:使用与不使用梯度冻结的训练效果
以下代码演示了在自参考目标值(目标值依赖网络自身预测)的场景中,冻结梯度与不冻结梯度对训练结果的影响。
1. 实验设置
- 任务:训练一个线性网络,使其输出逼近自身预测值的移动平均(模拟自参考目标)。
- 网络结构:单层线性网络 y = W x + b y = Wx + b y=Wx+b。
- 数据生成:输入 x x x 服从均匀分布,目标值 y target = y pred_history + ϵ y_{\text{target}} = y_{\text{pred\_history}} + \epsilon ytarget=ypred_history+ϵ,其中 y pred_history y_{\text{pred\_history}} ypred_history 是网络历史预测的均值, ϵ \epsilon ϵ 为噪声。
- 关键区别:
- 实验组:冻结目标值的梯度(使用
.detach()
)。 - 对照组:不冻结目标值的梯度。
- 实验组:冻结目标值的梯度(使用
2. 代码实现
import torch
import matplotlib.pyplot as plt
# 生成输入数据
x = torch.linspace(0, 10, 100).reshape(-1, 1) # 输入特征
# 定义网络
class LinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
# 训练函数(参数use_detach控制是否冻结梯度)
def train_model(use_detach=True, num_epochs=200):
model = LinearModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = torch.nn.MSELoss()
losses = []
y_pred_history = torch.zeros_like(x) # 历史预测均值
for epoch in range(num_epochs):
# 前向传播
y_pred = model(x)
# 生成目标值:历史预测均值 + 噪声
target = y_pred_history.detach().clone() + torch.randn_like(x) * 0.5
# 对照组:不冻结目标值的梯度(移除.detach())
if not use_detach:
target = y_pred_history + torch.randn_like(x) * 0.5
# 计算损失
loss = loss_fn(y_pred, target)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 更新历史预测均值(滑动平均)
y_pred_history = 0.9 * y_pred_history + 0.1 * y_pred.detach()
losses.append(loss.item())
return losses
# 运行实验
losses_detach = train_model(use_detach=True) # 冻结梯度
losses_no_detach = train_model(use_detach=False) # 不冻结梯度
# 可视化结果
plt.plot(losses_detach, label="Freeze Gradients (use_detach=True)")
plt.plot(losses_no_detach, label="No Freeze (use_detach=False)")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Comparison")
plt.legend()
plt.show()
3. 结果分析
-
冻结梯度(use_detach=True):
损失曲线稳定下降,网络逐步学习到目标值的移动平均模式(图1蓝色)。 -
不冻结梯度(use_detach=False):
损失曲线剧烈震荡,甚至可能发散(图1橙色)。原因如下:- 目标值随网络参数变化:目标值 y target y_{\text{target}} ytarget 的计算包含网络自身的预测梯度,导致目标值随参数更新不断波动。
- 自我追逐问题:网络试图逼近一个动态变化的目标,类似于“追逐自己的影子”,无法收敛。
4. 核心结论
- 冻结梯度的意义:
当目标值依赖网络自身预测时,冻结其梯度可以稳定训练目标,避免目标值波动导致的震荡。 - 不冻结的后果:
目标值与网络参数形成耦合反馈,训练过程失去稳定性,最终难以收敛。
附录:输出示例
图1:冻结梯度(蓝)与不冻结梯度(橙)的损失曲线对比