KDTree.hs

 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 {-# LANGUAGE MultiParamTypeClasses #-} module KDTree where import Control.Monad (liftM2) import Data.List (sort) import Test.QuickCheck -- really a 2D tree class Point a n where x :: a -> n y :: a -> n data D2 pt n = Leaf pt | Vertical { at :: n, left :: D2 pt n, right :: D2 pt n } | Horizontal { at :: n, above :: D2 pt n, below :: D2 pt n } build :: (Ord n, Point a n) => [a] -> D2 a n build [] = error "no empty trees" build [a] = Leaf a build as = Vertical n (build l) (build r) where (l, n, r) = findMedianBy x as smallestToFrontBy :: Ord n => (a -> n) -> [a] -> [a] smallestToFrontBy val [] = [] smallestToFrontBy val [a] = [a] smallestToFrontBy val (a:as) = if val a < val b then a:b:bs else b:a:bs where (b:bs) = smallestToFrontBy val as data FindTest = FT Int [Int] deriving (Show) instance Arbitrary FindTest where arbitrary = do xs <- liftM2 (:) arbitrary arbitrary k <- choose (0, pred \$ length xs) return \$ FT k xs shrink (FT k xs) = [FT k' xs | NonNegative k' <- shrink (NonNegative k)] ++ [FT (k `min` pred (length xs')) xs' | xs' <- shrink xs, not (null xs')] prop_findBy :: FindTest -> Bool prop_findBy (FT k xs) = all (n >=) left && all (n <=) right && length left == k && sort (left ++ right) == sort xs where (left, n, right) = findBy id k xs findBy :: Ord n => (a -> n) -> Int -> [a] -> ([a], n, [a]) findBy val = find where find 0 [a] = ([], val a, [a]) find k (a:as) = case length left `compare` k of EQ -> (left, val a, a:right) LT -> addToLeft (a:left) (find (k - length left - 1) right) GT -> addToRight (find k left) (a:right) where (left, right) = partition as [] [] partition [] left right = (left, right) partition (b:bs) left right | val b < val a = partition bs (b:left) right | otherwise = partition bs left (b:right) addToLeft :: [a] -> ([a], n, [a]) -> ([a], n, [a]) addToLeft as (left, n, right) = (as ++ left, n, right) addToRight :: ([a], n, [a]) -> [a] -> ([a], n, [a]) addToRight (left, n, right) as = (left, n, as ++ right) findMedianBy :: Ord n => (a -> n) -> [a] -> ([a], n, [a]) findMedianBy val as = findBy val (length as `div` 2) as -- random testing prop_smallest :: [Int] -> Property prop_smallest xs = length xs >= 2 ==> head (smallestToFrontBy id xs) == minimum xs prop_notSorted :: [Int] -> Property prop_notSorted xs = length xs >= 2 ==> sort xs == smallestToFrontBy id xs