深度学习-16-深入理解BERT基于本地数据微调训练文本分类模型的流程
文章目录
- 1 加载库和设置通用参数
-
- 1.1 DistilBert
- 1.2 模型库
- 1.3 微调任务
- 2 准备数据
-
- 2.1 加载数据
- 2.2 切分数据
- 2.3 数据分词
- 2.4 制作数据集
- 3 使用Trainer API微调transformer
-
- 3.1 加载预训练模型
- 3.2 定义训练器
- 3.3 执行训练
- 3.4 评估性能
- 3.5 保存模型
- 4 使用训练好的模型
- 5 参考附录
1 加载库和设置通用参数
import pandas as pd
import torch
import transformers
from transformers import DistilBertTokenizerFast
from transformers import DistilBertForSequenceClassification
torch.backends.cudnn.deterministic = True # 用于固定cuda的随机数种子
RANDOM_SEED = 123
torch.manual_seed(RANDOM_SEED)