PyTorch快速入门教程【小土堆】之torchvision中的数据集使用
视频地址torchvision中的数据集使用_哔哩哔哩_bilibili
本次选取了CIFAR10数据集作为演示
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor() # 转换为tensor型
])
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform,
download=True) # root为数据集放置位置,train为true说明训练集,transform为上边定义的dataset——transform,需要下载
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform,
download=True) # root为数据集放置位置,train为false说明测试集,transform为上边定义的dataset——transform,需要下载
# print(test_set[0])
# print(test_set.classes) # 类别为['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
#
# img, target = test_set[0]
# print(img)
# print(target)
# print(test_set.classes[target]) # 为cat
# img.show()
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i) # 把tensor类型的10张图片用tensorboard表示出来
writer.close()