module CompositionalCriteria (
  CompCriterion,
  trans,
  reduction,
  reduceTRS,
  necessarySubsystem,
  closingSystem,
  closingSystems,
  -- emptyness criterion
  emptyTRS,
  -- parallel critical pair closing system criterion
  selfParallelCriticalPairClosing,
  rSelfParallelCriticalPairClosing,
  parallelCriticalPairClosing,
  parallelCriticalPairClosingWithC,
  power
) where

import Prelude hiding (product)
import Data.List hiding (product)
import Term
import Rule
import TRS
import Rewriting
import PCP
import Result
import qualified Proof
import SMT

-- compositional criteriion takes 
type CompCriterion = TRS -> TRS -> IO Result

-- Enumerates all subsets in ascending order with respect to sizes.
power :: [a] -> [[a]]
power [] = [ [] ]
power (x : xs) = yss ++ [ x : ys | ys <- yss ]
  where yss = power xs

proofEmptyTRS :: String
proofEmptyTRS = "# emptiness\n\nThe empty TRS is confluent.\n"
  
emptyTRS :: Criterion
emptyTRS [] = return (YES proofEmptyTRS)
emptyTRS _  = return MAYBE

trans :: Criterion -> CompCriterion
trans cr trs sub
  | sub == [] = cr trs
  | otherwise = return MAYBE


-----------------------------------------
-- Parallel Critical Pair Closing System
-----------------------------------------

proofPCPCS :: TRS -> TRS -> String -> String
proofPCPCS trs cs crproof =
  Proof.subproof
    "parallel critical pair closing system (Shintani and Hirokawa 2022)"
    trs
    cs
    "The TRS R is left-linear and all parallel critical pairs are joinable by C.\n"
    crproof

proofReduction :: TRS -> TRS -> Result -> Result
proofReduction trs cc result =
  case result of
    YES crproof -> YES (proof crproof)
    NO witness -> NO (proof witness)
    MAYBE -> MAYBE
  where
    proof txt =
      Proof.subeqproof
        "parallel critical pair closing system (Shintani and Hirokawa 2022, Section 8 in LMCS 2023)"
        trs
        cc
        "The TRS R is left-linear and all parallel critical pairs are joinable by C.\n"
        txt

-- CR(R) <==> CR(C)

-- reduce R to C that satisfies CR(R) <==> CR(C)
reduction :: String -> Int -> Criterion -> Criterion
reduction tool k cr trs = do
  v <- reduceTRS tool k trs
  case v of
    Just cc -> do
      m <- reduction tool k cr cc
      return (proofReduction trs cc m)
    Nothing ->
      cr trs

reduceTRS :: String -> Int -> TRS -> IO (Maybe TRS)
reduceTRS _ _ trs
  | not (TRS.leftLinear trs) = return Nothing
reduceTRS tool k trs =
  case closingSystem k trs trs of
    []      -> return Nothing
    (cs:_)  -> necessarySubsystem tool k trs cs

-- taken from RL.hs
type Sequence = ([Rule], Term)

type JoinSequence = ([Rule], [Rule])

shorterThanEq :: Sequence -> Sequence -> Bool
shorterThanEq (rules1, s) (rules2, t) =
  s == t && length rules1 >= length rules2

