-
Notifications
You must be signed in to change notification settings - Fork 3
/
k-d-tree.py
139 lines (106 loc) · 3.63 KB
/
k-d-tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from random import seed, random
from time import time
from operator import itemgetter
from collections import namedtuple
from math import sqrt
from copy import deepcopy
def sqd(p1, p2):
return sum((c1 - c2) ** 2 for c1, c2 in zip(p1, p2))
class KdNode:
__slots__ = ("dom_elt", "split", "left", "right")
def __init__(self, dom_elt, split, left, right):
self.dom_elt = dom_elt
self.split = split
self.left = left
self.right = right
class Orthotope:
__slots__ = ("min", "max")
def __init__(self, mi, ma):
self.min, self.max = mi, ma
class KdTree:
__slots__ = ("n", "bounds")
def __init__(self, pts, bounds):
def nk2(split, exset):
if not exset:
return None
exset.sort(key=itemgetter(split))
m = len(exset) // 2
d = exset[m]
while m + 1 < len(exset) and exset[m + 1][split] == d[split]:
m += 1
d = exset[m]
s2 = (split + 1) % len(d) # cycle coordinates
return KdNode(d, split, nk2(s2, exset[:m]), nk2(s2, exset[m + 1:]))
self.n = nk2(0, pts)
self.bounds = bounds
T3 = namedtuple("T3", "nearest dist_sqd nodes_visited")
def find_nearest(k, t, p):
def nn(kd, target, hr, max_dist_sqd):
if kd is None:
return T3([0.0] * k, float("inf"), 0)
nodes_visited = 1
s = kd.split
pivot = kd.dom_elt
left_hr = deepcopy(hr)
right_hr = deepcopy(hr)
left_hr.max[s] = pivot[s]
right_hr.min[s] = pivot[s]
if target[s] <= pivot[s]:
nearer_kd, nearer_hr = kd.left, left_hr
further_kd, further_hr = kd.right, right_hr
else:
nearer_kd, nearer_hr = kd.right, right_hr
further_kd, further_hr = kd.left, left_hr
n1 = nn(nearer_kd, target, nearer_hr, max_dist_sqd)
nearest = n1.nearest
dist_sqd = n1.dist_sqd
nodes_visited += n1.nodes_visited
if dist_sqd < max_dist_sqd:
max_dist_sqd = dist_sqd
d = (pivot[s] - target[s]) ** 2
if d > max_dist_sqd:
return T3(nearest, dist_sqd, nodes_visited)
d = sqd(pivot, target)
if d < dist_sqd:
nearest = pivot
dist_sqd = d
max_dist_sqd = dist_sqd
n2 = nn(further_kd, target, further_hr, max_dist_sqd)
nodes_visited += n2.nodes_visited
if n2.dist_sqd < dist_sqd:
nearest = n2.nearest
dist_sqd = n2.dist_sqd
return T3(nearest, dist_sqd, nodes_visited)
return nn(t.n, p, t.bounds, float("inf"))
def show_nearest(k, heading, kd, p):
print(heading + ":")
print("Point: ", p)
n = find_nearest(k, kd, p)
print("Nearest neighbor:", n.nearest)
print("Distance: ", sqrt(n.dist_sqd))
print("Nodes visited: ", n.nodes_visited, "\n")
def random_point(k):
return [random() for _ in range(k)]
def random_points(k, n):
return [random_point(k) for _ in range(n)]
if __name__ == "__main__":
seed(1)
P = lambda *coords: list(coords)
kd1 = KdTree(
[P(2, 3), P(5, 4), P(9, 6), P(4, 7), P(8, 1), P(7, 2)],
Orthotope(P(0, 0), P(10, 10)),
)
show_nearest(2, "Wikipedia example data", kd1, P(9, 2))
N = 400000
t0 = time()
kd2 = KdTree(random_points(3, N), Orthotope(P(0, 0, 0), P(1, 1, 1)))
t1 = time()
text = lambda *parts: "".join(map(str, parts))
show_nearest(
2,
text(
"k-d tree with ", N, " random 3D points (generation time: ", t1 - t0, "s)"
),
kd2,
random_point(3),
)