open Term
open Format

let count = ref 0

let fresh () = incr count; sprintf "_ac_cps_x%d" !count

let rename (l, r) =  
  let s = [ x, V (fresh ()) | x <- variables l ] in
  (substitute s l, substitute s r)

(* AC Overlaps *)
let overlap_aux ac rule1 rule2 =
  let l1, r1 = rename rule1
  and l2, r2 = rename rule2 in
  [ (l1, r1), p, (l2, r2),  u
  | p <- function_positions l2; 
    ac_unifs <- [Ac_subst.unify ac ((subterm_at p l2),l1)];
    ac_unifs <> []
    (* exclusion of variant rules is unsound for AC-critical pairs *)
    (* && (p <> [] || not (Match.variant_rule (l1, r1) (l2, r2))); *)
    ;
    u <- ac_unifs ] 

let overlap2 ac rules1 rules2 = 
  Listx.unique
    [ x | rule1 <- rules1; rule2 <- rules2; 
          x <- overlap_aux ac rule1 rule2 ]

(* AC overlaps *)
let overlap ac rules = overlap2 ac rules rules


(* AC critical peaks *)
let critical_peak_aux ((l1, r1), p, (l2, r2), mu) =
  (substitute mu (replace l2 r1 p),
   substitute mu l2,
   substitute mu r2)
  
let critical_peak2 ac rules1 rules2 =
  List.map critical_peak_aux (overlap2 ac rules1 rules2)

let critical_peak ac rule1 = critical_peak2 ac rule1 rule1


(* AC critical pair *)
let cp2 ac rules1 rules2 = 
  List.map
    (fun (s,_,t) -> (s,t))
    (critical_peak2 ac rules1 rules2)

let cp ac rules = cp2 ac rules rules


(* utility *)
let print_overlaps xs =
        List.iter (fun (a,p,b,sigma) ->
           Trs.print_rules [a]; 
           Format.printf "position %a with %a@."
             (Formatx.print_list Format.pp_print_int "") p
             print_subst sigma;
           Trs.print_rules [b]) xs

let print_peaks xs = Listx.index xs |>
        List.iter (fun (i,(s,u,t)) ->
               Format.printf "%d: %a <-- @.  %a@.--> %a@.@."
                 i Term.print s Term.print u Term.print t)

