深度学习Save Best、Early Stop
一、Save Best
今天的大模型,在训练过程中可能会终止,但是模型其实是可以接着练的,假设GPU挂了,可以接着训练,在原有的权重上,训练其实就是更新w,如果前面对w进行了存档,那么可以从存档的比较优秀的地方进行训练。
下面代码默认每500步保存权重,第二个参数是选择保存最佳权重
class SaveCheckpointsCallback:
def __init__(self, save_dir, save_step=500, save_best_only=True):
"""
Save checkpoints each save_epoch epoch.
We save checkpoint by epoch in this implementation.
Usually, training scripts with pytorch evaluating model and save checkpoint by step.
Args:
save_dir (str): dir to save checkpoint
save_epoch (int, optional): the frequency to save checkpoint. Defaults to 1.
save_best_only (bool, optional): If True, only save the best model or save each model at every epoch.
"""
self.save_dir = save_dir # 保存路径
self.save_step = save_step # 保存步数
self.save_best_only = save_best_only # 是否只保存最好的模型
self.best_metrics = -1 # 最好的指标,指标不可能为负数,所以初始化为-1
# mkdir
if not os.path.exists(self.save_dir): # 如果不存在保存路径,则创建
os.mkdir(self.save_dir)
def __call__(self, step, state_dict, metric=None):
if step % self.save_step > 0: #每隔save_step步保存一次
return
if self.save_best_only:
assert metric is not None # 必须传入metric
if metric >= self.best_metrics:
# save checkpoints
torch.save(state_dict, os.path.join(self.save_dir, "best.ckpt")) # 保存最好的模型,覆盖之前的模型,不保存step,只保存state_dict,即模型参数,不保存优化器参数
# update best metrics
self.best_metrics = metric
else:
torch.save(state_dict, os.path.join(self.save_dir, f"{step}.ckpt")) # 保存每个step的模型,不覆盖之前的模型,保存step,保存state_dict,即模型参数,不保存优化器参数
二、Early Stop
如果训练着验证集的准确率开始下降或者损失上升,就需要用到早停:
class EarlyStopCallback:
def __init__(self, patience=5, min_delta=0.01):
"""
Args:
patience (int, optional): Number of epochs with no improvement after which training will be stopped.. Defaults to 5.
min_delta (float, optional): Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no improvement. Defaults to 0.01.
"""
self.patience = patience # 多少个step没有提升就停止训练
self.min_delta = min_delta # 最小的提升幅度
self.best_metric = -1
self.counter = 0 # 计数器,记录多少个step没有提升
def __call__(self, metric):
if metric >= self.best_metric + self.min_delta:#用准确率
# update best metric
self.best_metric = metric
# reset counter
self.counter = 0
else:
self.counter += 1 # 计数器加1,下面的patience判断用到
@property #使用@property装饰器,使得 对象.early_stop可以调用,不需要()
def early_stop(self):
return self.counter >= self.patience
三、Tensorboard
# TensorBoard 可视化
pip install tensorboard
训练过程中可以使用如下命令启动tensorboard服务。注意使用绝对路径,否则会报错
```shell
tensorboard --logdir="D:\PycharmProjects\pythondl\chapter_2_torch\runs" --host 0.0.0.0 --port 8848
```