open Term
open Ls

(*************************************************)
(*                UNIFICATION                    *)
(*************************************************)

(** basic matching **)
let rec var_clash (x,t) = function
    | []          -> false
    | (y,s) :: ys when x = y
                  -> t <> s || var_clash (x,t) ys
    | _ :: ys     -> var_clash (x,t) ys

let consist sigma =
  List.for_all
    (fun s -> not (var_clash s sigma))
    sigma

(* Notice: Exception List.combine *)
let rec match_with sub p s = match p,s with
  | V x, _ -> (x,s) :: sub
  | F (f,ps), F (g,ss) when f = g
           -> [s | (a,b) <- List.combine ps ss;
                   s <- match_with sub a b]
  | _      -> failwith "variable clash"

let matcher p s =
  let matched = match_with [] p s in
  if consist matched then
    matched
  else
     failwith "variable clash"

let rec instance_of s p =
  try
    ignore (matcher p s); true
  with _ ->
    false

let matchers trs s =
  [(l,r),(matcher l s) | (l,r) <- trs; instance_of s l]

let rec variant s t = instance_of s t && instance_of t s

let variant_rule (l1,r1) (l2,r2) =
  let m x y = F ("",[x;y]) in
  variant (m l1 r1) (m l2 r2)


(****** unification *****)

let update_elist s ts = [substitute [s] a, substitute [s] b | (a,b) <- ts]
let update_sigma s sigma = [x, substitute [s] t | (x,t) <- sigma]

let not_mem a list = not (List.mem a list)

let rec unification sigma = function
  | []      -> sigma
  | e :: es ->
    let next s =
      unification
        (s :: update_sigma s sigma)
        (update_elist s es) in
  match e with
    | t, u   when t = u
      -> unification sigma es
    | V x, t when not_mem x (variables t)
      -> next (x,t)
    | t, V x when not_mem x (variables t)
      -> next (x,t)
    | F (f,ts1), F (g,ts2) when f = g
      -> unification
          sigma
          (List.combine ts1 ts2  @  es)
    | _ -> failwith "not unifiable"

let mgu es = unification [] es

let unifiable es =
  try
    ignore (mgu es); true
  with _ -> false


(** TCAP function **)
let rec tcap trs = function
  | V x      -> new_var ()
  | F (f,ts) -> let u = F (f, List.map (tcap trs) ts) in
    if List.exists
      (fun (l,_) -> unifiable [l,u])
      trs
    then
      new_var ()
    else
      u


(*************************************************)
(**************** Critical Pair ******************)
(*************************************************)
let rec simple_cp = function
  | []         -> []
  | (a,b)::cps when variant a b
               -> simple_cp cps
  | (a,b)::cps -> let cntx ts = F ("_",ts) in
    (a,b) :: List.filter
              (fun (c,d) -> not (variant (cntx [a;c]) (cntx [b;d])))
              (simple_cp cps)

(** functions about overlap **)

let is_overlap ((a,b),p,(c,d)) = 
  let var (x,y) = variables x @ variables y in
  let at = subterm_at p c in
  try 
    cap (var (a,b)) (var (c,d)) = [] (* without common variables *)
      &&
    not (is_variable at)
      &&
    unifiable [a, at] (* if position p is invalid, then exception arises *)
    (*
      &&
    not (p = [] && variant (F ("",[a;b])) (F ("",[c;d]))) (* if p is root, then rules are not variant *)
    *)
  with Failure _ -> false

(* return renamed overlaps if TRSs contain common vars *)
let overlaps r s =
  let trs =
    if cap (Trs.variables r) (Trs.variables s) = [] then s else Trs.rename s in
  [a, p, b | a <- r;
             (l2,_) as b <- trs;
             p <- function_positions l2;
             is_overlap (a, p, b)]


(** Critical Pair **)
let cp_overlap laps =
  [replace (substitute s l2) (substitute s r1) p, (* rewrite inner redex *)
  p, (* inner redex position *)
  substitute s l2, (* root term *)
  substitute s r2 (* rewrite outer redex *)
  | ((l1,r1),p,(l2,r2)) <- laps;
    unifiable [l1, subterm_at p l2];
    s <- [mgu [l1, subterm_at p l2]] (* cheating *)
  ]

(* remove variant critical pairs *)
let c_pp laps =
  uniq_by
    (fun (l1,p1,_,r1) (l2,p2,_,r2) ->
      variant_rule (l1,r1) (l2,r2))
    (cp_overlap laps)

(** Critical Peak **)

(* critical peaks for different TRSs *)
let cpp2 r s = c_pp (overlaps r s)

(* lhs is rewritten by root position *)
let cpp2r r s =
  List.map (fun (t,p,s,u) -> (u,p,s,t)) (cpp2 s r)


(* <-R- ><| -S-> U <-R- |>< -S-> *)
let cpp2_all r s = cup (cpp2 r s) (cpp2r r s)

