open Ls

open Term
open Result
open Step
open Join

(** DEBUG utility **)
let logger s =
  if !Ref.verbose then Format.eprintf "%s@." s else ()

(* evaluate commutativity *)

type theorem =
  | KB of (Trs.t*Trs.t)
  | JK of (Trs.t*Trs.t)
  | RL of (Trs.t*Trs.t)
  | MO of (Trs.t*Trs.t)
  | DC of (Trs.t*Trs.t)
  | SCP of (Trs.t*Trs.t)
  | PC of (Trs.t*Trs.t)
  | UC of (Trs.t*Trs.t)
  | OC of (Trs.t*Trs.t)
  | CPC of (Trs.t*Trs.t)
  | NC of unit
  | SC of (Trs.t*Trs.t)
  | COM of (theorem list)

let rules_of_theorem = function
  | KB x  -> x
  | JK x  -> x
  | RL x  -> x
  | MO x  -> x
  | DC x  -> x
  | SCP x -> x
  | PC x -> x
  | UC x -> x
  | OC x -> x
  | CPC x -> x
  | SC x  -> x
  | _ -> failwith "Commute.rules_of_theorem"

let proof_format r s thm_s = 
  Format.sprintf "%@%s\n--- R\n%s\n--- S\n%s@." thm_s
    (Trs.sprint_rules r)
    (Trs.sprint_rules s)

let rec theorem_to_string thm =
  match thm with
    COM thms  -> Format.sprintf "%@%s\n\n" (thm_description thm)
                 ^ String.concat "" (List.map theorem_to_string thms)
  | _ -> let r,s = rules_of_theorem thm in
         proof_format r s (thm_description thm)
and thm_description = function
  | KB _  -> "Knuth and Bendix' criterion"
  | JK _  -> "Jouannaud and Kirchner's criterion"
  | RL _  -> "Rule Labeling"
  | MO _  -> "Mutually Orthogonal"
  | DC _  -> "Development Closedness"
  | SCP _ -> "Simultaneous Critical Pair"
  | PC _ -> "Parallel Closedness"
  | UC  _ -> "Upside Closedness"
  | OC  _ -> "Outside Closedness"
  | CPC _ -> "CP-closing systems"
  | NC _  -> invalid_arg "thm_description"
  | SC _  -> "Strongly Commuting"
  | COM _ -> "Commutation Theorem"

let show_proof thm = print_endline (theorem_to_string thm)

(*=================== THEOREMS ==================== *)
let visible b thm =
  if b then thm () else MAYBE


let thm_MO r s = visible !Ref.mo
  begin
    fun () ->
      logger ("applying " ^ thm_description (MO (r,s)));
      if Trs.left_linear (r@s) && [] = Match.cpp2_all r s
      then YES (MO (r,s))
      else MAYBE
  end

(* Jouannaud and Kirchner 1986 *)
(* -----------------------------------------*)
let kwcomm k r s =
  let cpps = Match.cpp2_all r s in
  let find rr ss (t,_,s,u) = match jj_light k rr ss (t,u) with
    | YES ()   -> YES ()
    | NO (x,y) -> NO (x,s,y)
    | MAYBE    -> MAYBE in
  match sequence_with (find s r) cpps with
  | YES _  -> YES (NC ())
  | NO nfp -> NO nfp 
  | MAYBE  -> MAYBE

let wcomm r s = match kwcomm 3 r s with
  | MAYBE when !Ref.k > 3 -> kwcomm !Ref.k r s
  | other -> other

let kb2 r s = visible !Ref.kb
  begin
    fun () ->
      logger ("applying " ^ thm_description (KB (r,s)));
      if not (Acrpo.terminate [] (r @ s)) then
        MAYBE
      else
        match wcomm r s with
        | YES _ -> YES (KB (r,s))
        | NO nf -> NO nf
        | MAYBE -> MAYBE
  end

let kwcr k r = kwcomm k r r
let wcr r    = wcomm r r
let kb1 r    = kb2 r r

let thm_JK r s =
  if !Ref.jk then
    begin
      logger ("applying " ^ thm_description (JK (r,s)));
      match Ac_cr.jk86 r s with
      | YES () -> YES (JK (r,s))
      | NO nf  -> NO nf
      | MAYBE  -> kb2 r s
    end
  else
    kb2 r s


(* Development Closedness: Yoshida Aoto Toyama 2009 *)
(* -----------------------------------------*)
let thm_DC1 r s =
  List.for_all
    (fun (t,u) ->
      List.mem u (Rewriting.multistep s t))
    (Match.cp_in r s)

let thm_DC2 r s =
  List.for_all
    (Rewriting.joinable_by (Rewriting.multistep r) (Rewriting.multistep s))
    (Match.cp2 s r)

let thm_DC r s = visible !Ref.dc
  begin
    fun () ->
      logger ("applying " ^ thm_description (DC (r,s)));
      let thm r s = thm_DC1 r s && thm_DC2 r s in
      if thm r s || thm s r then
        YES (DC (r,s))
      else
        MAYBE
  end


(* weighted Rule Labeling: V Oostrom 2008, Aoto 2010 *)
(* -----------------------------------------*)
let thm_RL r s =  visible !Ref.rl
  begin
    fun () ->
      logger ("applying " ^ thm_description (RL (r,s)));
      if Coll_rl.solve_rlv 4 r s then
        YES (RL (r,s))
      else if !Ref.k <= 4 then
        MAYBE
      else if Coll_rl.solve_rlv !Ref.k r s then
        YES (RL (r,s))
      else
        MAYBE
  end


(* Geser 90 *)
(* -----------------------------------------*)
let jj_plus r s (t,u0) =
  match rewrite s u0 with
  | [] -> MAYBE
  | us -> sequence_with (jj_light !Ref.k r s) [ t,u | u <- us ]

(* SN(R) & NE(S) & ->* . +<- ==> COMM(R,S) *)
let thm_SC r s = visible !Ref.sc
  begin
    fun () ->
      MAYBE
  (*
      if not (Trs.non_erase s || Acrpo.terminate r) then
        MAYBE
      else
        match sequence_with (jj_plus s r)
            (Match.cp2_all r s) with
        | YES _ -> YES (SC (r,s))
        | NO nf -> NO nf
        | MAYBE -> MAYBE
  *)
  end

let thm_SC r s = visible !Ref.sc
  begin
    fun () ->
      logger ("applying " ^ thm_description (SC (r,s)));
      if not (Trs.linear r && Trs.linear s) then MAYBE
      else
        let strongly_closed rr ss (s,t) =
          Listset.intersect (Rewriting.reachable 1 rr [s])
                            (Rewriting.reachable !Ref.k ss [t]) in
        let cp1 = Match.cp2_all r s in
        let cp2 = Match.cp2_all s r in
        let flip xs = List.map (fun (s,t) -> (t,s)) xs in
        if
          List.for_all (strongly_closed r s) (Listx.unique (cp1 @ flip cp1))
          &&
          List.for_all (strongly_closed s r) (Listx.unique (cp2 @ flip cp2))
        then
          YES (SC (r,s))
        else
          MAYBE
  end

(* Okui 1998 *)
let thm_SCP r s = visible !Ref.sic
  begin
    fun () ->
      logger ("applying " ^ thm_description (SCP (r,s)));
      let step rs t = Rewriting.reachable !Ref.k rs [t]
      and multi rs t = Rewriting.multistep rs t in
      (* performance for confluence *)
      let b =
        if Trs.variant r s then
          List.for_all
            (Rewriting.joinable_by (step s) (multi r))
            (Match.simultaneous_cp2 r r)
        else
          let scp1 = Match.simultaneous_cp2 r s in
          let scp2 = Match.simultaneous_cp2 s r in
          List.for_all (Rewriting.joinable_by (step s) (multi r)) scp1
          && List.for_all (Rewriting.joinable_by (step r) (multi s)) scp2
      in if Trs.left_linear r && b then YES (SCP (r,s)) else MAYBE
  end


(* Oyamaguchi and Hirokawa 2014 *)
let cpc r s = ()


