当前位置: 首页 > article >正文

重新理解一个类中的forward()和__init__()函数

forward()函数和__init__()的关系

__init__() 是一个类的构造函数,用于初始化对象的属性。它会在创建对象时自动调用,而且通常在这里完成对象所需的所有初始化操作。

forward() 是一个神经网络模型中的方法,用于定义数据流的向前传播过程。它接受输入数据,通过网络的各个层进行计算,最终返回输出结果。

在神经网络的 PyTorch 实现中,__init__() 方法通常用于实例化各个网络层(例如卷积层、池化层、全连接层的维度等【这里只是执行了初始化,但是可以通过后面实例化时调用的forward()重新给神经网络维度赋值】),并设置各层的超参数(例如卷积核大小、步幅、填充等)。而 forward() 方法则定义了这些网络层之间的计算顺序与逻辑,它负责将输入数据传递到网络中,并返回计算结果【这里输入进forward的数据维度要和forward()接收的第一个参数维度相同,虽然你看它只接受了一个参数‘x’,但是这个x的维度是多维的(在本代码中就是(input_dim, hidden_dim)两个大维度),而不是普通意义上的一个自然数

因此,两个方法通常一起使用,__init__() 用于设置网络结构和超参数,forward() 则定义了从输入到输出的完整计算流程。

例子:

定义类:

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

在上面的代码中,我们定义了一个名为 SimpleNet 的神经网络模型,它继承自 PyTorch 中的 nn.Module 类。我们在 __init__() 方法中定义了三层网络结构,分别是输入层 fc1、激活层 relu 和输出层 fc2。其中,输入层和输出层都使用了全连接层(nn.Linear),而激活层使用了 ReLU 激活函数。

forward() 方法中,我们按照输入数据 x 经过 fc1relufc2 三层的顺序进行计算,最终返回输出结果 out

调用

调用上述代码的 forward() 方法需要先创建一个 SimpleNet 类的对象,并将输入数据传递给该对象。以下是一个简单的示例:

# 创建一个 SimpleNet 对象,设置输入维度为 10,隐藏层维度为 20,输出维度为 5
net = SimpleNet(10, 20, 5)

# 构造一个随机的输入张量,大小为 [batch_size, input_dim],这里令 batch_size=1
input_tensor = torch.randn(1, 10)

# 将输入张量传入网络中,得到输出张量
output_tensor = net(input_tensor)

# 打印输出张量的形状
print(output_tensor.shape)

为什么上面的代没有看到 __init__()、forword()函数的出现就完成了上述代码的调用呢?

初始化一个类时,则自动调用了该类的 __init__() 方法【net = SimpleNet(10, 20, 5)】

调用一个类的实例时,会自动调用该类的forward() 方法【output_tensor = net(input_tensor)】


http://www.kler.cn/a/7416.html

相关文章:

  • 计算机网络之---公钥基础设施(PKI)
  • MyBatisPlus-DML编程控制
  • Muduo库源码剖析(八)——TcpServer类
  • 腾讯云轻量应用服务器价格表(2023版)
  • 前端学习:HTML基本标签
  • cgroups是linux内核中限制、记录、隔离进程组(process groups)所使用的物理资源的机制
  • 【C++从0到1】22、C++中switch语句
  • 「SQL面试题库」 No_25 统计各专业学生人数
  • 【ChatGPT】ChatGPT 能否取代程序员?
  • 英语——不定词(二)
  • 对象的比较(数据结构系列12)
  • 2023中国程序员薪酬报告出炉,你拖后腿了吗?
  • ViewBinding用法
  • mysql知识点看这一篇就够了!
  • 振动监测信号的角度域之阶次分析(1)
  • ChatGPT 存在很大的隐私问题
  • Java分布式事务(七)
  • 前端后端交互系列之原生Ajax的使用
  • C的实用笔记39——结构体占用内存大小(了解)
  • linux信号量及其实例
  • Vue+H5如何适配各个移动端?