module Main where

import Data.List
import Term
import Rule
import TRS
import Rewriting
import ARIParser
import qualified SN
import PCP
import RL
import CPS
import KH12
import qualified Fork
import qualified TA
import Proof
import CPF
import System.Environment

data OutputFormat = Text | CPF deriving Eq

type RuleRemoval = TRS -> IO (Maybe Proof)

data Options = Options {
  _tcap :: Bool,
  _ta   :: Bool,
  _rl   :: Bool,
  _cps  :: Bool,
  _kh12 :: Bool,
  _kb70 :: Bool
}

-- version
version :: String
version = "0.12"

-- main
help :: IO ()
help = do
  putStrLn "hakusan [option]* [file.ari]"
  putStrLn ("version " ++ version)
  putStrLn ""
  putStrLn "Options:"
  putStrLn "  -TCAP     non-confluence test by tcap"
  putStrLn "  -TA       non-confluence test by tree automata techniques"
  putStrLn "  -R        rule removal by rule labeling (Shintani and Hirokawa 2024)"
  putStrLn "  -C        rule removal by critical pair systems (Shintani and Hirokawa 2024)"
  putStrLn "  -K        rule removal by Klein and Hirokawa's criterion (2012)"
  putStrLn "  -KB       rule removal by Knuth and Bendix' criterion (1970)"
  putStrLn "  -smt <s>  uses SMT solver <s>"
  putStrLn "  -k <k>    maximum length of join sequences ->^{<= k} . <-^{<= k}"
  putStrLn "  -cpf      outputs a CPF3 certificate for CeTA"
  putStrLn "  -l        relative termination by linear polynomial interpretations"
  putStrLn "  -m        relative termination by matrix interpretations with {0,1}-coefficients"
  putStrLn "  -loop     non-termination check by finding a loop"
  putStrLn "  -u        relative termination by upper triangular matrix interpretations with {0,1}-coefficients"
  putStrLn "  -v        version"
  putStrLn "  -h        help"
  putStrLn ""
  putStrLn "Default is: hakusan -smt z3 -k 5 -TCAP -R -C -K -KB"

show_result :: OutputFormat -> Signature -> TRS -> Proof -> String
show_result Text _ _ proof               = show_proof proof
show_result CPF _ _ proof@(_, Subgoal _) = show_proof proof
show_result CPF sig trs proof@(_, status) =
  Proof.show_result status ++ "\n" ++
  CPF.show_certificate (trs, sig, trs, proof)

apply :: [RuleRemoval] -> TRS -> IO (Maybe Proof)
apply [] _ = return Nothing
apply (criterion : criteria) trs = do
  m <- criterion trs
  case m of
    Nothing    -> apply criteria trs
    Just proof -> return (Just proof)

