当前位置: 首页 > article >正文

小波神经网络:结合小波变换与神经网络的力量(附Pytorch代码实现)

小波神经网络:结合小波变换与神经网络的力量

引言

在人工智能和机器学习的领域中,小波神经网络(WNN)是一种结合了小波变换的多分辨率分析能力和神经网络的学习能力的混合模型。这种网络不仅能够捕捉数据的局部特征,还能够通过学习过程来逼近复杂的函数和模式。本文将深入探讨小波神经网络的工作原理、特点以及实现方法。

什么是小波神经网络?

小波神经网络是一种前馈神经网络,其隐藏层的激活函数由小波基函数导出。这种网络结构与多层感知器(MLP)相似,但关键在于其隐藏层的神经元,这些神经元被称为wavelons,它们的激活函数是基于小波变换的。

构建小波神经网络主要有两种方法:

  1. 分离方法:在这种方法中,小波变换和小波神经网络的处理是分开进行的。输入信号首先在隐藏层中被分解,然后小波系数用于调整输入的突触权重。

  2. 集成方法:在这种方法中,小波的平移和缩放与输入权重一起根据学习算法进行调整。

特点

小波神经网络的一个显著特点是其在隐藏层和输出层中没有极化或阈值,这使得网络能够更加灵活地逼近复杂的函数。

方程

前向计算

前向计算涉及到网络的输入、权重和激活函数。对于每个隐藏层神经元,其净输入计算如下:

n e t j ( n ) = ∑ i = 1 N i x i ( n ) ⋅ t j i net_j(n) = \sum_{i = 1}^{N_i} x_i(n) \cdot t_{ji} netj(n)=i=1Nixi(n)tji

其中 x i ( n ) x_i(n) xi(n) 是输入, t j i t_{ji} tji 是权重。激活函数是关于其参数 θ \theta θ 的sigmoid函数的二阶导数,其中 θ = n e t j − t j i r j \theta = \frac{net_j - t_{ji}}{r_j} θ=rjnetjtji

y j = φ ( n e t j ) = d 2 d θ 2 s i g ( θ ) y_j = \varphi(net_j) = \frac{d^2}{d\theta^2} sig(\theta) yj=φ(netj)=dθ2d2sig(θ)

最终输出 z z z 是隐藏层输出的线性组合:

z = ∑ j = 1 N h a j ( n ) y j ( n ) z = \sum_{j = 1}^{N_h} a_j(n)y_j(n) z=j=1Nhaj(n)yj(n)

反向计算

反向计算涉及到误差的传播和权重的调整。误差 E ( n ) E(n) E(n) 定义为:

E ( n ) = 1 2 e ( n ) 2 E(n) = \frac{1}{2} e(n)^2 E(n)=21e(n)2

e ( n ) = d ( n ) − z ( n ) e(n) = d(n) - z(n) e(n)=d(n)z(n)

自由变量的调整

权重 a j a_j aj、中心 t j i t_{ji} tji 和宽度(或缩放) r j r_j rj 的调整公式如下:

权重 a j a_j aj 的调整

Δ a j = − η ∂ E ∂ a j \Delta a_j = -\eta \frac{\partial E}{\partial a_j} Δaj=ηajE

∂ E ∂ a j = e ⋅ ( − 1 ) ⋅ y j \frac{\partial E}{\partial a_j} = e \cdot (-1) \cdot y_j ajE=e(1)yj

因此,

Δ a j = η ⋅ e ⋅ d 2 d θ 2 s i g ( θ ) \Delta a_j = \eta \cdot e \cdot \frac{d^2}{d\theta^2} sig(\theta) Δaj=ηedθ2d2sig(θ)

中心 t j i t_{ji} tji 的调整

Δ t j i = − η ∂ E ∂ t j i \Delta t_{ji} = -\eta \frac{\partial E}{\partial t_{ji}} Δtji=ηtjiE

∂ E ∂ t j i = e ⋅ ( − 1 ) ⋅ a j ⋅ [ d 3 d θ 3 s i g ( θ ) ⋅ ( − 1 r j ) ] \frac{\partial E}{\partial t_{ji}} = e \cdot (-1) \cdot a_j \cdot \left[\frac{d^3}{d\theta^3}sig(\theta)\cdot\left(\frac{-1}{r_j}\right)\right] tjiE=e(1)aj[dθ3d3sig(θ)(rj1)]

因此,

Δ t j i = − η e ⋅ a j r j ⋅ d 3 d θ 3 s i g ( θ ) \Delta t_{ji} = -\eta\frac{e\cdot a_j}{r_j}\cdotp \frac{d^3}{d\theta^3}sig(\theta) Δtji=ηrjeajdθ3d3sig(θ)

宽度(或缩放) r j r_j rj 的调整

