Closest Pair Implemetation Python

I am trying to implement the closest pair problem in Python using divide and conquer, everything seems to work fine except that in some input cases, there is a wrong answer. My code is as follows:

def closestSplitPair(Px,Py,d):
    X = Px[len(Px)-1][0]
    Sy = [item for item in Py if item[0]>=X-d and item[0]<=X+d]
    best,p3,q3 = d,None,None
    for i in xrange(0,len(Sy)-2):
        for j in xrange(1,min(7,len(Sy)-1-i)):
            if dist(Sy[i],Sy[i+j]) < best:
                best = (Sy[i],Sy[i+j])
                p3,q3 = Sy[i],Sy[i+j]
    return (p3,q3,best)

I am calling the above function through a recursive function which is as follows:

def closestPair(Px,Py): """Px and Py are input arrays sorted according to
their x and y coordinates respectively"""
    if len(Px) <= 3:
        return min_dist(Px)
        mid = len(Px)/2
        Qx = Px[:mid] ### x-sorted left side of P
        Qy = Py[:mid] ### y-sorted left side of P
        Rx = Px[mid:] ### x-sorted right side of P
        Ry = Py[mid:] ### y-sorted right side of P
        (p1,q1,d1) = closestPair(Qx,Qy)
        (p2,q2,d2) = closestPair(Rx,Ry)
        d = min(d1,d2)
        (p3,q3,d3) = closestSplitPair(Px,Py,d)
        return min((p1,q1,d1),(p2,q2,d2),(p3,q3,d3),key=lambda tup: tup[2])

where min_dist(P) is the brute force implementation of the closest pair algorithm for a list P having 3 or less elements and returns a tuple containing the pair of closest points and their distance.

If my input is P = [(0,0),(7,6),(2,20),(12,5),(16,16),(5,8),(19,7),(14,22),(8,19),(7,29),(10,11),(1,13)] , then my output is ((5,8),(7,6),2.8284271) which is the correct output. But when my input is P = [(94, 5), (96, -79), (20, 73), (8, -50), (78, 2), (100, 63), (-14, -69), (99, -8), (-11, -7), (-78, -46)] the output I get is ((78, 2), (94, 5), 16.278820596099706) whereas the correct output should be ((94, 5), (99, -8), 13.92838827718412)

You have two problems, you are forgetting to call dist to update the best distance. But the main problem is there is more than one recursive call happening so you can end up overwriting when you find a closer split pair with the default, best,p3,q3 = d,None,None . I passed the best pair from closest_pair as an argument to closest_split_pair so I would not potentially overwrite the value.

