open Ls
open Term
open Result


let flatten ac rs =
 [ Ac_subst.flatten ac s, Ac_subst.flatten ac t
   | s,t <- rs ]

let inverse  es = [ r, l | l, r <- es ]
let inverse' es = [ r,s,l | l,s,r <- es ]



(** C-modulo **)

(* heuristics for C-terminating TRS *)
let emb_multi c rs =
  let mul = Rewriting.multistep c in
  [ l2,r | l1,r <- rs; l2 <- mul l1 ]

let c_eqaul_by rs p s t =
  not (List.exists (p s) (Rewriting.multistep rs t))

let is_C c (l,r) =
  List.mem (root l) c && Acrpo.rule_C (l,r)

(* common C-symbols and rules without A *)
let common_c r s =
  let c_sym0, c_sym0' = Acrpo.c_sym r, Acrpo.c_sym s in
  let a_sym, a_sym'   = Acrpo.a_sym r, Acrpo.a_sym s in
  let c_sym, c_sym'   = c_sym0 \\ a_sym, c_sym0' \\ a_sym' in
  if c_sym @+ c_sym' = [] || c_sym <> c_sym' ||
     c_sym @^ Acrpo.ac_sym (r@s) <> [] then
    None
  else
    Some
      ( ([ lr | lr <- r; not (is_C c_sym lr) ]
        ,[ lr | lr <- s; not (is_C c_sym lr) ])
      , c_sym
      , [ lr | lr <- r; is_C c_sym lr ])

(** AC-modulo **)

