open Term

type subst = (string * Term.t) list

let union xs ys = Ls.uniq (xs @ ys)

let pair x y = (x, y)

let product f xs ys =
  List.concat (List.map (fun x -> List.map (f x) ys) xs)

let cons x xs = x :: xs

let pi list = List.fold_right (product cons) list [[]]



let associative_rules a =
  let ff f x y = F (f, [x; y]) in
  let x = V "x" and y = V "y" and z = V "z" in
  [ ff f (ff f x y) z, ff f x (ff f y z) | f <- a ]
  
let rec normalize a = function
  | V x      -> V x
  | F (f, ts) ->
    let u = F (f, [ normalize a t | t <- ts ]) in
    match u with
    | F (f, [F (g, [x; y]); z]) when f = g && List.mem f a ->
      normalize a (F (f, [x; (F (f, [y; z]))]))
    | _ -> u


let dom subst = [ x | x, t <- subst; V x <> t ]

let compose subst1 subst2 = 
  [ x, Term.substitute subst2 (Term.substitute subst1 (V x))
  | x <- union (dom subst1) (dom subst2) ]

let rec unify_aux a s t = 
  match s, t with
  | V x, t | t, V x -> [[x,t]]
  | F (f, ss), F (g, ts) when f = g ->
      [ List.concat l | l <- pi (List.map2 (unify_aux a) ss ts) ] @
      replicate a s t @ 
      replicate a t s
 | _ -> []
and replicate a s t = 
  match s, t with
  | F (f, [V x; s2]), F (g, [t1; t2]) when f = g && List.mem f a ->
      [ compose [(x, F (f, [t1; V x]))] mu | mu <- unify_aux a s t2 ]
  | _ -> []

let unify a s t =
  unify_aux a (normalize a s) (normalize a t)
  (*
  let rules = associative_rules a in
  unify_aux a
    (Step.normalform rules s) 
    (Step.normalform rules t)
  *)

let rec lift xs = function
  | V x -> V x
  | F (x, []) when List.mem x xs -> V x
  | F (f, ts) -> F (f, [ lift xs t | t <- ts ])

let matcher a s t = 
  let xs = Term.variables t in
  [ [ x, lift xs u | x, u <- subst ]
  | subst <- unify a s (Term.substitute [ x, F (x, []) | x <- xs ] t) ]
  
let rec rewrite_root1 a t = function
  | [] -> (false, t)
  | (l, r) :: rules ->
      match matcher a l t with
      | [] -> rewrite_root1 a t rules
      | subst :: _ -> (true, Term.substitute subst r)

let rec rewrite1_aux a rules = function
  | V _ as t -> false, t
  | F (f, ts) ->
      let l = [ rewrite1_aux a rules ti | ti <- ts ] in
      let b, u = rewrite_root1 a (F (f, [ t | _, t <- l ])) rules in
      b || List.exists (fun (b, _) -> b) l, u

let rec nf a rules t =
  let b, u = rewrite1_aux a rules t in
  if b then nf a rules u else t

let equal a s t = 
  normalize a s = normalize a t
  (*
  let rules = associative_rules a in
  Step.normalform rules s = Step.normalform rules t
  *)

(*
#install_printer Term.print;;
let xx = V "x"
let yy = V "y"
let zz = V "z"
let ww = V "w"
let aa = F ("a", [])
let bb = F ("b", [])
let cc = F ("c", [])
let dd = F ("d", [])
let ff(x, y) = F ("f", [x;y])
let gg(x) = F ("f", [x])
let s = gg(ff(xx,yy))
let t = gg(ff(zz,ww))
let _ = unify ["f"] s t
*)
