open Term
open Match

(* copy from cache.ml *)
let cache : ((t*t) * (string*t) list list) list ref
  = ref []

let lookup f s t =
 try
   let m = snd (List.find (fun ((a,b),c) -> a= s && 
                            b= t) !cache) in 
  (* let _ = Format.printf "cached@." in  *)
  m
 with Not_found ->
  let m = f (s,t) in 
  let _ = cache := ((s,t),m) :: !cache in m

(* extended rewriting with A-matching *)
let rec rewrite_root ac t = function
  | [] -> false, t
  | (l, r) :: rules ->
      (* let matcher = Ac_subst.matcher ac (l,t) in *)
      (* let _ = Format.printf "computing ac-matcher@." in *)
      let matcher = lookup (Ac_subst.matcher ac) l t in 
      (* let _ = Format.printf "ac-matcher found@." in *)
      begin
	match matcher with
	| [] -> rewrite_root ac t rules
	| m :: _ ->
	    true, subst m r
      end

let rec remove_first elem = function
  | [] -> []
  | h :: t -> if elem = h then t else h :: (remove_first elem t)

let rec remove_arg ls = function
  | [] -> ls
  | h :: t ->
      let ls1 = remove_first h ls in
      remove_arg ls1 t

let head f = function
  | F(g,gs) when f=g -> true
  | _ -> false

let rec flatten ac = function
  | V _ as v -> v
  | F(f,ts) ->
      let ts_flat = List.map (flatten ac) ts in 
      if not (List.mem f ac) then F(f, ts_flat)
      else
        if List.exists (head f) ts 
        then flatten ac (F(f, iter f ts_flat))
        else F(f,ts_flat)
and iter f = function
  | [] -> []
  | (F(g,gs)) :: tt when f=g -> gs @ (iter f tt)
  | h :: tt -> h :: (iter f tt)

let rec find_redex ac f args rules = function
  | [] -> false, V "false"
  | ls :: lss ->
      let b,u = rewrite_root ac (flatten ac (F(f, ls))) rules in
      if b then b, flatten ac (F(f, u :: (remove_arg args ls)))
      else
       find_redex ac f args rules lss

let rec rewrite_aux ac rules = function
  | V _ as t -> false, t
  | F (f, ts) ->
      let l = [ rewrite_aux ac rules ti | ti <- ts ] in
      let ls = List.map snd l in 
      let b, u = rewrite_root ac (flatten ac (F (f, [ t | _, t <- l ]))) rules in
      if b || List.exists (fun (b, _) -> b) l then true, u
      else 
       if not (List.mem f ac) then false, u
       else 
        (* f is ac. Try to find a redex by taking all possible args *) 
        (* assumes ac-symbols are always binary *)
        let pow_ls = [ arg_ls | arg_ls <- Ls.powset ls; List.length arg_ls >= 2 ] in
        let b,t = find_redex ac f ls rules pow_ls in 
        if b then b, flatten ac t else b, flatten ac (F(f,ls))
	

let rec nf ac rules t =
  let t_flat = flatten ac t in 
  (* let _ = Format.printf "Rewriting term %a@." Sterm.print t in *)
  let b, u = rewrite_aux ac rules t_flat in
  if b then nf ac rules (flatten ac u) else t_flat

(* check if s ==_AC t *)

let rec remove_first eq elem = function
  | [] -> []
  | h :: t -> if eq elem h then t else h :: (remove_first eq elem t)

let rec exists_eq eq elem = function
  | [] -> false
  | h :: tt ->
      if eq elem h 
      then true else exists_eq eq elem tt

let remove_common_args eq l1 l2 =
  let rec rca1 ll acc = function
    | [] -> ll, List.rev acc
    | h :: t ->
        if exists_eq eq h ll then
          let lln = remove_first eq h ll in rca1 lln acc t
        else rca1 ll (h::acc) t
  in
  rca1 l1 [] l2

let rec equal ac s t =
  let s_f = flatten ac s in
  let t_f = flatten ac t in 
  match s_f,t_f with 
  | V x, V y -> if x = y then true else false
  | F(f,ff),F(g,gg) when f = g ->
      let aa,bb = remove_common_args (equal ac) ff gg in
      if aa = [] && bb = [] then true else false
  | _, _ -> false

(* 

let t1 = Tplparser.parse_t "f(x,a(),g(f(f(z,b(),z))),e(),x)"
let t2 = Tplparser.parse_t "f(f(a(),g(f(b(),z,z))),e(),x,x)"
equal_ac ["f";"g"] t1 t2

*)


(* 

input trs, get (trs \setminus ac, ac_sym, ac rules)

*)

let make_assoc1 f = 
 (F(f,[F(f,[V "x"; V "y"]);V"z"]),F(f,[V"x"; F(f,[V "y"; V "z"])]))

let make_assoc2 f = 
   (F(f,[V"x"; F(f,[V "y"; V "z"])])),F(f,[F(f,[V "x"; V "y"]);V"z"])

let make_comm f = 
 F(f,[V"x";V"y"]),F(f,[V"y";V"x"])

let is_ac f rr =
  let fc = make_comm f in
  let fa1 = make_assoc1 f in
  let fa2 = make_assoc2 f in 
 List.exists (fun rl -> Match.variant_rule rl fc) rr &&
 List.exists (fun rl -> Match.variant_rule rl fa1 || Match.variant_rule rl fa2) rr

let rec signature = function
  | V _      -> []
  | F (f,ts) -> (f, List.length ts) ::
    [ ss | t <- ts; ss <- signature t ]

let signatures rs = Ls.uniq
  [ s | l,r <- rs; s <- signature l @ signature r ]

let rec remove_variant v_rule = function
  | []       -> []
  | rl :: rs ->
      if Match.variant_rule v_rule rl then
        remove_variant v_rule rs
      else
        rl :: remove_variant v_rule rs

let rec remove_variants rs = function
  | []       -> rs
  | rl :: ss -> remove_variants (remove_variant rl rs) ss


let find_ac rr = 
 let fs = signatures rr in 
 let ac_sym = [ f | f,n <- fs; n=2; is_ac f rr ] in
 if ac_sym = [] then None
 else 
  let ac_rules_all = List.flatten 
     [ [ make_comm f; make_assoc1 f; make_assoc2 f ] | f <- ac_sym ] 
  in 
  let ac_rules_1 = List.flatten 
   [ [ make_comm f; make_assoc1 f ] | f <- ac_sym ] 
  in
  Some (remove_variants rr ac_rules_all, ac_sym, ac_rules_1, ac_rules_all)
  