Δ r j = − η ∂ E ∂ r j \Delta r_{j} = -\eta \frac{\partial E}{\partial r_{j}} Δrj=ηrjE

∂ E ∂ r j = e ⋅ ( − 1 ) ⋅ a j ⋅ [ d 3 d θ 3 s i g ( θ ) ⋅ ( − ( n e t j − t j i ) r j 2 ) ] \frac{\partial E}{\partial r_{j}} = e \cdot (-1) \cdot a_j \cdot \left[\frac{d^3}{d\theta^3}sig(\theta)\cdot\left(-\frac{(net_j - t_{ji})}{r_j^2}\right)\right] rjE=e(1)aj[dθ3d3sig(θ)(rj2(netjtji))]

因此,

Δ r j i = − η e ⋅ a j ⋅ ( n e t j − t j i ) r j 2 ⋅ d 3 d θ 3 s i g ( θ ) \Delta r_{ji} = -\eta\frac{e\cdot a_j \cdot (net_j - t_{ji})}{r_j^2}\cdotp \frac{d^3}{d\theta^3}sig(\theta) Δrji=ηrj2eaj(netjtji)dθ3d3sig(θ)

代码实现

以下是小波神经网络的一个简单实现,使用Python语言。这个实现包括了网络的初始化、训练和绘图功能。数值结果如下图
在这里插入图片描述

import matplotlib.pyplot as plt  # 用于绘制图形
import numpy as np  # 用于处理数组
from math import sqrt, pi  # 从math库导入平方根和圆周率

class WNN(object):
    def __init__(self, eta=0.008, epoch_max=50000, Ni=1, Nh=40, Ns=1):
        """
        初始化WNN类的实例。
        
        :param eta: 学习率
        :param epoch_max: 最大训练周期数
        :param Ni: 输入层神经元数量
        :param Nh: 隐藏层神经元数量
        :param Ns: 输出层神经元数量
        """
        # 初始化参数
        self.eta = eta
        self.epoch_max = epoch_max
        self.Ni = Ni
        self.Nh = Nh
        self.Ns = Ns
        self.Aini = 0.01  # 初始化权重的放大因子

    def load_first_function(self):
        """
        加载第一个函数的数据。
        """
        x = np.arange(-6, 6, 0.15)  # 生成x值数组,步长为0.15
        self.N = x.shape[0]  # 样本数量
        xmax = np.max(x)  # x的最大值

        self.X_train = x / xmax  # 归一化x值
        self.d = 1 / (1 + np.exp(-1 * x)) * (np.cos(x) - np.sin(x))  # 目标函数值

    def sig_dev2(self, theta):
        """
        计算sigmoid函数的二阶导数。
        
        :param theta: sigmoid函数的参数
        :return: 二阶导数值
        """
        return 2 * (1 / (1 + np.exp(-theta)))**3 - 3 * (1 / (1 + np.exp(-theta)))**2 + (1 / (1 + np.exp(-theta)))

    def sig_dev3(self, theta):
        """
        计算sigmoid函数的三阶导数。
        
        :param theta: sigmoid函数的参数
        :return: 三阶导数值
        """
        return -6 * (1 / (1 + np.exp(-theta)))**4 + 12 * (1 / (1 + np.exp(-theta)))**3 - 7 * (1 / (1 + np.exp(-theta)))**2 + (1 / (1 + np.exp(-theta)))

    def sig_dev4(self, theta):
        """
        计算sigmoid函数的四阶导数。
        
        :param theta: sigmoid函数的参数
        :return: 四阶导数值
        """
        return 24 * (1 / (1 + np.exp(-theta)))**5 - 60 * (1 / (1 + np.exp(-theta)))**4 + 50 * (1 / (1 + np.exp(-theta)))**3 - 15 * (1 / (1 + np.exp(-theta)))**2 + (1 / (1 + np.exp(-theta)))

    def sig_dev5(self, theta):
        """
        计算sigmoid函数的五阶导数。
        
        :param theta: sigmoid函数的参数
        :return: 五阶导数值
        """
        return -120 * (1 / (1 + np.exp(-theta)))**6 + 360 * (1 / (1 + np.exp(-theta)))**5 - 390 * (1 / (1 + np.exp(-theta)))**4 + 180 * (1 / (1 + np.exp(-theta)))**3 - 31 * (1 / (1 + np.exp(-theta)))**2 + (1 / (1 + np.exp(-theta)))
    
    def train(self):
        """
        训练WNN模型。
        """
        # 初始化权重
        self.A = np.random.rand(self.Ns, self.Nh) * self.Aini

        # 初始化中心
        self.t = np.zeros((1, self.Nh))

        idx = np.random.permutation(self.Nh)
        for j in range(self.Nh):  
            self.t[0, j] = self.d[idx[j]]
        
        # 初始化宽度(或缩放)
        self.R = abs(np.max(self.t) - np.min(self.t)) / 2

        MSE = np.zeros(self.epoch_max)  # 存储每个epoch的MSE
        plt.ion()  # 开启交互模式

        for epoca in range(self.epoch_max): 
            z = np.zeros(self.N)  # 存储每个样本的输出
            E = np.zeros(self.N)  # 存储每个样本的误差

            index = np.random.permutation(self.N)

            for i in index:
                xi = self.X_train[i]  # 当前输入
                theta = (xi - self.t) / self.R  # 计算theta
                yj = self.sig_dev2(theta)  # 计算sigmoid二阶导数
                z[i] = np.dot(self.A, yj.T)[0][0]  # 计算WNN输出

                e = self.d[i] - z[i]  # 计算误差
                self.A = self.A + (self.eta * e * yj)  # 更新权重
                self.t = self.t - (self.eta * e * self.A / self.R * self.sig_dev3(theta))  # 更新中心
                self.R = self.R - (((self.eta * e * self.A * (xi - self.t)) / self.R**2) * self.sig_dev3(theta))  # 更新宽度

                E[i] = 0.5 * e**2  # 计算误差平方

            MSE[epoca] = np.sum(E) / self.N  # 计算MSE

            if (epoca % 200 == 0 or epoca == self.epoch_max - 1):
                if (epoca != 0):
                    plt.cla()
                    plt.clf()
                
                self.plot(z, epoca)  # 绘制进度图
        
        print(MSE[-1])  # 输出最终的MSE

        plt.ioff()  # 关闭交互模式
        plt.figure(1)
        plt.title('Mean Square Error (MSE)')  # 绘制MSE图
        plt.xlabel('Training Epochs')
        plt.ylabel('MSE')
        plt.plot(np.arange(0, MSE.size), MSE)
        plt.show()

    def plot(self, saida, epoca):
        """
        绘制训练进度图。
        
        :param saida: 当前输出
        :param epoca: 当前epoch
        """
        plt.figure(0)
        y, = plt.plot(self.X_train, saida, label="y")  # 绘制WNN输出
        d, = plt.plot(self.X_train, self.d, '.', label="d")  # 绘制目标值
        plt.legend([y, d], ['WNN Output', 'Desired Value'])
        plt.xlabel('x')
        plt.ylabel('f(x)')
        plt.text(np.min(self.X_train) - np.max(self.X_train) * 0.17, np.min(self.d) - np.max(self.d) * 0.17, 'Progress: ' + str(round(float(epoca) / self.epoch_max * 100, 2)) + '%')
        plt.axis([np.min(self.X_train) - np.max(self.X_train) * 0.2, np.max(self.X_train) * 1.2, np.min(self.d) - np.max(self.d) * 0.2, np.max(self.d) * 1.4])
        plt.show()
        plt.pause(1e-100)  # 暂停一段时间以便观察图形

    def show_function(self):
        """
        显示目标函数的图形。
        """
        plt.figure(0)
        plt.title('Function')
        plt.xlabel('x')
        plt.ylabel('f(x)')
        plt.plot(self.X_train, self.d)
        plt.show()

