module WPO where

import Data.List

import Term
import Rule
import Order
import SMT
import ReductionPair (ReductionPair (..))
import TRS
import Precedence
import ArgumentFiltering

-- variables for (efficient) encoding

-- partial status (no collapsing, no permutation) 
-- (i is 0-indexed)
varStatus :: Int -> String -> Int -> SMT.Formula
varStatus k f i = SMT.FVar ("status_" ++ show k ++ "_" ++ f ++ "_" ++ show i ++ "_wpo") 

varEmptyStatus :: Int -> String -> SMT.Formula
varEmptyStatus k f = SMT.FVar ("empty_status_" ++ show k ++ "_" ++ f  ++ "_wpo")

varPrec :: Int -> String -> SMT.Exp
varPrec k f = SMT.Var ("precedence_" ++ show k ++ "_" ++ f ++ "_wpo")

varGeq :: Int -> Rule -> SMT.Formula
varGeq k (s, t) = SMT.FVar ("geqWPO_" ++ show k ++ "_" ++ show s ++ "_" ++ show t)

varGt :: Int -> Rule -> SMT.Formula
varGt k (s, t) = SMT.FVar ("gtWPO_" ++ show k ++ "_" ++ show s ++ "_" ++ show t)

varGeqA :: Int -> Rule -> SMT.Formula
varGeqA k (s, t) = SMT.FVar ("geqA_" ++ show k ++ "_" ++ show s ++ "_" ++ show t  ++ "_wpo")

varGtA :: Int -> Rule -> SMT.Formula
varGtA k (s, t) = SMT.FVar ("gtA_" ++ show k ++ "_" ++ show s ++ "_" ++ show t  ++ "_wpo")

varLexGeq :: Int -> Rule -> Int -> Int -> SMT.Formula
varLexGeq rp_id (s, t) k l =
  SMT.FVar ("lex_geq_" ++ show rp_id ++ "_" ++ show s ++ "_" ++ show t ++ "_" ++ show k ++ "_" ++ show l  ++ "_wpo")

varLexGt :: Int -> Rule -> Int -> Int -> SMT.Formula
varLexGt rp_id (s, t) k l =
  SMT.FVar ("lex_gt_" ++ show rp_id ++ "_" ++ show s ++ "_" ++ show t ++ "_" ++ show k ++ "_" ++ show l  ++ "_wpo")

varLeftVar :: Int -> String -> SMT.Formula
varLeftVar k f = SMT.FVar ("left_var_" ++ show k ++ "_" ++ f  ++ "_wpo")

varRightVar :: Int -> String -> SMT.Formula
varRightVar k f = SMT.FVar ("right_var_" ++ show k ++ "_" ++ f  ++ "_wpo")

-- strictly simple with respect to status
varStrictlySimple :: Int -> SMT.Formula
varStrictlySimple k = SMT.FVar ("strictly_simple_" ++ show k  ++ "_wpo")

-- whether (>=_WPO, >_WPO) = (>=_A, >_A)
varAlgebraSimulation :: Int -> SMT.Formula
varAlgebraSimulation k = SMT.FVar ("algebra_simulation_" ++ show k  ++ "_wpo")

-- encoding of WPO

geqPrec :: Int -> String -> String -> Formula
geqPrec k f g = SMT.geq (varPrec k f) (varPrec k g)

gtPrec :: Int -> String -> String -> Formula
gtPrec k f g = SMT.gt (varPrec k f) (varPrec k g)

eqPrec :: Int -> String -> String -> Formula
eqPrec k f g = SMT.eq (varPrec k f) (varPrec k g)

sub :: Int -> Rule -> Formula
sub k (F f ss, t) =
  SMT.disj [ SMT.conj [ varStatus k f i, varGeq k (si, t) ] | (si, i) <- zip ss [0..] ]
sub _ _ = SMT.bottom

args :: Int -> Rule -> Formula
args k (s@(F _ _), F g ts) =
  SMT.conj [ SMT.implies (varStatus k g j) (varGt k (s, tj)) | (tj, j) <- zip ts [0..] ]
args _ _ = SMT.bottom

prec :: Int -> Rule -> Formula
prec k (F f _, F g _) = WPO.gtPrec k f g
prec _ _ = SMT.bottom

