module Main where

import Data.List
import Term
import Rule
import TRS
import Rewriting
import PCP
import Result
import Parser
import qualified RL as RL
import qualified PRL as PRL
import CPS
import qualified PCPS as PCPS
import qualified RTermination as RT
import qualified Proof as Proof
import System.Environment

-- Strong closedness (Huet 1980)

stronglyJoinable :: Int -> TRS -> Term -> Term -> Bool
stronglyJoinable k trs s t =
  intersect
    (at_most_k_steps k trs s)
    (at_most_k_steps 1 trs t) /= []

stronglyClosed :: Int -> TRS -> Bool
stronglyClosed k trs = all p (cp trs)
  where
    p (s, t) =
      stronglyJoinable k trs s t &&
      stronglyJoinable k trs t s

strongClosednessCriterion :: Int -> Criterion
strongClosednessCriterion k trs
  | TRS.linear trs && stronglyClosed k trs =
    return (YES proof)
  | otherwise = return MAYBE
  where
    proof =
      Proof.proof
        "strong closedness (Huet 1980)"
        trs
        "R is linear and every critical pair is strongly closed"

-- Parallel closedness (Huet 1980)

parallelClosed :: TRS -> Bool
parallelClosed trs = all p (cp trs)
  where p (s, t) = elem t (parallelStep trs s)

parallelClosednessCriterion :: Criterion
parallelClosednessCriterion trs
  | TRS.leftLinear trs && parallelClosed trs =
    return (YES proof)
  | otherwise = return MAYBE
  where
    proof =
      Proof.proof
        "parallel closedness (Huet 1980)."
        trs
        "R is let-linear and every critical pair is parallel closed"

-- Almost parallel closedness (Toyama 1988)

cpWithPositions :: TRS -> [(Term, Position, Term)]
cpWithPositions trs =
  [ (replace (Term.substitute l2 sigma) (Term.substitute r1 sigma) p, p,
     Term.substitute r2 sigma)
  | rule2 <- trs,
    let (l2, r2) = Rule.rename "y" 1 rule2,
    rule1 <- trs,
    let (l1, r1) = Rule.rename "x" 1 rule1,
    p <- functionPositions l2,
    let l2p = subtermAt l2 p,
    Just sigma <- [mgu l2p l1],
    p /= [] || not (Rule.variant rule1 rule2) ]

almostParallelClosed1 :: Int -> TRS -> (Term, Position, Term) -> Bool
almostParallelClosed1 k trs (s, [], t) =
  intersect (at_most_k_steps k trs s) (parallelStep trs t) /= []
almostParallelClosed1 _ trs (s, _, t) =
  elem t (parallelStep trs s)

almostParallelClosed :: Int -> TRS -> Bool
almostParallelClosed k trs =
  all (almostParallelClosed1 k trs) (cpWithPositions trs)

almostParallelClosednessCriterion :: Int -> Criterion
almostParallelClosednessCriterion k trs
  | TRS.leftLinear trs && almostParallelClosed k trs =
    return (YES proof)
  | otherwise = return MAYBE
  where
    proof =
      Proof.proof
        "almost parallel closedness (Toyama 1988)"
        trs
        "R is left-linear and every critical pair is almost parallel closed"

-- Development closedness (van Oostrom 1997)

developmentClosed :: TRS -> Bool
developmentClosed trs = all p (cp trs)
  where p (s, t) = elem t (multistep trs s)

developmentClosednessCriterion :: Criterion
developmentClosednessCriterion trs
  | TRS.leftLinear trs && developmentClosed trs =
    return (YES proof)
  | otherwise = return MAYBE
  where
    proof =
      Proof.proof
        "development closedness (van Oostrom 1997)"
        trs
        "R is left-linear and every critical pair is development closed"

-- Almost development closedness (van Oostrom 1997)

almostDevelopmentClosed1 :: Int -> TRS -> (Term, Position, Term) -> Bool
almostDevelopmentClosed1 k trs (s, [], t) =
  intersect (at_most_k_steps k trs s) (multistep trs t) /= []
almostDevelopmentClosed1 _ trs (s, _, t) =
  elem t (multistep trs s)

almostDevelopmentClosed :: Int -> TRS -> Bool
almostDevelopmentClosed k trs =
  all (almostDevelopmentClosed1 k trs) (cpWithPositions trs)

almostDevelopmentClosednessCriterion :: Int -> Criterion
almostDevelopmentClosednessCriterion k trs
  | TRS.leftLinear trs && almostDevelopmentClosed k trs =
    return (YES proof)
  | otherwise = return MAYBE
  where
    proof =
      Proof.proof
        "almost development closedness (van Oostrom 1997)"
        trs
        "R is left-linear and every critical pair is almost development closed"


