digit_eye开发记录(2): Python读取MNIST数据集
在上一篇博客 digit_eye开发记录(1): C++读取MNIST数据集 中解读了 IDX 文件格式,并使用 C++ 语言完成了 MNIST 数据集的解析,第6小节给出的完整代码有146行之多。使用 Python 读取则可以省略70%的代码,只用不到50行代码完成相同功能。
读取 buffer
np.frombuffer(buf, dtype, count, offset)
说明:
- buf: buffer,从文件读出来的
- dtype: 从buf读取时,按什么类型读取数据,或者说,读取的基本单位是什么
- count: 从buf读取时,读取多少个基本单位
- offset: 从buf读取时,指针首先偏移多少个字节
读取 magic number
magic number 是 mnist 文件的前4个字节。 以二进制形式打开后,读取4字节即可:
import numpy as np
with open(filename, 'rb') as fin:
buf = bytearray(fin.read())
magic = np.frombuffer(buf, np.uint8, count=4)
print(magic)
读取维度信息
回忆一下 magic numbers 的构成: 前两个字节是0,第三个字节是类型,第四个字节是维度数量 num_dims。
根据 num_dims 的取值,读取对应数量的字节,得到对应的维度信息。每个维度都是一个 int32 大小。
注意 MSB 到 LSB 的转换,通过 dtype=np.dtype('>u4')
指定, >u4
意思是:以MSB序,读取4个byte.
对于图像数据:
num_dims = magic[3]
dims = np.frombuffer(buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)
对于label数据:
dims = np.frombuffer(buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)
num_labels = dims[0]
读取图像像素
很容易想到使用 OOP 方式,定义 DataSet 类,在成员 self.images 中保存图像;于是乎,很“毛躁”的写出如下糟糕代码:
class DataSet:
def __init__(self):
self.images = []
self.labels = []
def load_images(self, filename):
...
for i in range(num_images):
self.images.append(...)
存在的问题:
- self.images 的类型一定是 list 吗?其实可以是 numpy 数组
- self.images 的每个元素,和其他元素,一定是独立的吗? 可以是同一个内存上连续的分布
- self.images 的每个元素,内存可以和读取文件得到的 buffer 复用吗?可以!
class DataSet:
def __init__(self):
self.images = None
self.labels = None
self.buf = None
def load_images(self, filename):
with open(filename, 'rb') as fin:
self.buf = bytearray(fin.read())
magic = np.frombuffer(self.buf, np.uint8, count=4)
num_dims = magic[3]
dims = np.frombuffer(self.buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)
num_images, rows, cols = dims
self.images = np.frombuffer(self.buf, dtype=np.uint8, offset=4+4*num_dims).reshape(num_images, rows, cols)
...
train_set = DataSet()
train_set.load_images('data/train-images.idx3-ubyte')
print("Images and buffer share memory:", np.shares_memory(train_set.images, train_set.buf))
解释:self.buf 的类型,如果直接用 fin.read() 则得到 bytes 类型,是不可变的;转为 bytearray 类型后,是可变的,就可以保持和 self.images( ) 共享。
遗憾的是, self.buf = bytearray(fin.read())
这句本身就发生了内存拷贝。
改进 - 避免内存拷贝
with open(filename, 'rb') as fin:
self.buf = bytearray(fin.read()) # 当前实现,存在两次内存分配
改为
with open(filename, 'rb') as fin:
self.buf = fin.read() # 读取为 bytes
self.buf = memoryview(self.buf) # 直接使用 memoryview
就可以避免 bytes 对象的中间拷贝过程。
完整代码
import numpy as np
import cv2
class DataSet:
def __init__(self):
self.images = None
self.labels = None
self.buf = None
def load_images(self, filename):
with open(filename, 'rb') as fin:
#self.buf = bytearray(fin.read())
self.buf = fin.read()
self.buf = memoryview(self.buf)
magic = np.frombuffer(self.buf, np.uint8, count=4)
num_dims = magic[3]
dims = np.frombuffer(self.buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)
num_images, rows, cols = dims
self.images = np.frombuffer(self.buf, dtype=np.uint8, offset=4+4*num_dims).reshape(num_images, rows, cols)
def load_labels(self, filename):
with open(filename, 'rb') as fin:
buf = fin.read()
magic = np.frombuffer(buf, np.uint8, count=4)
num_dims = magic[3]
dims = np.frombuffer(buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)
num_labels = dims[0]
assert num_labels == len(self.images)
self.labels = np.frombuffer(buf, dtype=np.uint8, offset=4+4*num_dims)
def show_image(self, index):
cv2.imshow('image', self.images[index])
print('label:', self.labels[index])
cv2.waitKey(0)
cv2.destroyAllWindows()
def main():
train_set = DataSet()
train_set.load_images('data/train-images.idx3-ubyte')
train_set.load_labels('data/train-labels.idx1-ubyte')
# train_set.show_image(0)
# train_set.show_image(2)
# train_set.show_image(5)
print("Images and buffer share memory:", np.shares_memory(train_set.images, train_set.buf))
if __name__ == '__main__':
main()
总结
在前一篇,我们解析了MNIST数据集的IDX格式并用C++做了文件读取的实现,在本篇则切换到 Python 语言,在降低70%代码量的情况下实现了相同功能,并且避免了不必要的内存拷贝。这份工程之美,建立在对 IDX 格式有所了解的前提之下,对于 Python 的熟悉也是必不可少的,对于C++的经验也促使了复用内存这一条件的达成。