Skip to content

Commit 8fe9a66

Browse files
committed
Add nearestNeighbor.
1 parent 3a904f6 commit 8fe9a66

File tree

1 file changed

+41
-3
lines changed

1 file changed

+41
-3
lines changed

KdTree.hs

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
{-# LANGUAGE TemplateHaskell #-}
22

3-
-- Based on
43
-- http://en.wikipedia.org/wiki/K-d_tree
5-
-- Translated by Issac Trotts
4+
-- Issac Trotts
5+
6+
import Data.Maybe
67

78
import qualified Data.Foldable as F
89
import qualified Data.List as L
@@ -11,11 +12,21 @@ import Test.QuickCheck
1112
import Test.QuickCheck.All
1213

1314
class Point p where
15+
-- |dimension returns the number of coordinates of a point.
1416
dimension :: p -> Int
1517

16-
-- Get the k'th coordinate, starting from 0
18+
-- |coord gets the k'th coordinate, starting from 0.
1719
coord :: Int -> p -> Double
1820

21+
-- |dist2 returns the squared distance between two points.
22+
dist2 :: p -> p -> Double
23+
dist2 a b = sum . map diff2 $ [0..dimension a - 1]
24+
where diff2 i = (coord i a - coord i b)^2
25+
26+
-- |compareDistance p a b compares the distances of a and b to p.
27+
compareDistance :: (Point p) => p -> p -> p -> Ordering
28+
compareDistance p a b = dist2 p a `compare` dist2 p b
29+
1930
data Point3d = Point3d { p3x :: Double, p3y :: Double, p3z :: Double }
2031
deriving (Eq, Ord, Show)
2132

@@ -75,6 +86,23 @@ subtrees :: KdTree p -> [KdTree p]
7586
subtrees KdEmpty = [KdEmpty]
7687
subtrees t@(KdNode l x r axis) = subtrees l ++ [t] ++ subtrees r
7788

89+
nearestNeighbor :: Point p => KdTree p -> p -> Maybe p
90+
nearestNeighbor KdEmpty probe = Nothing
91+
nearestNeighbor (KdNode KdEmpty p KdEmpty _) probe = Just p
92+
nearestNeighbor (KdNode l p r axis) probe =
93+
if xProbe <= xp then doStuff l r else doStuff r l
94+
where xProbe = coord axis probe
95+
xp = coord axis p
96+
doStuff tree1 tree2 =
97+
let candidates1 = case nearestNeighbor tree1 probe of
98+
Nothing -> [p]
99+
Just best1 -> [best1, p]
100+
sphereIntersectsPlane = (xProbe - xp)^2 <= dist2 probe p
101+
candidates2 = if sphereIntersectsPlane
102+
then candidates1 ++ maybeToList (nearestNeighbor tree2 probe)
103+
else candidates1 in
104+
Just . L.minimumBy (compareDistance probe) $ candidates2
105+
78106
-- Testing -----
79107

80108
-- |invariant tells whether the KD tree property holds for a given tree and
@@ -105,5 +133,15 @@ prop_invariant points = invariant' . fromList $ points
105133
prop_samePoints :: [Point3d] -> Bool
106134
prop_samePoints points = L.sort points == (L.sort . toList . fromList $ points)
107135

136+
prop_nearestNeighbor :: [Point3d] -> Point3d -> Bool
137+
prop_nearestNeighbor points probe =
138+
nearestNeighbor tree probe == bruteNearestNeighbor points probe
139+
where tree = fromList points
140+
141+
bruteNearestNeighbor :: [Point3d] -> Point3d -> Maybe Point3d
142+
bruteNearestNeighbor [] _ = Nothing
143+
bruteNearestNeighbor points probe =
144+
Just . head . L.sortBy (compareDistance probe) $ points
145+
108146
main = $quickCheckAll
109147

0 commit comments

Comments
 (0)