当前位置: 首页 > article >正文

002.动手实现softmax回归(pytorch简洁版)

  1. 相关操作可复用002.从零开始实现softMax回归(pytorch)
    中的代码
import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size,root='/Users/wydi/PycharmProjects/DeepLearning_with_LiMu/datasets/FashionMnist')


batch_size = 256
num_input = 784
num_output = 10
class LineraNet(nn.Module):
    def __init__(self,num_input,num_output):
        super(LineraNet,self).__init__()
        self.linear = nn.Linear(num_input,num_output)
    def forward(self,X):
        # 数据返回的每个batch样本x的形状为(batch_size, 1, 28, 28), 所以我们先用view()将x的形状转换成(batch_size, 784)才送入全连接层。
        y =self.linear(X.view(X.shape[0]),-1)
        return y
net = LineraNet(num_input,num_output)


# 封装自定义的结构转换函数
class FlattenLayer(nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x): # x shape: (batch, *, *, ...)
        return x.view(x.shape[0], -1)


from collections import OrderedDict
net = nn.Sequential(
    # FlattenLayer(),
    # nn.Linear(num_inputs, num_outputs)
    OrderedDict([
        ('flatten', FlattenLayer()),
        ('linear', nn.Linear(num_input, num_output))
    ])
)

init.normal_(net.linear.weight, mean=0, std=0.01)
init.constant_(net.linear.bias, val=0)


loss = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(net.parameters(), lr=0.1)


num_epochs = 5
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)


epoch 1, loss 0.0031, train acc 0.748, test acc 0.781
epoch 2, loss 0.0022, train acc 0.813, test acc 0.793
epoch 3, loss 0.0021, train acc 0.825, test acc 0.819
epoch 4, loss 0.0020, train acc 0.833, test acc 0.823
epoch 5, loss 0.0019, train acc 0.837, test acc 0.821


http://www.kler.cn/a/324377.html

相关文章:

  • 在 Node.js 中解决极验验证码:使用 Puppeteer 自动化
  • # ubuntu 安装的pycharm不能输入中文的解决方法
  • vue3 element el-table实现表格动态增加/删除/编辑表格行,带有校验规则
  • Java项目实战II基于微信小程序的电子商城购物平台(开发文档+数据库+源码)
  • python:用 sklearn 构建 K-Means 聚类模型
  • C++中 ,new int(10),new int(),new int[10],new int[10]()
  • AutosarMCAL开发——基于EB MCU驱动
  • 爬虫逆向学习(八):Canvas画图滑块验证码解决思路与绕过骚操作
  • 第十四章:html和css做一个心在跳动,为你而动的表白动画
  • Maven 实现依赖统一管理
  • 树莓派外挂Camera(基操)(TODO)
  • 如何通过 GitHub Actions 使用 SSH 自动化部署到阿里云 ECS 实例
  • Hadoop三大组件之YARN(一)
  • 丹摩智算(damodel)部署stable diffusion实验
  • 计241 作业2:C程序设计初步
  • 19.3 打镜像部署到k8s中,prometheus配置采集并在grafana看图
  • 《程序猿之Redis缓存实战(1) · 基础知识》
  • 哈希知识点总结:哈希、哈希表、位图、布隆过滤器
  • 视频融合共享平台LntonAIServer视频智能分析抖动检测算法和过亮过暗检测算法
  • vue3 实现文本内容超过N行折叠并显示“...展开”组件
  • 基于Hive和Hadoop的图书分析系统
  • jdk1.6版本发送HTTPS请求,报错Could not generate DH keypair问题解决
  • Synchronized和 ReentrantLock有什么区别?
  • OFDM通信系统发射端需要做ifftshift的原因分析
  • C语言课程设计题目六:学生信息管理系统设计
  • Excel提取数据