KDTree.hs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
{-# LANGUAGE MultiParamTypeClasses #-}

module KDTree
where

import Control.Monad (liftM2)
import Data.List (sort)
import Test.QuickCheck

-- really a 2D tree

class Point a n where
  x :: a -> n
  y :: a -> n


data D2 pt n = Leaf pt
             | Vertical   { at :: n, left :: D2 pt n, right :: D2 pt n }
             | Horizontal { at :: n, above :: D2 pt n, below :: D2 pt n }

build :: (Ord n, Point a n) => [a] -> D2 a n
build [] = error "no empty trees"
build [a] = Leaf a
build as = Vertical n (build l) (build r)
  where (l, n, r) = findMedianBy x as

smallestToFrontBy :: Ord n => (a -> n) -> [a] -> [a]
smallestToFrontBy val [] = []
smallestToFrontBy val [a] = [a]
smallestToFrontBy val (a:as) = if val a < val b then a:b:bs else b:a:bs
  where (b:bs) = smallestToFrontBy val as

data FindTest = FT Int [Int] deriving (Show)

instance Arbitrary FindTest where
  arbitrary = do
    xs <- liftM2 (:) arbitrary arbitrary
    k <- choose (0, pred $ length xs)
    return $ FT k xs

  shrink (FT k xs) = [FT k' xs | NonNegative k' <- shrink (NonNegative k)]
                     ++ [FT (k `min` pred (length xs')) xs'
                         | xs' <- shrink xs, not (null xs')]

prop_findBy :: FindTest -> Bool
prop_findBy (FT k xs) =
  all (n >=) left
  && all (n <=) right
  && length left == k
  && sort (left ++ right) == sort xs
  where (left, n, right) = findBy id k xs

findBy :: Ord n => (a -> n) -> Int -> [a] -> ([a], n, [a])
findBy val = find
  where find 0 [a] = ([], val a, [a])
        find k (a:as) =
          case length left `compare` k of
            EQ -> (left, val a, a:right)
            LT -> addToLeft (a:left) (find (k - length left - 1) right)
            GT -> addToRight (find k left) (a:right)
          where (left, right) = partition as [] []
                partition [] left right = (left, right)
                partition (b:bs) left right
                  | val b < val a = partition bs (b:left) right
                  | otherwise = partition bs left (b:right)

                addToLeft :: [a] -> ([a], n, [a]) -> ([a], n, [a])
                addToLeft as (left, n, right) = (as ++ left, n, right)

                addToRight :: ([a], n, [a]) -> [a] -> ([a], n, [a])
                addToRight (left, n, right) as = (left, n, as ++ right)

findMedianBy :: Ord n => (a -> n) -> [a] -> ([a], n, [a])
findMedianBy val as = findBy val (length as `div` 2) as

-- random testing

prop_smallest :: [Int] -> Property
prop_smallest xs =
  length xs >= 2 ==>
  head (smallestToFrontBy id xs) == minimum xs

prop_notSorted :: [Int] -> Property
prop_notSorted xs =
  length xs >= 2 ==>
  sort xs == smallestToFrontBy id xs