module LPO where

import Term
import Precedence
import ArgumentFiltering
import ReductionPair
import SMT
import Rule
import Order
import TRS

import Data.List

type Parameter = (Precedence, ArgumentFiltering)

gt' :: Precedence -> Term -> Term -> Bool
gt' _ (V _) _ = False
gt' p (F _ ss) t@(V _) = any (\s -> gt' p s t || s == t) ss  
gt' p s@(F _ ss) t@(F _ ts) =
  any (\s' -> gt' p s' t || s' == t) ss ||
  (all (\t' -> gt' p s t') ts &&
    gtPrec p s t || (geqPrec p s t && lex' ss ts))
  where
    lex' [] _ = False
    lex' (_ : _) [] = True
    lex' (s' : ss') (t' : ts') = gt' p s' t' || (s' == t' && lex' ss' ts')

gt :: Parameter -> Term -> Term -> Bool
gt (pr, af) s t = gt' pr (apply af s) (apply af t)

geq :: Parameter -> Term -> Term -> Bool
geq param@(_, af) s t =
  (apply af s == apply af t) || LPO.gt param s t

-- encoding

-- assumption
-- precedence is total and strict

varPrec :: Int -> String -> SMT.Exp
varPrec k f = SMT.Var ("lpo_p_" ++ show k ++ "_" ++ f)

-- whether i-th argument of f is respected or not 
-- (i is 0-indexed)
varRespect :: Int -> String -> Int -> SMT.Formula
varRespect k f i = SMT.FVar ("lpo_respect_" ++ show k ++ "_" ++ show i ++ "_" ++ f) 

-- collapsing or not
varCollapsing :: Int -> String -> SMT.Formula
varCollapsing k f = SMT.FVar ("lpo_collapsing_" ++ show k ++ "_" ++ f)

-- FIXME: hacky, avoid collision
varEq :: Int -> Rule -> SMT.Formula
varEq k (s, t) = SMT.FVar ("lpo_eq_" ++ show k ++ "_" ++ show s ++ "_" ++ show t)

varGt :: Int -> Rule -> SMT.Formula
varGt k (s, t) = SMT.FVar ("gtLPO_" ++ show k ++ "_" ++ show s ++ "_" ++ show t)

encodeEq :: Int -> Rule -> SMT.Formula
encodeEq _ (V x, V y)
  | x == y = SMT.top
  | otherwise = SMT.bottom
encodeEq k (F f ss, t@(V _)) = encodeEq_collapsing k f ss t
encodeEq k (s@(V _), F g ts) = encodeEq_collapsing k g ts s
encodeEq k (s@(F f ss), t@(F g ts)) =
  disj [ encodeEq_collapsing k f ss t,
         encodeEq_collapsing k g ts s,
         if f /= g
          then bottom
          else conj ( neg (varCollapsing k f) : [ implies (varRespect k f i) (varEq k (si, ti)) | (i, si, ti) <- zip3 [0..] ss ts ])]

encodeEq_collapsing :: Int -> String -> [Term] -> Term -> SMT.Formula
encodeEq_collapsing k f ss t =
  conj [ varCollapsing k f,
         disj [ conj [ varRespect k f i, varEq k (si, t)  ] | (i, si) <- zip [0..] ss ] ]

encodeGeq :: Int -> Rule -> Formula
encodeGeq k rl = disj [ varEq k rl, varGt k rl ]

encodeGt :: Int -> Rule -> Formula
encodeGt _ (V _, _) = bottom
encodeGt k (F f ss, t@(V _)) =
  disj [ encodeGt_collapsing_l k f ss t, encodeGt_sub k f ss t ]
encodeGt k (s@(F f ss), t@(F g ts)) =
  disj [ encodeGt_collapsing_l k f ss t,
         encodeGt_collapsing_r k s g ts,
         encodeGt_sub k f ss t,
         encodeGt_copy k f ss g ts,
         encodeGt_lex k f ss g ts ]

encodeGt_collapsing_l :: Int -> String -> [Term] -> Term -> Formula
encodeGt_collapsing_l k f ss t =
  conj [ varCollapsing k f,
         disj [ conj [ varRespect k f i, varGt k (si, t) ] | (i, si) <- zip [0..] ss ] ]

encodeGt_collapsing_r :: Int -> Term -> String -> [Term]-> Formula
encodeGt_collapsing_r k s g ts =
  conj [ varCollapsing k g,
         disj [ conj [ varRespect k g j, varGt k (s, tj) ] | (j, tj) <- zip [0..] ts ] ]

encodeGt_sub :: Int -> String -> [Term] -> Term -> Formula
encodeGt_sub k f ss t =
  conj [ neg (varCollapsing k f),
         disj [ conj [ varRespect k f i, encodeGeq k (si, t) ] | (i, si) <- zip [0..] ss ] ]

encodeGt_copy :: Int -> String -> [Term] -> String -> [Term] -> Formula
encodeGt_copy k f ss g ts =
  conj [ neg (varCollapsing k f),
         neg (varCollapsing k g),
         SMT.gt (varPrec k f) (varPrec k g),
         conj [ implies (varRespect k g j) (varGt k (s, tj)) | (j, tj) <- zip [0..] ts ] ]
  where
    s = F f ss

encodeGt_lex :: Int -> String -> [Term] -> String -> [Term] -> Formula
encodeGt_lex k f ss g ts
  | f /= g = bottom
  | otherwise = -- f = g
      conj [ neg (varCollapsing k f),
             disj [ lex_at i | i <- [0..(length ss - 1)] ] ]
  where
    s = F f ss
    lex_at i = -- confusing indexing...
      conj [ conj [ implies (varRespect k f i') (varEq k (si', ti')) | (i', si', ti') <- zip3 [0..] (take i ss) (take i ts) ],
             varRespect k f i, varGt k (ss !! i, ts !! i),
             conj [ implies (varRespect k f i') (varGt k (s, ti')) | (i', ti') <- zip [(i+1)..] (drop (i+1) ts) ] ]

side_condition :: Int -> (TRS, TRS) -> Signature -> Formula
side_condition k (ps, rs) fs = 
  conj ( [ Distinct [ varPrec k f | (f, _) <- fs] ] ++ -- strict precedence
         [ implies (varCollapsing k f) (exactlyOne [ varRespect k f i | i <- [0..(n-1)]]) | (f, n) <- fs ] ++ -- collapsing
         -- NOTE: encodeEq might boil it down to symmetric comparisons 
         [ implies (varEq k rl) (encodeEq k rl) | (l, r) <- possibleComp, rl <- [(l, r), (r, l)] ] ++ 
         [ implies (varGt k rl) (encodeGt k rl) | rl <- possibleComp ]
  )
  where
    possibleComp = nub [ (s', t') | (lhs, rhs) <- ps ++ rs,
                         s' <- subterms lhs, t' <- subterms rhs ]

monotone :: Int -> (String, Int) -> Int -> Formula
monotone k (f, _) i = varRespect k f i

weakly_simple :: Int -> (String, Int) -> Int -> Formula
weakly_simple k (f, _) i = varRespect k f i

strictly_simple :: Int -> (String, Int) -> Int -> Formula
strictly_simple k (f, _) i =
  conj [ varRespect k f i, neg (varCollapsing k f) ]

invariant :: Int -> (String, Int) -> Int -> Formula
invariant k (f, _) i = neg (varRespect k f i)

-- decoding

decodePrecedence :: Int -> Signature -> Model -> Precedence
decodePrecedence k sig model =
  [ (f, evalExp model (varPrec k f)) | (f, _) <- sig ]

decodeAF :: Int -> (String, Int) -> Model -> AF
decodeAF k (f, n) m = 
  if evalFormula m (varCollapsing k f)
    then case [ i | i <- [0..(n-1)], evalFormula m (varRespect k f i) ] of
           [i] -> Collapsing i
           _ -> error ("decodeAF: exactly one argument must be respected for collapsing\nmodel:" ++ show m)
    else Filtering [ evalFormula m (varRespect k f i) | i <- [0..(n-1)] ]

decodeArgumentFiltering :: Int -> Signature -> Model -> ArgumentFiltering
decodeArgumentFiltering k sig model =
  [ (f, decodeAF k (f, n) model) | (f, n) <- sig ]

decode :: Int -> Model -> Signature -> Order 
decode k m sig = LPO (decodePrecedence k sig m) (decodeArgumentFiltering k sig m)

-- interface

reduction_pair :: Int -> ReductionPair.ReductionPair
reduction_pair k = ReductionPair.ReductionPair {
  _side_condition = side_condition k,
  _geq = encodeGeq k,
  _gt = varGt k,
  _decode = decode k,
  _monotone = LPO.monotone k,
  _invariant = invariant k,
  _weakly_simple = weakly_simple k,
  _strictly_simple = strictly_simple k
}
