-- tree automata

module TA where

import Data.List
import qualified Data.List as L
import Term
import TRS
import TypeInference
import qualified Data.Set as Set
import qualified Data.Map as M

-- tree automata
data TA a = TA {
  _Q :: [a],
  _typedQ :: [(a, Int)], 
  _Qf :: [a],
  _non_epsilon :: [((String, [a]), a)],
  _epsilon :: [(a, a)],
  _relation :: [(a, a)]
}


type Signature = [(String, Int)]

type StateSubst a = [(String, a)]

type Config a = (Term, StateSubst a, a)

-- pretty printers

show_transition :: Show a => ((String, [a]), a) -> String
show_transition ((f, []), q) = f ++ " -> " ++ show q
show_transition ((f, ps), q) =
  f ++ 
  "(" ++ intercalate "," [ show p | p <- ps ] ++ ") -> " ++
  show q

show_relation :: Show a => [(a,a)] -> String
show_relation [] = "(empty)"
show_relation relation =
  intercalate ", " [ show p ++ " >> " ++ show q | (p, q) <- relation ]

show_type :: Show a => (a, Int) -> String
show_type (s, m) = "(" ++ show s ++ ", " ++ show m ++ ")"

instance (Show a, Eq a) => Show (TA a) where
  show ta = 
    unlines
      (["  states:",
        "    " ++ intercalate ", " [ show_state q | q <- _Q ta ],
        "  typed states:",
        "    " ++ intercalate ", " [ show_type (q, m) | (q, m) <- _typedQ ta],
        "  state relation:",
        "    " ++ show_relation (_relation ta),
        "  transitions:"] ++
        [ "    " ++ show_transition rule | rule <- _non_epsilon ta ] ++
        [ "    " ++ show p ++ " -> " ++ show q | (p, q) <- _epsilon ta ])
    where 
      show_state q = show q ++ if elem q (_Qf ta) then "F" else ""

-- { q in Q | p in P and p ->* q }
epsilon_closure :: Eq a => TA a -> [a] -> [a]
epsilon_closure ta ps =
  case qs of
    [] -> ps
    _  -> epsilon_closure ta (ps ++ qs)
  where 
    qs = nub [ q | (p, q) <- _epsilon ta, elem p ps && notElem q ps ]

-- { q in Q | t sigma ->* q }
eval :: Eq a => TA a -> Term -> StateSubst a -> [a]
eval ta (V x) sigma = 
  epsilon_closure ta [ q | (y, q) <- sigma, x == y ]
