PyTorch入门教学——torchvision中数据集的使用
1、torchvision.datasets
- datasets是torchvision工具集中的一个工具。
- 可以理解为调用官方数据集的一种方式,其中有很多开源的数据集,可供我们学习使用。
- datasets官网:Datasets — Torchvision 0.16 documentation (pytorch.org)
2、使用
- 这里以使用CIFAR10中的数据为例。
- 其中有这个数据集的使用方法和具体介绍。
- 参数:(每个数据集的参数大致相同)
- root:数据集下载后存放的目录。
- train:如果为True,则从训练集创建数据集,否则从测试集创建。
- transform:接收PIL图像的转换方式,并返回转换后的版本。
- download:如果为True,则从互联网下载数据集,然后将其放在设置的目录中。如果数据集已下载,则不会再次下载。
- 代码演示——查看数据集中图片的信息
-
import torchvision train_set = torchvision.datasets.CIFAR10(root="./Dataset/CIFAR10", train=True, download=True) # root:数据集要存放在什么位置 test_set = torchvision.datasets.CIFAR10(root="./Dataset/CIFAR10", train=False, download=True) print(test_set[0]) # 第一张图片的信息,包含格式和标签 print(test_set.classes) # 数据集中所包含的图片类别 img, target = test_set[0] print(img) print(target) # 标签 print(test_set.classes[target]) # 第一张图片的标签为猫 img.show() # 显示图片
-
- 代码演示——将数据集中的前10张图片在tensorboard中展示出来。
-
import torchvision from torch.utils.tensorboard import SummaryWriter test_set = torchvision.datasets.CIFAR10( root="./Dataset/CIFAR10", transform=torchvision.transforms.ToTensor(), # 将图片转换为totensor数据类型 train=False, download=True) writer = SummaryWriter('logs') # writer把summary内容写在哪个目录下 for i in range(10): img, target = test_set[i] writer.add_image('test_set', img, i) writer.close()
- 运行程序后,打开终端,输入下列命令打开tensorboard。
-
tensorboard --logdir=logs --port=6007
- (该数据集的图片像素为32*32,所以比较模糊)
-