当前位置: 首页 > article >正文

pytorch基础 nn.embedding

nn.Embedding 是 PyTorch 中的一个模块,用于创建嵌入层(embedding layer),它将离散的索引(例如词汇表中的单词索引)映射为固定大小的稠密向量。这是许多 NLP 模型(包括 Transformer)中的基本组件。


示例用法:

import torch
import torch.nn as nn

# 定义一个嵌入层
vocab_size = 10000  # 词汇表大小
embedding_dim = 512  # 嵌入向量的维度
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

# 示例输入:一批词索引序列
# 形状:(batch_size, sequence_length)
input_indices = torch.tensor([[1, 2, 3], [4, 5, 6]])  # 2 个序列,每个序列长度为 3

# 获取嵌入向量
output_embeddings = embedding_layer(input_indices)
print(output_embeddings.shape)  # 输出形状:(batch_size, sequence_length, embedding_dim)

输出:

输出的张量形状为 (batch_size, sequence_length, embedding_dim)。对于上面的例子:

  • batch_size = 2

  • sequence_length = 3

  • embedding_dim = 512

因此,输出形状为 (2, 3, 512)


主要参数:

  1. num_embeddings:词汇表的大小(唯一词的数量), 通常比较大。

  2. embedding_dim:嵌入向量的维度。

  3. padding_idx:(可选)用于填充词的索引。如果提供,该索引对应的嵌入向量将固定为零向量,并且在训练过程中不会更新。

  4. max_norm:(可选)如果指定,嵌入向量将被重新归一化,使其范数不超过该值。

  5. norm_type:(可选)如果指定了 max_norm,则使用该范数类型(默认是 2,表示 L2 范数)。

  6. scale_grad_by_freq:(可选)如果为 True,梯度将根据批次中词的频率进行缩放。


在 Transformer 中的使用:

在 Transformer 类中,nn.Embedding 通常用于创建:

  1. 源嵌入:用于输入(源)序列。

  2. 目标嵌入:用于输出(目标)序列。

  3. 位置编码:为嵌入添加位置信息(可选,因为位置编码也可以单独添加)。

以下是在 Transformer 类中的使用示例:

class Transformer(nn.Module):
    def __init__(self, n_src_vocab, n_trg_vocab, src_pad_idx, trg_pad_idx, d_word_vec=512):
        super().__init__()
        
        # 源和目标嵌入层
        self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=src_pad_idx)
        self.trg_word_emb = nn.Embedding(n_trg_vocab, d_word_vec, padding_idx=trg_pad_idx)
        
        # Transformer 的其他组件...

这是对 nn.Embedding 及其在 Transformer 模型中的作用的简要介绍。如果你需要进一步的解释或示例,请告诉我!


http://www.kler.cn/a/563238.html

相关文章:

  • H3C商场无线零售解决方案技术资料汇总集
  • 使用Python开发PDF文本提取工具
  • 从零开始用react + tailwindcss + express + mongodb实现一个聊天程序(四) 实现注册功能
  • 基于Matlab的语音识别系统设计
  • nginx基础篇 - 什么是代理服务器?正向代理和反向代理的区别?
  • 【语法】C++的string
  • Linux 权限系统和软件安装(二):深入理解 Linux 权限系统
  • Redis 高可用性:如何让你的缓存一直在线,稳定运行?
  • HTTP非流式请求 vs HTTP流式请求
  • Linux系统之DHCP网络协议
  • 深入探讨K8s资源管理和性能优化
  • Python 网络编程全攻略:核心知识与实战应用、高级应用场景、问题剖析、行业未来趋势等全解析
  • SpringBoot接入DeepSeek(硅基流动版)+ 前端页面调试
  • 【论文笔记-ECCV 2024】AnyControl:使用文本到图像生成的多功能控件创建您的艺术作品
  • 二十三种设计模式详解
  • 一周掌握Flutter开发--4、导航与路由
  • 清华大学DeepSeek赋能职场教程下载,清华大学DeepSeek文档下载(完成版下载)
  • 银河麒麟高级服务器操作系统通用rsync禁止匿名访问操作指南
  • RIP-AV:使用上下文感知网络进行视网膜动脉/静脉分割的联合代表性实例预训练
  • 【深度学习神经网络学习笔记(三)】向量化编程