eval ta (F f ts) sigma = epsilon_closure ta (nub qs)
  where 
    qs = [ q
         | ((f', ps), q) <- _non_epsilon ta,
           f == f', 
           elem ps (sequence [ eval ta t sigma | t <- ts ]) ]

signature_of :: TA a -> TA.Signature
signature_of ta = 
  nub [ (f, length ps) | ((f, ps), _) <- _non_epsilon ta ]

-- renaming states

index :: Eq a => a -> [a] -> Int
index x xs
  | Just n <- elemIndex x xs = n
  | otherwise = error "index"

indexes :: Eq a => [a] -> [a] -> [Int]
indexes xs ys = [ index x ys | x <- xs ]

rename :: Eq a => TA a -> TA Int
rename ta = 
  TA {
    _Q = [0 .. length (_Q ta) - 1],
    _typedQ = [(index qs (_Q ta), i) | (qs, i) <- _typedQ ta],
    _Qf = indexes (_Qf ta) (_Q ta), 
    _non_epsilon = 
      [ ((f, indexes ps (_Q ta)), index q (_Q ta))
      | ((f, ps), q) <- _non_epsilon ta ],
    _epsilon = 
      [ (index p (_Q ta), index q (_Q ta)) | (p, q) <- _epsilon ta ],
    _relation = 
      [ (index p (_Q ta), index q (_Q ta)) | (p, q) <- _relation ta ]
  }

-- Determinization.


image :: Eq a => TA a -> String -> [[a]] -> [a]
image nta f pss =
  epsilon_closure nta
    (nub [ q | ((f', ps), q) <- _non_epsilon nta,
               f == f',
               and [ elem p_i ps_i | (p_i, ps_i) <- zip ps pss ] ] )

determinize' :: TA.Signature -> TA Int -> TA [Int] -> M.Map String ([Int], Int) -> TA [Int]
determinize' sig nta dta ds =
  case new_states of
    [] -> dta'
    _  -> determinize' sig nta dta' ds
  where
    delta =
      [ ((f, pss), sort qs)
      | (f, n) <- sig,
        pss <- sequence (replicate n (_Q dta)),
        let qs = image nta f pss,
        qs /= []
      ]
    typed_states = 
      nub [ case M.lookup f ds of
              Just (_, n) -> (qs, n) 
              Nothing -> error "determinize'"
          | ((f, _), qs) <- delta, notElem qs (_Q dta)]
    new_states = 
      nub [ qs | (_, qs) <- delta, notElem qs (_Q dta) ]
    dta' = dta { 
      _Q  = _Q dta ++ new_states, 
      _typedQ = _typedQ dta ++ typed_states,
      _Qf = _Qf dta ++ [ qs | qs <- new_states, intersect qs (_Qf nta) /= [] ], 
      _non_epsilon = delta
    }


empty_ta :: TA a
empty_ta = 
  TA {_Q = [], _typedQ = [], _Qf = [], _non_epsilon = [], _epsilon = [], _relation = []}

determinize :: TA Int -> M.Map String ([Int], Int) -> TA Int
determinize nta ds = 
  TA.rename (determinize' (TA.signature_of nta) nta empty_ta ds)

-- Computing a minimal state-compatible and state-coherence relation (>>).

state_of_term :: Eq a => TA a -> Term -> StateSubst a -> Maybe a
state_of_term _ (V x) sigma = L.lookup x sigma
state_of_term dta (F f ts) sigma =
  case sequence [ state_of_term dta t sigma | t <- ts ] of
    Nothing -> Nothing
    Just ps -> L.lookup (f, ps) (_non_epsilon dta)


typed_subst :: [(String, Int)] -> [(Int, Int)] -> [StateSubst Int]
typed_subst [] _             = [[]]
typed_subst ((x, m) : xs) qs =
   [ (x, p) : s 
   | (p, n) <- nub qs, 
     s <- typed_subst xs qs, 
     n == m ]


state_substitutions :: [String] -> [Int] -> [StateSubst Int]
state_substitutions [] _ = [[]]
state_substitutions (x : xs) qs =
  [ (x, q) : sigma | q <- qs, sigma <- state_substitutions xs qs ]

-- compute a minimal state-compatible relation
compatible_relation' :: TRS -> TA Int -> M.Map String ([Int], Int) -> Maybe [(Int, Int)]
compatible_relation' trs dta ds =
  sequence
    [ case state_of_term dta r sigma of
        Nothing -> Nothing
        Just q -> Just (p, q)
    | (l, r) <- trs,
      sigma <- typed_subst (nub (variable_type_term l ds)) (nub (_typedQ dta)),
      Just p <- [state_of_term dta l sigma] ]

compatible_relation :: TRS -> TA Int -> M.Map String ([Int], Int) -> Maybe [(Int, Int)]
compatible_relation trs dta ds =
  case compatible_relation' trs dta ds of
    Nothing       -> Nothing
    Just relation -> Just (nub relation)

partitions' :: [a] -> [a] -> [([a], a, [a])]
partitions' _  []       = []
partitions' xs (y : ys) = (xs, y, ys) : partitions' (xs ++ [y]) ys

partitions :: [a] -> [([a], a, [a])]
partitions xs = partitions' [] xs

-- compute a minimal state-coherence extension  
coherence_pairs :: TA Int -> [(Int, Int)] -> Maybe [(Int, Int)] 
coherence_pairs dta relation =
  sequence
    [ case coherence_pairs' f ps1  [p2]  ps2 xs of
        Nothing -> Nothing
        Just q2 -> Just (q1, q2) 
    | ((f, ps), q1) <- _non_epsilon dta,
      (ps1, p1, ps2) <- partitions ps,
      (p1', p2) <- relation,
      p1' == p1 ]
  where
    xs = M.fromList (_non_epsilon dta)

coherence_pairs' :: String -> [Int] -> [Int] -> [Int] -> M.Map (String, [Int]) Int -> Maybe Int
coherence_pairs' f ps1 p2 ps2 relation = do
  M.lookup (f, ps1 ++ p2 ++ ps2) relation

subset :: Eq a => [a] -> [a] -> Bool
subset xs ys = all p xs
  where p x = elem x ys

coherence_completion :: TA Int -> [(Int, Int)] -> Maybe [(Int, Int)] 
coherence_completion ta relation =
  case coherence_pairs ta relation of
    Nothing -> Nothing
    Just pairs
      | Set.isSubsetOf (Set.fromList pairs') (Set.fromList relation) -> Just relation
      | otherwise -> coherence_completion ta (relation ++ pairs')
      where 
        pairs' = nub [ (p,q) | (p,q) <- pairs, notElem (p,q) relation ]


-- A minimal state-compatible and state-coherence relation (>>).
state_relation :: TRS -> TA Int -> M.Map String ([Int], Int) -> Maybe [(Int,Int)]
state_relation trs ta ds
  | Just relation1 <- compatible_relation trs ta ds,
    Just relation2 <- coherence_completion ta [ (p,q) | (p,q) <- relation1 ],
    all final relation2 = Just relation2
  | otherwise = Nothing
  where
    final (p, q) = notElem p (_Qf ta) || elem q (_Qf ta)  


-- rewrite closure by tree automata completion

new_state :: Int -> TA Int -> (TA Int, Int)
new_state m nta = (nta {_Q = nub (q : _Q nta)}, q) 
  where q = min m (maximum (0 : _Q nta) + 1)

normalize_term :: Int -> TA Int -> Term -> StateSubst Int -> M.Map String ([Int], Int) -> (TA Int, Int)
normalize_term _ nta (V x) sigma _ =
  case L.lookup x sigma of
    Nothing -> error "normalize_term0"
    Just q -> (nta, q)
normalize_term m nta (F f ts) sigma ds =
  case M.lookup f ds of
    Just (_, n) -> 
      let (nta2, ps) = normalize_terms m nta ts sigma ds in
      case L.lookup (f, ps) (_non_epsilon nta2) of
      Just q -> (nta2, q)
      Nothing -> 
        let (nta3, q) = new_state m nta2 in
        (nta3 {_non_epsilon = nub (((f, ps), q) : _non_epsilon nta3),  _typedQ = nub ((q, n) : _typedQ nta3)}, q)
    Nothing -> error "normalize_term1"

normalize_terms :: Int -> TA Int -> [Term] -> StateSubst Int -> M.Map String ([Int], Int) -> (TA Int, [Int])
normalize_terms _ ta []       _    _   = (ta, [])
normalize_terms m ta (t : ts) sigma ds = (ta2, q : qs)
  where
    (ta1, q)  = normalize_term  m ta  t  sigma ds
    (ta2, qs) = normalize_terms m ta1 ts sigma ds

normalize_config :: Int -> M.Map String ([Int], Int) -> TA Int -> Config Int -> TA Int
normalize_config m ds nta (F f ts, sigma, q)
  | notElem q (_Qf nta) =
    case M.lookup f ds of
      Just (_, n) ->
        case L.lookup (f, ps) (_non_epsilon nta) of
          Nothing -> nta' {_non_epsilon = nub (((f, ps), q) : _non_epsilon nta), _typedQ = nub ((q, n) : _typedQ nta')}
          Just p -> nta' { _epsilon = nub ((p, q) : _epsilon nta),  _typedQ = nub ((q, n) : _typedQ nta')}
      Nothing -> error "normalize_config"
    where
      (nta', ps) = normalize_terms m nta ts sigma ds
normalize_config m ds nta (F f ts, sigma, q) = 
  case M.lookup f ds of
    Just (_, n) -> 
      nta' { _epsilon = nub ((p, q) : _epsilon nta),  _typedQ = nub ((q, n) : _typedQ nta')}
    Nothing -> error "normalize_config"
  where (nta', p) = normalize_term m nta (F f ts) sigma ds
normalize_config m ds nta (V x, sigma, q) =  
      nta' { _epsilon = nub ((p, q) : _epsilon nta)}
  where (nta', p) = normalize_term m nta (V x) sigma ds

normalize_configs :: Int -> TA Int -> [Config Int] -> M.Map String ([Int], Int) -> TA Int
normalize_configs m nta configs ds = 
  L.foldl (normalize_config m ds) nta configs

-- configurations that violate compatibility
violations :: TRS -> TA Int -> [Config Int]
violations trs ta =
  [ (r, sigma, p)
  | (l, r) <- trs,
    sigma <- state_substitutions (Term.variables l) (_Q ta),
    p <- eval ta l sigma,
    notElem p (eval ta r sigma) ]

typed_violations :: TRS -> TA Int -> M.Map String ([Int], Int) -> [Config Int]
typed_violations trs ta ds =
  [ (r, sigma, p)
  | (l, r) <- trs,
    sigma <- typed_subst (nub (variable_type_term l ds)) (_typedQ ta),
    p <- eval ta l sigma,
    notElem p (eval ta r sigma) ]

-- an over-approximation of { t | t ->_R^* u } by
-- tree automata completion
complete :: Int -> TRS -> TA Int -> M.Map String ([Int], Int) -> TA Int
complete m trs nta ds =
  case typed_violations trs nta ds of
    [] -> nta
    cs -> complete m trs (normalize_configs m nta cs ds) ds

compatible_coherent_dta :: TRS -> TA Int -> M.Map String ([Int], Int) -> Maybe (TA Int)
compatible_coherent_dta trs nta ds =
  let dta = determinize nta ds in
  case state_relation trs dta ds of
    Nothing -> Nothing
    Just relation -> Just dta{_relation = relation}

rewrite_closure :: Int -> TRS -> Term -> TA.Signature -> Maybe (TA Int)
rewrite_closure m trs t sig =
  compatible_coherent_dta trs 
    (complete m trs (normalize_config m ds nta (t, [], 0)) ds) ds
  where 
   nta = empty_ta {_Q = [0], _Qf = [0]}
   ds = M.fromList (type_inference trs t sig)

nonfinal :: TA Int -> TA Int -> (Int, Int) -> Bool
nonfinal nta1 nta2 (q1, q2) =
  not (elem q1 (_Qf nta1)) ||
  not (elem q2 (_Qf nta2))

disjoint' :: [(Int, Int)] -> M.Map String [([Int], Int)] -> TA Int -> TA Int -> Bool
disjoint' states tr2 dta1 dta2 =
  all (nonfinal dta1 dta2) states' &&
  (subset states' states || disjoint' states' tr2 dta1 dta2)
  where
    states' = 
      nub (next states (_non_epsilon dta1) tr2 )

list_to_map :: [((String, [Int]), Int)] -> M.Map String [([Int], Int)]
list_to_map xs =
  M.fromListWith (++) [ (f2, [(ps2, q2)]) | ((f2, ps2), q2) <- xs ]

next :: [(Int, Int)] -> [((String, [Int]), Int)] -> M.Map String [([Int], Int)] -> [(Int, Int)]
next states xs ys = do
  ((f1, ps1), q1) <- xs
  (ps2, q2) <- maybe [] id (M.lookup f1 ys)
  if and [ elem state states | state <- zip ps1 ps2 ]
    then return (q1, q2)
    else []

-- L(M1) cap L(M2) = { }
disjoint :: TA Int -> TA Int -> Bool
disjoint dta1 dta2 =
  disjoint' [] (list_to_map (_non_epsilon dta2)) dta1 dta2