(* TCAP approximation: Zankl, H., Felgenhauer, B., Middeldorp, 2011*)
(* -----------------------------------------*)
let noncomm r s = visible !Ref.tcap
  begin
    fun () ->
      logger ("checking non commutation");
      let peaks = Match.cpp2_all r s in
      try
        let (t,_,so,u) =
          List.find
            (fun (t,_,_,u) -> Join.unjoin s r (t,u))
            peaks in
        NO (t,so,u)
      with Not_found ->
        MAYBE
  end


(* temporary theorems! there are NOT commutation criteria *)
let pto r (s,t) =
  Listset.intersect [t] (Rewriting.parallel_step r s)
  
let pto_mfrom r pair =
  Rewriting.joinable_by
    (Rewriting.parallel_step r)
    (fun x -> Rewriting.reachable !Ref.k r [x])
    pair

let reducts_of_pto_mfrom rs (s,t) =
  Listset.inter
    (Rewriting.parallel_step rs s)
    (Rewriting.reachable !Ref.k rs [t])


let parallel_closed r s = visible !Ref.pc
  begin
    fun () ->
      (* confluence criterion *)
      if not (Listset.equal r s) then MAYBE
      else
        if Trs.left_linear r &&
           List.for_all
             (fun (s,t) -> pto r (s,t))
             (Match.cp2 r r)
        then YES (PC (r,r))
        else MAYBE
  end


(* Toyama 1988 *)
let almost_parallel_closed r s = visible !Ref.almost_pc
  begin
    fun () ->
      logger ("applying " ^ thm_description (PC (r,s)));
      let b1 = List.for_all
                 (Rewriting.joinable_by
                   (Rewriting.parallel_step s)
                   (fun x -> Rewriting.reachable !Ref.k r [x]))
                 (Match.cp2 r s) in
      let b2 = List.for_all
                 (pto r)
                 (Match.cp_in s r) in
      if Trs.left_linear r && b1 && b2 then YES (PC (r,s)) else MAYBE
  end

                  
let pcp_in r =
  Match.parallel_cpp r
  |> List.filter (fun (t,ps,_,u) -> ps <> [[]])
  |> List.map (fun (t,_,_,u) -> (t,u))


(* Gramlich 1996: a variant of parallel closedness *)
let gramlich96 r s = visible !Ref.gramlich96
  begin
    fun () ->
      logger ("applying " ^ thm_description (PC (r,r)));
      (* confluence criterion *)
      if not (Listset.equal r s) then MAYBE
      else
        let b1 = List.for_all (pto_mfrom r) (Match.cp2 r r) in
        let b2 = List.for_all
                   (fun (s,t) -> List.mem t (Rewriting.reachable !Ref.k r [s]))
                   (pcp_in r) in
        if Trs.left_linear r && b1 && b2 then YES (PC (r,r)) else MAYBE
  end


let toyama81_condition r (t,ps,s,u) =
  let vs = variables_at ps s in
  Listset.intersect
    (Rewriting.reachable !Ref.k r [t])
    (including_variables vs (parallel_stepS r) u)


(* Toyama 1981: a variant of parallel closedness *)
let thm_toyama81 r s = visible !Ref.toyama81
  begin
    fun () ->
      logger ("applying " ^ thm_description (PC (r,r)));
      (* confluence criterion *)
      if not (Listset.equal r s) then MAYBE
      else
        let b1 = List.for_all (pto_mfrom r) (Match.cp2 r r) in
        let b2 = List.for_all (toyama81_condition r) (Match.parallel_cpp r)
        in
        if Trs.left_linear r && b1 && b2 then YES (PC (r,r)) else MAYBE
  end


(***** Upside/Outside closedness *****)

(* return list of u such that t -||-> u with upside closed positions of p *)
let upside_closed_step rs p t =
  position_is
    (List.for_all (fun q -> Positions.upside_closed q p))
    (parallel_stepS rs t)
  |> reducts

(* return list of u such that t -> u with a outside closed position of p *)
let outside_closed_step rs p t =
  position_is
    (List.for_all (fun q -> Positions.outside_closed q p))
    (parallel_stepS rs t)
  |> reducts

