当前位置: 首页 > article >正文

【Python】Tensorboard (Pytorch版)用法

PyTorch TensorBoard 全面指南

1. 安装 TensorBoard

如果你还没有安装 TensorBoard,可以使用 pip 安装:

pip install tensorboard

然后,在终端运行:

tensorboard --logdir=logs --port=6006

在浏览器中打开 http://localhost:6006,就可以看到可视化界面。


但是平时我们一般是服务器端进行训练,本地通过ssh连接服务器进行查看,这时候就要通过本地连接服务器的tensorboard进行查看日志了。就需要服务器端开启tensorboard映射到本地,本地查看服务器端tensorboard。做法有两步:

# 1. Windows本地端口映射至服务器(Windows端或本机操作)
# ssh -L 本地端口:本地IP:远程端口 远程服务器用户名@远程服务器IP -p 服务器连接端口
ssh -L 6006:127.0.0.1:6006 wangyaqi@servers.eezhilu.com -p 19025

# 2. 服务器开启tensorboard(服务器端操作)
tensorboard --logdir=logs
# 本地就可以浏览器访问tensorboard记录了

2. TensorBoard 基本用法

PyTorch 提供 SummaryWriter 来记录训练数据。所有日志都会存储在指定的 log_dir 目录下。
目录的定义一般有所要求,如下。

# ROOT代表当前文件的根目录
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
# 日志保存目录
save_dir = ROOT / 'runs' / Path(__file__).stem / datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
save_dir.mkdir(parents=True, exist_ok=True)
log_dir = save_dir

目录一般保存在 logsruns 目录中,而且我们通常会再创建包含文件名的子目录(用以表示不同的实验文件),在该子目录下再创建由运行时间代表的不同目录。

最终显示效果可能是这样的:不同的实验任务名字,每个任务下面有多个记录。
在这里插入图片描述
在上方可以看到tensorboard的各种选项卡:
在这里插入图片描述

2.1 记录标量(Scalars)

在深度学习训练过程中,我们通常会记录 lossaccuracy 等指标:

import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir)

for step in range(100):
    loss = np.exp(-0.1 * step)  # 模拟 loss 下降
    writer.add_scalar("Loss/train", loss, step)

writer.close()

2.2 记录直方图(Histograms)

可以用直方图来分析模型参数的变化,例如权重分布:

import torch.nn as nn

writer = SummaryWriter(log_dir)

# 定义简单的线性模型
model = nn.Linear(10, 5)

for epoch in range(10):
    weights = model.weight.data.numpy()
    writer.add_histogram("weights", weights, epoch)

writer.close()

在 TensorBoard “Histograms” 选项卡可以看到 weights 随时间的变化。


2.3 记录图片(Images)

如果你的数据是图片(如 MNIST),可以使用 add_image 记录:

import torchvision.utils as vutils

writer = SummaryWriter(log_dir)

# 生成 16 张随机噪声图片
images = torch.randn(16, 3, 64, 64)  
img_grid = vutils.make_grid(images, normalize=True)

writer.add_image("Random Images", img_grid, global_step=0)
writer.close()

然后在 TensorBoard “Images” 选项卡查看图片。


2.4 记录文本(Text)

writer = SummaryWriter(log_dir)

writer.add_text("Example", "This is an example text.", global_step=0)
writer.close()

2.5 记录计算图(Graph)

import torch.nn.functional as F

class SimpleNN(torch.nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = torch.nn.Linear(10, 5)
    
    def forward(self, x):
        return F.relu(self.fc1(x))

model = SimpleNN()

writer = SummaryWriter(log_dir)
dummy_input = torch.randn(1, 10)
writer.add_graph(model, dummy_input)
writer.close()

3. 进阶操作

3.1 监控超参数

可以使用 TensorBoard HPARAMS 记录不同超参数的实验结果:

from torch.utils.tensorboard.summary import hparams

writer = SummaryWriter(log_dir)

hparam_dict = {"lr": 0.001, "batch_size": 32}
metrics_dict = {"accuracy": 0.85}

writer.add_hparams(hparam_dict, metrics_dict)
writer.close()

在 TensorBoard “HPARAMS” 选项卡可以比较不同超参数的影响。


3.2 记录 Embeddings(嵌入向量)

add_embedding(mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None)
  • mat:这是一个形状为 [num_samples, embedding_dim] 的张量,代表高维嵌入数据。
  • metadata:这是一个可选参数,类型为列表或者元组,用于提供每个样本的元数据(例如标签)。
  • label_img:同样是可选参数,是形状为 [num_samples, channels, height, width] 的张量,用于为每个样本提供图像。
  • global_step:也是可选参数,代表全局步数,通常用于记录训练的进度。
  • tag:同样是可选参数,是一个字符串,用于标识这个嵌入可视化。
  • metadata_header:同样是可选参数,是一个列表或者元组,用于指定元数据的列名。

可以用 TensorBoard 可视化高维数据,比如 NLP 词向量:

# 生成 100 个 64 维嵌入向量
embeddings = torch.randn(100, 64)

writer = SummaryWriter(log_dir)
writer.add_embedding(embeddings, metadata=["word" + str(i) for i in range(100)])
writer.close()

在 TensorBoard Projector 选项卡可以查看高维数据的降维投影。


http://www.kler.cn/a/611056.html

相关文章:

  • springboot+mybatisplus
  • 【git拉取冲突解决】Please move or remove them before you merge. Aborting
  • 【Unity网络编程知识】使用Socket实现简单TCP通讯
  • Walrus 经济模型 101
  • 6.1 模拟专题:LeetCode 1576. 替换所有的问号
  • linux,防火墙,firewall,常用命令
  • 基于灵动微单片机SPIN系列的两轮车解决方案
  • java8循环解压zip文件---实现Excel文件数据追加
  • Elasticsearch 之 ElasticsearchRestTemplate 普通查询
  • EMC知识学习一
  • 利用Openfeign远程调用第三方接口(案例:百度地图逆地理编码接口,实现通过经纬度坐标获取详细地址)
  • 【工具分享 - Redis桌面客户端】Tiny RDM
  • Linux:(模拟HTTP协议,GET和POST方法,Http的状态码)
  • DeepSeek概述
  • Spring Boot 整合 OpenFeign 教程
  • 游戏引擎 Unity - Unity 主要窗口(层级、场景、游戏、检查器、项目、 控制台)
  • node-ddk,electron,主进程通讯,窗口间通讯
  • 图解AUTOSAR_SWS_UDPNetworkManagement
  • 26考研——图_图的应用(6)
  • Maven工具学习使用(一)——MAVEN安装与配置