当前位置: 首页 > article >正文

深度学习(3):Tensor和Optimizer

文章目录

  • 是什么
  • Tensor
    • 1. Tensor 的基本概念
    • 2. 自动求导(Autograd)机制
    • 3. `requires_grad` 属性
    • 4. `.data` 和 `.item()`
    • 5. 梯度清零
  • Optimizer

是什么

Tensor(张量):在 PyTorch 中,Tensor 是一种多维数组, 可以在 GPU 上进行高效的数值计算。

Optimizer(优化器):在 PyTorch 中,优化器负责管理和更新模型的参数,特别是在训练过程中根据计算出的梯度来更新参数以最小化损失函数。

Tensor

1. Tensor 的基本概念

  • 定义:Tensor 是 PyTorch 中的基本数据结构,表示多维数组。
  • 创建:使用 torch.tensor() 或其他方法(如 torch.zeros()torch.ones() 等)创建。
  • 属性
    • data:Tensor 的数据部分。
    • grad:Tensor 的梯度,只有当 requires_grad = True 时,才会在反向传播后被赋值。
    • requires_grad:是否需要对该 Tensor 计算梯度,默认值为 False

2. 自动求导(Autograd)机制

  • 计算图:PyTorch 通过记录 Tensor 的操作,构建一个有向无环图(DAG),称为计算图。
  • 反向传播:调用 loss.backward(),PyTorch 会自动计算损失对所有叶子节点的梯度。
  • 梯度累积:在默认情况下,梯度会在每次反向传播中累积,需要手动清零。

3. requires_grad 属性

  • 作用:指定 Tensor 是否需要计算梯度。
  • 设置方法:在创建 Tensor 时,通过参数 requires_grad=True,或者在已有的 Tensor 上设置 tensor.requires_grad = True

4. .data.item()

  • .data
    • tensor.data:获取 Tensor 的数据部分,不会影响计算图的构建。
    • 这个地方得注意,在老版本里面可以使用,但新版本里面有时候会导致梯度计算出现问题,一般使用使用 with torch.no_grad(): 块来避免计算图的影响
  • .item()
    • 作用:将单元素 Tensor 转换为 Python 数值。
    • 用途:用于打印、记录损失值或中间结果。

5. 梯度清零

  • 原因:防止梯度累积,确保每次反向传播计算的梯度是针对当前样本或批次的。
  • 方法
    • w.grad.data.zero_():将梯度清零。

Optimizer

优化器位于 torch.optim 包中,是实现各种优化算法的核心组件。它们的主要功能是调整网络参数以减少计算出的损失值。每一个优化器都继承自 torch.optim.Optimizer,并实现特定的优化策略。

使用主要包括三个步骤:计算梯度、执行参数更新和清除梯度。

import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(10, 2)  # 示例模型
criterion = nn.MSELoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 初始化优化器

for inputs, targets in data_loader:
    optimizer.zero_grad()   # 清除过去的梯度
    outputs = model(inputs)  # 前向传播
    loss = criterion(outputs, targets)  # 计算损失
    loss.backward()  # 反向传播,计算当前梯度
    optimizer.step()  # 根据梯度更新参数


http://www.kler.cn/news/324004.html

相关文章:

  • 求职Leetcode题目(11)
  • 如何使用C语言接入Doris数据库
  • 线性表二——栈stack
  • 微信小程序开发系列之-在微信小程序中使用云开发
  • How to install JetBrains ToolBox in Ubuntu 22.04 LTS?
  • ELK-03-skywalking监控linux系统
  • JAVA JDK华为云镜像下载,速度很快
  • AIGC入门:Comfyui整合包,解压即用!
  • Goweb---Gorm操作数据库(二)
  • project_object_model_3d
  • ES6中迭代器与生成器知识浅析
  • Python知识点:如何使用Python与.NET进行互操作(IronPython)
  • ubuntu 安装harbor
  • 解锁MySQL高可用新境界:深入探索MHA架构的无限魅力与实战部署
  • HI3520DV510 22AP80/SS522V100 芯片及开发板
  • 认识 Linux操作系统
  • 新疆交投路桥桥梁公司:向“新”求“质”,积蓄发展新势能
  • Tkinter制作登录界面以及登陆后页面切换(一)
  • Linux 基本指令的学习
  • 【深度学习】03-神经网络 3-3 梯度下降的优化方法-动量算法Momentum
  • mysql数据库sql语句总结
  • 综合业务区的数字化创新与智能化蓝图
  • GitLab CI/CD脚本入门
  • 04_OpenCV图片缩放
  • 先进制造aps专题二十六 基于强化学习的人工智能ai生产排程aps模型简介
  • Oracle 数据库安装和配置指南(新)
  • 进阶SpringBoot之分布式系统与 RPC 原理
  • 数据结构:成员运算符(.)+ 指向运算符(->)
  • 创建javaWeb项目(详细版本)2021年2月
  • 【递归】8. leetcode 671 二叉树中第二小的节点