当前位置: 首页 > article >正文

【深度学习】多层感知机的简洁实现

多层感知机的简洁实现

本节将介绍(通过高级API更简洁地实现多层感知机)。

import torch
from torch import nn
from d2l import torch as d2l

模型

与softmax回归的简洁实现相比,唯一的区别是我们添加了2个全连接层(之前我们只添加了1个全连接层)。
第一层是[隐藏层],它(包含256个隐藏单元,并使用了ReLU激活函数)。
第二层是输出层。

net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.ReLU(),
                    nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);
  • nn.Flatten() 是 PyTorch 中 torch.nn 模块里的一个类,它的主要作用是将输入的多维张量进行扁平化处理,也就是把除了第 0 维(通常代表批量大小 batch_size)之外的其余维度合并成一个一维向量
  • torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
    参数说明:
    tensor:需要进行初始化的张量,通常是神经网络层的权重(如 m.weight)。
    mean:正态分布的均值,默认值为 0.0。
    std:正态分布的标准差,默认值为 1.0。

nn.Linear 是 PyTorch 中 torch.nn 模块里用于构建全连接层(也称为线性层)的类

[训练过程]的实现与我们实现softmax回归时完全相同,这种模块化设计使我们能够将与模型架构有关的内容独立出来。

batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)

net.parameters() 是一个生成器,它会返回模型中所有需要学习的参数(可训练的张量),包括权重和偏置等

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

在这里插入图片描述


http://www.kler.cn/a/532895.html

相关文章:

  • e2studio开发RA4M2(6)----GPIO外部中断(IRQ)配置
  • 神经网络参数量和运算量的计算- 基于deepspeed库和thop库函数
  • 【C语言篇】“三子棋”
  • 将markdown文件和LaTex公式转为word
  • 并行计算、分布式计算与云计算:概念剖析与对比研究(表格对比)
  • 51c嵌入式~电路~合集25
  • 渗透测试之文件包含漏洞 超详细的文件包含漏洞文章
  • 3、参数化测试
  • 【Redis实战】Chapter01-投票后端
  • 『 C++ 』中理解回调类型在 C++ 中的使用方式。
  • Android学习20 -- 手搓App2(Gradle)
  • leetcode 1482. 制作 m 束花所需的最少天数
  • git error: invalid path
  • Redis - String相关命令
  • UE编辑器工具
  • 【自学笔记】Git的重点知识点-持续更新
  • LeetCode:392.判断子序列
  • 接口游标分页
  • 本系统旨在为用户提供一个灵活且可扩展的信息安全管理解决方案,通过插件化的开发模式,使得信息安全的维护更加高效、便捷。
  • 云原生详解:构建未来应用的架构革命
  • 996引擎-怪物:Lua 刷怪+清怪+自动拾取
  • 2025_2_4 C语言中关于free函数及悬空指针,链表的一级指针和二级指指针
  • 【Block总结】CoT,上下文Transformer注意力|即插即用
  • IIC重难点-2
  • 【JavaScript】《JavaScript高级程序设计 (第4版) 》笔记-Chapter2-HTML 中的 JavaScript
  • mysql 学习7 DCL语句,用来管理数据库用户,控制数据库的访问权限