当前位置: 首页 > article >正文

llama源码学习·model.py[3]ROPE旋转位置编码(3)源码中的广播机制

一.源码注释

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    '''
       这个函数的目的是为了确保freqs_cis可以根据广播规则与x进行元素级别的运算,特别是在x的维度数量大于2时。
       '''
    # 获取x的维度数量
    ndim = x.ndim
    
    # 确保x至少有两个维度
    assert ndim > 1
    
    # freqs_cis的形状与x的第二和最后一个维度相匹配
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    
    # 遍历x的每个维度,并为第二和最后一个维度保留其原始大小,而为所有其他维度赋值1。
    # 这是为了确保广播时,除了这两个特定维度外,其他所有维度都能自动扩展。
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    
    # 使用view函数来重塑freqs_cis的形状以匹配新的形状
    return freqs_cis.view(*shape)

二、举例说明

freqs_cis = torch.randn(3,4)
print(freqs_cis.shape)

out: torch.Size([3, 4])

x = torch.randn(2, 3, 4)
print(x.shape)

out: torch.Size([2, 3, 4])

# 调用广播函数
reshaped_freqs_cis = reshape_for_broadcast(freqs_cis, x)
print(reshaped_freqs_cis.shape)

out: torch.Size([1, 3, 4])

# 求和
s = reshaped_freqs_cis + x
print(s.shape)

out: torch.Size([2, 3, 4])

原文地址:https://blog.csdn.net/m0_72851153/article/details/146407522
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.kler.cn/a/594533.html

相关文章:

  • C++ 之多态 ------ C++面向对象三大特性之一【三大特性:封装、继承、多态】
  • WIFI p2p连接总结
  • MySQL 入门大全:查询语言分类
  • HTML 写一个计算器
  • 基于YOLOv8深度学习的智能小麦害虫检测识别系统
  • LabVIEW界面布局优化
  • Elasticsearch text字段检索方法
  • 使用Python SDK在亚马逊云科技上调用AI模型构建生成式AI应用
  • 【SpringMVC】SpringMVC拦截器,统一异常处理,文件上传与下载
  • 【Json—RPC框架】:宏定义不受命名空间限制,续行符的错误使用造成的bug
  • 【JVM】内存区域划分,类加载机制和垃圾回收机制
  • 【RabbitMQ】RabbitMQ中死信交换机是什么?延迟队列呢?有哪些应用场景?
  • Gitlab服务器数据迁移及版本升级
  • Web-Machine-N7靶机:渗透测试与漏洞挖掘的实战利器
  • Java+Html实现前后端客服聊天
  • cmake --build . --config Release和make是1个意思吗
  • Spring Boot + Spring Integration整合MQTT打造双向通信客户端
  • Unity TextMeshPro中显示建筑特殊符号
  • 全局上下文网络GCNet:创新架构提升视觉识别性能
  • 游戏引擎学习第170天