lexGeq :: Int -> Rule -> Formula
lexGeq rp_id rl@(F f _, F g _) = 
  SMT.conj [ WPO.eqPrec rp_id f g, varLexGeq rp_id rl 0 0 ]
lexGeq _ _ = SMT.bottom

lexGeq' :: Int -> Rule -> Int -> Int -> Formula
lexGeq' rp_id rl@(F f ss, F g ts) k l
  | k == m || l == n = conj [ neg (varStatus rp_id g j) | j <- [l..(n-1)] ]
  | otherwise =
      SMT.disj [
        SMT.conj [
          varStatus rp_id f k,
          varStatus rp_id g l,
          SMT.disj [
            varGt rp_id (ss !! k, ts !! l),
            SMT.conj [ varGeq rp_id (ss !! k, ts !! l), varLexGeq rp_id rl (k+1) (l+1) ]
          ]
        ],
        SMT.conj [
          neg (varStatus rp_id f k),
          varLexGeq rp_id rl (k + 1) l
        ],
        SMT.conj [
          neg (varStatus rp_id g l),
          varLexGeq rp_id rl k (l + 1)
        ]
      ]
  where
    m = length ss
    n = length ts
lexGeq' _ _ _ _ = error "lexGeq': undefined case"

lexGt :: Int -> Rule -> Formula
lexGt rp_id rl@(F f _, F g _) =
  SMT.conj [ WPO.eqPrec rp_id f g, varLexGt rp_id rl 0 0 ]
lexGt _ _ = SMT.bottom

lexGt' :: Int -> Rule -> Int -> Int -> Formula
lexGt' rp_id rl@(F f ss, F g ts) k l
  | k == m || l == n = disj [ varStatus rp_id f i | i <- [k..(m-1)] ]
  | otherwise =
      SMT.disj [
        SMT.conj [
          varStatus rp_id f k,
          varStatus rp_id g l,
          SMT.disj [
            varGt rp_id (ss !! k, ts !! l),
            SMT.conj [ varGeq rp_id (ss !! k, ts !! l), varLexGt rp_id rl (k+1) (l+1) ]
          ]
        ],
        SMT.conj [
          neg (varStatus rp_id f k),
          varLexGt rp_id rl (k + 1) l
        ],
        SMT.conj [
          neg (varStatus rp_id g l),
          varLexGt rp_id rl k (l + 1)
        ]
      ]
  where
    m = length ss
    n = length ts
lexGt' _ _ _ _ = error "lexGt': undefined case"

leftVar :: Int -> Rule -> Formula
leftVar _ (V x, V y)
  | x == y = top
leftVar rp_id (V _, F g _) = varLeftVar rp_id g
leftVar _ _ = bottom

rightVar :: Int -> Rule -> Formula
rightVar rp_id (F f _, V _) = varRightVar rp_id f
rightVar _ _ = bottom

geq :: Int -> Rule -> Formula
geq k rl =
  SMT.disj [
    varGtA k rl,
    SMT.conj [
      varGeqA k rl,
      SMT.disj [
        sub k rl, 
        SMT.conj [
          args k rl,
          SMT.disj [ prec k rl, lexGeq k rl ]
        ],
        leftVar k rl,
        rightVar k rl
      ]
    ]
  ]

gt :: Int -> Rule -> Formula
gt k rl =
  SMT.disj [ varGtA k rl,
    SMT.conj [ varGeqA k rl,
      SMT.disj [ sub k rl, 
        SMT.conj [ args k rl,
          SMT.disj [ prec k rl, lexGt k rl ] ] ] ] ]

monotone :: ReductionPair -> Int -> (String, Int) -> Int -> Formula
monotone rp k fn@(f, _) i =
  disj [
    varStatus k f i,
    conj [ varAlgebraSimulation k, _monotone rp fn i ]
  ]

weakly_simple :: ReductionPair -> Int -> (String, Int) -> Int -> Formula
weakly_simple rp k fn@(f, _) i =
  disj [
    varStatus k f i,
    conj [ varAlgebraSimulation k, _weakly_simple rp fn i ]
  ]

