Efficiently compute sum of N smallest numbers in an array
I have a code where first I need to sort values and then I need to sum the first 10 elements. I would love to use Numba package to speed the run time, but it is not working, Numba is getting the code slower than just Numpy.
My first test, just for sum:
import numpy as np
import numba
np.random.seed(0)
def SumNumpy(x):
return np.sum(x[:10])
@numba.jit()
def SumNumpyNumba(x):
return np.sum(x[:10])
My test:
x = np.random.rand(1000000000)
%timeit SumNumpy(x)
%timeit SumNumpyNumba(x)
The results:
100000 loops, best of 3: 6.8 µs per loop
1000000 loops, best of 3: 715 ns per loop
Here its is okay, Numba is doing a good work. But when I try together np.sort and np.sum:
def sumSortNumpy(x):
y = np.sort(x)
return np.sum(y[:10])
@numba.jit()
def sumSortNumpyNumba(x):
y = np.sort(x)
return np.sum(y[:10])
and test:
x = np.random.rand(100000)
%timeit sumSortNumpy(x)
%timeit sumSortNumpyNumba(x)
Results:
100 loops, best of 3: 14.6 ms per loop
10 loops, best of 3: 20.6 ms per loop
Numba/Numpy get slower than just Numpy. So my question if is there something we could to improve the functiom "sumSortNumpyNumba"?
I appreciate help.
Thanks.
We are summing after sorting, so the order won't matter within the first N=10
elements. Hence, we can use np.argpartition
that avoids the sorting step and simply gives us the group of first N
smallest numbers that could be summed over later on, like so -
def sumSortNumPyArgpartition(x, N=10):
return x[np.argpartition(x, N)[:N]].sum()
Timings on various datasets -
In [39]: np.random.seed(0)
...: x = np.random.rand(1000000)
In [40]: %timeit sumSortNumpy(x)
...: %timeit sumSortNumPyArgpartition(x)
10 loops, best of 3: 78.6 ms per loop
100 loops, best of 3: 12.3 ms per loop
In [41]: np.random.seed(0)
...: x = np.random.rand(10000000)
In [42]: %timeit sumSortNumpy(x)
...: %timeit sumSortNumPyArgpartition(x)
1 loop, best of 3: 920 ms per loop
10 loops, best of 3: 153 ms per loop
In [43]: np.random.seed(0)
...: x = np.random.rand(100000000)
In [44]: %timeit sumSortNumpy(x)
...: %timeit sumSortNumPyArgpartition(x)
1 loop, best of 3: 10.6 s per loop
1 loop, best of 3: 978 ms per loop
链接地址: http://www.djcxy.com/p/62932.html
上一篇: 为什么XmlSerializer抛出InvalidOperationException?
下一篇: 有效计算数组中N个最小数字的和