Brainpy的jit编译环境基础
文章目录
- JIT编译加速
- 对于函数的加速
- 对于类对象的加速
- 自动化的测试,关闭jit
- 数据操作
- 数组Array
- 数组属性
- 类型转换
- 动态变量
- 控制流
- 条件语句
- 非变量控制语句
- 基于变量的控制语句
- brainpy.math.where
- brainpy.math.ifelse
- 循环语句
- Pythonic 循环语法
- brainpy.math.for_loop()
- `brainpy.math.for_loop()` 参数和说明
- 参数说明
- 使用示例
- brainpy.math.while_loop()
- 参数说明
这个是书本的第一个章节的第2部分,对应开源文档的 brainpy的jit的编译部分
JIT编译环境下的编程基础,本文不详细记录文档,对于课本提到的进行记录。
说明:该文档版本是JIT1.1部分。目前使用的版本是2.6。文档最高更新到了2.3.8。
最新的版本编译说明是面向对象的一个例子2.3.8版本的描述
查看版本
import brainpy as bp
import brainpy.math as bm
bp.__version__
即时编译
JIT是一种运行计算机代码的方式,使程序在运行时而不是运行前完成编译。
JIT编译继承了解释器的灵活性和编译器的高效性。
使动态编译的一种形式,允许自适应优化,如动态重新编译和针对微架构的加速。
边执行边编译和流水线一样,对于重复的和即时响应要求高的函数非常适用。
BrainPy 的核心理念是即时编译(JIT)。 JIT 编译可将 Python 代码 "即时 "编译成机器代码执行。 随后,这些转换后的代码可以以本地机器代码的速度运行!Python 中提供了优秀的 JIT 编译器,如 JAX 和 Numba。 不过,这些编译器只能用于纯 Python 函数。 在计算神经科学中,大多数模型都有太多的参数和变量,仅使用函数很难管理和控制模型逻辑。 相反,Python 中基于类的面向对象编程(OOP)将使您的编码更可读、更可控、更灵活、更模块化。 因此,在脑建模编程中,有必要支持类对象的 JIT 编译。 在 BrainPy 中,我们在 JAX 和 Numba 的基础上为类对象提供了 JIT 编译接口。
import brainpy as bp
import brainpy.math as bm
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
bm.set_platform('cpu')
设置计算的平台,提供了brainpy.math.set_platform()
可以通过该函数确定船舰的数据存储在那种设备上,如果设置为cpu,则新数据存储在内存中,并使用cpu进行计算.如果是gpu和tpu平台也可以。
JIT编译加速
对于函数的加速
JIT编译可以理解为加速机制。把目标函数或类用brainpy.math.jit()包装,以指示BrainPy将Python代码转化为机器码。从函数入手,假设实现了一个高斯误差线性单元(GELU)
def gelu(x):
sqrt = bm.sqrt(2/bm.pi)
cdf = 0.5*(1.0 + bm.tanh(sqrt*(x + 0.044715 * bm.power(x, 3))))
y = x * cdf
return y
G E L U ( x ) = x ⋅ Φ ( x ) \mathrm{GELU}(x)=x\cdot\Phi(x) GELU(x)=x⋅Φ(x)
加速时间的测试
x = bm.random.random (100000)
%timeit gelu(x)
220 μs ± 3.17 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
gelu_jit = bm.jit(gelu)
%timeit gelu_jit(x)
37.3 μs ± 581 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
JIT对于函数的编译加速比较明显,但是对脑动力学编程而言。只进行JIT编程是不够的。动力学系统中,动态变量和微分方程使计算过程非常复杂。Brainpy支持比函数更上一层的类对象的JIT编译。
对于类对象的加速
(1) 该类对象要满足必须是brainpy.BrainPyObject的子类。brainpy.BrainPyObject的Object是BrainPy的基类,其中所有的方法都可以被JIT编译,因此所有继承基类的子类都可以是被JIT编译。
(2)动态变量必须被定义为brainpy.math.Variable
下面以逻辑回归(Logistic Regression)分类器为例进行介绍,由于权重w需要在训练过程中修改,所以定义为brainpy.math.Variable。其余参数在编译过程中被视为静态变量,值不会被修改。
class LogisticRegression(bp.Base):
def __init__(self, dimension):
super(LogisticRegression, self).__init__()
# parameters
self.dimension = dimension
# variables
#动态变量的定义
self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)
def __call__(self, X, Y):
u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
self.w.value = self.w - u
对于这个逻辑回归的类进行执行时间的测试。
import time
def benckmark(model, points, labels, num_iter=30, name=''):
t0 = time.time()
for i in range(num_iter):
model(points, labels)
print(f'{name} used time {time.time() - t0} s')
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)
自动化的测试,关闭jit
jit可以带来非常明显的加速。
在实际的脑动力学编程中,一个大规模动力学系统往往包含大量神经元和突出模型。如果显示地将每个对象都包装到brainpy.math.jit()中,会显得有些繁琐。为了简化编程逻辑,BrainPy实现了自动JIT编译。
BrainPy提供了一个brainpy.Runner类。该类也是模拟,训练积分等运行器的基类。在初始化时,运行器会收到名为jit的参数,默认设置为True。这表明Runner会自动编译目标工程(只要目标工程被传入Runner)。为便于理解,举一个动力学仿真的例子。调用BrainPy生成一个HH模型。将模型作为参数传入运行器中。模型自动编译执行。
model = bp.neurons.HH(1000)
runner= bp.DSRunner(target=model, inputs=('input', 10.))
runner(duration=1000,eval_time=True)
上面的例子中,没有涉及brainpy.math.jit()操作,BrainPy将所有显示JIT操作都封装在BrainPyObject中。
JIT编译有效缩减了运行时间。但是无法使用原方法进行调试。JIT编译后,底层代码会进行结构优化。为了调试只能关闭JIT编译。
#关闭JIt调试
model=bp.neurons.HH(1000)
runner= bp.DSRunner(target=model, inputs=('input', 10.),jit=False)
runner(duration=1000,eval_time=True)
关闭之后,执行时间显著增加。
Predict 10000 steps: : 24%|██▍ | 2406/10000 [01:08<03:22, 37.59it/s]
本节小结:加速效果的展示,函数,类,自动使用的方法。
数据操作
除了Python包含的数据类型,BrainPy还包含两个特殊的数据类型:数组(Array)和动态变量(Variable)。
数组类似于numpy的多维数组(ndarray)。动态变量是
BrainPy框架中应用于JIT编译的一种新型数据结构,为了可以在JIT编译环境对数组的值进行原地更新。使用动态变量来代替数组。支持自动求梯度的功能。
数组Array
数组是在brainpy.math中的,包含所有支持数组的操作。使用brainpy.math.array()创建一个一维数组。并与NumPy创建数组进行对比。
bm_array = bm.array([0.,1.,2.,3.,4.,5.,6.,7.,8.,9.])
np_array = np.array([0.,1.,2.,3.,4.,5.,6.,7.,8.,9.])
bm_array
#Array(value=Array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), dtype=float32)
数组属性
可以访问的数组属性。
在 brainpy.math.array()
创建的数组中,四个常用的属性包括:
-
nidm:数组的轴数
-
shape:数组在每个维度上的长度,定义数组的结构。例如,二维数组的形状可能是
(3, 3)
。 -
size:数组中的总元素数,即所有维度的元素数相乘。例如,形状为
(3, 3)
的数组,其size
为 9。 -
dtype:数组的元素数据类型,如
float32
、int32
等,决定数组的存储类型和计算精度。 -
device:表示存储数组数据的设备,通常是
cpu
或gpu
。设置device
可以利用 GPU 加速计算。 -
requires_grad:指定数组是否需要梯度计算。在自动微分(automatic differentiation)中使用此属性很重要,特别是在深度学习和其他需要梯度信息的领域中。
这些属性在数据操作和计算中具有重要作用。
import brainpy.math as bm
# 创建数组
array = bm.array([[1, 2, 3], [4, 5, 6]], dtype=bm.float32)
# 查看数组的四个属性
print('array.dim:{}'.format(array.ndim))
print('array.size:{}'.format(array.size))
print('array.shape:{}'.format(array.shape))
print('array.dtype:{}'.format(array.dtype))
print('array.device:{}'.format(array.device))
array.dim:2 array.size:6 array.shape:(2, 3) array.dtype:float32
array.device:<bound method Array.device of Array(value=Array([[1., 2.,
3.],
[4., 5., 6.]]),
dtype=float32)>
brainpy.math创建的数组被存储在JaxArray中,其内部存储了JAX中的数据类型DeviceArra
#得到
t1 = bm.arange(10)
t1
#Array(value=Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), dtype=int32)
t1.value
类型转换
和书中的内容不太一致了。
Numpy: ndarray
JAX:DeviceArray
BrainPy:JaxArray
三者怎么进行类型转换。
1.JaxArray和ndarray
JaxArray->ndarray
bm_array = bm.array([0.,1.,2.,3.,4.,5.,6.,7.,8.,9.])
bm_array.to_numpy()
bm.as_numpy(bm_array)
ndarray->JaxArray
np_array=np.array([0.,1.,2.,3.,4.,5.,6.,7.,8.,9.])
bm.asarray(np_array)
2.DeviceArray和JaxArray
to_jax
JaxArray ->DeviceArray
bm_array = bm.array([0.,1.,2.,3.,4.,5.,6.,7.,8.,9.])
bm_array.to_jax()
bm.as_jax(bm_array)
DeviceArray->JaxArray
jnp_array = jnp.array([0.,1.,2.,3.,4.,5.,6.,7.,8.,9.])
bm.asarray(jnp_array)
调用 bm_array.to_jax()
后表面上可能看不到区别,但实际上 to_jax()
方法将 bm_array
转换为了 JAX 的 DeviceArray
对象。这种转换的优势主要体现在:
-
加速计算:JAX 支持自动微分和 GPU 加速,通过将数据转换成 JAX 的
DeviceArray
类型,你可以直接利用 JAX 的高效计算能力。 -
JAX 函数兼容性:将数据转换为
DeviceArray
后,可以直接调用 JAX 中的优化函数、数值计算和自动微分函数,从而享受 JAX 的高效特性。 -
分布式计算支持:如果使用 TPU 或者多 GPU,
DeviceArray
在 JAX 中可以支持分布式计算,这在大规模深度学习任务中非常有用。
尽管表面上数据和 bm.array
没有明显差异,但 DeviceArray
能让你更充分地利用 JAX 的特性,尤其在深度学习和科学计算中效果明显。
Brainpy提供了大部分算子,并补充了JAX没有提供的算子。涉及就地更新和随机生成。
动态变量
在JIT编译环境下,数组作为数据类型编程时会产生问题。一旦数组被交给JIT编译器。数组中的值就不能被修改了。使用Variable。
动态变量是一个指向数组的指针,指向内存空间中存储的数组的值(DeviceArray)
在JIT编译过程中,动态变量重点数据可以被修改。
随时间动态变化的数组。
包装一下即可
t= bm.arange(4)
v= bm.Variable(t)
v
#t= bm.arange(4)
v= bm.Variable(t)
v
从JaxArray变成Variable。可以动态修改。仍然可以通过.value属性获取动态变量中的值(DeviceArray)。
没有标记为动态变量的数组将作为静态数组进行JIT编译,对静态数组的修改在JIT编译环境中是无效的。
v=v+1
#Array(value=Array([2, 3, 4, 5]), dtype=int32)
返回了一个JaxArray。
其中的数据结果是正确的。如果直接修改一个动态变量,该动态变量所指向的内存空间并没有被修改。而是开辟了一个新的内存空间来存储新的结果,并以数组的形式返回。
为了真正的做到更新动态变量,用户需要使用就地更新操作。可以修改动态变量内部的值。
1.索引和切片操作
v=bm.Variable(bm.arange(4))
v[0]=10
v
#Variable(value=Array([10, 1, 2, 3]), dtype=int32)
#使用切片也可以修改动态变量的数据。
v[1:3] = 1
v
.value赋值
变量就地更新中最常用的操作之一。在更新动态变量时经常需要将数组赋值给某动态变量。在动力学系统迭代更新的过程中重置动态变量的值。可以直接访问JaxArray中的数据。
v.value = bm.arange(4)
v
覆盖时要保证数组的形状,元素类型,动态变量完全一致。
try:
v.value = bm.arange(5)
except Exception as e:
print(e)
#The shape of the original data is (4,), while we got (5,) with batch_axis=None.
.update方法
该方法的功能有.value类似
v.update(bm.arange(5))
v
BrainPy除了可以实现模型的模拟,分析,还可以对模型进行训练。在机器学习领域,训练与测试中的数据会有一个新维度:批处理大小(Batch Size).即每次传给网络的样本数量。批处理大小往往是动态变化,为了适应动态变化的数组形状。在初始化动态变量时,要申明批处理维度。
#表示数组中批处理的维度,这里是参数中的1
dyn_var = bm.Variable(bm.zeros((1,100)),batch_axis=0)
dyn_var.shape
#批处理大小变成10
dyn_var = bm.ones((10, 100))
进一步学习查看文档Variables
控制流
此处记录一下控制语句的区别。
控制流1.1
控制流2.3.8
循环语句和条件语句在JIT编译环境下会受到一定限制,因为在编译过程中,只记录动态变量的形状和类型。在遇到依赖动态变量的条件语句时,会因为系统没有追踪动态变量的真实值而无法继续编译。
import brainpy as bp
import brainpy.math.jax as bm
bp.math.use_backend('jax')
在 JAX 中,控制流语法并不好用。 用户必须将直观的 Python 控制流转化为结构化的控制流。
条件语句
在 Python 中,选择语句也被称为决策控制语句或分支语句。 选择语句允许程序测试多个条件,并根据哪个条件为真执行指令。 常用的控制语句包括:
- if-else
- nested if
- if-elif-else
课本参照的应该是2.2的文档。所以此处使用这个版本的进行笔记记录。
非变量控制语句
实际上,当条件语句依赖于非变量实例时,BrainPy(基于 JAX)可以像您熟悉的 Python 程序一样正常编写控制流。
class OddEven(bp.BrainPyObject):
def __init__(self, type_=1):
super(OddEven, self).__init__()
self.type_ = type_
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
if self.type_ == 1:
self.a += 1
elif self.type_ == 2:
self.a -= 1
else:
raise ValueError(f'Unknown type: {self.type_}')
return self.a
在上例中,目标语句if(statement)语法依赖于一个标量,不是brainpy.math.Variable的实例。在这种情况下,条件语句可以任意复杂。 您可以用普通的 Python 代码编写模型。 这些模型在 JIT 编译时会运行得很好。
编译测试:
model = bm.jit(OddEven(type_=2))
model()
基于变量的控制语句
第二种条件语句依赖动态变量。
但是,如果 if … else … 语法中的语句目标依赖于 brainpy.math.Variable 实例,那么在使用 JIT 编译时,编写 Pythonic 控制流将导致错误。
#判断变量是一个变量类型的bm.Variable
class OddEvenCauseError(bp.BrainPyObject):
def __init__(self):
super(OddEvenCauseError, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
if self.rand < 0.5: self.a += 1
else: self.a -= 1
return self.a
wrong_model = bm.jit(OddEvenCauseError())
try:
wrong_model()
except Exception as e:
print(f"{e.__class__.__name__}: {str(e)}")
直接进行会出现错误。
报错和之前出现的不是很一致
ConcretizationTypeError: This problem may be caused by several ways:
- Your if-else conditional statement relies on instances of brainpy.math.Variable.
- Your if-else conditional statement relies on functional arguments which do not set in “static_argnames” when applying JIT compilation.
More details please see
https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError- The static variables which set in the “static_argnames” are provided as arguments, not keyword arguments, like “jit_f(v1, v2)” [<-
wrong]. Please write it as “jit_f(static_k1=v1, static_k2=v2)” [<-
right].
当条件语句依赖动态变量,会出现编译错误。提供了两个代替if-else语句的条件语句。
brainpy.math.where: return element-wise conditional comparison results.
brainpy.math.ifelse: Conditional statements of if-else, or if-elif-else, … for a scalar-typed value.
brainpy.math.where
where(condition, x, y)函数根据条件返回从 x 或 y 中选择的元素。 它可以很好地处理标量、向量和高维数组。
a = 1.
bm.where(a < 0, 0., 1.)
#Array(1., dtype=float32, weak_type=True)
a = bm.random.random(5)
bm.where(a < 0.5, 0., 1.)
#Array(value=Array([0., 1., 1., 0., 1.]), dtype=float32)
a = bm.random.random((3, 3))
bm.where(a < 0.5, 0., 1.)
Array([[0., 0., 1.],
[0., 0., 1.],
[0., 0., 0.]], dtype=float32, weak_type=True)
语句为真,选择x,为假,选择第二个值。
上述判断奇偶的代码可以使用上面的函数进行修改。
class OddEvenWhere(bp.BrainPyObject):
def __init__(self):
super(OddEvenWhere, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
self.a += bm.where(self.rand < 0.5, 1., -1.)
return self.a
#test语句
model = bm.jit(OddEvenWhere())
model()
brainpy.math.ifelse
基于 JAX 的控制流语法 jax.lax.cond,BrainPy 提供了更通用的条件语句,可实现多重分支。
老版本1.1的还是调用的jax的语句make_cond()。此处应该是对代码进行了一些封装。
在最简单的情况下,
brainpy.math.ifelse(condition, branches, operands, dyn_vars=None)
相当于
条件,分支,操作符,动态变量。使用
def ifelse(condition, branches, operands, dyn_vars=None):
true_fun, false_fun = branches
if condition:
return true_fun(operands)
else:
return false_fun(operands)
一个修改的实例
class OddEvenCond(bp.BrainPyObject):
def __init__(self):
super(OddEvenCond, self).__init__()
self.rand = bm.Variable(bm.random.random(1))
self.a = bm.Variable(bm.zeros(1))
def __call__(self):
self.a += bm.ifelse(self.rand[0] < 0.5,
[lambda _: 1., lambda _: -1.])
return self.a
brainpy.math.ifelse()函数的参数如下:
conditions代表所有的条件语句。
branches代表所有的分支语句,分支语句可以是数字,可以是函数。
operands代表所有分支语句(如果是函数的话)需要的参数。因为从程序优化的角度,如果因为包含乱序和预执行,内存分配是耗时操作,为了避免回撤代价,所以强制要求分配空间的一致。
所有分支语句的返回值必须具有相同的形状(shape)和数据类型(dtype)
def ifelse(conditions, branches, operands, dyn_vars=None):
pred1, pred2, ... = conditions
func1, func2, ..., funcN = branches
if pred1:
return func1(operands)
elif pred2:
return func2(operands)
...
else:
return funcN(operands)
def f(a):
if a > 10:
return 1.
elif a > 5:
return 2.
elif a > 0:
return 3.
elif a > -5:
return 4.
else:
return 5.
def f(a):
return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
branches=[1., 2., 3., 4., 5.])
一个复杂的例子
分支是一些简单的函数
def f2(a, x):
return bm.ifelse(conditions=[a > 10, a > 5, a > 0, a > -5],
branches=[lambda x: x*2,
2.,
lambda x: x**2 -1,
lambda x: x - 4.,
5.],
operands=x)
如果bm的变量实例在分支函数中应用,你可以在 dyn_vars 参数中声明它们。
a = bm.Variable(bm.zeros(2))
b = bm.Variable(bm.ones(2))
def true_f(x): a.value += 1
def false_f(x): b.value -= 1
bm.ifelse(True, [true_f, false_f], dyn_vars=[a, b])
bm.ifelse(False, [true_f, false_f], dyn_vars=[a, b])
print('a:', a)
print('b:', b)
循环语句
重复语句用于重复一组(块)编程指令。 在 Python 中,我们通常有两个循环/重复语句:
for 循环: for 循环:对序列中的每个项目执行一组语句一次。
while 循环:重复执行一组语句,直到给定条件满足为止: 重复执行一组语句,直到满足给定条件。
Pythonic 循环语法
实际上,JAX 支持编写 Pythonic 循环。 您只需迭代序列数据,然后在迭代项上应用逻辑。 这种 Pythonic 循环语法可以与 JIT 编译兼容,但会导致较长的跟踪和编译时间。 例如
class LoopSimple(bp.BrainPyObject):
def __init__(self):
super(LoopSimple, self).__init__()
rng = bm.random.RandomState(123)
self.seq = bm.Variable(rng.random(1000))
self.res = bm.Variable(bm.zeros(1))
def __call__(self):
for s in self.seq:
self.res += s
return self.res.value
import time
def measure_time(f, return_res=False, verbose=True):
t0 = time.time()
r = f()
t1 = time.time()
if verbose:
print(f'Result: {r}, Time: {t1 - t0}')
return r if return_res else None
model = bm.jit(LoopSimple())
# First time will trigger compilation
measure_time(model)
第一次编译的时间长,语句逻辑复杂的话,时间就会很久。
当模型复杂且迭代时间较长时,第一次运行时的编译将变得难以忍受。 JAX 提供了几种重要的循环语法,包括
- jax.lax.fori_loop
- jax.lax.scan
- jax.lax.while_loop
BrainPy 还提供了自己的循环语法,尤其适用于用户使用 brainpy.math.Variable 的情况。 具体来说,它们是
- brainpy.math.for_loop
- brainpy.math.while_loop
brainpy.math.for_loop()
当你使用 Variable 时,brainpy.math.make_loop() 用于生成一个 for 循环函数。假设你正在使用多个 JaxArray(以 dyn_vars 为一组)来实现 body 函数 “body_fun”,并且你想收集其中几个(以 out_vars 为一组)的历史值。 有时,主体函数已经返回了一些值,您也想收集返回值。 用 Python 的语法,可以这样实现
brainpy.math.for_loop()
是 BrainPy 中的一个工具函数,用于在给定的范围内执行自定义的 for
循环。与 Python 自带的 for
循环类似,但这个函数可以在神经动力学模拟等场景中支持自动微分和符号计算。for_loop
常用于需要依次对某些变量执行重复性计算的场景,尤其适合在 BrainPy 的神经网络和动力学模型中使用。
brainpy.math.for_loop()
参数和说明
brainpy.math.for_loop(
body_fn, # 循环体函数
loop_vars, # 初始循环变量
elems, # 循环元素
unroll=None, # 是否展开循环,默认不展开
name=None # 操作名称
)
参数说明
-
body_fn
(必需):- 类型:
Callable
- 说明:定义循环主体的函数。
body_fn
会在每个迭代中执行,用于描述循环内的具体操作。该函数接收loop_vars
和elems[i]
(即当前迭代元素)作为输入,返回更新后的loop_vars
。
- 类型:
-
loop_vars
(必需):- 类型:
Any
- 说明:初始循环变量,可以是单个变量或多个变量的集合(如列表、字典)。它会在每次循环迭代后更新,以便用于下一次迭代。
- 类型:
-
elems
(必需):- 类型:
Iterable
- 说明:循环范围内的元素集合,
for_loop
会依次遍历每个元素。每次循环中,elems[i]
会作为body_fn
的输入,以用于计算。
- 类型:
-
unroll
(可选):- 类型:
bool
- 说明:是否将循环展开。展开后会占用更多内存,但可能会提高计算速度。默认值为不展开。
- 类型:
-
name
(可选):- 类型:
str
- 说明:循环操作的名称,主要用于调试和日志记录,可以留空。
- 类型:
使用示例
以下示例展示了 for_loop
的基本用法,每次迭代将变量累加 elems
中的每个元素:
import brainpy as bp
# 定义循环体函数:每次迭代将循环变量加上当前元素
def body_fn(loop_vars, elem):
x, = loop_vars
return (x + elem,)
# 初始循环变量
initial_loop_vars = (0,)
# 循环元素
elems = [1, 2, 3, 4, 5] # 累加1到5
# 执行 for_loop
result = bp.math.for_loop(body_fn, initial_loop_vars, elems)
print("结果:", result)
运行结果将是 x = 15
,因为 1 + 2 + 3 + 4 + 5 = 15
。
#In BrainPy, you can define this logic using brainpy.math.for_loop()
import brainpy.math
hist_of_out_vars = brainpy.math.for_loop(body_fun, dyn_vars, operands)
对于上面的例子,我们可以使用 brainpy.math.for_loop 重写为:
class LoopStruct(bp.BrainPyObject):
def __init__(self):
super(LoopStruct, self).__init__()
rng = bm.random.RandomState(123)
self.seq = rng.random(1000)
self.res = bm.Variable(bm.zeros(1))
def __call__(self):
def add(s):
self.res += s
return self.res.value
return bm.for_loop(body_fun=add, dyn_vars=[self.res], operands=self.seq)
model = bm.jit(LoopStruct())
r = measure_time(model, verbose=False, return_res=True)
r.shape
本质上,body_fun 定义了变量更新的一步更新规则。 dyn_vars 定义了 body_fun 中使用的所有动态变量。 operands 指定了 body_fun 的输入。 它将在第一个轴上循环运行。
brainpy.math.while_loop()
以下是 brainpy.math.while_loop()
的主要参数及说明:
brainpy.math.while_loop(
cond_fn, # 循环条件函数
body_fn, # 循环体函数
loop_vars, # 初始循环变量
max_iterations=None, # 最大迭代次数,默认无限制
unroll=None, # 是否对循环展开,默认不展开
name=None # 操作名称
)
参数说明
-
cond_fn
(必需):- 类型:
Callable
- 说明:定义循环条件的函数,接收当前的
loop_vars
作为输入,返回一个布尔值。如果返回True
,循环继续,否则停止。
- 类型:
-
body_fn
(必需):- 类型:
Callable
- 说明:定义循环主体的函数,描述每次迭代中要执行的操作。
body_fn
接收当前loop_vars
并返回更新后的loop_vars
。
- 类型:
-
loop_vars
(必需):- 类型:
Any
- 说明:初始的循环变量,可以是单个变量或变量的集合,如列表或字典。
loop_vars
会在每次迭代中传递给body_fn
,并更新以用于下一次迭代。
- 类型:
-
max_iterations
(可选):- 类型:
int
- 说明:设定循环最大迭代次数,避免无限循环。默认不限制。
- 类型:
-
unroll
(可选):- 类型:
bool
- 说明:是否展开循环。默认不展开,展开循环可能提高性能,但占用更多内存。
- 类型:
-
name
(可选):- 类型:
str
- 说明:循环操作的名称,可以为空,主要用于调试和日志记录。
- 类型:
当使用 Varible 时,brainpy.math.while_loop() 用于生成一个 while 循环函数。 它支持以下循环逻辑:
while condition:
statements
使用 brainpy.math.while_loop() 时,条件应封装为一个返回布尔值的 cond_fun 函数,语句应封装为一个接收最近一步旧值并返回当前一步更新值的 body_fun 函数。
while cond_fun(x):
x = body_fun(x)
请注意 brainpy.math.for_loop 和 brainpy.math.while_loop 之间的区别:
1.brainpy.math.for_loop 的返回值是作为历史值收集的值。 而 brainpy.math.while_loop 的返回值应该与输入值的形状和类型相同,因为它们代表的是更新后的值。
2.brainpy.math.for_loop 可以接收任何内容,而无需明确要求返回值。 但是,brainpy.math.while_loop 应该返回它接收到的内容。
i = bm.Variable(bm.zeros(1))#变量
counter = bm.Variable(bm.zeros(1))#变量
#条件语句
def cond_f():
return i[0] < 10#i.value < 10
#执行函数
def body_f():
i.value += 1.
counter.value += i
bm.while_loop(body_f, cond_f, dyn_vars=[i, counter], operands=())
另一种形式
i = bm.Variable(bm.zeros(1))
def cond_f(counter):
return i[0] < 10
def body_f(counter):
i.value += 1.
return counter + i[0]
bm.while_loop(body_f, cond_f, dyn_vars=[i], operands=(1., ))