当前位置: 首页 > article >正文

train_args = TrainingArguments()里面的全部参数使用

目录

  • 一、全部参数介绍
    • 1.关于输出保存的设定
    • 2.关于评估策略的设定
    • 3.关于训练批次的设定
    • 4.关于优化器的设定
    • 5.关于log的超参数设定
    • 6.关于训练结果保存参数设定
    • 7.一些关于网络训练的超参数设定
    • 8.评估输出模型的一些参数
    • 9.恢复训练时的一些参数设定
    • 10.其他配置设定
    • 11.关于hub的一些参数设定
  • 二、常用的参数总结

进入到 class TrainingArguments:接下来介绍同Trainer一起使用的train_args里面的一些参数含义
重要超参数用 加粗* 体现

一、全部参数介绍

1.关于输出保存的设定

*1.output_dir (:obj:str):
模型训练后保存的模型地址,例如:output_dir="./checkpoints",


2.overwrite_output_dir (:obj:bool, optional, defaults to :obj:False):
指的是在输出目录output_dir已经存在的情况下将删除该目录并重新创建。
如果:obj: ’ True ',覆盖output_dir中指定的输出目录下的内容。默认值obj: ’ False ’
但是不添加这个参数,网络的训练参数也是覆盖的,一般不使用。

2.关于评估策略的设定

*3.evaluation_strategy (str or [~trainer_utils.IntervalStrategy], optional, defaults to "no"):
三个选项

  • :obj:"no": eval期间不进行评估。
  • :obj:"steps": 每一个eval_steps阶段之后都进行评估
  • :obj:"epoch": 每一个epoch之后进行评估

一般设定为

evaluation_strategy="epoch",     
# 每一轮进行一次评估策略,等同于evaluation_strategy="steps",因为在训练时候调用了eval.

4.prediction_loss_only (:obj:bool, optional, defaults to False):
当执行评估和预测的时候,是否仅仅返回损失
注意因为如果设定为True的话就不能使用其他评估函数了,若使用会报错
在这里插入图片描述

3.关于训练批次的设定

5.per_device_train_batch_size (:obj:int, optional, defaults to 8):
如果有多台设备可以设定每一个GPU训练的批次大小


6.per_device_eval_batch_size (int, optional, defaults to 8):
如果有多台设备可以设定每一个GPU评估的批次大小


*7.gradient_accumulation_steps (:obj:int, optional, defaults to 1):
当模型训练时候显存不够用了,可以设定这个参数。它也是我们上面 *一、2.在减少batch_size开启gradient_accumulation_steps梯度累加,下降至2.4!*中提到的。注意不是对loss累加。


8.eval_accumulation_steps (int, optional):
当然这个设定和7差不多,它是指在将结果移动到CPU之前累积输出张量的预测步骤数。如果不设置,整个预测在移动到CPU之前会累积在GPU上(更快,但需要更多内存)。


*9.num_train_epochs(:obj:float, optional, defaults to 3.0):
要执行的训练epoch的次数(如果不是整数,将执行停止训练前的最后一个epoch的小数部分百分比)。
15.max_steps (int, optional, defaults to -1):
如果设置为正数,则表示要执行的训练step的次数。覆盖`num_train_epochs’。在使用有限可迭代数据集的情况下,训练可能在所有数据还没训练完成时因达到设定的步数而停止。

4.关于优化器的设定


10.adam_beta1 (:obj:float, optional, defaults to 0.9):
11.adam_beta2 (:obj:float, optional, defaults to 0.999):
12.adam_epsilon (float, optional, defaults to 1e-8):
beta1、beta2、epsilon超参数设定


*13.learning_rate (:obj:float, optional, defaults to 5e-5):
AdamW优化器初始化的学习率
一般设定为

learning_rate=2e-5,

*14.weight_decay (float, optional, defaults to 0):

AdamW优化器中,除了bias和LayerNorm权重,如果weight_decay不是零,则应用于所有层
一般设定为

weight_decay=0.01,

15.lr_scheduler_type (str or [SchedulerType], optional, defaults to "linear"):
选择什么类型的学习率调度器来更新模型的学习率。可选的值有:
“linear”、“cosine”、“cosine_with_restarts”、“polynomial”、“constant”、“constant_with_warmup”


16.warmup_ratio (:obj:float, optional, defaults to 0.0):
线性预热从0达到learning_rate时,每步学习率的增长率


17.warmup_steps (int, optional, defaults to 0):
线性预热从0达到learning_rate时,预热阶段的步数,它会覆盖warmup_ratio的设置


18.max_grad_norm (float, optional, defaults to 1.0):
最大梯度范数(用于梯度剪裁)

5.关于log的超参数设定

19.log_level (str, optional, defaults to passive):
设置主进程上使用的日志级别。可选择的值:
‘debug’、‘info’、‘warning’、‘error’ 、‘critical’、‘passive’ (不设置任何值,由应用进行设置)


20.log_level_replica (str, optional, defaults to passive):
控制训练过程中副本节点的日志级别,设置参数和log_level一样


21.log_on_each_node (bool, optional, defaults to True):
在多节点分布式训练中,是每个节点使用“log_level”进行一次日志记录,还是仅在主节点


*22.logging_dir (str, optional):
日志目录,默认记录在:output_dir/runs/,注意output_dir是上面自己设定的名字


23.logging_strategy (str or [~trainer_utils.IntervalStrategy], optional, defaults to "steps"):
训练期间log打印的频率,可选的值有:

  • “no”:训练期间不记录日志
  • “steps”:每一个logging_steps`阶段之后都记录日志
  • “epoch”:每一个epoch之后记录日志

