小土堆学习笔记15:搭建小实战和Sequential的使用
nn.Sequential
是什么?
nn.Sequential
是 PyTorch 中的一个容器模块,用于按顺序存放多个神经网络层。
当输入数据传入 nn.Sequential
容器时,它会按顺序依次通过每个网络层,最后得到输出。这样可以简化网络的定义,尤其适用于堆叠的简单层次结构。
例如,代码中的 self.model1
使用 nn.Sequential
将多个卷积层、池化层和全连接层串联起来,使得输入数据可以自动按顺序通过这些层。
为什么需要导入 from torch.nn import Conv2d, MaxPool2d, Flatten
?
虽然 torch.nn
模块包含 Conv2d
, MaxPool2d
, Flatten
等类,但直接 import nn
后,nn
是整个模块,并不包含具体的类引用。因此,要使用这些类时需要通过 nn.Conv2d
, nn.MaxPool2d
, nn.Flatten
来引用。
为了简化代码和提高可读性,from torch.nn import Conv2d, MaxPool2d, Flatten
可以直接将这几个类导入,这样在代码中直接使用 Conv2d
、MaxPool2d
和 Flatten
即可,而无需加 nn.
前缀。