当前位置: 首页 > article >正文

pytorch小记(八):pytorch中有关于.detach()的浅显见解

pytorch小记(八):pytorch中有关于 .detach() 使用的见解

  • 理解神经网络中的局部梯度冻结:为什么部分节点不参与参数更新?
    • 引言
    • 1. 神经网络的计算图机制
      • 关键特性
      • 类比示例
    • 2. 局部梯度冻结的原理
      • 代码示例
      • 结果分析
    • 3. 可视化计算图
    • 4. 为什么局部冻结不会影响全局参数?
    • 5. 实际应用场景
      • 场景1:自洽目标值训练
      • 场景2:多任务学习
    • 6. 代码实战:线性回归中的局部冻结
      • 代码解释
    • 总结
  • 对比实验:使用与不使用梯度冻结的训练效果
    • 1. 实验设置
    • 2. 代码实现
    • 3. 结果分析
    • 4. 核心结论
    • 附录:输出示例


理解神经网络中的局部梯度冻结:为什么部分节点不参与参数更新?


引言

在训练神经网络时,我们经常面临一个挑战:如何让网络的一部分输出保持稳定,同时允许其他部分继续学习?
例如,在复杂任务中,某些中间结果需要作为固定目标值参与损失计算,但这些目标值本身可能来自同一网络的预测。如果直接使用这些预测值,训练过程会因目标值的不断变化而难以收敛。

此时,一种常见的技术是 局部梯度冻结:仅冻结网络中特定节点的梯度回传,而非整个网络。本文将通过拆解计算图与梯度传播机制,解释这一技术的原理和实现方式。


1. 神经网络的计算图机制

神经网络的前向传播会动态构建一个 计算图,记录输入数据与参数之间的运算关系。反向传播时,梯度根据计算图从损失函数回传至参数,指导参数更新。

关键特性

  • 参数共享:同一层的权重参数(如 W W W b b b)在所有输入数据的前向传播中共享。
  • 计算路径独立:不同输入数据生成独立的前向计算路径,但共享同一套参数。

类比示例

假设网络为线性函数 f ( x ) = W x + b f(x) = Wx + b f(x)=Wx+b

  • 输入 x 1 x_1 x1 时,计算路径为 f ( x 1 ) = W x 1 + b f(x_1) = Wx_1 + b f(x1)=Wx1+b
  • 输入 x 2 x_2 x2 时,计算路径为 f ( x 2 ) = W x 2 + b f(x_2) = Wx_2 + b f(x2)=Wx2+b
  • 两次计算共享参数 W W W b b b,但生成独立的计算节点。

2. 局部梯度冻结的原理

在PyTorch中,.detach() 是一个关键操作,用于 切断某条计算路径的梯度回传,使该路径的输出被视为固定值。其作用可总结为:

  • 不冻结参数:网络参数仍可通过其他计算路径更新。
  • 仅冻结指定节点的梯度:避免该节点的输出影响参数更新。

代码示例

import torch

# 定义一个简单的线性网络
class SimpleNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.W = torch.nn.Parameter(torch.tensor(2.0))
        self.b = torch.nn.Parameter(torch.tensor(0.5))

    def forward(self, x):
        return self.W * x + self.b

# 初始化网络
net = SimpleNet()

# 输入数据
x1 = torch.tensor(3.0)
x2 = torch.tensor(5.0)

# 前向传播:计算 x1 和 x2 的输出
y1 = net(x1)          # 保留梯度
y2 = net(x2).detach() # 冻结梯度

# 计算损失(假设目标值为10.0)
loss = (y1 - 10.0)**2

# 反向传播并更新参数
loss.backward()

结果分析

  • y1 的梯度会更新参数:梯度从 y1 回传,调整 W W W b b b
  • y2 的梯度被切断y2 的计算节点不参与参数更新。

3. 可视化计算图

       ↗ [y2] (detached) → 不参与梯度计算  
