当前位置: 首页 > article >正文

前馈神经网络 (Feedforward Neural Network, FNN)

代码功能

网络定义:
使用 torch.nn 构建了一个简单的前馈神经网络。
隐藏层使用 ReLU 激活函数,输出层使用 Sigmoid 函数(适用于二分类问题)。
数据生成:
使用经典的 XOR 问题作为数据集。
数据点为二维输入,目标为 0 或 1。
训练过程:
使用二分类交叉熵损失函数 BCELoss。
优化器为 Adam,具有较快的收敛速度。
损失可视化:
每次训练后记录损失并绘制损失曲线。
结果输出:
显示最终预测值,并与真实标签进行比较。
在这里插入图片描述

代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# 1. 定义前馈神经网络
class FeedforwardNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FeedforwardNN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),  # 输入层到隐藏层
            nn.ReLU(),  # 激活函数
            nn.Linear(hidden_dim, output_dim),  # 隐藏层到输出层
            nn.Sigmoid()  # 输出层的激活函数(适用于二分类问题)
        )

    def forward(self, x):
        return self.fc(x)

# 2. 创建 XOR 数据集
def create_xor_data():
    X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
    y = np.array([[0], [1], [1], [0]], dtype=np.float32)
    return X, y

# 3. 训练前馈神经网络
def train_fnn():
    # 数据准备
    X, y = create_xor_data()
    X = torch.tensor(X, dtype=torch.float32)
    y = torch.tensor(y, dtype=torch.float32)

    # 初始化网络、损失函数和优化器
    input_dim = X.shape[1]
    hidden_dim = 10
    output_dim = 1
    model = FeedforwardNN(input_dim, hidden_dim, output_dim)
    criterion = nn.BCELoss()  # 二分类交叉熵损失
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    # 训练网络
    epochs = 1000
    loss_history = []
    for epoch in range(epochs):
        # 前向传播
        outputs = model(X)
        loss = criterion(outputs, y)

        # 反向传播与优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 记录损失
        loss_history.append(loss.item())
        if (epoch + 1) % 100 == 0:
            print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

    # 绘制损失曲线
    plt.plot(loss_history)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.show()

    # 输出训练结果
    with torch.no_grad():
        predictions = model(X).round()
        print("Predictions:", predictions.numpy())
        print("Ground Truth:", y.numpy())

# 运行训练
if __name__ == "__main__":
    train_fnn()


http://www.kler.cn/a/401612.html

相关文章:

  • C# 异常处理、多个异常、自定义异常处理
  • 【HarmonyOS】鸿蒙系统在租房项目中的项目实战(一)
  • MODBUS TCP转CANOpen网关
  • 【会话文本nlp】对话文本解析库pyconverse使用教程版本报错、模型下载等问题解决超参数调试
  • 基于yolov8、yolov5的行人检测识别系统(含UI界面、训练好的模型、Python代码、数据集)
  • 大语言模型通用能力排行榜(2024年11月8日更新)
  • 如何理解Lua 使用虚拟堆栈
  • Windows11暂停更新(超长延期)
  • html5 实现视频播放
  • 【设计模式】模板方法模式 在java中的应用
  • javaScript交互补充3(JSON数据)
  • JavaEE-多线程基础知识
  • C++ ─── 哈希表(unordered_set 和unordered_map) 开散列和闭散列的模拟实现
  • 搜维尔科技:基于Touch力反馈与VR技术的虚拟气管切开术的虚拟操作软件平台
  • CentOS 环境下通过 YUM 安装软件
  • OpenAI 助力数据分析中的模式识别与趋势预测
  • 疫情期间基于Spring Boot的图书馆管理系统
  • 基于yolov8、yolov5的行人检测识别系统(含UI界面、训练好的模型、Python代码、数据集)
  • Vue所有图片预加载加上Token请求头信息、图片请求加载鉴权
  • 小米运动健康与华为运动健康在苹手机ios系统中无法识别蓝牙状态 (如何在ios系统中开启 蓝牙 相册 定位 通知 相机等功能权限,保你有用)
  • 23种设计模式-模板方法(Template Method)设计模式
  • Unix发展历程的深度探索
  • 时代变迁对传统机器人等方向课程的巨大撕裂
  • react 使用中注意事项提要
  • CSS3中的伸缩盒模型(弹性盒子、弹性布局)之伸缩容器、伸缩项目、主轴方向、主轴换行方式、复合属性flex-flow
  • Java 多线程详解