-- Certifiable rule labeling based on parallel critical pairs
module RL where

import Data.List
import Term
import Rule
import TRS
import Rewriting
import PCP
import Labeling
import SMT
import Proof

-- s ->_{l -> r} ...
type Step = (Rule, Term)

-- s1 ->_{l1 -> r1} s2 ->_{l2 -> r2} ... sn ->_{ln -> rn} t
type Sequence = ([Step], Term)

-- s ->* -||-> ->* t
type ParallelSequence = (Sequence, [Rule], Term, Sequence)

-- t ->* -||-> ->* *<- <-||- *<- u
type JoinParallelSequence = (ParallelSequence, ParallelSequence)

data PeakType = PhiPsi | PsiPhi deriving Show

-- t <-||- -> u and t ->* -||-> ->* *<- <-||- *<- u
type CandidateDiagrams = (PCP, [JoinParallelSequence])

type Diagram = (PCP, [Term])

-- Generating diagrams.

geq_emb :: Eq a => [a] -> [a] -> Bool
geq_emb _ []       = True 
geq_emb [] (_ : _) = False
geq_emb (x : xs) (y : ys)
  | x == y    = geq_emb xs ys
  | otherwise = geq_emb xs (y : ys)

rules_of_steps :: [Step] -> [Rule]
rules_of_steps steps = [ rule | (rule, _) <- steps ]

geq_sequence :: Sequence -> Sequence -> Bool
geq_sequence (steps1, s) (steps2, t) =
  s == t && geq_emb (rules_of_steps steps1) (rules_of_steps steps2)

