-- Rule labeling based on parallel critical pairs
module PRL where

import Data.List
import Term
import Rule
import TRS
import Rewriting
import PCP
import RL (Labeling, sequences, gt2, labelRule, labelRules,
           showLabeling, labeling)
import SMT
import Result
import qualified Proof as Proof

magicK :: Int
magicK = 2

proofSH22 :: TRS -> TRS -> String -> String -> String
proofSH22 trs cs labels crProof =
  Proof.subproof
      "Composable parallel rule labeling (Shintani and Hirokawa 2022)."
      trs
      cs
      (unlines [caption, "", labels])
      crProof
  where
    caption = "All parallel critical peaks (except C's) are decreasing wrt rule labeling:"

proofZFM15 :: TRS -> String -> String
proofZFM15 trs labels =
  Proof.proof
    "Parallel rule labeling (Zankl et al. 2015)."
    trs
    (unlines [caption, "", labels])
  where
    caption = "All parallel critical peaks (except C's) are decreasing wrt rule labeling:"

-- ([rule1,...,ruleN], t) stands for
--   s -> ... -> -||-> -> ... -> t
type ParallelSequence = ([Rule], [Rule], [Rule], Term)

type JoinParallelSequence = ([Rule], [Rule], [Rule], [Rule], [Rule], [Rule])

parallelSequences :: Int -> TRS -> [String] -> Term -> [ParallelSequence]
parallelSequences n trs xs s =
  [ ([], [], rules3, t) | (rules3, t) <- ss ] ++
  [ (rules1, rules2, rules3, v)
  | (rules1, t) <- ss,
    length rules1 < n,
    (rules2, u) <- restrictedParallelStep trs xs t,
    rules2 /= [] && not (length rules2 == 1 && subset rules2 rules1),
    (rules3, v) <- sequences (n - length rules1 - 1) trs u ]
  where 
    ss = sequences n trs s

-- joinSequence n trs s t
joinParallelSequences :: Int -> TRS -> Term -> [String] -> Term -> [JoinParallelSequence]
joinParallelSequences n trs s xs t =
   [ (rulesR1, rulesR2, rulesR3,
      rulesL1, rulesL2, rulesL3)
   | (rulesR1, rulesR2, rulesR3, u) <- parallelSequences n trs (Term.variables s) s,
     (rulesL1, rulesL2, rulesL3, v) <- parallelSequences n trs xs t,
     u == v ]

geq2 :: [Exp] -> [Exp] -> Formula
geq2 ks ms = conj [ disj [ Geq k m | k <- nub ks ] | m <- nub ms ]

