-- SMT encodings for lower triangular matrix interpretations on N^2.

module EncodingMatrix where

import Data.List
import Term
import Rule
import SMT
import qualified Matrix
import Order
import ReductionPair (ReductionPair (..))
import TRS

-- Symbolic versions of Vector etc.

type EVector = [Exp]
type FVector = [Formula]
type FMatrix = [FVector]
type EMatrix = [EVector]
type ELinear = (Exp, [Exp])
type SForm = (EVector, [EMatrix])
type SInterpretation = (String, SForm)
type SAlgebra = (Int, [SInterpretation])

zero_vector :: Int -> EVector
zero_vector d = replicate d (Val 0)

add_vector :: EVector -> EVector -> EVector
add_vector v1 v2 = zipWith op v1 v2
  where op a b = plus [a,b]

sum_vector :: Int -> [EVector] -> EVector
sum_vector d vs = foldl add_vector (zero_vector d) vs

zero_matrix :: Int -> EMatrix
zero_matrix d = replicate d (zero_vector d)

unit_matrix :: Int -> EMatrix
unit_matrix d = 
  [ [ if i == j then Val 1 else Val 0 | j <- [0..d-1] ]
  | i <- [0..d-1] ]

add_matrix :: EMatrix -> EMatrix -> EMatrix
add_matrix m1 m2 = zipWith add_vector m1 m2 

sum_matrix :: Int -> [EMatrix] -> EMatrix
sum_matrix d ms = foldl add_matrix (zero_matrix d) ms

inner_product :: FVector -> EVector -> Exp
inner_product v1 v2 = plus (zipWith times01 v1 v2)

mul_matrix_vector :: FMatrix -> EVector -> EVector
mul_matrix_vector m v = [ inner_product u v | u <- m ]

mul_vector_matrix :: FVector -> EMatrix -> EVector
mul_vector_matrix u m = [ inner_product u v | v <- m ]

mul_matrix :: FMatrix -> EMatrix -> EMatrix
mul_matrix m1 m2 =
  [ mul_vector_matrix v (transpose m2) | v <- m1 ]

-- _Fk is a square matrix.
-- If Lex is set, it is in a lower triangluar form.
-- [fk00    0]
-- [fk10 fk11]
-- rp_id is the number for this reduction pair
-- it is used for lexicographic combination. 
_Fkij :: Matrix.BaseOrder -> Int -> String -> Int -> Int -> Int -> Formula
_Fkij Matrix.Lex _ _ _ i j
  | i < j = bottom
_Fkij _ rp_id f k i j = 
  FVar ("Fkij_" ++ show rp_id ++ "_" ++ 
        f ++ "_" ++ "_" ++ show k ++ "_" ++ show i ++ "_" ++ show j)

_Fk :: Matrix.BaseOrder -> Int -> Int -> String -> Int -> FMatrix
_Fk o d rp_id f k =
  [ [ _Fkij o rp_id f k i j | j <- [0 .. d-1] ]
  | i <- [0 .. d-1] ]

f0_i :: Int -> String -> Int -> Exp
f0_i rp_id f i = Var ("f0_i_" ++ show rp_id ++ "_" ++ f ++ "_" ++ show i)

f0 :: Int -> Int -> String -> EVector
f0 d rp_id f = [ f0_i rp_id f i | i <- [0..d-1] ]

-- C_A(t,x) returns the coefficient matrix of alpha(x) in [\alpha]_A(t):
-- C_A(x,x) = E
-- C_A(y,x) = O
-- C_A(f(t_0,...,t_{n-1}),x) = F_0 C_A(t_0,x) + ... + F_{n-1} C_A(t_{n-1},x),
--   where f_A(x_0,...,x_{n-1}) = f_0 + F_0 x0 + ... + F_{n-1} x_{n-1}
coefficient :: Matrix.BaseOrder -> Int -> Int -> Term -> String -> EMatrix
coefficient _ d _ (V x) y
  | x == y    = unit_matrix d
  | otherwise = zero_matrix d
coefficient o d rp_id (F f ts) x =
  sum_matrix d
    [ mul_matrix (_Fk o d rp_id f k) (coefficient o d rp_id tk x) 
    | (k, tk) <- zip [0..] ts ]

