pytorch入门(2)——TensorBoard的使用
TensorBoard 是Google开发的一个机器学习可视化工具。其主要用于记录机器学习过程,例如:
- 记录损失变化、准确率变化等
- 记录图片变化、语音变化、文本变化等,例如在做GAN时,可以过一段时间记录一张生成的图片
- 绘制模型
TensorBoard 安装
pip install tensorboard
安装后,在命令行输入,若可以正常输出,则说明安装成功。
tensorboard --help
TensorBoard 运行
tensorboard --logdir=my_log
Pytorch 使用 TensorBoard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
writer.add_image()
writer.add_scalar()
writer.close()
Pytorch使用Tensorboard主要用到了三个API:
SummaryWriter:这个用来创建一个log文件,TensorBoard面板查看时,也是需要选择查看那个log文件。
add_something: 向log文件里面增添数据。例如可以通过add_scalar增添折线图数据,add_image可以增添图片。
close:当训练结束后,我们可以通过close方法结束log写入。
实例1——绘制折线图
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
# y = x
for i in range(100):
writer.add_scalar("y=x", i, i)
writer.close()
实例2——写入图像
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
writer = SummaryWriter("logs")
img_path = "hymenoptera_data/train/ants/0013035.jpg"
img_PIL = Image.open(img_path)
img_array = np.array(img_PIL)
writer.add_image("test", img_array, 1, dataformats='HWC')
writer.close()