module Matrix where

import Data.List
import Term
import Rule
import TRS
import SMT

type Vector = (Int, Int)

-- (a,b,c,d) stands for
-- [a b]
-- [c d]
type Matrix = (Int, Int, Int, Int)

-- 2x2 matrix interpretations
-- ("f", f0, [F1, F2]) represents f_A(x1,x2) = f0 + F1 x1 + F2 x2 
type Interpretation = (String, Vector, [Matrix])

type Algebra = [Interpretation]

type VectorExp = (Exp, Exp)

type MatrixExp = (Exp, Exp, Exp, Exp)

-- pretty printers

show_vector :: Vector -> String
show_vector (a, b) = "[" ++ show a ++ ", " ++ show b ++ "]"

show_matrix :: Matrix -> String
show_matrix (a,b,c,d) =
  "[[" ++ 
  show a ++ ", " ++ show b ++ "], [" ++ 
  show c ++ ", " ++ show d ++ 
  "]]"

show_cxi :: (Matrix, Int) -> String
show_cxi ((1,0,0,1),i) = "x" ++ show i
show_cxi (m,i)         = show_matrix m ++ " x" ++ show i

show_lhs :: (String, Int) -> String
show_lhs (f, n) =
  f ++ "_A(" ++ 
  intercalate "," [ "x" ++ show i | i <- [1..n] ] ++ 
  ")"

show_rhs :: (Vector, [Matrix]) -> String
show_rhs (c0, cs) = 
  intercalate " + "
    ([ show_cxi (m, i) | (i, m) <- zip [1..] cs ] ++
     (if c0 == (0,0) && cs /= [] then [] else [show_vector c0]))

show_interpretation :: Interpretation -> String
show_interpretation (f, c0, cs) =
  show_lhs (f, length cs) ++ " = " ++ show_rhs (c0, cs)

show_algebra :: Algebra -> String
show_algebra a =
  unlines [ "  " ++ show_interpretation interpretation | interpretation <- a ]

-- encoding functions

var_fij :: String -> Int -> Int -> Exp
var_fij f i j = Var ("fi_" ++ f ++ "_" ++ show i ++ "_" ++ show j) 

fvar_fij :: String -> Int -> Int -> Formula
fvar_fij f i j = FVar ("fi_" ++ f ++ "_" ++ show i ++ "_" ++ show j) 

var_f0 :: String -> VectorExp
var_f0 f = 
  (var_fij f 0 1, 
   var_fij f 0 2)

var_fi :: String -> Int -> (Exp, Formula, Formula, Formula)
var_fi f i = 
  (var_fij f i 1, 
   fvar_fij f i 2, 
   fvar_fij f i 3, 
   fvar_fij f i 4)

plus_vector :: [VectorExp] -> VectorExp
plus_vector vs = (plus [ e | e <- es1 ], plus [ e | e <- es2 ])
  where (es1, es2) = unzip vs

plus_matrix :: [MatrixExp] -> MatrixExp
plus_matrix ms =
  (plus [ e | e <- es1 ], 
   plus [ e | e <- es2 ],
   plus [ e | e <- es3 ],
   plus [ e | e <- es4 ])
  where (es1, es2, es3, es4) = unzip4 ms

{-
times_matrix :: (Exp, Formula, Formula, Formula) -> MatrixExp -> MatrixExp
times_matrix (a,b,c,d) (a',b',c',d') =
  (plus [times [a, a'], times [b, c']],
   plus [times [a, b'], times [b, d']],
   plus [times [c, a'], times [d, c']],
   plus [times [c, b'], times [d, d']])
-}

times_bool :: Formula -> Exp -> Exp
times_bool f e = ite f e (Val 0)

