当前位置: 首页 > article >正文

llama源码学习·model.py[2]SwiGLU激活函数

一、激活函数的目的

激活函数的目的是为网络引入非线性,并使其能够学习并逼近复杂的数据模式

二、介绍GLU(Gated Linear Unit)

GLU:将输入分成两部分,一部分直接经过线性变换,另一部分经过 s i g m o i d sigmoid sigmoid 函数变换,然后将这两部分的输出逐点相乘

G L U ( x , W , V , B , c ) = σ ( x W + b ) ⊗ ( x V + c ) GLU(x, W, V, B, c) = \sigma (xW + b) \otimes (xV + c) GLU(x,W,V,B,c)=σ(xW+b)(xV+c)

  • $ \sigma $ 是 $ sigmoid $ 激活函数
  • $ W, V $ 权重
  • $ b, c $ 偏置

绘制GLU激活函数

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# 定义GLU激活函数
class GLU(nn.Module):
    def forward(self, x):
        a, b = x.chunk(2, dim=-1)  
        print('a:', a, 'b:', b)
        return a * F.sigmoid(b)  # 应用sigmoid函数然后进行逐元素乘法(权重和偏置为1)

# 实例化GLU模块
glu = GLU()


# torch.linspace(-3, 3, 100):在-3到3中生成一个等距的一维数组,数量为100个
# unsqueeze(-1)将原先 100 个元素 的一维数组,转换成 100*1 的二维数组
# expand(-1, 2)  复制 100*1的单列,生成 100*2的两列
x_range = torch.linspace(-3, 3, 100).unsqueeze(-1).expand(-1, 2)  

y_glu = glu(x_range) # 得到经过GLU变换的结果

plt.figure(figsize=(10, 4))
plt.plot(x_range[:, 0].numpy(), y_glu.detach().numpy(), label='GLU Function')
plt.xlabel('Input value')
plt.ylabel('Output value')
plt.title('GLU Activation Function') 
plt.legend()
plt.grid(True) 
plt.show()

在这里插入图片描述

三、介绍Swish激活函数

$ SwiGLU $ 是 $ GLU $ 的一种变体,其中包含了 G L U GLU GLU S w i s h Swish Swish 激活函数。

S w i s h β ( x ) = x σ ( β x ) Swish_{\beta}(x) = x \sigma(\beta x) Swishβ(x)=xσ(βx)

  • $ \beta $ 是一个可学习参数
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
class Swish(nn.Module):
    def forward(self, x, beta):
        print(x)
        return x * F.sigmoid(beta * x) 
swish = Swish()
x_range = torch.linspace(-3, 3, 100).unsqueeze(-1) 
betas = [0.1, 1.0, 10.0]
plt.figure(figsize=(10, 4)) 
for beta in betas:
    y_swish = swish(x_range, beta)
    plt.plot(x, y_swish, label=f'beta={beta}')
plt.xlabel('Input value')
plt.ylabel('Output value')
plt.title('Swish Activation Function') 
plt.legend()
plt.grid(True) 
plt.show()

在这里插入图片描述

四、介绍SwiGLU

G L U GLU GLU 中的激活函数 s i g m o i d sigmoid sigmoid 改为 S w i s h Swish Swish 就是 S w i G L U SwiGLU SwiGLU 激活函数。

S w i G L U ( x , W , V , B , c ) = S w i s h β ( x W + b ) ⊗ ( x V + c ) SwiGLU(x, W, V, B, c) = Swish_\beta(xW + b) \otimes (xV + c) SwiGLU(x,W,V,B,c)=Swishβ(xW+b)(xV+c)

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

class SwiGLU(nn.Module):
    def forward(self, x):
        a, b = x.chunk(2, dim=-1) 
        return a * F.silu(b)  # 使用Swish激活函数,F.silu就是Swish

swiglu = SwiGLU()
x_range = torch.linspace(-3, 3, 100)  # 创建一个范围为-3到3的线性空间
y_swiglu = swiglu(x_range.unsqueeze(-1).expand(-1, 2))  # 应用 SwiGLU 函数,确保维度是偶数


# 绘制 SwiGLU 函数的图像
plt.figure(figsize=(10, 4))
plt.plot(x_range.numpy(), y_swiglu.detach().numpy(), label='SwiGLU Function')
plt.xlabel('Input value')
plt.ylabel('Output value')
plt.title('SwiGLU Activation Function')
plt.legend()
plt.grid(True)
plt.show()

在这里插入图片描述

五、GLU 和 SwiGLU 的区别

仅为 G L U GLU GLU 使用 s i g m o i d sigmoid sigmoid S w i G L U SwiGLU SwiGLU 使用 S w i s h Swish Swish


http://www.kler.cn/a/582690.html

相关文章:

  • docker部署jenkins,安装使用一条龙教程
  • Chrome 扩展开发 API实战:Extension(五)
  • 基于javaweb的SpringBoot+MyBatis实验室管理系统设计和实现(源码+文档+部署讲解)
  • SSH 安全致命漏洞:渗透路径与防御策略
  • Java 实现 WebSocket 客户端
  • 麒麟操作系统和统信的区别,上面一般用什么OFFICE,excel软件?
  • git subtree更新子仓库的方式
  • java项目之基于ssm的在线学习系统(源码+文档)
  • EG82088串口边缘计算网关
  • 蓝桥杯——又是二分
  • Flutter 小技巧之通过 MediaQuery 优化 App 性能
  • Spring Boot 项目零风险升级 Tomcat 指南:锁定版本也能修复漏洞
  • 【Leetcode 每日一题】2269. 找到一个数字的 K 美丽值
  • Python+jupyter进行数据分析与数据挖掘
  • Docker基础入门(一)
  • React 中如何实现表单的受控组件?
  • Linux_17进程控制
  • Flink 1.17.2 版本用 java 读取 starrocks
  • c#如何直接获取json中的某个值
  • Java中的加盐加密:提升密码存储安全性的关键实践