-- c_A(t) returns the constant part of [\alpha]_A(t):
-- c_A(x) = 0
-- c_A(f(t_0,...,t_{n-1})) = f_0 + c_A(t_0) + ... + c_A(t_{n-1})
--   where f_A(x0,...,x_{n-1}) = f_0 + F_0 x_0 + ... + F_{n-1} x_{n-1}
constant :: Matrix.BaseOrder -> Int -> Int -> Term -> EVector
constant _ d _     (V _)    = zero_vector d
constant o d rp_id (F f ts) =
  sum_vector d
    (f0 d rp_id f : 
     [ mul_matrix_vector (_Fk o d rp_id f k) (constant o d rp_id tk) 
     | (k, tk) <- zip [0..] ts ])

geq_linear :: ELinear -> ELinear -> Formula
geq_linear (a0, as) (b0, bs) = 
  conj (geq a0 b0 : [ geq ak bk | (ak, bk) <- zip as bs ])

gt_linear :: ELinear -> ELinear -> Formula 
gt_linear (a0, as) (b0, bs) = 
  conj (gt a0 b0 : [ geq ak bk | (ak, bk) <- zip as bs ])

geq_linear_lex :: [ELinear] -> [ELinear] -> Formula 
geq_linear_lex []         []         = top
geq_linear_lex (l1 : ls1) (l2 : ls2) =
  disj [gt_linear l1 l2, 
        conj [geq_linear l1 l2, geq_linear_lex ls1 ls2]]
geq_linear_lex _ _ = error "geq_linear_lex"

gt_linear_lex :: [ELinear] -> [ELinear] -> Formula 
gt_linear_lex []         []         = bottom
gt_linear_lex (l1 : ls1) (l2 : ls2) =
  disj [gt_linear l1 l2, 
        conj [geq_linear l1 l2, gt_linear_lex ls1 ls2]]
gt_linear_lex _ _ = error "gt_linear_lex"

geq_linears :: Matrix.BaseOrder -> [ELinear] -> [ELinear] -> Formula
geq_linears Matrix.Standard ls1 ls2 =
  conj [ geq_linear l1 l2 | (l1, l2) <- zip ls1 ls2 ]  
geq_linears Matrix.Lex ls1 ls2 =
  geq_linear_lex ls1 ls2

gt_linears :: Matrix.BaseOrder -> [ELinear] -> [ELinear] -> Formula
gt_linears Matrix.Standard (l1 : ls1) (l2 : ls2) =
  conj (gt_linear l1 l2 : 
        [ geq_linear l1' l2' | (l1', l2') <- zip ls1 ls2 ])
gt_linears Matrix.Standard _ _ = error "gt_linears"
gt_linears Matrix.Lex ls1 ls2 =
  gt_linear_lex ls1 ls2

interpret :: Matrix.BaseOrder -> Int -> Int -> [String] -> Term -> [ELinear]
interpret o d rp_id xs t = 
  [ (v !! i, [ a_ij | _A <- _As, a_ij <- _A !! i ] )
  | i <- [0 .. length v - 1] ]
  where
    v = constant o d rp_id t
    _As = [ coefficient o d rp_id t x | x <- xs ]

-- >=_A
geq_A :: Matrix.BaseOrder -> Int -> Int -> Rule -> Formula
geq_A o d rp_id (s, t) =
  geq_linears o (interpret o d rp_id xs s) (interpret o d rp_id xs t)
  where xs = Rule.variables (s, t)

-- >=_A
gt_A :: Matrix.BaseOrder -> Int -> Int -> Rule -> Formula
gt_A o d rp_id (s, t) =
  gt_linears o (interpret o d rp_id xs s) (interpret o d rp_id xs t)
  where xs = Rule.variables (s, t)

-- Definition:
-- A is in a column echelon form if
-- for every i >= 0 and j > 0 with M_ij > 0 
-- there exists i' < i with a_{i',j-1} > 0.
-- Example:
-- (0 0 0 0 0)
-- (1 0 0 0 0)
-- (2 0 0 0 0)
-- (3 4 0 0 0)
-- (5 6 7 0 0)

