torchvision数据集使用
文章目录
- 一、下载torchvision中的数据集文件
- 二、断点知识点
- 三、数据集形式建立
- 四、展示数据集中的图片
一、下载torchvision中的数据集文件
这段代码是使用PyTorch的torchvision库来加载CIFAR-10数据集。
import torchvision
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
root指在什么位置,train的True表示创建一个训练集,为False表示创建一个测试集,download为True的话则是直接下载
./dataset表示的是相对路径,把数据保存进名为Dataset中
回车后得到:
复制蓝色的链接,还可以在迅雷中添加链接进行下载,这样子下载的速度可能相对较快。
在下载好了文件包,可以在pycharm文件中见到Dataset文件包
二、断点知识点
- 只要我们在代码行的最左侧点击一下鼠标左键,就完成设置断点
- 设置好断点后我们可以进入调试模式
- 调试模式不同于运行模式,如果进行代码运行那么断点就依然忽略不计
鼠标右键点击一下,可以看到一只瓢虫,点击就是进行调试
- 调试的话就会将代码运行到断点就不运行了,同时在下面可以看到具体数据内容
调试的具体用法:
再来看看更重要的横排按钮:
1.跳转到当前断点(断点后你为了查看逻辑可能去了其他文件或行,点这个就能回到当前断点的行)
2.step over(F8快捷键):在当前层代码单步执行。
3.step into(F7快捷键):单步执行,但会进入子函数。如果一直按F7,则会一层层一直进入。
4.step into my code(Alt+Shift+F7快捷键):单步执行,只进入自己代码的子函数,不会进入导入包的子函数。
三、数据集形式建立
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
#root指在什么位置,train的True表示创建一个训练集,为False表示创建一个测试集
#download为True的话则是直接下载
#./dataset表示的是相对路径,把数据保存进去
#把转换成totensor格式的transform对每张照片进行处理
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
#打印测试集中的第一个数据项
print(test_set[0])
#打印测试集中所有的类别名称
print(test_set.classes)
# 提取测试集中的第一个数据项,img是图像数据,target是图像对应的类别索引。
img, target = test_set[0]
# 打印提取出的图像数据。
print(img)
# 打印提取出的类别索引。
print(target)
运行结果:
最下面的 3 表示类名classes的第三项,也就是[‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’] 中列表的第三项 “cat”
四、展示数据集中的图片
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
writer = SummaryWriter("p10")
# 循环遍历测试集中的前10张图像。
for i in range(10):
# 提取图像和对应的标签。
img, target = test_set[i]
# 使用SummaryWriter的add_image方法将图像写入TensorBoard日志。
# "test_set"是图像的标签,img是图像数据,i是图像的索引。
writer.add_image("test_set",img,i)
writer.close()
在Terminal终端中输入:tensorboard --logdir="p10"
运行结果: