np.expand_dims函数
在 NumPy 中,np.expand_dims
是一个用于增加数组维度的函数。它可以在指定的位置插入一个新的维度,使数组的形状增加一个维度。
语法
np.expand_dims(a, axis)
a
:要操作的数组。axis
:指定插入新轴的位置。可以是负数,表示从数组的末尾开始计算位置。
功能
np.expand_dims
将数组 a
在指定的 axis
位置插入一个新的维度。这在操作多维数组时非常有用,比如在需要进行广播或确保数组维度匹配时。
示例
示例 1:增加一个新的轴
假设有一个一维数组,我们想将其变成二维数组:
import numpy as np
a = np.array([1, 2, 3])
print("原始数组:", a)
print("原始数组的形状:", a.shape)
# 在第 0 轴增加一个新轴
b = np.expand_dims(a, axis=0)
print("增加维度后的数组:", b)
print("新数组的形状:", b.shape)
输出:
原始数组: [1 2 3]
原始数组的形状: (3,)
增加维度后的数组: [[1 2 3]]
新数组的形状: (1, 3)
在这里,np.expand_dims(a, axis=0)
在第 0 轴插入了一个新轴,使数组从一维变成了二维。
示例 2:在第 1 轴增加一个新轴
c = np.expand_dims(a, axis=1)
print("在第 1 轴增加新维度后的数组:", c)
print("新数组的形状:", c.shape)
输出:
在第 1 轴增加新维度后的数组:
[[1]
[2]
[3]]
新数组的形状: (3, 1)
这里 np.expand_dims(a, axis=1)
在第 1 轴插入了一个新轴,使数组从 (3,)
变成了 (3, 1)
。
应用场景
-
图像数据预处理:在机器学习中,图像通常表示为 3D 数组(高度、宽度、通道),而神经网络的输入通常需要 4D(批次大小、高度、宽度、通道),这时可以用
np.expand_dims
增加批次维度。image = np.random.rand(224, 224, 3) # 单张彩色图片 batch_image = np.expand_dims(image, axis=0) # 增加批次维度 print(batch_image.shape) # 输出:(1, 224, 224, 3)
-
广播:在计算中,为了使数组形状符合广播机制,可以使用
np.expand_dims
添加单独的轴,以匹配其他数组的维度。
总结
np.expand_dims
是用于在指定轴插入一个新维度的函数,这对于处理多维数组、确保形状兼容性和准备数据进行深度学习非常有用。