echelon_form :: Int -> Int -> String -> Int -> Formula
echelon_form d rp_id f k =
  conj 
    [ implies (_Fkij Matrix.Lex rp_id f k i j) 
              (disj [ _Fkij Matrix.Lex rp_id f k i' (j - 1) 
                    | i' <- [0..i-1] ])
    | i <- [0..d-1], 
      j <- [1..d-1] ]

weak_monotonicity :: Matrix.BaseOrder -> Int -> Int -> Signature -> Formula
weak_monotonicity Matrix.Standard _ _     _   = top
weak_monotonicity Matrix.Lex      d rp_id sig =
  conj [ echelon_form d rp_id f k 
       | (f, n) <- sig, k <- [0..n-1] ]

side_condition :: Matrix.BaseOrder -> Int -> Int -> (TRS, TRS) -> Signature -> Formula
side_condition o d rp_id _ sig = 
  conj (weak_monotonicity o d rp_id sig :
        [ conj [ geq (f0_i rp_id f i) (Val 0) | i <- [0..d-1] ] 
               | (f, _) <- sig ])

monotone :: Matrix.BaseOrder -> Int -> Int -> (String, Int) -> Int -> Formula
monotone o@Matrix.Standard _ rp_id (f, _) k = 
  _Fkij o rp_id f k 0 0
monotone o@Matrix.Lex      d rp_id (f, _) k =
  conj [ _Fkij o rp_id f k i i | i <- [0..d-1] ]

invariant :: Matrix.BaseOrder -> Int -> Int -> (String, Int) -> Int -> Formula
invariant o d rp_id (f, _) k = 
  conj [ neg (_Fkij o rp_id f k i j)
       | i <- [0..d-1],
         j <- [0..d-1] ]

weakly_simple :: Matrix.BaseOrder -> Int -> Int -> (String, Int) -> Int -> Formula
weakly_simple o d rp_id (f, _) k = conj [ _Fkij o rp_id f k i i | i <- [0..d-1] ]

strictly_simple :: Matrix.BaseOrder -> Int -> Int -> (String, Int) -> Int -> Formula
strictly_simple o d rp_id (f, _) k = conj
  (c : [ _Fkij o rp_id f k i i | i <- [0..d-1] ])
  where
    c = case o of
          Matrix.Standard ->
            gt (f0_i rp_id f 0) (Val 0)
          Matrix.Lex ->
            disj [ gt (f0_i rp_id f i) (Val 0) | i <- [0..d-1] ]

decode_f0 :: Int -> Int -> Model -> String -> Matrix.Vector
decode_f0 d rp_id model f =
  [ evalExp model (f0_i rp_id f i) | i <- [0..d-1] ]

decode_Fkij :: Matrix.BaseOrder -> Int -> Model -> String -> Int -> Int -> Int -> Int
decode_Fkij o rp_id model f k i j
  | evalFormula model (_Fkij o rp_id f k i j) = 1
  | otherwise                                 = 0

decode_Fk :: Matrix.BaseOrder -> Int -> Int -> Model -> String -> Int -> Matrix.Matrix
decode_Fk o d rp_id model f k =
  [ [ decode_Fkij o rp_id model f k i j | j <- [0..d-1] ]
  | i <- [0..d-1] ]

decode_form :: Matrix.BaseOrder -> Int -> Int -> Model -> String -> Int -> Matrix.Form
decode_form o d rp_id model f n =
  (decode_f0 d rp_id model f, 
   [ decode_Fk o d rp_id model f k | k <- [0 .. n-1] ])

decode_interpretation :: Matrix.BaseOrder -> Int -> Int -> Model -> String -> Int -> Matrix.Interpretation
decode_interpretation o d rp_id model f n = 
  (f, decode_form o d rp_id model f n)

decode_algebra :: Matrix.BaseOrder -> Int -> Int -> Model -> Signature -> Matrix.Algebra
decode_algebra o d rp_id model sig = 
  (o, d,
   [ decode_interpretation o d rp_id model f n | (f, n) <- sig ])

decode :: Matrix.BaseOrder -> Int -> Int -> Model -> Signature -> Order
decode o d rp_id model sig =
  Matrix (decode_algebra o d rp_id model sig)

reduction_pair :: Matrix.BaseOrder -> Int -> Int -> ReductionPair 
reduction_pair o d rp_id = ReductionPair {
  _side_condition = side_condition o d rp_id,
  _monotone = monotone o d rp_id,
  _invariant = invariant o d rp_id,
  _geq = geq_A o d rp_id,
  _gt = gt_A o d rp_id,
  _decode = decode o d rp_id,
  _weakly_simple = weakly_simple o d rp_id,
  _strictly_simple = strictly_simple o d rp_id
}
