当前位置: 首页 > article >正文

digit_eye开发记录(2): Python读取MNIST数据集

在上一篇博客 digit_eye开发记录(1): C++读取MNIST数据集 中解读了 IDX 文件格式,并使用 C++ 语言完成了 MNIST 数据集的解析,第6小节给出的完整代码有146行之多。使用 Python 读取则可以省略70%的代码,只用不到50行代码完成相同功能。

读取 buffer

np.frombuffer(buf, dtype, count, offset)

说明:

  • buf: buffer,从文件读出来的
  • dtype: 从buf读取时,按什么类型读取数据,或者说,读取的基本单位是什么
  • count: 从buf读取时,读取多少个基本单位
  • offset: 从buf读取时,指针首先偏移多少个字节

读取 magic number

magic number 是 mnist 文件的前4个字节。 以二进制形式打开后,读取4字节即可:

import numpy as np

with open(filename, 'rb') as fin:
    buf = bytearray(fin.read())
magic = np.frombuffer(buf, np.uint8, count=4)
print(magic)

读取维度信息

回忆一下 magic numbers 的构成: 前两个字节是0,第三个字节是类型,第四个字节是维度数量 num_dims。
根据 num_dims 的取值,读取对应数量的字节,得到对应的维度信息。每个维度都是一个 int32 大小。

注意 MSB 到 LSB 的转换,通过 dtype=np.dtype('>u4') 指定, >u4 意思是:以MSB序,读取4个byte.

对于图像数据:

num_dims = magic[3]
dims = np.frombuffer(buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)

对于label数据:

dims = np.frombuffer(buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)
num_labels = dims[0]

读取图像像素

很容易想到使用 OOP 方式,定义 DataSet 类,在成员 self.images 中保存图像;于是乎,很“毛躁”的写出如下糟糕代码:

class DataSet:
    def __init__(self):
        self.images = []
        self.labels = []
    def load_images(self, filename):
        ...
        for i in range(num_images):
            self.images.append(...)

存在的问题:

  • self.images 的类型一定是 list 吗?其实可以是 numpy 数组
  • self.images 的每个元素,和其他元素,一定是独立的吗? 可以是同一个内存上连续的分布
  • self.images 的每个元素,内存可以和读取文件得到的 buffer 复用吗?可以!
class DataSet:
    def __init__(self):
        self.images = None
        self.labels = None
        self.buf = None

    def load_images(self, filename):
        with open(filename, 'rb') as fin:
            self.buf = bytearray(fin.read())

        magic = np.frombuffer(self.buf, np.uint8, count=4)
        num_dims = magic[3]
        dims = np.frombuffer(self.buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)
        num_images, rows, cols = dims
        self.images = np.frombuffer(self.buf, dtype=np.uint8, offset=4+4*num_dims).reshape(num_images, rows, cols)
   ...

train_set = DataSet()
train_set.load_images('data/train-images.idx3-ubyte')
print("Images and buffer share memory:", np.shares_memory(train_set.images, train_set.buf))

解释:self.buf 的类型,如果直接用 fin.read() 则得到 bytes 类型,是不可变的;转为 bytearray 类型后,是可变的,就可以保持和 self.images( ) 共享。

遗憾的是, self.buf = bytearray(fin.read()) 这句本身就发生了内存拷贝。

改进 - 避免内存拷贝

with open(filename, 'rb') as fin:  
    self.buf = bytearray(fin.read())  # 当前实现,存在两次内存分配  

改为

with open(filename, 'rb') as fin:  
    self.buf = fin.read()  # 读取为 bytes  
    self.buf = memoryview(self.buf)  # 直接使用 memoryview  

就可以避免 bytes 对象的中间拷贝过程。

完整代码

import numpy as np
import cv2

class DataSet:
    def __init__(self):
        self.images = None
        self.labels = None
        self.buf = None

    def load_images(self, filename):
        with open(filename, 'rb') as fin:
            #self.buf = bytearray(fin.read())
            self.buf = fin.read()
            self.buf = memoryview(self.buf)

        magic = np.frombuffer(self.buf, np.uint8, count=4)
        num_dims = magic[3]
        dims = np.frombuffer(self.buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)
        num_images, rows, cols = dims
        self.images = np.frombuffer(self.buf, dtype=np.uint8, offset=4+4*num_dims).reshape(num_images, rows, cols)

    def load_labels(self, filename):
        with open(filename, 'rb') as fin:
            buf = fin.read()

        magic = np.frombuffer(buf, np.uint8, count=4)
        num_dims = magic[3]
        dims = np.frombuffer(buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)
        num_labels = dims[0]
        assert num_labels == len(self.images)
        self.labels = np.frombuffer(buf, dtype=np.uint8, offset=4+4*num_dims)

    def show_image(self, index):
        cv2.imshow('image', self.images[index])
        print('label:', self.labels[index])
        cv2.waitKey(0)
        cv2.destroyAllWindows()

def main():
    train_set = DataSet()
    train_set.load_images('data/train-images.idx3-ubyte')
    train_set.load_labels('data/train-labels.idx1-ubyte')
    # train_set.show_image(0)
    # train_set.show_image(2)
    # train_set.show_image(5)
    print("Images and buffer share memory:", np.shares_memory(train_set.images, train_set.buf))

if __name__ == '__main__':
    main()

总结

在前一篇,我们解析了MNIST数据集的IDX格式并用C++做了文件读取的实现,在本篇则切换到 Python 语言,在降低70%代码量的情况下实现了相同功能,并且避免了不必要的内存拷贝。这份工程之美,建立在对 IDX 格式有所了解的前提之下,对于 Python 的熟悉也是必不可少的,对于C++的经验也促使了复用内存这一条件的达成。


http://www.kler.cn/a/412356.html

相关文章:

  • elasticsearch报错fully-formed single-node cluster with cluster UUID
  • kafka生产者和消费者命令的使用
  • 计算机基础(下)
  • 基于SpringBoot的工程教育认证的计算机课程管理系统【附源码】
  • Python基础学习-11函数参数
  • 基于混合ABC和A*算法复现
  • 渗透测试笔记—window基础
  • 蓝桥杯每日真题 - 第24天
  • 27加餐篇:gRPC框架的优势与不足之处
  • Apache Zeppelin:一个基于Web的大数据可视化分析平台
  • 前端 设置 div 标签内子多个子 div 内容,在一行展示,并且可以字段自动换行
  • Flink 实现超速监控:从 Kafka 读取卡口数据写入 MySQL
  • 浏览器开发工具
  • java——SpringBoot中常用注解及其底层原理
  • SSM之AOP与事务
  • 缓存雪崩、击穿、穿透深度解析与实战应对
  • 使用OpenCV实现视频背景减除与目标检测
  • 【QT】背景,安装和介绍
  • 【云计算网络安全】解析 Amazon 安全服务:构建纵深防御设计最佳实践
  • docker-compose文件的简介及使用
  • Git 使用技巧
  • 鸿蒙开发异步与线程
  • 使用Cmake导入OpenCV库的大坑记录
  • 如何将 GitHub 私有仓库(private)转换为公共仓库(public)
  • 反爬虫机制
  • 【大数据学习 | Spark-SQL】SparkSession对象