-- ...L stands for "<-"
-- ...R stands for "->"
parallelRuleLabelingConstraint :: Int -> Labeling -> Labeling -> TRS -> TRS -> Formula
parallelRuleLabelingConstraint n phi psi trs c =
  conj [
    conj [ disj [ conj [gt2  ks  (labelRules psi' rs1),
                        gt2  ms  (labelRules phi' ls1),
                        geq2 ms  (labelRules psi' rs2),
                        geq2 ks  (labelRules phi' ls2),
                        gt2  kms (labelRules psi' rs3 ++
                                  labelRules phi' ls3)]
                | (rs1,rs2,rs3,ls1,ls2,ls3) <- joinParallelSequences n trs s xs t ]
        | (phi', psi') <- [(phi, psi), (psi, phi)],
          (s, (rulesL, ruleR), xs, t) <- pcp trs,
          -- discard parallel peaks that consisting only C-rules
          not (TRS.subsume c (ruleR:rulesL)),
          let ks = labelRules phi' rulesL,
          let ms = labelRules psi' [ruleR],
          let kms = nub (ks ++ ms)
        ],
    bottomConstraint phi psi trs c]


bottomConstraint :: Labeling -> Labeling -> TRS -> TRS -> Formula
bottomConstraint phi psi trs c =
  conj [ conj [ Eq (appL lr) (Val 0) | lr <- c ],
         conj [ Eq (appR lr) (Val 0) | lr <- c ],
         conj [ Gt (appL lr) (Val 0) | lr <- trs'],
         conj [ Gt (appR lr) (Val 0) | lr <- trs']
       ]
  where
    trs' = trs \\ c
    appL = labelRule phi
    appR = labelRule psi

sh22 :: String -> Int -> Criterion -> TRS -> IO Result
sh22 tool k thm trs
  | TRS.leftLinear trs =
    sh22' tool ks thm trs (initializeCs trs)
  | otherwise =
    return MAYBE
  where
    ks | k > magicK = [magicK, k]
       | otherwise  = [k]

-- 1. apply PRL,
-- 2. apply underlying confluence criterion,
-- 3. try decomposition
initializeCs :: TRS -> [TRS]
initializeCs trs =
  [] : trs :
    [c | c1 <- power (trs \\ c0),
         let c = c1 ++ c0,
         -- c is not empty and c =/= trs
         c /= [] && length c /= n]
  where
    c0 = makeC0 trs
    n  = length trs

makeC0 :: TRS -> TRS
makeC0 _ = []

power :: [a] -> [[a]]
power [] = [ [] ]
power (x : xs) = yss ++ [ x : ys | ys <- yss ]
  where yss = power xs

-- Assume R is left-linear TRS
sh22' :: String -> [Int] -> Criterion -> TRS -> [TRS] -> IO Result
sh22' _ _ _ _ [] = return MAYBE
sh22' tool ks thm trs (cs:css) = do
  m <- parallelRuleLabelingWithC tool ks trs cs
  case m of
    YES labels -> do
      v <- thm cs
      case v of
        YES crProof -> return (YES (proofSH22 trs cs labels crProof))
        _ -> sh22' tool ks thm trs css
    _ -> sh22' tool ks thm trs css

-- Assume that the TRS C is
-- + left-linear,
-- + confluent, and
-- + a subset of TRS
parallelRuleLabelingWithC :: String -> [Int] -> TRS -> TRS -> IO Result
parallelRuleLabelingWithC _ [] _ _ =
  return MAYBE
parallelRuleLabelingWithC tool (k:ks) trs c = do 
  m <- sat tool (parallelRuleLabelingConstraint k phi psi trs c)
  case m of
    Just model -> return (YES (labels model))
    Nothing    -> parallelRuleLabelingWithC tool ks trs c
  where
    phi = labeling "phi_" trs
    psi = labeling "psi_" trs
    labels model =
      unlines [ showLabeling ("phi", phi, model)
              , showLabeling ("psi", psi, model)]

parallelRuleLabeling :: String -> Int -> TRS -> IO Result
parallelRuleLabeling tool k trs
  | TRS.leftLinear trs = do
      m <- parallelRuleLabelingWithC tool ks trs []
      case m of
        YES labels -> return (YES (proofZFM15 trs labels))
        _          -> return m
  | otherwise =
    return MAYBE
  where
    ks | magicK < k = [magicK, k]
       | otherwise  = [k]

-- Ordinary parallel rule labeling (Zankl et al. JAR 2014)

-- ...L stands for "<-"
-- ...R stands for "->"
ordinaryParallelRuleLabelingConstraint :: Int -> Labeling -> TRS -> Formula
ordinaryParallelRuleLabelingConstraint n phi trs =
  conj [ disj [ conj [gt2  ks  (labelRules phi rs1),
                      gt2  ms  (labelRules phi ls1),
                      geq2 ms  (labelRules phi rs2),
                      geq2 ks  (labelRules phi ls2),
                      gt2  kms (labelRules phi (rs3 ++ ls3))]
              | (rs1,rs2,rs3,ls1,ls2,ls3) <- joinParallelSequences n trs s xs t ]
       | (s, (rulesL, ruleR), xs, t) <- pcp trs,
         let ks = labelRules phi rulesL,
         let ms = labelRules phi [ruleR],
         let kms = nub (ks ++ ms)
       ]

ordinaryParallelRuleLabeling :: String -> Int -> TRS -> IO Result
ordinaryParallelRuleLabeling tool k trs
  | not (TRS.leftLinear trs) = return MAYBE
  | otherwise =
    ordinaryParallelRuleLabeling0 tool ks trs
  where
    ks | magicK < k = [magicK, k]
       | otherwise  = [k]

ordinaryParallelRuleLabeling0 :: String -> [Int] -> TRS -> IO Result
ordinaryParallelRuleLabeling0 _ [] _ =
  return MAYBE
ordinaryParallelRuleLabeling0 tool (k:ks) trs = do
  let smt = ordinaryParallelRuleLabelingConstraint k phi trs
  m <- sat tool smt
  case m of
    Nothing    ->
      ordinaryParallelRuleLabeling0 tool ks trs
    Just model -> 
      return (YES (proofZFM15 trs (showLabeling ("phi", phi, model))))
  where
    phi = labeling "phi_" trs
