当前位置: 首页 > article >正文

创建Dataloader基础篇【一】

概述

在transformers trainer训练、评估模型中,大致根据以下过程加载与处理训练、评估数据集:

  1. 使用dataset.Dataset加载数据
  2. 使用Dataset.map与自定义的convert_examples_to_features函数处理Dataset中的每一行数据
  3. 定义sampler,在迭代Dataloader过程中,本质是迭代sampler。默认auto-batch模式下,sampler在每次迭代过程中,会返回一个batch的索引数值(indices),然后根据indices从Dataloader.dataset中取数据(fetch)。e.g. [self.dataset[index] for index in batch_indices]
  4. 将第三步取到的数据,喂到collator_fn中,组装成tensor类型,并返回组装后的结果。

因为trainer默认是以aotu_batch方式加载与处理数据,因此本部分仅记录aotu_batch方式。另外本文仅记录trainer中创建dataloader的基础过程。对于一些个性化加载与处理、如长文档文本分类,如有必要,会另起一篇文章再进行记录。

实例

# set up
from typing import List, Dict, Union

from datasets import Dataset
from transformers import default_data_collator
from transformers import BertTokenizer
from torch.utils.data import DataLoader, RandomSampler, BatchSampler, SequentialSampler

from config import CKP  # huggingface 中预训练模型下载到本地的地址

# emotion classification demo
x = [{"texts": "我爱中国。", "labels": 1}, {"texts": "今天天气真糟糕!", "labels": 0}] * 3

# 可以使用datasets.load_dataset函数,将样本数据存储为json格式,每一条样本占据一行
examples: Dataset = Dataset.from_list(x)
tokenizer: BertTokenizer = BertTokenizer.from_pretrained(CKP)

def convert_examples_to_features(exams: Dict[str, List[Union[str, int]]]):
    return tokenizer(exams["texts"], padding=True, max_length=20, truncation=True)

# map函数中的batch=True并不影响最终结果,只是影响convert_examples_to_features的签名|定义
dataset = examples.map(convert_examples_to_features, with_indices=False, with_rank=False, batched=True,
                       batch_size=1, remove_columns=["texts"])

# 验证sampler
sequence_sampler = SequentialSampler(dataset)
print(f"sequence sampler: {list(sequence_sampler)}")

random_sampler = RandomSampler(dataset)
print(f"random sampler: {list(random_sampler)}")

batch_sampler = BatchSampler(random_sampler, batch_size=2, drop_last=False)
print(f"batch sampler: {list(batch_sampler)}")

# 在convert_examples_to_features已经对input_ids进行了pad,所以使用default_data_collator
# 如果仅进行编码,即padding=False, 此处使用transformers.DataCollatorWithPadding
dataloader = DataLoader(dataset, batch_size=1, collate_fn=default_data_collator)

# add breakpoint in here, you will see
# step1. get next batch indices
# step2. fetch data according batch indices
# step3. collator data by collator_fn and return batch
for batch in dataloader:
    print(batch)

参考资料

datasets.Dataset.map方法学习笔记
transformers中的data_collator
【pytorch】Dataloader学习笔记


http://www.kler.cn/news/148849.html

相关文章:

  • 拆解按摩器:有意思的按键与LED控制电路,学习借鉴一下!
  • <Linux>(极简关键、省时省力)《Linux操作系统原理分析之Linux 进程管理 9》(13)
  • IELTS学习笔记_grammar_新东方
  • 基于MBC调制方法的准Z源三相逆变器Simulink建模与仿真
  • 目标检测YOLO系列从入门到精通技术详解100篇-【目标检测】特征点检测与匹配
  • MySQL慢查询
  • Flink Flink中的合流
  • Python---lambda表达式
  • 交换机的VRRP主备配置例子
  • 计网Lesson3 - 计算机网络评价指标与封包解包
  • 别再让假的fiddler教程毒害你了,来看这套最全最新的fiddler全工具讲解
  • 基于C#实现Kruskal算法
  • DGL在异构图上的GraphConv模块
  • 【Redisson】基于自定义注解的Redisson分布式锁实现
  • 堆的应用(堆排序、Top-K问题)
  • 大模型的开源闭源
  • linux -系统通用命令查询
  • viple模拟器使用(四):unity模拟器中实现沿右墙迷宫算法
  • 门面模式-C++实现
  • java中IO知识点概念
  • GoLong的学习之路,进阶,RabbitMQ (消息队列)
  • Jmeter-分布式压测(远程启动服务器,windows)
  • 代码随想录-刷题第九天
  • 通义千问 Qwen-7B-Chat-Int4 模型本地化部署
  • 机器人规划算法——movebase导航框架源码分析
  • Linux的软件安装
  • linaro交叉编译工具链下载与使用笔记
  • Nacos 端口偏移量说明
  • java文件上传以及使用阿里云OSS
  • 【ArcGIS Pro微课1000例】0038:基于ArcGIS Pro的人口密度分析与制图