当前位置: 首页 > article >正文

【自用】NLP算法面经(5)

一、L1、L2正则化

正则化是机器学习中用于防止过拟合并提高模型泛化能力的技术。当模型过拟合时,它已经很好地学习了训练数据,甚至是训练数据中的噪声,所以可能无法在新的、未见过的数据上表现良好。
比如:
在这里插入图片描述
其中,x1和x2为特征,f为拟合模型,w1和w2为模型权重,b为模型偏执。左图拟合模型公式最高阶次为1,即一条直线,对应欠拟合;中间拟合模型公式最高阶次为2,即一条简单的曲线;右图拟合模型公式最高阶次为4甚至更高,即一条复杂的曲线,对应过拟合。
可以看出,欠拟合时模型未有效学习数据中的信息,错分样本很多;过拟合时模型学习过于充分,甚至连包裹在红色类中的两个明显的噪声都被学习了,训练样本不会被错分,但无法保证对没见过的测试样本进行有效划分。所以更希望得到中间的这种模型。
对比中图和右图,容易发现模型阶次越高,分类超平面弯曲越多,高阶项前系数越大,弯曲程度越大,所以减少过拟合其实就是要减少模型的高阶次或弱化高阶项

1、L1正则化

也称为LASSO正则化,将模型权重系数的绝对值之和添加到损失函数中。它能够将模型中某些权重置0使对应的特征不再发挥作用以达到降低模型复杂度的作用。
假设原本模型的分类损失为标准交叉熵,则加上L1正则化项后如下:
在这里插入图片描述
其中,yi和pi是样本xi的真实标签和预测概率,wj为权重系数,λ用来平衡学习和正则化程度。
有了后面这一项,在优化损失时,|wj|会一定程度地减小,从而达到弱化高阶项的作用(其实低阶项也会被弱化,但分类超平面的复杂度主要受高阶项控制)。其实,L1正则化能达到使模型稀疏化的作用,即有些权重被置0。
简化上面的损失函数只有一个权重系数,写作J(w)=L(w)+λ|w|,假设L(w)在w=0时的导数如下:
在这里插入图片描述
于是有:
在这里插入图片描述
如果λ较大,会使损失函数的导数在w=0的左右两侧异号,则该点极可能是一个极小值点,在优化时,很可能将w优化至0。对于多个w的情况,与之类似,但只是一部分w取0即可达到极小值。部分w置0,则对应的特征不在发挥作用,从而使模型稀疏化。
L1正则化能够通过使模型稀疏化达到降低模型复杂度的作用。这种稀疏化特性使它能够作为一种特征选择策略,适合在高维且特征相关性不强的场景中使用。

2、L2正则化

也称为Ridge正则化,将模型系数的平方值之和添加到损失函数中。与L1正则化不同,L2正则化不会强制系数恰好为0,而是鼓励系数变小。
仍然假设原本模型的分类损失为标准交叉熵,则加上L2正则化项后如下:
在这里插入图片描述
同样地,简化上面的损失函数只有一个权重系数,写作J(w)=L(w)+λw²,假设L(w)在w=0时的导数如下:
在这里插入图片描述
有:
在这里插入图片描述
可见,L2正则项的加入不影响w=0处损失函数的导数,也就不容易在w=0处形成极小值。响应地,w就不容易被优化为0。对于多个w的情况,所有wj都不为0,却又希望损失J(w)小,就会将各个wj优化的很小,也就使高阶项发挥的作用变小,从而降低模型复杂度。
L2正则化能够通过将各项权重系数优化得很小以达到降低模型复杂度的目的。它能够减少单个特征在模型中的作用,避免某个特征主导整个预测方向。L2正则化项是可微的,优化计算效率更高,适合处理低维且特征间具有强相关性的场景。

二、self-attention

class SelfAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(SelfAttention, self).__init__()

        self.query_matrix = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.key_matrix = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.value_matrix = nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.dropout = nn.Dropout(0.1)
        self.scale = torch.sqrt(torch.FloatTensor([hidden_dim])).to(device)

    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.size()
        Q = self.query_matrix(x)
        K = self.key_matrix(x)
        V = self.value_matrix(x)
        
        scores = torch.matmul(Q, K.transpose(1,2))
        
        scaled_scores = scores / self.scale
        
        attn_weights = torch.softmax(scaled_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = torch.matmul(attn_weights, V)
        
        return attn_output, attn_weights

三、稳定版softmax

import math

def softmax(x):
    max_x = max(x)
    exp_x = [math.exp(i - max_x) for i in x]
    sum_exp_x = sum(exp_x)
    return [i / sum_exp_x for i in exp_x]

四、chatglm3-6b推理

from transformers import AutoTokenizer, AutoModel
import torch

model_path = "THUDM/chatglm3-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path)

model.eval()

def chatglm3_inference(query, history=None, max_length=4096, temperature=0.8, top_p=0.9):
    inputs = tokenizer.build_chat_input(
        query,
        history=history,
        role="user",
        max_length=max_length,
        padding="max_length" if len(history) > 0 else "do_not_pad"
    ).to(model.device)

    gen_kwargs = {
        "max_length": max_length,
        "do_sample": True,
        "temperature": temperature,
        "top_p": top_p,
        "pad_token_id": tokenizer.eos_token_id
    }
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)

    response = tokenizer.decode(
        outputs[0][len(input["input_ids"][0]):],
        skip_special_token=True,
        clean_up_tokenization_spaces=True
    )

    updated_history = []
    if history:
        updated_history.append(history)
    updated_history.append((query, response))
    return response, updated_history

if __name__ == '__main__':
    history = []

    user_input = "xxx"
    response, history = chatglm3_inference(user_input, history)

http://www.kler.cn/a/593476.html

相关文章:

  • MySQL InnoDB引擎中Redo Log、Binlog、Undo Log的原理、执行顺序
  • 【css酷炫效果】纯CSS实现火焰文字特效
  • 【病毒分析】伪造微软官网+勒索加密+支付威胁,CTF中勒索病毒解密题目真实还原!
  • Spring Boot配置与注解的使用
  • 力扣105. 从前序与中序遍历序列构造二叉树(Java实现)
  • 内网渗透练习-红日靶场一
  • pytorch3d学习(四)——批处理(实现多obj文件读取)
  • Unity3D开发AI桌面精灵/宠物系列 【二】 语音唤醒 ivw 的两种方式-Windows本地或第三方讯飞等
  • 【AI知识】导数,偏导,梯度的解释
  • 内网安全-横向移动Kerberos 攻击SPN 扫描WinRMWinRSRDP
  • 详细解析GetOpenFileName()
  • 如何打造安全稳定的亚马逊采购测评自养号下单系统?
  • Solana笔记案例:写一个SOL转账程序
  • C#:使用UDP协议实现数据的发送和接收
  • 量子计算专业书籍,做个比较
  • C++ 核心编程 ——4.9 文件操作
  • 主流NoSQL数据库类型及选型分析
  • 接口测试工具:Jmeter
  • UI设计中的模态对话框:合理使用指南
  • 举例说明 牛顿法 Hessian 矩阵