使用广播机制将for循环转为矩阵运算
文章目录
- 1 问题描述
- 2 求解
- 2.1 问题分析
- 2.2 Numpy广播机制定义
- 2.3 矩阵运算
- 3 总结
1 问题描述
有两个点云 N 1 ∈ ( n , 3 ) N_1 \in(n, 3) N1∈(n,3)和 N 2 ∈ ( m , 3 ) N2\in(m,3) N2∈(m,3)。现在要统计 N 2 N_2 N2到 N 1 N_1 N1距离小于2的点的数量。
2 求解
2.1 问题分析
这个问题属于一类典型的for
循环遍历问题,求解的方法如下:
import numpy as np
n1 = np.array([[1,0,0],
[2,0,0],
[4,0,0]])
n2 = np.array([[1,0,0],
[2,0,0],
[8,0,0],
[0,8,0],
[0,0,9]])
m, _ = n2.shape
n, _ = n1.shape
counter = 0
for i in range(0, m):
for j in range(0,n):
distances = np.sqrt((n1[j,0]-n2[i,0])**2 + (n1[j,1]-n2[i,1])**2 + (n1[j,2]-n2[i,2])**2)
if distances < 2:
counter = counter + 1
break
print(counter)
这种求解方法的时间复杂度是
O
(
N
2
)
O(N^2)
O(N2)。但是在实际情况尤其是深度学情况下,需要尽量避免使用for
循环,以此提升程序速度。可以考虑使用numpy矩阵运算来代替for
循环。
2.2 Numpy广播机制定义
Numpy广播(broadcasting)是NumPy中用于处理形状不匹配的数组进行逐元素运算的一种机制。当进行二进制运算(如加法、减法、乘法等)或一元运算(如取负、取指数等)时,NumPy会自动调整数组的形状,使其能够逐元素地进行运算,而无需显式地进行形状的扩充或重复。
一般情况下,想利用Numpy将for
运算简化成矩阵运算需要反向思维,将两个待求解的矩阵/数组变换成形状不匹配的状态,引发广播机制实现逐元素运算。
2.3 矩阵运算
使用numpy的广播机制可以将上述问题转化为矩阵运算:
import numpy as np
n1 = np.array([[1,0,0],
[2,0,0],
[4,0,0]])
n2 = np.array([[1,0,0],
[2,0,0],
[8,0,0],
[0,8,0],
[0,0,9]])
#1 对n1扩充维度 (n,3)->(n,1,3)
n1_expand = n1[:, np.newaxis]
#2 使用广播机制计算两个点云之间的距离矩阵
distances = np.sqrt(np.sum((n1_expand - n2) ** 2, axis=2))
#3 设置判断条件(即: 距离小于2)
condition = np.any(distances < 2, axis=0)
#4 统计满足条件个数
counter = np.sum(condition)
print(counter)
程序解析:
-
n1_expand = n1[:, np.newaxis]
是指将N1的形状从(n, 3)扩展为(n, 1, 3)。其中np.newaxis
是在NumPy中用于增加数组的维度的特殊索引。它用于在指定的位置插入一个新的维度。
此时n1_expand的shape为(n, 1, 3),n2的shape为(m,3),为了实现对应元素的减法操作,Numpy广播机制会将n1_epxand和n2都扩充至(n, m, 3)。- 对于n1_expand来说,它在dim=1维度复制m次,在对应做减法时等同于
for i in range(0, m):
- 对于n2来说,它等于先扩充了dim=0维度,再从dim=0上复制了n次,等同于
for j in range(0, n):
因为最终的shape是(n, m, 3),减法的操作对象只作用到dim=2维,所以此次广播相当于对dim=0和dim=1的逐元素叠加遍历,即双重for循环。
矩阵代替循环的设计技巧:先确定运算符作用维度,然后每代替一层循环就需要给矩阵增加一个维度,即:广播后的矩阵维度=替代的for循环层数+运算作用维度。
在这个示例中,距离运算只局限在(n, 3)中的dim=1这一个维度上,想要替换的是两个点云矩阵间相对元素计算的两层for循环,所以广播后的矩阵的shape=2+1,即:3。
- 对于n1_expand来说,它在dim=1维度复制m次,在对应做减法时等同于
-
distances = np.sqrt(np.sum((n1_expand - n2) ** 2, axis=2))
这里n1_expand - n2为对应元素减法获得各坐标轴上的差值,整体表达式的shape仍为(n, m, 3);**2
指定对所有差值做平方;axis=2
指定对dim=2,即最后一维的3个元素做np.sum
相加操作,这正好实现了三个差值平方的相加。因为加法操作发生在dim=2
上,所以这个对应维度在完成加法操作后消去了,此时的shape变为(n,m)。最后外围的np.sqrt实现开根号。
至此计算出点云 N 1 N_1 N1中任意一点到点云 N 2 N_2 N2任意一点的距离,将它们以矩阵的形式保存,形成distance
,shape为(n, m)。例如,distance[1,2]
表示点云 N 1 N_1 N1中的第1个点到点云 N 2 N_2 N2中的第2个点的距离,反之也成立。 -
condition = np.any(distances < 2, axis=0)
这句话主要实现判断 N 2 N_2 N2中是否有满足距离小于2的点。distance[x,y]
表示x到y的距离,其中 y ∈ N 2 y\in N_2 y∈N2。因此要对distance逐列筛选,即:只要distance[:,j]
中的n个元素里有任一个及以上小于2就认为当前 N 2 N_2 N2点云中的第j个元素符合条件。由此规定axis=0
按列索引,np.any
只要有一个符合条件就将它置为True
,最后condtion为一个shape是(1, m)的bool
矩阵。 -
最后统计condtion中True的数量,记为最终所求。
3 总结
一般情况下,当for
循环内的操作为仅为简单的四则运算,且操作对象为2个时,就可以考虑采用矩阵运算替代for
循环。由于Numpy矩阵运算是并行的,所以将for
循环转为矩阵运算会大量节约计算时间。此外,该类问题还有一种更为通俗的解法:采用哈希表。