当前位置: 首页 > article >正文

在 PyTorch 训练中使用 `tqdm` 显示进度条


在 PyTorch 训练中使用 tqdm 显示进度条

在深度学习的训练过程中,实时查看训练进度是非常重要的,它可以帮助我们更好地理解训练的效率,并及时调整模型或优化参数。使用 tqdm 库来为训练过程添加进度条是一个非常有效的方式,本文将介绍如何在 PyTorch 中结合 tqdm 来动态显示训练进度。

1. 安装 tqdm

首先,如果你还没有安装 tqdm,可以通过 pip 命令进行安装:

pip install tqdm

tqdm 是一个非常轻量级的 Python 库,它可以快速地为循环加上进度条,并提供非常友好的终端显示效果。

2. 如何在训练循环中使用 tqdm

2.1 包装 train_loader 和显示进度条

在训练过程中,最常见的任务是通过数据加载器(train_loader)来批量读取训练数据并进行前向传播、反向传播、优化等操作。为了显示训练的进度条,我们可以使用 tqdm 来包装 train_loader。这样,我们就能在每次读取数据时自动更新进度条。

示例代码:

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim

# 假设模型、损失函数、优化器等已经定义好
model = MyModel().to(device)
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

# 假设 train_loader 已经定义好
train_loader = ...

num_epochs = 10

# 训练循环
for epoch in range(num_epochs):
    model.train()  # 设置模型为训练模式
    running_loss = 0.0
    correct = 0
    total = 0

    # 使用 tqdm 包装 train_loader,自动显示进度条
    for batch_idx, (audio, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", ncols=100)):
        audio = audio.to(device)
        labels = labels.to(device)

        # 前向传播
        optimizer.zero_grad()
        outputs = model(audio)

        # 计算损失
        loss = criterion(outputs, labels)

        # 反向传播
        loss.backward()
        optimizer.step()

        # 更新统计信息
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    # 输出每个 epoch 的总结信息
    print(f"\nEpoch {epoch+1} complete. Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")

2.2 进度条的显示效果

当你使用 tqdm 包装 train_loader 时,进度条会自动显示。每个 epoch 内,tqdm 会根据 train_loader 的批次数量动态更新进度条。显示的内容通常包括:

  • 当前 epoch 和 batch 的编号
  • 已完成的批次数量
  • 当前 batch 的损失值
  • 进度百分比
  • 剩余时间估计
进度条示例:
Epoch 1/10: 100%|█████████████████████████████████████████████████████████████████████| 2043/2043 [03:30<00:00,  9.70it/s]

进度条字段说明:

  • Epoch 1/10: 表示当前是第 1 个 epoch,总共有 10 个 epoch。
  • 100%: 表示该 epoch 的训练已经完成。
  • █████████████████████████████████████████████████████████████████████: 代表进度条,显示当前训练的进度。
  • 2043/2043: 当前 epoch 中处理的 batch 数,表示已处理 2043 个 batch,总共有 2043 个。
  • 03:30<00:00: 训练已经花费的时间(03:30)和预计剩余时间(<00:00)。
  • 9.70it/s: 表示每秒处理的批次数(即训练速度)。

2.3 进度条中的其他信息

TQDM 还可以显示更多的训练相关信息,例如:

  • 损失(Loss):当前 batch 的损失值。
  • 训练速度(it/s):每秒处理的批次数。
  • 估算剩余时间:根据当前训练速度估算的剩余时间。

这些信息可以帮助你在训练过程中更好地评估模型的学习情况。

3. 总结

使用 tqdm 包装 train_loader 可以极大地提高训练过程中的可视化效果,使得我们能够实时了解训练的进展。进度条的显示不仅可以告诉我们当前 epoch 的训练进度,还能够实时更新损失、准确率等信息,帮助我们更加高效地调试和优化模型。

如果你在进行长期训练时,tqdm 显示的剩余时间可以帮助你更好地掌控时间管理。在大规模训练时,动态显示进度条也能有效提升训练的可操作性和可视化体验。

4. 常见错误:TypeError: 'module' object is not callable

解决方法看: 我的另一篇博客

5. 参考文献

  • tqdm 官方文档
  • PyTorch 官方文档

文章包含了 tqdm 的基本用法以及在 PyTorch 训练中如何集成它,应该能帮助你快速理解如何提高训练过程的可视化效果。如果你有其他细节需要补充或修改,评论区告诉我!


http://www.kler.cn/a/411471.html

相关文章:

  • JVM类加载和垃圾回收算法详解
  • 自动化的内存管理技术之垃圾回收机制-JavaScript引用数据内存回收机制
  • k8s集群增加nfs-subdir-external-provisioner存储类
  • .Net与C#
  • 文件内容扫描工具
  • 【PX4_Autopolite飞控源码】中飞控板初始化过程中的引脚IO控制(拉低/拉高)
  • PYNQ 框架 - 时钟系统 + pl_clk 时钟输出不准确问题
  • 将VSCode设置成中文语言环境
  • JAVA面试题、八股文学习之JVM篇
  • web day03 Maven基础 Junit
  • Postman设置接口关联,实现参数化
  • 【工作总结】2. 链路追踪与 APM 系统构建
  • HTTP中GET和POST的区别是什么?
  • 【排版教程】Word、WPS 分节符(奇数页等) 自动变成 分节符(下一页) 解决办法
  • 流媒体中ES流、PS流 、TS流怎么理解
  • Vscode终端出现在此系统上禁止运行脚本解决方法
  • 快速排序 归并排序
  • spring boot框架漏洞复现
  • 《白帽子讲Web安全》13-14章
  • 解决:Openstack创建实例进入控制台报错Something went wrong, connection is closed
  • 6.STM32之通信接口《精讲》之IIC通信---硬件IIC(STM32自带的硬件收发器)
  • Flink cdc同步增量数据timestamp字段相差八小时(分析|解决)不是粘贴复制的!
  • 2024APMCM亚太杯数学建模C题【宠物行业】原创论文分享
  • kali Linux中foremost安装
  • 实现乱序函数?(面试常考)
  • 计算(a+b)/c的值