common :: [Term] -> [Term] -> Maybe Term
common ts us =
  case intersect ts us of
    [] -> Nothing
    v : _ -> Just v

closed1 :: Int -> TRS -> (Term, [String], Term) -> Maybe Term
closed1 k trs (t, xs, u) =
  common (at_most_k_steps k trs t)
         [ v | (_, v) <- restrictedParallelStep trs xs u ]

allClosed1 :: Int -> TRS -> [(Term, (TRS, Rule), [String], Term)] -> Maybe [(Term, Term, Term)]
allClosed1 _ _ [] = Just []
allClosed1 k trs ((t, _, xs, u) : peaks)
  | Just v <- closed1 k trs (t, xs, u),
    Just valleys <- allClosed1 k trs peaks =
    Just ((t, v, u) : valleys)
  | otherwise = Nothing

closed2 :: Int -> TRS -> (Term, Term) -> Maybe Term
closed2 k trs (t, u) =
  common (parallelStep trs t) (at_most_k_steps k trs u)

allClosed2 :: Int -> TRS -> [(Term, Term)] -> Maybe [(Term, Term, Term)]
allClosed2 _ _ [] = Just []
allClosed2 k trs ((t, u) : peaks)
  | Just v <- closed2 k trs (t, u),
    Just valleys <- allClosed2 k trs peaks =
    Just ((t, v, u) : valleys)
  | otherwise = Nothing

toyama81 :: Int -> Criterion
toyama81 k trs
  | TRS.leftLinear trs,
    Just valleys2 <- allClosed2 k trs (cp trs),
    Just valleys1 <- allClosed1 k trs (pcp trs) =
    return (YES (proof valleys1 valleys2))
  | otherwise = return MAYBE
  where
    proof v1 v2 =
      Proof.proof
        "parallel closedness based on parallel critical pair (Toyama 1981)"
        trs
        (unlines
          [ "R is left linear and"
          , "The parallel critical pairs <-||- -root-> are closed as follows:\n"
          , unlines [ "  " ++ show t ++ " -||-> " ++ show v ++ " <-^* " ++ show u
                    | (t, v, u) <- v1 ]
          , ""
          , "The ordinary critical pairs <- -root-> are closed as follows:\n"
          , unlines [ "  " ++ show t ++ " ->^* " ++ show v ++ " <-||- " ++ show u
                    | (t, v, u) <- v2 ]])


-- temporary definition: criterion for reversible systems
reversible :: Int -> TRS -> TRS
reversible k trs =
  reversibleFixed
    [(lr,rs) | (lr,rs) <- dependMap,
               all (\(lr',rs') ->
                 not (lr == lr' && rs' `properSubset` rs))
                 dependMap]
  where
    usedRules s t = [rs | (u,rs) <- at_most_k_walks k trs s, u == t]
    dependMap = [((l,r), rs) | (l,r) <- trs, rs <- usedRules r l]
    properSubset xs ys = xs `subset` ys && not (ys `subset` xs)

reversibleFixed :: [(Rule,TRS)] -> TRS
reversibleFixed depend
  | depend == depend' = nub keys
  | otherwise = reversibleFixed depend'
  where
    keys = [lr | (lr,_) <- depend]
    depend' = [(lr,used) | (lr,used) <- depend,
                           used `subset` keys]

at12Criterion1 :: Int -> TRS -> TRS -> (Term, Term) -> Bool
at12Criterion1 k s p (t, u) =
  [] /= intersect
    (at_most_k_steps k sp t)
    [w | v <- at_most_k_steps k sp u,
         w <- parallelStep p v]
  where
    sp = s ++ p

at12Criterion2 :: TRS -> TRS -> Bool
at12Criterion2 s p =
  null [True | (_, pos, _, _) <- overlap2 p s, pos /= []]

at12Criterion3 :: Int -> TRS -> TRS -> (Term, Term) -> Bool
at12Criterion3 k s p (t, u) =
  [] /= intersect to [u]
  ||
  [] /= intersect to from
  where
    sp   = s ++ p
    to   = [w | v <- at_most_k_steps k sp t, w <- parallelStep p v]
    from = [w | v <- reducts s u, w <- at_most_k_steps k sp v]

at12Criterion :: Int -> TRS -> TRS -> Bool
at12Criterion k s p =
  all (at12Criterion1 k s p) (cp s)
  &&
  at12Criterion2 s p
  &&
  all (at12Criterion3 k s p) (cp2 s p)