24.logging_first_step (bool, optional, defaults to False):
global_step 表示训练的全局步数。当训练开始时,global_step 被初始化为 0,每次更新模型时,global_step 会自动递增。是否打印日志和评估第一个global_step


*25.logging_steps (int, optional, defaults to 500):
如果 logging_strategy="steps",则在日志中更新step的数量。和 *23 搭配使用。
一般设定为:

logging_steps=10, 

26.logging_nan_inf_filter (bool, optional, defaults to True):

是否在日志中过滤掉 naninf 损失,如果设置为 True,每步的损失如果是 nan或者inf将会被过滤,将会使用平均损失记录在日志当中。

6.关于训练结果保存参数设定

27.save_strategy (str or [~trainer_utils.IntervalStrategy], optional, defaults to "steps"):
训练过程中,checkpoint的保存策略,可选择的值有:

  • :obj:"no": 训练过程中,不保存checkpoint.
  • :obj:"epoch": 每个epoch完成之后保存checkpoint.
  • :obj:"steps": 每个save_steps完成之后checkpoint`.

一般设定为

save_strategy="epoch",           # 每一轮都保存一个模型

28.save_steps (int, optional, defaults to 500):
如果`save_strategy=“steps”,则通过设定save_steps 保存的更新步骤数


29.save_total_limit (int, optional):
如果设置了值,则将限制checkpoint的总数量,output_dir里面超过数量的老的checkpoint将会被删掉
一般设定为

 save_total_limit=3,              # 最大保存模型的数量

30.save_on_each_node (:obj:bool, optional, defaults to :obj:False):
在进行多节点分布式训练时,是否在每个节点上保存模型和权重,还是仅在主节点上保存。当不同节点使用相同的存储时,不激活此超参数,文件将以相同的名称保存在每个节点上。

7.一些关于网络训练的超参数设定

  1. no_cuda (:obj:bool, optional, defaults to :obj:False):
    是否不使用CUDA,默认false

