module TypeInference where

import Term
import Rule
import UnionFind
import Data.List
import qualified Data.Map as M

type TypeDeclare a = [(String, ([a], a))]
type VariableType a = [(String, a)]
type TypeConstraints = [(Int, Int)]
type ES = [Rule]
type Type = [Int]
type Signature = [(String, Int)]

number_of_type :: TypeInference.Signature -> Int
number_of_type []             = 0
number_of_type ((_, n) : sig) = 1 + n + number_of_type sig

rename :: String -> Int -> Int -> Term -> Subst
rename x k m t = s
  where
    ys = TypeInference.variables t
    s  = [ (y, V (x ++ show i)) | (y, i) <- zip ys [k + m ..] ]

rename_rule :: String -> Int -> Int -> Rule -> Subst
rename_rule x k m (l, r) = s
  where
    ys = nub (TypeInference.variables l ++ TypeInference.variables r)
    s  = [ (y, V (x ++ show i)) | (y, i) <- zip ys [k + m ..]]

renamed_term :: String -> Int -> Int -> Term -> Term
renamed_term x k m t =
  Term.substitute t s
    where
      s = TypeInference.rename x k m t

renamed_rule :: String -> Int -> Int -> Rule -> Rule
renamed_rule x k m (l, r) = (Term.substitute l s, Term.substitute r s)
  where
    s = rename_rule x k m (l, r)

rename_es :: String -> Int -> Int -> ES -> ES
rename_es _ _ _ []            = []
rename_es x k m ((l, r) : ys) =
  renamed_rule x k m (l, r) : rename_es x k' m ys
    where
      k' = k + length (nub (ls1 ++ ls2)) 
      ls1 = TypeInference.variables l
      ls2 = TypeInference.variables r

es_to_term :: ES -> [Term]
es_to_term []            = []
es_to_term ((l, r) : xs) = [l, r] ++ es_to_term xs 

variables :: Term -> [String]
variables (V x)    = [x]
variables (F _ ts) = concat [ TypeInference.variables x | x <- ts ]

signature_to_type' :: (String, Int) -> Int -> TypeDeclare Int
signature_to_type' (f, m) n = [(f, (ls, l))]
  where
    ls = [ i | i <- [n..m + n - 1] ]  
    l = m + n

signature_to_type :: TypeInference.Signature -> Int -> TypeDeclare Int
signature_to_type [] _ = []
signature_to_type ((f, m) : xs) n =
  signature_to_type' (f, m) n ++ signature_to_type xs (m + n + 1) 

variables_to_type :: [String] -> Int -> VariableType Int
variables_to_type [] _ = []
variables_to_type (x : xs) m 
  = (x, m) : variables_to_type xs (m + 1)

es_term_to_type :: ES -> TypeInference.Signature -> (TypeDeclare Int, VariableType Int)
es_term_to_type x t = (signature_to_type t m, variables_to_type y 0)
  where
    y = nub (concat ([ TypeInference.variables s | (s, _) <- x ]
        ++ [ TypeInference.variables u | (_, u) <- x ]))
    m = length y

trivial_constraint' :: [Int] -> TypeConstraints
trivial_constraint' []       = []
trivial_constraint' (m : ms) = (m, m) : trivial_constraint' ms

trivial_constraint :: TypeDeclare Int -> TypeConstraints
trivial_constraint []                  = []
trivial_constraint ((_, (ms, m)) : ns) = 
  trivial_constraint' ms ++ [(m,m)] ++ trivial_constraint ns

