在Python中高效地检查大量对象的欧几里得距离

在路线规划算法中,我试图根据到另一个节点的距离在节点列表上执行过滤器。 我实际上是从粗略的场景图中拉出列表。 我使用术语“单元格”来表示一个简单场景图中的一个卷,从中我们已经获取了彼此接近的节点列表。

现在,我正在实施这个:

# SSCCE version of the core function
def nodes_in_range(src, cell, maxDist):
    srcX, srcY, srcZ = src.x, src.y, src.z
    maxDistSq = maxDist ** 2
    for node in cell:
        distSq = (node.x - srcX) ** 2
        if distSq > maxDistSq: continue
        distSq += (node.y - srcY) ** 2
        if distSq > maxDistSq: continue
        distSq += (node.z - srcZ) ** 2
        if distSq <= maxDistSq:
            yield node, distSq ** 0.5  # fast sqrt

from collections import namedtuple
class Node(namedtuple('Node', ('ID', 'x', 'y', 'z'))):
    # actual class has assorted other properties
    pass

# 1, 3 and 9 are <= 4.2 from Node(1)
cell = [
    Node(1, 0, 0, 0),
    Node(2, -2, -3, 4),
    Node(3, .1, .2, .3),
    Node(4, 2.3, -3.3, -4.5),
    Node(5, -2.5, 4.5, 5),
    Node(6, 4, 3., 2.),
    Node(7, -2.46, 2.46, -2.47),
    Node(8, 2.45, -2.46, -2.47),
    Node(9, .5, .5, .1),
    Node(10, 5, 6, 7),
    # In practice, cells have upto 600 entries
]

if __name__ == "__main__":
    for node, dist in nodes_in_range(cell[0], cell, 4.2):
        print("{:3n} {:5.2f}".format(node.ID, dist))

这个例程被称为很多(在某些查询中被调用的次数是10 ^ 7次以上),所以perf的每一个部分都很重要,并且避免使用条件查询进行成员查询。

我试图做的是切换到numpy并组织单元格,以便可以进行矢量化。 我想要达到的是这样的:

import numpy
import numpy.linalg
contarry = numpy.ascontiguousarray
float32 = numpy.float32

# The "np_cell" has two arrays: one is the list of nodes and the
# second is a vectorizable array of their positions.
# np_cell[N][1] == numpy array position of np_cell[N][0]

def make_np_cell(cell):
    return (
        cell,
        contarry([contarry((node.x, node.y, node.z), float32) for node in cell]),
     )

# This version fails because norm returns a single value.
def np_nodes_in_range1(srcPos, np_cell, maxDist):
    distances = numpy.linalg.norm(np_cell[1] - srcPos)

    for (node, dist) in zip(np_cell[0], distances):
        if dist <= maxDist:
            yield node, dist

# This version fails because 
def np_nodes_in_range2(srcPos, np_cell, maxDist):
    # this will fail because the distances are wrong
    distances = numpy.linalg.norm(np_cell[1] - srcPos, ord=1, axis=1)
    for (node, dist) in zip(np_cell[0], distances):
        if dist <= maxDist:
            yield node, dist

# This version doesn't vectorize and so performs poorly
def np_nodes_in_range3(srcPos, np_cell, maxDist):
    norm = numpy.linalg.norm
    for (node, pos) in zip(np_cell[0], np_cell[1]):
        dist = norm(srcPos - pos)
        if dist <= maxDist:
            yield node, dist

if __name__ == "__main__":
    np_cell = make_np_cell(cell)
    srcPos = np_cell[1][0]  # Position column [1], first node [0]
    print("v1 - fails because it gets a single distance")
    try:
        for node, dist in np_nodes_in_range1(srcPos, np_cell, float32(4.2)):
            print("{:3n} {:5.2f}".format(node.ID, dist))
    except TypeError:
        print("distances was a single value")

    print("v2 - gets the wrong distance values")
    for node, dist in np_nodes_in_range2(srcPos, np_cell, float32(4.2)):
        print("{:3n} {:5.2f}".format(node.ID, dist))

    print("v3 - slower")
    for node, dist in np_nodes_in_range3(srcPos, np_cell, float32(4.2)):
        print("{:3n} {:5.2f}".format(node.ID, dist))

整体结合在这里 - 我包含了一个v4,它尝试使用enumerate而不是zip并发现它的速度大约慢了12us。

示例输出:

  1  0.00
  3  0.37
  9  0.71
v1 - fails because it gets a single distance
distances was a single value
v2 - gets the wrong distance values
  1  0.00
  3  0.60
  9  1.10
v3 - slower
  1  0.00
  3  0.37
  9  0.71
v4 - v2 using enumerate
  1  0.00
  3  0.60
  9  1.10

至于性能,我们可以用timeit来测试。 我会用一个简单的乘法来增加单元中的节点数量:

In [2]: from sscce import *
In [3]: cell = cell * 32   # increase to 320 nodes
In [4]: len(cell)
Out[4]: 320
In [5]: %timeit -n 1000 -r 7 sum(1 for _ in nodes_in_range(cell[0], cell, 4.2))
1000 loops, best of 7: 742 µs per loop
In [6]: np_cell = make_np_cell(cell)
In [7]: srcPos = np_cell[1][0]
In [8]: %timeit -n 1000 -r 7 sum(1 for _ in np_nodes_in_range2(srcPos, np_cell, numpy.float32(4.2)))
1000 loops, best of 7: 136 µs per loop
In [9]: %timeit -n 1000 -r 7 sum(1 for _ in np_nodes_in_range3(srcPos, np_cell, numpy.float32(4.2)))
1000 loops, best of 7: 3.64 ms per loop