times_matrix :: (Exp, Formula, Formula, Formula) -> MatrixExp -> MatrixExp
times_matrix (a,b,c,d) (a',b',c',d') =
  (plus [times [a, a'], times_bool b c'],
   plus [times [a, b'], times_bool b d'],
   plus [times_bool c a', times_bool d c'],
   plus [times_bool c b', times_bool d d'])

times_matrix_vector :: (Exp, Formula, Formula, Formula) -> VectorExp -> VectorExp
times_matrix_vector (a,b,c,d) (a',b') =
  (plus [times [a, a'],   times_bool b b'],
   plus [times_bool c a', times_bool d b'])

constant' :: Term -> [VectorExp]
constant' (V _) = []
constant' (F f ts) = 
  (var_f0 f :
   [ times_matrix_vector (var_fi f i) ei
   | (i, ti) <- zip [1..] ts,
     ei <- constant' ti ])

constant :: Term -> VectorExp
constant t = plus_vector (constant' t)

unit_matrix :: MatrixExp
unit_matrix = (Val 1, Val 0, Val 0, Val 1)

coefficient' :: String -> Term -> [MatrixExp]
coefficient' x (V y)
  | x == y    = [unit_matrix]
  | otherwise = []
coefficient' x (F f ts) = 
  [ times_matrix (var_fi f i) ei
  | (i, ti) <- zip [1..] ts,
    ei <- coefficient' x ti ]

coefficient :: String -> Term -> MatrixExp
coefficient x t =
  plus_matrix (coefficient' x t)

geq_vector :: VectorExp -> VectorExp -> Formula
geq_vector (e1,e2) (e1',e2') =
  conj [Geq e1 e1', Geq e2 e2']

gt_vector :: VectorExp -> VectorExp -> Formula
gt_vector (e1,e2) (e1',e2') =
  conj [Gt e1 e1', Geq e2 e2']

geq_matrix :: MatrixExp -> MatrixExp -> Formula
geq_matrix (a,b,c,d) (a',b',c',d') =
  conj [Geq a a', Geq b b', Geq c c', Geq d d']

geqA :: Term -> Term -> Formula
geqA s t =
  conj (geq_vector (constant s) (constant t) :
        [ geq_matrix (coefficient x s) (coefficient x t)
        | x <- Rule.variables (s, t) ])

gtA :: Term -> Term -> Formula
gtA s t =
  conj (gt_vector (constant s) (constant t) :
        [ geq_matrix (coefficient x s) (coefficient x t)
        | x <- Rule.variables (s, t) ])

side_condition :: Signature -> Formula
side_condition sig =
  conj 
    ([ geq_vector (var_f0 f) (Val 0, Val 0) | (f, _) <- sig ] ++
     [ Geq (var_fij f i 1) (Val 1) | (f, n) <- sig, i <- [1..n] ])

encode :: Signature -> TRS -> TRS -> Formula
encode sig rs ss =
  conj
    (side_condition sig :
     [ gtA  l r | (l, r) <- rs ] ++
     [ geqA l r | (l, r) <- ss ])

eval_vector_exp :: Model -> VectorExp -> Vector
eval_vector_exp model (e1, e2) =
  (eval_exp model e1, eval_exp model e2)

eval_matrix_exp :: Model -> (Exp, Formula, Formula, Formula) -> Matrix
eval_matrix_exp model (e1, f2, f3, f4) =
  (eval_exp model e1,
   if eval_formula model f2 then 1 else 0,
   if eval_formula model f3 then 1 else 0,
   if eval_formula model f4 then 1 else 0)

decode_f0 :: Model -> String -> Vector
decode_f0 model f = eval_vector_exp model (var_f0 f)

decode_Fi :: Model -> String -> Int -> Matrix
decode_Fi model f i = eval_matrix_exp model (var_fi f i)

decode_interpretation :: Model -> String -> Int -> Interpretation
decode_interpretation model f n =
  (f, decode_f0 model f, [ decode_Fi model f i | i <- [1..n] ])

decode_algebra :: Model -> Signature -> Algebra
decode_algebra model sig =
  [ decode_interpretation model f n | (f, n) <- sig ] 


terminating :: String -> TRS -> TRS -> IO (Maybe Algebra) 
terminating smt rs ss = do
  m <- sat smt (encode sig rs ss)
  case m of
    Nothing    -> return Nothing
    Just model -> return (Just (decode_algebra model sig))
  where
    sig = TRS.signature_of (rs ++ ss)