open Term
open Substitution
open Rewriting

let filter (b,rs) = b, List.filter Rule.left_linear rs

(** addition and elimination depend on
    J. Nagele, B. Felgenhauer, A. Middeldorp,
    - "Improving Automatic Confluence Analysis of Rewrite Systems by Redundant Rules"
    - "CSI: New Evidence"
 **)

let rewrite_aux s (l,r) p =
  let t = subterm_at p s in
  if unifiable l t
  then
    let sub = mgu l t in
    Some (sub, substitute sub (replace s r p))
  else
    None

let getSome = function
    Some x -> x
  | None   -> failwith "getSome error"

(* return the rule l1σ -> r1σ[r2σ]_p *)
let new_rules (l1,r1) lr2 =
  let ps = function_positions r1 in
  let xs = List.map (rewrite_aux r1 lr2) ps
           |> List.filter (fun x -> x <> None)
           |> List.map getSome in
  [l1,r1] @ List.map (fun (sub, t) -> substitute sub l1, t) xs

(* compute all forward closures of rs1 w.r.t. rs2 *)
let forwarding_aux rs1 rs2 =
  Listx.concat_map
    (fun lr1 ->
      Listx.concat_map (fun lr2 -> new_rules lr1 lr2) rs2)
    rs1

(* forwarding preserves rewrite steps *)
let rec forwarding k rs =
  if k <= 0 then true, rs
  else
    forwarding (k-1) (Listx.unique (forwarding_aux rs rs))


let joins k rs ((l1,r1),p,(l2,r2),sub) =
  let subst t = substitute sub t in
  let t = replace (subst l2) (subst r1) p in
  let u = subst r2 in
  Listx.concat_map
    (fun common -> [t, common; u, common])
    (Listset.inter (reachable k rs [t]) (reachable k rs [u]))

(* adding join sequences does not preserve rewrite steps *)
let add_joins k rs =
  match Overlap.overlap rs with
    [] -> true, rs
  | xs -> let rs' = Listx.concat_map (joins k rs) xs
                  |> List.filter (fun (l,r) -> l <> r) in
          rs' = [], Listx.unique (rs@rs')
  

let redundant_with step1 step2 rr (l,r) =
  let ss = Listset.remove (l,r) rr in
  Listset.intersect (step1 ss l)(step2 ss r) 

(* redundant rules elimination w.r.t. rewrite step preserves rewrite steps *)
let rec reduce k rs =
  try
    let step1 rr s = reachable k rr [s] in
    let step2 _ s  = [s] in
    let lr = List.find (redundant_with step1 step2 rs) rs in
    reduce k (Listset.remove lr rs)
  with _ ->
    true, rs


(* redundant rules elimination w.r.t. conversion does not preserve rewrite steps *)
let rec reduce_joins k rs =
  try
    let step rr s = reachable k rr [s] in
    let lr = List.find (redundant_with step step rs) rs in
    reduce_joins k (Listset.remove lr rs)
  with _ ->
    false, rs


(* critical pairs removed redundants (assume R = S) *)
let cpp2_pos r s =
  if not (Listset.equal r s) then
    failwith "Trans.cpp2_pos: detected R <> S"
  else
    let others trs = Listset.diff (Listset.union r s) trs in
    [t,
     p,
     substitute s l2,
     u
    | ((l1,r1),p,(l2,r2)) <- Match.overlaps r s;
      Match.unifiable [l1, subterm_at p l2];
      s <- [Match.mgu [l1, subterm_at p l2]];
      t <- [replace (substitute s l2) (substitute s r1) p];
      u <- [substitute s r2];
      not (joinable 4 (others [l1,r1; l2,r2]) (l1,r1));
      not (joinable 4 (others [l1,r1; l2,r2]) (l2,r2))
    ]
    |> Ls.uniq_by
         (fun (a,p1,_,b) (c,p2,_,d) ->
           Rule.variant (a,b) (c,d))

let cpp2 r s =
  cpp2_pos r s |> List.map (fun (t,_,s,u) -> (t,s,u))

let cpp2r s r =
  cpp2_pos r s |> List.map (fun (t,_,s,u) -> (u,s,t))

let cp2 r s =
  cpp2 r s |> List.map (fun (t,s,u) -> (t,u))

let cp2r r s =
  cpp2r r s |> List.map (fun (t,s,u) -> (t,u))

let cpp2_all r s =
  Listset.union
    (cpp2 r s)
    (cpp2r r s)

let cp2_all r s =
  Listset.union
    (cp2 r s)
    (cp2r r s)

let cp_in r s =
  cpp2_pos r s
  |> List.filter (fun (t,p,s,u) -> p <> [])
  |> List.map (fun (t,_,_,u) -> (t,u))

(*
let _ =
    let r = [f[a],f[b]; f[a],f[d]; d,c;] in
    let s = [f[a],f[b]; f[a],f[d]; b,d; d,c;] in
    let rr = [f[d],b; a,b] in
    let ss = [d,a; g[a],c] in
    cp2_all s s;;
*)

(* output helper *)

let is_transformed org_rs new_rs =
  if Trs.variant org_rs new_rs then
    ()
  else
    (Format.printf "NOTE: original TRS is@.";
    Trs.print_rules org_rs;)