loop :: OutputFormat -> Bool -> [RuleRemoval] -> TRS -> IO Proof
loop _ _ _         []  = return ([Emptiness],Proved)
loop output_format disprovable criteria trs = do
  m1 <- apply criteria trs
  case m1 of
    Nothing -> return ([], Subgoal trs)
    Just (proof_steps1, status) -> 
      case status of
        Proved    -> return (proof_steps1, status)
        Disproved
          | disprovable -> return (proof_steps1, status)
          | otherwise   -> return ([], Subgoal trs)
        Subgoal subsystem -> do
          (proof_steps2, status') <-
            loop output_format (output_format == Text) criteria subsystem
          return (proof_steps1 ++ proof_steps2, status')        

analyze_confluence :: OutputFormat -> String -> Int -> Options -> String -> IO ()
analyze_confluence output_format smt k options file = do 
  (sig, trs, relative) <- read_file file
  case relative of
    _ : _ -> error "Relative rewriting is not supported."
    []    -> do
      proof <- loop output_format True (criteria_from smt k options sig) trs
      putStr (Main.show_result output_format sig trs proof)

prove_termination :: OutputFormat -> String -> SN.OrderName -> String -> IO ()
prove_termination format smt order file = do
  (_, trs_R, trs_S) <- read_file file
  m <- SN.terminating_with order smt trs_R trs_S
  case m of
    Nothing -> putStrLn "MAYBE"
    Just sn_proof -> do
      putStrLn "YES"
      case format of
        CPF -> putStr (show_terminationCertificate (trs_R ++ trs_S, trs_R, trs_S, sn_proof)) 
        Text -> do
          putStrLn ""
          putStr (show sn_proof)

prove_non_termination :: Int -> String -> IO ()
prove_non_termination k file = do
  (_, trs, _) <- read_file file
  case find (\rule -> not (Rule.well_formed rule)) trs of
    Just rule -> 
      putStr (unlines ["NO", "", "ill-formed rule:" ++ show_rule rule])
    Nothing ->
      case SN.non_terminating k trs of
        Nothing -> putStrLn "MAYBE"
        Just (s, t) ->
          putStr (unlines ["NO", "", show s ++ " ->+ " ++ show t])



-- Knuth and Bendix' criterion (1970)

diverging :: Fork.Fork -> Bool
diverging fork = t /= u
  where (t, u) = Fork.last_terms fork

kb70 :: String -> TRS -> IO (Maybe Proof)
kb70 smt trs = do
  m <- SN.terminating smt trs []
  case m of
    Nothing -> return Nothing
    Just sn_proof -> 
      let forks = [ Fork.fork_from trs cp | cp <- critical_peaks trs ] in
      case find diverging forks of
        Nothing -> 
          return (Just ([proof_step], Subgoal []))
            where proof_step = KH12 {
              _trsR = trs, 
              _trsS = [], 
              _relativeTerminationProof = sn_proof
            }
        Just fork ->
          return (Just ([proof_step], Disproved))
            where proof_step = TCAP trs fork

-- Proving non-confluence
{-
non_joinable_by_tcap :: Int -> Options -> TRS -> Fork.Peak -> Maybe Proof
non_joinable_by_tcap k o trs peak
  | _tcap o && Fork.non_joinable_by_tcap trs peak =
      case Fork.reconstruct_fork k trs peak of
        Nothing -> error "disprove_by_tcap"
        Just fork -> Just ([TCAP trs fork], Disproved)
  | otherwise = Nothing

non_joinable_by_ta :: Int -> Options -> TRS -> Fork.Peak -> Maybe Proof
non_joinable_by_ta k o trs peak@(t, _, u)
  | _ta o, Just (ta1, ta2) <- TA.non_joinable trs t u =
      case Fork.reconstruct_fork k trs peak of
        Nothing -> error "non_joinable_by"
        Just fork -> Just ([TA trs fork ta1 ta2], Disproved)
  | otherwise = Nothing

disprove :: Int -> Options -> TRS -> IO (Maybe Proof)
disprove k o trs =
  return
    (asum [ non_joinable peak
          | peak <- peaks, non_joinable <- list ])
  where
    peaks = Fork.peak_candidates k trs
    list = [non_joinable_by_tcap k o trs,
            non_joinable_by_ta   k o trs]
-}

size_pair :: ((Term, TA.TA Int), (Term, TA.TA Int)) -> Int
size_pair ((t, _), (u, _)) = Term.size t + Term.size u


unifiable_shape :: Term -> Term -> Bool
unifiable_shape (V _) _ = True
unifiable_shape _ (V _) = True
unifiable_shape (F f ss) (F g ts) =
  f == g && and [ unifiable_shape s t | (s, t) <- zip ss ts ]

disprove_ta' :: Int -> Int -> TRS -> [Term] -> Signature -> Maybe Proof
disprove_ta' _ _ _ [] _ = Nothing
disprove_ta' k m trs (s : ss) sig =
  case find p (Fork.unordered_pairs dtas) of
    Nothing -> disprove_ta' k m trs ss sig
    Just ((t, dta1), (u, dta2)) ->
      case Fork.reconstruct_fork k trs (t, s, u) of
        Nothing -> error "disprove"
        Just fork -> Just ([TA trs fork dta1 dta2], Disproved)
  where
    p ((_, dta1), (_, dta2)) = TA.disjoint dta1 dta2
    sig' = sig ++ [(f, 0) | f <- Term.functions s, lookup f sig == Nothing]
    dtas = [ (t, dta)
           | t <- k_step_reducts k trs s,
             Just dta <- [TA.rewrite_closure m trs t sig'] ]

disprove_ta :: Int -> Int -> Signature -> TRS -> IO (Maybe Proof)
disprove_ta k m sig trs =
  return (disprove_ta' k m trs (Fork.sources trs) sig)

disprove_tcap' :: Int -> TRS -> [Term] -> Maybe Proof
disprove_tcap' _ _ [] = Nothing
disprove_tcap' k trs (s : ss) =
  case find p (Fork.unordered_pairs tcaps) of
    Nothing -> disprove_tcap' k trs ss
    Just ((t, _), (u, _)) ->
      case Fork.reconstruct_fork k trs (t, s, u) of
        Nothing -> error "disprove"
        Just fork -> Just ([TCAP trs fork], Disproved)
  where
    p ((_, t'), (_, u')) = not (unifiable_shape t' u')
    tcaps = nub [ (t, Fork.tcap "x" trs t) | t <- k_step_reducts k trs s ]

disprove_tcap :: Int -> TRS -> IO (Maybe Proof)
disprove_tcap k trs =
  return (disprove_tcap' k trs (Fork.sources trs))



-- command line options

empty_options :: Options
empty_options = 
  Options {
    _tcap = False,
    _ta   = False,
    _rl   = False,
    _cps  = False,
    _kh12 = False,
    _kb70 = False
  }

default_options :: Options
default_options = 
  Options {
    _tcap = True,
    _ta   = True,
    _rl   = True,
    _cps  = True,
    _kh12 = True,
    _kb70 = True 
  }

cpf_options :: Options
cpf_options = 
  Options { 
    _tcap = True,
    _ta   = True,
    _rl   = True,
    _cps  = True,
    _kh12 = False,
    _kb70 = False 
  }

only_rl :: Options -> Bool
only_rl o = _rl o && not (_cps o) && not (_kh12 o) && not (_kb70 o)

guard :: Bool -> [a] -> [a]
guard True xs = xs
guard False _ = []

criteria_from :: String -> Int -> Options -> Signature -> [RuleRemoval]
criteria_from smt k o sig =
  guard (_tcap o) [disprove_tcap 1] ++
  guard (_ta o) [disprove_ta 1 10 sig] ++
  guard (_tcap o && k > 1) [disprove_tcap k] ++
  guard (_ta o && k > 1) [disprove_ta k 10 sig] ++
  guard (_kb70 o) [kb70 smt] ++
  guard (_rl o && k > 2) [RL.reduce_all smt 2] ++
  guard (_rl o) [RL.reduce_all smt k] ++
  guard (_cps o && k > 2) [CPS.reduce smt 2] ++
  guard (_cps o) [CPS.reduce smt k] ++
  guard (_kh12 o) [KH12.reduce smt k] ++
  guard (_rl o && not (only_rl o)) [RL.reduce smt k] -- ++
  -- guard (_tcap o && k > 1) [disprove_tcap k] ++
  --guard (_ta o && k > 1) [disprove_ta k 10 sig]


parse_args :: OutputFormat -> String -> Int -> Maybe Options -> [String] -> IO ()
parse_args _ _ _ _ ["-h"] = help
parse_args _ _ _ _ ["-v"] = putStrLn version
parse_args format smt _ _ ["-l", file] =
  prove_termination format smt SN.LinearOrder file
parse_args format smt _ _ ["-m", file] =
  prove_termination format smt SN.MatrixOrder file
parse_args format smt _ _ ["-u", file] =
  prove_termination format smt SN.UpperTriangularMatrixOrder file
parse_args _ smt k m ("-cpf" : args) = 
  parse_args CPF smt k m args
parse_args format _ k m ("-smt" : smt : args) =
  parse_args format smt k m args
parse_args format smt _ m ("-k" : s : args) =
  parse_args format smt (read s) m args
parse_args _ _ k _ ["-loop", file] =
  prove_non_termination k file
parse_args format smt k (Just o) ("-TCAP" : args) =
  parse_args format smt k (Just o{_tcap = True}) args
parse_args format smt k (Just o) ("-TA" : args) =
  parse_args format smt k (Just o{_ta = True}) args
parse_args format smt k (Just o) ("-R" : args) =
  parse_args format smt k (Just o{_rl = True}) args
parse_args format smt k (Just o) ("-C" : args) =
  parse_args format smt k (Just o{_cps = True}) args
parse_args format smt k (Just o) ("-K" : args) =
  parse_args format smt k (Just o{_kh12 = True}) args
parse_args format smt k (Just o) ("-KB" : args) =
  parse_args format smt k (Just o{_kb70 = True}) args
parse_args Text smt k Nothing [file] =
  analyze_confluence Text smt k default_options file
parse_args CPF smt k Nothing [file] =
  analyze_confluence CPF smt k cpf_options file
parse_args format smt k (Just o) [file] =
  analyze_confluence format smt k o file
parse_args format smt k Nothing args =
  parse_args format smt k (Just empty_options) args
parse_args _ _ _ _ _ = help

main :: IO ()
main = do
  args <- getArgs
  parse_args Text "z3" 5 Nothing args
