当前位置: 首页 > article >正文

cosine_with_warmup_scheduler(lrgb文件中的cosine_scheduler.py)

这段代码实现了一个带有预热阶段的余弦退火学习率调度器。其目的是在训练过程中动态调整学习率,开始时通过预热线性增加学习率,然后在余弦曲线的基础上逐渐减少学习率。调度器的设计可以帮助模型在训练初期稳定学习,并随着训练的进行逐步减少学习率,避免训练后期的过拟合问题。

from lrgb.cosine_scheduler import cosine_with_warmup_scheduler

import math

import torch.optim as optim
from torch.optim import Optimizer


def cosine_with_warmup_scheduler(optimizer: Optimizer,
                                 num_warmup_epochs: int, max_epoch: int):
    scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=num_warmup_epochs,
        num_training_steps=max_epoch
    )
    return scheduler


def get_cosine_schedule_with_warmup(
        optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int,
        num_cycles: float = 0.5, last_epoch: int = -1):
    """
    Implementation by Huggingface:
    https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/optimization.py

    Create a schedule with a learning rate that decreases following the values
    of the cosine function between the initial lr set in the optimizer to 0,
    after a warmup period during which it increases linearly between 0 and the
    initial lr set in the optimizer.
    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        num_cycles (`float`, *optional*, defaults to 0.5):
            The number of waves in the cosine schedule (the defaults is to just
            decrease from the max value to 0 following a half-cosine).
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.
    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

http://www.kler.cn/news/364297.html

相关文章:

  • 戴维南,叠加,稳态笔记
  • Python URL编码
  • U9的插件开发之BE插件(1)
  • volatile 关键字的作用学习
  • LeetCode Hot 100:图论
  • 记一次js泄露pass获取核心业务
  • Electron 离线环境打包解决方案(electron-forge)
  • 堆(堆排序,TOP K, 优先级队列)
  • AI图片生成3D物体和2D视频提取3D动画
  • vr头显都是什么操作系统
  • SpringColoud GateWay 核心组件
  • 纷享销客生态大会成都站成功举办:携手精英伙伴,共话CRM新纪元
  • RTC(Real-Time Clock)简介
  • Flutter 12 实现双击屏幕显示点赞爱心多种动画(AnimationIcon)效果
  • 《Python游戏编程入门》注-第3章3
  • OpenCV 图像去畸变(相机标定)
  • VoLTE 微信令:VoLTE 打 VoLTE,被叫号码不存在的信令流程
  • C语言——求解一元二次方程
  • 机器学习中的图像处理与计算机视觉
  • docker部署es与kibana Mac
  • c语言内核链表
  • Linux基础命令(五) 之 cat,head,tail,more,less,grep
  • 【Java 22 | 9】 深入解析Java 22 :Foreign Function Memory API 的改进
  • Elasticsearch基础操作入门
  • 安卓屏幕旋转(TODO)
  • 华为ensp静态路由,浮动路由,缺省路由讲解及配置