open Ls

open Term
open Result
open Step
open Join

(* evaluate commutativity *)

type theorem =
  | KB of (trs*trs)
  | JK of (trs*trs)
  | RL of (trs*trs)
  | MO of (trs*trs)
  | DC of (trs*trs)
  | NC of unit
  | SC of (trs*trs)
  | COM of (theorem list)

let proof_format r s thm_s = 
  Format.sprintf "%@%s\n--- R\n%s\n--- S\n%s@." thm_s
    (Trs.sprint_rules r)
    (Trs.sprint_rules s)
let pf = proof_format

let rec thm_description = function
  | KB (r,s) -> pf r s "Knuth and Bendix' criterion"
  | JK (r,s) -> pf r s "Jouannaud and Kirchner's criterion"
  | RL (r,s) -> pf r s "Rule Labeling"
  | MO (r,s) -> pf r s "Mutually Orthogonal"
  | DC (r,s) -> pf r s "Development Closedness"
  | NC ()    -> invalid_arg "thm_description"
  | SC (r,s) -> pf r s "Strictly Commuting"
  | COM thms -> Format.sprintf "%@Commutation Lemma\n\n" ^
                String.concat "" (List.map thm_description thms)

let show_proof thm = print_endline (thm_description thm)

(*=================== THEOREMS ==================== *)
let visible b thm =
  if b then thm () else MAYBE

let thm_MO r s =
  if Trs.left_linear (r@s) && [] = Match.cpp2_all r s
  then YES (MO (r,s))
  else MAYBE

(* Jouannaud and Kirchner 1986 *)
(* -----------------------------------------*)
let kwcomm k r s =
  let cpps = Match.cpp2_all r s in
  let find rr ss (t,s,u) = match jj_light k rr ss (t,u) with
    | YES ()   -> YES ()
    | NO (x,y) -> NO (x,s,y)
    | MAYBE    -> MAYBE in
  match sequence_with (find s r) cpps with
  | YES _  -> YES (NC ())
  | NO nfp -> NO nfp 
  | MAYBE  -> MAYBE

let wcomm r s = match kwcomm 3 r s with
  | MAYBE when !Ref.k > 3 -> kwcomm !Ref.k r s
  | other -> other

let kb2 r s =
  if not !Ref.kb then
    MAYBE
  else
    if not (Acrpo.terminate [] (r @ s)) then
      MAYBE
    else
      match wcomm r s with
      | YES _ -> YES (KB (r,s))
      | NO nf -> NO nf
      | MAYBE -> MAYBE

let kwcr k r = kwcomm k r r
let wcr r    = wcomm r r
let kb1 r    = kb2 r r

let thm_JK r s =
  if !Ref.jk then
    match Ac_cr.jk86 r s with
    | YES () -> YES (JK (r,s))
    | NO nf  -> NO nf
    | MAYBE  -> kb2 r s
  else
    kb2 r s


(* Development Closedness: Yoshida Aoto Toyama 2009 *)
(* -----------------------------------------*)
let thm_DC1 r s =
  List.for_all
    (fun (a,b) -> goto 1 (multi_step s) a b)
    (Match.cp_in r s)

let thm_DC2 r s =
  List.for_all (join_by [multi_stepS r] [multi_stepS s])
    (Match.cp2 s r)

let thm_DC r s = visible !Ref.dc
  begin
    fun () ->
      let thm r s = thm_DC1 r s && thm_DC2 r s in
      if thm r s || thm s r then
        YES (DC (r,s))
      else
        MAYBE
  end


(* weighted Rule Labeling: V Oostrom 2008, Aoto 2010 *)
(* -----------------------------------------*)
let thm_RL r s =  visible !Ref.rl
  begin
    fun () ->
      match Rl.solve_rlv 4 r s with
      | Some _ -> YES (RL (r,s))
      | None   when !Ref.k <= 4 -> MAYBE
      | None   ->
          match Rl.solve_rlv !Ref.k r s with
          | Some _ -> YES (RL (r,s))
          | None   -> MAYBE
  end


