Pytorch学习笔记(四)Learn the Basics - Transforms
这篇博客瞄准的是 pytorch 官方教程中 Learn the Basics 章节的 Transforms 部分。
- 官网链接:https://pytorch.org/tutorials/beginner/basics/transforms_tutorial.html
完整网盘链接: https://pan.baidu.com/s/1L9PVZ-KRDGVER-AJnXOvlQ?pwd=aa2m 提取码: aa2m
Transforms
由于大部分数据并不是以训练机器学习算法所需的最终处理形式出现,所以需要使用 Transforms 来对数据进行一些操作以适配机器学习。
所有 TorchVision 数据集都有两个参数: -transform
用于修改特征、target_transform
用于修改标签。torchvision.transforms
模块提供了几种常用的开箱即用转换。
FashionMNIST
特征采用 PIL
图像格式,label 为整数。为了训练需要将特征作为归一化 Tensor,使用 ToTensor 和 Lambda将label作为one-hot编码的Tensor。
- 导入数据并用Lambda将数据转换为Tensor
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(
lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)
)
)