module KH12 where

import Term
import Rule
import TRS
import SN
import Proof
import CPS
import SMT

-- s and t are compatible if ren(s) and ren(t) are unifiable
compatible :: Term -> Term -> Bool
compatible (V _) _           = True
compatible _ (V _)           = True
compatible (F f ss) (F g ts) = f == g && and (zipWith compatible ss ts)

incompatible :: Term -> Term -> Bool
incompatible s t = not (compatible s t)

rule_index' :: Int -> TRS -> Rule -> Maybe Int
rule_index' _ [] _ = Nothing
rule_index' i (rule' : rules) rule
  | Rule.variant rule rule' = Just i
  | otherwise               = KH12.rule_index' (i + 1) rules rule

rule_index :: TRS -> Rule -> Maybe Int
rule_index table rule = KH12.rule_index' 0 table rule

in_S :: TRS -> Rule -> Formula
in_S table rule
  | Just i <- KH12.rule_index table rule = FVar ("in_S_" ++ show i)
  | otherwise                            = error "in_S"

in_R :: TRS -> Rule -> Formula
in_R table rule = neg (KH12.in_S table rule)

in_FunS :: String -> Formula
in_FunS f = FVar ("inFunS_" ++ f)

subset_of_S :: TRS -> TRS -> Formula
subset_of_S table trs = conj [ KH12.in_S table rule | rule <- trs ]

subset_of_R :: TRS -> TRS -> Formula
subset_of_R table trs = conj [ KH12.in_R table rule | rule <- trs ]

subset_of_FunS :: [String] -> Formula
subset_of_FunS fs = conj [ KH12.in_FunS f | f <- fs ]

s_critical_pair' :: Int -> TRS -> Rule -> Rule -> Position -> Formula
s_critical_pair' k trs (l1, r1) rule2 p
  | Just sigma <- mgu l1 (subterm_at l2 p) = formula (joinable k trs (Term.substitute (replace l2 r1 p) sigma) (Term.substitute r2 sigma))
  | otherwise                              = formula False
  where (l2, r2) = Rule.rename "y" 1 rule2

s_critical_pair :: Int -> TRS -> Rule -> Rule -> Formula
s_critical_pair k trs rule1@(l1, _) rule2@(l2, _) =
  conj [ s_critical_pair' k trs rule1 rule2 p
       | p <- function_positions l2, compatible l1 (subterm_at l2 p)]

encode_sno :: TRS -> Formula
encode_sno trsRS = 
  conj [ disj [KH12.subset_of_R table [rule1, rule2],
               KH12.subset_of_S table [rule1, rule2],
               conj [ neg (formula (compatible l1 (subterm_at l2 p))) 
                    | p <- function_positions l2 ]]
       | rule1@(l1, _) <- trsRS, rule2@(l2, _) <- trsRS]
  where table = trsRS

encode_sn :: TRS -> Formula
encode_sn trsRS = 
  conj
    ([side_condition (TRS.signature_of trsRS)] ++
     [conj [implies (KH12.in_R table rule) (gt  l r),
            implies (KH12.in_S table rule) (geq l r)]
     | rule@(l, r) <- trsRS])
  where table = trsRS

encode_cp :: Int -> TRS -> Formula
encode_cp k trsRS =
  conj [implies (subset_of_R table [rule1, rule2]) (s_critical_pair k trsRS rule1 rule2)
  | rule1 <- trsRS, rule2 <- trsRS]
  where table = trsRS

encode_rule_removal :: TRS -> Formula
encode_rule_removal trsRS =
  conj
    (neg (KH12.subset_of_S table trsRS) : 
     [ implies (KH12.in_S table rule) (KH12.subset_of_FunS (Rule.functions rule)) 
     | rule <- trsRS ] ++
     [ implies (KH12.subset_of_FunS (Term.functions l)) (KH12.in_S table rule)
     | rule@(l, _) <- trsRS ])
  where table = trsRS

encode :: Int -> TRS -> Formula
encode k trsRS =
  conj [encode_sno trsRS,
        encode_sn trsRS,
        encode_cp k trsRS,
        KH12.encode_rule_removal trsRS]

r_is_empty :: TRS -> Formula
r_is_empty trs = conj [neg (in_R table rule) | rule <- trs]
  where table = trs

decode_R :: TRS -> Model -> TRS
decode_R table model = 
  [ rule | rule <- table, eval_formula model (neg (KH12.in_S table rule))]

decode_S :: TRS -> Model -> TRS
decode_S table model = 
  [ rule | rule <- table, eval_formula model (KH12.in_S table rule)]

decode_SN_proof :: Model -> Signature -> TRS -> TRS -> SNProof
decode_SN_proof model sig trsR trsS
  | eval_formula model (r_is_empty trsR) = RIsEmpty
  | otherwise = 
      MonotoneReductionPair trsR trsS (decode model sig)


-- trs = R cup S
-- CR(R cup S) <=> CR(S)
reduce :: String -> Int -> TRS -> IO (Maybe Proof)
reduce smt k trs = do
  m <- sat smt (KH12.encode k trs)
  case m of
    Nothing    -> return Nothing
    Just model -> do
      let trsR = KH12.decode_R trs model
      let trsS = KH12.decode_S trs model
      let sn_proof = KH12.decode_SN_proof model (TRS.signature_of trs) trsR trsS
      let proof_step = KH12 {
        _trsR = trsR, 
        _trsS = trsS, 
        _relativeTerminationProof = sn_proof
      }
      return (Just ([proof_step], Subgoal trsS))
    
