当前位置: 首页 > article >正文

pyro ExponentialLR 如何设置优化器 optimizer的学习率 pytorch 深度神经网络 bnn,

 第一。pyro 不支持 “ReduceLROnPlateau” ,因为需要Loss作为输入数值,计算量大

pytorch的学习率调整 视频 看这个博主的视频

05-01-学习率调整策略_哔哩哔哩_bilibili

第二 ,svi 支持 scheduler注意点,

属于  pyro.optim.PyroOptim的有三个
AdagradRMSProp ClippedAdam DCTAdam,但是还是会报错,类似下面的错误

optimizer = pyro.optim.SGD
# 指数下降学习率
pyro_scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': learn_rate}, 'gamma': 0.1})
Traceback (most recent call last):
  File "/home/aistudio/bnn_pyro_fso_middle_2_16__256.py", line 441, in <module>
    svi = SVI(model, mean_field_guide, optimizer, loss=Trace_ELBO())
  File "/home/aistudio/external-libraries/pyro/infer/svi.py", line 72, in __init__
    raise ValueError(
ValueError: Optimizer should be an instance of pyro.optim.PyroOptim class.

正确的方法是 

optimizer = torch.optim.SGD
# 指数下降学习率
pyro_scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': learn_rate}, 'gamma': 0.1})


# 设置ReduceLROnPlateau调度器
# 在这个例子中,`'min'`模式表示当验证集上的损失停止下降时,学习率会降低。`factor=0.1`表示学习率会乘以0.1,`
# patience=10`表示如果验证集损失在连续10个epochs内没有改善,则降低学习率。
# scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=20, verbose=True)

# svi = SVI(model, mean_field_guide, optimizer, loss=Trace_ELBO())
svi = SVI(model, mean_field_guide, pyro_scheduler, loss=Trace_ELBO())

ExponentialLR

ExponentialLR是指数型下降的学习率调节器,每一轮会将学习率乘以gamma,所以这里千万注意gamma不要设置的太小,不然几轮之后学习率就会降到0。

lr_scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
# 按指数衰减调整学习率,调整公式:lr = lr*gamma**epoch

这是一个用于动态生成参数的调整学习率的包装器,用于 `torch.optim.lr_scheduler` 对象。

    :param scheduler_constructor: 一个 `torch.optim.lr_scheduler` 的类
    :param optim_args: 一个包含优化器学习参数的字典,或者是一个返回此类字典的可调用对象。必须包含键 'optimizer',其值为 PyTorch 优化器的值
    :param clip_args: 一个包含 `clip_norm` 和/或 `clip_value` 参数的字典,或者是一个返回此类字典的可调用对象。

    例子::

        optimizer = torch.optim.SGD
        scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1})
        svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO())
        for i in range(epochs):
            for minibatch in DataLoader(dataset, batch_size):
                svi.step(minibatch)
            scheduler.step()
 

"""
A wrapper for :class:`~torch.optim.lr_scheduler` objects that adjusts learning rates
for dynamically generated parameters.

:param scheduler_constructor: a :class:`~torch.optim.lr_scheduler`
:param optim_args: a dictionary of learning arguments for the optimizer or a callable that returns
    such dictionaries. must contain the key 'optimizer' with pytorch optimizer value
:param clip_args: a dictionary of clip_norm and/or clip_value args or a callable that returns
    such dictionaries.

Example::

    optimizer = torch.optim.SGD
    scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1})
    svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO())
    for i in range(epochs):
        for minibatch in DataLoader(dataset, batch_size):
            svi.step(minibatch)
        scheduler.step()

pyro.optim.lr_scheduler的源代码

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Iterable, List, Optional, Union, ValuesView

from torch import Tensor

from pyro.optim.optim import PyroOptim


class PyroLRScheduler(PyroOptim):
    """
    A wrapper for :class:`~torch.optim.lr_scheduler` objects that adjusts learning rates
    for dynamically generated parameters.

    :param scheduler_constructor: a :class:`~torch.optim.lr_scheduler`
    :param optim_args: a dictionary of learning arguments for the optimizer or a callable that returns
        such dictionaries. must contain the key 'optimizer' with pytorch optimizer value
    :param clip_args: a dictionary of clip_norm and/or clip_value args or a callable that returns
        such dictionaries.

    Example::

        optimizer = torch.optim.SGD
        scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1})
        svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO())
        for i in range(epochs):
            for minibatch in DataLoader(dataset, batch_size):
                svi.step(minibatch)
            scheduler.step()
    """

    def __init__(
        self,
        scheduler_constructor,
        optim_args: Union[Dict],
        clip_args: Optional[Union[Dict]] = None,
    ):
        # pytorch scheduler
        self.pt_scheduler_constructor = scheduler_constructor
        # torch optimizer
        pt_optim_constructor = optim_args.pop("optimizer")
        # kwargs for the torch optimizer
        optim_kwargs = optim_args.pop("optim_args")
        self.kwargs = optim_args
        super().__init__(pt_optim_constructor, optim_kwargs, clip_args)

    def __call__(self, params: Union[List, ValuesView], *args, **kwargs) -> None:
        super().__call__(params, *args, **kwargs)

    def _get_optim(
        self, params: Union[Tensor, Iterable[Tensor], Iterable[Dict[Any, Any]]]
    ):
        optim = super()._get_optim(params)
        return self.pt_scheduler_constructor(optim, **self.kwargs)

    def step(self, *args, **kwargs) -> None:
        """
        Takes the same arguments as the PyTorch scheduler
        (e.g. optional ``loss`` for ``ReduceLROnPlateau``)
        """
        for scheduler in self.optim_objs.values():
            scheduler.step(*args, **kwargs)


http://www.kler.cn/a/292610.html

相关文章:

  • 【教程】Ubuntu设置alacritty为默认终端
  • 百度搜索AI探索版多线程批量生成TXT原创文章软件-可生成3种类型文章
  • python解析网页上的json数据落地到EXCEL
  • 〔 MySQL 〕数据类型
  • SQL 中 BETWEEN AND 用于字符串的理解
  • 解锁微前端的优秀库
  • JavaScript 21个常用数组使用方法
  • Linux运维--Firewall防火墙命令以及规则等详解(全)
  • 针对不同区域的摄像头,完成不同的算法配置的智慧快消开源了
  • PostgreSQL技术内幕7:PostgreSQL查询编译
  • SpringBoot 消息队列RabbitMQ Work模型 绑定多个消费者 负载均衡 消息处理速度
  • 学习记录:js算法(二十五):合并两个有序链表
  • Power Automate向Power Apps传Table数据
  • C++项目引入开源库bit7z
  • ARM SIMD instruction -- movi
  • ccfcsp-202209(1、2、3)
  • Kafka【十一】数据一致性与高水位(HW :High Watermark)机制
  • 企业财务流程优化的财税自动化解决方案
  • Python知识点:如何使用Pytest进行单元测试
  • JVM系列(十) -垃圾收集器介绍
  • OpenObserve云原生可观测平台本地Docker部署与远程访问实战教程
  • KDD2024参会笔记-Day1
  • Core ML
  • 基于.NET6的WPF基础总结(上)
  • CSS学习9
  • RPC框架-RMI