-- Certifiable rule labeling based on parallel critical pairs
module PRLCert where

import Data.List
import Term
import Rule
import TRS
import Rewriting
import PCP
import RL (Labeling, sequences, gt2, labelRule, labelRules,
           showLabeling, labeling)
import SMT
import CompositionalCriteria
import qualified Proof
import qualified XML

quickTestBound :: Int
quickTestBound = 2

quickTest :: Int -> [Int]
quickTest k
  | k > quickTestBound = [quickTestBound, k]
  | otherwise          = [k]

proofSH22 :: TRS -> TRS -> String -> String -> String
proofSH22 trs cs labels crProof =
  Proof.subproof
      "Compositional parallel rule labeling (Shintani and Hirokawa 2022)."
      trs
      cs
      (unlines [caption, "", labels])
      crProof
  where
    caption = "All parallel critical peaks (except C's) are decreasing wrt rule labeling:"

proofSH22withC :: TRS -> TRS -> String -> String
proofSH22withC trs cs labels =
  Proof.subproof
      "Compositional parallel rule labeling (Shintani and Hirokawa 2022)."
      trs
      cs
      (unlines [caption, "", labels])
      ""
  where
    caption = "All parallel critical peaks (except C's) are decreasing wrt rule labeling:"

proofZFM15 :: TRS -> String -> String
proofZFM15 trs labels =
  Proof.proof
    "Parallel rule labeling (Zankl et al. 2015)."
    trs
    (unlines [caption, "", labels])
  where
    caption = "All parallel critical peaks (except C's) are decreasing wrt rule labeling:"

-- ([rule1,...,ruleN], t) stands for
--   s -> ... -> -||-> -> ... -> t
type ParallelSequence = ([Rule], [Rule], [Rule], Term)

type JoinParallelSequence = ([Rule], [Rule], [Rule], [Rule], [Rule], [Rule])

parallelSequences :: Int -> TRS -> [String] -> Term -> [ParallelSequence]
parallelSequences n trs xs s =
  [ ([], [], rules3, t) | (rules3, t) <- ss ] ++
  [ (rules1, rules2, rules3, v)
  | (rules1, t) <- ss,
    length rules1 < n,
    (rules2, u) <- restrictedParallelStep trs xs t,
    rules2 /= [] && not (length rules2 == 1 && subset rules2 rules1),
    (rules3, v) <- sequences (n - length rules1 - 1) trs u ]
  where 
    ss = sequences n trs s

-- joinSequence n trs s t
joinParallelSequences :: Int -> TRS -> Term -> [String] -> Term -> [JoinParallelSequence]
joinParallelSequences n trs s xs t =
   [ (rulesR1, rulesR2, rulesR3,
      rulesL1, rulesL2, rulesL3)
   | (rulesR1, rulesR2, rulesR3, u) <- parallelSequences n trs (Term.variables s) s,
     (rulesL1, rulesL2, rulesL3, v) <- parallelSequences n trs xs t,
     u == v ]

geq2 :: [Exp] -> [Exp] -> Formula
geq2 ks ms = conj [ disj [ Geq k m | k <- nub ks ] | m <- nub ms ]