let single_outside_closed_step rs p t =
  position_is
    (fun qs -> match qs with [q] -> Positions.outside_closed q p
                             | _ -> invalid_arg "single_outside_closed_step")
    (stepS rs t)
  |> reducts

(* return list of t such that s -||-> t with position below of p *)
let parallel_step_below_of rs p t =
  position_is
    (List.for_all (Positions.above_of p))
    (parallel_stepS rs t)
  |> reducts

(* rewrite parallel positions with (t,P) > (t,Q) and Var(s,P) \subseteq Var(t,Q) *)
let rec parallel_stepS_smaller_than rs w u =
  such_that
    (fun step -> w > weight (reduct step, Step.positions [step]))
    (parallel_stepS rs u)
and weight (t,ps) = 
  Listx.sum (List.map term_size (subterms_at ps t))
  (* if ps = [] then                                     *)
  (*   term_size t                                       *)
  (* else                                                *)
  (*   Listx.sum (List.map term_size (subterms_at ps t)) *)

let parallel_step_smaller_than rs w u =
  reducts (parallel_stepS_smaller_than rs w u)

(* rewrite parallel strongly outermost positions *)
(* s -||->^P t such that every p in P is strongly outermost position of 's' *)
let strongly_outermost_parallel_stepS rs t =
  let so_ps = Positions.strongly_outermost_positions rs t in
  position_is
    (fun qs -> Listset.subset qs so_ps)
    (parallel_stepS rs t)

(* s -||->^P t such that P is a subset of prefix stable positions of the reduct *)
let parallel_stepS_stable_to rs t =
  such_that
  (fun ((_,_,u) as step) ->
    let qs = Step.positions [step] in
    let ps = Positions.prefix_stable_positions rs u in
    Listset.subset qs ps)
  (parallel_stepS rs t)

let strongly_outermost_parallel_step rs t =
  reducts (strongly_outermost_parallel_stepS rs t)

let parallel_step_stable_to rs t =
  reducts (parallel_stepS_stable_to rs t)

(* t <-||- s by strongly outermost positions of 's' *)
let rev_strongly_outermost_parallel_step rs t =
  such_that
    (fun ((_,_,u) as step) ->
      let ps = Step.positions [step] in
      let so = Positions.strongly_outermost_positions rs u in
      Listset.subset ps so)
    (parallel_stepS (Trs.inverse rs) t)
  |> reducts

(* t P^<-||- s such that every p in P is above equal of strongly outermost of 't' *)
let rev_parallel_step_stable_to rs t =
  let ps_qs = Positions.prefix_stable_positions rs t in
  position_is
    (fun ps -> Listset.subset ps ps_qs)
    (parallel_stepS (Trs.inverse rs) t)
  |> reducts

(* t <-||- u by upside positions *)
let is_upside_closed rs (t,p,u) =
  List.exists (fun u' -> t = u') (upside_closed_step rs p u)

(* t <- u by a outside position *)
let is_single_outside_closed rs (t,p,u) =
  List.exists (fun u' -> t = u') (single_outside_closed_step rs p u)

(* t <- u by a outside position *)
let is_outside_closed rs (t,p,u) =
  List.exists (fun u' -> t = u') (outside_closed_step rs p u)

(* eq_closure(->) = ->^= *)
let eq_closure arr t = t :: arr t
  
