module Fork where

import Data.List
import Text.Read (readMaybe)
import Term
import Rule
import TRS
import Rewriting
import PCP

type Peak = (Term, Term, Term)

-- -p->_rule t
type Step = (Position, Rule, Term)

-- s -> ... -> t
type Sequence = (Term, [Step])

-- t *<- s ->* u
type Fork = ([Step], Term, [Step])

last_term :: Term -> [Step] -> Term
last_term s []    = s
last_term _ steps = t
  where (_, _, t) = last steps

last_terms :: Fork -> (Term, Term)
last_terms (left, s, right) =
  (last_term s left, last_term s right)

-- { (p, rule, t) | s -p->_rule t }
possible_steps :: TRS -> Term -> [Step]
possible_steps trs s =
  [ (p, rule, replace s (Term.substitute r sigma) p)
  | p <- positions s,
    rule@(l, r) <- trs,
    Just sigma <- [match l (subterm_at s p)] ]

reduction_sequence :: TRS -> Term -> [Step]
reduction_sequence trs s =
  case possible_steps trs s of
    []                 -> []
    step@(_, _, t) : _ -> step : reduction_sequence trs t

max_index :: [String] -> Int
max_index fs =
  maximum (0 : [ n | 'c' : s <- fs, Just n <- [readMaybe s] ])

grounding :: [String] -> [Term] -> Subst
grounding fs ts =
  [ (x, F ("c" ++ show i) []) | (x, i) <- zip xs [max_index fs + 1 ..] ]
  where
    xs = nub [ x | t <- ts, x <- Term.variables t ]

non_trivial :: CP -> Bool
non_trivial (t, _, _, s, _, u) = s /= t && s /= u && t /= u

identical_cp :: CP -> CP -> Bool
identical_cp (t1,_,_,_,_,u1) (t2,_,_,_,_,u2) =
  Rule.variant (t1, u1) (t2, u2) || 
  Rule.variant (t1, u1) (u2, t2)

identical_fork :: Fork -> Fork -> Bool
identical_fork fork1 fork2 =
  (t1, u1) == (t2, u2) || (t1, u1) == (u2, t2)
  where
    (t1, u1) = last_terms fork1
    (t2, u2) = last_terms fork2

forward_pairs :: TRS -> TRS
forward_pairs trs =
  [ (Term.substitute (replace l2 l1 p) sigma,
     Term.substitute r2 sigma)
  | rule1 <- trs,
    let (l1, r1) = Rule.rename "x" 1 rule1,
    rule2 <- trs,
    not (Rule.left_linear rule2),
    let (l2, r2) = Rule.rename "y" 1 rule2,
    p <- function_positions l2,
    Just sigma <- [mgu r1 (subterm_at l2 p)] ] ++
  [ (Term.substitute l1 sigma,
     Term.substitute (replace r1 r2 p) sigma)
  | rule1 <- trs,
    not (Rule.left_linear rule1),
    let (l1, r1) = Rule.rename "x" 1 rule1,
    rule2 <- trs,
    let (l2, r2) = Rule.rename "y" 1 rule2,
    p <- function_positions r1, p /= [],
    Just sigma <- [mgu (subterm_at r1 p) l2] ]
    
unfold :: Int -> TRS -> TRS
unfold 0 trs = trs
unfold n trs = 
  unfold (n - 1) (TRS.unique (trs ++ forward_pairs trs))

forks' :: TRS -> [Fork]
forks' trs =
  [ ([(p, rule_L, t')], s', [([], rule_R, u')])
  | (t, rule_L, p, s, rule_R, u) <- nubBy identical_cp (critical_peaks trs),
    let sigma = grounding (TRS.functions trs) [s,t,u],
    let s' = Term.substitute s sigma,
    let t' = Term.substitute t sigma,
    let u' = Term.substitute u sigma ]

forks'' :: TRS -> [(Term,Term)]
forks'' trs =
  [ (t, u)
  | (t, _, _, _, _, u) <- nubBy identical_cp (critical_peaks trs) ]
 
joinable :: Int -> TRS -> Term -> Term -> Bool
joinable k trs t u =
  intersect (at_most_k_steps k trs t) (at_most_k_steps k trs u) /= []

reconstruct_steps :: Int -> TRS -> Term -> Term -> Maybe [Step]
reconstruct_steps _ _ s t
  | s == t = Just []
reconstruct_steps 0 _ _ _ = Nothing
reconstruct_steps k trs s t =
  reconstruct_steps' k trs (possible_steps trs s) t

reconstruct_steps' :: Int -> TRS -> [Step] -> Term -> Maybe [Step]
reconstruct_steps' _ _ [] _ = Nothing
reconstruct_steps' k trs (step@(_,_,s) : other_steps) t =
  case reconstruct_steps (k - 1) trs s t of
    Nothing -> reconstruct_steps' k trs other_steps t
    Just steps -> Just (step : steps)

reconstruct_fork :: Int -> TRS -> Peak -> Maybe Fork
reconstruct_fork k trs (t, s, u)
  | Just left  <- reconstruct_steps k trs s t,
    Just right <- reconstruct_steps k trs s u = Just (left, s, right)
  | otherwise = Nothing


-- trs must be terminating 
fork_from :: TRS -> CP -> Fork
fork_from trs (t, rule_L, p, s, rule_R, u) =
  ((p,  rule_L, t) : reduction_sequence trs t, s,
   ([], rule_R, u) : reduction_sequence trs u)

tcap' :: String -> Int -> TRS -> Term -> (Int, Term)
tcap' x k _   (V _) = (k + 1, V (x ++ show k))
tcap' x k trs (F f ts)
  | any (\(l, _) -> unifiable u l) trs = (k + 1, V (x ++ show k))
  | otherwise = (m, u)
  where 
    (m, us) = tcap_list x k trs ts
    u = F f us

peak_size :: Peak -> Int
peak_size (t, _, u) = Term.size t + Term.size u

unordered_pairs :: [a] -> [(a,a)]
unordered_pairs [] = []
unordered_pairs (x : ys) = [ (x, y) | y <- ys ] ++ unordered_pairs ys

sources :: TRS -> [Term]
sources trs =
  nub [ Term.substitute s (grounding fs [s])
      | u <- us,
        p <- function_positions u,
        let u' = subterm_at u p,
        t <- ts,
        Just sigma <- [mgu t u'],
        let s = Term.rename "x" 1 (Term.substitute u sigma) ]
  where
    fs = TRS.functions trs
    ts = nub [ Term.rename "x" 1 l'
             | (l, r) <- trs, 
               (l', r') <- [(l, r),(r, l)], 
               Rule.well_formed (l', r') ]
    us = nub [ Term.rename "y" 1 l | (l,_) <- trs ]


tcap_list :: String -> Int -> TRS -> [Term] -> (Int, [Term])
tcap_list _ k _   []       = (k, [])
tcap_list x k trs (t : ts) = (n, u : us)
    where 
      (m, u)  = tcap' x k trs t
      (n, us) = tcap_list x m trs ts

tcap :: String -> TRS -> Term -> Term
tcap x trs t = u
  where (_, u) = tcap' x 0 trs t

non_joinable_by_tcap :: TRS -> Peak -> Bool
non_joinable_by_tcap trs (t, _, u) = 
  not (unifiable (tcap "x" trs t)
                 (tcap "y" trs u))