参数(W, b)  
       ↘ [y1] → 参与梯度计算 → 损失函数
  • 参数共享y1y2 均使用 ( W ) 和 ( b ) 计算。
  • 梯度路径独立:仅 y1 的梯度回传影响参数更新。

4. 为什么局部冻结不会影响全局参数?

  1. 反向传播的局部性
    梯度回传时,只有参与损失计算且未被冻结的节点(如 y1)会贡献梯度。冻结节点(如 y2)的梯度在计算图中被标记为“无需计算”。

  2. 参数更新的独立性
    参数的更新仅由未冻结的路径驱动。例如,若 y1 的梯度指示 W W W 需要增大,优化器会根据这一信息调整 W W W,而冻结的 y2 不影响这一过程。


5. 实际应用场景

场景1:自洽目标值训练

假设需要训练一个网络,使其输出值逐渐逼近自身的历史预测值(类似移动平均)。此时:

  • 使用历史参数生成目标值,并冻结其梯度。
  • 当前参数通过当前预测值与历史目标值的差异更新。

场景2:多任务学习

在网络的不同分支中,某些分支需要固定特征提取器(如预训练层),而其他分支需要继续训练。通过冻结特定分支的梯度,可以实现参数的局部更新。


6. 代码实战:线性回归中的局部冻结

import torch
import matplotlib.pyplot as plt

# 生成数据
X = torch.linspace(0, 10, 100).reshape(-1, 1)
Y = 2 * X + 3 + torch.randn(100).reshape(-1, 1) * 2  # Y = 2X + 3 + 噪声

# 定义网络
class LinearModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

model = LinearModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = torch.nn.MSELoss()

# 训练过程
losses = []
for epoch in range(100):
    # 前向传播(冻结部分输入的梯度)
    y_pred_frozen = model(X[:50]).detach()  # 冻结前50个样本的预测
    y_pred_train = model(X[50:])            # 后50个样本参与训练

    # 计算损失
    loss = loss_fn(y_pred_train, Y[50:])
    
    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())

# 绘制损失曲线
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss (Partial Freezing)")
plt.show()

代码解释

  • 冻结部分输入的预测值:前50个样本的预测值被冻结,仅后50个样本参与训练。
  • 参数仍全局更新:尽管前50个样本的梯度被切断,后50个样本的梯度仍会更新所有参数。

总结

  • 局部梯度冻结的意义:通过冻结特定节点的梯度,可以在保持参数全局共享的前提下,控制梯度回传路径,避免目标值波动。
  • 实现方式:使用 .detach() 切断指定计算节点的梯度,而非冻结整个网络。
  • 核心原则:参数更新由未冻结的节点驱动,冻结节点仅提供稳定的中间值。

理解这一机制,能够帮助开发者在复杂模型训练中灵活控制梯度流,平衡模型的稳定性与学习能力。


对比实验:使用与不使用梯度冻结的训练效果

以下代码演示了在自参考目标值(目标值依赖网络自身预测)的场景中,冻结梯度不冻结梯度对训练结果的影响。


1. 实验设置

  • 任务:训练一个线性网络,使其输出逼近自身预测值的移动平均(模拟自参考目标)。
  • 网络结构:单层线性网络 y = W x + b y = Wx + b y=Wx+b
  • 数据生成:输入 x x x 服从均匀分布,目标值 y target = y pred_history + ϵ y_{\text{target}} = y_{\text{pred\_history}} + \epsilon ytarget=ypred_history+ϵ,其中 y pred_history y_{\text{pred\_history}} ypred_history 是网络历史预测的均值, ϵ \epsilon ϵ 为噪声。
  • 关键区别
    • 实验组:冻结目标值的梯度(使用 .detach())。
    • 对照组:不冻结目标值的梯度。

2. 代码实现

import torch
import matplotlib.pyplot as plt

# 生成输入数据
x = torch.linspace(0, 10, 100).reshape(-1, 1)  # 输入特征

