当前位置: 首页 > article >正文

PyTorch API 详细中文文档,基于PyTorch2.5


PyTorch API 详细中文文档

按模块分类,涵盖核心函数与用法示例


目录

  1. 张量操作 (Tensor Operations)
  2. 数学运算 (Math Operations)
  3. 自动求导 (Autograd)
  4. 神经网络模块 (torch.nn)
  5. 优化器 (torch.optim)
  6. 数据加载与处理 (torch.utils.data)
  7. 设备管理 (Device Management)
  8. 模型保存与加载
  9. 分布式训练 (Distributed Training)
  10. 实用工具函数

1. 张量操作 (Tensor Operations)

1.1 张量创建
函数描述示例
torch.tensor(data, dtype, device)从数据创建张量torch.tensor([1,2,3], dtype=torch.float32)
torch.zeros(shape)创建全零张量torch.zeros(2,3)
torch.ones(shape)创建全一张量torch.ones(5)
torch.rand(shape)均匀分布随机张量torch.rand(3,3)
torch.randn(shape)标准正态分布张量torch.randn(4,4)
torch.arange(start, end, step)创建等差序列torch.arange(0, 10, 2)[0,2,4,6,8]
torch.linspace(start, end, steps)线性间隔序列torch.linspace(0, 1, 5)[0, 0.25, 0.5, 0.75, 1]
1.2 张量属性
属性/方法描述示例
.shape张量维度x = torch.rand(2,3); x.shape → torch.Size([2,3])
.dtype数据类型x.dtype → torch.float32
.device所在设备x.device → device(type='cpu')
.requires_grad是否追踪梯度x.requires_grad = True
1.3 张量变形
函数描述示例
.view(shape)调整形状(不复制数据)x = torch.arange(6); x.view(2,3)
.reshape(shape)类似 view,但自动处理内存连续性x.reshape(3,2)
.permute(dims)调整维度顺序x = torch.rand(2,3,4); x.permute(1,2,0)
.squeeze(dim)去除大小为1的维度x = torch.rand(1,3); x.squeeze(0)shape [3]
.unsqueeze(dim)添加大小为1的维度x = torch.rand(3); x.unsqueeze(0)shape [1,3]

2. 数学运算 (Math Operations)

2.1 逐元素运算
函数描述示例
torch.add(x, y)加法torch.add(x, y)x + y
torch.mul(x, y)乘法torch.mul(x, y)x * y
torch.exp(x)指数运算torch.exp(torch.tensor([1.0]))[2.7183]
torch.log(x)自然对数torch.log(torch.exp(tensor([2.0])))[2.0]
torch.clamp(x, min, max)限制值范围torch.clamp(x, min=0, max=1)
2.2 矩阵运算
函数描述示例
torch.matmul(x, y)矩阵乘法x = torch.rand(2,3); y = torch.rand(3,4); torch.matmul(x, y)
torch.inverse(x)矩阵求逆x = torch.rand(3,3); inv_x = torch.inverse(x)
torch.eig(x)特征值分解eigenvalues, eigenvectors = torch.eig(x)
2.3 统计运算
函数描述示例
torch.sum(x, dim)沿维度求和x = torch.rand(2,3); torch.sum(x, dim=1)
torch.mean(x, dim)沿维度求均值torch.mean(x, dim=0)
torch.max(x, dim)沿维度求最大值values, indices = torch.max(x, dim=1)
torch.argmax(x, dim)最大值索引indices = torch.argmax(x, dim=1)

3. 自动求导 (Autograd)

