from torchvision import datasets
""" 下载训练数据集 (包含训练数据+标签)"""
training_data = datasets.MNIST(
root='data',
train=True,
download=True,
transform=ToTensor() # 张量,图片是不能直接传入神经网络模型
) # 对于pytorch库能够识别的数据一般是tensor张量.
# NumPy 数组只能在CPU上运行。Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度。
""" 下载测试数据集(包含训练图片+标签)"""
test_data = datasets.MNIST(
root='data',
train=False,
download=True,
transform=ToTensor()
)
print(len(training_data))
""" 展示手写字图片 """
# tensor --> numpy 矩阵类型的数据
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):
img, label = training_data[i + 59000] # 提取第59000张图片
figure.add_subplot(3, 3, i + 1) # 图像窗口中创建多个小窗口,小窗口用于显示图片
plt.title(label)
plt.axis("off") # 关闭坐标
plt.imshow(img.squeeze(), cmap="gray")
a = img.squeeze() # img.squeeze()从张量img中去掉维度为1的(降维)
plt.show()