module SN where

import Data.List
import Term
import Rule
import TRS
import SMT
import Linear
import Matrix
import UpperTriangularMatrix

data OrderName = LinearOrder | MatrixOrder | UpperTriangularMatrixOrder

data Order = Linear Linear.Algebra | Matrix Matrix.Algebra | UpperTriangularMatrix UpperTriangularMatrix.Algebra

data SNProof = 
    RIsEmpty
  | MonotoneReductionPair TRS TRS Order

instance Show SNProof where
  show RIsEmpty =
    unlines 
      ["-- emptiness",
       "",
       "Since R is empty, R/S is terminating."]
  show (MonotoneReductionPair trs_R trs_S (Linear algebra)) =
    unlines 
      ["-- linear polynomial interpretations",
       "",
       "Consider the TRSs R and S",
       "",
       "  R:",
       show_trs trs_R,
       "  S:",
       show_trs trs_S,
       "The linear polynomial interpretation on N with",
       "",
       Linear.show_algebra algebra,
       "strictly orients R and weakly orients S.",
       "Hence, R/S is terminating."]
  show (MonotoneReductionPair trs_R trs_S (Matrix algebra)) =
    unlines 
      ["-- 2x2 matrix interpretations",
       "",
       "Consider the TRSs R and S",
       "",
       "  R:",
       show_trs trs_R,
       "  S:",
       show_trs trs_S,
       "The matrix interpretation on NxN with",
       "",
       Matrix.show_algebra algebra,
       "strictly orients R and weakly orients S.",
       "Hence, R/S is terminating."]
  show (MonotoneReductionPair trs_R trs_S (UpperTriangularMatrix algebra)) =
    unlines 
      ["-- 2x2 upper triangular matrix interpretations",
       "",
       "Consider the TRSs R and S",
       "",
       "  R:",
       show_trs trs_R,
       "  S:",
       show_trs trs_S,
       "The matrix interpretation on NxN with",
       "",
       UpperTriangularMatrix.show_algebra algebra,
       "strictly orients R and weakly orients S.",
       "Hence, R/S is terminating."]

side_condition :: Signature -> Formula
side_condition sig = Matrix.side_condition sig
--side_condition sig = Linear.side_condition sig
--side_condition sig = UpperTriangularMatrix.side_condition sig


gt :: Term -> Term -> Formula
gt s t = Matrix.gtA s t
--gt s t = Linear.gtA s t
--gt s t = UpperTriangularMatrix.gtA s t

geq :: Term -> Term -> Formula
geq s t = Matrix.geqA s t
--geq s t = Linear.geqA s t
--geq s t = UpperTriangularMatrix.geqA s t

decode :: Model -> Signature -> Order
decode model sig = Matrix (Matrix.decode_algebra model sig)
--decode model sig = Linear (Linear.decode_algebra model sig)
--decode model sig = Matrix (UpperTriangularMatrix.decode_algebra model sig)

terminating_with :: OrderName -> String -> TRS -> TRS -> IO (Maybe SNProof) 
terminating_with _ _ [] _  = return (Just RIsEmpty)
terminating_with LinearOrder smt rs ss = do
  m <- Linear.terminating smt rs ss
  case m of
    Nothing -> return Nothing
    Just a  -> return (Just (MonotoneReductionPair rs ss (Linear a)))
terminating_with MatrixOrder smt rs ss = do
  m <- Matrix.terminating smt rs ss
  case m of
    Nothing -> return Nothing
    Just a  -> return (Just (MonotoneReductionPair rs ss (Matrix a)))
terminating_with UpperTriangularMatrixOrder smt rs ss = do
  m <- UpperTriangularMatrix.terminating smt rs ss
  case m of
    Nothing -> return Nothing
    Just a  -> return (Just (MonotoneReductionPair rs ss (UpperTriangularMatrix a)))

terminating :: String -> TRS -> TRS -> IO (Maybe SNProof)
terminating smt rs ss = terminating_with MatrixOrder smt rs ss


data NablaTerm = NV String Int | NF String [NablaTerm] deriving (Eq, Show)

type NablaSubst = [(String, (Int, NablaTerm))]


shift :: Int -> NablaTerm -> NablaTerm
shift k (NV x i)  = NV x (i + k)
shift k (NF f ts) = NF f [ shift k t | t <- ts ]

nabla_substitute :: NablaTerm -> NablaSubst -> NablaTerm
nabla_substitute t@(NV x i) sigma 
  | Just (j, u) <- lookup x sigma, i >= j = shift (i - j) u
  | otherwise = t
nabla_substitute (NF f ts) sigma =
  NF f [ nabla_substitute t sigma | t <- ts ]

occur_check :: String -> Int -> NablaTerm -> Bool
occur_check x i (NV y j) = x == y && i <= j
occur_check x i (NF _ ts) = any (occur_check x i) ts

gtdot :: String -> Int -> NablaTerm -> Bool
gtdot x i (NV y j) = x > y || (x == y && i > j)
gtdot x i t@(NF _ _) = not (occur_check x i t)

nabla_domain :: NablaSubst -> [(String, Int)]
nabla_domain sigma = nub [ (x, i) | (x, (i, _)) <- sigma ]

nabla_compose :: NablaSubst -> NablaSubst -> NablaSubst
nabla_compose sigma tau =
  [ (x, (i, nabla_substitute (nabla_substitute (NV x i) sigma) tau))
  | (x, i) <- nub (nabla_domain sigma ++ nabla_domain tau) ]