# 创建WNN类的实例
wnn = WNN()

# 加载数据并训练模型
wnn.load_first_function()
wnn.train()

http://www.kler.cn/a/392329.html

相关文章:

  • 图漾相机基础操作
  • github开源链游详细搭建文档
  • 【FlutterDart】 listView.builder例子二(14 /100)
  • CSS 学习之正确看待 CSS 世界里的 margin 合并
  • springboot适配mybatis+guassdb与Mysql兼容性问题处理
  • Neo4j的部署和操作
  • 详细介绍MySQL、Mongo、Redis等数据库的索引
  • Prometheus常用查询PromQL表达式
  • 国家网络安全法律法规
  • mqtt学习笔记(一)
  • 汽车共享管理:SpringBoot技术的应用与挑战
  • 操作系统离散存储练习题
  • C#核心(9)静态类和静态构造函数
  • 机器学习——朴素贝叶斯
  • C++ QT 工具日志异步分批保存
  • 英伟达Isaac Manipulator产品体验
  • 【Vue3】知识汇总,附详细定义和源码详解,后续出微信小程序项目(3)
  • Error response from daemon:
  • OCRSpace申请free api流程
  • Power bi中的lookupvalue函数
  • Oracle In子句
  • 每日OJ题_牛客_春游_贪心+数学_C++_Java
  • Spark:背压机制
  • 南山前海13元一份的猪脚饭
  • mysql 几种启动和关闭mysql方法介绍
  • 青少年编程与数学 02-003 Go语言网络编程 18课题、Go语言Session编程