当前位置: 首页 > article >正文

关于torch.nn.Embedding的浅显理解

最近在使用词嵌入向量表示我的数据标签,并且在试图理解torch.nn.Embedding函数。

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)

这里只解释我对前两个参数的理解,这也是我唯二理解的:num_embeddings(int) – size of the dictionary of embeddings,其实就是你给Embedding函数的张量里互不相同的数的个数;embedding_dim (int) – the size of each embedding vector也即生成的词嵌入向量的最后一个维度。For example:

import torch.nn as nn
import torch

known_label_lt = nn.Embedding(3, 10)

label = torch.tensor([
    [1, 0, 1, 0, 1],
    [2, 1, 0, 2, 1],
    [1, 1, 2, 1, 0],
    [1, 1, 0, 1, 2]
]).long() # without .long(), will result in an error. 

state = known_label_lt(label)
print(state.shape)

这里输入的向量label里只能包含三个不同的数:0,1,2 。或者反过来说known_label_lt的第一个参数只能是3,known_label_lt的第二个参数就决定了label的每一个数会被扩展到10维。所以最后生成的词嵌入维度是:

torch.Size([4, 5, 10])

http://www.kler.cn/a/162857.html

相关文章:

  • Redis主从复制(replication)
  • 运行springBlade项目历程
  • Nebula NGQL语言的使用 一
  • Mysql前言
  • MySQL中的事务与锁
  • TCP可靠连接的建立和释放,TCP报文段的格式,UDP简单介绍
  • 初识Linux:权限(1)
  • 手持式安卓主板_PDA安卓板_智能手持终端方案
  • 【C/PTA】结构体专项练习
  • 直面双碳目标,优维科技携手奥意建筑打造绿色低碳建筑数智云平台
  • C++异常剖析
  • C语言精选——选择题Day40
  • 基于AWS Serverless的Glue服务进行ETL(提取、转换和加载)数据分析(二)——数据清洗、转换
  • 创建自定义Docker镜像:一步步指南
  • 一.初始typescript
  • 人大金仓(kingbase)数据库常用sql命令
  • 深度学习之注意力机制
  • Fiddler抓包模拟器(雷电模拟器)
  • 【力扣】160.相交链表
  • 船舶机电设备智能故障诊断系统
  • python3.5安装教程及环境配置,python3.7.2安装与配置
  • 3DCAT+上汽奥迪:打造新零售汽车配置器实时云渲染解决方案
  • Linux 权限管理
  • markdown记录
  • 字符串指令集
  • 渗透测试工具AWVS的全面解析