32.seed (:obj:int, optional, defaults to 42):
将在训练开始时设置的随机种子。为了确保跨运行的再现性,请使用:func:`~transformer . trainer。如果模型有一些随机初始化的参数,模型初始化函数来实例化模型。


33.bf16 (:obj:bool, optional, defaults to :obj:False):
是否使用bf16 16位(混合)精度训练代替32位训练。
34.fp16 (:obj:bool, optional, defaults to :obj:False):
是否使用fp16 16位(混合)精度训练代替32位训练。
35. fp16_opt_level (:obj:str, optional, defaults to ‘O1’):
对于:obj: ‘ fp16 ’训练,Apex AMP优化水平在[‘O0’, ‘O1’, ‘O2’, and ‘O3’]选择.
37. fp16_backend (:obj:str, optional, defaults to :obj:"auto"):
使用‘半精度后端’。
38.half_precision_backend (:obj:str, optional, defaults to :obj:"auto"):
后端用于混合精度训练。必须是:obj: ‘auto’、obj: ‘amp’或:obj: ‘apex’中的一个。:obj: ‘auto’将根据检测到的PyTorch版本使用AMP或APEX,而其他选择将强制请求后端。
39.bf16_full_eval (:obj:bool, optional, defaults to :obj:False):
是否使用full bfloat16计算而不是32位。这将更快并节省内存,但可能会损害度量值。
40.fp16_full_eval (:obj:bool, optional, defaults to :obj:False):
是否使用完整的float16计算而不是32位。这将更快并节省内存,但可能会损害度量值。
41.tf32 (:obj:bool, optional):
是否启用tf32模式,仅在Ampere和更新的GPU架构中使用。


42.local_rank (:obj:int, optional, defaults to -1):
分布式训练过程的等级。
43.xpu_backend (:obj:str, optional):
用于xpu分布式训练的后端。必须是:obj: ‘ “mpi“ ’或:obj: ‘ ”ccl” ’之一。
44.tpu_num_cores (:obj:int, optional):
在tpu上训练时,tpu核数(由启动器脚本自动传递)。


45.dataloader_drop_last (:obj:bool, optional, defaults to :obj:False):
是否删除最后一个不完整的批处理(如果数据集的长度不能被批处理大小整除)。


46.eval_steps (:obj:int, optional):
如果:obj: ’评估策略= ’ steps ',则两次评估之间的更新步数。如果没有设置,默认值与:obj: ‘ logging steps ’相同。


47.dataloader_num_workers (:obj:int, optional, defaults to 0):
用于数据加载的子进程数(仅限PyTorch)。0表示数据将在主进程中加载。


48.past_index (:obj:int, optional, defaults to -1):
一些模型如:doc:TransformerXL <../model_doc/transformerxl> or :doc:XLNet <../model_doc/xlnet> ,他们能够利用过去的隐藏状态进行预测。如果此参数设置为正int,则‘Trainer’将使用相应的输出(通常是索引2)作为过去的状态,并在关键字参数‘mems’下的下一个训练步骤中将其提供给模型。


49.run_name (:obj:str, optional):
运行的描述符。通常用于‘ wandb https://www.wandb.com/ ’日志记录。
50.disable_tqdm (:obj:bool, optional):
是否禁用由:class:`~transformer .notebook生成的tqdm进度条和度量表。NotebookTrainingTracker '在Jupyter笔记本。如果日志级别设置为警告或更低(默认),则默认为:obj: ‘ True ‘,否则:obj: ’ False ’。


51.remove_unused_columns (:obj:bool, optional, defaults to :obj:True):
如果使用:obj: ’ datassets,是否自动删除模型forward未使用的列。


52.label_names (:obj:List[str], optional):
与标签对应的输入字典中的键列表。默认为:obj: [“labels”] ,除非使用的模型是:obj: ’ xxxforquestionanswer ’中的一个,在这种情况下,它将默认为:obj: [“start positions”, “end positions”] 。

8.评估输出模型的一些参数

*53.load_best_model_at_end (:obj:bool, optional, defaults to :obj:False):
是否在训练结束时加载训练过程中找到的最佳模型。
注意:当设置obj: ‘ True ’时,参数:obj: ‘ save_strategy ’需要与:obj: ‘ eval_strategy ’相同,在为“steps”的情况下,obj: ‘ save_steps ’必须是:obj: ‘ eval_steps ’的整数倍。
在这里插入图片描述
*54.metric_for_best_model (:obj:str, optional):
与:obj: ‘ load best model at end ’结合使用,指定用于比较两个不同模型的度量。必须是由求值返回的度量的名称,带或不带前缀:obj: ‘ "eval " ’。如果未指定,将默认为:obj: ’ “loss” ‘,并且:obj: ’ load best model at end=True '(使用评估损失)。
如果您设置此值,obj:greater_is_better 将default to :obj:True。但如果希望评估指标越低代表优秀,此时要将:obj:greater_is_better 设置为False

*55.greater_is_better (:obj:bool, optional):
结合使用:obj: ‘ load best model at end ’和:obj: ‘ metric for best model ’来指定更好的模型是否应该有更大的度量。
obj:True 当 :obj:metric_for_best_model 设定的评估指标不是obj:"loss" or obj:"eval_loss".
obj:False 当 :obj:metric_for_best_model 未设定, 或设定为 :obj:"loss" or :obj:"eval_loss".

load_best_model_at_end=True, #是否在最后打印输出
metric_for_best_model="loss",
greater_is_better=False,

或者

load_best_model_at_end=True, #是否在最后打印输出
metric_for_best_model="f1",
greater_is_better=True,#或者注释此横

9.恢复训练时的一些参数设定

*56.ignore_data_skip (:obj:bool, optional, defaults to :obj:False):
恢复训练时,是否跳过epoch和batch以获得与上一次训练相同阶段的数据加载。如果设置为:obj: ’ True ',训练将更快地开始(因为跳过步骤可能需要很长时间),但不会产生与中断训练相同的结果。


