@@ -19,7 +19,7 @@ class Point p where
1919 -- | dist2 returns the squared distance between two points.
2020 dist2 :: p -> p -> Double
2121 dist2 a b = sum . map diff2 $ [0 .. dimension a - 1 ]
22- where diff2 i = (coord i a - coord i b)^ 2
22+ where diff2 i = (coord i a - coord i b)^ 2
2323
2424-- | compareDistance p a b compares the distances of a and b to p.
2525compareDistance :: (Point p ) => p -> p -> p -> Ordering
@@ -37,9 +37,9 @@ instance Point Point3d where
3737
3838
3939data KdTree point = KdNode { kdLeft :: KdTree point ,
40- kdPoint :: point ,
40+ kdPoint :: point ,
4141 kdRight :: KdTree point ,
42- kdAxis :: Int }
42+ kdAxis :: Int }
4343 | KdEmpty
4444 deriving (Eq , Ord , Show )
4545
@@ -50,8 +50,8 @@ instance Functor KdTree where
5050instance F. Foldable KdTree where
5151 foldr f init KdEmpty = init
5252 foldr f init (KdNode l x r _) = F. foldr f init3 l
53- where init3 = f x init2
54- init2 = F. foldr f init r
53+ where init3 = f x init2
54+ init2 = F. foldr f init r
5555
5656fromList :: Point p => [p ] -> KdTree p
5757fromList points = fromListWithDepth points 0
@@ -62,16 +62,16 @@ fromListWithDepth [] _ = KdEmpty
6262fromListWithDepth points depth = node
6363 where axis = axisFromDepth (head points) depth
6464
65- -- Sort point list and choose median as pivot element
66- sortedPoints =
67- L. sortBy (\ a b -> coord axis a `compare` coord axis b) points
68- medianIndex = length sortedPoints `div` 2
69-
70- -- Create node and construct subtrees
71- node = KdNode { kdLeft = fromListWithDepth (take medianIndex sortedPoints) (depth+ 1 ),
72- kdPoint = sortedPoints !! medianIndex,
73- kdRight = fromListWithDepth (drop (medianIndex+ 1 ) sortedPoints) (depth+ 1 ),
74- kdAxis = axis }
65+ -- Sort point list and choose median as pivot element
66+ sortedPoints =
67+ L. sortBy (\ a b -> coord axis a `compare` coord axis b) points
68+ medianIndex = length sortedPoints `div` 2
69+
70+ -- Create node and construct subtrees
71+ node = KdNode { kdLeft = fromListWithDepth (take medianIndex sortedPoints) (depth+ 1 ),
72+ kdPoint = sortedPoints !! medianIndex,
73+ kdRight = fromListWithDepth (drop (medianIndex+ 1 ) sortedPoints) (depth+ 1 ),
74+ kdAxis = axis }
7575
7676axisFromDepth :: Point p => p -> Int -> Int
7777axisFromDepth p depth = depth `mod` k
@@ -90,18 +90,18 @@ nearestNeighbor (KdNode KdEmpty p KdEmpty _) probe = Just p
9090nearestNeighbor (KdNode l p r axis) probe =
9191 if xProbe <= xp then doStuff l r else doStuff r l
9292 where xProbe = coord axis probe
93- xp = coord axis p
93+ xp = coord axis p
9494 doStuff tree1 tree2 =
95- let candidates1 = case nearestNeighbor tree1 probe of
96- Nothing -> [p]
97- Just best1 -> [best1, p]
98- sphereIntersectsPlane = (xProbe - xp)^ 2 <= dist2 probe p
99- candidates2 = if sphereIntersectsPlane
100- then candidates1 ++ maybeToList (nearestNeighbor tree2 probe)
101- else candidates1 in
102- Just . L. minimumBy (compareDistance probe) $ candidates2
103-
104- -- | invariant tells whether the KD tree property holds for a given tree and
95+ let candidates1 = case nearestNeighbor tree1 probe of
96+ Nothing -> [p]
97+ Just best1 -> [best1, p]
98+ sphereIntersectsPlane = (xProbe - xp)^ 2 <= dist2 probe p
99+ candidates2 = if sphereIntersectsPlane
100+ then candidates1 ++ maybeToList (nearestNeighbor tree2 probe)
101+ else candidates1 in
102+ Just . L. minimumBy (compareDistance probe) $ candidates2
103+
104+ -- | invariant tells whether the K-D tree property holds for a given tree and
105105-- all its subtrees.
106106-- Specifically, it tests that all points in the left subtree lie to the left
107107-- of the plane, p is on the plane, and all points in the right subtree lie to
@@ -110,16 +110,33 @@ invariant :: Point p => KdTree p -> Bool
110110invariant KdEmpty = True
111111invariant (KdNode l p r axis) = leftIsGood && rightIsGood
112112 where x = coord axis p
113- leftIsGood = all ((<= x) . coord axis) (toList l)
114- rightIsGood = all ((>= x) . coord axis) (toList r)
113+ leftIsGood = all ((<= x) . coord axis) (toList l)
114+ rightIsGood = all ((>= x) . coord axis) (toList r)
115115
116+ -- | invariant' tells whether the K-D tree property holds for all subtrees.
116117invariant' :: Point p => KdTree p -> Bool
117118invariant' = all invariant . subtrees
118119
120+ kNearestNeighbors :: (Eq p , Point p ) => KdTree p -> Int -> p -> [p ]
121+ kNearestNeighbors KdEmpty _ _ = []
122+ kNearestNeighbors _ k _ | k <= 0 = []
123+ kNearestNeighbors tree k probe = nearest : kNearestNeighbors tree' (k- 1 ) probe
124+ where nearest = fromJust $ nearestNeighbor tree probe
125+ tree' = tree `remove` nearest
126+
127+ remove :: (Eq p , Point p ) => KdTree p -> p -> KdTree p
128+ remove KdEmpty _ = KdEmpty
129+ remove (KdNode l p r axis) pKill =
130+ if p == pKill
131+ then fromListWithDepth (toList l ++ toList r) axis
132+ else if coord axis pKill <= coord axis p
133+ then KdNode (remove l pKill) p r axis
134+ else KdNode l p (remove r pKill) axis
135+
119136instance Arbitrary Point3d where
120137 arbitrary = do
121- x <- arbitrary
122- y <- arbitrary
123- z <- arbitrary
124- return (Point3d x y z)
138+ x <- arbitrary
139+ y <- arbitrary
140+ z <- arbitrary
141+ return (Point3d x y z)
125142
0 commit comments