当前位置: 首页 > article >正文

NLP-transformer学习:(8)trainer 使用方法

NLP-transformer学习:(8)trainer 使用方法

在这里插入图片描述
11月工作996压力较大,任务完成后,目前休息了一个月,2025年新的一天继续开始补基础。
本章节是单独的 NLP-transformer学习 章节,主要实践了evaluate。同时,最近将学习代码传到:https://github.com/MexWayne/mexwayne_transformers-code,
作者的代码版本有些细节我发现到目前不能完全行的通,为了尊重原作者,我这里保持了大部分的内容,并标明了来源,欢迎大家一起学习。

文章目录

  • NLP-transformer学习:(8)trainer 使用方法
  • 一、整体代码
  • 二、遇到的问题
    • 问题1:如下图
    • 问题2:如下图
    • 问题3:如下图


一、整体代码

这里没什么好讲的说实话,我这里将整体代码附上:

# import the related package
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

import torch
import evaluate
from transformers import DataCollatorWithPadding


def process_function(examples):
    tokenized_examples = tokenizer(examples["review"], max_length=128, truncation=True)
    tokenized_examples["labels"] = examples["label"]
    return tokenized_examples


def eval_metric(eval_predict):
    predictions, labels = eval_predict
    predictions = predictions.argmax(axis=-1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels)
    acc.update(f1)
    return acc


if __name__ == "__main__":
    # download the dataset
    dataset = load_dataset("csv", data_files="/home/mex/Desktop/learn_transformer/mexwayne_transformers_NLP/01-Getting_Started/07-trainer/ChnSentiCorp_htl_all.csv", split="train")
    dataset = dataset.filter(lambda x: x["review"] is not None)
    print("load dataset:")
    print(dataset)

    # split the dataset into 0.1 and 0.9, the 0.1 for test_dataset
    datasets = dataset.train_test_split(test_size=0.1)
    print("split dataset:")
    print(dataset)

    # build tokenize the data
    tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")
    tokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)
    print(tokenized_datasets)

    # load the model
    model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3")
    print("model config:")
    print(model.config)

    # build the evaluate module
    acc_metric = evaluate.load("accuracy")
    f1_metric = evaluate.load("f1")


    # build train args 
    train_args = TrainingArguments(output_dir="./checkpoints",  # train model output path 
                               per_device_train_batch_size=64,  # train batch_size
                               per_device_eval_batch_size=128,  # test batch_size
                               logging_steps=10,                # log 
                               eval_strategy="epoch",     # evaluation strategy
                               save_strategy="epoch",           # save the model every epoch 
                               save_total_limit=3,              # only keep 3 model save
                               learning_rate=2e-5,              #  
                               weight_decay=0.01,               #  
                               metric_for_best_model="f1",      #  
                               load_best_model_at_end=True)     # save the best model after train 
    print("train_args")
    print(train_args)


    # build trainer
    trainer = Trainer(model=model, 
                      args=train_args, 
                      train_dataset=tokenized_datasets["train"], 
                      eval_dataset=tokenized_datasets["test"], 
                      data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
                      compute_metrics=eval_metric)
    
    # start trainnig
    trainer.train()

    # evaluation
    trainer.evaluate(tokenized_datasets["test"])
    #####################################################


    #####################################################
    # try one new, to predict a results
    trainer.predict(tokenized_datasets["test"])
    from transformers import pipeline
    
    model.config.id2label = id2_label
    pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=0)
    sen = "我觉得还是蛮不错的哈!"
    print(pipe(sen))

其中依赖的 csv 在我的github 仓库下:
https://github.com/MexWayne/mexwayne_transformers-code/blob/master/01-Getting_Started/07-trainer/ChnSentiCorp_htl_all.csv

当你正确的配置环境后可以看到数据被正确加载
在这里插入图片描述
在这里插入图片描述
以及相关的trainer 可以训练数据

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

二、遇到的问题

问题1:如下图

在这里插入图片描述
遇到这个问题
conda install accelerate 不解决问题
真正的解决方案:

要用最新的
pip install -U accelerate
pip install -U transformers

问题2:如下图

注意这里 我用了 一个 库
在这里插入图片描述
要手动键入,1 和 2 都不行,3可以,这是最终结果
在这里插入图片描述

问题3:如下图

笔者之前一直能提交的仓库突然不行了
在这里插入图片描述
后来查了下,用 ssh -T git@github.com 可以次测试通断
在这里插入图片描述
发现笔者的 22port 用不了
后来采取如下方法在这里插入图片描述
当需要恢复22port 时,只需要删除这个config 即可


http://www.kler.cn/a/468183.html

相关文章:

  • 【Uniapp-Vue3】swiper滑块视图容器的用法
  • Objective-C 是一种面向对象的编程语言
  • pdf预览兼容问题- chrome浏览器105及一下预览不了
  • 外网访问本地部署的 VMware ESXi 服务
  • 力扣hot100——栈
  • Transformer知识梳理
  • 抖音评论地区分布可视化期末项目
  • 【微服务】【Sentinel】认识Sentinel
  • JODConverter结合LibreOffice如何转换ppt pptx成图片
  • 谷粒商城-高级篇-Sentinel-分布式系统的流量防卫兵
  • Arduino 小白的 DIY 空气质量检测仪(5)- OLED显示模块、按钮模块
  • 微信小程序校园自助点餐系统实战:从设计到实现
  • CSS系列(50)-- View Transitions详解 系列总结
  • 应用Docker快速实现 JMeter + InfluxDB + Grafana 监控方案
  • 虚拟机图像界面打不开了
  • NLP初识
  • leetcode中简单题的算法思想
  • 计算机网络•自顶向下方法:网络安全、RSA算法
  • react报错解决
  • 1、pycharm、python下载与安装
  • 服务器信息整理:用途、操作系统安装日期、设备序列化、IP、MAC地址、BIOS时间、系统
  • 什么是Kafka的重平衡机制?
  • 小红书怎么看ip所属地?小红书ip属地为什么可以变
  • 基于Spring Boot的健康饮食管理系统
  • 开发培训:慧集通(DataLinkX)iPaaS集成平台-基于接口的组件开发
  • WebSocket 基础入门:协议原理与实现