当前位置: 首页 > article >正文

大模型输出的outputs为什么要取[0](即outputs[0])

目录

      • 引言
      • 1、outputs 内容
      • 2、outputs[0]含义
      • 3、outputs[0] 的具体内容是什么
      • 4、outputs 中其他可能的内容
      • 总结:

引言

在模型推理过程中,我们常常需要编写如下代码:

inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"].cuda()
        outputs = self.model.generate(inputs,do_sample=False,max_length=4096)
        output = self.tokenizer.decode(outputs[0])

可以看出,outputs 是由 self.model.generate() 方法生成的结果,那么outputs中包含了什么呢?为什么在解码时用的是outputs[0]而不是直接用outputs

1、outputs 内容

outputs 包含的是模型生成的token序列,通常是多维张量,由模型调用函数self.model.generate()输出

self.model.generate() 的输出 (outputs)

  • self.model.generate(inputs, do_sample=False, max_length=4096) 调用 Hugging Face 的 AutoModelForCausalLM 模型的 generate 方法。
  • 该方法根据给定的输入 inputs(已被 tokenizer 编码为 token 的张量)生成一段输出。输出的形状通常是一个张量,它的维度是 [batch_size, sequence_length],表示每个输入(batch_size)所对应的生成的 token 序列(sequence_length)。

例如,如果 batch_size=1sequence_length=50,那么 outputs 可能是一个形状为 [1, 50] 的张量,其中每个数字代表生成文本中的一个 token(这些 token 对应模型词汇表中的词汇)。

2、outputs[0]含义

outputs[0] 是模型为当前输入生成的第一个(也是唯一的一个,因为这里没有 batch)token 序列。

  • 在这个代码中,generate() 方法生成的结果是一个二维张量,其中 outputs[0] 代表了第一个输入样本生成的完整序列(一个张量,其中包含一系列 token 的索引)。
  • 举个例子,如果 outputs 的形状是 [1, 50],那么 outputs[0] 就是一个形状为 [50] 的张量,表示生成的 50 个 token 的索引。

3、outputs[0] 的具体内容是什么

outputs[0] 是一个token 索引的列表,每个数字代表一个 token 的 ID,它们与模型的词汇表中的词汇一一对应。

例如,假设 outputs[0][101, 1234, 567, 90, 102],每个数字是模型词汇表中的一个 token ID。这个序列在经过 self.tokenizer.decode() 后会变成一个可读的文本字符串。

因此,outputs[0]一方面是取输出内容的第一条,另一方面是将数据维度降维了。 当输入prompt有多条内容的时候,输出也会对应有多条。

4、outputs 中其他可能的内容

如果 generate()的参数设置不同(比如 num_return_sequences > 1 或者do_sample=True),outputs可能包含多个生成的序列。每个序列代表模型生成的一个可能的输出版本。在这种情况下,outputs 的形状会是 [num_return_sequences, sequence_length],表示多个生成的序列。

总结:

  • outputs 是一个包含模型生成的 token ID 序列的张量,维度是 [batch_size, sequence_length]
  • outputs[0] 是该张量中的第一个生成序列,包含一个 token ID 的列表。
  • 通过 tokenizer.decode(outputs[0]),我们可以将 token ID 转换回自然语言文本。

http://www.kler.cn/news/359828.html

相关文章:

  • STM32G4系列MCU的启动项配置
  • CIM系统:智慧城市的数字基石
  • 鸿蒙--进度条通知
  • Springboot+vue图书商城购物系统【附源码】
  • 列表(list)、元组(tuple)、字典(dictionary)、array(数组)-numpy、DataFrame-pandas 、集合(set)
  • Python知识点:基于Python技术和工具,如何使用Chainlink进行链下数据访问
  • 01 Druid未授权错误及解决方案
  • 命令行工具cURL 的用法
  • Python知识点:基于Python工具,如何使用Web3.py进行以太坊智能合约开发
  • Docker compose 安装Jenkins
  • Spring Boot助力中小型医院网站开发
  • ACM与蓝桥杯竞赛指南 基本输入输出格式三
  • 《Python游戏编程入门》注-第2章3
  • shell脚本每日一练1
  • 【独家:AI编程助手Cursor如何revolutionize Java设计模式学习】
  • 离线电脑 Visual Studio Community 2017:您的许可证已过期
  • PCL 基于FPFH特征描述子获取点云对应关系
  • Maven 项目管理工具
  • 大数据新视界 --大数据大厂之大数据与边缘计算的协同:实时分析的新前沿
  • 低光照图像增强:全局与局部上下文建模