(* upside closed theorem by Ohta and Oyamaguchi (97) *)
let upsideClosed97 r s = visible !Ref.uc
  begin
    fun () ->
      logger ("applying " ^ thm_description (OC (r,r)));
      (* confluence criterion *)
      if not (Listset.equal r s) then MAYBE
      else
        let root_eq rs = eq_closure (rewrite_at rs []) in
        let join_by (a,x) (b,y) = Rewriting.joinable_by a b (x,y) in
        let reach_to (a,x) y = join_by (a,x) ((fun z -> [z]), y) in
        (* condition for inner critical pairs *)
        let inner_case (t,p,u) = 
             (* (i) *)
          reach_to (rewrite_at r [], t) u
          || (* (ii) *)
          is_upside_closed r (t,p,u) in
        (* condition for overlay critical pairs *)
        let overlay_case (t,u) = 
             (* (i) *)
          let ir = Trs.inverse r in
          join_by (parallel_step_below_of ir [], t) (root_eq ir, u)
          || (* (ii) *)
          reach_to (parallel_step r >> parallel_step_below_of r [], u) t
          || (* (iii) *)
          join_by (root_eq r, t) (parallel_step r, u)
          || (* (iv) *)
          reach_to (root_eq r >> root_eq r, t) u in
        let b =
          List.for_all
            (fun (t,p,_,u) ->
              if p <> [] then inner_case (t,p,u)
                         else overlay_case (t,u))
            (Match.cpp r) in
        (* finally *)
        if Trs.left_linear r && b then YES (UC (r,r)) else MAYBE
  end

(* outside closed theorem by Ohta and Oyamaguchi (97) *)
let outsideClosed04 r s = visible !Ref.oc
  begin
    fun () ->
      logger ("applying " ^ thm_description (OC (r,r)));
      (* confluence criterion *)
      if not (Listset.equal r s) then MAYBE
      else
        let root_eq rs = eq_closure (rewrite_at rs []) in
        let join_by (a,x) (b,y) = Rewriting.joinable_by a b (x,y) in
        let reach_to (a,x) y = join_by (a,x) ((fun z -> [z]), y) in
        (* condition for inner critical pairs *)
        let inner_case (t,p,u) =
             (* (i) *)
          reach_to (rewrite_at r [], t) u
          || (* (ii) *)
          is_single_outside_closed r (t,p,u) in
        (* condition for overlay critical pairs *)
        let overlay_case (t,u) =
             (* (i) *)
          let ir = Trs.inverse r in
          join_by (parallel_step_below_of ir [], t) (root_eq ir, u)
          || (* (ii) *)
          reach_to (parallel_step r >> parallel_step_below_of r [], u) t
          || (* (iii) *)
          join_by (root_eq r, t) (parallel_step r, u)
          || (* (iv) *)
          reach_to (root_eq r >> root_eq r, t) u in
        let b =
          List.for_all
            (fun (t,p,_,u) ->
              if p <> [] then inner_case (t,p,u)
                         else overlay_case (t,u))
            (Match.cpp r) in
        (* finally *)
        if Trs.left_linear r && b then YES (OC (r,r)) else MAYBE
  end



let replace_all (s,pattern,u) =
  List.fold_left (fun s' (_,p) -> replace s' (subterm_at p u) p)
                 s
                 pattern


(* Is t <-Q-||-P-> u holds? *)
let parallel_convertible (t,(to_t,to_u),u) =
  let qs,ps = Step.positions [to_t], Step.positions [to_u] in
  Positions.parallel_list ps qs
  &&
  t = source to_u && u = source to_t
  &&
  Listset.subset qs (Term.positions t) && Listset.subset ps (Term.positions u)
  &&
  replace_all to_u = replace_all to_t


let below_equal_list ps qs =
  Listx.concat_map
    (fun q ->
      List.filter
        (fun p -> Positions.below_equal p q)
        ps)
    qs


