module RTermination
  ( RTermination
  , sn
  , looping
  , WST
  , toWST
  , showWSTInput
  , theory
  ) where

import Data.List
import System.IO
import System.Process
import Term
import Rule
import TRS
import Result

type RTermination = TRS -> TRS -> IO Result


-- A term t is non-terminating with respect to R/S if
-- R contains rule of form l -> C[l sigma] such that
-- x in Var(x sigma) cap Var(C) for some variable x.
-- See Lemma 10 of Hirokawa and Middeldorp (JAR 2011).

surroundingVariables :: Term -> Position -> [String]
surroundingVariables _        []      = []
surroundingVariables (F _ ts) (i : p) = 
  nub [ x | (j, tj) <- zip [0..] ts, 
            x <- if i == j then surroundingVariables tj p else Term.variables tj ]
surroundingVariables _ _ = error "surroundingVariables"

loopInstance :: (Term, Term, Position) -> Bool
loopInstance (l, r, p) =
  case match l (subtermAt r p) of
    Nothing -> False
    Just sigma -> any check (surroundingVariables r p)
      where
        check x = elem x (Term.variables (Term.substitute (V x) sigma)) 

-- True means that R/S is non-terminating. 
-- False means that we do not know.
looping :: TRS -> TRS -> Bool
looping [] _  = False
looping rs ss =
  any (\(l, r) -> Term.subsume l r) rs ||
  any loopInstance [ (l, r, p) | (l, r) <- ss, p <- positions r ]

-- check termination of R/S by command tt
sn :: String -> RTermination
sn tt rs ss
  | looping rs ss = return (NO "R/S loops")
  | otherwise     = run tt (toWST rs ss)

-- execute external termination prover
type WST      = ([Variable], [Theory], [Rule], [WeakRule])
type Variable = String
type Theory   = (String, [String])
type WeakRule = Rule

toWST :: TRS -> TRS -> WST
toWST strict weak =
  (TRS.variables (weak ++ strict)
  , theories
  , strict
  -- , weak \\ (strict ++ weak')
  , weak
  )
  where
    -- (weak',theories) = theory weak
    theories = []

theory :: TRS -> (TRS, [Theory])
theory trs = associativeCommutativeTheory trs

associativeCommutativeTheory :: TRS -> (TRS, [Theory])
associativeCommutativeTheory trs =
  let acs   = acTheory trs binary
      as    = aTheory (trs \\ acs) binary
      cs    = cTheory (trs \\ acs) binary
      acSym = definedSymbols acs
      aSym  = definedSymbols as
      cSym  = definedSymbols cs
  in (nub (acs ++ as ++ cs),
      [ (s, f:fs) | (s, f:fs) <- [("AC", acSym), ("A", aSym), ("C", cSym)] ])
  where
    binary = nub [f | (F f ts, _) <- trs, length ts == 2]

acTheory, aTheory, cTheory :: TRS -> [String] -> TRS
acTheory trs df = concat
  [ [lr1 | lr1 <- trs,
           lr2 <- ac,
           Rule.variant lr1 lr2]
  | ac <- acs, TRS.subsume trs ac ]
  where
    acs = acRules df

aTheory trs df = concat
  [ [lr1 | lr1 <- trs,
           lr2 <- a,
           Rule.variant lr1 lr2]
  | a <- as, TRS.subsume trs a ]
  where
    as = aRules df

cTheory trs df =
  [ lr | lr <- trs, [c] <- cs,
         Rule.variant lr c ]
  where
    cs = cRules df

acRules, aRules, aRRules, aLRules, cRules :: [String] -> [TRS]
-- AC: C-rules and A-rules
acRules sym =
  [ cs ++ concat ass
  | f <- sym,
    cs <- cRules [f],
    let aLs = aLRules [f],
    let aRs = aRRules [f],
    ass <- [aLs, aRs, aLs++aRs]]
aLRules sym =
  [[(f[ x, f[y,z]], f[f[x,y], z])] | s <- sym, let f = F s]
    where
      (x,y,z) = (V "x", V "y", V "z")
aRRules sym =
  [[(f[f[x,y], z], f[ x, f[y,z]])] | s <- sym, let f = F s]
    where
    (x,y,z) = (V "x", V "y", V "z")
-- A: {(x+y)+z -> x+(y+z), x+(y+z) -> (x+y)+z}
aRules sym
  | length asl /= length asr
    = error "aRules"
  | otherwise
    = zipWith (++) (aLRules sym) (aRRules sym)
  where
      asl = aLRules sym
      asr = aRRules sym
-- C: x+y -> y+x
cRules sym =
  [[(f[x,y], f[y,x])] | s <- sym, let f = F s]
  where
    (x,y) = (V "x", V "y")


showWSTInput :: WST -> String
showWSTInput (vs,ts,rs,wrs) = unlines [
  declarVars,
  declarTheory,
  declarRules
  ]
  where
    declar dec body = unlines (["(" ++ dec] ++ body ++ [")"])
    declarVars = declar "VAR" [unwords vs]
    declarTheory = 
      case ts of
        [] -> ""
        _  -> declar "THEORY" [declar t body | (t,body) <- ts]
    weakRules =  [ "  " ++ show l ++ " ->= " ++ show r | (l,r) <- wrs]
    declarRules = declar "RULES" (weakRules ++ lines (TRS.showTRS rs))

execute :: String -> String -> IO (String, String)
execute tool input = do
  (Just hin, Just hout, Just herr, _) <-
    createProcess (proc tool ["/dev/stdin"]) {
      std_in = CreatePipe,
      std_out = CreatePipe,
      std_err = CreatePipe }
  hPutStr hin input
  hClose hin
  errMsg <- hGetContents herr
  msg <- hGetContents hout
  return (msg, errMsg)

writeError :: a -> String -> IO a
writeError x s = hPutStrLn stderr s >> return x

run :: String -> WST -> IO Result
run tool input = do
  -- putStrLn ("-- input --\n" ++ showWSTInput input ++ "\n")
  (s,e) <- execute tool (showWSTInput input)
  -- putStrLn ("-- output --\n" ++ s ++ "\n")
  v <- firstLine (lines s) e
  case v of
    "YES"   -> return (YES tool)
    "NO"    -> return (NO tool)
    "MAYBE" -> return MAYBE
    _       -> error ""
  where
    firstLine []     e = writeError "" (unlines [tool ++ " failed:", e])
    firstLine ("":_) e = writeError "" (unlines [tool ++ " failed:", e])
    firstLine (s:_)  _ = return $ head (words s)
