当前位置: 首页 > article >正文

激活函数 05 ——Swish

Swish背景


发展阶段典型函数主要特性局限性
早期阶段Sigmoid/Tanh平滑可导,输出有界梯度消失问题
现代阶段ReLU计算高效,缓解梯度消失神经元死亡现象
改进阶段LeakyReLU改善负区间响应参数敏感性
新星阶段Swish/GELU自适应非线性计算复杂度略高

Swish激活函数由Google Brain团队在2017年首次提出,论文:PDF

Swish ( x ) = x ⋅ σ ( β x ) \text{Swish}(x) = x \cdot \sigma(\beta x) \quad Swish(x)=xσ(βx)
其中  σ ( z ) = 1 1 + e − z \text{其中} \ \sigma(z) = \frac{1}{1+e^{-z}} 其中 σ(z)=1+ez1

原始版本中 β=1,后续研究发现可训练参数β能获得更好效果。该函数结合了ReLU的线性响应特性和Sigmoid的平滑特性。建议从固定 β=1 开始,当模型参数量>1M时考虑可训练参数版本。

函数形态

令β=1时的函数形态:

  • 正区间 lim ⁡ x → + ∞ Swish ( x ) = x \lim_{x→+\infty}\text{Swish}(x)=x limx+Swish(x)=x(与ReLU渐近一致)
  • 负区间 lim ⁡ x → − ∞ Swish ( x ) = 0 \lim_{x→-∞}\text{Swish}(x)=0 limxSwish(x)=0(保留有限梯度)
  • 原点特性 Swish ( 0 ) = 0 ⋅ σ ( 0 ) = 0 \text {Swish}(0) = 0·σ(0) = 0 Swish(0)=0σ(0)=0
  • 非单调性:在 x ∈ ( − 1.278 , 0 ) x∈(-1.278, 0) x(1.278,0) 区间呈现局部极小值

在这里插入图片描述

d d x Swish ( x ) = σ ( x ) + x ⋅ σ ( x ) ( 1 − σ ( x ) ) = σ ( x ) ( 1 + x − x σ ( x ) ) = Swish ( x ) + σ ( x ) ( 1 − Swish ( x ) ) \begin{aligned} \frac{d}{dx}\text{Swish}(x) &= \sigma(x) + x \cdot \sigma(x)(1-\sigma(x)) \\ &= \sigma(x)(1 + x - x\sigma(x)) \\ &= \text{Swish}(x) + \sigma(x)(1 - \text{Swish}(x)) \end{aligned} dxdSwish(x)=σ(x)+xσ(x)(1σ(x))=σ(x)(1+xxσ(x))=Swish(x)+σ(x)(1Swish(x))

导数特性:最大梯度值约为1.099,出现在x≈1.0处;负区间保持非零梯度(解决ReLU死亡问题);梯度曲线二阶可导,有利于高阶优化方法

参数化Swish变体

后续研究提出可训练参数版本:

Swish β ( x ) = x ⋅ σ ( β x ) \text{Swish}_β(x) = x \cdot \sigma(\beta x) Swishβ(x)=xσ(βx)

通过反向传播自动学习β参数,实验显示在图像分类任务中β常收敛到区间 [ 1.0 , 1.5 ] [1.0,1.5] [1.0,1.5]

适合场景

深层网络:尤其适合50层以上的深度架构
低初始化场景:对参数初始化敏感性低于ReLU
长期训练:在超过100epoch的训练中优势更明显

计算开销:相比ReLU增加约15%的计算量
硬件优化:需要开启自动混合精度训练(AMP)
初始化策略:建议配合He正态初始化

Pytorch实现

import torch
import torch.nn as nn

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

自动求导特性实现内存高效

class MemoryEfficientSwish(nn.Module):
    class F(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x):
            ctx.save_for_backward(x)
            return x * torch.sigmoid(x)

        @staticmethod
        def backward(ctx, grad_output):
            x = ctx.saved_tensors[0]
            sigmoid_x = torch.sigmoid(x)
            return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))

    def forward(self, x):
        return self.F.apply(x)

参数化

class ParametricSwish(nn.Module):
    def __init__(self, beta=1.0, trainable=True):
        super().__init__()
        self.beta = nn.Parameter(torch.tensor(beta), requires_grad=trainable)
        
    def forward(self, x):
        return x * torch.sigmoid(self.beta * x)

http://www.kler.cn/a/542810.html

相关文章:

  • 【DeepSeek服务器繁忙,请稍后再试...如何解决?】
  • chrome-mojo C++ Bindings API
  • SSM仓库物品管理系统 附带详细运行指导视频
  • 【文本处理】如何在批量WORD和txt文本提取手机号码,固话号码,提取邮箱,删除中文,删除英文,提取车牌号等等一些文本提取固定格式的操作,基于WPF的解决方案
  • SwiftUI 中 .overlay 两种写法的区别及拓展
  • 第四十章:职场转折:突破困境,重新出发
  • 二、通义灵码插件保姆级教学-IDEA(使用篇)
  • jenkins备份还原配置文件
  • PHP的JIT编译器
  • Lombok使用指南
  • 使用mermaid画流程图
  • ubuntu22.04 git clone问题
  • Springboot集成Milvus和Embedding服务,实现向量化检索
  • vue3自定义提示框和下载
  • 1313:【例3.5】位数问题
  • 【python】http.server内置库构建临时文件服务
  • 【Vue2】vue2项目中如何使用mavon-editor编辑器,数据如何回显到网页,如何回显到编辑器二次编辑
  • 玩转工厂模式
  • 【Unity】【VR开发】如何让手主动吸附到物体上
  • Linux 实操篇 时间日期类、搜索查找类、压缩和解压类
  • 高效利用Python爬虫开发批量获取商品信息
  • Stylelint 如何处理 CSS 预处理器
  • 微服务中如何使用openfeign上传文件
  • 【Oracle专栏】本地 expdp 导出远程库
  • 免费申请 | FRDM-MCXA156评测活动发布!
  • 01-SDRAM控制器的设计——案例总概述