llama源码学习·model.py[3]ROPE旋转位置编码(3)源码中的广播机制
一.源码注释
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
'''
这个函数的目的是为了确保freqs_cis可以根据广播规则与x进行元素级别的运算,特别是在x的维度数量大于2时。
'''
# 获取x的维度数量
ndim = x.ndim
# 确保x至少有两个维度
assert ndim > 1
# freqs_cis的形状与x的第二和最后一个维度相匹配
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
# 遍历x的每个维度,并为第二和最后一个维度保留其原始大小,而为所有其他维度赋值1。
# 这是为了确保广播时,除了这两个特定维度外,其他所有维度都能自动扩展。
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
# 使用view函数来重塑freqs_cis的形状以匹配新的形状
return freqs_cis.view(*shape)
二、举例说明
freqs_cis = torch.randn(3,4)
print(freqs_cis.shape)
out: torch.Size([3, 4])
x = torch.randn(2, 3, 4)
print(x.shape)
out: torch.Size([2, 3, 4])
# 调用广播函数
reshaped_freqs_cis = reshape_for_broadcast(freqs_cis, x)
print(reshaped_freqs_cis.shape)
out: torch.Size([1, 3, 4])
# 求和
s = reshaped_freqs_cis + x
print(s.shape)
out: torch.Size([2, 3, 4])
原文地址:https://blog.csdn.net/m0_72851153/article/details/146407522
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.kler.cn/a/594533.html 如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.kler.cn/a/594533.html 如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!