aotoToyama12 :: RT.RTermination -> Int -> Criterion
aotoToyama12 tt k trs0
  | TRS.leftLinear trs0 = do
    m <- tt trs ps
    case m of
      MAYBE -> return MAYBE
      NO _  -> return MAYBE
      YES _ ->
        if at12Criterion k trs ps then
          return (YES proof)
        else
          return MAYBE
  | otherwise =
    return MAYBE
  where
    trs = trs0 \\ ps
    ps = reversible k trs0
    proof =
      Proof.proof
       "reversible TRS (Aoto and Toyama 2012)"
       trs0
       (unlines
         [ "Let P be the following reversible subset of R:\n"
         , TRS.showTRS ps
         , "R and P satisfies conditions of Theorem 3.18" ])

aotoToyama12Completion :: RT.RTermination -> Int -> Criterion
aotoToyama12Completion tt k trs0
  | TRS.leftLinear trs0 = do
    m <- tt trsC ps
    case m of
      MAYBE -> return MAYBE
      NO _  -> return MAYBE
      YES _ ->
        if at12Criterion k trsC ps then
          return (YES proof)
        else
          return MAYBE
  | otherwise =
    return MAYBE
  where
    trsC = reductionPreserving k trs ps
    trs = trs0 \\ ps
    ps = reversible k trs0
    proof =
      Proof.proof
       "reduction preserving completion (Aoto and Toyama 2012)"
       trs0
       (unlines
         [ "Let P be the following reversible TRS reduced from R:\n"
         , TRS.showTRS ps
         , "R and P satisfies conditions of Theorem 3.18" ])