type_inference_term :: Term ->  M.Map String ([Int], Int) -> M.Map String Int -> TypeConstraints
type_inference_term (V _) _ _ = []
type_inference_term (F f ts) xs ys =
  case M.lookup f xs of
    Just (ms, _) -> 
       concat [ type_inference_term' t n xs ys | (t, n) <- zip ts ms ]
    Nothing      -> error "type_inference_term"

type_inference_term' :: Term -> Int -> M.Map String ([Int], Int) -> M.Map String Int -> TypeConstraints
type_inference_term' (F f ts) m xs ys =
  case M.lookup f xs of
    Just (_, n) -> [(n, m)] ++ type_inference_term (F f ts) xs ys
    Nothing     -> error "type_inference_term'1"
type_inference_term' (V x) m _ ys =
  case M.lookup x ys of
    Just n  -> [(n, m)]
    Nothing -> error "type_inference_term'2" 

type_inference_rule :: Rule -> TypeDeclare Int -> VariableType Int -> TypeConstraints
type_inference_rule (V x, V y) xs ys = 
  [ case lookup x ys of
    Just m -> case lookup y ys of
                Just n -> (m, n)
                Nothing -> error "type_inference_rule"
    Nothing -> error "type_inference_rule"
  ] 
  ++ trivial_constraint xs
type_inference_rule (F f ts, V x) xs ys =
   [ case lookup f xs of
       Just (_, m) -> case lookup x ys of
                   Just n -> (m, n)
                   Nothing -> error "type_inference_rule"
       Nothing -> error "type_inference_rule" 
   ] 
  ++ type_inference_term (F f ts) xs' ys'  
  ++ trivial_constraint xs
  where
    xs' = M.fromList xs
    ys' = M.fromList ys 
type_inference_rule (V x, F f ts) xs ys = 
  [ case lookup x ys of
      Just m -> case lookup f xs of
                  Just (_, n) -> (m, n)
                  Nothing -> error "type_inference_rule"
      Nothing -> error "type_inference_rule" 
  ] 
  ++ type_inference_term (F f ts) xs' ys'
  ++ trivial_constraint xs
    where
    xs' = M.fromList xs
    ys' = M.fromList ys
type_inference_rule (F f ts, F g us) xs ys =
  [ case lookup f xs of
      Just (_, m) -> case lookup g xs of
                       Just (_, n) -> (m, n)
                       Nothing -> error "type_inference_rule"
      Nothing -> error "type_inference_rule"
  ] 
  ++ type_inference_term (F f ts) xs' ys' 
  ++ type_inference_term (F g us) xs' ys'
  ++ trivial_constraint xs
    where
    xs' = M.fromList xs
    ys' = M.fromList ys
type_inference_es :: ES -> Term -> TypeDeclare Int -> VariableType Int -> TypeConstraints
type_inference_es xs t ys zs 
  = concat [ type_inference_rule x ys zs | x <- xs ] 
      ++ type_inference_term t ys' zs'
    where
    ys' = M.fromList ys
    zs' = M.fromList zs


type_inference' :: UnionFind -> TypeConstraints -> [[Int]]
type_inference' x [] = UnionFind.partition x
type_inference' x ((l, r) : ys) = 
  type_inference' (UnionFind.union x l r) ys

type_inference2 :: ES -> Term -> TypeDeclare Int -> VariableType Int -> TypeInference.Signature -> [[Int]]
type_inference2 xs t ys zs sig = type_inference' (UnionFind.empty (m + 1)) ls
  where
    ls = type_inference_es xs t ys zs
    m  = number_of_type sig + length zs 

type_declare :: TypeDeclare Int -> [[Int]] -> TypeDeclare Int
type_declare [] _                  = []
type_declare ((f, (ms, m)) : ns) l =
  (f, (ms', m')) : type_declare ns l
    where
      ms' = [type_index l' l 0 | l' <- ms ]
      m'  = type_index m l 0

type_index :: Int -> [[Int]] -> Int -> Int
type_index _ [] _ = error "type_index"
type_index m (n : ns) l
  | elem m n = l
  | otherwise = type_index m ns (l + 1)

type_inference :: ES -> Term -> TypeInference.Signature -> TypeDeclare Int
type_inference xs t s = type_declare y ls 
  where
    xs'    = rename_es "x" 0 0 xs
    ls     = type_inference2 xs' t y z s
    (y, z) = es_term_to_type xs' s


variable_type_term :: Term -> M.Map String ([Int], Int) -> VariableType Int
variable_type_term (V _) _     = error "variable_type_term1"
variable_type_term (F _ []) _  = []
variable_type_term (F f ts) ms =
  case M.lookup f ms of
    Just (ss, _) -> variable_type_term_list ts ss ms
    Nothing      -> error "variable_type_term2"

variable_type_term_list :: [Term] -> [Int] -> M.Map String ([Int], Int) -> VariableType Int
variable_type_term_list [] _ _                      = []
variable_type_term_list _ [] _                      = []
variable_type_term_list ((V x) : ts) (s : ss) ms    =
  (x, s) : variable_type_term_list ts ss ms
variable_type_term_list ((F f ts) : ss) (_ : us) ms =
  variable_type_term (F f ts) ms 
    ++ variable_type_term_list ss us ms

un_typed' :: (String, Int) -> (String, ([Int], Int))
un_typed' (f, n) = (f, (ns, 0))
  where
    ns = [ 0 | _ <- [0 .. n - 1]]

un_typed :: TypeInference.Signature -> (TypeDeclare Int, Type)
un_typed sig = (ds, [0])
  where
    ds = [ un_typed' s | s <- sig ]
 
 