phi :: NablaTerm -> Term
phi (NV x 0) = V x
phi (NV x i) = V (x ++ "__" ++ show i)
phi (NF f ts) = F f [ phi t | t <- ts ]

phi_subst :: [(NablaTerm, NablaTerm)] -> Subst
phi_subst ps = [ (x, phi t) | (NV x 0, t) <- ps ]   

semi_unify' :: [(NablaTerm, NablaTerm)] -> [(NablaTerm, NablaTerm)] -> Maybe Subst
semi_unify' ps1 [] = Just (phi_subst ps1)
semi_unify' ps1 ((NV x i, NV y j) : ps2)
  | x == y && i == j = semi_unify' [] (ps2 ++ ps1)
semi_unify' ps1 ((NF f ss, NF g ts) : ps2)
  | f == g = semi_unify' [] (zip ss ts ++ ps2 ++ ps1)
  | otherwise = Nothing
semi_unify' ps1 (p@(NV x i, t) : ps2)
  | gtdot x i t =
    if ps1 == ps1' then 
      semi_unify' (ps1 ++ [p]) ps2'
    else
      semi_unify' [p] (ps2' ++ ps1')
    where
      tau = [(x, (i, t))]
      ps1' = [ (nabla_substitute u tau, nabla_substitute v tau) 
            | (u, v) <- ps1 ]
      ps2' = [ (nabla_substitute u tau, nabla_substitute v tau) 
            | (u, v) <- ps2 ]
semi_unify' _ ((NV x i, t@(NF _ _)) : _)
  | occur_check x i t = Nothing
semi_unify' ps1 ((s, t) : ps2) = semi_unify' ps1 ((t, s) : ps2)
    
nabla_term :: Term -> NablaTerm
nabla_term (V x) = NV x 0
nabla_term (F f ts) = NF f [ nabla_term t | t <- ts ] 


-- if semi_unify s t = Just (sigma, tau) then s sigma tau = t sigma
semi_unify :: Term -> Term -> Maybe Subst
semi_unify s t = semi_unify' [] [(shift 1 (nabla_term s), (nabla_term t))]


forward_unfolding :: TRS -> TRS
forward_unfolding trs =
  [ (Term.substitute l1 sigma,
     Term.substitute (Term.replace r1 r2 p) sigma)
  | (l1, r1) <- trs1,
    p <- positions r1,
    let r1_p = subterm_at r1 p,
    (l2, r2) <- trs2,
    Just sigma <- [mgu r1_p l2] ]
  where
    trs1 = TRS.rename "x" 1 trs
    trs2 = TRS.rename "y" 1 trs

backward_unfolding :: TRS -> TRS
backward_unfolding trs =
  [ (Term.substitute (Term.replace l2 l1 p) sigma,
     Term.substitute r2 sigma)
  | (l2, r2) <- trs1,
    p <- positions l2,
    let l2_p = subterm_at l2 p,
    (l1, r1) <- trs2,
    Just sigma <- [mgu r1 l2_p] ]
  where
    trs1 = TRS.rename "x" 1 trs
    trs2 = TRS.rename "y" 1 trs

unfold :: TRS -> TRS
unfold trs =
  nubBy Rule.variant 
    (trs ++
     [ Rule.rename "x" 1 rule
     | rule <- forward_unfolding trs 
       -- ++ backward_unfolding 
     ])

find_semi_unifiable_term :: Term -> [Term] -> Maybe Subst
find_semi_unifiable_term _ [] = Nothing
find_semi_unifiable_term s (t : ts) =
  case semi_unify s t of
    Nothing -> find_semi_unifiable_term s ts
    Just sigma -> Just sigma


find_loop :: TRS -> [(Term,Term)] -> [(Int,Term,Term)] -> Maybe (Term, Term)
find_loop _ _ [] = Nothing
find_loop trs checked ((k, s, t) : pairs)
  | Just sigma <- find_semi_unifiable_term s (subterms t) = 
    Just (Term.substitute s sigma, Term.substitute t sigma)
--    Just (s, t, sigma)
  | k == 0 || elem (s, t) checked = find_loop trs checked pairs
  | otherwise =
      find_loop trs ((s, t) : checked) (pairs ++ [ (k - 1, s, u) | u <- us ])
      where us = reducts trs t

non_terminating :: Int -> TRS -> Maybe (Term,Term)
non_terminating 0 _ = Nothing 
non_terminating k trs = 
  find_loop trs [] [ (k - 1, l, r) | (l, r) <- unfold (unfold trs) ]

-- >>> trs = read_trs "f(s(x), y) -> f(x,s(x)) f(x,s(y)) -> f(y,x)"
-- >>> non_terminating 1 trs
-- Just (f(x1,s(f(s(x2),x3))),f(f(x2,s(x2)),x1),[("x1",s(f(s(x2__1),x3__1)))])
-- >>> s = shift 1 $ nabla_term $ read_term "f(x,b())"
-- >>> t = nabla_term $ read_term "f(a(),x)"
-- >>> (s, t)
-- >>> semi_unify' [] [(s, t)]
-- (NF "f" [NV "x" 1,NF "b" []],NF "f" [NF "a" [],NV "x" 0])
-- Nothing
--
