当前位置: 首页 > article >正文

从预训练的BERT中提取Embedding

文章目录

    • 背景
    • 前置准备
    • 思路
    • 利用Transformer 库实现

背景

假设要执行一项情感分析任务,样本数据如下
在这里插入图片描述
可以看到几个句子及其对应的标签,其中1表示正面情绪,0表示负面情绪。我们可以利用给定的数据集训练一个分类器,对句子所表达的情感进行分类。

前置准备

# 安装modelscope包
pip install modelscope
# 下载 bert-base-uncased 模型
modelscope download --model AI-ModelScope/bert-base-uncased

思路

  1. 分词:以第一句为例,我们使用WordPiece对句子进行分词,并得到标记(单词),如下所示。

    tokens = [I, love, Paris]

  2. 添加标记:在开头添加[CLS]标记,在结尾添加[SEP]标记,如下所示。

    tokens = [ [CLS], I, love, Paris, [SEP] ]

  3. 填充:为了保持所有标记的长度一致,我们将数据集中的所有句子的标记长度设为7。句子I loveParis的标记长度是5,为了使其长度为7,需要添加两个标记来填充,即[PAD]。因此,新标记如下所示。

    tokens = [ [CLS], I, love, Paris, [SEP], [PAD], [PAD] ]

    添加两个[PAD]标记后,标记的长度达到所要求的7。

  4. 注意力掩码:下一步,要让模型理解[PAD]标记只是为了匹配标记的长度,而不是实际标记的一部分。为了做到这一点,我们需要引入一个注意力掩码。我们将所有位置的注意力掩码值设置为1,将[PAD]标记的位置设置为0,如下所示。

    attention_mask = [ 1, 1, 1, 1, 1, 0, 0]

  5. 映射到token id:然后,将所有的标记映射到一个唯一的标记ID。假设映射的标记ID如下所示。

    token_ids = [101, 1045, 2293, 3000, 102, 0, 0]

    ID 101表示标记[CLS],1045表示标记I,2293表示标记love,以此类推。

    现在,我们把token_ids和attention_mask一起输入预训练的BERT模型,并获得每个标记的特征向量(嵌入)。通过代码,我们可以进一步理解以上步骤。下图显示的标记+单词而不是id,但实际传入的是id
    在这里插入图片描述

以上,可以得到每个单词的Embedding,整个句子的Embedding是 R [ C L S ] R_{[CLS]} R[CLS]

利用Transformer 库实现

from transformers import BertModel, BertTokenizer
import torch
# 下载并加载预训练的模型
model = BertModel.from_pretrained('bert-base-uncased')
# 下载并加载用于预训练模型的词元分析器。
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 下面,让我们看看如何对输入进行预处理。
# 0. 对输入进行预处理假设输入句如下所示。
sentence = 'I love Paris'
# 1. 分词
tokens = tokenizer.tokenize(sentence)
print(tokens) # ['i', 'love', 'paris']

# 2. 添加标记
tokens = ['[CLS]'] + tokens + ['[SEP]']
print(tokens) # ['[CLS]', 'i', 'love', 'paris', '[SEP]']

# 3. 填充
tokens = tokens + ['[PAD]'] + ['[PAD]']
print(tokens) #['[CLS]', 'i', 'love', 'paris', '[SEP]', '[PAD]', '[PAD]' ]

# 4. 注意力掩码
attention_mask = [1 if i!= '[PAD]' else 0 for i in tokens]
print(attention_mask) # [1, 1, 1, 1, 1, 0, 0]

# 5. 将所有标记转换为它们的标记ID
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids) # [101, 1045, 2293, 3000, 102, 0, 0]

# 6. 将token_ids和attention_mask转换为张量
token_ids = torch.tensor(token_ids).unsqueeze(0)
attention_mask = torch.tensor(attention_mask).unsqueeze(0)


# 7. 将token_ids和atten-tion_mask送入模型,并得到嵌入向量。
# 需要注意,model返回的输出是一个有两个值的元组。第1个值hidden_rep表示隐藏状态的特征,它包括从顶层编码器(编码器12)获得的所有标记的特征。第2个值cls_head表示[CLS]标记的特征。
hidden_rep, cls_head = model(token_ids, attention_mask = attention_mask)
print(hidden_rep.shape) # torch.Size([1, 7, 768])

'''
数组[1, 7, 768]表示[batch_size, se-quence_length, hidden_size],也就是说,批量大小设为1,序列长度等于标记长度,即7。因为有7个标记,所以序列长度为7。隐藏层的大小等于特征向量(嵌入向量)的大小,在BERT-base模型中,其为768。
* hidden_rep[0][0]给出了第1个标记[CLS]的特征。   
* hidden_rep[0][1]给出了第2个标记I的特征。   
* hidden_rep[0][2]给出了第3个标记love的特征
'''

print(cls_head.shape) # torch.Size([1, 768])

'''
大小[1, 768]表示[batch_size, hid-den_size]。我们知道cls_head持有句子的总特征,所以,可以用cls_head作为句子I love Paris的整句特征。
'''


以上获得的是从顶层编码器(编码器12)获得的特征,如果要获取所有编码器的特征,需要修改以下两个地方。

# 下载并加载预训练的模型时,设置output_hidden_states = True
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 调用模型时,产生的是三元组
last_hidden_state, pooler_output, hidden_states = model(token_ids, attention_mask = attention_mask)

'''
* last_hidden_state,它仅有从最后的编码器(编码器12)中获得的所有标记的特征
* pooler_output表示来自最后的编码器的[CLS]标记的特征,它被一个线性激活函数和tanh激活函数进一步处理。
* hidden_states包含从所有编码器层获得的所有标记的特征。它是一个包含13个值的元组,含有所有编码器层(隐藏层)的特征,即从输入嵌入层h到最后的编码器层h。
'''

http://www.kler.cn/a/500256.html

相关文章:

  • Kotlin面向对象编程
  • 数据在内存的存储
  • Microsoft Sql Server 2019 数据类型
  • 【计算机网络】lab4 Ipv4(IPV4的研究)
  • k8s helm部署kafka集群(KRaft模式)——筑梦之路
  • OpenAI Swarm的使用过程记录
  • ubuntu/kali安装c-jwt-cracker
  • react的statehook useState Hook详细使用
  • 【YOLO】将多类别YOLO格式txt标签数据区分成单类别标签
  • jQuery CSS 类
  • react全局状态管理——redux和zustand,及其区别
  • docker更换镜像源脚本
  • 单元测试概述入门
  • 附加共享数据库( ATTACH DATABASE)的使用场景
  • 微信小程序用的SSL证书有什么要求吗?
  • 设计模式 行为型 解释器模式(Interpreter Pattern)与 常见技术框架应用 解析
  • 基于高斯混合模型的数据分析及其延伸应用(具体代码分析)
  • 【LevelDB 和 Sqlite】
  • 芯片:CPU和GPU有什么区别?
  • 【了解到的与深度学习有关知识】
  • 逆向 易九批 最新版 爬虫逆向 x-sign ......
  • Python 写的 智慧记 进销存 辅助 程序 导入导出 excel 可打印 Pyside6版
  • An FPGA-based SoC System——RISC-V On PYNQ项目复现