当前位置: 首页 > article >正文

神经网络多层感知器异或问题求解-学习篇

多层感知器可以解决单层感知器无法解决的异或问题
首先给了四个输入样本,输入样本和位置信息如下所示,现在要学习一个模型,在二维空间中把两个样本分开,输入数据是个矩阵,矩阵中有四个样本,样本的维度是三维,三个维度分别表示偏置,x坐标,y坐标。对应的标签Y是 区分矩阵中的样本,例如:[1,0,0]对应0,[1,0,1]对应1,以此类推。使用单层感知器无法解决此异或问题,怎么样使用多层感知器求解这个问题?
在这里插入图片描述

#输入数据 各维度表示偏置,x坐标,y坐标
X = np.array([[1,0,0],
              [1,0,1],
              [1,1,0],
              [1,1,1]])
#标签
Y = np.array([[0,1,1,0]])
#第一个网络层参数矩阵,初始化输入层权值,取值范围-1 到 1
V = (np.random.random((3,4)) - 0.5) * 2
#第二个网络层参数矩阵,初始化输入层权值,取值范围-1 到 1
W = (np.random.random((4,1)) - 0.5) * 2
使用误差反向传播算法原理

算法原理:
误差反向传播算法的基本思想是通过两个过程来实现神经网络的训练:信号的正向传播与误差的反向传播。

正向传播: 输入一个训练样本,通过神经网络的前向传播计算出输出结果。具体来说,从输入层开始,计算每一层节点的输出值,直到得到网络最终的输出结果。每一层节点的输出都是基于前一层的输出和当前层的权重、偏置计算得到的。
误差反向传播: 将输出结果与实际结果进行比较,计算出误差。然后,将误差从输出层向输入层反向传播,计算每个节点对误差的贡献,并根据贡献值调整每个节点的权重和偏置。这个过程是通过链式法则实现的,即利用误差的梯度信息来逐层调整权重和偏置。

算法步骤:
初始化: 随机初始化神经网络的权重和偏置。
输入训练样本对: 将训练样本输入到神经网络的输入层。
前向传播: 根据当前的权重和偏置,计算每一层节点的输出值,直到得到网络最终的输出结果。
计算网络输出误差: 将输出结果与实际结果进行比较,计算出误差。误差的计算通常使用损失函数,如均方误差、交叉熵损失等。
反向传播: 根据误差计算每个节点对误差的贡献,并将此贡献值反向传播回去。具体来说,对于每个输出节点,计算其误差,然后将此误差沿着连接线进行反向传播,直到到达输入层的节点。
调整权重和偏置: 根据误差贡献值,对每个权重和偏置进行调整。调整的方法通常是使用梯度下降算法,即根据误差的梯度信息来更新权重和偏置。
检查网络总误差: 检查网络的总误差是否达到精度要求。如果满足,则训练结束;如果不满足,则返回步骤2,继续训练过程。

接下来我们进行一下公式迭代:

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
根据上面的计算,我们就可以求得Wji的参数

下面使用Python解决多层感知器异或问题

import numpy as np
import matplotlib.pyplot as plt

#输入数据
X = np.array([[1,0,0],
              [1,0,1],
              [1,1,0],
              [1,1,1]])
#标签
Y = np.array([[0,1,1,0]])
#第一个网络层参数矩阵,初始化输入层权值,取值范围-1 到 1
V = (np.random.random((3,4)) - 0.5) * 2
#第二个网络层参数矩阵,初始化输入层权值,取值范围-1 到 1
W = (np.random.random((4,1)) - 0.5) * 2

def get_show():
    #正样本
    all_positive_x = [0,1]
    all_positive_y = [0,1]
    #负样本
    all_negative_x = [0,1]
    all_negative_y = [1,0]

    plt.figure()
    plt.plot(all_positive_x, all_positive_y,'bo')
    plt.plot(all_negative_x, all_negative_y,'yo')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.show()

#get_show()

lr = 0.11  #学习率
#激活函数(从0-1)
def sigmoid(x):
    return 1/(1+np.exp(-x))
