时序预测:多头注意力+宽度学习
✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。
我是Srlua小谢,在这里我会分享我的知识和经验。🎥
希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮
记得先点赞👍后阅读哦~ 👏👏
📘📚 所属专栏:传知代码论文复现
欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙
目录
概述
文章的主要贡献点
Multi-Attn整体架构
基于BLS随机映射的非线性动态特征重新激活
利用多头注意力机制进行多层语义信息提取
核心代码复现
代码优化方法
实验结果
使用方式
环境配置
文件结构
概述
Liyun Su, Lang Xiong和Jialing Yang在2024年发表了题为“Multi-Attn BLS: Multi-head attention mechanism with broad learning system for chaotic time series prediction”的论文,发表在《Applied Soft Computing》杂志上(CiteScore14.3,影响因子8.7)。这篇论文针对混沌时间序列数据的高复杂性和非线性提出了一种新的范式,即将宽度学习模型与多头自注意力机制相结合。在此之前,将这两种高度非线性映射算法融合的主要方法是使用堆叠的多头自注意力来提取特征,然后使用宽度学习模型进行分类预测。这篇论文提出了一种直接将多头注意力模块集成到宽度学习中的方法,从而实现了端到端的预测模型。
需要本文的详细复现过程的项目源码、数据和预训练好的模型可从该地址处获取完整版:地址
深度神经网络虽然具有残差连接来确保信息完整性,但需要较长的训练时间。宽度学习模型则采用级联结构实现信息重用,保证原始信息的完整性。它是一个单一、简单且专门化的网络,无需重新训练,并具有大多数机器学习模型的快速解决能力和大多数深度学习模型的拟合能力。对于宽度学习模型的更深入理解,请参阅原文(链接提供)。此外,该论文指出,多头注意力机制能够充分提取不同维度和层次的关键特征,并有效利用这些关键特征。他们通过列举之前的研究表明,带有注意力机制的模型可以通过捕获一部分语义信息来确保信息的有效性,从而在不同层次捕获丰富的信息。
因此,作者提出了使用宽度学习系统(BLS)来扩展混沌时间序列数据的维度,并引入多头注意力机制来提取不同级别的语义信息,包括线性和非线性相关性、混沌机制和噪声。同时,他们还利用残差连接来确保信息完整性。
文章的主要贡献点
1、提出了一种名为“Multi-Attn BLS”的BLS新范式,可以用于动态建模混沌时序数据。该模型可以通过级联和注意机制最大程度地丰富固定特征,并从混沌时间序列系统中有效提取语义信息。 2、Multi-Attn BLS使用带有位置编码的多头注意力机制来学习复杂的混沌时间序列模式,并通过捕捉时空关系最大化地提取语义信息。 3、Multi-Attn BLS在三个基准测试上取得了出色的预测效果,其它在混沌时间序列中也具有很强的可解释性。
Multi-Attn整体架构
Multi-Attn BLS主要可分为三个部分:1)混沌时序数据预处理;2)基于BLS随机映射的非线性动态特征重新激活;3)利用多头注意力机制进行多层语义信息提取。
首先,根据相空间重构理论,Liyun Su,Lang Xiong和Jialing Yang使用C-C方法来解决嵌入维度和延迟时间,以恢复混沌系统,并将混沌时间序列转变为可预测模式。然后,重新构建的混沌时间序列数据被BLS的特征层和增强层随机映射并增强到高维系统,从而生成含有不同模式的混沌时间序列的混合特征。最后,使用多头注意力机制和残差连接来提取系统中保留的时空关系,包括线性相关、非线性确定性和噪声。
混沌时序数据预处理:基于相空间重构理论的混沌系统恢复 混沌时间序列是动力系统产生的单变量或多变量时间序列。相空间重构定理将混沌时间序列映射到高维空间,以重构原始动力系统的一组表示。根据Takens嵌入定理,必须重构相空间以恢复原始的混沌吸引子,并使混沌系统的时间序列获得固定维度。
基于BLS随机映射的非线性动态特征重新激活
BLS的整体架构如上图所示,在这里我们实际上只用到了它的映射能力,即特征节点层和增强节点层,也就是上面的mapping feature nodes和enhancement feature nodes。这两层的搭建方法如下:
利用多头注意力机制进行多层语义信息提取
在这一部分,作者使用了堆叠的多头自注意力机制来处理经BLS映射得到的高维节点。因此,我们着重介绍了多头自注意力机制的原理。
多头自注意力机制的推导过程如下所示:首先,我们定义了多头自注意力操作:
其中,每个头的计算过程为:
再经过足够多的多头注意力模块处理后,作者使用了一个全连接层将结果映射到输出空间。
核心代码复现
在本文中,我们主要关注MultiAttn-BLS中多头自注意力机制和BLS模型的融合,对时序数据预处理的复现不是本文重点。在这里给出由笔者复现的MultiAttn-BLS代码,代码采用pytorch框架搭建模型框架:
<span style="background-color:#f8f8f8"><span style="color:#333333"><span style="color:#aa5500"># BLS映射层</span>
<span style="color:#770088">import</span> <span style="color:#000000">numpy</span> <span style="color:#770088">as</span> <span style="color:#000000">np</span>
<span style="color:#770088">from</span> <span style="color:#000000">sklearn</span> <span style="color:#770088">import</span> <span style="color:#000000">preprocessing</span>
<span style="color:#770088">from</span> <span style="color:#000000">numpy</span> <span style="color:#770088">import</span> <span style="color:#000000">random</span>
<span style="color:#770088">from</span> <span style="color:#000000">scipy</span> <span style="color:#770088">import</span> <span style="color:#000000">linalg</span> <span style="color:#770088">as</span> <span style="color:#000000">LA</span>
<span style="color:#770088">def</span> <span style="color:#0000ff">show_accuracy</span>(<span style="color:#000000">predictLabel</span>, <span style="color:#000000">Label</span>):
<span style="color:#000000">count</span> <span style="color:#981a1a">=</span> <span style="color:#116644">0</span>
<span style="color:#000000">label_1</span> <span style="color:#981a1a">=</span> <span style="color:#000000">Label</span>.<span style="color:#000000">argmax</span>(<span style="color:#000000">axis</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>)
<span style="color:#000000">predlabel</span> <span style="color:#981a1a">=</span> <span style="color:#000000">predictLabel</span>.<span style="color:#000000">argmax</span>(<span style="color:#000000">axis</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>)
<span style="color:#770088">for</span> <span style="color:#000000">j</span> <span style="color:#770088">in</span> <span style="color:#3300aa">list</span>(<span style="color:#3300aa">range</span>(<span style="color:#000000">Label</span>.<span style="color:#000000">shape</span>[<span style="color:#116644">0</span>])):
<span style="color:#770088">if</span> <span style="color:#000000">label_1</span>[<span style="color:#000000">j</span>] <span style="color:#981a1a">==</span> <span style="color:#000000">predlabel</span>[<span style="color:#000000">j</span>]:
<span style="color:#000000">count</span> <span style="color:#981a1a">+=</span> <span style="color:#116644">1</span>
<span style="color:#770088">return</span> (<span style="color:#3300aa">round</span>(<span style="color:#000000">count</span> <span style="color:#981a1a">/</span> <span style="color:#3300aa">len</span>(<span style="color:#000000">Label</span>), <span style="color:#116644">5</span>))
<span style="color:#770088">def</span> <span style="color:#0000ff">tansig</span>(<span style="color:#000000">x</span>):
<span style="color:#770088">return</span> (<span style="color:#116644">2</span> <span style="color:#981a1a">/</span> (<span style="color:#116644">1</span> <span style="color:#981a1a">+</span> <span style="color:#000000">np</span>.<span style="color:#000000">exp</span>(<span style="color:#981a1a">-</span><span style="color:#116644">2</span> <span style="color:#981a1a">*</span> <span style="color:#000000">x</span>))) <span style="color:#981a1a">-</span> <span style="color:#116644">1</span>
<span style="color:#770088">def</span> <span style="color:#0000ff">sigmoid</span>(<span style="color:#000000">data</span>):
<span style="color:#770088">return</span> <span style="color:#116644">1.0</span> <span style="color:#981a1a">/</span> (<span style="color:#116644">1</span> <span style="color:#981a1a">+</span> <span style="color:#000000">np</span>.<span style="color:#000000">exp</span>(<span style="color:#981a1a">-</span><span style="color:#000000">data</span>))
<span style="color:#770088">def</span> <span style="color:#0000ff">linear</span>(<span style="color:#000000">data</span>):
<span style="color:#770088">return</span> <span style="color:#000000">data</span>
<span style="color:#770088">def</span> <span style="color:#0000ff">tanh</span>(<span style="color:#000000">data</span>):
<span style="color:#770088">return</span> (<span style="color:#000000">np</span>.<span style="color:#000000">exp</span>(<span style="color:#000000">data</span>) <span style="color:#981a1a">-</span> <span style="color:#000000">np</span>.<span style="color:#000000">exp</span>(<span style="color:#981a1a">-</span><span style="color:#000000">data</span>)) <span style="color:#981a1a">/</span> (<span style="color:#000000">np</span>.<span style="color:#000000">exp</span>(<span style="color:#000000">data</span>) <span style="color:#981a1a">+</span> <span style="color:#000000">np</span>.<span style="color:#000000">exp</span>(<span style="color:#981a1a">-</span><span style="color:#000000">data</span>))
<span style="color:#770088">def</span> <span style="color:#0000ff">relu</span>(<span style="color:#000000">data</span>):
<span style="color:#770088">return</span> <span style="color:#000000">np</span>.<span style="color:#000000">maximum</span>(<span style="color:#000000">data</span>, <span style="color:#116644">0</span>)
<span style="color:#770088">def</span> <span style="color:#0000ff">pinv</span>(<span style="color:#000000">A</span>, <span style="color:#000000">reg</span>):
<span style="color:#770088">return</span> <span style="color:#000000">np</span>.<span style="color:#000000">mat</span>(<span style="color:#000000">reg</span> <span style="color:#981a1a">*</span> <span style="color:#000000">np</span>.<span style="color:#000000">eye</span>(<span style="color:#000000">A</span>.<span style="color:#000000">shape</span>[<span style="color:#116644">1</span>]) <span style="color:#981a1a">+</span> <span style="color:#000000">A</span>.<span style="color:#000000">T</span>.<span style="color:#000000">dot</span>(<span style="color:#000000">A</span>)).<span style="color:#000000">I</span>.<span style="color:#000000">dot</span>(<span style="color:#000000">A</span>.<span style="color:#000000">T</span>)
<span style="color:#770088">def</span> <span style="color:#0000ff">shrinkage</span>(<span style="color:#000000">a</span>, <span style="color:#000000">b</span>):
<span style="color:#000000">z</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">maximum</span>(<span style="color:#000000">a</span> <span style="color:#981a1a">-</span> <span style="color:#000000">b</span>, <span style="color:#116644">0</span>) <span style="color:#981a1a">-</span> <span style="color:#000000">np</span>.<span style="color:#000000">maximum</span>(<span style="color:#981a1a">-</span><span style="color:#000000">a</span> <span style="color:#981a1a">-</span> <span style="color:#000000">b</span>, <span style="color:#116644">0</span>)
<span style="color:#770088">return</span> <span style="color:#000000">z</span>
<span style="color:#770088">def</span> <span style="color:#0000ff">sparse_bls</span>(<span style="color:#000000">A</span>, <span style="color:#000000">b</span>): <span style="color:#aa5500">#A:映射后每个窗口的节点,b:加入bias的输入数据</span>
<span style="color:#000000">lam</span> <span style="color:#981a1a">=</span> <span style="color:#116644">0.001</span>
<span style="color:#000000">itrs</span> <span style="color:#981a1a">=</span> <span style="color:#116644">50</span>
<span style="color:#000000">AA</span> <span style="color:#981a1a">=</span> <span style="color:#000000">A</span>.<span style="color:#000000">T</span>.<span style="color:#000000">dot</span>(<span style="color:#000000">A</span>)
<span style="color:#000000">m</span> <span style="color:#981a1a">=</span> <span style="color:#000000">A</span>.<span style="color:#000000">shape</span>[<span style="color:#116644">1</span>]
<span style="color:#000000">n</span> <span style="color:#981a1a">=</span> <span style="color:#000000">b</span>.<span style="color:#000000">shape</span>[<span style="color:#116644">1</span>]
<span style="color:#000000">x1</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">zeros</span>([<span style="color:#000000">m</span>, <span style="color:#000000">n</span>])
<span style="color:#000000">wk</span> <span style="color:#981a1a">=</span> <span style="color:#000000">x1</span>
<span style="color:#000000">ok</span> <span style="color:#981a1a">=</span> <span style="color:#000000">x1</span>
<span style="color:#000000">uk</span> <span style="color:#981a1a">=</span> <span style="color:#000000">x1</span>
<span style="color:#000000">L1</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">mat</span>(<span style="color:#000000">AA</span> <span style="color:#981a1a">+</span> <span style="color:#000000">np</span>.<span style="color:#000000">eye</span>(<span style="color:#000000">m</span>)).<span style="color:#000000">I</span>
<span style="color:#000000">L2</span> <span style="color:#981a1a">=</span> (<span style="color:#000000">L1</span>.<span style="color:#000000">dot</span>(<span style="color:#000000">A</span>.<span style="color:#000000">T</span>)).<span style="color:#000000">dot</span>(<span style="color:#000000">b</span>)
<span style="color:#770088">for</span> <span style="color:#000000">i</span> <span style="color:#770088">in</span> <span style="color:#3300aa">range</span>(<span style="color:#000000">itrs</span>):
<span style="color:#000000">ck</span> <span style="color:#981a1a">=</span> <span style="color:#000000">L2</span> <span style="color:#981a1a">+</span> <span style="color:#000000">np</span>.<span style="color:#000000">dot</span>(<span style="color:#000000">L1</span>, (<span style="color:#000000">ok</span> <span style="color:#981a1a">-</span> <span style="color:#000000">uk</span>))
<span style="color:#000000">ok</span> <span style="color:#981a1a">=</span> <span style="color:#000000">shrinkage</span>(<span style="color:#000000">ck</span> <span style="color:#981a1a">+</span> <span style="color:#000000">uk</span>, <span style="color:#000000">lam</span>)
<span style="color:#000000">uk</span> <span style="color:#981a1a">=</span> <span style="color:#000000">uk</span> <span style="color:#981a1a">+</span> <span style="color:#000000">ck</span> <span style="color:#981a1a">-</span> <span style="color:#000000">ok</span>
<span style="color:#000000">wk</span> <span style="color:#981a1a">=</span> <span style="color:#000000">ok</span>
<span style="color:#770088">return</span> <span style="color:#000000">wk</span>
<span style="color:#770088">def</span> <span style="color:#0000ff">generate_mappingFeaturelayer</span>(<span style="color:#000000">train_x</span>, <span style="color:#000000">FeatureOfInputDataWithBias</span>, <span style="color:#000000">N1</span>, <span style="color:#000000">N2</span>, <span style="color:#000000">OutputOfFeatureMappingLayer</span>, <span style="color:#000000">u</span><span style="color:#981a1a">=</span><span style="color:#116644">0</span>):
<span style="color:#000000">Beta1OfEachWindow</span> <span style="color:#981a1a">=</span> <span style="color:#3300aa">list</span>()
<span style="color:#000000">distOfMaxAndMin</span> <span style="color:#981a1a">=</span> []
<span style="color:#000000">minOfEachWindow</span> <span style="color:#981a1a">=</span> []
<span style="color:#770088">for</span> <span style="color:#000000">i</span> <span style="color:#770088">in</span> <span style="color:#3300aa">range</span>(<span style="color:#000000">N2</span>):
<span style="color:#000000">random</span>.<span style="color:#000000">seed</span>(<span style="color:#000000">i</span> <span style="color:#981a1a">+</span> <span style="color:#000000">u</span>)
<span style="color:#000000">weightOfEachWindow</span> <span style="color:#981a1a">=</span> <span style="color:#116644">2</span> <span style="color:#981a1a">*</span> <span style="color:#000000">random</span>.<span style="color:#000000">randn</span>(<span style="color:#000000">train_x</span>.<span style="color:#000000">shape</span>[<span style="color:#116644">1</span>] <span style="color:#981a1a">+</span> <span style="color:#116644">1</span>, <span style="color:#000000">N1</span>) <span style="color:#981a1a">-</span> <span style="color:#116644">1</span>
<span style="color:#000000">FeatureOfEachWindow</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">dot</span>(<span style="color:#000000">FeatureOfInputDataWithBias</span>, <span style="color:#000000">weightOfEachWindow</span>)
<span style="color:#000000">scaler1</span> <span style="color:#981a1a">=</span> <span style="color:#000000">preprocessing</span>.<span style="color:#000000">MinMaxScaler</span>(<span style="color:#000000">feature_range</span><span style="color:#981a1a">=</span>(<span style="color:#981a1a">-</span><span style="color:#116644">1</span>, <span style="color:#116644">1</span>)).<span style="color:#000000">fit</span>(<span style="color:#000000">FeatureOfEachWindow</span>)
<span style="color:#000000">FeatureOfEachWindowAfterPreprocess</span> <span style="color:#981a1a">=</span> <span style="color:#000000">scaler1</span>.<span style="color:#000000">transform</span>(<span style="color:#000000">FeatureOfEachWindow</span>)
<span style="color:#000000">betaOfEachWindow</span> <span style="color:#981a1a">=</span> <span style="color:#000000">sparse_bls</span>(<span style="color:#000000">FeatureOfEachWindowAfterPreprocess</span>, <span style="color:#000000">FeatureOfInputDataWithBias</span>).<span style="color:#000000">T</span>
<span style="color:#aa5500"># betaOfEachWindow = graph_autoencoder(FeatureOfInputDataWithBias, FeatureOfEachWindowAfterPreprocess).T</span>
<span style="color:#000000">Beta1OfEachWindow</span>.<span style="color:#000000">append</span>(<span style="color:#000000">betaOfEachWindow</span>)
<span style="color:#000000">outputOfEachWindow</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">dot</span>(<span style="color:#000000">FeatureOfInputDataWithBias</span>, <span style="color:#000000">betaOfEachWindow</span>)
<span style="color:#000000">distOfMaxAndMin</span>.<span style="color:#000000">append</span>(<span style="color:#000000">np</span>.<span style="color:#000000">max</span>(<span style="color:#000000">outputOfEachWindow</span>, <span style="color:#000000">axis</span><span style="color:#981a1a">=</span><span style="color:#116644">0</span>) <span style="color:#981a1a">-</span> <span style="color:#000000">np</span>.<span style="color:#000000">min</span>(<span style="color:#000000">outputOfEachWindow</span>, <span style="color:#000000">axis</span><span style="color:#981a1a">=</span><span style="color:#116644">0</span>))
<span style="color:#000000">minOfEachWindow</span>.<span style="color:#000000">append</span>(<span style="color:#000000">np</span>.<span style="color:#000000">mean</span>(<span style="color:#000000">outputOfEachWindow</span>, <span style="color:#000000">axis</span><span style="color:#981a1a">=</span><span style="color:#116644">0</span>))
<span style="color:#000000">outputOfEachWindow</span> <span style="color:#981a1a">=</span> (<span style="color:#000000">outputOfEachWindow</span> <span style="color:#981a1a">-</span> <span style="color:#000000">minOfEachWindow</span>[<span style="color:#000000">i</span>]) <span style="color:#981a1a">/</span> <span style="color:#000000">distOfMaxAndMin</span>[<span style="color:#000000">i</span>]
<span style="color:#000000">OutputOfFeatureMappingLayer</span>[:, <span style="color:#000000">N1</span> <span style="color:#981a1a">*</span> <span style="color:#000000">i</span>:<span style="color:#000000">N1</span> <span style="color:#981a1a">*</span> (<span style="color:#000000">i</span> <span style="color:#981a1a">+</span> <span style="color:#116644">1</span>)] <span style="color:#981a1a">=</span> <span style="color:#000000">outputOfEachWindow</span>
<span style="color:#770088">return</span> <span style="color:#000000">OutputOfFeatureMappingLayer</span>, <span style="color:#000000">Beta1OfEachWindow</span>, <span style="color:#000000">distOfMaxAndMin</span>, <span style="color:#000000">minOfEachWindow</span>
<span style="color:#770088">def</span> <span style="color:#0000ff">generate_enhancelayer</span>(<span style="color:#000000">OutputOfFeatureMappingLayer</span>, <span style="color:#000000">N1</span>, <span style="color:#000000">N2</span>, <span style="color:#000000">N3</span>, <span style="color:#000000">s</span>):
<span style="color:#000000">InputOfEnhanceLayerWithBias</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">hstack</span>(
[<span style="color:#000000">OutputOfFeatureMappingLayer</span>, <span style="color:#116644">0.1</span> <span style="color:#981a1a">*</span> <span style="color:#000000">np</span>.<span style="color:#000000">ones</span>((<span style="color:#000000">OutputOfFeatureMappingLayer</span>.<span style="color:#000000">shape</span>[<span style="color:#116644">0</span>], <span style="color:#116644">1</span>))])
<span style="color:#770088">if</span> <span style="color:#000000">N1</span> <span style="color:#981a1a">*</span> <span style="color:#000000">N2</span> <span style="color:#981a1a">>=</span> <span style="color:#000000">N3</span>:
<span style="color:#000000">random</span>.<span style="color:#000000">seed</span>(<span style="color:#116644">67797325</span>)
<span style="color:#000000">weightOfEnhanceLayer</span> <span style="color:#981a1a">=</span> <span style="color:#000000">LA</span>.<span style="color:#000000">orth</span>(<span style="color:#116644">2</span> <span style="color:#981a1a">*</span> <span style="color:#000000">random</span>.<span style="color:#000000">randn</span>(<span style="color:#000000">N2</span> <span style="color:#981a1a">*</span> <span style="color:#000000">N1</span> <span style="color:#981a1a">+</span> <span style="color:#116644">1</span>, <span style="color:#000000">N3</span>) <span style="color:#981a1a">-</span> <span style="color:#116644">1</span>)
<span style="color:#770088">else</span>:
<span style="color:#000000">random</span>.<span style="color:#000000">seed</span>(<span style="color:#116644">67797325</span>)
<span style="color:#000000">weightOfEnhanceLayer</span> <span style="color:#981a1a">=</span> <span style="color:#000000">LA</span>.<span style="color:#000000">orth</span>(<span style="color:#116644">2</span> <span style="color:#981a1a">*</span> <span style="color:#000000">random</span>.<span style="color:#000000">randn</span>(<span style="color:#000000">N2</span> <span style="color:#981a1a">*</span> <span style="color:#000000">N1</span> <span style="color:#981a1a">+</span> <span style="color:#116644">1</span>, <span style="color:#000000">N3</span>).<span style="color:#000000">T</span> <span style="color:#981a1a">-</span> <span style="color:#116644">1</span>).<span style="color:#000000">T</span>
<span style="color:#000000">tempOfOutputOfEnhanceLayer</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">dot</span>(<span style="color:#000000">InputOfEnhanceLayerWithBias</span>, <span style="color:#000000">weightOfEnhanceLayer</span>)
<span style="color:#000000">parameterOfShrink</span> <span style="color:#981a1a">=</span> <span style="color:#000000">s</span> <span style="color:#981a1a">/</span> <span style="color:#000000">np</span>.<span style="color:#000000">max</span>(<span style="color:#000000">tempOfOutputOfEnhanceLayer</span>)
<span style="color:#000000">OutputOfEnhanceLayer</span> <span style="color:#981a1a">=</span> <span style="color:#000000">tansig</span>(<span style="color:#000000">tempOfOutputOfEnhanceLayer</span> <span style="color:#981a1a">*</span> <span style="color:#000000">parameterOfShrink</span>)
<span style="color:#770088">return</span> <span style="color:#000000">OutputOfEnhanceLayer</span>, <span style="color:#000000">parameterOfShrink</span>, <span style="color:#000000">weightOfEnhanceLayer</span>
<span style="color:#770088">def</span> <span style="color:#0000ff">BLS_Genfeatures</span>(<span style="color:#000000">train_x</span>, <span style="color:#000000">test_x</span>, <span style="color:#000000">N1</span>, <span style="color:#000000">N2</span>, <span style="color:#000000">N3</span>, <span style="color:#000000">s</span>):
<span style="color:#000000">u</span> <span style="color:#981a1a">=</span> <span style="color:#116644">0</span>
<span style="color:#000000">train_x</span> <span style="color:#981a1a">=</span> <span style="color:#000000">preprocessing</span>.<span style="color:#000000">scale</span>(<span style="color:#000000">train_x</span>, <span style="color:#000000">axis</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>)
<span style="color:#000000">FeatureOfInputDataWithBias</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">hstack</span>([<span style="color:#000000">train_x</span>, <span style="color:#116644">0.1</span> <span style="color:#981a1a">*</span> <span style="color:#000000">np</span>.<span style="color:#000000">ones</span>((<span style="color:#000000">train_x</span>.<span style="color:#000000">shape</span>[<span style="color:#116644">0</span>], <span style="color:#116644">1</span>))])
<span style="color:#000000">OutputOfFeatureMappingLayer</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">zeros</span>([<span style="color:#000000">train_x</span>.<span style="color:#000000">shape</span>[<span style="color:#116644">0</span>], <span style="color:#000000">N2</span> <span style="color:#981a1a">*</span> <span style="color:#000000">N1</span>])
<span style="color:#000000">OutputOfFeatureMappingLayer</span>, <span style="color:#000000">Beta1OfEachWindow</span>, <span style="color:#000000">distOfMaxAndMin</span>, <span style="color:#000000">minOfEachWindow</span> <span style="color:#981a1a">=</span> \
<span style="color:#000000">generate_mappingFeaturelayer</span>(<span style="color:#000000">train_x</span>, <span style="color:#000000">FeatureOfInputDataWithBias</span>, <span style="color:#000000">N1</span>, <span style="color:#000000">N2</span>, <span style="color:#000000">OutputOfFeatureMappingLayer</span>, <span style="color:#000000">u</span>)
<span style="color:#000000">OutputOfEnhanceLayer</span>, <span style="color:#000000">parameterOfShrink</span>, <span style="color:#000000">weightOfEnhanceLayer</span> <span style="color:#981a1a">=</span> \
<span style="color:#000000">generate_enhancelayer</span>(<span style="color:#000000">OutputOfFeatureMappingLayer</span>, <span style="color:#000000">N1</span>, <span style="color:#000000">N2</span>, <span style="color:#000000">N3</span>, <span style="color:#000000">s</span>)
<span style="color:#000000">InputOfOutputLayerTrain</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">hstack</span>([<span style="color:#000000">OutputOfFeatureMappingLayer</span>, <span style="color:#000000">OutputOfEnhanceLayer</span>])
<span style="color:#000000">test_x</span> <span style="color:#981a1a">=</span> <span style="color:#000000">preprocessing</span>.<span style="color:#000000">scale</span>(<span style="color:#000000">test_x</span>, <span style="color:#000000">axis</span><span style="color:#981a1a">=</span><span style="color:#116644">1</span>)
<span style="color:#000000">FeatureOfInputDataWithBiasTest</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">hstack</span>([<span style="color:#000000">test_x</span>, <span style="color:#116644">0.1</span> <span style="color:#981a1a">*</span> <span style="color:#000000">np</span>.<span style="color:#000000">ones</span>((<span style="color:#000000">test_x</span>.<span style="color:#000000">shape</span>[<span style="color:#116644">0</span>], <span style="color:#116644">1</span>))])
<span style="color:#000000">OutputOfFeatureMappingLayerTest</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">zeros</span>([<span style="color:#000000">test_x</span>.<span style="color:#000000">shape</span>[<span style="color:#116644">0</span>], <span style="color:#000000">N2</span> <span style="color:#981a1a">*</span> <span style="color:#000000">N1</span>])
<span style="color:#770088">for</span> <span style="color:#000000">i</span> <span style="color:#770088">in</span> <span style="color:#3300aa">range</span>(<span style="color:#000000">N2</span>):
<span style="color:#000000">outputOfEachWindowTest</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">dot</span>(<span style="color:#000000">FeatureOfInputDataWithBiasTest</span>, <span style="color:#000000">Beta1OfEachWindow</span>[<span style="color:#000000">i</span>])
<span style="color:#000000">OutputOfFeatureMappingLayerTest</span>[:, <span style="color:#000000">N1</span> <span style="color:#981a1a">*</span> <span style="color:#000000">i</span>:<span style="color:#000000">N1</span> <span style="color:#981a1a">*</span> (<span style="color:#000000">i</span> <span style="color:#981a1a">+</span> <span style="color:#116644">1</span>)] <span style="color:#981a1a">=</span> (<span style="color:#000000">outputOfEachWindowTest</span> <span style="color:#981a1a">-</span> <span style="color:#000000">minOfEachWindow</span>[<span style="color:#000000">i</span>]) <span style="color:#981a1a">/</span> \
<span style="color:#000000">distOfMaxAndMin</span>[<span style="color:#000000">i</span>]
<span style="color:#000000">InputOfEnhanceLayerWithBiasTest</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">hstack</span>(
[<span style="color:#000000">OutputOfFeatureMappingLayerTest</span>, <span style="color:#116644">0.1</span> <span style="color:#981a1a">*</span> <span style="color:#000000">np</span>.<span style="color:#000000">ones</span>((<span style="color:#000000">OutputOfFeatureMappingLayerTest</span>.<span style="color:#000000">shape</span>[<span style="color:#116644">0</span>], <span style="color:#116644">1</span>))])
<span style="color:#000000">tempOfOutputOfEnhanceLayerTest</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">dot</span>(<span style="color:#000000">InputOfEnhanceLayerWithBiasTest</span>, <span style="color:#000000">weightOfEnhanceLayer</span>)
<span style="color:#000000">OutputOfEnhanceLayerTest</span> <span style="color:#981a1a">=</span> <span style="color:#000000">tansig</span>(<span style="color:#000000">tempOfOutputOfEnhanceLayerTest</span> <span style="color:#981a1a">*</span> <span style="color:#000000">parameterOfShrink</span>)
<span style="color:#000000">InputOfOutputLayerTest</span> <span style="color:#981a1a">=</span> <span style="color:#000000">np</span>.<span style="color:#000000">hstack</span>([<span style="color:#000000">OutputOfFeatureMappingLayerTest</span>, <span style="color:#000000">OutputOfEnhanceLayerTest</span>])
<span style="color:#770088">return</span> <span style="color:#000000">InputOfOutputLayerTrain</span>, <span style="color:#000000">InputOfOutputLayerTest</span></span></span>
<span style="background-color:#f8f8f8"><span style="color:#333333"><span style="color:#aa5500"># 多头注意力层</span>
<span style="color:#770088">import</span> <span style="color:#000000">torch</span>
<span style="color:#770088">import</span> <span style="color:#000000">math</span>
<span style="color:#770088">import</span> <span style="color:#000000">ssl</span>
<span style="color:#770088">import</span> <span style="color:#000000">numpy</span> <span style="color:#770088">as</span> <span style="color:#000000">np</span>
<span style="color:#770088">import</span> <span style="color:#000000">torch</span>.<span style="color:#000000">nn</span>.<span style="color:#000000">functional</span> <span style="color:#770088">as</span> <span style="color:#000000">F</span>
<span style="color:#770088">import</span> <span style="color:#000000">torch</span>.<span style="color:#000000">nn</span> <span style="color:#770088">as</span> <span style="color:#000000">nn</span>
<span style="color:#000000">ssl</span>.<span style="color:#000000">_create_default_https_context</span> <span style="color:#981a1a">=</span> <span style="color:#000000">ssl</span>.<span style="color:#000000">_create_unverified_context</span>
<span style="color:#770088">def</span> <span style="color:#0000ff">pos_encoding</span>(<span style="color:#000000">OutputOfFeatureMappingLayer</span>, <span style="color:#000000">B</span>, <span style="color:#000000">N</span>, <span style="color:#000000">C</span>):
<span style="color:#000000">OutputOfFeatureMappingLayer</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">tensor</span>(<span style="color:#000000">OutputOfFeatureMappingLayer</span>, <span style="color:#000000">dtype</span><span style="color:#981a1a">=</span><span style="color:#000000">torch</span>.<span style="color:#000000">float</span>).<span style="color:#000000">reshape</span>(<span style="color:#000000">B</span>, <span style="color:#000000">N</span>, <span style="color:#000000">C</span>)
<span style="color:#aa5500"># 定义位置编码的最大序列长度和特征维度</span>
<span style="color:#000000">max_sequence_length</span> <span style="color:#981a1a">=</span> <span style="color:#000000">OutputOfFeatureMappingLayer</span>.<span style="color:#000000">size</span>(<span style="color:#116644">1</span>)
<span style="color:#000000">feature_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">OutputOfFeatureMappingLayer</span>.<span style="color:#000000">size</span>(<span style="color:#116644">2</span>)
<span style="color:#aa5500"># 计算位置编码矩阵</span>
<span style="color:#000000">position_encodings</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">zeros</span>(<span style="color:#000000">max_sequence_length</span>, <span style="color:#000000">feature_dim</span>)
<span style="color:#000000">position</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">arange</span>(<span style="color:#116644">0</span>, <span style="color:#000000">max_sequence_length</span>, <span style="color:#000000">dtype</span><span style="color:#981a1a">=</span><span style="color:#000000">torch</span>.<span style="color:#000000">float</span>).<span style="color:#000000">unsqueeze</span>(<span style="color:#116644">1</span>)
<span style="color:#000000">div_term</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">exp</span>(<span style="color:#000000">torch</span>.<span style="color:#000000">arange</span>(<span style="color:#116644">0</span>, <span style="color:#000000">feature_dim</span>, <span style="color:#116644">2</span>).<span style="color:#000000">float</span>() <span style="color:#981a1a">*</span> (<span style="color:#981a1a">-</span><span style="color:#000000">math</span>.<span style="color:#000000">log</span>(<span style="color:#116644">10000.0</span>) <span style="color:#981a1a">/</span> <span style="color:#000000">feature_dim</span>))
<span style="color:#000000">position_encodings</span>[:, <span style="color:#116644">0</span>::<span style="color:#116644">2</span>] <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">sin</span>(<span style="color:#000000">position</span> <span style="color:#981a1a">*</span> <span style="color:#000000">div_term</span>)
<span style="color:#000000">position_encodings</span>[:, <span style="color:#116644">1</span>::<span style="color:#116644">2</span>] <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">cos</span>(<span style="color:#000000">position</span> <span style="color:#981a1a">*</span> <span style="color:#000000">div_term</span>)
<span style="color:#aa5500"># 将位置编码矩阵扩展为和输入张量 x 的形状一致</span>
<span style="color:#000000">position_encodings</span> <span style="color:#981a1a">=</span> <span style="color:#000000">position_encodings</span>.<span style="color:#000000">unsqueeze</span>(<span style="color:#116644">0</span>).<span style="color:#000000">expand_as</span>(<span style="color:#000000">OutputOfFeatureMappingLayer</span>)
<span style="color:#000000">OutputOfFeatureMappingLayer</span> <span style="color:#981a1a">=</span> <span style="color:#000000">OutputOfFeatureMappingLayer</span> <span style="color:#981a1a">+</span> <span style="color:#000000">position_encodings</span>
<span style="color:#770088">return</span> <span style="color:#000000">OutputOfFeatureMappingLayer</span>
<span style="color:#770088">class</span> <span style="color:#0000ff">MultiHeadSelfAttentionWithResidual</span>(<span style="color:#000000">torch</span>.<span style="color:#000000">nn</span>.<span style="color:#000000">Module</span>):
<span style="color:#770088">def</span> <span style="color:#0000ff">__init__</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">input_dim</span>, <span style="color:#000000">output_dim</span>, <span style="color:#000000">num_heads</span><span style="color:#981a1a">=</span><span style="color:#116644">4</span>):
<span style="color:#3300aa">super</span>(<span style="color:#000000">MultiHeadSelfAttentionWithResidual</span>, <span style="color:#0055aa">self</span>).<span style="color:#000000">__init__</span>()
<span style="color:#0055aa">self</span>.<span style="color:#000000">input_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">input_dim</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">output_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">output_dim</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">num_heads</span> <span style="color:#981a1a">=</span> <span style="color:#000000">num_heads</span>
<span style="color:#aa5500"># 每个注意力头的输出维度</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">head_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">output_dim</span> <span style="color:#981a1a">//</span> <span style="color:#000000">num_heads</span>
<span style="color:#aa5500"># 初始化查询、键、值权重矩阵</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">query_weights</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">nn</span>.<span style="color:#000000">Linear</span>(<span style="color:#000000">input_dim</span>, <span style="color:#000000">output_dim</span>)
<span style="color:#0055aa">self</span>.<span style="color:#000000">key_weights</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">nn</span>.<span style="color:#000000">Linear</span>(<span style="color:#000000">input_dim</span>, <span style="color:#000000">output_dim</span>)
<span style="color:#0055aa">self</span>.<span style="color:#000000">value_weights</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">nn</span>.<span style="color:#000000">Linear</span>(<span style="color:#000000">input_dim</span>, <span style="color:#000000">output_dim</span>)
<span style="color:#aa5500"># 输出权重矩阵</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">output_weights</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">nn</span>.<span style="color:#000000">Linear</span>(<span style="color:#000000">output_dim</span>, <span style="color:#000000">output_dim</span>)
<span style="color:#770088">def</span> <span style="color:#0000ff">forward</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">inputs</span>):
<span style="color:#aa5500"># 输入形状: (batch_size, seq_len, input_dim)</span>
<span style="color:#000000">batch_size</span>, <span style="color:#000000">seq_len</span>, <span style="color:#000000">_</span> <span style="color:#981a1a">=</span> <span style="color:#000000">inputs</span>.<span style="color:#000000">size</span>()
<span style="color:#aa5500"># 计算查询、键、值向量</span>
<span style="color:#000000">queries</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">query_weights</span>(<span style="color:#000000">inputs</span>) <span style="color:#aa5500"># (batch_size, seq_len, output_dim)</span>
<span style="color:#000000">keys</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">key_weights</span>(<span style="color:#000000">inputs</span>) <span style="color:#aa5500"># (batch_size, seq_len, output_dim)</span>
<span style="color:#000000">values</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">value_weights</span>(<span style="color:#000000">inputs</span>) <span style="color:#aa5500"># (batch_size, seq_len, output_dim)</span>
<span style="color:#aa5500"># 将向量分割成多个头</span>
<span style="color:#000000">queries</span> <span style="color:#981a1a">=</span> <span style="color:#000000">queries</span>.<span style="color:#000000">view</span>(<span style="color:#000000">batch_size</span>, <span style="color:#000000">seq_len</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">num_heads</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">head_dim</span>).<span style="color:#000000">transpose</span>(<span style="color:#116644">1</span>,
<span style="color:#116644">2</span>) <span style="color:#aa5500"># (batch_size, num_heads, seq_len, head_dim)</span>
<span style="color:#000000">keys</span> <span style="color:#981a1a">=</span> <span style="color:#000000">keys</span>.<span style="color:#000000">view</span>(<span style="color:#000000">batch_size</span>, <span style="color:#000000">seq_len</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">num_heads</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">head_dim</span>).<span style="color:#000000">transpose</span>(<span style="color:#116644">1</span>,
<span style="color:#116644">2</span>) <span style="color:#aa5500"># (batch_size, num_heads, seq_len, head_dim)</span>
<span style="color:#000000">values</span> <span style="color:#981a1a">=</span> <span style="color:#000000">values</span>.<span style="color:#000000">view</span>(<span style="color:#000000">batch_size</span>, <span style="color:#000000">seq_len</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">num_heads</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">head_dim</span>).<span style="color:#000000">transpose</span>(<span style="color:#116644">1</span>,
<span style="color:#116644">2</span>) <span style="color:#aa5500"># (batch_size, num_heads, seq_len, head_dim)</span>
<span style="color:#aa5500"># 计算注意力分数</span>
<span style="color:#000000">scores</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">matmul</span>(<span style="color:#000000">queries</span>, <span style="color:#000000">keys</span>.<span style="color:#000000">transpose</span>(<span style="color:#981a1a">-</span><span style="color:#116644">2</span>, <span style="color:#981a1a">-</span><span style="color:#116644">1</span>)) <span style="color:#aa5500"># (batch_size, num_heads, seq_len, seq_len)</span>
<span style="color:#aa5500"># 对注意力分数进行缩放</span>
<span style="color:#000000">scores</span> <span style="color:#981a1a">=</span> <span style="color:#000000">scores</span> <span style="color:#981a1a">/</span> <span style="color:#000000">np</span>.<span style="color:#000000">sqrt</span>(<span style="color:#0055aa">self</span>.<span style="color:#000000">head_dim</span>)
<span style="color:#aa5500"># 计算注意力权重</span>
<span style="color:#000000">attention_weights</span> <span style="color:#981a1a">=</span> <span style="color:#000000">F</span>.<span style="color:#000000">softmax</span>(<span style="color:#000000">scores</span>, <span style="color:#000000">dim</span><span style="color:#981a1a">=-</span><span style="color:#116644">1</span>) <span style="color:#aa5500"># (batch_size, num_heads, seq_len, seq_len)</span>
<span style="color:#aa5500"># 使用注意力权重对值向量加权求和</span>
<span style="color:#000000">attention_output</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">matmul</span>(<span style="color:#000000">attention_weights</span>, <span style="color:#000000">values</span>) <span style="color:#aa5500"># (batch_size, num_heads, seq_len, head_dim)</span>
<span style="color:#aa5500"># 将多个头的输出拼接并投影到输出维度</span>
<span style="color:#000000">attention_output</span> <span style="color:#981a1a">=</span> <span style="color:#000000">attention_output</span>.<span style="color:#000000">transpose</span>(<span style="color:#116644">1</span>, <span style="color:#116644">2</span>).<span style="color:#000000">contiguous</span>().<span style="color:#000000">view</span>(<span style="color:#000000">batch_size</span>, <span style="color:#000000">seq_len</span>,
<span style="color:#0055aa">self</span>.<span style="color:#000000">output_dim</span>) <span style="color:#aa5500"># (batch_size, seq_len, output_dim)</span>
<span style="color:#aa5500"># 使用线性层进行输出变换</span>
<span style="color:#000000">output</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">output_weights</span>(<span style="color:#000000">attention_output</span>) <span style="color:#aa5500"># (batch_size, seq_len, output_dim)</span>
<span style="color:#aa5500"># 添加残差连接</span>
<span style="color:#000000">output</span> <span style="color:#981a1a">=</span> <span style="color:#000000">output</span> <span style="color:#981a1a">+</span> <span style="color:#000000">inputs</span>
<span style="color:#770088">return</span> <span style="color:#000000">output</span>
<span style="color:#770088">class</span> <span style="color:#0000ff">FeedForwardLayerWithResidual</span>(<span style="color:#000000">torch</span>.<span style="color:#000000">nn</span>.<span style="color:#000000">Module</span>):
<span style="color:#770088">def</span> <span style="color:#0000ff">__init__</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">input_dim</span>, <span style="color:#000000">hidden_dim</span>):
<span style="color:#3300aa">super</span>(<span style="color:#000000">FeedForwardLayerWithResidual</span>, <span style="color:#0055aa">self</span>).<span style="color:#000000">__init__</span>()
<span style="color:#0055aa">self</span>.<span style="color:#000000">input_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">input_dim</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">hidden_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">hidden_dim</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">output_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">input_dim</span>
<span style="color:#aa5500"># 定义第一个线性层</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">linear1</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">nn</span>.<span style="color:#000000">Linear</span>(<span style="color:#000000">input_dim</span>, <span style="color:#000000">hidden_dim</span>)
<span style="color:#aa5500"># 定义激活函数</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">relu</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">nn</span>.<span style="color:#000000">ReLU</span>()
<span style="color:#aa5500"># 定义第二个线性层</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">linear2</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">nn</span>.<span style="color:#000000">Linear</span>(<span style="color:#000000">hidden_dim</span>, <span style="color:#0055aa">self</span>.<span style="color:#000000">output_dim</span>)
<span style="color:#770088">def</span> <span style="color:#0000ff">forward</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">inputs</span>):
<span style="color:#aa5500"># 输入形状: (batch_size, seq_len, input_dim)</span>
<span style="color:#000000">batch_size</span>, <span style="color:#000000">seq_len</span>, <span style="color:#000000">input_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">inputs</span>.<span style="color:#000000">size</span>()
<span style="color:#aa5500"># 将输入展平</span>
<span style="color:#000000">inputs</span> <span style="color:#981a1a">=</span> <span style="color:#000000">inputs</span>.<span style="color:#000000">view</span>(<span style="color:#000000">batch_size</span>, <span style="color:#000000">input_dim</span> <span style="color:#981a1a">*</span> <span style="color:#000000">seq_len</span>)
<span style="color:#aa5500"># 第一个线性层</span>
<span style="color:#000000">hidden</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">linear1</span>(<span style="color:#000000">inputs</span>)
<span style="color:#aa5500"># 激活函数</span>
<span style="color:#000000">hidden</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">relu</span>(<span style="color:#000000">hidden</span>)
<span style="color:#aa5500"># 第二个线性层</span>
<span style="color:#000000">output</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">linear2</span>(<span style="color:#000000">hidden</span>)
<span style="color:#aa5500"># 将输出形状重塑为与输入相同</span>
<span style="color:#aa5500"># output = output.view(batch_size, seq_len, self.output_dim)</span>
<span style="color:#aa5500"># 添加残差连接</span>
<span style="color:#000000">output</span> <span style="color:#981a1a">=</span> <span style="color:#000000">output</span> <span style="color:#981a1a">+</span> <span style="color:#000000">inputs</span>
<span style="color:#770088">return</span> <span style="color:#000000">output</span>
<span style="color:#770088">class</span> <span style="color:#0000ff">MultiAttn_layer</span>(<span style="color:#000000">torch</span>.<span style="color:#000000">nn</span>.<span style="color:#000000">Module</span>):
<span style="color:#770088">def</span> <span style="color:#0000ff">__init__</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">input_dim</span>, <span style="color:#000000">output_dim</span>, <span style="color:#000000">in_features</span>, <span style="color:#000000">hidden_features</span>, <span style="color:#000000">layer_num</span>, <span style="color:#000000">num_heads</span><span style="color:#981a1a">=</span><span style="color:#116644">4</span>):
<span style="color:#3300aa">super</span>(<span style="color:#000000">MultiAttn_layer</span>, <span style="color:#0055aa">self</span>).<span style="color:#000000">__init__</span>()
<span style="color:#0055aa">self</span>.<span style="color:#000000">input_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">input_dim</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">output_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">output_dim</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">num_heads</span> <span style="color:#981a1a">=</span> <span style="color:#000000">num_heads</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">norm</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nn</span>.<span style="color:#000000">LayerNorm</span>(<span style="color:#000000">in_features</span>, <span style="color:#000000">eps</span><span style="color:#981a1a">=</span><span style="color:#116644">1e-06</span>, <span style="color:#000000">elementwise_affine</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
<span style="color:#0055aa">self</span>.<span style="color:#000000">FC</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nn</span>.<span style="color:#000000">Linear</span>(<span style="color:#000000">in_features</span><span style="color:#981a1a">=</span><span style="color:#000000">in_features</span>, <span style="color:#000000">out_features</span><span style="color:#981a1a">=</span><span style="color:#116644">2</span>)
<span style="color:#0055aa">self</span>.<span style="color:#000000">layer_num</span> <span style="color:#981a1a">=</span> <span style="color:#000000">layer_num</span>
<span style="color:#770088">if</span> <span style="color:#000000">hidden_features</span> <span style="color:#770088">is</span> <span style="color:#770088">None</span>:
<span style="color:#0055aa">self</span>.<span style="color:#000000">hidden_dim</span> <span style="color:#981a1a">=</span> <span style="color:#116644">4</span> <span style="color:#981a1a">*</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">output_dim</span>
<span style="color:#770088">else</span>:
<span style="color:#0055aa">self</span>.<span style="color:#000000">hidden_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">hidden_features</span>
<span style="color:#aa5500"># 多头自注意力层</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">self_attn</span> <span style="color:#981a1a">=</span> <span style="color:#000000">MultiHeadSelfAttentionWithResidual</span>(<span style="color:#000000">input_dim</span>, <span style="color:#000000">output_dim</span>, <span style="color:#000000">num_heads</span>)
<span style="color:#aa5500"># 前馈神经网络层</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">feed_forward</span> <span style="color:#981a1a">=</span> <span style="color:#000000">FeedForwardLayerWithResidual</span>(<span style="color:#000000">in_features</span>, <span style="color:#000000">hidden_features</span>)
<span style="color:#770088">def</span> <span style="color:#0000ff">forward</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">inputs</span>):
<span style="color:#aa5500"># 输入形状: (batch_size, seq_len, input_dim)</span>
<span style="color:#000000">attn_output</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">self_attn</span>(<span style="color:#000000">inputs</span>)
<span style="color:#aa5500"># 先经过多头自注意力层</span>
<span style="color:#770088">for</span> <span style="color:#000000">i</span> <span style="color:#770088">in</span> <span style="color:#3300aa">range</span>(<span style="color:#0055aa">self</span>.<span style="color:#000000">layer_num</span><span style="color:#981a1a">-</span><span style="color:#116644">1</span>):
<span style="color:#000000">attn_output</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">self_attn</span>(<span style="color:#000000">attn_output</span>)
<span style="color:#aa5500"># 再经过前馈神经网络层</span>
<span style="color:#000000">output</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">feed_forward</span>(<span style="color:#000000">attn_output</span>)
<span style="color:#000000">output</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">norm</span>(<span style="color:#000000">output</span>)
<span style="color:#000000">output</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">FC</span>(<span style="color:#000000">output</span>)
<span style="color:#770088">return</span> <span style="color:#000000">output</span></span></span>
代码优化方法
在上面的多头注意力层中,我们手动搭建了一个多头注意力层,并复现了论文中的前馈网络层。实际上,我们可以使用目前比较火热的一些Transformer模型中的模块来帮助我们快速搭建好想要的模型,这些优秀模型的架构在一定程度上更合理、不易出错。在这里我使用了Vision Transformer模型中的部分网络架构来改进了多头注意力层,代码如下:
<span style="background-color:#f8f8f8"><span style="color:#333333"><span style="color:#aa5500"># 改进后的多头注意力层</span>
<span style="color:#770088">import</span> <span style="color:#000000">torch</span>
<span style="color:#770088">import</span> <span style="color:#000000">math</span>
<span style="color:#770088">import</span> <span style="color:#000000">ssl</span>
<span style="color:#770088">import</span> <span style="color:#000000">torch</span>.<span style="color:#000000">nn</span> <span style="color:#770088">as</span> <span style="color:#000000">nn</span>
<span style="color:#770088">from</span> <span style="color:#000000">timm</span>.<span style="color:#000000">models</span>.<span style="color:#000000">vision_transformer</span> <span style="color:#770088">import</span> <span style="color:#000000">vit_base_patch8_224</span>
<span style="color:#000000">ssl</span>.<span style="color:#000000">_create_default_https_context</span> <span style="color:#981a1a">=</span> <span style="color:#000000">ssl</span>.<span style="color:#000000">_create_unverified_context</span>
<span style="color:#770088">def</span> <span style="color:#0000ff">pos_encoding</span>(<span style="color:#000000">OutputOfFeatureMappingLayer</span>, <span style="color:#000000">B</span>, <span style="color:#000000">N</span>, <span style="color:#000000">C</span>):
<span style="color:#000000">OutputOfFeatureMappingLayer</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">tensor</span>(<span style="color:#000000">OutputOfFeatureMappingLayer</span>, <span style="color:#000000">dtype</span><span style="color:#981a1a">=</span><span style="color:#000000">torch</span>.<span style="color:#000000">float</span>).<span style="color:#000000">reshape</span>(<span style="color:#000000">B</span>, <span style="color:#000000">N</span>, <span style="color:#000000">C</span>)
<span style="color:#aa5500"># 定义位置编码的最大序列长度和特征维度</span>
<span style="color:#000000">max_sequence_length</span> <span style="color:#981a1a">=</span> <span style="color:#000000">OutputOfFeatureMappingLayer</span>.<span style="color:#000000">size</span>(<span style="color:#116644">1</span>)
<span style="color:#000000">feature_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">OutputOfFeatureMappingLayer</span>.<span style="color:#000000">size</span>(<span style="color:#116644">2</span>)
<span style="color:#aa5500"># 计算位置编码矩阵</span>
<span style="color:#000000">position_encodings</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">zeros</span>(<span style="color:#000000">max_sequence_length</span>, <span style="color:#000000">feature_dim</span>)
<span style="color:#000000">position</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">arange</span>(<span style="color:#116644">0</span>, <span style="color:#000000">max_sequence_length</span>, <span style="color:#000000">dtype</span><span style="color:#981a1a">=</span><span style="color:#000000">torch</span>.<span style="color:#000000">float</span>).<span style="color:#000000">unsqueeze</span>(<span style="color:#116644">1</span>)
<span style="color:#000000">div_term</span> <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">exp</span>(<span style="color:#000000">torch</span>.<span style="color:#000000">arange</span>(<span style="color:#116644">0</span>, <span style="color:#000000">feature_dim</span>, <span style="color:#116644">2</span>).<span style="color:#000000">float</span>() <span style="color:#981a1a">*</span> (<span style="color:#981a1a">-</span><span style="color:#000000">math</span>.<span style="color:#000000">log</span>(<span style="color:#116644">10000.0</span>) <span style="color:#981a1a">/</span> <span style="color:#000000">feature_dim</span>))
<span style="color:#000000">position_encodings</span>[:, <span style="color:#116644">0</span>::<span style="color:#116644">2</span>] <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">sin</span>(<span style="color:#000000">position</span> <span style="color:#981a1a">*</span> <span style="color:#000000">div_term</span>)
<span style="color:#000000">position_encodings</span>[:, <span style="color:#116644">1</span>::<span style="color:#116644">2</span>] <span style="color:#981a1a">=</span> <span style="color:#000000">torch</span>.<span style="color:#000000">cos</span>(<span style="color:#000000">position</span> <span style="color:#981a1a">*</span> <span style="color:#000000">div_term</span>)
<span style="color:#aa5500"># 将位置编码矩阵扩展为和输入张量 x 的形状一致</span>
<span style="color:#000000">position_encodings</span> <span style="color:#981a1a">=</span> <span style="color:#000000">position_encodings</span>.<span style="color:#000000">unsqueeze</span>(<span style="color:#116644">0</span>).<span style="color:#000000">expand_as</span>(<span style="color:#000000">OutputOfFeatureMappingLayer</span>)
<span style="color:#000000">OutputOfFeatureMappingLayer</span> <span style="color:#981a1a">=</span> <span style="color:#000000">OutputOfFeatureMappingLayer</span> <span style="color:#981a1a">+</span> <span style="color:#000000">position_encodings</span>
<span style="color:#770088">return</span> <span style="color:#000000">OutputOfFeatureMappingLayer</span>
<span style="color:#770088">class</span> <span style="color:#0000ff">MultiAttn_layer</span>(<span style="color:#000000">torch</span>.<span style="color:#000000">nn</span>.<span style="color:#000000">Module</span>):
<span style="color:#770088">def</span> <span style="color:#0000ff">__init__</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">input_dim</span>, <span style="color:#000000">output_dim</span>, <span style="color:#000000">in_features</span>, <span style="color:#000000">hidden_features</span>, <span style="color:#000000">layer_num</span>, <span style="color:#000000">num_heads</span><span style="color:#981a1a">=</span><span style="color:#116644">4</span>):
<span style="color:#3300aa">super</span>(<span style="color:#000000">MultiAttn_layer</span>, <span style="color:#0055aa">self</span>).<span style="color:#000000">__init__</span>()
<span style="color:#0055aa">self</span>.<span style="color:#000000">input_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">input_dim</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">output_dim</span> <span style="color:#981a1a">=</span> <span style="color:#000000">output_dim</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">num_head</span> <span style="color:#981a1a">=</span> <span style="color:#000000">num_heads</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">norm</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nn</span>.<span style="color:#000000">LayerNorm</span>(<span style="color:#000000">input_dim</span>, <span style="color:#000000">eps</span><span style="color:#981a1a">=</span><span style="color:#116644">1e-06</span>, <span style="color:#000000">elementwise_affine</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
<span style="color:#0055aa">self</span>.<span style="color:#000000">fc</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nn</span>.<span style="color:#000000">Linear</span>(<span style="color:#000000">in_features</span><span style="color:#981a1a">=</span><span style="color:#000000">in_features</span>, <span style="color:#000000">out_features</span><span style="color:#981a1a">=</span><span style="color:#116644">2</span>, <span style="color:#000000">bias</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
<span style="color:#0055aa">self</span>.<span style="color:#000000">infeatures</span> <span style="color:#981a1a">=</span> <span style="color:#000000">in_features</span>
<span style="color:#0055aa">self</span>.<span style="color:#000000">layer_num</span> <span style="color:#981a1a">=</span> <span style="color:#000000">layer_num</span>
<span style="color:#000000">model</span> <span style="color:#981a1a">=</span> <span style="color:#000000">vit_base_patch8_224</span>(<span style="color:#000000">pretrained</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
<span style="color:#000000">model</span>.<span style="color:#000000">blocks</span>[<span style="color:#116644">0</span>].<span style="color:#000000">norm1</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nn</span>.<span style="color:#000000">LayerNorm</span>(<span style="color:#000000">input_dim</span>, <span style="color:#000000">eps</span><span style="color:#981a1a">=</span><span style="color:#116644">1e-06</span>, <span style="color:#000000">elementwise_affine</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
<span style="color:#000000">model</span>.<span style="color:#000000">blocks</span>[<span style="color:#116644">0</span>].<span style="color:#000000">attn</span>.<span style="color:#000000">qkv</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nn</span>.<span style="color:#000000">Linear</span>(<span style="color:#000000">in_features</span><span style="color:#981a1a">=</span><span style="color:#000000">input_dim</span>, <span style="color:#000000">out_features</span><span style="color:#981a1a">=</span><span style="color:#000000">input_dim</span><span style="color:#981a1a">*</span><span style="color:#116644">3</span>, <span style="color:#000000">bias</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
<span style="color:#000000">model</span>.<span style="color:#000000">blocks</span>[<span style="color:#116644">0</span>].<span style="color:#000000">attn</span>.<span style="color:#000000">proj</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nn</span>.<span style="color:#000000">Linear</span>(<span style="color:#000000">in_features</span><span style="color:#981a1a">=</span><span style="color:#000000">input_dim</span>, <span style="color:#000000">out_features</span><span style="color:#981a1a">=</span><span style="color:#000000">input_dim</span>, <span style="color:#000000">bias</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
<span style="color:#000000">model</span>.<span style="color:#000000">blocks</span>[<span style="color:#116644">0</span>].<span style="color:#000000">proj</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nn</span>.<span style="color:#000000">Linear</span>(<span style="color:#000000">in_features</span><span style="color:#981a1a">=</span><span style="color:#000000">input_dim</span>, <span style="color:#000000">out_features</span><span style="color:#981a1a">=</span><span style="color:#000000">input_dim</span>, <span style="color:#000000">bias</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
<span style="color:#000000">model</span>.<span style="color:#000000">blocks</span>[<span style="color:#116644">0</span>].<span style="color:#000000">norm2</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nn</span>.<span style="color:#000000">LayerNorm</span>(<span style="color:#000000">input_dim</span>, <span style="color:#000000">eps</span><span style="color:#981a1a">=</span><span style="color:#116644">1e-06</span>, <span style="color:#000000">elementwise_affine</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
<span style="color:#000000">model</span>.<span style="color:#000000">blocks</span>[<span style="color:#116644">0</span>].<span style="color:#000000">mlp</span>.<span style="color:#000000">fc1</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nn</span>.<span style="color:#000000">Linear</span>(<span style="color:#000000">in_features</span><span style="color:#981a1a">=</span><span style="color:#000000">input_dim</span>, <span style="color:#000000">out_features</span><span style="color:#981a1a">=</span><span style="color:#000000">input_dim</span><span style="color:#981a1a">*</span><span style="color:#116644">3</span>, <span style="color:#000000">bias</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
<span style="color:#000000">model</span>.<span style="color:#000000">blocks</span>[<span style="color:#116644">0</span>].<span style="color:#000000">mlp</span>.<span style="color:#000000">fc2</span> <span style="color:#981a1a">=</span> <span style="color:#000000">nn</span>.<span style="color:#000000">Linear</span>(<span style="color:#000000">in_features</span><span style="color:#981a1a">=</span><span style="color:#000000">input_dim</span><span style="color:#981a1a">*</span><span style="color:#116644">3</span>, <span style="color:#000000">out_features</span><span style="color:#981a1a">=</span><span style="color:#000000">input_dim</span>, <span style="color:#000000">bias</span><span style="color:#981a1a">=</span><span style="color:#770088">True</span>)
<span style="color:#0055aa">self</span>.<span style="color:#000000">blocks</span> <span style="color:#981a1a">=</span> <span style="color:#000000">model</span>.<span style="color:#000000">blocks</span>[<span style="color:#116644">0</span>]
<span style="color:#770088">def</span> <span style="color:#0000ff">forward</span>(<span style="color:#0055aa">self</span>, <span style="color:#000000">inputs</span>):
<span style="color:#000000">output</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">blocks</span>(<span style="color:#000000">inputs</span>)
<span style="color:#770088">for</span> <span style="color:#000000">i</span> <span style="color:#770088">in</span> <span style="color:#3300aa">range</span>(<span style="color:#0055aa">self</span>.<span style="color:#000000">layer_num</span><span style="color:#981a1a">-</span><span style="color:#116644">1</span>):
<span style="color:#000000">output</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">blocks</span>(<span style="color:#000000">output</span>)
<span style="color:#000000">output</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">norm</span>(<span style="color:#000000">output</span>)
<span style="color:#000000">output</span> <span style="color:#981a1a">=</span> <span style="color:#000000">output</span>.<span style="color:#000000">view</span>(<span style="color:#000000">output</span>.<span style="color:#000000">shape</span>[<span style="color:#116644">0</span>], <span style="color:#0055aa">self</span>.<span style="color:#000000">infeatures</span>)
<span style="color:#000000">output</span> <span style="color:#981a1a">=</span> <span style="color:#0055aa">self</span>.<span style="color:#000000">fc</span>(<span style="color:#000000">output</span>)
<span style="color:#770088">return</span> <span style="color:#000000">output</span>
</span></span>
实验结果
原文作者将他们所提出的模型在两份数据集上做了对比实验,结果如下图所示,可以看到Multi-Attn BLS在四类指标上都能取得最优值,其性能其优于岭回归算法、传统BLS模型和LSTM。
使用方式
我将本次复现的代码集成到一个main.py文件中,需要的小伙伴们可以在终端输入相应命令调用。
环境配置
笔者本地的环境为macOS M2,使用的编程语言为python3.9,注意尽量不要使用版本太低的python。本次复现没有使用很复杂、冷门的库,因此最重要的是配置好pytorch环境,关于这个的文档很多,大家可以任意找一份配置好环境。
文件结构
代码的文件结构如上图所示, Data_Processing.py是对数据的前处理函数,如果需要使用自有数据集可以忽略这部分; MultiAttn_layer.py是复现的多头注意力层; VIT_layer.py则是改进的使用VIT模块搭建的多头注意力层; BLSlayer.py是BLS映射层,用来得到特征层和增强层; 我们用main.py来集成这些下游代码函数,进行端对端的分类;
X.npy和y.npy文件是我本次使用的测试数据,分别是数据和代码,使用时请注意将您的数据和代码也以这种.npy格式保存到该文件夹中。它们具体的数据类型和格式为
X每一行表示一个样本,545是每一个样本的输入维度;Y采用数字编码,是一个一维向量。
终端调用 终端进入该文件夹后,可以使用以下命令来调用main.py启动模型训练
<span style="background-color:#f8f8f8"><span style="color:#333333"><span style="color:#000000">python3</span> <span style="color:#981a1a">-</span><span style="color:#000000">c</span> <span style="color:#aa1111">"from main import MultiAttn_BLS; MultiAttn_BLS('./X.npy','./y.npy')"</span></span></span>
MultiAttn_BLS函数的输入参数介绍如下:
<span style="background-color:#f8f8f8"><span style="color:#333333">:<span style="color:#000000">param</span> <span style="color:#000000">datapath</span>: <span style="color:#3300aa">str</span> <span style="color:#000000">数据的路径</span>
:<span style="color:#000000">param</span> <span style="color:#000000">labelpath</span>: <span style="color:#3300aa">str</span> <span style="color:#000000">标签的路径</span>
:<span style="color:#000000">param</span> <span style="color:#000000">N1</span>: <span style="color:#3300aa">int</span> <span style="color:#000000">特征层节点组数,默认为12</span>
:<span style="color:#000000">param</span> <span style="color:#000000">N2</span>: <span style="color:#3300aa">int</span> <span style="color:#000000">特征层组内节点数,默认为12</span>
:<span style="color:#000000">param</span> <span style="color:#000000">num_heads</span>: <span style="color:#3300aa">int</span> <span style="color:#000000">多头自注意力头的数量,默认为4</span>
<span style="color:#000000">(为了让多头自注意力机制能适应BLS生成的节点维度,N1应为num_heads的奇数倍)</span>
:<span style="color:#000000">param</span> <span style="color:#000000">layer_num</span>: <span style="color:#3300aa">int</span> <span style="color:#000000">多头自注意力层数,默认为3</span>
:<span style="color:#000000">param</span> <span style="color:#000000">batch_size</span>: <span style="color:#3300aa">int</span> <span style="color:#000000">多头自注意力层训练的批次量,默认为16</span>
:<span style="color:#000000">param</span> <span style="color:#000000">s</span>: <span style="color:#3300aa">float</span> <span style="color:#000000">缩放尺度,一般设置为0</span><span style="color:#116644">.8</span>
:<span style="color:#000000">param</span> <span style="color:#000000">num_epochs</span>: <span style="color:#3300aa">int</span> <span style="color:#000000">多头自注意力层训练的迭代次数,默认为50</span>
:<span style="color:#000000">param</span> <span style="color:#000000">test_size</span>: <span style="color:#3300aa">float</span> <span style="color:#000000">划分数据集:测试集的比例,默认为0</span><span style="color:#116644">.3</span>
:<span style="color:#770088">return</span>: <span style="color:#000000">训练结束后的模型</span></span></span>
我基于本地的数据集做了一份分类任务,可以看到随着epoch增加,训练损失在不断下降,测试的准确率可以达到77.52%,还不错的结果。大家可以根据自己的需要修改模型的输出层来适应不同的学习任务。
需要本文的详细复现过程的项目源码、数据和预训练好的模型可从该地址处获取完整版:传知代码论文复现
希望对你有帮助!加油!
若您认为本文内容有益,请不吝赐予赞同并订阅,以便持续接收有价值的信息。衷心感谢您的关注和支持!