57.sharded_ddp (:obj:bool, :obj:str or list of :class:~transformers.trainer_utils.ShardedDDPOption, optional, defaults to :obj:False):
使用来自‘ FairScale https://github.com/facebookresearch/fairscale ’的Sharded DDP训练(仅在分布式训练中)
下列选项的列表:

  • :obj:"simple": to use first instance of sharded DDP released by fairscale (:obj:ShardedDDP) similar
    to ZeRO-2.
  • :obj:"zero_dp_2": to use the second instance of sharded DPP released by fairscale
    (:obj:FullyShardedDDP) in Zero-2 mode (with :obj:reshard_after_forward=False).
  • :obj:"zero_dp_3": to use the second instance of sharded DPP released by fairscale
    (:obj:FullyShardedDDP) in Zero-3 mode (with :obj:reshard_after_forward=True).
  • :obj:"offload": to add ZeRO-offload (only compatible with :obj:"zero_dp_2" and :obj:"zero_dp_3").

10.其他配置设定

58.deepspeed (:obj:str or :obj:dict, optional):
查看“深度速度https://github.com/microsoft/deepspeed”。该值要么是DeepSpeed json配置文件的位置(例如,‘ ’ ds config.json ‘ ‘),要么是已经加载的json文件:obj: ’ dict ’”。


59.label_smoothing_factor (:obj:float, optional, defaults to 0.0):
要使用的标签平滑因子。0表示没有标签平滑,否则底层的onehot-encoded标签将从0s和1s分别变为:obj:label_smoothing_factor/num_labels 和:obj:1 - label_smoothing_factor + label_smoothing_factor/num_labels(obj: ‘标签平滑因子/个数标签’和:obj: ‘ 1 -标签平滑因子+标签平滑因子/个数标签’。)


60.debug (:obj:str or list of :class:~transformers.debug_utils.DebugOption, optional, defaults to :obj:""):
启用一个或多个调试特性。
可能的选项有:

  • :obj:"underflow_overflow": detects overflow in model’s input/outputs and reports the last frames that
    led to the event
  • :obj:"tpu_metrics_debug": print debug metrics on TPU

61.adafactor (:obj:bool, optional, defaults to :obj:False):
是否使用:class:~转换器。用:class:transformers.AdamW`代替:class:`transformers.AdamW`。


62.group_by_length (:obj:bool, optional, defaults to :obj:False):
是否将训练数据集中大致相同长度的样本分组在一起(以最小化填充应用并提高效率)。仅在应用动态填充时有用。
63.length_column_name (:obj:str, optional, defaults to :obj:"length"):
预先计算长度的列名。如果列存在,按长度分组将使用这些值,而不是在列车启动时计算它们。忽略,除非:obj: ‘ group by length ’为:obj: ‘ True ’且数据集是:obj: ‘ dataset ’的实例。


64.report_to (:obj:str or :obj:List[str], optional, defaults to :obj:"all"):
要向其报告结果和日志的集成列表。支持的平台有:obj: ’ ’ azure ml ’ ',:obj: ’ ’ comet ml ’ ',:obj: ’ ’ mlflow ’ ',:obj: ’ ’ tensorboard ‘ ’和:obj: ’ ’ wandb ‘ ’。使用:obj: ‘ “all“ ’报告安装的所有集成,使用:obj: ‘ ”none” ’报告没有集成。


65.ddp_find_unused_parameters (:obj:bool, optional):
当使用分布式训练时,标志:obj: ‘查找未使用的参数’的值传递给:obj: ‘ distributeddataparparallel’。如果使用梯度检查点默认为:obj: ‘ False ‘,否则默认为:obj: ’ True ’。


66.dataloader_pin_memory (:obj:bool, optional, defaults to :obj:True):
是否要将内存固定在数据加载器中。默认为:obj: ‘ True ’。


67.skip_memory_metrics (:obj:bool, optional, defaults to :obj:True):
是否跳过向指标添加内存分析器报告。这在默认情况下被跳过,因为它减慢了训练和评估的速度。

11.关于hub的一些参数设定

68.push_to_hub (:obj:bool, optional, defaults to :obj:False):
训练后是否将训练好的模型上传到hub。如果激活了此功能,并且存在:obj: ’ output dir ‘,则它需要是:class: ’ ~转换到的存储库的本地克隆。Trainer’将被推。


69.resume_from_checkpoint (:obj:str, optional):
指向具有模型有效检查点的文件夹的路径。此参数不被:class:`~transformer直接使用。


