当前位置: 首页 > article >正文

动手学深度学习(四)---多层感知机

文章目录

  • 一、理论知识
    • 1.感知机
    • 2.XOR问题
    • 3.多层感知机
    • 4.多层感知机的从零开始实现
  • 【相关总结】
    • 1.torch.randn()
    • 2.torch.zeros_like()

一、理论知识

1.感知机

给定输入x,权重w,和偏移b,感知机输出:
在这里插入图片描述
在这里插入图片描述

2.XOR问题

感知机不能拟合XOR问题,他只能产生线性分割面
在这里插入图片描述

3.多层感知机

多层感知机和softmax没有本质区别,只是多加了一层隐藏层 没有隐藏层就是softmax回归,加上隐藏层就是多层感知机

4.多层感知机的从零开始实现

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

2.实现一个具有单隐藏层的多层感知机,他包含256个隐藏单元

num_inputs, num_outputs, num_hiddens = 784, 10, 256
# 28 * 28

# 声明是torch的Parameter
W1 = nn.Parameter(
#     生成随机数字的tensor
    torch.randn(num_inputs, num_hiddens, requires_grad=True))
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad = True))
W2 = nn.Parameter(
    torch.randn(num_hiddens, num_outputs, requires_grad=T rue))
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))

params = [W1, b1, W2, b2]

【相关总结】

1.torch.randn()

生成随机数字的tensor
这些随机数字满足标准正态分布
torch.randn(size) size可以为一个数字或者一个元组

import torch
x = torch.randn(3)
y = torch.randn(2,3)
print(x)
print(y)

tensor([-0.1201, -1.0340, 0.7885])
tensor([[-0.5694, 0.0461, 1.0315],
[-1.0342, -0.9757, -0.1844]])

2.torch.zeros_like()

torch.zeros_like(input, dtype=None, layout=None, device=None, requires_grad=False)
返回一个与给定输入张量形状和数据类型相同,但所有元素都被设置为零的新张量。

import torch

x = torch.tensor([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])
y = torch.zeros_like(x)
print(y)

tensor([[0, 0, 0],
[0, 0, 0],
[0, 0, 0]])


http://www.kler.cn/a/145732.html

相关文章:

  • QT:QTabWidget设置tabPosition为West时,文字向上
  • 梯度提升决策树树(GBDT)公式推导
  • qml Timer详解
  • C#与AI的共同发展
  • 【JVM-9】Java性能调优利器:jmap工具使用指南与应用案例
  • CTTSHOW-WEB入门-爆破25-28
  • 【蓝桥杯】刷题
  • 卷积神经网络经典backbone
  • 使用Selenium、Python和图鉴打码平台实现B站登录
  • 让代码变美的第三天 - 简单工厂模式
  • 27、Nuxt.js项目整合ElementUI组件库
  • 【线性代数与矩阵论】坐标变换与相似矩阵
  • HTML的学习
  • kafka的设计原理
  • FO-like Transformation
  • [ruby on rails] array、jsonb字段
  • Java 文件常用操作与流转换
  • 单细胞seurat入门—— 从原始数据到表达矩阵
  • 隐写-MISC-bugku-解题步骤
  • QXDM Filter使用指南
  • P17C++析构函数
  • java - 定时器
  • 机器学习【04重要】pycharm中关闭jupyter服务器
  • 交叉编译 和 软硬链接 的初识(面试重点)
  • 【面经八股】搜广推方向:常见面试题(五)
  • 流量主如何在广告收益和用户体验中找到平衡