-- ...L stands for "<-"
-- ...R stands for "->"
parallelRuleLabelingConstraint :: Int -> Labeling -> Labeling -> TRS -> TRS -> Formula
parallelRuleLabelingConstraint n phi psi trs c =
  conj [
    conj [ disj [ conj [gt2  ks  (labelRules psi' rs1),
                        gt2  ms  (labelRules phi' ls1),
                        geq2 ms  (labelRules psi' rs2),
                        geq2 ks  (labelRules phi' ls2),
                        gt2  kms (labelRules psi' rs3 ++
                                  labelRules phi' ls3)]
                | (rs1,rs2,rs3,ls1,ls2,ls3) <- joinParallelSequences n trs s xs t ]
        | (phi', psi') <- [(phi, psi), (psi, phi)],
          (s, (rulesL, ruleR), xs, t) <- pcp trs,
          -- discard parallel peaks that consisting only C-rules
          not (TRS.subsume c (ruleR:rulesL)),
          let ks = labelRules phi' rulesL,
          let ms = labelRules psi' [ruleR],
          let kms = nub (ks ++ ms)
        ],
    bottomConstraint phi psi trs c]


bottomConstraint :: Labeling -> Labeling -> TRS -> TRS -> Formula
bottomConstraint phi psi trs c =
  conj [ conj [ Eq (appL lr) (Val 0) | lr <- c ],
         conj [ Eq (appR lr) (Val 0) | lr <- c ],
         conj [ Gt (appL lr) (Val 0) | lr <- trs'],
         conj [ Gt (appR lr) (Val 0) | lr <- trs']
       ]
  where
    trs' = trs \\ c
    appL = labelRule phi
    appR = labelRule psi

-- 1. apply PRL,
-- 2. apply underlying confluence criterion,
-- 3. try decomposition
initializeCs :: TRS -> [TRS]
initializeCs trs =
  [] : trs :
    [c | c1 <- power (trs \\ c0),
         let c = c1 ++ c0,
         -- c is not empty and c =/= trs
         c /= [] && length c /= n]
  where
    c0 = []
    n  = length trs

-- Assume that the TRS C is
-- + left-linear,
-- + confluent, and
-- + a subset of TRS
parallelRuleLabeling0 :: String -> [Int] -> TRS -> TRS -> IO (Maybe String)
parallelRuleLabeling0 _ [] _ _ =
  return Nothing
parallelRuleLabeling0 tool (k:ks) trs c = do 
  m <- sat tool (parallelRuleLabelingConstraint k phi psi trs c)
  case m of
    Just model -> return (Just (labels model))
    Nothing    -> parallelRuleLabeling0 tool ks trs c
  where
    phi = labeling "phi_" trs
    psi = labeling "psi_" trs
    labels model =
      unlines [ showLabeling ("phi", phi, model)
              , showLabeling ("psi", psi, model)]

-- Ordinary parallel rule labeling (Zankl et al. JAR 2014)

-- ...L stands for "<-"
-- ...R stands for "->"
ordinaryParallelRuleLabelingConstraint :: Int -> Labeling -> TRS -> Formula
ordinaryParallelRuleLabelingConstraint n phi trs =
  conj [ disj [ conj [gt2  ks  (labelRules phi rs1),
                      gt2  ms  (labelRules phi ls1),
                      geq2 ms  (labelRules phi rs2),
                      geq2 ks  (labelRules phi ls2),
                      gt2  kms (labelRules phi (rs3 ++ ls3))]
              | (rs1,rs2,rs3,ls1,ls2,ls3) <- joinParallelSequences n trs s xs t ]
       | (s, (rulesL, ruleR), xs, t) <- pcp trs,
         let ks = labelRules phi rulesL,
         let ms = labelRules phi [ruleR],
         let kms = nub (ks ++ ms)
       ]


-- certifiable version

extractLabeling :: Labeling -> Model -> XML.Labeling
extractLabeling phi model =
  [ (rule, evalExp model e) | (rule, e) <- phi ]

-- Assume that the TRS C is
-- + left-linear,
-- + confluent, and
-- + a subset of TRS
parallelRuleLabeling :: String -> [Int] -> TRS -> IO (Maybe XML.Certificate)
parallelRuleLabeling _ [] _ =
  return Nothing
parallelRuleLabeling _ _ trs
  | not (TRS.leftLinear trs) = return Nothing 
parallelRuleLabeling smt (k:ks) trs = do 
  m <- sat smt (parallelRuleLabelingConstraint k phi psi trs [])
  case m of
    Nothing    -> parallelRuleLabeling smt ks trs
    Just model -> return (Just (trs, XML.Composition phi' psi' (k*2) [] XML.EmptyTRS))
      where
        phi' = [ (rule, evalExp model e) | (rule, e) <- phi ]
        psi' = [ (rule, evalExp model e) | (rule, e) <- psi ]
  where
    phi = labeling "phi_" trs
    psi = labeling "psi_" trs
