module UnionFind where

import Data.IntMap
import Data.List

type UnionFind = 
  (IntMap Int,   -- parents of equivalence classes
   IntMap Int)   -- ranks of equivalence classes

empty :: Int -> UnionFind
empty n = 
  (fromList [ (i, i) | i <- [0..n-1] ],
   fromList [ (i, 0) | i <- [0..n-1] ])

find :: Int -> UnionFind -> (Int, UnionFind)
find i uf@(parents, _)
  | i == j    = (i, uf)
  | otherwise =
      let (k, (parents',ranks')) = UnionFind.find j uf in
      (k, (Data.IntMap.insert j k parents',ranks'))
  where j = parents ! i

union :: UnionFind -> Int -> Int -> UnionFind
union uf i j
  | i' == j'        = uf2
  | rank_i < rank_j = (Data.IntMap.insert i' j' parents, ranks)
  | rank_i > rank_j = (Data.IntMap.insert j' i' parents, ranks)
  | otherwise       =
      (Data.IntMap.insert j' i' parents,
       Data.IntMap.insert i' (rank_i + 1) ranks)
  where
    (i',uf1) = UnionFind.find i uf
    (j',uf2@(parents,ranks)) = UnionFind.find j uf1
    rank_i = ranks ! i'
    rank_j = ranks ! j'

size :: UnionFind -> Int
size (parents, _) = Data.IntMap.size parents

flatten' :: UnionFind -> Int -> UnionFind
flatten' uf i = uf'
  where (_, uf') = UnionFind.find i uf

flatten :: UnionFind -> UnionFind
flatten uf = Prelude.foldl flatten' uf [0 .. UnionFind.size uf - 1]


partitionBy :: (a -> a -> Bool) -> [a] -> [[a]]
partitionBy _     []       = []
partitionBy equal (x : xs) = (x : ys) : partitionBy equal zs
  where 
    (ys, zs) = Data.List.partition (equal x) xs

partition :: UnionFind -> [[Int]]
partition uf = partitionBy equal [0 .. UnionFind.size uf - 1]
  where
    uf' = flatten uf
    equal i j = UnionFind.find i uf' == UnionFind.find j uf'

-- >>> uf0 = UnionFind.empty 10
-- >>> uf1 = UnionFind.union uf0 2 4
-- >>> uf2 = UnionFind.union uf1 4 6
-- >>> uf3 = UnionFind.union uf2 1 3
-- >>> uf4 = UnionFind.union uf3 8 9
-- >>> UnionFind.partition uf4
-- [[0],[1,3],[2,4,6],[5],[7],[8,9]]
--

