当前位置: 首页 > article >正文

神经网络入门实战:(二十二)只训练 (多层网络的) 指定层 / (单层网络的) 指定参数

只训练 (多层网络的) 指定层 / (单层网络的) 指定参数

在训练的时候,有时候并不需要将网络层每次都从头训练,对于只训练指定层的情况,可以通过设置该层参数的 requires_grad =True ,其余层通过设置 requires_grad = False 来冻结(不更新权重)。

【注意】:此时就不要使用 self.model = nn.Sequential(...) 函数将所有层都放在一起,变成一个整体了。

1)多层网络

在定义模型时,在 __init__ 函数后, forward 函数前,加上下方这段代码:

# 初始化时先冻结所有层
for name, param in self.named_parameters():
    param.requires_grad = False

# 然后单独设置指定层的 requires_grad 为 True
for param in self.layer_name.parameters(): # layer_name 就是要单独训练的层的自定义名字
    param.requires_grad = True

随后在创建优化器时,只输入指定层的参数。【如果想要训练所有层的话,优化器的输入一般是 model_Instance.parameters()

optimizer = torch.optim.SGD(model_instance.layer_name.parameters(), lr=0.01)

具体示例:

CIFAR10 配套网络模型为例(只训练第三个卷积层):

import torch
import torch.nn as nn
import torch.optim as optim

class CIFAR10_NET(nn.Module):
    def __init__(self):
        super(CIFAR10_NET, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5, padding=2)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 32, 5, padding=2)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(32, 64, 5, padding=2)  # 这是我们要训练的层
        self.pool3 = nn.MaxPool2d(2, 2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(1024, 64)  # 注意:这里的1024是基于输入图像大小和前面的卷积+池化层计算得出的
        self.fc2 = nn.Linear(64, 10)

        # 初始化时先冻结所有层
        for name, param in self.named_parameters():
            param.requires_grad = False

        # 然后单独设置第三个卷积层的requires_grad为True
        for param in self.conv3.parameters():
            param.requires_grad = True

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = self.pool3(torch.relu(self.conv3(x)))  # 第三个卷积层
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化模型
model_instance = CIFAR10_NET()

# 创建优化器,只包含第三个卷积层的参数
optimizer = optim.SGD(model_instance.conv3.parameters(), lr=0.01)

# 损失函数
criterion = nn.CrossEntropyLoss()

# 示例输入和目标(这里需要真实的CIFAR-10数据,但这里只是示例)
inputs = torch.randn(1, 3, 32, 32)  # 假设输入图像大小为32x32
targets = torch.tensor([1])  # 假设目标类别为1

# 训练循环(简化版)
model.train()
for epoch in range(10):  # 假设训练10个epoch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

2)从官网或本地加载预训练模型,并修改或添加某个层,并且只训练该层

由于这种情况,代码里不会再完整的定义一遍网络模型,所以无法直接在模型定义时选择该指定哪一层训练。

此时需要先导入预训练模型,然后加入下方代码进行选择:

for name, param in model_name.named_parameters(): # model_name为模型名
    # layer_name为指定层名
    if 'layer_name' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False  # 冻结其他所有参数

然后再指定优化器的输入参数即可:

model_instance = model_name() # 实例化模型
optimizer = optim.SGD(model_instance.layer_name.parameters(), lr=0.01)

具体示例:

仍以 CIFAR10 配套网络模型为例(将最后的 nn.Linear(64, 10) 层改为 nn.Linear(64, 20) )。并且原先的网络模型已经训练好了并保存在本地了。

import torch
import torch.nn as nn
import torch.optim as optim

CIFAR10_NET_new = torch.load("E:\\5_NN_model\\CIFAR10_NET.pth") # 原先将整个预训练模型保存在了本地
CIFAR10_NET_new.fc2 = nn.Linear(64, 20) # 修改最后一个全连接层,其名字为: fc2 

for name, param in CIFAR10_NET_new.named_parameters():
    # 如果参数属于修改后的全连接层,则设置为可训练
    if 'fc2' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False  # 冻结其他所有参数

model_instance = CIFAR10_NET_new() # 实例化模型
        
# 创建优化器,只包含修改后的全连接层的参数
optimizer = optim.SGD(model_instance.fc2.parameters(), lr=0.01)

# 损失函数
criterion = nn.CrossEntropyLoss()

# 示例输入和目标(这里需要真实的CIFAR-10数据,但这里只是示例)
inputs = torch.randn(1, 3, 32, 32)  # 假设输入图像大小为32x32
targets = torch.tensor([1])  # 假设目标类别为1

# 训练循环(简化版)
model.train()
for epoch in range(10):  # 假设训练10个epoch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

2)单层网络

此单层网络需要是使用张量类型的。

如果只训练某个参数,那么令其 requires_grad 属性为 True ,其他的参数该属性设置为 False 即可。

另外在训练过程中,只使用 w2.data -= learning_rate * w2.grad.data 来更新 w2 的值即可。

具体示例:

以线性模型 y = w1*x^2+w2*x+b 为例(只训练 w2 )。

import torch

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# 初始化权重
w1 = torch.tensor([0.5], requires_grad=False)  # 不需要训练
w2 = torch.tensor([0.0], requires_grad=True)  # 需要训练,初始化为0.0以更清楚地看到训练效果
b = torch.tensor([0.5], requires_grad=False)  # 不需要训练

# 定义前向传播函数
def forward(x):
    return x * x * w1 + x * w2 + b

# 定义损失函数(均方误差)
def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2

# 训练过程
learning_rate = 0.01
for epoch in range(100):
    total_loss = 0.0
    for x, y in zip(x_data, y_data):
        l = loss(x, y)
        l.backward()  # 反向传播,计算梯度

        # 只更新w2
        w2.data -= learning_rate * w2.grad.data
        w2.grad.zero_()# 清零梯度,为下一次迭代做准备

        total_loss += l.item()

    average_loss = total_loss / len(x_data)
    print(f"Epoch {epoch+1}, Average Loss: {average_loss}")

上一篇下一篇
神经网络入门实战(二十一)待发布

http://www.kler.cn/a/457126.html

相关文章:

  • 每日一练 | 时延和抖动
  • 工业以太网交换机怎么挑选?
  • 简单XXE漏洞理解以及在实战中演练
  • 大数据Scala面试题汇总
  • HALCON中用于分类的高斯混合模型create_class_gmm
  • Idea创建JDK17的maven项目失败
  • 青少年编程与数学 02-005 移动Web编程基础 06课题、响应式设计
  • Web 漏洞之 CSRF 漏洞挖掘:攻防深度剖析
  • SelectionArea 实现富文本
  • 【源码 导入教程 文档 讲解】基于springboot校园新闻管理系统源码和论文
  • 【13】MySQL如何选择合适的索引?
  • 【GlobalMapper精品教程】091:根据指定字段融合图斑(字段值相同融合到一起)
  • C++学习指南
  • 初识MySQL · 库的操作
  • linux内核系列---网络
  • Java圣诞树
  • 数据结构:二叉树部分接口(链式)
  • 力扣算法--求两数之和等于目标数
  • MySQL的TIMESTAMP类型字段非空和默认值属性的影响
  • 用科技的方法能否实现真正的智能
  • DAY3 QT简易登陆界面优化
  • blender中合并的模型,在threejs中显示多个mesh;blender多材质烘培成一个材质
  • Debian 12 安装配置 fail2ban 保护 SSH 访问
  • 数据安全中间件的好处
  • OpenCV-Python实战(6)——图相运算
  • adb无线连接手机后scrcpy连接报错ERROR: Could not find any ADB device