3.1 梯度计算
函数/属性描述示例
x.backward()反向传播计算梯度x = torch.tensor(2.0, requires_grad=True); y = x**2; y.backward()
x.grad查看梯度值x.grad4.0(若 y = x²
torch.no_grad()禁用梯度追踪with torch.no_grad(): y = x * 2
detach()分离张量(不追踪梯度)y = x.detach()
3.2 梯度控制
函数描述
x.retain_grad()保留非叶子节点的梯度
torch.autograd.grad(outputs, inputs)手动计算梯度

示例

x = torch.tensor(3.0, requires_grad=True)  
y = x**3 + 2*x  
dy_dx = torch.autograd.grad(y, x)  # 返回 (torch.tensor(29.0),)  

4. 神经网络模块 (torch.nn)

4.1 层定义
描述示例
nn.Linear(in_features, out_features)全连接层layer = nn.Linear(784, 256)
nn.Conv2d(in_channels, out_channels, kernel_size)卷积层conv = nn.Conv2d(3, 16, kernel_size=3)
nn.LSTM(input_size, hidden_size)LSTM 层lstm = nn.LSTM(100, 50)
nn.Dropout(p=0.5)Dropout 层dropout = nn.Dropout(0.2)
4.2 激活函数
函数描述示例
nn.ReLU()ReLU 激活F.relu(x)nn.ReLU()(x)
nn.Sigmoid()Sigmoid 函数torch.sigmoid(x)
nn.Softmax(dim)Softmax 归一化F.softmax(x, dim=1)
4.3 损失函数
描述示例
nn.MSELoss()均方误差loss_fn = nn.MSELoss()
nn.CrossEntropyLoss()交叉熵损失loss = loss_fn(outputs, labels)
nn.BCELoss()二分类交叉熵loss_fn = nn.BCELoss()

5. 优化器 (torch.optim)

5.1 优化器定义
描述示例
optim.SGD(params, lr)随机梯度下降optimizer = optim.SGD(model.parameters(), lr=0.01)
optim.Adam(params, lr)Adam 优化器optimizer = optim.Adam(model.parameters(), lr=0.001)
optim.RMSprop(params, lr)RMSprop 优化器optimizer = optim.RMSprop(params, lr=0.01)
5.2 优化器方法
方法描述示例
optimizer.zero_grad()清空梯度optimizer.zero_grad()
optimizer.step()更新参数loss.backward(); optimizer.step()
optimizer.state_dict()获取优化器状态state = optimizer.state_dict()

6. 数据加载与处理 (torch.utils.data)

6.1 数据集类
类/函数描述示例
Dataset自定义数据集基类继承并实现 __len____getitem__
DataLoader(dataset, batch_size, shuffle)数据加载器loader = DataLoader(dataset, batch_size=64, shuffle=True)

自定义数据集示例

class MyDataset(Dataset):  
    def __init__(self, data, labels):  
        self.data = data  
        self.labels = labels  
    def __len__(self):  
        return len(self.data)  
    def __getitem__(self, idx):  
        return self.data[idx], self.labels[idx]  
6.2 数据预处理 (TorchVision)
from torchvision import transforms  

transform = transforms.Compose([  
    transforms.Resize(256),          # 调整图像大小  
    transforms.ToTensor(),           # 转为张量  
    transforms.Normalize(mean=[0.5], std=[0.5])  # 标准化  
])  

7. 设备管理 (Device Management)

7.1 设备切换
函数/方法描述示例
.to(device)移动张量/模型到设备x = x.to('cuda:0')
torch.cuda.is_available()检查 GPU 是否可用if torch.cuda.is_available(): ...
torch.cuda.empty_cache()清空 GPU 缓存torch.cuda.empty_cache()

8. 模型保存与加载

函数描述示例
torch.save(obj, path)保存对象(模型/参数)torch.save(model.state_dict(), 'model.pth')
torch.load(path)加载对象model.load_state_dict(torch.load('model.pth'))
model.state_dict()获取模型参数字典params = model.state_dict()

9. 分布式训练 (Distributed Training)

函数/类描述示例
nn.DataParallel(model)单机多卡并行model = nn.DataParallel(model)
torch.distributed.init_process_group()初始化分布式训练需配合多进程使用

10. 实用工具函数

函数描述示例
torch.cat(tensors, dim)沿维度拼接张量torch.cat([x, y], dim=0)
torch.stack(tensors, dim)堆叠张量(新建维度)torch.stack([x, y], dim=1)
torch.split(tensor, split_size, dim)分割张量chunks = torch.split(x, 2, dim=0)

常见问题与技巧

  1. GPU 内存不足

    • 使用 batch_size 较小的值
    • 启用混合精度训练 (torch.cuda.amp)
    • 使用 torch.utils.checkpoint 节省内存
  2. 梯度爆炸/消失

    • 使用梯度裁剪:torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    • 调整权重初始化方法
  3. 模型推理模式

    model.eval()  # 关闭 Dropout 和 BatchNorm 的随机性  
    with torch.no_grad():  
        outputs = model(inputs)  
    

文档说明

  • 本文档基于 PyTorch 2.5 编写,部分 API 可能不兼容旧版本。
  • 更详细的参数说明请参考 PyTorch 官方文档。

http://www.kler.cn/a/522219.html

相关文章:

  • 2025美赛数学建模MCM/ICM选题建议与分析,思路+模型+代码
  • LangChain的开发流程
  • C语言实现统计数组正负元素相关数据
  • 漏洞修复:Apache Tomcat 安全漏洞(CVE-2024-50379) | Apache Tomcat 安全漏洞(CVE-2024-52318)
  • 您与此网站之间建立的连接不安全
  • Versal - 基础3(AXI NoC 专题+仿真+QoS)
  • 【PySide6快速入门】QFileDialog 文件选择对话框
  • RAG与CAG的较量与融合
  • python接口测试:2.8 Pytest之pytest-html报告生成
  • 【Rust自学】15.6. RefCell与内部可变性:“摆脱”安全性限制
  • 计算生物学与生物信息学:一周年创作纪念
  • 系统思考—转型
  • Lucene常用的字段类型lucene检索打分原理
  • Go-并行编程新手指南
  • 【深度学习】搭建卷积神经网络并进行参数解读
  • ROS应用之SwarmSim在ROS 中的协同路径规划
  • obsidian插件——Metadata Hider
  • 软工_软件工程
  • Dest1ny漏洞库:用友 U8-CRM 系统 ajaxgetborrowdata.php 存在 SQL 注入漏洞
  • EtherCAT主站IGH-- 18 -- IGH之fsm_mbox_gateway.h/c文件解析
  • 使用Python Dotenv库管理环境变量
  • 日志收集Day008
  • 【系统架构设计师】操作系统 ① ( 知识的三种层次 - 系统知识、高频考点、试题拆解 - 软考备考策略 | 操作系统涉及的软考知识点 | 操作系统简介 )
  • 人机环境系统中的贝叶斯与非贝叶斯
  • 【算法学习笔记】36:中国剩余定理(Chinese Remainder Theorem)求解线性同余方程组
  • 06-机器学习-数据预处理