module System.Random.Shuffle ( shuffle , shuffle' , shuffleM ) where import Data.Function (fix) import System.Random (RandomGen, randomR) import Control.Monad (liftM2) import Control.Monad.Random (MonadRandom, getRandomR) data Tree a = Leaf !a | Node !Int !(Tree a) !(Tree a) deriving Show buildTree :: [a] -> Tree a buildTree = (fix growLevel) . (map Leaf) where growLevel _ [node] = node growLevel self l = self $ inner l inner [] = [] inner [e] = [e] inner (e1 : e2 : es) = e1 `seq` e2 `seq` (join e1 e2) : inner es join l@(Leaf _) r@(Leaf _) = Node 2 l r join l@(Node ct _ _) r@(Leaf _) = Node (ct + 1) l r join l@(Leaf _) r@(Node ct _ _) = Node (ct + 1) l r join l@(Node ctl _ _) r@(Node ctr _ _) = Node (ctl + ctr) l r shuffle :: [a] -> [Int] -> [a] shuffle elements = shuffleTree (buildTree elements) where shuffleTree (Leaf e) [] = [e] shuffleTree tree (r : rs) = let (b, rest) = extractTree r tree in b : (shuffleTree rest rs) shuffleTree _ _ = error "[shuffle] called with lists of different lengths" extractTree 0 (Node _ (Leaf e) r) = (e, r) extractTree 1 (Node 2 (Leaf l) (Leaf r)) = (r, Leaf l) extractTree n (Node c (Leaf l) r) = let (e, r') = extractTree (n 1) r in (e, Node (c 1) (Leaf l) r') extractTree n (Node n' l (Leaf e)) | n + 1 == n' = (e, l) extractTree n (Node c l@(Node cl _ _) r) | n < cl = let (e, l') = extractTree n l in (e, Node (c 1) l' r) | otherwise = let (e, r') = extractTree (n cl) r in (e, Node (c 1) l r') extractTree _ _ = error "[extractTree] impossible" shuffle' :: RandomGen gen => [a] -> Int -> gen -> [a] shuffle' elements len = shuffle elements . rseq len where rseq :: RandomGen gen => Int -> gen -> [Int] rseq n = fst . unzip . rseq' (n 1) where rseq' :: RandomGen gen => Int -> gen -> [(Int, gen)] rseq' 0 _ = [] rseq' i gen = (j, gen) : rseq' (i 1) gen' where (j, gen') = randomR (0, i) gen shuffleM :: (MonadRandom m) => [a] -> m [a] shuffleM elements = return . shuffle elements =<< rseqM (length elements 1) where rseqM :: (MonadRandom m) => Int -> m [Int] rseqM 0 = return [] rseqM i = (:) `liftM2` getRandomR (0, i) $ rseqM (i 1)