深度学习:广播机制
广播机制(Broadcasting)是 PyTorch(以及其他深度学习框架如 NumPy)中的一种强大功能,它允许不同形状的张量进行逐元素操作,而不需要显式地扩展张量的维度。广播机制通过自动扩展较小的张量来匹配较大张量的形状,从而使得逐元素操作能够顺利进行。
广播机制的基本规则
- 维度对齐:从后往前比较两个张量的维度。如果两个张量在某个维度上的大小相等,或者其中一个维度的大小为1,则这两个维度是兼容的。
- 扩展维度:如果两个张量在某个维度上的大小不相等且都不为1,则无法广播。否则,大小为1的维度会被扩展以匹配另一个张量在该维度上的大小。
具体例子
假设我们有两个张量 A 和 B,它们的形状分别为 (3, 1, 4) 和 (1, 5, 4)。我们希望对这两个张量进行逐元素加法操作。
-
维度对齐
从后往前比较两个张量的维度:
最后一个维度:A 和 B 的最后一个维度都是 4,所以它们是兼容的。
倒数第二个维度:A 的维度是 1,B 的维度是 5。由于 A 的维度为 1,可以广播到 5。
倒数第三个维度:A 的维度是 3,B 的维度是 1。由于 B 的维度为 1,可以广播到 3。 -
扩展维度
根据上述规则,A 和 B 的维度会被扩展为:
A 的形状从 (3, 1, 4) 扩展为 (3, 5, 4)。
B 的形状从 (1, 5, 4) 扩展为 (3, 5, 4)。
扩展后的张量形状相同,因此可以进行逐元素加法操作。 -
逐元素加法:
扩展后的 A 和 B 形状相同,可以进行逐元素加法操作,结果 C = A + B 的形状为 (3, 5, 4)。
总结
广播机制通过自动扩展较小的张量来匹配较大张量的形状,从而使得逐元素操作能够顺利进行。这种机制避免了显式地扩展张量的维度,提高了代码的简洁性和效率。