cython numpy累积功能

我需要实现一个函数来对可变段长度的数组的元素进行求和。 所以,

a = np.arange(10)
section_lengths = np.array([3, 2, 4])
out = accumulate(a, section_lengths)
print out
array([  3.,   7.,  35.])

我尝试在这里使用cython实现:

https://gist.github.com/2784725

为了提高性能,我将其与section_lengths全部相同的情况下的纯numpy解决方案进行比较:

LEN = 10000
b = np.ones(LEN, dtype=np.int) * 2000
a = np.arange(np.sum(b), dtype=np.double)
out = np.zeros(LEN, dtype=np.double)

%timeit np.sum(a.reshape(-1,2000), axis=1)
10 loops, best of 3: 25.1 ms per loop

%timeit accumulate.accumulate(a, b, out)
10 loops, best of 3: 64.6 ms per loop

你会有任何改善性能的建议吗?


您可以尝试以下一些方法:

  • 除了@cython.boundscheck(False)编译器指令外,还可以尝试添加@cython.wraparound(False)

  • 在您的setup.py脚本中,尝试添加一些优化标志:

    ext_modules = [Extension("accumulate", ["accumulate.pyx"], extra_compile_args=["-O3",])]

  • 看一下由cython -a accumulate.pyx生成的.html文件cython -a accumulate.pyx以查看是否有部分缺少静态类型或严重依赖Python C-API调用:

    http://docs.cython.org/src/quickstart/cythonize.html#determining-where-to-add-types

  • 在方法的末尾添加一个return语句。 目前它正在i_el += 1紧密循环中执行一堆不必要的错误检查。

  • 不知道它是否会有所作为,但我倾向于使循环计数器cdef unsigned int而不仅仅是int

  • section_lengths不相等时,您也可以将代码与numpy进行比较,因为它可能需要的不仅仅是简单的sum


    在for循环更新中out[i_bas]较慢,您可以创建一个临时变量来执行归并,并在nest for循环结束时更新out[i_bas] 。 以下代码将与numpy版本一样快:

    import numpy as np
    cimport numpy as np
    
    ctypedef np.int_t DTYPE_int_t
    ctypedef np.double_t DTYPE_double_t
    
    cimport cython
    @cython.boundscheck(False)
    @cython.wraparound(False)
    def accumulate(
           np.ndarray[DTYPE_double_t, ndim=1] a not None,
           np.ndarray[DTYPE_int_t, ndim=1] section_lengths not None,
           np.ndarray[DTYPE_double_t, ndim=1] out not None,
           ):
        cdef int i_el, i_bas, sec_length, lenout
        cdef double tmp
        lenout = out.shape[0]
        i_el = 0
        for i_bas in range(lenout):
            tmp = 0
            for sec_length in range(section_lengths[i_bas]):
                tmp += a[i_el]
                i_el+=1
            out[i_bas] = tmp
    
    链接地址: http://www.djcxy.com/p/10791.html

    上一篇: cython numpy accumulate function

    下一篇: how to serialize boost::uuids::uuid