# 定义网络
class LinearModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

# 训练函数(参数use_detach控制是否冻结梯度)
def train_model(use_detach=True, num_epochs=200):
    model = LinearModel()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    loss_fn = torch.nn.MSELoss()
    losses = []
    y_pred_history = torch.zeros_like(x)  # 历史预测均值
    
    for epoch in range(num_epochs):
        # 前向传播
        y_pred = model(x)
        
        # 生成目标值:历史预测均值 + 噪声
        target = y_pred_history.detach().clone() + torch.randn_like(x) * 0.5
        
        # 对照组:不冻结目标值的梯度(移除.detach())
        if not use_detach:
            target = y_pred_history + torch.randn_like(x) * 0.5
        
        # 计算损失
        loss = loss_fn(y_pred, target)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 更新历史预测均值(滑动平均)
        y_pred_history = 0.9 * y_pred_history + 0.1 * y_pred.detach()
        
        losses.append(loss.item())
    
    return losses

# 运行实验
losses_detach = train_model(use_detach=True)    # 冻结梯度
losses_no_detach = train_model(use_detach=False) # 不冻结梯度

# 可视化结果
plt.plot(losses_detach, label="Freeze Gradients (use_detach=True)")
plt.plot(losses_no_detach, label="No Freeze (use_detach=False)")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Comparison")
plt.legend()
plt.show()

3. 结果分析

  • 冻结梯度(use_detach=True)
    损失曲线稳定下降,网络逐步学习到目标值的移动平均模式(图1蓝色)。

  • 不冻结梯度(use_detach=False)
    损失曲线剧烈震荡,甚至可能发散(图1橙色)。原因如下:

    1. 目标值随网络参数变化:目标值 y target y_{\text{target}} ytarget 的计算包含网络自身的预测梯度,导致目标值随参数更新不断波动。
    2. 自我追逐问题:网络试图逼近一个动态变化的目标,类似于“追逐自己的影子”,无法收敛。

4. 核心结论

  • 冻结梯度的意义
    当目标值依赖网络自身预测时,冻结其梯度可以稳定训练目标,避免目标值波动导致的震荡。
  • 不冻结的后果
    目标值与网络参数形成耦合反馈,训练过程失去稳定性,最终难以收敛。

附录:输出示例

在这里插入图片描述图1:冻结梯度(蓝)与不冻结梯度(橙)的损失曲线对比


http://www.kler.cn/a/589023.html

相关文章:

  • PostgreSQL技术内幕26:PG聚合算子实现分析
  • 【VBA】excel获取股票实时行情(历史数据,基金数据下载)
  • 量子计算与医疗诊断的未来:超越传统的无限可能
  • 深度学习篇---Opencv中Haar级联分类器的自定义
  • Java 大视界 -- 基于 Java 的大数据机器学习模型的迁移学习应用与实践(129)
  • 五大基础算法——贪心算法
  • 各省水资源平台 水资源遥测终端机都用什么协议
  • CSS中绝对定位
  • 【报错问题】在visual studio 终端使用npm -v后报错禁止运行脚本怎么处理
  • 大三下找C++开发实习的感受分享
  • 网络通信(传输层协议:TCP/IP ,UDP):
  • ADA-YOLO模型深度解析 | 自适应动态注意力驱动的目标检测新范式
  • 本地部署Deep Seek-R1,搭建个人知识库——笔记
  • 【Go每日一练】计算整数数组的最大子数组和
  • 深入解析大语言模型的 Function Call 实现—— 以 Qwen2.5为例
  • 02-Canvas-fabric.ActiveSelection
  • 浅谈Mysql数据库事务操作 用mybatis操作mysql事务 再在Springboot中使用Spring事务控制mysql事务回滚
  • 数学 :矩阵
  • 【Gitee】删除仓库的详细步骤
  • ArcGIS 水利制图符号库:提升水利工作效率的利器