当前位置: 首页 > article >正文

datasets.Dataset.map方法学习笔记

Dataset.map 方法概要

可以将datasets中的Dataset实例看做是一张数据表。map方法会将输入的function按照指定的方式应用在每一行(每一行称为一个example)上。本文采用一下示例进行说明:

from datasets import Dataset  #  datasets.__version__ = '2.13.0'
x = [{"text": "good", "label": 1}, {"text": "bad", "label": 0}, {"text": "great", "label": 2}]
ds = Dataset.from_list(x)

function是map方法的核心,其介绍单独放在下列章节。其它常用参数的说明如下:

  • remove_columns:当function可调用对象处理完数据表中的全部样例之后,可以将指定的列删掉。在文本处理中,一般会将编码之后的text原文删除掉。
  • fn_kwargs: 传入可调用对象function的关键词参数。
  • desc: 自定义的描述性信息。
     

function位置参数

function位置参数接受一个可调用对象,本质是该可调用对象对数据表中的每行进行处理。按照布尔型位置参数with_indices, with_rank, batched的取值, function有8种签名。其中batched表示可调用对象一次处理一行还是多行,with_indices表示是否将样本的索引编号传入可调用对象, with_rank表示是否将进程rank传入可调用对象。

单样本处理(batched=False)

  • 当设置batched=False时,可调用对象依次处理每一行。在默认情况下,可调用对象的签名为function(example: Dict[str, Any]) -> Dict[str, Any],example中的键为数据表中的所有列。可使用input_columns指定哪些列进入example中。可调用对象应该返回一个Dict对象,并使用返回结果更新数据表。

    from typing import Dict, Any
    
    def func(exam: Dict[str, Any]) -> Dict[str, Any]:
    	return {"text_length": len(exam["text"])}
    
    print(ds.map(func).data)
    
  • 如果指定with_indices=True,则可调用对象中应该指定一个位置参数用于接受样本索引编号。

    def func_with_indices(exam, index):
    	return {"text_length": len(exam["text"]), "row_index": int(index)}
    
    print(ds.map(func_with_indices, with_indices=True, batched=False).data)
    
  • 如果指定with_rank=True,则可调用对象中应指定一个位置参数用于接受进程rank。并且应当设置num_proc 大于1,否则进程rank的值是相同的。

    def func_with_rank(exam, rank):
    	return {"text_length": len(exam["text"]), "proc_rank": int(rank)}
    
    print(ds.map(func_with_rank, with_rank=True, batched=False, num_proc=2).data)
    
  • 如果同时指定with_indices=Truewith_rank=True,则可调用对象中接受样本索引编号的位置参数应置于接受进程rank位置参数之前。如def function(example, idx, rank): ...

    def func_with_index_rank(exam, index, rank):
    	return {"text_length": len(exam["text"]), "row_index": int(index), "proc_rank": int(rank)}
    	
    	print(ds.map(func_with_rank, with_indices=True, with_rank=True, batched=False, num_proc=2).data)
    

样本批处理(Batched=True)

当设置batched=True时,可调用对象会对样本进行批处理,批的大小可以通过batch_size控制,默认一个批为1000条样本。此情况下签名应满足function(batch: Dict[str, List]) -> Dict[str, List]。batch中的键仍然是数据表中的列名称,值为多行数据组成的列表。

def batched_func(batch):
	return {"text_length": [len(text) for text in batch["text]]}

print(ds.map(batched_func, batched=True, batch_size=2).data)

 

map方法返回的torch.tensor会被转换为list

x = [{'text': 'good', 'label': torch.tensor(1, dtype=torch.long)},
	 {'text': 'bad', 'label': torch.tensor(0, dtype=torch.long)},
	 {'text': 'great', 'label': torch.tensor(2, dtype=torch.long)}]
ds = Dataset.from_list(x)
print(type(x[0]["label"]))  # torch.Tensor

print(ds.map().data)

def to_tensor(exam):
	return {"label" : torch.tensor(exam["label"], dtype=torch.long)}

print(ds.map(to_tensor).data)  # 结果一致,数据表中的label均为整数型

http://www.kler.cn/a/148528.html

相关文章:

  • 万字长文分析函数式编程
  • 基于yolov8、yolov5的番茄成熟度检测识别系统(含UI界面、训练好的模型、Python代码、数据集)
  • Spring Cloud Contract快速入门Demo
  • Rust 整数
  • SQL练习(2)
  • 3DTiles之i3dm介绍
  • vscode在Windows上安装插件提示错误xhr failed
  • 编程语言发展史:Ruby语言的发展和应用
  • Docker 镜像使用
  • sqlserver写入中文乱码问题
  • Java中的mysql——面试题+答案(数据库设计)——第25期
  • 机器学习的复习笔记2-回归
  • 如何获取高质量的静态住宅IP代理?常见误区与注意事项
  • C语言--每日选择题--Day28
  • 开关电源低温启动测试条件是什么?如何测试开关电源?
  • Micropython STM32F4外部中断实验
  • 【闲读 1】量子论引出对认知的思考
  • docker compose搭建渗透测试vulstudy靶场示例
  • 11.28 知识回顾(Web框架、路由控制、视图层)
  • java基础-IO
  • Jquery动画特效
  • vue项目门店官网页面, 根据视口大小自动跳转页面逻辑(pc --> mobile / mobile -->pc)
  • 【算法】七大经典排序(插入,选择,冒泡,希尔,堆,快速,归并)(含可视化算法动图,清晰易懂,零基础入门)
  • MongoDB安装教程
  • 51单片机制作数字频率计
  • 跨标签页通信的8种方式(下)