当前位置: 首页 > article >正文

手撕SwiGLU和GELU

GELU(Gaussian Error Linear Unit):
  • 公式
    GELU ( x ) = x ⋅ Φ ( x ) = x ⋅ 1 2 ( 1 + erf ( x 2 ) ) \text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right) GELU(x)=xΦ(x)=x21(1+erf(2 x))

  • 近似公式(在实践中经常使用的版本):
    GELU ( x ) ≈ 0.5 ⋅ x ⋅ ( 1 + tanh ⁡ ( 2 π ⋅ ( x + 0.044715 ⋅ x 3 ) ) ) \text{GELU}(x) \approx 0.5 \cdot x \cdot \left(1 + \tanh\left(\sqrt{\frac{2}{\pi}} \cdot (x + 0.044715 \cdot x^3)\right)\right) GELU(x)0.5x(1+tanh(π2 (x+0.044715x3)))

SwiGLU(Swish-Gated Linear Unit):
  • 公式
    SwiGLU ( x ) = σ ( Linear ( x 1 ) ) ⋅ Swish ( Linear ( x 2 ) ) \text{SwiGLU}(x) = \sigma(\text{Linear}(x_1)) \cdot \text{Swish}(\text{Linear}(x_2)) SwiGLU(x)=σ(Linear(x1))Swish(Linear(x2))
    其中,Swish 是一个平滑的激活函数:
    Swish ( x ) = x ⋅ σ ( x ) = x 1 + e − x \text{Swish}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}} Swish(x)=xσ(x)=1+exx
GELU 实现(PyTorch 内置):
import torch
import torch.nn as nn

# GELU 激活函数 (PyTorch 内置)
gelu = nn.GELU()

# 输入张量
x = torch.randn(2, 5)
output = gelu(x)
print(output)
import torch
import torch.nn as nn

class GELUApprox(nn.Module):
    def forward(self, x):
        # GELU 近似实现
        return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * x ** 3)))

# 示例
x = torch.randn(2, 5)
gelu_approx = GELUApprox()
output = gelu_approx(x)
print(output)

SwiGLU 实现:
import torch
import torch.nn as nn

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super(SwiGLU, self).__init__()
        # 两个线性层,用于将输入拆分成两部分
        self.linear1 = nn.Linear(d_model, d_model)
        self.linear2 = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        return torch.sigmoid(self.linear1(x)) * torch.nn.functional.silu(self.linear2(x))  # SiLU 是 Swish 的实现

# 输入张量
x = torch.randn(2, 5)

# SwiGLU 激活函数
swiglu = SwiGLU(d_model=5)
output = swiglu(x)
print(output)

http://www.kler.cn/news/340656.html

相关文章:

  • 基于依赖注入技术的.net core WebApi框架创建实例
  • 前端开发中的高级技巧与最佳实践
  • 天气API接口调用
  • 【树形DP】AT_dp_p Independent Set 题解
  • yolov8/9/10/11模型在中医舌苔分类识别中的应用【代码+数据集+python环境+GUI系统】
  • 【2024】前端学习笔记11-网页布局-弹性布局flex
  • 【C++】输入输出缺省参数
  • k8s的pod管理及优化
  • linux线程 | 一篇文章带你理解线程的概念
  • STM32单片机(F03C8T6)-点灯(寄存器点灯和库函数点灯)
  • oracle查询表空间信息
  • 「小土堆」pytorch DataSet
  • Sequelize 做登录查询数据
  • OBOO鸥柏:布局于为无人机展厅行产业提供LCD液晶显示终端
  • 【TypeScript】知识点梳理(三)
  • 设计师找素材,收藏好这8个网站
  • 注意,学会解决路由问题!(未完)
  • 【AI知识点】机器学习中的常用优化算法(梯度下降、SGD、Adam等)
  • sqli-labs less-20 less-21 less-22 cookie注入
  • 【JNI】hello world