Numba 官网:http://numba.pydata.org/。
Numba 对 for 循环有很好的加速效果,推荐使用。Numba 支持的函数有限,直接使用容易产生错误,需要多做测试后再应用。
一、简单的例子
一个简单的使用例子(在函数前增加 @jit):
from numba import jit
@jit
def f(x):
return x*2
print(f(3))
二、求和的例子
以下代码是这些的时间对比:
- for 循环求和
- sum()
- numpy.sum()
- numba + for 循环求和
- numba + numpy.sum()
- numba (nopython) + numpy.sum()
- numba (nopython, parallel) + numpy.sum()
from numba import jit
import numpy as np
import time
numpy_array = np.arange(0,1e5,1)
times = 1000
def for_sum(numpy_array):
sum = 0
for number in numpy_array:
sum += number
return sum
start = time.time()
for _ in range(times):
result = for_sum(numpy_array)
end = time.time()
print('for循环求和时间:', end - start)
start = time.time()
for _ in range(times):
result = sum(numpy_array)
end = time.time()
print('sum()函数求和时间:', end - start)
start = time.time()
for _ in range(times):
result = np.sum(numpy_array)
end = time.time()
print('numpy.sum()函数求和时间:', end - start)
print()
@jit
def numba_for_sum(numpy_array):
sum = 0
for number in numpy_array:
sum += number
return sum
@jit
def numba_np_sum(numpy_array):
result = np.sum(numpy_array)
return result
@jit(nopython=True)
def numba_nopython_np_sum(numpy_array):
result = np.sum(numpy_array)
return result
@jit(nopython=True, parallel=True)
def numba_nopython_parallel_np_sum(numpy_array):
result = np.sum(numpy_array)
return result
start = time.time()
for _ in range(times):
result = numba_for_sum(numpy_array)
end = time.time()
print('numba + for循环求和时间:', end - start)
start = time.time()
for _ in range(times):
result = numba_np_sum(numpy_array)
end = time.time()
print('numba + numpy.sum()函数求和时间:', end - start)
start = time.time()
for _ in range(times):
result = numba_nopython_np_sum(numpy_array)
end = time.time()
print('numba(nopython) + numpy.sum()函数求和时间:', end - start)
start = time.time()
for _ in range(times):
result = numba_nopython_parallel_np_sum(numpy_array)
end = time.time()
print('numba(nopython,parallel) + numpy.sum()函数求和时间:', end - start)
运行结果:
for循环求和时间: 6.970077037811279
sum()函数求和时间: 5.682592391967773
numpy.sum()函数求和时间: 0.025031328201293945
numba + for循环求和时间: 0.4277036190032959
numba + numpy.sum()函数求和时间: 0.24217891693115234
numba(nopython) + numpy.sum()函数求和时间: 0.1281888484954834
numba(nopython,parallel) + numpy.sum()函数求和时间: 0.7088756561279297
在这里的例子中,直接使用 numpy.sum() 的计算速度是最快的。
对于求和的过程明显可以直接使用 numpy.sum(),但 Numba 加速可适用于普遍的 for 循环处理 Numpy 数组的情况。在实际应用中,Numba 是否能起到有效加速需要做具体的代码测试和时间测试,同时也可查阅官方的文档说明。
三、不同 jit 参数在循环求和中的时间对比
@jit、@jit(nopython=True)、@jit(nopython=True, parallel=True) 在循环求和中的时间对比:
from numba import jit
from numba import prange
import time
import numpy as np
numpy_array = np.arange(0,1e5,1)
times = 1000
def for_sum(numpy_array):
sum = 0
for number in numpy_array:
sum += number
return sum
@jit
def numba_for_sum_1(numpy_array):
sum = 0
for number in numpy_array:
sum += number
return sum
@jit(nopython=True)
def numba_for_sum_2(numpy_array):
sum = 0
for number in numpy_array:
sum += number
return sum
@jit(nopython=True, parallel=True)
def numba_for_sum_3(numpy_array):
sum = 0
for i in prange(len(numpy_array)):
sum += numpy_array[i]
return sum
start = time.time()
for _ in range(times):
result = for_sum(numpy_array)
end = time.time()
print('for循环时间:', end - start)
start = time.time()
for _ in range(times):
result = numba_for_sum_1(numpy_array)
end = time.time()
print('@jit时间:', end - start)
start = time.time()
for _ in range(times):
result = numba_for_sum_2(numpy_array)
end = time.time()
print('@jit(nopython=True)时间:', end - start)
start = time.time()
for _ in range(times):
result = numba_for_sum_3(numpy_array)
end = time.time()
print('@jit(nopython=True, parallel=True)时间:', end - start)
运行结果:
for循环时间: 6.5412514209747314
@jit时间: 0.4498889446258545
@jit(nopython=True)时间: 0.1052699089050293
@jit(nopython=True, parallel=True)时间: 0.4585683345794678
这里的例子用 @jit(nopython=True) 效果最好。如果是计算时间比较久,且有多核的情况,那么用 @jit(nopython=True, parallel=True) 效果最好。
四、报错的例子
这里记录一个报错的例子,供参考。
(1)这样写会报错:
@jit
def numba_for_sum():
sum = 0
for i0 in range(2):
sum = sum + np.array([1, 2])
return sum
print(numba_for_sum())
(2)这样写是运行正确:
@jit
def numba_for_sum():
sum = np.zeros(2)
for i0 in range(2):
sum = sum + np.array([1, 2])
return sum
print(numba_for_sum())
(3)而在这种没有循环的情况下,数组的广播能力又恢复了,不会报错:
@jit
def numba_for_sum():
sum = 0
sum = sum + np.array([1, 2])
return sum
print(numba_for_sum())
这应该跟 numba 的运行机制有关,得一步步代码做测试。如果直接把原先的代码加一个 @jit,大概率会出问题。
参考资料:
[3] numba从入门到精通(1)—为什么numba能够加速
【说明:本站主要是个人的一些笔记和代码分享,内容可能会不定期修改。为了使全网显示的始终是最新版本,这里的文章未经同意请勿转载。引用请注明出处:https://www.guanjihuan.com】