当前位置: 首页 > article >正文

002.动手实现softmax回归(pytorch简洁版)

  1. 相关操作可复用002.从零开始实现softMax回归(pytorch)
    中的代码
import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size,root='/Users/wydi/PycharmProjects/DeepLearning_with_LiMu/datasets/FashionMnist')


batch_size = 256
num_input = 784
num_output = 10
class LineraNet(nn.Module):
    def __init__(self,num_input,num_output):
        super(LineraNet,self).__init__()
        self.linear = nn.Linear(num_input,num_output)
    def forward(self,X):
        # 数据返回的每个batch样本x的形状为(batch_size, 1, 28, 28), 所以我们先用view()将x的形状转换成(batch_size, 784)才送入全连接层。
        y =self.linear(X.view(X.shape[0]),-1)
        return y
net = LineraNet(num_input,num_output)


# 封装自定义的结构转换函数
class FlattenLayer(nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x): # x shape: (batch, *, *, ...)
        return x.view(x.shape[0], -1)


from collections import OrderedDict
net = nn.Sequential(
    # FlattenLayer(),
    # nn.Linear(num_inputs, num_outputs)
    OrderedDict([
        ('flatten', FlattenLayer()),
        ('linear', nn.Linear(num_input, num_output))
    ])
)

init.normal_(net.linear.weight, mean=0, std=0.01)
init.constant_(net.linear.bias, val=0)


loss = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(net.parameters(), lr=0.1)


num_epochs = 5
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)


epoch 1, loss 0.0031, train acc 0.748, test acc 0.781
epoch 2, loss 0.0022, train acc 0.813, test acc 0.793
epoch 3, loss 0.0021, train acc 0.825, test acc 0.819
epoch 4, loss 0.0020, train acc 0.833, test acc 0.823
epoch 5, loss 0.0019, train acc 0.837, test acc 0.821


http://www.kler.cn/news/324377.html

相关文章:

  • AutosarMCAL开发——基于EB MCU驱动
  • 爬虫逆向学习(八):Canvas画图滑块验证码解决思路与绕过骚操作
  • 第十四章:html和css做一个心在跳动,为你而动的表白动画
  • Maven 实现依赖统一管理
  • 树莓派外挂Camera(基操)(TODO)
  • 如何通过 GitHub Actions 使用 SSH 自动化部署到阿里云 ECS 实例
  • Hadoop三大组件之YARN(一)
  • 丹摩智算(damodel)部署stable diffusion实验
  • 计241 作业2:C程序设计初步
  • 19.3 打镜像部署到k8s中,prometheus配置采集并在grafana看图
  • 《程序猿之Redis缓存实战(1) · 基础知识》
  • 哈希知识点总结:哈希、哈希表、位图、布隆过滤器
  • 视频融合共享平台LntonAIServer视频智能分析抖动检测算法和过亮过暗检测算法
  • vue3 实现文本内容超过N行折叠并显示“...展开”组件
  • 基于Hive和Hadoop的图书分析系统
  • jdk1.6版本发送HTTPS请求,报错Could not generate DH keypair问题解决
  • Synchronized和 ReentrantLock有什么区别?
  • OFDM通信系统发射端需要做ifftshift的原因分析
  • C语言课程设计题目六:学生信息管理系统设计
  • Excel提取数据
  • FPGA IP 和 开源 HDL 一般去哪找?
  • Linux基础命令zip,unzip详解
  • 【ESP32】Arduino开发 | I2C控制器+I2C主从收发例程
  • 2024华为OD机试E卷-构成正方形的数量-(C++/Java/Python)
  • Redis 实现分布式锁时需要考虑的问题
  • 使用 Frida Hook Android App
  • Linux文件IO(十一)-复制文件描述符与截断文件
  • 大数据复习知识点2
  • Deep Learning for Video Anomaly Detection: A Review 深度学习视频异常检测综述阅读
  • flink设置保存点和恢复保存点