reductionPreserving :: Int -> TRS -> TRS -> TRS
reductionPreserving k trs ps =
  nub [(l',r) | (l,r) <- replacement,
                l' <- parallelStep ps l,
                (l',r) `elem` replacement
                || r `elem` at_most_k_steps k trs l']
  where
    replacement = [(l,r') | (l,r) <- trs, r' <- parallelStep ps r]

-- parallel critical pair closing system
pcpPair :: TRS -> [(Term, Term)]
pcpPair trs = [(t, u) | (t, _, _, u) <- pcp trs]

selfParallelCriticalPairClosing :: Int -> Criterion
selfParallelCriticalPairClosing _ [] =
  return (YES "# emptiness\n\nThe empty TRS is confluent.\n")
selfParallelCriticalPairClosing k trs
  | TRS.leftLinear trs =
      -- infinitely loopping
      -- parallelCriticalPairClosing cr trs
      parallelCriticalPairClosing0 k cr trs (pcpPair trs) pcpcs
  | otherwise =
      return MAYBE
  where
    cr = selfParallelCriticalPairClosing k
    -- pcpcs = [rs | rs <- parallelClosingSystems k trs,
    --               length rs /= n]
    pcpcs = [rs | rs <- PRL.power trs, length rs /= n]
    n = length trs

parallelCriticalPairClosing :: Int -> Criterion -> Criterion
parallelCriticalPairClosing k cr trs = do
  m <- cr trs
  case m of
    YES v -> return (YES v)
    _     -> parallelCriticalPairClosing0 k cr trs (pcpPair trs) pcpcs
  where
    -- pcpcs = [rs | rs <- parallelClosingSystems k trs,
    --               length rs /= n]
    pcpcs = [rs | rs <- PRL.power trs, length rs /= n]
    n = length trs

parallelCriticalPairClosing0 :: Int -> Criterion -> TRS -> [(Term, Term)] -> [TRS] -> IO Result
parallelCriticalPairClosing0 _ _ _ _ [] =
  return MAYBE
parallelCriticalPairClosing0 _ _ trs _ _
  | not (TRS.leftLinear trs) =
    return MAYBE
parallelCriticalPairClosing0 k cr trs ps (cs:css)
  | all (joinable k cs) ps = do
    m <- cr cs
    case m of
      YES x -> return (YES (proof x))
      _     -> parallelCriticalPairClosing0 k cr trs ps css
  | otherwise =
    parallelCriticalPairClosing0 k cr trs ps css
  where
    proof crproof =
      Proof.subproof
        "parallel critical pair closing system (Shintani and Hirokawa 2022)"
        trs
        cs
        "Since R is left-linear and all parallel critical pair are joinable by confluent C,"
        crproof

joinable :: Int -> TRS -> (Term,Term) -> Bool
joinable k trs (s,t) =
  intersect
    (at_most_k_steps k trs s)
    (at_most_k_steps k trs t) /= []

-- parallelClosingSystems :: Int -> TRS -> [TRS]
-- parallelClosingSystems k trs =
--   mergeRules [ joinRules k trs t u | (t, _, _, u) <- pcp trs ]

-- mergeRules :: [[TRS]] -> [TRS]
-- mergeRules [] = []
-- mergeRules ([rss]) = rss
-- mergeRules (rss:rsss) =
--   [nub (rs ++ trs) | rs <- rss, trs <- mergeRules rsss]
-- 
-- joinRules :: Int -> TRS -> Term -> Term -> [TRS]
-- joinRules k trs t u =
--   nub [nub (trs1++trs2) |
--     (v,trs1) <- minimal_at_most_k_walks k trs t,
--     (w,trs2) <- minimal_at_most_k_walks k trs u,
--     v == w]
-- 
-- minimal_at_most_k_walks :: Int -> TRS -> Term -> [(Term,[Rule])]
-- minimal_at_most_k_walks k trs t =
--   minimalBy f (at_most_k_walks k trs t)
--   where
    -- f (x,xs) (y,ys) = x == y && subset xs ys

-- minimalBy :: (a -> a -> Bool) -> [a] -> [a]
-- minimalBy _ [] = []
-- minimalBy f (xs:xss)
--   | any (\ys -> f ys xs) yss = yss
--   | otherwise = xs : yss
--   where
--     yss = minimalBy f [ys | ys <- xss, not (f xs ys)]

-- Menu
criterion :: String -> RT.RTermination -> String -> Int -> Criterion
criterion thm tt smt k = case thm of
  "rl"         -> RL.ruleLabeling smt k
  "prl"        -> PRL.parallelRuleLabeling smt k
  "cps"        -> CPS.hm11 tt k
  "pcps"       -> PCPS.hm11 tt k
  "at12"       -> aotoToyama12 tt k
  "at12C"      -> aotoToyama12Completion tt k
  "sc"         -> strongClosednessCriterion k
  "orthogonal" -> orthogonal
  "empty"      -> emptyTRS
  _            -> error emsg
  where
    emsg = unlines
      [ "unkown confluence criterion:" ++ thm
      , "supported criteria:"
      ,  unwords supportedCriteria ]

supportedCriteria :: [String]
supportedCriteria =
  ["rl", "prl", "cps", "pcps", "at12", "at12C", "sc", "orthogonal", "empty"]

emptyTRS :: Criterion
emptyTRS [] = return (YES "# emptiness\n\nThe empty TRS is confluent.\n")
emptyTRS _  = return MAYBE

orthogonal :: Criterion
orthogonal trs
  | TRS.leftLinear trs && TRS.cp trs == []
    = return (YES proof)
  | otherwise
    = return MAYBE
  where
    proof =
      Proof.proof "orthogonality" trs "The left-linear and non-overlapping TRS is confluent"

cpsWith :: RT.RTermination -> String -> String -> Int -> Criterion
cpsWith tt smt arg k = CPS.sh22 tt k (criterion arg tt smt k)

pcpsWith :: RT.RTermination -> String -> String -> Int -> Criterion
pcpsWith tt smt arg k = PCPS.sh22 tt k (criterion arg tt smt k)

pcpcsWith :: RT.RTermination -> String -> String -> Int -> Criterion
pcpcsWith tt smt arg k = parallelCriticalPairClosing k (criterion arg tt smt k)

prlWith :: RT.RTermination -> String -> String -> Int -> Criterion
prlWith tt smt arg k = PRL.sh22 smt k (criterion arg tt smt k)


-- main
help :: IO ()
help = do
  putStrLn "hakusan [option] <file.trs>"
  putStrLn "Options:"
  putStrLn "  -orthogonal    orthogonality (Rosen 1973)"
  putStrLn "  -sc <k>        strong closedness (Huet 1980)"
  putStrLn "  -pc            parallel closedness (Huet 1980)"
  putStrLn "  -apc <k>       almost parallel closedness (Toyama 1988)"
  putStrLn "  -dc            development closedness (van Oostrom 1997)"
  putStrLn "  -adc <k>       almost development closedness (van Oostrom 1997)"
  putStrLn "  -toyama81 <k>  enhanced parallel closedness (Toyama 1981)"
  putStrLn "  -orl <k>       rule labeling with one labeling function (van Oostrom 2008)"
  putStrLn "  -rl <k>        rule labeling (van Oostrom 2008)"
  putStrLn "  -oprl <k>      parallel rule labeling with one labeling function (Zankl et al. 2015)"
  putStrLn "  -prl <k>       parallel rule labeling (Zankl et al. 2015)"
  putStrLn "  -cps <k>       critical pair system (Hirokawa and Middeldorp 2011)"
  putStrLn "  -pcpcs <k>     parallel-critical-pair-closing system"
  putStrLn "  -prl-<X> <k>   parallel rule labeling with criterion <X>"
  putStrLn "  -cps-<X> <k>   critical pair system with criterion <X>"
  putStrLn "  -pcps-<X> <k>  parallel critical pair system with criterion <X>"
  putStrLn "  -pcpcs-<X> <k> parallel-critical-pair-closing system with criterion <X>"
  putStrLn "  -tt <s>        uses termination tool <s>"
  putStrLn "  -smt <s>       uses SMT solver <s>"
  putStrLn ("where <X> ::= " ++ intercalate " | " supportedCriteria)

data Config = Config {
  _criterion :: Criterion,
  _smt :: String,
  _termination :: String
}

parseArgs :: Config -> [String] -> IO ()
parseArgs c [file] = do
  spec <- readSpecFile file
  case spec of
    Left e    -> putStrLn (show e)
    Right trs -> do
      result <- _criterion c trs
      putStrLn (show result)
parseArgs c ("-smt"       : s : args) =
  parseArgs (c { _smt = s }) args
parseArgs c ("-tt"       : s : args) =
  parseArgs (c { _termination = s }) args
parseArgs c ("-orthogonal": args) =
  parseArgs (c { _criterion = orthogonal}) args
parseArgs c ("-sc"       : s : args) =
  parseArgs (c { _criterion = strongClosednessCriterion k}) args
  where k = read s :: Int
parseArgs c ("-pc"           : args) =
  parseArgs (c { _criterion = parallelClosednessCriterion}) args
parseArgs c ("-apc"      : s : args) =
  parseArgs (c { _criterion = almostParallelClosednessCriterion k}) args
  where k = read s :: Int
parseArgs c ("-dc"           : args) =
  parseArgs (c { _criterion = developmentClosednessCriterion }) args
parseArgs c ("-adc"      : s : args) =
  parseArgs (c { _criterion = almostDevelopmentClosednessCriterion k}) args
  where k = read s :: Int
parseArgs c ("-toyama81" : s : args) =
  parseArgs (c { _criterion = toyama81 k}) args
  where k = read s :: Int
parseArgs c ("-orl"      : s : args) =
  parseArgs (c { _criterion = RL.ordinaryRuleLabeling (_smt c) k}) args
  where k = read s :: Int
parseArgs c ("-rl"       : s : args) =
  parseArgs (c { _criterion = RL.ruleLabeling (_smt c) k }) args
  where k = read s :: Int
parseArgs c ("-prl"      : s : args) =
  parseArgs (c { _criterion = PRL.parallelRuleLabeling (_smt c) k }) args
  where k = read s :: Int
parseArgs c ("-oprl"     : s : args) =
  parseArgs (c { _criterion = PRL.ordinaryParallelRuleLabeling (_smt c) k }) args
  where k = read s :: Int
parseArgs c ("-cps"      : s : args) =
  parseArgs (c { _criterion = CPS.hm11 (RT.sn (_termination c)) k }) args
  where
    k = read s :: Int
parseArgs c ("-at12"     : s : args) =
  parseArgs (c { _criterion = aotoToyama12 (RT.sn (_termination c)) k }) args
  where
    k = read s :: Int
parseArgs c ("-at12C"    : s : args) =
  parseArgs (c { _criterion = aotoToyama12Completion (RT.sn (_termination c)) k }) args
  where
    k = read s :: Int
parseArgs c ("-pcps"     : s : args) =
  parseArgs (c { _criterion = PCPS.hm11 (RT.sn (_termination c)) k }) args
  where
    k = read s :: Int
parseArgs c ("-pcpcs"    : s : args) =
  parseArgs (c { _criterion = selfParallelCriticalPairClosing k }) args
  where
    k = read s :: Int
-- composable options
parseArgs c (opt : s : args)
  | "-prl-"   `isPrefixOf` opt = add prlWith "-prl-"
  | "-cps-"   `isPrefixOf` opt = add cpsWith "-cps-"
  | "-pcps-"  `isPrefixOf` opt = add pcpsWith "-pcps-"
  | "-pcpcs-" `isPrefixOf` opt = add pcpcsWith "-pcpcs-"
  where
    add thm lab =
      parseArgs (c { _criterion = thm tt smt (cutOpt lab) k}) args
    cutOpt h = deleteFirstsBy (==) opt h
    k = read s :: Int
    tt = RT.sn (_termination c)
    smt = _smt c
parseArgs _ _                        = help

main :: IO ()
main = do
  args <- getArgs
  parseArgs
    (Config { _criterion = strongClosednessCriterion 5,
              _smt = "z3",
              _termination = "NaTT.exe" })
    args
