当前位置: 首页 > article >正文

矢量化操作

约定


本文中的”向量”均指一维数组/张量,”矩阵”均值二维数组/张量

前言


在ML当中,向量和矩阵非常常见。由于之前使用C语言的惯性,本人经常会从标量的角度考虑向量和矩阵的运算,也就是用for循环来完成向量或矩阵的运算。实际上,for循环的风格比python内置的操作或pytorch的函数要更加耗费时间并且实现起来更复杂。

接下来我将会通过加法和乘法演示矢量化的效果。

加法


首先定义一个矢量加法函数备用:

def tensor_add(a, b):
    '''
    only up to 2D tensor
    do not support broadcasting
    '''
    if a.dim() == 1 and b.dim() == 1: # both are vectors
        length = a.size(0)
        c = torch.zeros(length, dtype=a.dtype)
        for i in range(length):
            c[i] = a[i] + b[i]
        return c
    elif a.dim() == 2 and b.dim() == 2: # both are matrices
        rows, cols = a.size()
        c = torch.zeros(rows, cols, dtype=a.dtype)
        for i in range(rows):
            for j in range(cols):
                c[i, j] = a[i, j] + b[i, j]
        return c
    else:
        raise Exception('Unsupported tensor dimensions')

然后比较for循环和python内置加号的性能:

n = 1000
a = torch.ones(n)
b = torch.ones(n)

start_time = time.time()
c = tensor_add(a, b)
end_time = time.time()
print("Time taken(for-loop):", end_time - start_time)

start_time = time.time()
c = a + b
end_time = time.time()
print("Time taken(vectorized):", end_time - start_time)
Time taken(for-loop): 0.005979299545288086
Time taken(vectorized): 0.0001308917999267578

矩阵加法:

A = torch.ones((200, 200))
B = torch.ones((200, 200))

start_time = time.time()
C = tensor_add(A, B)
end_time = time.time()
print("Time taken(for-loop):", end_time - start_time)

start_time = time.time()
C = A + B
end_time = time.time()
print("Time taken(vectorized):", end_time - start_time)
Time taken(for-loop): 0.214155912399292
Time taken(vectorized): 0.0006866455078125

我们可以看到随着张量的尺寸变大,for循环和矢量化间的时间消耗差距也在变大。

乘法


对于乘法,我们也是首先定义一个函数来实施for循环乘法:

def tensor_multiply(a, b):
    '''
    only up to 2D tensor
    for vector-vector, cross product is computed
    for matrix-matrix, matrix multiplication is computed
    do not support matrix-vector multiplication
    '''
    if a.dim() == 1 and b.dim() == 1: # both are vectors
        if a.size(0)!= b.size(0):
            raise Exception('Vector dimensions do not match')
        length = a.size(0)
        c = torch.zeros((length, length), dtype=a.dtype)
        for i in range(length):
            for j in range(length):
                c[i, j] = a[i] * b[j]
        return c
    elif a.dim() == 2 and b.dim() == 2: # both are matrices
        if a.size(1)!= b.size(0):
            raise Exception('Matrix dimensions do not match')
        rows = a.size(0)
        cols = b.size(1)
        c = torch.zeros(rows, cols, dtype=a.dtype)
        for i in range(rows):
            for j in range(cols):
                for k in range(a.size(1)): # 3 loops for matrix multiplication
                    c[i, j] += a[i, k] * b[k, j]
        return c
    else:
        raise Exception('Unsupported tensor dimensions')

然后,比较for循环和矢量化方式的性能:

start_time = time.time()
c = tensor_multiply(a, b)
end_time = time.time()
print("Time taken(for-loop):", end_time - start_time)

start_time = time.time()
c = torch.outer(a, b) # outer product is different from cross product https://blog.csdn.net/Dust_Evc/article/details/127502272
end_time = time.time()
print("Time taken(vectorized):", end_time - start_time)
Time taken(for-loop): 3.7471888065338135
Time taken(vectorized): 0.00025653839111328125

对于矩阵:

start_time = time.time()
C = tensor_multiply(A, B)
end_time = time.time()
print("Time taken(for-loop):", end_time - start_time)

start_time = time.time()
C = A@B
end_time = time.time()
print("Time taken(vectorized):", end_time - start_time)
Time taken(for-loop): 54.04433226585388
Time taken(vectorized): 0.011148929595947266

总结


我们可以看到,使用python内置方法或pytorch函数可以极大地加速向量和矩阵运算。

资源


本博客的jupyter notebook文件


http://www.kler.cn/a/308824.html

相关文章:

  • CentOS网络配置
  • Go八股(Ⅴ)map
  • 如何在CentOS 7上搭建SMB服务
  • 火车车厢重排问题,C++详解
  • Leecode热题100-35.搜索插入位置
  • UDP协议和TCP协议之间有什么具体区别?
  • JS日期转化指定格式,获取月/周日期区间
  • CentOS 中配置 OpenJDK以及多版本管理
  • Unix-like系统是什么
  • 408算法题leetcode--第五天
  • frp内网穿透功能使用教程
  • 玩机搞机-----如何简单的使用ADB指令来卸载和冻结系统应用 无需root权限 详细操作图示教程
  • Spring Boot-热部署问题
  • SpringBoot启动成功,但端口启动失败
  • 架构师备考的一些思考(四)
  • 集群聊天服务器项目【C++】(六)MySql数据库
  • 【观察】戴尔AI算力加速服务器“焕新上市”,打通AI落地行业“最后一公里”...
  • 2024年上半年软考【中级】网络工程师 综合知识真题回顾
  • Bio-Linux-shell详解-2-基本Shell命令快速掌握
  • 夕阳红老年大学视频教学网站管理系统设计与实现,按用户最近浏览分类字段推荐视频课程-留言和评论分词过滤
  • Davinci 大数据可视化分析
  • 网站被爬,数据泄露,如何应对不断强化的安全危机?
  • Python3网络爬虫开发实战(15)Scrapy 框架的使用(第一版)
  • 《黑神话:悟空》:中国游戏界的新篇章
  • 计算机毕业设计选题推荐-共享图书管理系统-小程序/App
  • MySQL 中常用函数使用