-- Rule labeling

module RL where

import Term
import Rule
import TRS
import SMT
import Result
import qualified Proof as Proof

-- ([rule1,...,ruleN], t) stands for
--   s ->_rule1 ... ->_ruleN t for some term s
type Sequence = ([Rule], Term)

-- ([rule1,...,ruleM], u, [rule1,...,ruleN]) stands for
--   s ->_rule1 ... ->_ruleM u and
--   t ->_rule1 ... ->_ruleN u for some terms s and t
type JoinSequence = ([Rule], [Rule])

geqEmb :: Eq a => [a] -> [a] -> Bool
geqEmb _ []       = True 
geqEmb [] (_ : _) = False
geqEmb (x : xs) (y : ys)
  | x == y    = geqEmb xs ys
  | otherwise = geqEmb xs (y : ys)

geqSequence :: Sequence -> Sequence -> Bool
geqSequence (rules1, s) (rules2, t) =
  s == t && geqEmb rules1 rules2

minimalSequences :: [Sequence] -> [Sequence]
minimalSequences [] = []
minimalSequences (s : ss)
  | any (geqSequence s) ss = minimalSequences ss
  | otherwise              = 
      s : minimalSequences [ s' | s' <- ss, not (geqSequence s' s) ]

sequences :: Int -> TRS -> Term -> [Sequence]
sequences 0 _   s = [([], s)]
sequences n trs s = 
  minimalSequences
    (ss ++ 
     [ (rules ++ [rule], u)
     | (rules, t) <- ss,
       rule <- trs,
       u <- reducts [rule] t ])
  where 
    ss = sequences (n - 1) trs s

-- joinSequence n trs s t
joinSequences :: Int -> TRS -> Term -> Term -> [JoinSequence]
joinSequences n trs s t =
  [ (rules1, rules2)
  | (rules1, u) <- ss,
    (rules2, v) <- ts,
    u == v ] 
  where
    ss = sequences n trs s
    ts = sequences n trs t

type Labeling = [(Rule, Exp)]

split :: [a] -> [([a], a, [a])]
split [] = []
split (x : xs) =
  ([], x, xs) : [ (x : xs', y, zs) | (xs', y, zs) <- split xs ]

gt2 :: [Exp] -> [Exp] -> Formula
gt2 ks ms = conj [ disj [ Gt k m | k <- ks ] | m <- ms ]

-- constrains a rewrite sequence ->_n1 ->_n2 -> ... ->_nN  
-- to be of form 
--   ->_(<km)^*  cup  
--   ->_(<k)^* ->_(<=m) ->_(<km)^*
decreasing1 :: Exp -> Exp -> [Exp] -> Formula
decreasing1 k m ns =
  disj (gt2 [k,m] ns :
        [ conj [gt2 [k] ns1, Geq m n2, gt2 [k,m] ns3]
        | (ns1, n2, ns3) <- split ns ])

-- <-_k ->_m subseteq ->...->nsR  nsL<-..<-
-- where  
decreasing :: Exp -> Exp -> [Exp] -> [Exp] -> Formula
decreasing k m nsR nsL =
  conj [decreasing1 k m nsR, decreasing1 m k nsL]

labeling :: String -> TRS -> Labeling
labeling x trs =
  [ (rule, Var (x ++ show i)) | (rule, i) <- zip trs [0 :: Integer ..] ]

lookupBy :: (a -> a -> Bool) -> a -> [(a,b)] -> Maybe b
lookupBy _ _ [] = Nothing
lookupBy p x ((y, z) : a)
  | p x y     = Just z
  | otherwise = lookupBy p x a

labelRule :: Labeling -> Rule -> Exp
labelRule phi rule
  | Just e <- lookupBy Rule.variant rule phi = e
  | otherwise = error "varRule"

labelRules :: Labeling -> [Rule] -> [Exp]
labelRules phi rules = [ labelRule phi rule | rule <- rules ]

mirror :: ((Term, Term), Rule, Rule) -> ((Term, Term), Rule, Rule) 
mirror ((s, t), ruleL, ruleR) = ((t, s), ruleR, ruleL)

cpWithRules :: TRS -> [((Term, Term), Rule, Rule)]
cpWithRules trs =
  [ (Rule.rename "x" 1 (s, t), rule1, rule2)
  | rule2 <- trs,
    let (l2, r2) = Rule.rename "y" 0 rule2,
    p <- functionPositions l2,
    rule1 <- trs,
    p /= [] || not (Rule.variant rule1 rule2),
    let (l1, r1) = Rule.rename "x" 0 rule1,
    Just sigma <- [mgu l1 (subtermAt l2 p)],
    let s = Term.substitute (replace l2 r1 p) sigma,
    let t = Term.substitute r2 sigma
  ]

ruleLabelingConstraint1 :: Labeling -> Labeling -> Rule -> Rule -> JoinSequence -> Formula
ruleLabelingConstraint1 phi psi ruleL ruleR (rulesR, rulesL) =
  decreasing (labelRule  phi ruleL) 
             (labelRule  psi ruleR)
             (labelRules psi rulesR)
             (labelRules phi rulesL)

-- ...L stands for "<-"
-- ...R stands for "->"
ruleLabelingConstraint :: Int -> Labeling -> Labeling -> TRS -> Formula
ruleLabelingConstraint k phi psi trs =
  conj [ disj [ ruleLabelingConstraint1 phi' psi' ruleL ruleR js | js <- jss ] 
       | ((s, t), ruleL, ruleR) <- cpWithRules trs,
         let jss = joinSequences k trs s t,
         (phi', psi') <- [(phi,psi),(psi,phi)] ]
    
showSequence :: ([(Rule,Int)], String, String, [Rule]) -> String
showSequence (a, to, from, rules) = 
  unwords [ to ++ show i ++ from | rule <- rules, Just i <- [lookup rule a] ]

showJoinSequence :: ([(Rule,Int)], ([Rule], [Rule])) -> String
showJoinSequence (a, (rulesR, rulesL)) =
  showSequence (a, "-", "->", rulesR) ++ " " ++
  showSequence (a, "<-", "-", rulesL)

showLabel :: (String, Rule, Int) -> String
showLabel (s, rule, n) =
  s ++ "(" ++ showRule rule ++ ") = " ++ show n

showLabeling :: (String, Labeling, Model) -> String
showLabeling (s, phi, model) =
  unlines [ "  " ++ showLabel (s, rule, evalExp model e) | (rule, e) <- phi ]

ruleLabeling :: String -> Int -> TRS -> IO Result
ruleLabeling tool k trs
  | not (TRS.linear trs) = return MAYBE
  | otherwise = do
    m <- sat tool (ruleLabelingConstraint k phi psi trs)
    case m of
      Nothing    ->
        return MAYBE
      Just model -> 
        return (YES (proof model))
  where
    phi = labeling "phi_" trs
    psi = labeling "psi_" trs
    proof model =
      Proof.proof
        "Rule labeling (van Oostrom 2008)."
        trs
        (unlines 
          [ "All critical peaks are decreasing wrt rule labeling:"
          , ""
          , showLabeling ("phi", phi, model)
          , showLabeling ("psi", psi, model) ])
  
-- Ordinary rule labeling based on a single labeling function. 

ordinaryRuleLabelingConstraint :: Int -> Labeling -> TRS -> Formula
ordinaryRuleLabelingConstraint k phi trs =
  conj [ disj [ decreasing 
                  (labelRule  phi ruleL) 
                  (labelRule  phi ruleR)
                  (labelRules phi rulesR)
                  (labelRules phi rulesL) 
              | (rulesR, rulesL) <- joinSequences k trs s t ]
       | ((s, t), ruleL, ruleR) <- cpWithRules trs ]

ordinaryRuleLabeling :: String -> Int -> TRS -> IO Result
ordinaryRuleLabeling tool k trs
  | not (TRS.linear trs) = return MAYBE
  | otherwise = do
    m <- sat tool (ordinaryRuleLabelingConstraint k phi trs)
    case m of
      Nothing    ->
        return MAYBE
      Just model -> 
        return (YES (proof model))
  where
    phi = labeling "phi_" trs
    proof model =
      Proof.proof
        "Rule labeling (van Oostrom 2008)."
        trs
        (unlines 
          [ "All critical peaks are decreasing wrt rule labeling:\n"
          , showLabeling ("phi", phi, model) ])