(* critical peaks for the TRS *)
let cpp trs = cpp2 trs trs
let cpp_in r s =
  c_pp [r,p,a | (r,p,a) <- overlaps r s;
               p <> []]

let cpp_out r s =
  c_pp [r,p,a | (r,p,a) <- overlaps r s;
               p = []]


(** Critical Pair **)
(* sorted critical pairs *)
let c_p xs = List.map (fun (l,_,_,r) -> (l,r)) xs

(* critical pairs for different trs's *)
let cp2 r s = c_p (cpp2 r s)

(* lhs is rewritten by root position *)
let cp2r r s = c_p (cpp2r r s)

(* <-R- ><| -S-> U <-R- |>< -S-> *)
let cp2_all r s = c_p (cpp2_all r s)

(* critical pairs for a trs *)
let cp trs = 
  uniq_by (fun (s,t) (u,v) ->
      variant_rule (s,t) (u,v) ||
      variant_rule (s,t) (v,u))
    (cp2 trs trs)

let cp_in r s = c_p (cpp_in r s)

let cp_out r s = c_p (cpp_out r s)

(* remove identical critical pair *)
let rm_ident1     = List.filter (fun (x,_,_,y) -> x <> y)
let cpp2_ r s     = cpp2 r s     |> rm_ident1
let cpp2r_ r s    = cpp2_ r s    |> rm_ident1
let cpp2_all_ r s = cpp2_all r s |> rm_ident1
let cpp_in_ r s   = cpp_in r s   |> rm_ident1
let cpp_out_ r s  = cpp_out r s  |> rm_ident1
let cpp_ r        = cpp2 r r     |> rm_ident1

let rm_ident2     = List.filter (fun (x,y) -> x <> y)
let cp2_ r s     = cp2 r s     |> rm_ident2
let cp2r_ r s    = cp2_ r s    |> rm_ident2
let cp2_all_ r s = cp2_all r s |> rm_ident2
let cp_in_ r s   = cp_in r s   |> rm_ident2
let cp_out_ r s  = cp_out r s  |> rm_ident2
let cp_ r        = cp2 r r     |> rm_ident2

(** Simultaneous Critical Pairs **)

let rec unconnected_with (l,r) ss =
  List.for_all
    (fun (p,_) -> not (List.mem p (function_positions l)))
    ss

let add_position p rdxps =
  List.map (List.map (fun (q,a) -> (p@q,a))) rdxps

(* generate list of all overlapping simultaneous set of t *)
let rec simultaneous_sets rs t =
  let tt = if Listset.inter (variables t) (Trs.variables rs) <> []
           then Term.rename t else t in
  match tt with
    V x -> []
  | F (f,ss) as t ->
     (* get list of simultaneous sets that every element is a union of
        some simultaneous sets on arguments *)
     let sss = Listx.index ss
               |> List.map (fun (i,s) -> add_position [i] (simultaneous_sets rs s))
               |> List.filter ((<>) [])
               |> Listx.pi
               |> List.map List.concat in
     let merge_redexp (p,rl) =
       List.map
         (fun xs -> (p,rl) :: xs)
         (List.filter (fun xs -> unconnected_with rl xs) sss) in
     let rdxps = List.map
                   (fun rl -> [],rl)
                   (List.filter
                      (fun (l,_) -> unifiable [t,l])
                      rs) in
     (* list of simultaneous set consists of *)
     Listx.unique @@ List.filter ((<>) [])
       (List.map Listset.singleton rdxps      (* root redex positions, *)
        @ Listx.concat_map merge_redexp rdxps (* joint redex positions, *)
        @ sss) (* and redex positions under arguments *)

(* sort_redex_patterns [([1],x); ([],y); ([0],z)]
   = [([1],x); ([0],y); ([],z)] *)
let sort_redex_patterns rdxps =
  List.sort
     (fun (p,_) (q,_) -> Pervasives.compare (List.length q) (List.length p))
     rdxps

(* assume that
   (1) t is linear, (2) rdxps is simultaneous,
   and (3) t overlap l at p for all (p,(l,_)) in rdxps *)
let rec sumup_rdxps t rdxps =
  List.fold_left
    (fun u (p,(l,_)) -> substitute (mgu [l,subterm_at p u]) u)
    t
    rdxps

let print_rdxps rdxps =
  List.iter
    (fun (p,rl) -> Format.eprintf "(%a, %a)@." print_position p Rule.print rl)
  rdxps

let remove_variant_overlaps rdxps =
  List.filter
    (fun (x,p,_,lr1) ->
      match x with [q,lr2] -> not (p = q && Rule.variant lr1 lr2) | _ -> true)
    rdxps


(** Simultaneous critical pairs *)
(* [Rule] -> Rule -> [[Pos,Rule],Pos,Rule] *)
let simultaneous_overlap_outer rs (l,r) =
  simultaneous_sets rs l
  |> Listx.concat_map
       (fun rdxps ->
         let ordered_rdxps = sort_redex_patterns rdxps in
         let (p,(l2,_)) as top = List.hd (List.rev ordered_rdxps) in
         if unifiable [l2,(subterm_at p l)]
         then
           [(rdxps, [], sumup_rdxps l rdxps, (l,r))]
         else
           [])
  |> remove_variant_overlaps

let simultaneous_overlaps_outer r s =
  let rr,ss =
    if Listset.inter (Trs.variables r) (Trs.variables s) = []
    then r,s
    else Trs.rename r, Trs.rename s in
  Listx.concat_map (simultaneous_overlap_outer rr) ss

(* [Rule] -> Rule -> [[Pos,Rule],Pos,Term,Rule] *)
let simultaneous_overlap_inner rs (l,r) =
  let mem (p,lr) rdxps =
    List.exists (fun (q,lr2) -> p = q && Rule.variant lr lr2) rdxps
  in
  overlaps [l,r] rs
  |> List.filter (fun (_,p,_) -> p <> [])
  |> Listx.concat_map (fun (_,p,(l2,r2)) ->
         (* lookup with simultaneous set of overlay redex t *)
         let t = substitute (mgu [l,subterm_at p l2]) l2 in
         List.map
           (fun rdxps -> rdxps, p, sumup_rdxps t rdxps ,(l,r))
           (List.filter (mem ([],(l2,r2))) (simultaneous_sets rs t)))
  |> remove_variant_overlaps

let simultaneous_overlaps_inner r s =
  let rr,ss =
    if Listset.inter (Trs.variables r) (Trs.variables s) = []
    then r,s
    else Trs.rename r, Trs.rename s in
  Listx.concat_map (simultaneous_overlap_inner rr) ss

let rec rewrite_redex_patterns t rdxps =
  let ordered_rdxps = sort_redex_patterns rdxps in
  (* rewrite term t at botom to top to ensure /complete/ development *)
  List.fold_left
    (fun u (p,lr) -> Rewriting.rewrite_with lr p u)
    t
    ordered_rdxps

(* [Rule] -> [Rule] -> [[Pos,Rule],Pos,Term,Rule] *)
let simultaneous_overlaps r s =
  simultaneous_overlaps_inner r s @ simultaneous_overlaps_outer r s
  
let simultaneous_cp2 r s =
  Listx.unique
    (List.map
       (fun (rdxps,p,t,lr) ->
         (* t and redex patterns may have common variables *)
         let tt = rename t in
         rewrite_redex_patterns tt rdxps, Rewriting.rewrite_with lr p tt)
       (simultaneous_overlaps r s))

let simultaneous_cp r = simultaneous_cp2 r r


(** Parallel critical pairs **)

let rec parallel p q = match (p,q) with
    i::p', j::q' -> i <> j || parallel p' q'
  | _            -> false
                
let parallel_positions positions =
  List.for_all
    (fun p -> List.for_all
                (fun q -> p = q || parallel p q)
                positions)
    positions

let parallel_overlaps_aux r s =
  let pos xs = List.map fst xs in
  simultaneous_overlaps r s
  |> List.filter (fun (rdxps,p,t,rl) ->
         match rdxps with
           [[],_] -> true
         | _      -> p = [] && parallel_positions (pos rdxps))
  |> List.map (fun (rdxps,p,t,rl) ->
         if p = [] then (rdxps,t,rl) 
         else
           match rdxps with
             [[],rl2] -> ([p,rl],t,rl2)
           | _        -> invalid_arg "argument is not parallel overlap")

let parallel_overlaps r s =
  let make_peak t lr rdxps =
    (rewrite_redex_patterns t rdxps,
    List.map (fun (x,y) -> (y,x)) rdxps,
    t,
    (lr,[]),
    Rewriting.rewrite_with lr [] t) in
  Listx.unique
    (Listx.concat_map
       (fun (rdxps,t,lr) ->
         (* t and redex patterns may have common variables *)
         let tt = rename t in
         let parallel_rdxps = Listset.remove [] (Listx.power rdxps) in
         List.map (make_peak tt lr) parallel_rdxps)
       (parallel_overlaps_aux r s))

(* [Rule] -> [Rule] -> [Term,[Pos],Term,Term] *)
let parallel_cpp2 r s =
  Listx.unique
    (List.map
      (fun (t,pats,s,_,u) -> (t,(List.map snd pats),s,u))
      (parallel_overlaps r s))

let parallel_cpp r = parallel_cpp2 r r

let parallel_cp2 r s = List.map (fun (t,_,_,u) -> (t,u)) (parallel_cpp2 r s)

let parallel_cp r = parallel_cp2 r r


(** (weakly) Orthogonality **)

(* mutually orthogonal *)
let is_WMO r s =
  Trs.left_linear (r @ s) &&
  List.for_all (fun (s,t) -> s = t)
    (cp2_all r s)

(* orthogonal *)
let is_OH trs =
  is_WMO trs trs



(**********************************************)
(**************** Comparison ******************)
(**********************************************)
let compare s t =
  if variant s t then
    0
  else
    Pervasives.compare s t
