module Linear where

import Data.List
import Term
import Rule
import TRS
import SMT

type Interpretation = (String, Int, [Int])

type Algebra = [Interpretation]

-- pretty printers

show_lhs :: (String, Int) -> String
show_lhs (f, n) =
  f ++ "_A(" ++ 
  intercalate "," [ "x" ++ show i | i <- [1..n] ] ++ 
  ")"

show_cxi :: (Int, Int) -> String
show_cxi (1, i) = "x" ++ show i
show_cxi (c, i) = show c ++ " x" ++ show i

show_rhs :: (Int, [Int]) -> String
show_rhs (c0, cs) = 
  intercalate " + "
    ([ show_cxi (m, i) | (i, m) <- zip [1..] cs ] ++
     (if c0 == 0 && cs /= [] then [] else [show c0]))

show_interpretation :: Interpretation -> String
show_interpretation (f, c0, cs) =
  show_lhs (f, length cs) ++ " = " ++ show_rhs (c0, cs)

show_algebra :: Algebra -> String
show_algebra a =
  unlines
    [ "  " ++ show_interpretation interpretation 
    | interpretation <- a ]

-- encoding

var_fi :: String -> Int -> Exp
var_fi f i = Var ("fi_" ++ f ++ "_" ++ show i ++ "_") 

constant' :: Term -> [Exp]
constant' (V _) = []
constant' (F f ts) = 
  (var_fi f 0 :
   [ times [var_fi f i, ei]
   | (i, ti) <- zip [1..] ts,
     ei <- constant' ti ])

constant :: Term -> Exp
constant t = plus (constant' t)

coefficient' :: String -> Term -> [Exp]
coefficient' x (V y)
  | x == y    = [Val 1]
  | otherwise = []
coefficient' x (F f ts) = 
  [ times [var_fi f i, ei]
  | (i, ti) <- zip [1..] ts,
    ei <- coefficient' x ti ]

coefficient :: String -> Term -> Exp
coefficient x t = plus (coefficient' x t)

geqA :: Term -> Term -> Formula
geqA s t =
  conj (Geq (constant s) (constant t) :
        [ Geq (coefficient x s) (coefficient x t)
        | x <- Rule.variables (s, t) ])

gtA :: Term -> Term -> Formula
gtA s t =
  conj (Gt (constant s) (constant t) :
        [ Geq (coefficient x s) (coefficient x t)
        | x <- Rule.variables (s, t) ])

side_condition :: Signature -> Formula
side_condition sig =
  conj 
    ([ Geq (var_fi f 0) (Val 0) | (f, _) <- sig ] ++
     [ Geq (var_fi f i) (Val 1) 
     | (f, n) <- sig, i <- [1..n] ])

encode :: Signature -> TRS -> TRS -> Formula
encode sig rs ss =
  conj
    (side_condition sig :
     [ gtA  l r | (l, r) <- rs ] ++
     [ geqA l r | (l, r) <- ss ])

decode_interpretation :: Model -> String -> Int -> Interpretation
decode_interpretation model f n =
  (f, eval_exp model (var_fi f 0),
      [ eval_exp model (var_fi f i) | i <- [1..n] ])

decode_algebra :: Model -> Signature -> Algebra
decode_algebra model sig =
  [ decode_interpretation model f n | (f, n) <- sig ] 

terminating :: String -> TRS -> TRS -> IO (Maybe Algebra) 
terminating smt rs ss = do
  m <- sat smt (encode sig rs ss)
  case m of
    Nothing    -> return Nothing
    Just model -> return (Just (decode_algebra model sig))
  where
    sig = TRS.signature_of (rs ++ ss)