strictly_simple :: ReductionPair -> Int -> (String, Int) -> Int -> Formula
strictly_simple rp k fn@(f, _) i =
  disj [
    varStatus k f i,
    conj [ varAlgebraSimulation k, _strictly_simple rp fn i ]
  ]

invariant :: ReductionPair -> Int -> (String, Int) -> Int -> Formula
invariant rp k fn@(f, _) i = conj [ _invariant rp fn i, neg (varStatus k f i) ]

side_condition :: Int -> ReductionPair -> (TRS, TRS) -> Signature -> Formula
side_condition rp_id rp (ps, rs) sig = conj $
  [ _side_condition rp (ps, rs) sig,
    implies (varStrictlySimple rp_id) strictlySimple',
    implies (varAlgebraSimulation rp_id) (conj ([ varEmptyStatus rp_id f | (f, _) <- sig ] ++ [ eqPrec rp_id f g  | (f, _) <- sig, (g, _) <- sig ])) ] ++
  [ implies (varStatus rp_id f i) (_weakly_simple rp fn i) | fn@(f, n) <- sig, i <- [0..(n-1)] ] ++ 
  [ implies (varGeq rp_id rl) (WPO.geq rp_id rl) | rl <- possible ] ++ 
  [ implies (varGt rp_id rl) (WPO.gt rp_id rl) | rl <- possible ] ++
  [ implies (varGtA rp_id rl) (_gt rp rl) | rl <- possible ] ++
  [ implies (varGeqA rp_id rl) (_geq rp rl) | rl <- possible ] ++
  [ implies (varLexGeq rp_id rl i j) (lexGeq' rp_id rl i j) | rl@(F _ ss, F _ ts) <- possible, i <- [0..(length ss)], j <- [0..(length ts)] ] ++
  [ implies (varLexGt rp_id rl i j) (lexGt' rp_id rl i j) | rl@(F _ ss, F _ ts) <- possible, i <- [0..(length ss)], j <- [0..(length ts)] ] ++
  [ implies (varEmptyStatus rp_id f) (conj [ neg (varStatus rp_id f i) | i <- [0 .. n-1] ]) | (f, n) <- sig ] ++
  [ implies (varLeftVar rp_id f) (leftVar' f) | (f, _) <- sig ] ++ 
  [ implies (varRightVar rp_id f) (rightVar' f) | (f, _) <- sig ]
  where
    possible = nub [ (s', t') | (lhs, rhs) <- ps ++ rs, s' <- subterms lhs, t' <- subterms rhs ]
    leftVar' g =
      conj (varEmptyStatus rp_id g : [ WPO.geqPrec rp_id f g | (f, _) <- sig ]) -- g is least
    rightVar' f = 
      conj (varStrictlySimple rp_id : [ -- f is greatest and ... 
        disj [ WPO.gtPrec rp_id f g, conj [ WPO.eqPrec rp_id f g, varEmptyStatus rp_id g ] ] | (g, _) <- sig
      ])
    strictlySimple' =
      conj [ implies (varStatus rp_id f i) (_strictly_simple rp fn i) | fn@(f, n) <- sig, i <- [0 .. n-1] ]

decodePrecedence :: Int -> Signature -> Model -> Precedence
decodePrecedence k sig model =
  [ (f, evalExp model (varPrec k f)) | (f, _) <- sig ]

decodeAF :: Int -> (String, Int) -> Model -> AF
decodeAF k (f, n) m = 
  Filtering [ evalFormula m (varStatus k f i) | i <- [0..(n-1)] ]

decodeArgumentFiltering :: Int -> Signature -> Model -> ArgumentFiltering
decodeArgumentFiltering k sig model =
  [ (f, decodeAF k (f, n) model) | (f, n) <- sig ]

decode :: ReductionPair -> Int -> Model -> Signature -> Order 
decode rp k m sig = WPO (_decode rp m sig) (decodePrecedence k sig m) (decodeArgumentFiltering k sig m)

wpo :: Int -> ReductionPair -> ReductionPair
wpo k rp = ReductionPair {
  _side_condition = side_condition k rp,
  _geq = varGeq k,
  _gt = varGt k,
  _monotone = monotone rp k ,
  _invariant = invariant rp k,
  _decode = decode rp k,
  _weakly_simple = weakly_simple rp k,
  _strictly_simple = strictly_simple rp k
}