#激活函数的导数
def dsigmoid(x):
    x = x*(1-x)
    return x
#更新权值(2个权值矩阵,V和W)
def update():
    global X,Y,V,W,lr
    L1 = sigmoid(np.dot(X,V))  #隐藏层输出(4*3)x(3*4)=(4,4)
    L2 = sigmoid(np.dot(L1,W))  #输出层输出(4,4)x(4*1)=(4,1)
    L2_delta = (Y.T - L2) * dsigmoid(L2) #输出层的误差=下一层的误差*激活函数导数*与下一层的连接权重矩阵(全为1)
    L1_delta = L2_delta.dot(W.T)*dsigmoid(L1) #隐藏层的误差=下一层的误差*激活函数导数*与下一层的连接权重矩阵
    W_C = lr*L1.T.dot(L2_delta)
    V_C = lr*X.T.dot(L1_delta)
    W = W + W_C #对W矩阵的参数更新 模型的学习
    V = V + V_C #对V矩阵的参数更新

errors = []  #记录误差
for i in range(100000):
    update() #更新权值
    if i % 1000 == 0: #输出误差
        L1 = sigmoid(np.dot(X,V))
        L2 = sigmoid(np.dot(L1,W))
        errors.append(np.mean(np.abs(Y.T-L2)))
        print("Error:",np.mean(np.abs(Y.T-L2)))
plt.plot(errors)
plt.ylabel('Errors')
plt.show()

L1 = sigmoid(np.dot(X, V))  # 隐藏层输出(4*3)x(3*4)=(4,4)
L2 = sigmoid(np.dot(L1, W))  # 输出层输出(4,4)x(4*1)=(4,1)
print(L2) #第二层的结果 概率矩阵 》0.5是一类,小于05是一类

def classify(x):
    if x > 0.5:
        return 1
    else:
        return 0

for i in map(classify, L2): #L2一共4个数
    print(i)

运行结果为:
在这里插入图片描述

#是与 标签Y = np.array([[0,1,1,0]]) 趋向一致的
[[0.01011728]
 [0.98925078]
 [0.99013233]
 [0.01323669]]
0
1
1
0

http://www.kler.cn/news/303476.html

相关文章:

  • mysql数据库如何开启binlog日志
  • cesium.js 入门到精通(7)
  • 修改centos7系统语言en_US.UTF-8为中文zh_CN.UTF-8
  • 高防服务器的优势与劣势分析
  • 【LLM:Fan】
  • 踩坑记:Poco库,MySql,解析大文本的bug
  • 递归、排序、二分查找(C语言实现)
  • mybatis与concat实现模糊查询、mybatis中模糊查询concat传入参数为空时的解决方法
  • nacos安装使用调优及面试题分享
  • Apple发布会都有哪些亮点?如何在苹果手机和电脑上录制屏幕?
  • MATLAB默认工作路径修改
  • 串口通信数据包介绍和包结构定义实例
  • 【Echarts】vue3打开echarts的正确方式
  • real, dimension(3) :: rho1 和 real :: rho1(3) 的区别
  • C++学习笔记----7、使用类与对象获得高性能(一)---- 书写类(1)
  • element表格合并列数据相同合并单元格
  • 【Flutter 面试题】 无需上下文进行路由跳转原理是怎么样的
  • Python用MarkovRNN马尔可夫递归神经网络建模序列数据t-SNE可视化研究
  • 医疗报销|基于springboot的医疗报销系统设计与实现(附项目源码+论文+数据库)
  • RocketMQ 集群搭建详细指南
  • F12抓包10:UI自动化 - Elements(元素)定位页面元素
  • 【devops】devops-git之git分支与标签使用
  • Kubernetes 容器与镜像管理
  • 五、Django 路由配置
  • 如何编写ChatGPT提示词
  • LabVIEW中EPICS客户端/服务端的测试
  • 数据库系统概论(3,4)
  • 【网络安全】漏洞挖掘之会话管理缺陷
  • Layout 布局组件快速搭建
  • 如何建设数据中台(五)——数据汇集—打破企业数据孤岛