(* Geser 90 *)
(* -----------------------------------------*)
let jj_plus r s (t,u0) =
  match Step.rewrite s u0 with
  | [] -> MAYBE
  | us -> sequence_with (jj_light !Ref.k r s) [ t,u | u <- us ]

(* SN(R) & NE(S) & ->* . +<- ==> COMM(R,S) *)
let thm_SC r s = visible !Ref.sc
  begin
    fun () ->
      MAYBE
  (*
      if not (Trs.non_erase s || Acrpo.terminate r) then
        MAYBE
      else
        match sequence_with (jj_plus s r)
            (Match.cp2_all r s) with
        | YES _ -> YES (SC (r,s))
        | NO nf -> NO nf
        | MAYBE -> MAYBE
  *)
  end


(* TCAP approximation: Zankl, H., Felgenhauer, B., Middeldorp, 2011*)
(* -----------------------------------------*)
let noncomm r s =
  if !Ref.tcap then
    match kwcomm 4 r s with
    | NO nf -> NO nf
    | other -> MAYBE
  else
    MAYBE




(* &&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&& *) 
(*                     ALGORITHM                       *) 
(* &&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&& *) 



(* from commit: 1e29a93f2b5d408afe9cfa929ec4c3783237daa1 *)
open Num


(* size of powerset that subtracted 1 *)
let bound xs = (Int 2 **/ Int (List.length xs)) -/ Int 1

