module CPS where

import Data.List
import Term
import TRS
import PCP
import Proof
import SN
import Rule
import SMT

-- { [t_i, t_{i-1},...,t1,s]) | s -> t1 -> .. -> ti for some i <= k }
at_most_k_steps :: Int -> TRS -> Term -> [([Term], [Rule])]
at_most_k_steps 0 _   s = [ ([s], []) ]
at_most_k_steps k trs s =
  nubBy eq (trss ++ [ (u : ts, nub (rule : rules)) | (ts@(t : _), rules) <- trss, (u, rule) <- reductsRules trs t ])
  where 
    eq ((v : _), _) ((w : _), _) = v == w
    eq _       _       = error "at_most_k_steps"
    trss = at_most_k_steps (k-1) trs s

join_sequence :: Int -> TRS -> Term -> Term -> Maybe [Term]
join_sequence k trs t u
  | t == u    = Just []
  | otherwise =
    case vss of
      vs : _ -> Just (init (tail vs))
      []     -> Nothing
    where
      trss = at_most_k_steps k trs t
      urss = at_most_k_steps k trs u
      vss = [ reverse ts ++ (v : us) | (v : ts, _) <- trss, (v' : us, _) <- urss, v == v' ]

rules_of_join_sequence :: Int -> TRS -> Term -> Term -> [TRS]
rules_of_join_sequence k trs t u
  | t == u    = [[]]
  | otherwise = [ nub (rules ++ rules') | (v : _, rules) <- trss, (v' : _, rules') <- urss, v == v']
  where trss = at_most_k_steps k trs t
        urss = at_most_k_steps k trs u

rules_of_to_sequence :: Int -> TRS -> Term -> Term -> [TRS]
rules_of_to_sequence k trs t u
  | t == u = [[]]
  | otherwise = [rules | (v : _, rules) <- at_most_k_steps k trs t, v == u]

joinable :: Int -> TRS -> Term -> Term -> Bool
joinable k trs t u = join_sequence k trs t u /= Nothing

diagrams :: Int -> TRS -> [PCP] -> Maybe [CritPairInfo]
diagrams _ _    [] = Just []
diagrams k trs ((t, _, ps, _, s, _, u) : peaks) = do
    vs <- join_sequence k trs t u
    ds <- diagrams k trs peaks
    return ((t, ps, s, u, Nothing, vs) : ds)

pcps' :: Int -> TRS -> [PCP] -> ([CritPairInfo], TRS)
pcps' _ _ [] = ([], [])
pcps' k trsC ((t, rulesL, ps, _, s, ruleR, u) : peaks)
  | TRS.subsume trsC (ruleR : rulesL) =
      ((t, ps, s, u, Nothing, [s]) : joinSequences, trsP)
  | Just vs <- join_sequence k trsC t u =
      ((t, ps, s, u, Nothing, vs)  : joinSequences, trsP)
  | otherwise                          = 
      (joinSequences, (s, t) : (s, u) : trsP)
  where
    (joinSequences, trsP) = pcps' k trsC peaks

-- PCPS(R,C)
-- Instead of R, parallel critical peaks of R must be given. 
pcps :: Int -> [PCP] -> TRS -> ([CritPairInfo], TRS)
pcps k peaks trsC = (joinSequences, nub trsP)
  where (joinSequences, trsP) = pcps' k trsC peaks

rule_index' :: Int -> TRS -> Rule -> Maybe Int
rule_index' _ [] _ = Nothing
rule_index' i (rule' : rules) rule
  | Rule.variant rule rule' = Just i
  | otherwise               = rule_index' (i + 1) rules rule

rule_index :: TRS -> Rule -> Maybe Int
rule_index table rule = rule_index' 0 table rule

in_S :: TRS -> Rule -> Formula
in_S table rule
  | Just i <- rule_index table rule = FVar ("in_S_" ++ show i)
  | otherwise                       = error "in_S"

in_FunS :: String -> Formula
in_FunS f = FVar ("inFunS_" ++ f)

p_is_empty :: Formula
p_is_empty = FVar "p_is_empty"

subset_of_S :: TRS -> TRS -> Formula
subset_of_S table trs = conj [ in_S table rule | rule <- trs ]

subset_of_FunS :: [String] -> Formula
subset_of_FunS fs = conj [ in_FunS f | f <- fs ]

encode_peaks :: Int -> TRS -> [PCP] -> Formula
encode_peaks k trs peaks =
  conj [ disj [subset_of_S table (rule : rules),
               disj [ subset_of_S table ss
                    | ss <- rules_of_join_sequence k trs t u ],
               conj [gt s t, gt s u, neg p_is_empty]] 
       | (t, rules, _, _, s, rule, u) <- peaks ]
  where
    table = trs

encode_rule_removal :: Int -> TRS -> Formula
encode_rule_removal k trs =
  conj
    (neg (subset_of_S table trs) : 
     [ implies (in_S table rule) (subset_of_FunS (Rule.functions rule)) 
     | rule <- trs ] ++
     [ implies (subset_of_FunS (Term.functions l))
               (disj [ subset_of_S table ss
                     | ss <- rules_of_to_sequence k trs l r])
     | (l, r) <- trs ])
  where table = trs

encode :: Int -> TRS -> [PCP] -> Formula
encode k trs peaks = 
  conj [side_condition (TRS.signature_of trs),
        implies (neg p_is_empty) (conj [ geq l r | (l, r) <- trs ]),
        encode_peaks k trs peaks,
        encode_rule_removal k trs]

decode_S :: TRS -> Model -> TRS
decode_S table model = 
  [ rule | rule <- table, eval_formula model (in_S table rule)]

decode_SN_proof :: Model -> Signature -> TRS -> TRS -> SNProof
decode_SN_proof model sig trs_P trs_R
  | eval_formula model p_is_empty = RIsEmpty
  | otherwise                   = 
      MonotoneReductionPair trs_P trs_R (SN.decode model sig)

reduce :: String -> Int -> TRS -> IO (Maybe Proof)
reduce smt k trs
  | TRS.left_linear trs, 
    Just joinSequences <- diagrams k trs peaks = do
      m <- sat smt (CPS.encode k trs peaks)
      case m of
        Nothing    -> return Nothing
        Just model -> return (Just ([proof_step], Subgoal subsystem))
          where
            subsystem = decode_S trs model
            (joinSequencesForS, trsP) = pcps k peaks subsystem
            proof_step =
              CPS {
                _trsR = trs,
                _trsS = subsystem,
                _trsP = trsP,
                _joinSequencesForS = joinSequencesForS,
                _joinSequencesForR = joinSequences,
                _relativeTerminationProof = 
                  decode_SN_proof model (TRS.signature_of trs) trsP trs
              }
  | otherwise = return Nothing
  where 
    peaks = pcpeak trs
