open Format
open Formatx
open Term
open Substitution
open Overlap
open Rewriting
open Iroiro
open Arg
open Result

(* proof tree *)

type proof = 
  | KH12   of Rules.t (* R *) * Rules.t (* S *) * Rules.t (* CP_S(R) *)  * proof
  | HM11   of Rules.t (* R *) * Rules.t (* CPS'(R) *) * Rules.t (* CP(R) *)
  | HM13   of Rules.t (* R *) * Rules.t (* S *) * Rules.t (* CPS'(R,S) *)  * proof * proof
  | HM13CR of Rules.t (* R *) * Rules.t (* non-closed-CPS(R) *) 
  | VO94   of Rules.t (* R *) * Rules.t (* CP(R) *) (* development closedness *)
  | AYT09  of Rules.t (* R *) * Rules.t (* CP(R) *) (* development closedness a la Toyama *)
  | VO08   of Rules.t (* R *) * (Rule.t * int) list (* rule label *)
  | JK86   of Rules.t (* R *) * Rules.t (* CP_AC(R) *)

let print_label ppf (rule, n) =
  fprintf ppf "[%a] = %d" Rule.print rule n 

let rec print_proof ppf = function
  | HM11 (rs, cps, cp) ->
      fprintf ppf 
        "---- CR(R)@.\
         ---- [Hirokawa and Middeldorp, 2011]@.\
         R:@.%a@.\
         CPS'(R):@.%a@.\
         CP(R):@.%a@."
        Rules.print rs
        Rules.print cps
        Rules.print cp
  | KH12 (rs, ss, ecp, proof) ->
      fprintf ppf 
        "---- CR(R cup S)@.\
         ---- [Klein and Hirokawa, 2012]@.\
         R:@.%a@.\
         S:@.%a@.\
         CP_S(R):@.%a@.\
         CR(S):@.see below.@.@.%a"
	Rules.print rs Rules.print ss
	Rules.print ecp print_proof proof
  | HM13 (rs, ss, cps, proof1, proof2) ->
      fprintf ppf 
        "---- CR(R cup S)@.\
         ---- [Hirokawa and Middeldorp, 2013]@.\
         R:@.%a@.\
         S:@.%a@.\
         CPS(R,S):@.%a@.\
         CR(R):@.see below.@.\
         CR(S):@.see below.@.@.%a%a"
	Rules.print rs Rules.print ss
	Rules.print cps 
        print_proof proof1
        print_proof proof2
  | HM13CR (rs, cps) ->
      fprintf ppf 
        "---- CR(R)@.\
         ---- [Hirokawa and Middeldorp, 2013]@.\
         R:@.%a@.\
         non-closed-CPS(R):@.%a"
	Rules.print rs
	Rules.print cps 
  | JK86 (rs, cp) ->
      fprintf ppf 
        "---- CR(R)@.\
         ---- [Jouannaud and Kirchner, 1986]@.\
         R:@.%a@.\
         CP_AC(R):@.%a"
	Rules.print rs Rules.print cp
  | VO94 (rs, cp) ->
      fprintf ppf
        "---- CR(R)@.\
         ---- Development Closedness Theorem [van Oostrom, TCS 1994].@.\
         R:@.%a@.@.\
         CP(R):@.%a@.@."
        Rules.print rs Rules.print cp
  | AYT09 (rs, cp) ->
      fprintf ppf
        "---- CR(R)@.\
         ---- Extended Development Closedness Theorem [Aoto, Yoshida and Toyama, RTA 2009].@.\
         R:@.%a@.@.\
         CP(R):@.%a@.@."
        Rules.print rs Rules.print cp
  | VO08 (rs, a) ->
      fprintf ppf
        "---- CR(R)@.\
         ---- Rule Labeling [van Oostrom, RTA 2008].@.\
         R:@.%a@.@.\
         Label:@.%a@.@."
        Rules.print rs (print_list print_label "@.") a

let print_disproof ppf (rs, s, t) =
  fprintf ppf
    "---- Non-confluence check@.\
     R:@.%a@.@.\
     Unjoinable conversion:@.%a <->*@.%a@.@." 
    Rules.print rs
    Term.print s Term.print t

(* Auciliary functions *)

exception Failed
exception Disproved of Term.t * Term.t

let rec try_find f = function
  | [] -> raise Failed
  | x :: xs -> try f x with Failed -> try_find f xs

let rec pairwise_find p = function
  | [] -> raise Failed
  | x :: xs ->
      try 
        let y = List.find (p x) xs in (x, y) 
      with Not_found -> pairwise_find p xs

(* Global variables *)

let filenames  = ref []
let t          = ref None 
let tt         = ref "ttt2"
let actt       = ref "muterm -i"
let smt        = ref "minismt -m"
let wcrk       = ref 5
let rlvk       = ref 5
let rlck       = ref 5

let rec tcap rs = function
  | V _ -> V (fresh ())
  | F (f, ts) -> 
      let u = F (f, [ tcap rs t | t <- ts ]) in 
      if List.exists (fun (l, _) -> unifiable u l) rs then
        V (fresh ())
      else
        u

let intersect xs ys = List.exists (fun x -> List.mem x ys) xs

let commute1 n rs ss (s, t) = 
  let ss = reachable n rs [s]
  and ts = reachable n ss [t] in
  intersect ss ts

let commute n rs ss es =
  List.for_all (commute1 n rs ss) es

let joinable n rs es = commute n rs rs es

(** Confluence criteria *)

(* Knuth and Bendix' criterion (1970) *)
let kb70 rs = 
  let cp = cp rs in
  if not (Rsn.not_sn rs []) && 
     Sn.sn !tt rs [] = YES &&
     List.for_all (fun (s, t) -> nf rs s = nf rs t) cp
  then
    KH12 (rs, [], cp, HM11 ([], [], []))
  else
    raise Failed

(* Hirokawa and Middeldorp's criterion (JAR 2011) *)
let hm11 rs = 
  let p (s, t) = s <> t in
  let cps = cps ~p rs in
  let cp = cp rs in
  if Rules.left_linear rs && 
     not (Rsn.not_sn cps rs) && 
     joinable !wcrk rs cp &&
     Sn.sn !tt cps rs = YES
  then
    HM11 (rs, cps, cp)
  else 
    raise Failed

(* Development closedness theorem (Van Oostrom, TCS 1994) *)
let vo94 rs = 
  if Rules.left_linear rs then 
     let p (s, t) = List.mem t (multistep rs s) in
     let cp = cp rs in
     if List.for_all p cp then
       VO94 (rs, cp)
     else 
       raise Failed
  else
    raise Failed

(* Extended development closedness theorem (Aoto, Yoshida and Toyama, RTA 2009) *)
let root_overlap (_, p, _, _) = p = []

let closed rs o = 
  let s, t = cp_of_overlap o in
  if root_overlap o then 
    intersect (multistep rs s) (reachable !wcrk rs [t])
  else
    List.mem t (multistep rs s)

let ayt09 rs = 
  if Rules.left_linear rs &&
     List.for_all (closed rs) (overlap rs) then 
    AYT09 (rs, cp rs)
  else
    raise Failed

(* Van Oostrom's rule labeling (RTA 2008) *)

let rlv rs = 
  match Rlv.solve ~tool:!smt !rlvk rs with
  | Some a -> VO08 (rs, a)
  | None -> raise Failed

let rlc rs =
  match Rl.solve  ~tool:!smt !rlck rs with
  | Some a -> VO08 (rs, a)
  | None -> raise Failed

(* Non-confluence check *)

let unjoinable_by_tcap rs s t = 
  not (Substitution.unifiable (tcap rs s) (tcap rs t))

(*
let unjoinable_by_closure k rs s t = 
  let ss1 = reachable k rs [s]
  and ts1 = reachable k rs [t] in
  let ss2 = reachable 1 rs ss1
  and ts2 = reachable 1 rs ts1 in
  Listset.subset ss2 ss1 &&
  Listset.subset ts2 ts1 && 
  not (intersect ss2 ts2)

let unjoinable k rs s t =
  unjoinable_by_tcap rs s t ||
  unjoinable_by_closure 1 rs s t
*)

let rec ground = function
  | V x -> F (x, [])
  | F (f, ts) -> F (f, [ ground t | t <- ts ])

let critical_peak rs =
  [ ground (substitute mu l2) | (_, _), _, (l2, _), mu <- overlap rs ]

let non_confluent k rs =
  let f t = pairwise_find (unjoinable_by_tcap rs) (reachable k rs [t]) in
  try_find f (critical_peak (rs @ [ r, l | l, r <- rs]))

(* Klein and Hirokawa's criterion (LPAR 2012) *)

(* joinability of S-overlap *)

let kh12_overlap_aux rule1 rule2 =
  let l1, r1 = Rule.rename rule1
  and l2, r2 = Rule.rename rule2 in
  [ (l1, r1), p, (l2, r2), mgu (subterm_at p l2) l1
  | p <- function_positions l2; 
    unifiable (ren l1) (ren (subterm_at p l2));
    p <> [] || not (Rule.variant (l1, r1) (l2, r2)) ] 

let kh12_overlap rs =
  Listset.unique [ o | a <- rs; b <- rs; o <- kh12_overlap_aux a b ]

let kh12_cp rs = [ cp_of_overlap o | o <- kh12_overlap rs ]

let kh12 cr rs ss =
  let ecp =
    try kh12_cp rs with Not_unifiable -> raise Failed in
  if sno rs ss && 
     joinable !wcrk (rs @ ss) ecp &&
     Sn.sn !tt rs ss = YES
  then
    KH12 (rs, ss, ecp, cr ss)
  else
    raise Failed

let kh12_auto cr rs =
  let f ss = kh12 cr (Listset.diff rs ss) ss in
  try_find f [ ss | ss <- Listx.power rs; ss <> [] && not (Listset.subset rs ss) ]

(* Jouannaud and Kirchner's criterion (1986) *)

let jk86 rr =
  match Ac_rewriting.find_ac rr with
  | None -> raise Failed
  | Some (rs, ac_sym, ac, ac_all) ->
      if not (Sn.sn !actt ~ac_sym rs [] = YES) then 
	raise Failed
      else 
	let cps = Rules.flatten ac_sym
          (Ac_overlap.cp  ac_sym rs @
           Ac_overlap.cp2 ac_sym rs ac_all) in
	let cps_nf =
           [ Ac_rewriting.nf ac_sym rs s,
             Ac_rewriting.nf ac_sym rs t 
           | s, t <- cps ] in
	let nj =
          [ s,t | s,t <- Variant.rename_rules cps_nf; 
                  not (Ac_rewriting.equal ac_sym s t) ] in
        match nj with
        | [] -> JK86 (rr, Variant.rename_rules cps)
        | (s, t) :: _ -> raise (Disproved (s, t))

let jk86_maybe rr = 
  try jk86 rr with Disproved (_, _) -> raise Failed

(* Hirokawa and Middeldorp's criterion (2013) *)

(*
let hm13 cr rs ss =
  let p (s, t) = true (* s <> t *) in
  let union = Listx.unique (rs @ ss) in
  if Rules.left_linear rs && 
     Rules.left_linear ss &&
     commute !wcrk ss rs (mutual_cp rs ss) then
     let cps = mutual_cps ~p rs ss in
     if Sn.sn !tt cps union = YES then
       HM13 (rs, ss, cps, cr rs, cr ss)
     else 
       raise Failed
  else
   raise Failed

let hm13_auto cr rs =
  if Rules.left_linear rs then
    let f (rs1, rs2) = hm13 cr rs1 rs2 in
    let l = Listx.power rs in
    try_find f
      [ rs1, rs2
      | rs1 <- l; rs2 <- l;
        not (Listset.equal rs1 rs2);
        Listset.equal rs (rs1 @ rs2) ]
  else
    raise Failed

let hm13b_non_closed_overlap rs (s, t) =
  not (List.mem t (multistep rs s))

let hm13b_cps rs ss =
  Listx.unique
    (cps2 ~p:(hm13b_non_closed_overlap ss) rs ss @
     cps2 ~p:(hm13b_non_closed_overlap rs) ss rs)

let hm13b cr rs ss =
  let union = Listx.unique (rs @ ss) in
  if Rules.left_linear rs && 
     Rules.left_linear ss &&
     commute !wcrk ss rs (mutual_cp rs ss) then
     let cps = hm13b_cps rs ss in
     if Sn.sn !tt cps union = YES then
       if Listset.equal rs ss then
         HM13CR (rs, cps)
       else
         HM13 (rs, ss, cps, cr rs, cr ss)
     else 
       raise Failed
  else
   raise Failed

let hm13b_auto cr rs =
  if Rules.left_linear rs then
    let f (rs1, rs2) = hm13b cr rs1 rs2 in
    let l = Listx.power rs in
    try_find f
      [ rs1, rs2
      | rs1 <- l; rs2 <- l;
        not (Listset.equal rs1 rs2);
        Listset.equal rs (rs1 @ rs2) ]
  else
    raise Failed

let hm13b_cr rs =
  let cr rs = raise Failed in
  hm13b cr rs rs
*)

let hm13c_cps rs ss =
  let peaks =
    [ peak | o <- overlap2 rs ss; 
	     (t, s, u) as peak <- [critical_peak_of_overlap o];
	     not (intersect (multistep ss t) (reachable !wcrk rs [u])) ] @
    [ peak | o <- overlap2 ss rs; not (root_overlap o);
	     (t, s, u) as peak <- [critical_peak_of_overlap o];
	     not (List.mem u (multistep rs t)) ] in
  Listx.unique [ rule | t, s, u <- peaks; rule <- [s, t; s, u] ]

let hm13c cr rs ss =
  let union = Listx.unique (rs @ ss) in
  if Rules.left_linear rs && 
     Rules.left_linear ss &&
     commute !wcrk ss rs (mutual_cp rs ss) then
     let cps = hm13c_cps rs ss in
     if Sn.sn !tt cps union = YES then
       if Listset.equal rs ss then
         HM13CR (rs, cps)
       else
         HM13 (rs, ss, cps, cr rs, cr ss)
     else 
       raise Failed
  else
   raise Failed

let hm13c_auto cr rs =
  if Rules.left_linear rs then
    let f (rs1, rs2) = hm13c cr rs1 rs2 in
    let l = Listx.power rs in
    try_find f
      [ rs1, rs2
      | rs1 <- l; rs2 <- l;
        not (Listset.equal rs1 rs2);
        Listset.equal rs (rs1 @ rs2) ]
  else
    raise Failed

let hm13 rs =
  let cr rs = raise Failed in
  hm13c cr rs rs

let rec confluent1 rs = 
  try_find (fun f -> f rs)
    [ kb70; hm13; rlv; rlc; jk86_maybe; kh12_auto confluent1]

let confluent rs = 
  try_find (fun f -> f rs)
    [ kb70; hm13; rlv; rlc; jk86; kh12_auto confluent1]

(*
module M = Map.Make
  (struct
     type t = Rules.t
     let compare rs ss =
       Pervasives.compare
	 (List.sort Pervasives.compare rs)
	 (List.sort Pervasives.compare ss)
   end)

let memo = ref M.empty

let rec confluent rs = 
  try 
    M.find rs !memo
  with Not_found ->
    let b =
      try_find (fun f -> f rs)
	[ hm13c_cr; rlv; hm13c_auto confluent ] in
    (memo := M.add rs b !memo; b)
*)

let options = align [ 
  "-t",  Float  (fun x -> t := Some x), "<f> second timeout"; 
  "-tt", String (fun x -> tt := x),
     sprintf "<tool> termination tool (default: %s)" !tt;
  "-actt", String (fun x -> actt := x),
     sprintf "<tool> AC termination tool (default: %s)" !actt;
  "-smt", String (fun x -> smt := x),
     sprintf "<tool> QF_NIA SMT solver (default: %s)" !smt;
]

(* main *)
     
let short_usage = "Usage: saigawa [options] <file>\n"
let usage = "Confluence tool Saigawa v1.5\n" ^ short_usage ^ "Options are:"

let () = 
  parse options (fun it -> filenames := !filenames @ [it] ) usage;
  if List.length !filenames <> 1 then
    (eprintf "%s%!" short_usage; exit 1)
  else
    let filename = List.hd !filenames in
    let f_timeout () = printf "TIMEOUT@." in
    let f_prove () =
      let rs = Read.read_trs filename in
      if not (Rules.variable_condition rs) then
	(eprintf "Cannot handle TRS with extra variables.@."; exit 1);
      try 
        let s, t = non_confluent !wcrk rs in
        printf "NO@.@.%a@." print_disproof (rs, s, t)
      with Failed ->
	try 
          let proof = confluent rs in
          printf "YES@.@.%a@." print_proof proof
        with
        | Disproved (s, t) -> printf "NO@.@.%a@." print_disproof (rs, s, t)
        | Failed -> printf "MAYBE@."
    in
    Alarm.try_with_timeout ?timeout:!t f_timeout f_prove ()
