当前位置: 首页 > article >正文

Bert完形填空

转载自:| 03_language_model/03_Bert完形填空.ipynb | 基于transformers使用Bert模型做完形填空 |Open In Colab |

完形填空

利用语言模型,可以完成完形填空(fill mask),预测缺失的单词。
当前,效果最好的语言模型是Bert系列的预训练语言模型。

!pip install transformers
import os

from transformers import pipeline

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
model_name = "hfl/chinese-macbert-base"

nlp = pipeline("fill-mask",
               model=model_name,
               tokenizer=model_name,
               device=-1,  # gpu device id
               )
from pprint import pprint

pprint(nlp(f"明天天{nlp.tokenizer.mask_token}很好?"))
print("*" * 42)
pprint(nlp(f"明天心{nlp.tokenizer.mask_token}很好?"))
print("*" * 42)
pprint(nlp(f"张亮在哪里任{nlp.tokenizer.mask_token}?"))
print("*" * 42)
pprint(nlp(f"少先队员{nlp.tokenizer.mask_token}该为老人让座位。"))

模型默认保存在:~/.cache/huggingface/transformers

不通过pipeline,可以自己写预测逻辑:

from transformers import AutoModelWithLMHead, AutoTokenizer
import torch

# tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
# model = AutoModelWithLMHead.from_pretrained("distilbert-base-cased")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelWithLMHead.from_pretrained(model_name)

sequence = f"明天天{nlp.tokenizer.mask_token}很好."
input = tokenizer.encode(sequence, return_tensors="pt")
mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]
token_logits = model(input).logits
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
for token in top_5_tokens:
    print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token])))

http://www.kler.cn/a/383076.html

相关文章:

  • 接口测试Day-02-安装postman项目推送Gitee仓库
  • Hadoop yarn安装
  • 使用 AI 辅助开发一个开源 IP 信息查询工具:一
  • MySQL -- 库的相关操作
  • ChatGPT生成接口文档的方法与实践
  • SpringBoot 启动类 SpringApplication 二 run方法
  • Java基础使用①Java特点+环境安装+IDEA使用
  • 求猫用宠物空气净化器推荐,有没有吸毛强、噪音小的产品
  • Linux awk命令详解-参数-选项-内置变量-内置函数-脚本(多图、多示例)
  • 我们来学mysql -- EXPLAIN之ID(原理篇)
  • 爱普生 SG - 8201CJA 可编程振荡器成为电子应用的解决方案
  • 【LeetCode】【算法】142. 环形链表II
  • 开放寻址法、链式哈希数据结构详细解读
  • Vue3 + Element Plus简单使用案例及【eslint】报错处理
  • 【漏洞复现】Apache Druid RCE (CVE-2023-25194) 漏洞
  • Linux与Windows中的流量抓取工具:wireshark与tcpdump
  • 防火墙|WAF|漏洞|网络安全
  • 【LeetCode】【算法】215. 数组中的第K个最大元素
  • 内外连接【MySQL】
  • 机器学习(三)——决策树(附核心思想、重要算法、概念(信息熵、基尼指数、剪枝处理)及Python源码)
  • Flutter UI构建渲染(4)
  • Windows10/11下python脚本自动连接WiFi热点
  • STM32启动文件分析
  • Axure是什么软件?全方位解读助力设计入门
  • 实践是认识的来源
  • GPU的内存是什么?