let outside_parallel_closed rs (t,ps,s,qs,u) ((o1',o2),v,(o1,o2')) =
  let vp t ps =
    List.filter (fun q -> Positions.parallel_list [q] ps)
                (variable_positions t) in
  let psps t = Positions.prefix_stable_positions rs t in
  let outside_positions (t,ps) = 
    List.filter
      (fun q -> Positions.not_below_of_list ps [q])
      (Term.positions t) in
  let prefix_stables_u =
    psps u @ Positions.positions_under psps u qs
  in
  (* condition (a) *)
  2 >= Positions.chains (o1'@o2@o1@o2'@(vp t ps)@(vp u qs))
  (* condition (b) *)
  && Positions.not_below_of_list o2 o2'
  (* && Listset.subset o2 (Positions.prefix_stable_positions rs v) *)
  && Listset.subset o2 (psps u @ psps t)
  && Listset.subset o2'(prefix_stables_u @ outside_positions (u,qs))
  (* condition (c) *)
  && (o1' = [] || weight (t,ps) > weight (t,o1')
     ||
     (o2' <> [] && o1 = []
       && weight (t,ps) = weight (t,o1')
       && weight (u,qs) > weight (u,o2')))


(* return intermediate of steps if exists.
     t -||-P1-> v1 and v2 <-P2-||- u ==> v = v1[v2|_p]_{p \in P2}
       when v1[v2|_p]_{p \in P2} = v2[v1|_p]_{p \in P1}
   Note this is aprroximation, only the case P1 || P2 is calculated.
 *)
let rec intermediate_of_conversion_aux step1 step2 =
  let v1,v2 = reduct step1, reduct step2 in
  let ps,qs = Step.positions [step1], Step.positions [step2] in
  let replaceables =
    Listset.maximal ~leq:Positions.prefix_order (Listset.unique (ps@qs)) in
  (* check incompatibility *)
  if not (Listset.subset replaceables (Term.positions v1) &&
         Listset.subset replaceables (Term.positions v2))
  then []
  else
    let replace_by t (dst1,ps1) (dst2,ps2) p =
      if   List.mem p ps1
      then replace t (subterm_at p dst1) p
      else replace t (subterm_at p dst2) p in
    let v1' =
      List.fold_left
        (fun t -> replace_by t (v1,ps) (v2,qs)) v1 replaceables in
    let v2' =
      List.fold_left
        (fun t -> replace_by t (v2,qs) (v1,ps)) v2 replaceables in
    (* return intermediate *)
    if
      ps <> [] && qs = [] then [(step1, v1', step2)]
    else if
      ps = [] && qs <> [] then [(step1, v2', step2)]
    else if
      v1' = v2' then [(step1, v1', step2)]
    else
      [(step1,v1',step2); (step1,v2',step2)]
(* return intermediates *)
and intermediate_of_conversion steps1 steps2 =
  List.concat
    (Listx.product intermediate_of_conversion_aux steps1 steps2)


(* return parallel outside-closed tuples ((o0',o2),v,(o1,o2'))
   where t <-O1'-||-O2-> v <-O1-||-O2'-> u *)
let rec parallel_convertion rs (t,pats1,s,pats2,u) =
  (* setup parameters *)
  let ps,qs = List.map snd pats1, List.map snd pats2 in
  (* get candidates of intermediate v *)
  let ioc = intermediate_of_conversion
              (parallel_stepS rs t)
              (parallel_stepS rs u) in
  (* filter and trim them *)
  Listx.concat_map (one_more_step rs (t,ps,s,qs,u)) ioc
  |> List.filter (fun (conv1,v,conv2) ->
      parallel_convertible (t,conv1,v)
      &&
      parallel_convertible (v,conv2,u))
(* take O1 (step_from_t) and O2 (step_from_u),
   find O1' and O2' such that t <-O1'-||-O2-> v <-O1-||-O2'-> u *)
and one_more_step rs peak (from_t,v,from_u) =
  (* t? <-||- v *)
  (* let steps_to_t = parallel_stepS rs v in *)
  let steps_to_t = (* optiomization *)
    position_is
      (Positions.parallel_list (Step.positions [from_t]))
      (parallel_stepS rs v) in
  (* v -||-> u? *)
  (* let steps_to_u = parallel_stepS rs v in *)
  let steps_to_u = (* optiomization *)
    position_is
      (Positions.parallel_list (Step.positions [from_u]))
      (parallel_stepS rs v) in
  (* make two conversions *)
  Listx.product
    (fun to_t to_u ->
      (* t <-O1'-||-O2-> v <-O1-||-O2'-> u *)
      ((to_t,from_t),v,(from_u,to_u)))
    steps_to_t steps_to_u
  (* filter non parallel outside closed peak *)
  |> List.filter
    (fun ((to_t,from_t),v,(from_u,to_u)) ->
      let o1',o1 = positions [to_t], positions [from_u] in
      let o2,o2' = positions [from_t], positions [to_u] in
      outside_parallel_closed rs peak ((o1',o2),v,(o1,o2')))


let parallelOutsideClosed r = visible !Ref.opc
  begin
    fun () ->
      logger ("applying " ^ thm_description (OC (r,r)));
      let critical_peaks = Match.parallel_overlaps r r in
      let parallel_peaks =
        List.map
          (fun (t,pats,s,(lr,p),u) -> (t,pats,s,[(lr,p)],u))
          critical_peaks in
      let rev_parallel_peaks =
        List.map
          (fun (t,pats,s,(lr,p),u) -> (u,[(lr,p)],s,pats,t))
          critical_peaks
      in
      if Trs.left_linear r &&
         List.for_all
           (fun peak -> parallel_convertion r peak <> [])
           (Listx.unique (parallel_peaks @ rev_parallel_peaks))
      then YES (OC (r,r)) else MAYBE
  end

(* Shintani and Hirokawa 2020 *)
(* disabled *)
let thm_OPC r s = MAYBE
  (* (* confluence criterion *) *)
  (* if Listset.equal r s then  *)
  (*   parallelOutsideClosed r  *)
  (* else                       *)
  (*   MAYBE                    *)


(* &&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&& *) 
(*                     ALGORITHM                       *) 
(* &&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&& *) 


(* check directly *)
let direct0 r s =
    Result.exists (fun thm -> thm r s)
      [ noncomm; thm_MO; thm_SC; thm_SCP; thm_JK; thm_DC; thm_RL;
      (* confluence theorems *)
        (* thm_OPC; *) (* unpublished *)
        thm_toyama81;
        parallel_closed;
        almost_parallel_closed;
        gramlich96;
        upsideClosed97;
        outsideClosed04;
      ]


(* from commit: 1e29a93f2b5d408afe9cfa929ec4c3783237daa1 *)
open Num


(* size of powerset that subtracted 1
   may throw Failure "int_of_big_int" *)
let bound xs = (Int 2 **/ Int (List.length xs)) -/ Int 1

(* get subset with index of element of its power set
  i.e. decode (Int k) xs = nth k (powerset xs)

  example:
    decode (Int 0) [1;2;3] = []
    decode (Int 2) [1;2;3] = [2]
    decode (Int 7) [1;2;3] = [1;2;3]
*)
let rec decode k = function
  | [] -> []
  | _ :: xs when mod_num k (Int 2) =/ Int 0 ->
      decode (floor_num (k // Int 2)) xs
  | x :: xs ->
      x :: decode (floor_num (k // Int 2)) xs


(* caching *)
type peak = Term.t * Term.t * Term.t
let table : ((Trs.t * Trs.t) * (theorem,peak) Result.t) list ref
  = ref []
let _DEFAULT_SIZE = 100
let create_table n =
  ()
  (* table := Hashtbl.create n *)

let cache_switch = ref false
let cache_on  () = cache_switch := true
let cache_off () = cache_switch := false

let lookup r s =
  let rr,ss = Trs.sort r, Trs.sort s in
  try
    Some (List.assoc (rr,ss) !table)
    (* Some (Hashtbl.find !table (rr,ss)) *)
  with Not_found ->
    None

let entry r s result =
  if not !cache_switch then
    result
  else
    (table := ((Trs.sort r,Trs.sort s),result) :: !table;
    (* (Hashtbl.add !table (sort r,sort s) result; *)
     result)


let deflt rs = (rs, (bound rs, Int 0, Int 0))

(* divide set S into subsystem A and subsystems B1,B2,..
    such that Bi contains S \ A *)
let divide (r,(_,i,p)) =
  let rs = decode i r in
  let ss = Ls.del r rs @+ decode p rs in
  if p > bound rs || rs = [] || ss = [] ||
  subseteq r ss || subseteq r rs then
    None
  else
    Some (deflt rs, deflt ss)


(* debug *)
let count = ref 1
let get_count () =
  let n = !count in
  count := n + 1;
  n

let begin_degmsg r s =
  if !Ref.verbose then
    Format.eprintf "[%d]Checking R,S\nR =\n%sS =\n%s@."
      (get_count ()) (Trs.sprint_rules r) (Trs.sprint_rules s)
  else
    ()
let end_degmsg result = 
  if !Ref.verbose then
    match result with
    | YES _ -> Format.eprintf "> Commute\n\n@."
    | NO  _ -> Format.eprintf "> Not Commute\n\n@."
    | MAYBE -> Format.eprintf "> MAYBE\n\n@."
  else
    ()


let logging r s =
  begin_degmsg r s;
  let result = entry r s (direct0 r s) in
  end_degmsg result;
  result

(* API for extern modules *)
let direct r s =
  match lookup r s with
  | Some MAYBE -> MAYBE
  | Some yesno -> yesno
  | None       -> logging r s

(* means already searched *)
exception FFF
let direct_raise r s = 
  match lookup r s with
  | Some MAYBE -> raise FFF
  | Some yesno -> yesno
  | None       -> logging r s

  
let incr i = i +/ Int 1

(* r and s are trs, l1 and l2 mean upper bound of i and j,
   i and j indicate subset of trss r and s,
   p and q indicate subset of i and j
   (i u p = r & j u q = s & p include r-i, q include s-j)
*)
let rec commute' (r,(l1,i,p)) (s,(l2,j,q) as s0) =
  if r = [] || s = [] then
    invalid_arg "commute': empty"
  else
    try
      match direct_raise r s with
      | YES thm -> YES thm
      | NO nf   -> NO nf
      | MAYBE   -> check_div1 (r,(l1,incr i,p)) s0
      (* | MAYBE   -> check_div2 (r,(l1,i,p)) (s,(l1,incr j,q)) *)
    with FFF ->
      MAYBE
and check_div1 (r,(l1,i,p) as r0) (s,(l2,j,q) as s0) =
  (* R1 include R  where R = R1 u R2 and R2 include R-R1 *)
  if i >= l1 then
     MAYBE
  (* R2 include R1 where R = R1 u R2 and R2 include R-R1 *)
  else if p >= i then
    (* increase R1 *)
    check_div1 (r,(l1,incr i,Int 0)) (deflt s)
  else
    match divide r0 with
    (* TODO: improve algorithm *)
    | None -> check_div1 (r,(l1,incr i,p)) (deflt s)
    | Some ((rs1,_ as r1), (rs2,_ as r2)) ->
        match commute' r1 s0, commute' r2 s0 with
        | YES thm1, YES thm2 -> YES (COM [thm1;thm2])
        (* | _ -> check_div1 (r,(l1,i,incr p)) (deflt s)  *)
        | _ -> match check_div2 r0 (s,(l2,incr j,q)) with
          | MAYBE -> check_div1 (r,(l1,i,incr p)) (deflt s)
          | yesno -> yesno
and check_div2 (r,(l1,i,p) as r0) (s,(l2,j,q) as s0) =
  (* S1 include S  where S = S1 u S2 and S2 include S-S1 *)
  if j >= l2 then
    MAYBE
  (* S2 include S1 where S = S1 u S2 and S2 include S-S1 *)
  else if q >= j then
    (* increase S1 *)
    check_div2 r0 (s,(l2,incr j,Int 0))
  else
    match divide s0 with
    (* TODO: improve algorithm *)
    | None -> check_div2 r0 (s,(l2,j,incr q))
    | Some ((rs1,(m1,j1,q1) as s1), (rs2,(m2,j2,q2) as s2)) ->
        match commute' r0 s1, commute' r0 s2 with
        | YES thm1, YES thm2 -> YES (COM [thm1;thm2])
        | _ -> check_div2 r0 (s,(l2,j,incr q))

let commute r s =
  let size =
    try
      int_of_num (bound r */ bound s)
    with Failure _ ->
      _DEFAULT_SIZE in
  (* create cache table *)
  create_table size;
  if not (Trs.left_linear (r @ s)) then
    MAYBE
  else
    match noncomm r s with
    | NO nf -> NO nf
    | other -> commute' (deflt r) (deflt s)

let cr r = commute r r
(*
example
let rs = [ a,b; a,d; b,c; c,d; f[x],f[x] ];;
*)
