module Redundancy where

import Data.List
import SMT
import TRS
import Rule
import Term
import Proof


-- { [t_i, t_{i-1},...,t1,s]) | s -> t1 -> .. -> ti for some i <= k }
at_most_k_steps :: Int -> TRS -> Term -> [([Term], [Rule])]
at_most_k_steps 0 _   s = [ ([s], []) ]
at_most_k_steps k trs s =
  nubBy eq (trss ++ [ (u : ts, nub (rule : rules)) | (ts@(t : _), rules) <- trss, (u, rule) <- reductsRules trs t ])
  where 
    eq ((v : _), _) ((w : _), _) = v == w
    eq _       _       = error "at_most_k_steps"
    trss = at_most_k_steps (k-1) trs s

joinSequence :: Int -> TRS -> Term -> Term -> Maybe [Term]
joinSequence k trs t u
  | t == u    = Just []
  | otherwise =
    case vss of
      vs : _ -> Just (init (tail vs))
      []     -> Nothing
    where
      trss = at_most_k_steps k trs t
      urss = at_most_k_steps k trs u
      vss = [ reverse ts ++ (v : us) | (v : ts, _) <- trss, (v' : us, _) <- urss, v == v' ]

rules_of_join_sequence :: Int -> TRS -> Term -> Term -> [TRS]
rules_of_join_sequence k trs t u
  | t == u    = [[]]
  | otherwise = [ nub (rules ++ rules') | (v : _, rules) <- trss, (v' : _, rules') <- urss, v == v']
  where trss = at_most_k_steps k trs t
        urss = at_most_k_steps k trs u

rulesOfToSequence :: Int -> TRS -> Term -> Term -> [TRS]
rulesOfToSequence k trs t u
  | t == u = [[]]
  | otherwise = [rules | (v : _, rules) <- at_most_k_steps k trs t, v == u]

rule_index' :: Int -> TRS -> Rule -> Maybe Int
rule_index' _ [] _ = Nothing
rule_index' i (rule' : rules) rule
  | Rule.variant rule rule' = Just i
  | otherwise               = rule_index' (i + 1) rules rule

rule_index :: TRS -> Rule -> Maybe Int
rule_index table rule = rule_index' 0 table rule

subset_of_S :: TRS -> TRS -> Formula
subset_of_S table trs = conj [ in_S table rule | rule <- trs ]

subset_of_FunS :: [String] -> Formula
subset_of_FunS fs = conj [ in_FunS f | f <- fs ]

in_FunS :: String -> Formula
in_FunS f = FVar ("inFunS_" ++ f)

in_S :: TRS -> Rule -> Formula
in_S table rule
  | Just i <- rule_index table rule = FVar ("in_S_" ++ show i)
  | otherwise                       = error "in_S"

encode_rule_removal :: Int -> TRS -> Formula
encode_rule_removal k trs =
  conj
    (neg (subset_of_S table trs) : 
     [ implies (in_S table rule) (subset_of_FunS (Rule.functions rule)) 
     | rule <- trs ] ++
     [ implies (subset_of_FunS (Term.functions l))
               (disj [ subset_of_S table ss
                     | ss <- rulesOfToSequence k trs l r])
     | (l, r) <- trs ])
  where table = trs

encode_nfm15 :: Int -> TRS -> Formula
encode_nfm15 k trs =
  conj [
    disj [
      subset_of_S table ss
      | ss <- rules_of_join_sequence k trs l r]
    | (l, r) <- trs]
  where table = trs

encode :: Int -> TRS -> Formula
encode k trs =
  conj [encode_nfm15 k trs, encode_rule_removal k trs]

decode_S :: TRS -> Model -> TRS
decode_S table model = 
  [ rule | rule <- table, eval_formula model (in_S table rule)]

reduce :: String -> Int -> TRS -> IO (Maybe Proof)
reduce smt k trsR = do
  m <- sat smt (Redundancy.encode k trsR)
  case m of
    Nothing -> return Nothing
    Just model -> do
      let trsS = decode_S trsR model
      return (Just ([Redundancy { _trsR = trsR, _trsS = trsS }], Subgoal trsS))
