当前位置: 首页 > article >正文

4.PyTorch——优化器

import numpy as np
import pandas as pd

import torch as t

PyTorch将深度学习中常用的优化方法全部封装在torch.optim中,其设计十分灵活,能够很方便的扩展成自定义的优化方法。

所有的优化方法都是继承基类optim.Optimizer,并实现了自己的优化步骤。下面就以最基本的优化方法——随机梯度下降法(SGD)举例说明。这里需重点掌握:

  • 优化方法的基本使用方法
  • 如何对模型的不同部分设置不同的学习率
  • 如何调整学习率
# 定义一个LeNet网络
class Net(t.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.features = t.nn.Sequential(
                        t.nn.Conv2d(3, 6, 5),
                        t.nn.ReLU(),
                        t.nn.MaxPool2d(2, 2),
                        t.nn.Conv2d(6, 16, 5),
                        t.nn.ReLU(),
                        t.nn.MaxPool2d(2, 2)
        )
        self.classifier = t.nn.Sequential(
                        t.nn.Linear(16*5*5, 120),
                        t.nn.ReLU(),
                        t.nn.Linear(120, 84),
                        t.nn.ReLU(),
                        t.nn.Linear(84, 10)
        )

    def  forward(self, x):
        x = self.features(x)
        x = x.view(-1, 16*5*5)
        x = self.classifier(x)
        return x

net = Net()
optimizer = t.optim.SGD(params=net.parameters(), lr=1)
optimizer.zero_grad()     # 梯度清零

input = t.randn(1, 3, 32, 32)
output = net(input)
output.backward(output)   

optimizer.step()   # 执行优化
# 为不同子网络设置不同的学习率,在finetune中经常用到
# 如果对某个参数不指定学习率,就使用最外层的默认学习率
optimizer = t.optim.SGD([
                {'params': net.features.parameters()}, # 学习率为1e-5
                {'params': net.classifier.parameters(), 'lr': 1e-2}
            ], lr=1e-5)
optimizer
SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    lr: 1e-05
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0

Parameter Group 1
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.01
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0
)
# 只为两个全连接层设置较大的学习率,其余层的学习率较小
special_layers = t.nn.ModuleList([net.classifier[0], net.classifier[3]])
special_layers_params = list(map(id, special_layers.parameters()))
base_params = filter(lambda p: id(p) not in special_layers_params, net.parameters())

optimizer = t.optim.SGD([{'params':base_params},
                         {'params':special_layers.parameters(), 'lr':0.01}], lr=0.001)
optimizer
SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.001
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0

Parameter Group 1
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.01
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0
)

对于如何调整学习率,主要有两种做法。一种是修改optimizer.param_groups中对应的学习率,另一种是更简单也是较为推荐的做法——新建优化器,由于optimizer十分轻量级,构建开销很小,故而可以构建新的optimizer。但是后者对于使用动量的优化器(如Adam),会丢失动量等状态信息,可能会造成损失函数的收敛出现震荡等情况。

# 方法1: 调整学习率,新建一个optimizer
old_lr = 0.1
optimizer1 = t.optim.SGD([
                {'params': net.features.parameters()},
                {'params': net.classifier.parameters(), 'lr': old_lr*0.1}
            ], lr=1e-5)
optimizer1
SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    lr: 1e-05
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0

Parameter Group 1
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.010000000000000002
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0
)
# 方法2: 调整学习率, 手动decay, 保存动量
for param_group in optimizer.param_groups:
    param_group['lr'] *= 0.1 # 学习率为之前的0.1倍
optimizer
SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.0001
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0

Parameter Group 1
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.001
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0
)

http://www.kler.cn/news/162952.html

相关文章:

  • Bert-vits2新版本V2.1英文模型本地训练以及中英文混合推理(mix)
  • 【c语言指针详解】指针的基本概念和用法
  • 面对对象基础案例
  • React中使用react-json-view展示JSON数据
  • 2023年甘肃职业院校技能大赛(中职教师组)网络安全竞赛样题(五)
  • 持续集成交付CICD:CentOS 7 安装 Nexus 3.63
  • Flask template+Vue +项目中include引入其他模版(其他模版也会用到vue)的使用探索
  • 独立服务器的主要应用方向有什么_Maizyun
  • 云原生(Cloud Native)——概念,技术,背景,优缺点,实践例子
  • Vue3如何优雅的跨组件通信
  • C++_对C数据类型的扩展
  • 整数以及浮点数在内存中的存储
  • 等待和通知
  • 联想电脑重装系统Win10步骤和详细教程
  • Ubuntu22.04 交叉编译fdk-aac for Rv1106
  • 【软件安装】VMware安装Centos7虚拟机并且设置静态IP,实现Windows和Centos7网络互相访问
  • Tair(2):Tair安装部署
  • 检测判断IP合法性API接口
  • Ubuntu 修改当前用户的名称
  • 膳食补充剂行业分析:2028年中国市场有望突破3700亿元
  • 有限空间作业中毒窒息事故频发,汉威科技创新方案护航
  • Flink 使用场景
  • K8S集群优化的可执行优化
  • 带大家做一个,易上手的家常辣子鸡
  • HbuilderX使用Uniapp+Vue3安装uview-plus
  • redis-学习笔记(list)
  • Conda常用命令总结
  • Apache Lucene 9.9,有史以来最快的 Lucene 版本
  • Python:核心知识点整理大全7-笔记
  • [网鼎杯 2020 朱雀组]phpweb1