minimalSequences :: [Sequence] -> [Sequence]
minimalSequences [] = []
minimalSequences (s : ss)
  | any (shorterThanEq s) ss = minimalSequences ss
  | otherwise                = 
      s : minimalSequences [ s' | s' <- ss, not (shorterThanEq s' s) ]

sequences :: Int -> TRS -> Term -> [Sequence]
sequences 0 _   s = [([], s)]
sequences n trs s = 
  (ss ++ 
   [ (rules ++ [rule], u)
   | (rules, t) <- ss,
     rule <- trs,
     u <- reducts [rule] t ])
  where 
    ss = sequences (n - 1) trs s

-- caliculate join sequence by breadth-first search
-- NOTE: for performance, sequences may be filtered
-- e.g. (rules,t) is removed if there exists (rules',t) in same list
-- and |rules| >= |rules'|
joinSequencesFromList :: Int -> TRS -> [Sequence] -> [Sequence] -> [JoinSequence]
joinSequencesFromList 0 _ sStep tStep =
  commonReduct sStep tStep
joinSequencesFromList n trs sStep tStep =
  commonReduct sStep tStep ++
  joinSequencesFromList (n-1) trs (rewriteOnce trs sStep) (rewriteOnce trs tStep)

commonReduct :: [Sequence] -> [Sequence] -> [JoinSequence]
commonReduct steps1 steps2 =
  [(rules1, rules2)
  | (rules1, s) <- steps1,
    (rules2, t) <- steps2,
    s == t]

rewriteOnce :: TRS -> [Sequence] -> [Sequence]
rewriteOnce trs steps =
  minimalSequences
    (steps ++
    [(rule : rules, y)
    | (rules, x) <- steps,
      (rule, y) <- stepWithRule trs x])

joinSequences :: Int -> TRS -> Term -> Term -> [JoinSequence]
joinSequences n trs s t = joinSequencesFromList n trs [([],s)] [([],t)]

product :: Eq a => [[[a]]] -> [[a]]
product [] = [[]]
product (xss:xsss) =
  [ nub (xs ++ ysss) | ysss <- product xsss, xs <- xss ]

productHead :: Eq a => [[a]] -> Maybe [a]
productHead [] = Just []
productHead ((x:_):xs) = do
  ys <- productHead xs
  return (nub (x:ys))
productHead _ = Nothing
  
pcpPair :: TRS -> [(Term,Term)]
pcpPair trs = [(t,u) | (t,_,_,u) <- pcp trs, t /= u]

-- closingSystem generates [C] such that PCP(R) are joinable by C
closingSystem :: Int -> TRS -> TRS -> [TRS]
closingSystem k trs jtrs =
  case productHead joins of
    Just trsList ->
      [nub [rule | trs' <- trsList, rule <- trs']]
    Nothing -> []
  where
    joins =
      [ [nub (xs++ys) | (xs,ys) <- joinSequences k jtrs t u]
      | (t,u) <- pcpPair trs]

closingSystems :: Int -> TRS -> TRS -> [TRS]
closingSystems k trs jtrs =
  product
    [ [nub (xs++ys) | (xs,ys) <- joinSequences k trs t u]
    | (t,u) <- pcpPair jtrs]

-- CR(R) <== CR(C)

-- parallel critical pair closing system
selfParallelCriticalPairClosing :: Int -> Criterion -> Criterion
selfParallelCriticalPairClosing k noncr trs =
  rSelfParallelCriticalPairClosing k id noncr trs

-- parallel critical pair closing system
rSelfParallelCriticalPairClosing :: Int -> (Criterion -> Criterion) -> Criterion -> Criterion
rSelfParallelCriticalPairClosing _ _ _ [] =
  return (YES "# emptiness\n\nThe empty TRS is confluent.\n")
rSelfParallelCriticalPairClosing k pp noncr trs
  | TRS.leftLinear trs = do
      m <- noncr trs
      case m of
        NO _ ->
          return m
        _ ->
          parallelCriticalPairClosing0 k cr trs (pcpPair trs) css
  | otherwise =
      return MAYBE
  where
    cr = pp (rSelfParallelCriticalPairClosing k pp noncr)
    css = [rs | rs <- power trs, length rs /= n]
    n = length trs

parallelCriticalPairClosing :: Int -> Criterion -> Criterion
parallelCriticalPairClosing k cr trs = do
  m <- cr trs
  case m of
    MAYBE ->
      parallelCriticalPairClosing0 k cr trs (pcpPair trs) css
    _ -> return m
  where
    css = [rs | rs <- power trs, length rs /= n]
    n = length trs

parallelCriticalPairClosingWithC :: Int -> CompCriterion
parallelCriticalPairClosingWithC k trs cs
  | all (joinable k cs) ps =
      return (YES (proofPCPCS trs cs ""))
  | otherwise =
      return MAYBE
  where
    ps = pcpPair trs

-- left-linearity of trs and css is needed!
parallelCriticalPairClosing0 :: Int -> Criterion -> TRS -> [(Term,Term)] -> [TRS] -> IO Result
parallelCriticalPairClosing0 _ _ _ _ [] =
  return MAYBE
parallelCriticalPairClosing0 k cr trs ps (cs:css)
  | all (joinable k cs) ps = do
    m <- cr cs
    case m of
      YES x -> return (YES (proofPCPCS trs cs x))
      _ -> parallelCriticalPairClosing0 k cr trs ps css
  | otherwise =
    parallelCriticalPairClosing0 k cr trs ps css

joinable :: Int -> TRS -> (Term,Term) -> Bool
joinable k trs (s,t) =
  intersect
    (at_most_k_steps k trs s)
    (at_most_k_steps k trs t) /= []


-- CR(R) ==> CR(C)

type Labeling a = [(a, Formula)]

labeling :: String -> [a] -> Labeling a
labeling x ys =
  [ (y, FVar (x ++ show i)) | (y, i) <- zip ys [0 :: Integer ..] ]

lookupBy :: (a -> a -> Bool) -> a -> [(a,b)] -> Maybe b
lookupBy _ _ [] = Nothing
lookupBy p x ((y, z) : a)
  | p x y     = Just z
  | otherwise = lookupBy p x a

label :: (a -> a -> Bool) -> Labeling a -> a -> Formula
label eq phi x
  | Just e <- lookupBy eq x phi = e
  | otherwise = error "varRule"

-- SAT encoding
-- given subset C0 of R, computing C that superset of C0
necessaryConstraint :: Labeling Rule -> Int -> TRS -> TRS -> Formula
necessaryConstraint phi k trs cc0 =
  conj [
    -- C should subsume C0 
    conj [ruleL rl | rl <- cc0],
    -- C /= R
    disj [neg (ruleL rl) | rl <- trs],
    -- if rule in C then all variables related symbols in rule should be true
    conj [
      disj [
        neg (ruleL rl),
        conj [sigL f | f <- Rule.functions rl]
      ]
    | rl <- trs],
    -- every l->r in R, l ->* C r or not(subset Fun(l) Fun(C0))
    conj [
      disj [
        disj [
          conj [ ruleL rl | rl <- ss]
        | (ss, u) <- sequences k trs l, r == u],
        neg (conj [
          sigL f | f <- Term.functions l])
        ]
    | (l,r) <- trs \\ cc0]
  ]
  where
    ruleL = label Rule.variant phi
    sigL = label (==) (labeling "f_" (TRS.functions trs))

assemble :: Labeling a -> Model -> [a]
assemble phi model = 
  [x | (a, b) <- model,
       b > 0,
       Just x <- [lookup (FVar a) list]]
  where
    list = [(f, rl) | (rl, f) <- phi]

necessarySubsystem :: String -> Int -> TRS -> TRS -> IO (Maybe TRS)
necessarySubsystem tool k trs cc = do
  v <- sat tool (necessaryConstraint phi k trs cc)
  case v of
    Just model -> return (Just (assemble phi model))
    Nothing -> return Nothing
  where
    phi = labeling "rule_" trs