minimal_sequences :: [Sequence] -> [Sequence]
minimal_sequences [] = []
minimal_sequences (s : ss)
  | any (geq_sequence s) ss = minimal_sequences ss
  | otherwise              = 
      s : minimal_sequences [ s' | s' <- ss, not (geq_sequence s' s) ]

reverse_sequences :: Int -> TRS -> Term -> [Sequence]
reverse_sequences 0 _   s = [([], s)]
reverse_sequences n trs s =
  minimal_sequences
    (([], s) :
    [ ((rule, u) : steps, u)
    | (steps, t) <- reverse_sequences (n - 1) trs s,
      notElem s [ t' | (_, t') <- steps ],
      rule <- trs,
      u <- reducts [rule] t,
      s /= u ])

parallel_sequences :: Int -> TRS -> [String] -> Term -> [ParallelSequence]
parallel_sequences n trs xs s =
  [ (([], s), [], s, seq3) | seq3 <- ss ] ++
  [ (seq1, rules2, u, seq3)
  | seq1@(steps1, t) <- ss,
    length steps1 < n,
    (rules2, u) <- restrictedParallelStep trs xs t,
    rules2 /= [] && not (length rules2 == 1 &&
                         subset rules2 (rules_of_steps steps1)),
    seq3 <- reverse_sequences (n - length steps1 - 1) trs u ]
  where 
    ss = reverse_sequences n trs s

-- joinSequence n trs s t
join_parallel_sequences :: Int -> TRS -> Term -> [String] -> Term -> [JoinParallelSequence]
join_parallel_sequences n trs s xs t =
   [ (psR, psL)
   | psR@(_, _, _, (_, u)) <- parallel_sequences n trs (Term.variables s) s,
     psL@(_, _, _, (_, v)) <- parallel_sequences n trs xs t,
     u == v ]

candidate_diagrams :: TRS -> Int -> [CandidateDiagrams]
candidate_diagrams rs k =
  [ ((t, rulesL, ps, xs, s, ruleR, u), join_parallel_sequences k rs t xs u)
  | (t, rulesL, ps, xs, s, ruleR,  u) <- pcpeak rs ]

-- Labels

-- rule in C
in_C :: Labeling Exp -> Rule -> Formula
in_C phi rule = Eq (label_of_rule phi rule) (Val 0)

-- TRS is a subset of C
subset_of_C :: Labeling Exp -> TRS -> Formula
subset_of_C phi trs = conj [ in_C phi rule | rule <- trs ]

-- f in Fun(C)
in_FunC :: String -> Formula
in_FunC f = FVar ("y_" ++ f)

-- F is a subset Fun(C)
subset_of_FunC :: [String] -> Formula
subset_of_FunC fs = conj [ in_FunC f | f <- fs ]

-- C /= R, 
-- Fun(l -> r) subseteq Fun(C) for all l -> r in C, and
-- for every l -> r in R
--   if Fun(l) is a subset of Fun(C) then
--      l ->* C r
--   
constraint_on_C :: Labeling Exp -> Int -> TRS -> Formula
constraint_on_C phi k trs_R = 
  conj [neg (subset_of_C phi trs_R), 
        conj [ implies (in_C phi rule) (subset_of_FunC (Rule.functions rule))
             | rule <- trs_R ],
        -- l ->* C r or not (Fun(l) is a subset of Fun(C0)) for every l -> r in R
        conj [ implies 
                 (subset_of_FunC (Term.functions l))
                 (disj [ subset_of_C phi (rules_of_steps steps)
                       | (steps, u) <- reverse_sequences k trs_R l, r == u ])
             | (l, r) <- trs_R ] ]

labels_of_sequence:: Labeling a -> Sequence -> [a]
labels_of_sequence phi (steps, _) = [ label_of_rule phi rule | (rule, _) <- steps ]


-- Encoding functions.

labeling :: String -> TRS -> Labeling Exp
labeling x trs =
  [ (rule, Var (x ++ show i)) | (rule, i) <- zip trs [0 :: Integer ..] ]

var_diagram :: PeakType -> Int -> Exp
var_diagram PhiPsi i = Var ("phi_psi_" ++ show i) 
var_diagram PsiPhi i = Var ("psi_phi_" ++ show i) 

gt2 :: [Exp] -> [Exp] -> Formula
gt2 ks ms = conj [ disj [ Gt k m | k <- ks ] | m <- ms ]

geq2 :: [Exp] -> [Exp] -> Formula
geq2 ks ms = conj [ disj [ Geq k m | k <- nub ks ] | m <- nub ms ]

-- ...L stands for "phi<-"
-- ...R stands for "->psi"
decreasingness :: Labeling Exp -> Labeling Exp -> [CandidateDiagrams] -> Formula
decreasingness phi psi diagrams =
  conj [ disj (conj [subset_of_C phi' rulesL, in_C psi' ruleR] :
               [ conj [Eq (var_diagram peakType i) (Val j),
                       gt2  ks  (labels_of_sequence psi' stepsR1),
                       geq2 ms  (labels_of_rules    psi' rulesR2),
                       gt2  kms (labels_of_sequence psi' stepsR3),
                       gt2  kms (labels_of_sequence phi' stepsL3),
                       geq2 ks  (labels_of_rules    phi' rulesL2),
                       gt2  ms  (labels_of_sequence phi' stepsL1)]
               | (j, ((stepsR1, rulesR2, _, stepsR3), 
                 (stepsL1, rulesL2, _, stepsL3))) <- zip [0..] jss ])
       | (i, ((_, rulesL, _, _, _, ruleR, _), jss)) <- zip [0..] diagrams,
         (peakType, phi', psi') <- [(PhiPsi, phi, psi), (PsiPhi, psi, phi)],
         let ks = labels_of_rules phi' rulesL,
         let ms = [label_of_rule psi' ruleR],
         let kms = nub (ks ++ ms) ]

constraint :: Int -> TRS -> Labeling Exp -> Labeling Exp -> [CandidateDiagrams] -> Formula
constraint k trs_R phi psi ds
  | any unjoinable ds = bottom
  | otherwise         =
      conj (decreasingness phi psi ds :
            constraint_on_C phi k trs_R :
            [ iff (in_C phi rule) (in_C psi rule) | rule <- trs_R ] ++
            [ Geq (label_of_rule chi rule) (Val 0) 
            | rule <- trs_R, chi <- [phi, psi] ])
  where
    unjoinable (_, js) = js == []


eval_labeling :: Model -> Labeling Exp -> Labeling Int
eval_labeling model phi =
  [ (rule, eval_exp model e) | (rule, e) <- phi ]

compact :: Eq a => [a] -> [a]
compact [] = []
compact [x] = [x]
compact (x : y : zs)
  | x == y    =     compact (y : zs)
  | otherwise = x : compact (y : zs) 

terms_in_parallel_sequence :: ParallelSequence -> [Term]
terms_in_parallel_sequence ((steps1, _), _, t, (steps2, _)) =
  reverse [ s | (_, s) <- steps1 ] ++ t :
  reverse [ u | (_, u) <- steps2 ]

intermediates:: JoinParallelSequence -> [Term]
intermediates (seqR, seqL) = compact (ts ++ reverse us) 
  where
    ts = terms_in_parallel_sequence seqR
    us = terms_in_parallel_sequence seqL

decode_C :: Model -> Labeling Exp -> TRS
decode_C model phi = 
  [ rule | (rule, e) <- phi, eval_exp model e == 0 ]

decode_diagrams :: PeakType -> Labeling Int -> Labeling Int -> [CandidateDiagrams] -> Model -> [CritPairInfo]
decode_diagrams peakType phi' psi' ds model =
  [ (t, ps, s, u, Just (maxLeft, right), intermediates js) 
  | (i, ((t, rulesL, ps, _, s, ruleR, u), jss)) <- zip [0..] ds,
    let js = jss !! eval_exp model (var_diagram peakType i),
    let maxLeft = maximum (0 : labels_of_rules phi' rulesL),
    let right = label_of_rule psi' ruleR,
    maxLeft > 0 || right > 0
  ]

rl_proof :: TRS -> Model -> Labeling Exp -> Labeling Exp -> [CandidateDiagrams] -> ProofStep
rl_proof trs model phi psi diagrams =
  RL {
    _trsR = trs,
    _trsC = decode_C model phi,
    _phi = phi',
    _psi = psi',
    _joinsRS = decode_diagrams PhiPsi phi' psi' diagrams model,
    _joinsSR = decode_diagrams PsiPhi psi' phi' diagrams model
  }
  where
    phi' = eval_labeling model phi
    psi' = eval_labeling model psi

reduce_all :: String -> Int -> TRS -> IO (Maybe Proof)
reduce_all smt k trs
  | not (TRS.left_linear trs) = return Nothing
  | otherwise = do
      m <- sat smt (conj (constraint k trs' phi psi diagrams : gt0))
      case m of
        Nothing    -> return Nothing
        Just model -> 
          let proof_step = rl_proof trs model phi psi diagrams in
          return (Just ([proof_step], Subgoal []))
      where
        trs' = TRS.unique trs
        phi = labeling "phi_" trs'
        -- psi = labeling "psi_" trs'
        psi = labeling "phi_" trs'
        diagrams = candidate_diagrams trs' k
        gt0 = [ Gt (label_of_rule phi rule) (Val 0) | rule <- trs' ]

reduce :: String -> Int -> TRS -> IO (Maybe Proof)
reduce smt k trs
  | not (TRS.left_linear trs) = return Nothing
  | otherwise = do
      m <- sat smt (conj [constraint k trs' phi psi diagrams, disj gt0])
      case m of
        Nothing    -> return Nothing
        Just model -> 
          let proof_step = rl_proof trs model phi psi diagrams in
          return (Just ([proof_step], Subgoal (_trsC proof_step)))
      where
        trs' = TRS.unique trs
        phi = labeling "phi_" trs'
        psi = labeling "psi_" trs'
        diagrams  = candidate_diagrams trs' k
        gt0 = [ Gt (label_of_rule phi rule) (Val 0) | rule <- trs' ]