70.hub_model_id (:obj:str, optional):
要与本地“输出目录”保持同步的存储库的名称。它可以是一个简单的模型ID,在这种情况下,模型将被推入您的名称空间。否则,它应该是整个存储库名称,例如:obj: ‘ “user name/model“ ‘,这允许您使用:obj: ’ ”organization name/model” ’向您所属的组织推送。默认为:obj: ‘用户名/输出目录名’,其中‘输出目录名’是:obj: ‘输出目录’的名称。

71.hub_strategy (:obj:str or :class:~transformers.trainer_utils.HubStrategy, optional, defaults to :obj:"every_save"):
定义推送到Hub的内容的范围以及何时推送。

  • :obj:"end": push the model, its configuration, the tokenizer (if passed along to the
    :class:~transformers.Trainer) and a draft of a model card at the end of training.
  • :obj:"every_save": push the model, its configuration, the tokenizer (if passed along to the
    :class:~transformers.Trainer) and a draft of a model card each time there is a model save. The pushes are asynchronous to not block training, and in case the save are very frequent, a new push is only attempted if the previous one is finished. A last push is made with the final model at the end of training.
  • :obj:"checkpoint": like :obj:"every_save" but the latest checkpoint is also pushed in a subfolder named last-checkpoint, allowing you to resume training easily with :obj:trainer.train(resume_from_checkpoint="last-checkpoint").
  • :obj:"all_checkpoints": like :obj:"checkpoint" but all checkpoints are pushed like they appear in the output folder (so you will get one checkpoint folder per folder in your final repository)

72.hub_token (:obj:str, optional):
用于将模型推送到Hub的令牌。将默认使用:obj: ‘ huggingface-cli login ’获取的缓存文件夹中的令牌。


73.gradient_checkpointing (:obj:bool, optional, defaults to :obj:False):
如果为True,则使用梯度检查点以牺牲较慢的向后传递来节省内存。

二、常用的参数总结

train_args = TrainingArguments(
    output_dir="./checkpoints",
    evaluation_strategy="epoch",
    per_device_train_batch_size=32,#一般根据显存可以设置大一点
    per_device_eval_batch_size=32,#一般根据显存可以设置大一点
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=2,#当然这个可以设置200,这里只是举例所以设置的比较小
    logging_steps=10, #没几次打印一次log日志
    save_strategy="epoch",#每个批次保存一下
    save_total_limit=3,#最多保存参数的数量
    load_best_model_at_end=True, #是否在最后打印输出
    metric_for_best_model="f1",
)
trainer =Trainer(
    model=model,
    args=train_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
    compute_metrics=eval_metric
)

http://www.kler.cn/a/470789.html

相关文章:

  • conda指定路径安装虚拟python环境
  • linux-27 发行版以及跟内核的关系
  • mv指令详解
  • 【Linux】文件的压缩与解压
  • 卸载wps后word图标没有变成白纸恢复
  • 网络协议安全的攻击手法
  • 中电金信携手华为发布“全链路实时营销解决方案”,重塑金融营销数智新生态
  • 设计模式-结构型-适配器模式
  • flutter 专题二十四 Flutter性能优化在携程酒店的实践
  • 计算机毕业设计Python+Vue.js游戏推荐系统 Steam游戏推荐系统 Django Flask 游 戏可视化 游戏数据分析 游戏大数据 爬虫
  • AI巡检系统在安全生产管理中的创新应用
  • 游戏引擎学习第74天
  • Redis 数据库源码分析
  • Opencv实现Sobel算子、Scharr算子、Laplacian算子、Canny检测图像边缘
  • stm32 移植RTL8201F(正点原子例程为例)
  • Easyexcel-4.0.3读取文件内容时遇到“java.lang.ClassNotFoundException”
  • 《从入门到精通:蓝桥杯编程大赛知识点全攻略》(二)-递归实现组合型枚举、带分数问题
  • libaom 源码分析线程结构
  • uni-app 页面生命周期及组件生命周期汇总(Vue2、Vue3)
  • 特征点检测与匹配——MATLAB R2022b
  • 2025资源从哪里来!
  • vue3-dom-diff算法
  • Postman接口测试02|接口用例设计
  • 云原生周刊:K8s 生态系统的五大趋势预测
  • IDEA中Lombok不能使用,找不到get方法
  • 乾元通渠道商中标玉溪市自然灾害应急能力提升项目