NLP-transformer学习:(8)trainer 使用方法
NLP-transformer学习:(8)trainer 使用方法
11月工作996压力较大,任务完成后,目前休息了一个月,2025年新的一天继续开始补基础。
本章节是单独的 NLP-transformer学习 章节,主要实践了evaluate。同时,最近将学习代码传到:https://github.com/MexWayne/mexwayne_transformers-code,
作者的代码版本有些细节我发现到目前不能完全行的通,为了尊重原作者,我这里保持了大部分的内容,并标明了来源,欢迎大家一起学习。
文章目录
- NLP-transformer学习:(8)trainer 使用方法
- 一、整体代码
- 二、遇到的问题
- 问题1:如下图
- 问题2:如下图
- 问题3:如下图
一、整体代码
这里没什么好讲的说实话,我这里将整体代码附上:
# import the related package
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import torch
import evaluate
from transformers import DataCollatorWithPadding
def process_function(examples):
tokenized_examples = tokenizer(examples["review"], max_length=128, truncation=True)
tokenized_examples["labels"] = examples["label"]
return tokenized_examples
def eval_metric(eval_predict):
predictions, labels = eval_predict
predictions = predictions.argmax(axis=-1)
acc = acc_metric.compute(predictions=predictions, references=labels)
f1 = f1_metric.compute(predictions=predictions, references=labels)
acc.update(f1)
return acc
if __name__ == "__main__":
# download the dataset
dataset = load_dataset("csv", data_files="/home/mex/Desktop/learn_transformer/mexwayne_transformers_NLP/01-Getting_Started/07-trainer/ChnSentiCorp_htl_all.csv", split="train")
dataset = dataset.filter(lambda x: x["review"] is not None)
print("load dataset:")
print(dataset)
# split the dataset into 0.1 and 0.9, the 0.1 for test_dataset
datasets = dataset.train_test_split(test_size=0.1)
print("split dataset:")
print(dataset)
# build tokenize the data
tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")
tokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)
print(tokenized_datasets)
# load the model
model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3")
print("model config:")
print(model.config)
# build the evaluate module
acc_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
# build train args
train_args = TrainingArguments(output_dir="./checkpoints", # train model output path
per_device_train_batch_size=64, # train batch_size
per_device_eval_batch_size=128, # test batch_size
logging_steps=10, # log
eval_strategy="epoch", # evaluation strategy
save_strategy="epoch", # save the model every epoch
save_total_limit=3, # only keep 3 model save
learning_rate=2e-5, #
weight_decay=0.01, #
metric_for_best_model="f1", #
load_best_model_at_end=True) # save the best model after train
print("train_args")
print(train_args)
# build trainer
trainer = Trainer(model=model,
args=train_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["test"],
data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
compute_metrics=eval_metric)
# start trainnig
trainer.train()
# evaluation
trainer.evaluate(tokenized_datasets["test"])
#####################################################
#####################################################
# try one new, to predict a results
trainer.predict(tokenized_datasets["test"])
from transformers import pipeline
model.config.id2label = id2_label
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=0)
sen = "我觉得还是蛮不错的哈!"
print(pipe(sen))
其中依赖的 csv 在我的github 仓库下:
https://github.com/MexWayne/mexwayne_transformers-code/blob/master/01-Getting_Started/07-trainer/ChnSentiCorp_htl_all.csv
当你正确的配置环境后可以看到数据被正确加载
以及相关的trainer 可以训练数据
二、遇到的问题
问题1:如下图
遇到这个问题
conda install accelerate 不解决问题
真正的解决方案:
要用最新的
pip install -U accelerate
pip install -U transformers
问题2:如下图
注意这里 我用了 一个 库
要手动键入,1 和 2 都不行,3可以,这是最终结果
问题3:如下图
笔者之前一直能提交的仓库突然不行了
后来查了下,用 ssh -T git@github.com 可以次测试通断
发现笔者的 22port 用不了
后来采取如下方法
当需要恢复22port 时,只需要删除这个config 即可