Python调用最小二乘法
文章目录
- numpy实现
- scipy封装
- 速度对比
所谓线性最小二乘法,可以理解为是解方程的延续,区别在于,当未知量远小于方程数的时候,将得到一个无解的问题。最小二乘法的实质,是保证误差最小的情况下对未知数进行赋值。
最小二乘法是非常经典的算法,而且这个名字我们在高中的时候就已经接触了,属于极其常用的算法。此前曾经写过线性最小二乘法的原理,并用Python实现:最小二乘法及其Python实现;以及scipy
中非线性最小二乘法的调用方式:非线性最小二乘法;还有稀疏矩阵的最小二乘法:稀疏矩阵最小二乘法。
下面讲对numpy
和scipy
中实现的线性最小二乘法进行说明,并比较二者的速度。
numpy实现
numpy
中便实现了最小二乘法,即lstsq(a,b)
用于求解类似于a@x=b
中的x
,其中,a
为
M
×
N
M\times N
M×N的矩阵;则当b
为
M
M
M行的向量时,刚好相当于求解线性方程组。对于
A
x
=
b
Ax=b
Ax=b这样的方程组,如果
A
A
A是满秩仿真,那么可以表示为
x
=
A
−
1
b
x=A^{-1}b
x=A−1b,否则可以表示为
x
=
(
A
T
A
)
−
1
A
T
b
x=(A^{T}A)^{-1}A^{T}b
x=(ATA)−1ATb。
当b
为
M
×
K
M\times K
M×K的矩阵时,则对每一列,都会计算一组x
。
其返回值共有4个,分别是拟合得到的x
、拟合误差、矩阵a
的秩、以及矩阵a
的单值形式。
import numpy as np
np.random.seed(42)
M = np.random.rand(4,4)
x = np.arange(4)
y = M@x
xhat = np.linalg.lstsq(M,y)
print(xhat[0])
#[0. 1. 2. 3.]
scipy封装
scipy.linalg
同样提供了最小二乘法函数,函数名同样是lstsq
,其参数列表为
lstsq(a, b, cond=None, overwrite_a=False, overwrite_b=False, check_finite=True, lapack_driver=None)
其中a, b
即
A
x
=
b
Ax=b
Ax=b,二者均提供可覆写开关,设为True
可以节省运行时间,此外,函数也支持有限性检查,这是linalg
中许多函数都具备的选项。其返回值与numpy
中的最小二乘函数相同。
cond
为浮点型参数,表示奇异值阈值,当奇异值小于cond
时将舍弃。
lapack_driver
为字符串选项,表示选用何种LAPACK
中的算法引擎,可选'gelsd'
, 'gelsy'
, 'gelss'
。
import scipy.linalg as sl
xhat1 = sl.lstsq(M, y)
print(xhat1[0])
# [0. 1. 2. 3.]
速度对比
最后,对着两组最小二乘函数做一个速度上的对比
from timeit import timeit
N = 100
A = np.random.rand(N,N)
b = np.arange(N)
timeit(lambda:np.linalg.lstsq(A, b), number=10)
# 0.015487500000745058
timeit(lambda:sl.lstsq(A, b), number=10)
# 0.011151800004881807
这一次,二者并没有拉开太大的差距,即使将矩阵维度放大到500,二者也是半斤八两。
N = 500
A = np.random.rand(N,N)
b = np.arange(N)
timeit(lambda:np.linalg.lstsq(A, b), number=10)
0.389679799991427
timeit(lambda:sl.lstsq(A, b), number=10)
0.35642060000100173