def closest_split_pair(p_x, p_y, delta, best_pair): # <- a parameter
    ln_x = len(p_x)
    mx_x = p_x[ln_x // 2][0]
    s_y = [x for x in p_y if mx_x - delta <= x[0] <= mx_x + delta]
    best = delta
    for i in range(len(s_y) - 1):
        for j in range(1, min(i + 7, (len(s_y) - i))):
            p, q = s_y[i], s_y[i + j]
            dst = dist(p, q)
            if dst < best:
                best_pair = p, q
                best = dst
    return best_pair

The end of closest_pair looks like the following:

    p_1, q_1 = closest_pair(srt_q_x, srt_q_y)
    p_2, q_2 = closest_pair(srt_r_x, srt_r_y)
    closest = min(dist(p_1, q_1), dist(p_2, q_2))
    # get min of both and then pass that as an arg to closest_split_pair
    mn = min((p_1, q_1), (p_2, q_2), key=lambda x: dist(x[0], x[1]))
    p_3, q_3 = closest_split_pair(p_x, p_y, closest,mn)
    # either return mn or we have a closer split pair
    return min(mn, (p_3, q_3), key=lambda x: dist(x[0], x[1]))

You also have some other logic issues, your slicing logic is not correct, I made some changes to your code where brute is just a simple bruteforce double loop:

def closestPair(Px, Py):
    if len(Px) <= 3:
        return brute(Px)

    mid = len(Px) / 2
    # get left and right half of Px 
    q, r = Px[:mid], Px[mid:]
     # sorted versions of q and r by their x and y coordinates 
    Qx, Qy = [x for x in q if Py and  x[0] <= Px[-1][0]], [x for x in q if x[1] <= Py[-1][1]]
    Rx, Ry = [x for x in r if Py and x[0] <= Px[-1][0]], [x for x in r if x[1] <= Py[-1][1]]
    (p1, q1) = closestPair(Qx, Qy)
    (p2, q2) = closestPair(Rx, Ry)
    d = min(dist(p1, p2), dist(p2, q2))
    mn = min((p1, q1), (p2, q2), key=lambda x: dist(x[0], x[1]))
    (p3, q3) = closest_split_pair(Px, Py, d, mn)
    return min(mn, (p3, q3), key=lambda x: dist(x[0], x[1]))

I just did the algorithm today so there are no doubt some improvements to be made but this will get you the correct answer.

Here is a recursive divide-and-conquer python implementation of the closest point problem based on the heap data structure. It also accounts for the negative integers. It can return the k-closest point by popping k nodes in the heap using heappop().

from __future__ import division
from collections import namedtuple
from random import randint
import math as m
import heapq as hq

def get_key(item):

def closest_point_problem(points):
    point = []
    heap = []
    pt = namedtuple('pt', 'x y')
    for i in range(len(points)):
        point.append(pt(points[i][0], points[i][1]))
    point = sorted(point, key=get_key)
    visited_index = []
    find_min(0, len(point) - 1, point, heap, visited_index)

def find_min(start, end, point, heap, visited_index):
    if len(point[start:end + 1]) & 1:
        mid = start + ((len(point[start:end + 1]) + 1) >> 1)
        mid = start + (len(point[start:end + 1]) >> 1)
        if start in visited_index:
            start = start + 1
        if end in visited_index:
            end = end - 1
    if len(point[start:end + 1]) > 3:
        if start < mid - 1:
            distance1 = m.sqrt((point[start].x - point[start + 1].x) ** 2 + (point[start].y - point[start + 1].y) ** 2)
            distance2 = m.sqrt((point[mid].x - point[mid - 1].x) ** 2 + (point[mid].y - point[mid - 1].y) ** 2)
            if distance1 < distance2:
                hq.heappush(heap, (distance1, ((point[start].x, point[start].y), (point[start + 1].x, point[start + 1].y))))
                hq.heappush(heap, (distance2, ((point[mid].x, point[mid].y), (point[mid - 1].x, point[mid - 1].y))))
            visited_index.append(start + 1)
            visited_index.append(mid - 1)
            find_min(start, mid, point, heap, visited_index)
        if mid + 1 < end:
            distance1 = m.sqrt((point[mid].x - point[mid + 1].x) ** 2 + (point[mid].y - point[mid + 1].y) ** 2)
            distance2 = m.sqrt((point[end].x - point[end - 1].x) ** 2 + (point[end].y - point[end - 1].y) ** 2)
            if distance1 < distance2:
                hq.heappush(heap, (distance1, ((point[mid].x, point[mid].y), (point[mid + 1].x, point[mid + 1].y))))
                hq.heappush(heap, (distance2, ((point[end].x, point[end].y), (point[end - 1].x, point[end - 1].y))))
            visited_index.append(end - 1)
            visited_index.append(mid + 1)
            find_min(mid, end, point, heap, visited_index)

x = []
num_points = 10
for i in range(num_points):
    x.append((randint(- num_points << 2, num_points << 2), randint(- num_points << 2, num_points << 2)))


Brute force can work faster with stdlib functions. Therefore, it can be effectively applied to more than 3 points.

from itertools import combinations

def closest(points_list):
    return min((dist(p1, p2), p1, p2)
               for p1, p2 in combinations(points_list, r=2))

The most effective way to divide the points is to divide them on tiles. If you don't have outliers, you can just split your space on equal parts and compare points only in the same or in the neighbour tiles. Number of tiles must be as large as it possible. But, to avoid isolated tiles, when each point doesn't have points in neighbour tiles, you must limit number of tiles by the number of points. Full listing:

from math import sqrt
from itertools import combinations, product
from collections import defaultdict
import sys

max_float = sys.float_info.max

def dist((x1, y1), (x2, y2)):
    return sqrt((x1 - x2) ** 2 + (y1 - y2) **2)

def closest(points_list):
    if len(points_list) < 2:
        return (max_float, None, None)  # default value compatible with min function
    return min((dist(p1, p2), p1, p2)
               for p1, p2 in combinations(points_list, r=2))

def closest_between(pnt_lst1, pnt_lst2):
    if not pnt_lst1 or not pnt_lst2:
        return (max_float, None, None)  # default value compatible with min function
    return min((dist(p1, p2), p1, p2)
               for p1, p2 in product(pnt_lst1, pnt_lst2))

def divide_on_tiles(points_list):
    side = int(sqrt(len(points_list)))  # number of tiles on one side of square
    tiles = defaultdict(list)
    min_x = min(x for x, y in points_list)
    max_x = max(x for x, y in points_list)
    min_y = min(x for x, y in points_list)
    max_y = max(x for x, y in points_list)
    tile_x_size = float(max_x - min_x) / side
    tile_y_size = float(max_y - min_y) / side
    for x, y in points_list:
        x_tile = int((x - min_x) / tile_x_size)
        y_tile = int((y - min_y) / tile_y_size)
        tiles[(x_tile, y_tile)].append((x, y))
    return tiles

def closest_for_tile(tiles, (x_tile, y_tile)):
    points = tiles[(x_tile, y_tile)]
    return min(closest(points),
               # use dict.get to avoid creating empty tiles
               # we compare current tile only with half of neighbours (right/top),
               # because another half (left/bottom) make it in another iteration by themselves
               closest_between(points, tiles.get((x_tile+1, y_tile))),
               closest_between(points, tiles.get((x_tile, y_tile+1))),
               closest_between(points, tiles.get((x_tile+1, y_tile+1))),
               closest_between(points, tiles.get((x_tile-1, y_tile+1))))

def find_closest_in_tiles(tiles):
    return min(closest_for_tile(tiles, coord) for coord in tiles.keys())

P1 = [(0,0),(7,6),(2,20),(12,5),(16,16),(5,8),(19,7),(14,22),(8,19),(7,29),(10,11),(1,13)]
P2 = [(94, 5), (96, -79), (20, 73), (8, -50), (78, 2), (100, 63), (-14, -69), (99, -8), (-11, -7), (-78, -46)]

print find_closest_in_tiles(divide_on_tiles(P1)) # (2.8284271247461903, (7, 6), (5, 8))
print find_closest_in_tiles(divide_on_tiles(P2)) # (13.92838827718412, (94, 5), (99, -8))
print find_closest_in_tiles(divide_on_tiles(P1 + P2)) # (2.8284271247461903, (7, 6), (5, 8))

