当前位置: 首页 > article >正文

从零构建属于自己的GPT系列3:模型训练2(训练函数解读、模型训练函数解读、代码逐行解读)

🚩🚩🚩Hugging Face 实战系列 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在PyCharm中进行
本篇文章配套的代码资源已经上传

从零构建属于自己的GPT系列1:文本数据预处理
从零构建属于自己的GPT系列2:语言模型训练

3 数据加载函数

def load_dataset(logger, args):
    """
    加载训练集
    """
    logger.info("loading training dataset")
    train_path = args.train_path

    with open(train_path, "rb") as f:
        train_list = pickle.load(f)

    # test
    # train_list = train_list[:24]

    train_dataset = CPMDataset(train_list, args.max_len)

    return train_dataset
  1. List item

4 训练函数

def train(model, logger, train_dataset, args):
    train_dataloader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn,
        drop_last=True
    )
    logger.info("total_steps:{}".format(len(train_dataloader)* args.epochs))
    t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochs
    optimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps)
    scheduler = transformers.get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )# 设置warmup

    logger.info('start training')

    train_losses = []   # 记录每个epoch的平均loss
    # ========== start training ========== #
    for epoch in range(args.epochs):
        train_loss = train_epoch(
            model=model, train_dataloader=train_dataloader,
            optimizer=optimizer, scheduler=scheduler,
            logger=logger, epoch=epoch, args=args)
        train_losses.append(round(train_loss, 4))
        logger.info("train loss list:{}".format(train_losses))

    logger.info('training finished')
    logger.info("train_losses:{}".format(train_losses))

5 迭代训练函数

def train_epoch(model, train_dataloader, optimizer, scheduler, logger,
                epoch, args):
    model.train()
    device = args.device
    ignore_index = args.ignore_index
    epoch_start_time = datetime.now()

    total_loss = 0  # 记录下整个epoch的loss的总和
    epoch_correct_num = 0   # 每个epoch中,预测正确的word的数量
    epoch_total_num = 0  # 每个epoch中,预测的word的总数量

    for batch_idx, (input_ids, labels) in enumerate(train_dataloader):
        # 捕获cuda out of memory exception
        try:
            input_ids = input_ids.to(device)
            labels = labels.to(device)
            outputs = model.forward(input_ids, labels=labels)
            logits = outputs.logits
            loss = outputs.loss
            loss = loss.mean()

            # 统计该batch的预测token的正确数与总数
            batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)
            # 统计该epoch的预测token的正确数与总数
            epoch_correct_num += batch_correct_num
            epoch_total_num += batch_total_num
            # 计算该batch的accuracy
            batch_acc = batch_correct_num / batch_total_num

            total_loss += loss.item()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            # 进行一定step的梯度累计之后,更新参数
            if (batch_idx + 1) % args.gradient_accumulation_steps == 0:
                # 更新参数
                optimizer.step()
                # 更新学习率
                scheduler.step()
                # 清空梯度信息
                optimizer.zero_grad()

            if (batch_idx + 1) % args.log_step == 0:
                logger.info(
                    "batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(
                        batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))

            del input_ids, outputs

        except RuntimeError as exception:
            if "out of memory" in str(exception):
                logger.info("WARNING: ran out of memory")
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()
            else:
                logger.info(str(exception))
                raise exception

    # 记录当前epoch的平均loss与accuracy
    epoch_mean_loss = total_loss / len(train_dataloader)
    epoch_mean_acc = epoch_correct_num / epoch_total_num
    logger.info(
        "epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))

    # save model
    logger.info('saving model for epoch {}'.format(epoch + 1))
    model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    model_to_save = model.module if hasattr(model, 'module') else model
    model_to_save.save_pretrained(model_path)
    logger.info('epoch {} finished'.format(epoch + 1))
    epoch_finish_time = datetime.now()
    logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))

    return epoch_mean_loss

从零构建属于自己的GPT系列1:文本数据预处理
从零构建属于自己的GPT系列2:语言模型训练


http://www.kler.cn/a/155929.html

相关文章:

  • 鸿蒙next版开发:ArkTS组件点击事件详解
  • WebGIS三维地图框架--Cesium
  • DevOps工程技术价值流:加速业务价值流的落地实践与深度赋能
  • 利用阿里云下载 WebRTC 源码
  • Sigrity SPEED2000 Power Ground Noise Simulation模式如何查看PDS系统的自阻抗操作指导
  • Java 堆内存管理详解:`-Xms` 和 `-Xmx` 参数的使用与默认内存设置
  • 算法复习,数据结构 ,算法特性,冒泡法动态演示,复杂度,辗转相除法*,寻找最大公因数
  • Win中Redis部署与配置
  • SCAU:判断点是否在圆上
  • QWebChannel 是 Qt 框架中用于在 Web 页面和 Qt 应用程序之间进行通信的类
  • 【doccano】文本标注工具——属性级情感分析标注自己的业务数据
  • 使用SLS日志服务采集Kong网关的日志
  • c语言编程题经典100例——(41~45例)
  • Android textView 显示: STRING_TOO_LARGE
  • 23.12.3日总结
  • 鸿蒙工具DevEco Studio调试Build task failed. Open the Run window to view details.
  • 讲一讲redis的使用
  • WordPress外贸站优化工具,WordPress外贸SEO优化方法
  • iOS Class Guard 成功了,但无法区分差异
  • ssm医药进出口交易系统源码和论文
  • 移除元素、合并两个有序数组(leetcode)
  • 人工智能(pytorch)搭建模型21-基于pytorch搭建卷积神经网络VoVNetV2模型,并利用简单数据进行快速训练
  • Stable Diffusion 系列教程 - 1 基础准备(针对新手)
  • 浅析SD-WAN技术如何加强企业网络安全
  • YOLOv8 区域计数 | 入侵检测 | 人员闯入
  • 编程中常见的技术难题有哪些?By AI