open Term
open Ls

(*************************************************)
(*                UNIFICATION                    *)
(*************************************************)

(** basic matching **)
(* Notice: Exception List.combine *)
let rec match_with sub p s = match p,s with
  | V x, _ -> (x,s) :: sub
  | F (f,ps), F (g,ss) when f = g
           -> [s | (a,b) <- List.combine ps ss;
                   s <- match_with sub a b]
  | _      -> failwith "variable clash"

let matcher p s =
  let matched = match_with [] p s in
  if consist matched then
    matched
  else
     failwith "variable clash"

let rec instance_of s p =
  try
    ignore (matcher p s); true
  with _ ->
    false

let matchers trs s =
  [(l,r),(matcher l s) | (l,r) <- trs; instance_of s l]

let rec variant s t = instance_of s t && instance_of t s

let variant_rule (l1,r1) (l2,r2) =
  let m x y = F ("",[x;y]) in
  variant (m l1 r1) (m l2 r2)


(****** unification *****)

let update_elist s ts = [subst [s] a, subst [s] b | (a,b) <- ts]
let update_sigma s sigma = [x, subst [s] t | (x,t) <- sigma]

let not_mem a list = not (List.mem a list)

let rec unification sigma = function
  | []      -> sigma
  | e :: es ->
    let next s =
      unification
        (s :: update_sigma s sigma)
        (update_elist s es) in
  match e with
    | t, u   when t = u
      -> unification sigma es
    | V x, t when not_mem x (vars t)
      -> next (x,t)
    | t, V x when not_mem x (vars t)
      -> next (x,t)
    | F (f,ts1), F (g,ts2) when f = g
      -> unification
          sigma
          (List.combine ts1 ts2  @  es)
    | _ -> failwith "not unifiable"

let mgu es = unification [] es

let unifiable es =
  try
    ignore (mgu es); true
  with _ -> false


(** TCAP function **)
let rec tcap trs = function
  | V x      -> new_var ()
  | F (f,ts) -> let u = F (f, List.map (tcap trs) ts) in
    if List.exists
      (fun (l,_) -> unifiable [l,u])
      trs
    then
      new_var ()
    else
      u
(* let tcap trs t = let c = tcap trs t in *)
  (* Format.printf "  %a ==> %a@." Term.print t Term.print c; c *)
(*************************************************)
(**************** Critical Pair ******************)
(*************************************************)
let rec simple_cp = function
  | []         -> []
  | (a,b)::cps when variant a b
               -> simple_cp cps
  | (a,b)::cps -> let cntx ts = F ("_",ts) in
    (a,b) :: List.filter
              (fun (c,d) -> not (variant (cntx [a;c]) (cntx [b;d])))
              (simple_cp cps)

(** functions about overlap **)

let is_overlap ((a,b),p,(c,d)) = 
  let var (x,y) = vars x @ vars y in
  let at = subterm_at c p in
  try 
    cap (var (a,b)) (var (c,d)) = [] (* without common variables *)
      &&
    not (is_var at)
      &&
    unifiable [a, at] (* if position p is invalid, then exception is thrown *)
    (*
      &&
    not (p = [] && variant (F ("",[a;b])) (F ("",[c;d]))) (* if p is root, then rules are not variant *)
    *)
  with Failure _ -> false

(* return renamed overlaps if TRSs contain common vars *)
let overlaps r s =
  let trs = if cap (Trs.vars r) (Trs.vars s) = []
    then s
    else Trs.rename s in
  [a, p, b | a <- r;
             (l2,_) as b <- trs;
             p <- poss_fun l2;
             is_overlap (a, p, b)]


(** Critical Pair **)
let cp_overlap laps =
  [replace (subst s l2) (subst s r1) p, (* rewrite inner redex *)
  p, (* inner redex position *)
  subst s l2, (* root term *)
  subst s r2 (* rewrite outer redex *)
  | ((l1,r1),p,(l2,r2)) <- laps;
    unifiable [l1, subterm_at l2 p];
    s <- [mgu [l1, subterm_at l2 p]] (* cheating *)
  ]

(* remove variant critical pairs *)
let c_p_ laps =
  uniq_by
    (fun (l1,p1,_,r1) (l2,p2,_,r2) ->
      variant_rule (l1,r1) (l2,r2))
    (cp_overlap laps)

(** Critical Peak **)
(* extract critical peak *)
let c_pp laps = List.map
  (fun (l,_,t,r) -> (l,t,r))
   (c_p_ laps)

(* critical peaks for different TRSs *)
let cpp2 r s = c_pp (overlaps r s)

(* lhs is rewritten by root position *)
let cpp2r r s = [a,t,b | b,t,a <- cpp2 s r]

(* <-R- ><| -S-> U <-R- |>< -S-> *)
let cpp2_all r s = cup (cpp2 r s) (cpp2r r s)

(* critical peaks for the TRS *)
let cpp trs = cpp2 trs trs
let cpp_in r s =
  c_pp [r,p,a | (r,p,a) <- overlaps r s;
               p <> []]

let cpp_out r s =
  c_pp [r,p,a | (r,p,a) <- overlaps r s;
               p = []]


(** Critical Pair **)
(* sorted critical pairs *)
let c_p xs = List.map (fun (l,_,r) -> (l,r)) xs

(* critical pairs for different trs's *)
let cp2 r s = c_p (cpp2 r s)

(* lhs is rewritten by root position *)
let cp2r r s = c_p (cpp2r r s)

(* <-R- ><| -S-> U <-R- |>< -S-> *)
let cp2_all r s = c_p (cpp2_all r s)

(* critical pairs for a trs *)
let cp trs = 
  uniq_by (fun (s,t) (u,v) ->
      variant_rule (s,t) (u,v) ||
      variant_rule (s,t) (v,u))
    (cp2 trs trs)

let cp_in r s = c_p (cpp_in r s)

let cp_out r s = c_p (cpp_out r s)

(* remove identical critical pair *)
let rm_ident1     = List.filter (fun (x,_,y) -> x <> y)
let cpp2_ r s     = cpp2 r s     |> rm_ident1
let cpp2r_ r s    = cpp2_ r s    |> rm_ident1
let cpp2_all_ r s = cpp2_all r s |> rm_ident1
let cpp_in_ r s   = cpp_in r s   |> rm_ident1
let cpp_out_ r s  = cpp_out r s  |> rm_ident1
let cpp_ r        = cpp2 r r     |> rm_ident1

let rm_ident2     = List.filter (fun (x,y) -> x <> y)
let cp2_ r s     = cp2 r s     |> rm_ident2
let cp2r_ r s    = cp2_ r s    |> rm_ident2
let cp2_all_ r s = cp2_all r s |> rm_ident2
let cp_in_ r s   = cp_in r s   |> rm_ident2
let cp_out_ r s  = cp_out r s  |> rm_ident2
let cp_ r        = cp2 r r     |> rm_ident2

(** (weakly) Orthogonality **)

(* mutually orthogonal *)
let is_WMO r s =
  Trs.left_linear (r @ s) &&
  List.for_all (fun (s,t) -> s = t)
    (cp2_all r s)

(* orthogonal *)
let is_OH trs =
  is_WMO trs trs



(**********************************************)
(**************** Comparison ******************)
(**********************************************)
let compare s t =
  if variant s t then
    0
  else
    Pervasives.compare s t
