当前位置: 首页 > article >正文

Pytorch构建网络模型结构都有哪些方式

目录

前言

1.使用nn.Module基类

2.使用nn.Sequential容器

3. 使用nn.ModuleList

4. 使用nn.ModuleDict

5. 混合使用nn.Module和原生Python代码

6.表格总结


前言

  • nn.Module:最通用、最灵活的方式,适用于几乎所有场景。
  • nn.Sequential:适合简单的顺序模型,代码简洁。
  • nn.ModuleListnn.ModuleDict:适合需要动态调整层的模型,方便子模块的管理和访问。
  • 混合使用原生Python代码:适合需要动态逻辑或复杂决策的网络模型。

这些方式可以根据具体项目需求进行选择,通常,nn.Module是最常用的方式,它能够满足几乎所有的模型设计需求。

1.使用nn.Module基类

这是最常用的方法之一。你可以通过继承nn.Module基类来定义自己的神经网络。nn.Module提供了神经网络层的封装以及模型参数的管理。

示例:

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = SimpleNet()

详细步骤:

  • __init__方法:在这里定义网络层。self.fc1 = nn.Linear(784, 128) 表示创建了一个输入大小为784、输出大小为128的全连接层。
  • forward方法:定义了数据的前向传播方式。输入数据依次通过定义的各个层,最后得到输出。

优点:灵活,适合复杂网络。

2.使用nn.Sequential容器

如果你的模型是一个简单的顺序网络(即各层按顺序逐个执行,没有复杂的网络结构),可以使用nn.Sequential来简化代码。

示例:

import torch.nn as nn

model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

详细步骤:

  • nn.Sequential接受一系列的层作为参数,并按顺序逐个应用于输入数据。
  • 各个层之间的前向传播方式自动处理,减少了手动编写forward方法的工作。

优点:简洁,适合简单的线性模型。

3. 使用nn.ModuleList

nn.ModuleList可以用来存储一个nn.Module的列表,但不会定义网络的前向传播逻辑,需要在forward方法中手动实现。

class CustomNet(nn.Module):
    def __init__(self):
        super(CustomNet, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(100, 100) for i in range(5)])
        self.relu = nn.ReLU()

    def forward(self, x):
        for layer in self.layers:
            x = self.relu(layer(x))
        return x

model = CustomNet()

详细步骤:

  • 使用nn.ModuleList存储多个相同或不同的层。
  • forward方法中循环这些层,自定义前向传播逻辑。

优点:适合需要灵活定义多个子层的网络结构。

4. 使用nn.ModuleDict

nn.ModuleDictnn.ModuleList类似,但它以字典的形式存储模块,允许通过键值对的方式来访问不同的子模块。

class CustomNet(nn.Module):
    def __init__(self):
        super(CustomNet, self).__init__()
        self.layers = nn.ModuleDict({
            'fc1': nn.Linear(784, 128),
            'fc2': nn.Linear(128, 64),
            'fc3': nn.Linear(64, 10)
        })
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.layers['fc1'](x))
        x = self.relu(self.layers['fc2'](x))
        x = self.layers['fc3'](x)
        return x

model = CustomNet()

详细步骤:

  • 使用nn.ModuleDict来存储模块,可以通过键值访问。
  • 灵活构建前向传播路径,适合需要不同路径的网络结构。

优点:适合需要动态访问或选择子模块的网络。

5. 混合使用nn.Module和原生Python代码

在某些情况下,你可能需要在模型中嵌入一些动态的逻辑。此时,可以将nn.Module与原生Python控制流(如if-elsefor循环等)结合使用,构建更加复杂的模型。

class DynamicNet(nn.Module):
    def __init__(self):
        super(DynamicNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        if x.mean() > 0.5:
            x = self.fc1(x)
        else:
            x = self.fc2(x)
        return self.relu(x)

model = DynamicNet()

详细步骤:

  • forward中使用Python原生的控制流来决定前向传播路径。
  • 这种方式非常灵活,适合复杂的模型逻辑需求。

优点:灵活且强大,适合复杂模型。

6.表格总结


http://www.kler.cn/a/281835.html

相关文章:

  • Matplotlib | 理解直方图中bins表示的数据含义
  • 散户持股增厚工具:智能T0算法交易
  • java基础知识全集(一篇看到爽)(持续更新中)
  • 退款成功订阅消息点击后提示订单不存在
  • 5G的SUCI、SUPI、5G-GUTI使用场景及关系
  • 基于普中51单片机开发板的电子门铃设计( proteus仿真+程序+设计报告+讲解视频)
  • 买了服务器后如何正确挂载数据盘|什么是系统盘,什么是数据盘
  • 33.鼠标悬停时的波浪线效果 CSS 重置
  • FLUX 1 将像 Stable Diffusion 一样完整支持ControlNet组件
  • JavaScript异步编程中的常见陷阱与解决方案
  • YASKAWA机器人维修操作命令攻略-移动命令运用案例
  • jupyter notebook修改默认浏览器(改chrome)
  • 深度学习学习经验——长短期记忆网络(LSTM)
  • 爆改YOLOv8 | yolov8添加MSDA注意力机制
  • 代码随想录算法训练营第五十五天 | 图论part05
  • 怎么整合spring security和JWT
  • eclipse 配置 ABAP 连接操作手册
  • 北摩高科半年度军航民航双突破,技术创新引领行业发展
  • Ubuntu 22.04上稳定安装与配置搜狗输入法详细教程
  • C#入门篇5
  • 千益畅行,旅游卡,案例分享
  • Android UI绘制原理:UI的绘制流程是怎么样呢?为什么子线程不能刷新UI呢?讲解大体的流程是怎么样的
  • Flutter中的Key
  • 人工智能 | AutoGPT理念与应用
  • Android 获取安装包的签名,获取签名文件的MD5值
  • 测开新手:pytest+requests+allure自动化测试接入Jenkins学习