(* get subset with index of element of its power set
  i.e. decode (Int k) xs = nth k (powerset xs)

  example:
    decode (Int 0) [1;2;3] = []
    decode (Int 2) [1;2;3] = [2]
    decode (Int 7) [1;2;3] = [1;2;3]
*)
let rec decode k = function
  | [] -> []
  | _ :: xs when mod_num k (Int 2) =/ Int 0 ->
      decode (floor_num (k // Int 2)) xs
  | x :: xs ->
      x :: decode (floor_num (k // Int 2)) xs


(* caching *)
type peak = Term.t * Term.t * Term.t
let table : ((Term.trs * Term.trs) * (theorem,peak) result) list ref
  = ref []
(* let table = ref (Hashtbl.create 100) *)
let create_table n =
  ()
  (* table := Hashtbl.create n *)

let cache_switch = ref false
let cache_on  () = cache_switch := true
let cache_off () = cache_switch := false

let lookup r s =
  let rr,ss = Trs.sort r, Trs.sort s in
  try
    Some (List.assoc (rr,ss) !table)
    (* Some (Hashtbl.find !table (rr,ss)) *)
  with Not_found ->
    None

let entry r s result =
  if not !cache_switch then
    result
  else
    (table := ((Trs.sort r,Trs.sort s),result) :: !table;
    (* (Hashtbl.add !table (sort r,sort s) result; *)
     result)


let deflt rs = (rs, (bound rs, Int 0, Int 0))

(* divide set S into subsystem A and subsystems B1,B2,..
    such that Bi contains S \ A *)
let divide (r,(_,i,p)) =
  let rs = decode i r in
  let ss = Ls.del r rs @+ decode p rs in
  if p > bound rs || rs = [] || ss = [] ||
  subseteq r ss || subseteq r rs then
    None
  else
    Some (deflt rs, deflt ss)


(* debug *)
let count = ref 1
let get_count () =
  let n = !count in
  count := n + 1;
  n

let begin_degmsg r s =
  if !Ref.debug then
    Format.eprintf "[%d]Checking R,S\nR =\n%sS =\n%s@."
      (get_count ()) (Trs.sprint_rules r) (Trs.sprint_rules s)
  else
    ()
let end_degmsg result = 
  if !Ref.debug then
    match result with
    | YES _ -> Format.eprintf "> Commute\n\n@."
    | NO  _ -> Format.eprintf "> Not Commute\n\n@."
    | MAYBE -> Format.eprintf "> MAYBE\n\n@."
  else
    ()


(* check directly *)
let direct0 r s =
    Result.exists (fun thm -> thm r s)
      [ noncomm; thm_MO; thm_JK; thm_DC; thm_RL ]

let logging r s =
  begin_degmsg r s;
  let result = entry r s (direct0 r s) in
  end_degmsg result;
  result

(* API for extern modules *)
let direct r s =
  match lookup r s with
  | Some MAYBE -> MAYBE
  | Some yesno -> yesno
  | None       -> logging r s

(* means already searched *)
exception FFF
let direct_raise r s = 
  match lookup r s with
  | Some MAYBE -> raise FFF
  | Some yesno -> yesno
  | None       -> logging r s

  
let incr i = i +/ Int 1

(* r and s are trs, l1 and l2 mean upper bound of i and j,
   i and j indicate subset of trss r and s,
   p and q indicate subset of i and j
   (i u p = r & j u q = s & p include r-i, q include s-j)
*)
let rec commute' (r,(l1,i,p)) (s,(l2,j,q) as s0) =
  if r = [] || s = [] then
    invalid_arg "commute': empty"
  else
    try
      match direct_raise r s with
      | YES thm -> YES thm
      | NO nf   -> NO nf
      | MAYBE   -> check_div1 (r,(l1,incr i,p)) s0
      (* | MAYBE   -> check_div2 (r,(l1,i,p)) (s,(l1,incr j,q)) *)
    with FFF ->
      MAYBE
and check_div1 (r,(l1,i,p) as r0) (s,(l2,j,q) as s0) =
  (* R1 include R  where R = R1 u R2 and R2 include R-R1 *)
  if i >= l1 then
     MAYBE
  (* R2 include R1 where R = R1 u R2 and R2 include R-R1 *)
  else if p >= i then
    (* increase R1 *)
    check_div1 (r,(l1,incr i,Int 0)) (deflt s)
  else
    match divide r0 with
    (* TODO: improve algorithm *)
    | None -> check_div1 (r,(l1,incr i,p)) (deflt s)
    | Some ((rs1,_ as r1), (rs2,_ as r2)) ->
        match commute' r1 s0, commute' r2 s0 with
        | YES thm1, YES thm2 -> YES (COM [thm1;thm2])
        (* | _ -> check_div1 (r,(l1,i,incr p)) (deflt s)  *)
        | _ -> match check_div2 r0 (s,(l2,incr j,q)) with
          | MAYBE -> check_div1 (r,(l1,i,incr p)) (deflt s)
          | yesno -> yesno
and check_div2 (r,(l1,i,p) as r0) (s,(l2,j,q) as s0) =
  (* S1 include S  where S = S1 u S2 and S2 include S-S1 *)
  if j >= l2 then
    MAYBE
  (* S2 include S1 where S = S1 u S2 and S2 include S-S1 *)
  else if q >= j then
    (* increase S1 *)
    check_div2 r0 (s,(l2,incr j,Int 0))
  else
    match divide s0 with
    (* TODO: improve algorithm *)
    | None -> check_div2 r0 (s,(l2,j,incr q))
    | Some ((rs1,(m1,j1,q1) as s1), (rs2,(m2,j2,q2) as s2)) ->
        match commute' r0 s1, commute' r0 s2 with
        | YES thm1, YES thm2 -> YES (COM [thm1;thm2])
        | _ -> check_div2 r0 (s,(l2,j,incr q))

let commute r s =
  create_table (int_of_num (bound r) * int_of_num (bound s));
  if not (Trs.left_linear (r @ s)) then
    MAYBE
  else
    match noncomm r s with
    | NO nf -> NO nf
    | other -> commute' (deflt r) (deflt s)

let cr r = commute r r
(*
example
let rs = [ a,b; a,d; b,c; c,d; f[x],f[x] ];;
*)
