当前位置: 首页 > article >正文

Pytorch代码入门学习之分类任务(三):定义损失函数与优化器

目录

一、定义损失函数

1.1 代码

1.2 损失函数简介

1.3 交叉熵误差(cross entropy error)

二、定义优化器

2.1 代码

2.2 构造优化器

2.3 随机梯度下降法(SGD)


一、定义损失函数

1.1 代码

criterion = nn.CrossEntropyLoss()

1.2 损失函数简介

        神经网络的学习通过某个指标表示目前的状态,然后以这个指标为基准,寻找最优的权重参数。神经网络以某个指标为线索寻找最优权重参数,该指标称为损失函数(loss function)。这个损失函数可以使用任意函数, 但一般用均方误差和交叉熵误差等。损失函数是表示神经网络性能的“恶劣程度”的指标,即当前的神经网络对监督数据在多大程度上不拟合、不一致。这个值越低,表示网络的学习效果越好。

        但是,如果loss很低的话,可能出现过拟合现象。

        过拟合是指训练出来的模型在训练集上表现得很好,但是在测试集上表现的较差,模型训练的误差远小于它在测试集上的误差。

1.3 交叉熵误差(cross entropy error)

        交叉熵误差如下式所示:

E = -\sum_k{}t_{k} logy_{k}

         其中,log表示以e为底数的自然对数(log e );yk指神经网络的输出,tk是正确解标签。并且,tk中只有正确解标签的索引为1,其他均为0(one-hot表示)。 因此,上式实际上只计算对应正确解标签的输出的自然对数。比如,假设正确解标签的索引是“2”,与之对应的神经网络的输出是0.6,则交叉熵误差 是−log 0.6 = 0.51;若“2”对应的输出是0.1,则交叉熵误差为−log 0.1 = 2.30。因此,交叉熵误差的值是由正确解标签所对应的输出结果决定的。

        交叉熵误差函数需要两个参数,第一个是输入参数(预测值),第二个是正确值

二、定义优化器

2.1 代码

import torch.optim as optim
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

2.2 构造优化器

        optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9):第一个参数是需要更新的参数,第二个参数是指学习率(指每次更新学习率下降的大小),第三个参数为动量;

2.3 随机梯度下降法(SGD)

        用数学式子可以把SGD写为如下的式:

        其中,W记为需要更新的权重参数,\frac{\partial L}{\partial W}是指损失函数关于W的梯度,\eta表示学习率,一般情况下会取为0.01或0.001这类事先决定好的值。式子中的“箭头”表示用右边的值更新左边的值。

        SGD较为简单,且容易实现,但是在解决某些问题时可能没有效率。SGD是朝着梯度方向只前进一定距离的简单方法,且梯度的方法并没有指向最小值的方向。

        参考:004 第一个分类任务2_哔哩哔哩_bilibili


http://www.kler.cn/news/108085.html

相关文章:

  • 【Qt】绘图与绘图设备
  • C++不能在子类中构造函数的初始化成员列表中直接初始化基类成员变量
  • C++ 运算符
  • Linux touch命令:创建文件及修改文件时间
  • 底层驱动day8作业
  • 【C++】智能指针:auto_ptr、unique_ptr、share_ptr、weak_ptr(技术介绍 + 代码实现)(待更新)
  • Megatron-LM GPT 源码分析(三) Pipeline Parallel分析
  • AWS SAP-C02教程11-解决方案
  • C#,数值计算——分类与推理,基座向量机的 Svmgenkernel的计算方法与源程序
  • 中微爱芯74逻辑兼容替代TI/ON/NXP工规品质型号全
  • 【杂记】Ubuntu20.04装系统,安装CUDA等
  • python爬虫之feapder.AirSpider轻量爬虫案例:豆瓣
  • PHP简单实现预定义钩子和自定义钩子
  • Linux国产系统无法连接身份证读卡器USB权限解决办法
  • nrf52832 开发板入手笔记:J-Flash 蓝牙协议栈烧写
  • Nginx 的配置文件(负载均衡,反向代理)
  • Spring Security: 整体架构
  • uniapp-图片压缩(适配H5,APP)
  • 10月Java行情 回暖?
  • 【机器学习可解释性】4.SHAP 值
  • 第10期 | GPTSecurity周报
  • scratch接钻石 2023年9月中国电子学会图形化编程 少儿编程 scratch编程等级考试三级真题和答案解析
  • 力扣第763题 划分字母区间 c++ 哈希 + 双指针 + 小小贪心
  • 制作自己的前端组件库并上传到npm上
  • MySQL实战2
  • 华为c语言编程规范
  • 【Unity】RenderFeature应用(简单场景扫描效果)
  • Linux学习第26天:异步通知驱动开发: 主动
  • 基于Headless构建高可用spark+pyspark集群
  • React中useEffect Hook使用纠错