let common_ac r s =
  match Ac_rewriting.find_ac r, Ac_rewriting.find_ac s with
  | Some (rs1, ac_sym , ac1, ac_all),
    Some (rs2, ac_sym', ac2, _) when ac_sym = ac_sym'
      -> Some ((rs1,rs2), ac_sym, ac_all)
  | _ -> None

let common_acS r s = if !Ref.jkac then common_ac r s else None

let common_cS r s = if !Ref.jkc then common_c r s else None

(* check commutation modulo of AC cup C systems s.t. AC cap C = [] *)
let ac_jk86_modulo r s =
  let (rr,ss), c_sym, c_all = 
    match common_cS r s with
    | None   -> (r,s),[],[]
    | Some x -> x in
  (* failure if TRSs has different AC rules or non terminating *)
  let ((rr1,ss1), ac_sym, ac_all) =
    match common_acS rr ss with
    | None -> (rr,ss),[],[]
    | Some x -> x in
  (* check E-termination (We have only AC-terminating module) *)
  if not (Acrpo.terminate (c_sym@ac_sym) (rr1@ss1)) then
    MAYBE
  else
    let rc,sc = emb_multi c_all rr1, emb_multi c_all ss1 in
    (* peak   <-R- . -S-> *)
    let cps =
      Listx.unique (
          Ac_overlap.critical_peak2 ac_sym rc (sc@ac_all)
          @ inverse' (Ac_overlap.critical_peak2 ac_sym sc (rc@ac_all))) in
    (* valley -S-> . <-R- *)
    let cps_nf =
      [ Ac_rewriting.nf ac_sym sc s,
        u,
        Ac_rewriting.nf ac_sym rc t 
      | s, u, t <- cps ] in
    let p s t = Ac_rewriting.equal ac_sym s t in
    let nj =
      [ s,u,t | s,u,t <- cps_nf;
                (* s,t <- Trs.rename [s',t']; *)
                c_eqaul_by c_all p s t ] in
    match nj with
    | []             -> YES ()
    | (s, u, t) :: _ -> NO (s,u,t)


(* convert results from the implication CRM_AC => CR *)
let ac_jk86 r s =
  match ac_jk86_modulo r s with
    YES x -> YES x
  | NO _  -> MAYBE
  | MAYBE -> MAYBE


(** A-modulo **)

let is_A a (l,r) =
  List.mem (root l) a && Acrpo.rule_A (l,r)

(* common C-symbols and rules without A *)
let common_a r s =
  (* systems have same A-rules without AC-rules? *)
  let a_sym, a_sym' = Acrpo.a_sym r, Acrpo.a_sym s in
  if a_sym @+ a_sym' = [] || a_sym <> a_sym' ||
     a_sym @^ Acrpo.ac_sym (r@s) <> [] then
    None
  else
    (* all A-rules are reversible? *)
    let a_all = [ rule | a <- a_sym; rule <- Acrpo.a_class a] in
    if List.for_all
      (fun rule -> List.exists (Match.variant_rule rule) (r@^s))
      a_all
    then
      Some
        ( ([ lr | lr <- r; not (is_A a_sym lr) ]
          ,[ lr | lr <- s; not (is_A a_sym lr) ])
        , a_sym
        , a_all)
    else
      None

let common_aS r s = if !Ref.jka then common_a r s else None

let a_unifiers a_sym ((l1,r1),p,(l2,r2)) =
  let var (x,y) = Term.variables x @ Term.variables y in
  let at = subterm_at p l2 in
  if var (l1,r1) @^ var (l2,r2) = [] && not (is_variable at)
     (* exclusion of variant rules is unsound for extended critical pairs *)
     (* && (p <> [] || not (Match.variant_rule (l1,r1) (l2,r2))) *)
  then
    Assoc.unify a_sym l1 at
  else
    []

let a_overlaps a_sym r s0 =
  let s =
    if Listset.inter (Trs.variables r) (Trs.variables s0) = [] then s0 else Trs.rename s0
  in
  [ (lr1,p,lr2), m
    | lr1 <- r; l2,r2 as lr2 <- s; p <- function_positions l2;
      m <- a_unifiers a_sym (lr1,p,lr2) ]

let a_cpp2 a_sym r s =
  [ replace (Term.substitute m l2) (Term.substitute m r1) p,
    Term.substitute m l2,
    Term.substitute m r2
    | ((l1,r1),p,(l2,r2)), m <- a_overlaps a_sym r s ]

(* ExtA(R) = 
     { f(l,x) -> f(r,x), 
       f(x,l) -> f(x,r),
       f(x,f(l,y)) -> f(x,f(r,y)) 
     | l -> r in R with root(l) = f in F_A}
*)
let a_ext a_sym rs = 
  let x = Term.new_var () in
  let y = Term.new_var () in
  rs @
  [ rule
  | (F (f, _) as l), r <- rs; List.mem f a_sym;
    rule <- [F (f, [x; l]), F (f, [x; r]);
             F (f, [l; x]), F (f, [r; x]);
             F (f, [x; F (f, [l; y])]), F (f, [x; F (f, [r; y])]) ] ]


let a_c_jk86_modulo r s =
  (* failure if TRSs has different AC rules or non terminating *)
  let ((rr0,ss0),c_sym,c_all) =
    match common_cS r s with
    | None   -> (r,s),[],[]
    | Some x -> x in
  let ((rr,ss), a_sym, a_all) =
    match common_aS rr0 ss0 with
    | None   -> (rr0,ss0),[],[]
    | Some x -> x in
  (* check E-termination (We have only AC-terminating module) *)
    if not (Acrpo.terminate (c_sym@a_sym) (rr@ss)) then
        MAYBE
    else
      let rc,sc = emb_multi c_all rr, emb_multi c_all ss in
      let ext_rc = a_ext a_sym rc in
      let ext_sc = a_ext a_sym sc in
      (* peak   <-R- . -S-> *)
      let cpps =
          (a_cpp2 a_sym ext_rc (ext_sc@a_all) @
           inverse' (a_cpp2 a_sym ext_sc (ext_rc@a_all))) in
      (* valley -S-> . <-R- *)
      let cpps_nf =
        [ Assoc.nf a_sym ext_sc t,
          s,
          Assoc.nf a_sym ext_rc u
          | t,s,u <- cpps ] in
      let p s t = Assoc.equal a_sym s t in
      let nj =
        [ t,s,u | t,s,u <- cpps_nf; 
                  t',u' <- Trs.rename [t,u];
                  c_eqaul_by c_all p t' u' ] in
      match nj with
      | []           -> YES ()
      | (t,s,u) :: _ -> NO (t,s,u)

(* convert results from the implication CRM_A <=> CR *)
let a_c_jk86 r s =
  a_c_jk86_modulo r s


let jk86 rs ss =
  Result.exists
    (* (fun thm -> thm (Step.reduce rs) (Step.reduce ss)) <- integrate to main*)
    (fun thm -> thm rs ss)
    (* [ ac_jk86; a_jk86 ] *)
    [ a_c_jk86; ac_jk86 ]