强调:

nodes_in_range
    1000 loops, best of 7: 742 µs per loop

np_nodes_in_range2
    1000 loops, best of 7: 136 µs per loop

np_nodes_in_range3
    1000 loops, best of 7: 3.64 ms per loop # OUCH

问题:

  • 我在做矢量化距离计算时做错了什么?

    distances = numpy.linalg.norm(np_cell[1] - srcPos)
    

    VS

    distances = numpy.linalg.norm(np_cell[1] - srcPos, ord=1, axis=1)
    
  • 这是最好的方法吗?

  • 细胞群体在几个节点和几百个之间变化。 我目前正在遍历单元格,但似乎我想编组一个完整的候选集合(nodes[], positions[])尽管为此可能会额外增加成本(我总是可以使用批累加器,所以我总是尝试填充蓄能器,比如在排水前至少有1024个位置)。 但我认为这种想法是由我使用连续阵列而形成的。 我应该如何看待这样的事情:

    nodes_in_range(src, chain(cell.nodes for cell in scene if cell_in_range(boundingBox)))
    
  • 而不是担心试图压扁整个事情?


  • 我在做矢量化距离计算时做错了什么?

    distances = numpy.linalg.norm(np_cell[1] - srcPos)
    

    VS

    distances = numpy.linalg.norm(np_cell[1] - srcPos, ord=1, axis=1)
    
  • 首先,如果axis=Nonenp.linalg.norm将计算向量范数(如果输入是1D)或矩阵范数(如果输入是多维的)。 这两个都是标量。

    其次, ord=1意味着L1标准(即曼哈顿距离),而不是像标题中提到的欧几里得距离。


  • 这是最好的方法吗?
  • 的KD树可能会快很多 。 您可以使用scipy.spatial.cKDTree进行高尔夫球搜索,以查找距查询点某个阈值距离内的节点:

    import numpy as np
    from scipy.spatial import cKDTree
    
    # it will be much easier (and faster) to deal with numpy arrays here (you could
    # always look up the corresponding node objects by index if you wanted to)
    X = np.array([(n.x, n.y, n.z) for n in cell])
    
    # construct a k-D tree
    tree = cKDTree(X)
    
    # query it with the first point, find the indices of all points within a maximum
    # distance of 4.2 of the query point
    query_point = X[0]
    idx = tree.query_ball_point(query_point, r=4.2, p=2)
    
    # these indices are one out from yours, since they start at 0 rather than 1
    print(idx)
    # [0, 2, 8]
    
    # query_ball_point doesn't return the distances, but we can easily compute these
    # using broadcasting
    neighbor_points = X[idx]
    
    d = np.sqrt(((query_point[None, :] - neighbor_points) ** 2).sum(1))
    print(d)
    # [ 0.          0.37416574  0.71414284]
    

    标杆:

    查询cKDTree的速度非常快,即使是非常多的点:

    X = np.random.randn(10000000, 3)
    tree = cKDTree(X)
    
    %timeit tree.query_ball_point(np.random.randn(3), r=4.2)
    # 1 loops, best of 3: 229 ms per loop
    

    正如你在评论中提到的那样,上面的例子比你的数据更加严格的性能测试。 由于选择的距离容差以及数据为高斯(因此聚集在0附近)的事实,它匹配了10m点的约99%。

    以下是对统一数据的测试,其中有更严格的距离截止点,匹配大约30%的点数,如下例所示:

    %timeit tree.query_ball_point((0., 0., 0.), r=1.2)
    # 10 loops, best of 3: 86 ms per loop
    

    显然,这比你使用的点数要多得多。 对于您的示例数据:

    tree = cKDTree(np_cell[1])
    %timeit tree.query_ball_point(np_cell[1][0], r=4.2)
    # The slowest run took 4.26 times longer than the fastest. This could mean that an intermediate result is being cached 
    # 100000 loops, best of 3: 16.9 µs per loop
    

    这可以轻松击败我的机器上的np_nodes_in_range2功能:

    %timeit sum(1 for _ in np_nodes_in_range2(srcPos, np_cell, numpy.float32(4.2)))
    # The slowest run took 7.77 times longer than the fastest. This could mean that an intermediate result is being cached 
    # 10000 loops, best of 3: 84.4 µs per loop
    

    其他要考虑的事项:

    如果您需要同时查询多个点,则构建第二个树并使用query_ball_tree而不是query_ball_point会更有效:

    X = np.random.randn(100, 3)
    Y = np.random.randn(10, 3)
    tree1 = cKDTree(X)
    tree2 = cKDTree(Y)
    
    # indices contains a list-of-lists, where the ith sublist contains the indices
    # of the neighbours of Y[i] in X
    indices = tree2.query_ball_tree(tree1, r=4.2)
    

    如果你不关心指数,只想得到球的点数,那么使用count_neighbours可能会更快:

    n_neighbors = tree2.count_neighbors(tree1, r=4.2)
    
    链接地址: http://www.djcxy.com/p/84773.html

    上一篇: Efficiently checking Euclidean distance for a large number of objects in Python

    下一篇: How to force stop an Android app