open Term

(* DEBUG *)
let say m f = f ()
let print_eq (l,r) = ()
let print_eqs eqs = ()

let rec remove_first elem = function
  | [] -> []
  | h :: t -> if elem = h then t else h :: (remove_first elem t)

let remove_common_args l1 l2 =
  let rec rca1 ll acc = function
    | [] -> ll, List.rev acc
    | h :: t -> 
  if List.mem h ll then
          let lln = remove_first h ll in rca1 lln acc t
        else rca1 ll (h::acc) t
  in
  rca1 l1 [] l2
    
let group s t =
 match s,t with
 | F(f,ss), F(g,tt) when f = g ->
     let ss1,tt1 = remove_common_args ss tt in
     f, Ls.count ss1, Ls.count tt1
 | _ -> failwith "invalid arg in Ac.normalize"

let head f = function
  | F(g,gs) when f=g -> true
  | _ -> false

let rec flatten ac = function
  | V _ as v -> v
  | F(f,ts) ->
      let ts_flat = List.map (flatten ac) ts in 
      if not (List.mem f ac) then F(f, ts_flat)
      else
  if List.exists (head f) ts 
  then flatten ac (F(f, iter f ts_flat))
  else F(f,ts_flat)
and iter f = function
  | [] -> []
  | (F(g,gs)) :: tt when f=g -> gs @ (iter f tt)
  | h :: tt -> h :: (iter f tt)

(* solve Diophantine equations; x+x+z = y+y ==> 2x+z = 2y *)
let basic_sol ac h1 h2 = 
  let f, ss, tt = group (flatten ac h1) (flatten ac h2) in
  let bb1,bb2 = (List.map snd ss), (List.map snd tt) in
  let aa = List.map fst (ss @ tt) in 
  let bs = List.map (fun (x,y) -> x @ y) (Dio.basis bb1 bb2) in 
  (f, aa, bs)
  
let rec make_ac_term f v = function
    0 -> None
  | 1 -> Some v
  | n -> Some (make_ac_term_aux f v n)
and make_ac_term_aux f v = function
  | n when n = 2 -> F (f, [v; v])
  | n when n > 2 ->
     let t = make_ac_term_aux f v (n-1) in
     F (f, [v; t])
  | _ -> invalid_arg "make_ac_term"

let rec make_ac_term_from (f,one) = function
    [] -> one
  | [v] -> v
  | [v1;v2] -> F (f, [v1;v2])
  | v1::vs -> F (f, [v1; make_ac_term_from (f,one) vs])

let merge_vec vec vecs = List.map2 (fun xs xss -> xs::xss) vec vecs

(* vec :: string list; sols :: int list list *)
let build_ac1_unifier (f,one) (vec, gens) =
  let rec loop f (vec, gens) acc =
    match gens with
      [] -> acc
    | gen::gens -> 
       let new_v = Term.new_var () in
       loop f (vec, gens)
         (merge_vec
            [make_ac_term f new_v i | v,i <- List.combine vec gen]
            acc)
  in
  let var_map = List.combine
                  vec
                  (loop f (vec, gens) (Ls.rep (List.length vec) []))
  in
  [x, make_ac_term_from (f,one) [v | Some v <- gen]
  | (V x),gen <- var_map ]

(* assume |ac| = 1 *)
let elementary_ac1_unify ac one s t =
  let s' = flatten ac s in
  let t' = flatten ac t in
  match (s', t') with
    F (f, _), F (g,_) when f = g ->
     let _, vs, sols = basic_sol ac s t in
     build_ac1_unifier (f,one) (vs, sols)
  | V x, _ -> [x, t]
  | _, V x -> [x, s]
  | _ -> []
  
let rec erase (ac,one) = function
    V x -> V x
  | t when t = one -> one
  | F (f, [t1;t2]) when List.mem f ac ->
     let t1' = erase (ac,one) t1
     and t2' = erase (ac,one) t2
     in
     if t1' = one then t2'
     else if t2' = one then t1'
     else F (f, [t1';t2'])
  | F (f, ts) -> F (f, [erase (ac,one) t | t <- ts])

let elementary_ac_unify ac s t =
  let one = new_const () in
  let ac1_mgu = elementary_ac1_unify ac one s t in
  let var_range = Ls.uniq [ v | _,t <- ac1_mgu; v <- variables t ] in
  (* List.map only works a limited size of var_range (at least 21 elements fails) *)
  let admissible_sigmas = (* List.map ... *)
    List.fold_left (fun xss xs ->
      List.map (fun v -> v,one) xs :: xss) 
      [] (Listx.power var_range) in
  let mksubst sigma = [x, erase (ac,one) v'
                  | x,v <- ac1_mgu; v' <- [substitute sigma v]] in
  [subst | sigma <- admissible_sigmas; subst <- [mksubst sigma];
         List.for_all (fun (_,v) -> v <> one) subst]

(*
let filterPreCond4 basic_sols aa = 
  let i1n = Ls.range 0 ((List.length aa)-1) in
  let violates4 b i = (List.nth b i >= 2) && (not (is_variable (List.nth aa i))) in
  [ b | b <- basic_sols; not (List.exists (violates4 b) i1n) ]


let condition3 base aa =
 let base_ar = Array.of_list (List.map Array.of_list base) in
 let b = ref true in 
 for i=0 to (List.length aa) - 1 do
  let column_i = [ base_ar.(j).(i) | j <- Ls.range 0 ((Array.length base_ar) - 1) ] in
  begin
  if List.for_all (fun x -> x=0) column_i 
  then b := false  
  end;
 done; !b


let condition4 base aa =
  let base_ar = Array.of_list (List.map Array.of_list base) in
  let b = ref true in 
  for i=0 to (List.length aa) - 1 do
    if not (is_variable (List.nth aa i)) then
      let column_i = [ base_ar.(j).(i) | j <- Ls.range 0 ((Array.length base_ar) - 1) ] in
      let sum = List.fold_left (+) 0 column_i in
      begin
  if not (sum = 1) then b := false
      end;
  done; !b
*)


let apply phi (s,t) = Term.substitute phi s, Term.substitute phi t

let combine phi2 phi1 = 
  [ v, (Term.substitute phi2) t | v,t <- phi1 ]
@ [ v,t | v,t <- phi2; not (List.mem v (List.map fst phi1)) ]

(*
let lookup i basis_var =
 [ List.nth b i, var | b,var <- basis_var; List.nth b i >= 1 ]

let rec make_term f acc = function
  | [] -> F(f,acc)
  | (n,s) :: tt ->
      let ls = [ s | i <- Ls.range 1 n ] in
      make_term f (acc @ ls) tt
      
let make_term_or_sing f ls = 
  match ls with
  | [(n,t)] when n = 1 -> t
  | _ -> make_term f [] ls
  
let make_eqs f basis_var aa = 
 let n = (List.length aa) - 1 in 
 [ List.nth aa i, make_term_or_sing f n_var_ls | i <- Ls.range 0 n; 
   n_var_ls <- [lookup i basis_var]; not (is_variable (List.nth aa i)) ]  

let s_of_v = function
  | V x -> x
  | _ -> failwith "invalid arg in Ac.s_of_v"

let make_sigma f basis_var aa =
 let n = (List.length aa) - 1 in 
 [ s_of_v (List.nth aa i), make_term_or_sing f n_var_ls | i <- Ls.range 0 n; 
   n_var_ls <- [lookup i basis_var]; (is_variable (List.nth aa i)) ]  
*)
  
let rec unify_set f i k sigma eqs =
  if i > k then sigma else
  let si,ti = List.nth eqs i in 
  let next_sigma = [ combine phi1 phi2 | phi2 <- sigma; 
         phi1 <- f (Term.substitute phi2 si,Term.substitute phi2 ti) ] in
  unify_set f (i+1) k next_sigma eqs


(*
let rec simplify_vt subst =
  let appears v sub1 =
    try 
      let _ = List.find (fun (x,y) -> List.mem v (Term.variables y)) sub1 in true 
    with Not_found -> false
  in
  let v,t = List.find (fun (x,y) -> appears x (Ls.del [x,y] subst)) subst in 
  let subst1 = Ls.del [v,t] subst in 
  [ w, Term.substitute [v,t] s | w,s <- subst1 ]


let rec simplify subst = 
 try 
  let subst1 = simplify_vt subst in
  simplify subst1
 with Not_found -> subst

let remove_vars vars subst =
 [ v,t | v,t <- subst; List.mem v vars ]

let filter_invalid subs =
 [ sub | sub <- subs; List.for_all (fun (x,y) -> not (List.mem x (Term.variables y))) sub ]

let rec powerset = function
 | [] -> [[]]
 | h::t -> List.fold_left (fun xs t -> (h::t)::t::xs) [] (powerset t);;
*)

(*
let rec unify ac = function
  | t1,t2 when t1 = t2 -> [[]]
  | (V x, t) 
  | (t, V x) when not (List.mem x (vars t)) -> [[x,t]]
  | F(f,ss),F(g,tt) when f=g && not (List.mem f ac) ->
      let k = (List.length ss) - 1 in
      let ss_tt = List.combine ss tt in 
      unify_set (unify ac) 0 k [[]] ss_tt 
  | F(f,ss),F(g,tt) when f=g && (List.mem f ac) ->
      let f,aa,bs = basic_sol [f] (F(f,ss)) (F(g,tt)) in
      let bs1 = filterPreCond4 bs aa in  
      let bss2 = [ b | b <- powerset bs1; b <> [] ] in
      let bss3 = [ b | b <- bss2; condition3 b aa ] in
      let bss4 = [ b | b  <- bss3; condition4 b aa ] in  
      let basis_var_ls = [ [ bj, Term.new_var () | bj <- bsi ] | bsi <- bss4 ] in 
      let subs =
      List.flatten
      [ unify_set (unify ac) 0 ((List.length eqs)-1)  [(make_sigma f basis_var aa)] eqs | basis_var <- 
  basis_var_ls; eqs <- [ make_eqs f basis_var aa ] ] 
      in subs (*
      let subs1 = List.map simplify subs in 
      List.map (remove_vars (Rule.variables (F(f,ss),F(g,tt)))) subs1 *)
  | _ -> []
*)
  

(* reconstruct term structure for given AC-term
   reconstruct f(x,y,z) => f(x,f(y,z)) *)
let rec reconstruct ac = function
    V x -> V x
  | F (f,ts) when not (List.mem f ac) -> F (f,ts)
  | F (f,ts) ->
     let c xs = F (f, xs) in
     match ts with
       [t1;t2] -> c [t1;t2]
     | t1::t2::ts' -> c [t1; reconstruct ac (c (t2::ts'))]
     | _ -> invalid_arg "reconstruct"

let rec unify_proto ac (s,t) =
  let s',t' = flatten ac s, flatten ac t in
  let vs = variables s @ variables t in
  let refresh sub = [x,v | x,v <- sub; List.mem x vs] in  (*
  unify_proto_aux ac vs (s',t')
   *)
  Ls.uniq
    [[x, reconstruct ac v | x,v <- sub]
    | sub0 <- unify_proto_aux ac vs (s',t'); sub <- [refresh sub0]]
and unify_proto_aux ac vs st = 
  let this = unify_proto_aux ac vs in
  match st with
  | t1,t2 when t1 = t2 -> [[]]
  | (V x, V y) -> if List.mem x vs then [[y,V x]]
                  else if List.mem y vs then [[x,V y]]
                  else [[x,V y]]
  | (V x, t) when not (List.mem x (variables t)) -> [[x,t]]
  | (t, V x) when not (List.mem x (variables t)) -> [[x,t]]
  | F(f,ss),F(g,tt) when f=g && not (List.mem f ac) ->
     let k = (List.length ss) - 1 in
     let ss_tt = List.combine ss tt in 
     unify_set this 0 k [[]] ss_tt
  | F(f,ss),F(g,tt) when f=g && (List.mem f ac) ->
     let vs,vt = [new_var() | _ <- ss], [new_var() | _ <- tt] in
     let e_sigmas = elementary_ac_unify ac (F(f,vs)) (F(f,vt)) in
     let list = [ List.combine (List.map (substitute e_sigma) vs) ss @
                    List.combine (List.map (substitute e_sigma) vt) tt
              | e_sigma <- e_sigmas ] in
     List.flatten
       [ unify_set this 0 (List.length ts -1) e_sigmas ts
       | ts <- list]
  | _ -> []

let vars_in_equations eqs = [v | s,t <- eqs; v <- variables s @ variables t]

(* helper functions *)
let not_mem x xs = not (List.mem x xs)
let elem x eqs = List.mem x (vars_in_equations eqs)
let free x eqs = not_mem x (vars_in_equations eqs)
let split_by p xs = [x | x <- xs; p x], [x | x <- xs; not (p x)]
let apply_to (x,t) eqs =
  let xss,yss = split_by (fun eq -> elem x [eq]) eqs in
  List.map (apply [x,t]) xss, (V x,t) :: yss

  (* Variable Replacement *)
let rec var_rep org_vars eqs eq =
  match eq with
    V x, V y when
         (not_mem x org_vars || List.mem y org_vars)
         && elem x eqs && elem y eqs
    -> Some (apply_to (x,V y) eqs)
  | V y, V x when
         (not_mem x org_vars || List.mem y org_vars)
         && elem x eqs && elem y eqs
    -> Some (apply_to (x,V y) eqs)
  (* original modification *)
  | V x, V y when
         (not_mem x org_vars && List.mem y org_vars)
         && elem x eqs
    -> Some (apply_to (x,V y) eqs)
  | V y, V x when
         (not_mem x org_vars && List.mem y org_vars)
         && elem x eqs
    -> Some (apply_to (x,V y) eqs)
  | _ -> None


(* Special Variable Replacement (not in the paper)
   following rules are original rules dealing with the missing case in the paper
 *)
let var_rep_direction_aux org_vars eqs eq =
  match eq with
  | (V x, V y) when
         elem x eqs && List.mem x org_vars && List.mem y org_vars
    ->
      let xss,yss = apply_to (x,V y) eqs in
      Some (xss, (x, V y), List.filter (fun es -> es <> eq) yss)
  | (V y, V x) when
         elem x eqs && List.mem x org_vars && List.mem y org_vars
    ->
      let xss,yss = apply_to (x,V y) eqs in
      Some (xss, (x, V y), List.filter (fun es -> es <> eq) yss)
  | _ -> None

(* let x be not introduced variables. Then {x -> y; y -> z} is simplified to
   {x -> z; y -> z} *)
let rec var_rep_direction org_vars eqs =
    var_rep_direction_loop org_vars [] [] eqs
and var_rep_direction_loop org_vars directed not_directed = function
    [] -> if directed <> []
          then Some (directed @ not_directed)
          else None
  | eq :: eqs ->
     match var_rep_direction_aux org_vars (not_directed@eqs) eq with
       Some (results,(x,t),others) ->
        var_rep_direction_loop
          org_vars
          ((V x,t) :: List.map (apply [x,t]) directed)
          not_directed
          (results@others)
     | None -> var_rep_direction_loop org_vars directed (eq::not_directed) eqs

(* Replacement *)
let rep org_vars eqs eq = match eq with
    (V x, V y) when x <> y && elem x eqs && elem y eqs
    -> Some (apply_to (x,V y) eqs)
  | (V x, (F _ as t)) when elem x eqs && not_mem x (variables t)
    ->
     (* check conflict with prune *)
      (*
     if not_mem x org_vars && not (linear t) then None
     else *)
     Some (apply_to (x,t) eqs)
  (* we use Orient rule *)
  (*
  | ((F _ as t), V x) when elem x eqs && not_mem x (variables t)
    ->
     (* check conflict with prune *)
     if not_mem x org_vars && not (linear t) then None
     else
        Some ((V x, t) :: List.map (apply [x,t]) eqs)
   *)
  | _ -> None


(* Existential Quantifiers Elimination *)
let rec eqe org_vars eqs eq = match eq with
  | (V x, t) when not_mem x org_vars && free x ((t,t)::eqs) ->
     Some eqs
  | (t, V x) when not_mem x org_vars && free x ((t,t)::eqs) ->
     Some eqs
  | _ -> None


(* Merge *)
(* given eqs shoud be sorted by sort_equation *)
let rec merge eqs eq = match eq with
  | (V x, V y)
    -> None
  | (V x, t) when
         List.exists (fun (u,v) -> (u = V x && not (is_variable v))) eqs
    ->
      let xss,yss = split_by (fun (l,_) -> l = V x) eqs in
      Some (List.map (fun (_,r) -> (t,r)) xss, (V x, t) :: yss)
  | _ -> None


(* Mutate *)
let rec mutate ac eq = match eq with
  (* for standard unification *)
  | (F(f,ss), F(g,tt)) when not_mem f ac ->
     if f <> g || List.length ss <> List.length tt
     then Some []
     else Some [ List.combine ss tt ]
  (* for AC-unification *)
  | ((F(f,_) as s), (F(g,_) as t)) when List.mem f ac ->
     if f <> g
     then Some []
     else
       let ss0,tt0 = match flatten ac s, flatten ac t with
           F(_,xs),F(_,ys) -> xs,ys
         | _ -> failwith "not reachable" in 
       let ss,tt = remove_common_args ss0 tt0 in
       (* because we have `complete` elementary AC-unification *)
       (* let vs,vt = [new_var() | _ <- ss], [new_var() | _ <- tt] in *)
       let subs,subt = [s,new_var() | s <- Listx.unique ss]
                     , [t,new_var() | t <- Listx.unique tt] in
       let replace sub u = let (_,v) = List.find (fun (s,_) -> s = u) sub
                           in v in
       let vs,vt = List.map (replace subs) ss, List.map (replace subt) tt in
       let unifiers = elementary_ac_unify ac (F(f,vs)) (F(f,vt)) in
       Some [ List.combine [substitute sigma v | v<-vs] ss @
                List.combine [substitute sigma v | v<-vt ] tt
            | sigma <- unifiers ]
  | _ -> None

let rec mutate_all ac eqs =
  let eqss = List.map (mutate ac) eqs in
  if List.for_all ((=)None) eqss then 
    [eqs]
  else
    let singleton xs = List.map (fun x -> [[x]]) xs in
    let aux = function
        Some xs, _ -> xs
      | None, xs   -> xs in
    [eqs' | xss <- Ls.product (List.map aux
                                  (List.combine eqss (singleton eqs)));
            eqs' <- mutate_all ac (List.concat xss)]


(* Check *)
(* False if there is no solution of eqs *)
let rec check eqs =
  if eqs = [] then true
  else
    let rec sig_check = function
        [] -> true
      | (F (f,_), F (g,_)) :: _ when f <> g -> false
      | eq :: eqs -> sig_check eqs in
    (* maybe [x, t | t,V x, <- eqs] is needed *)
    let sigma = [x, t | V x, t <- eqs] in
    let apply_right sigma (s,t) = s, substitute sigma t in
    let eqs' = List.fold_left
                (fun eqs' s -> List.map (apply_right [s]) eqs')
                eqs
                sigma in
    sig_check eqs &&
    [] = [() | V x,F(f,ts) <- eqs'; List.mem x (variables (F (f,ts)))]

  
(* Prune; Give-up to find solutions for possibly non-terminating cases *)
let prune org_vars eqs =
  (*
  let p x t = not_mem x org_vars && not (linear t) in
  [] = ([() | V x,t <- eqs; p x t] @ [() | t,V x <- eqs; p x t])
  *) true


(* from Iwami, "正則項上の単一化について", Computer Software, vol 35(4), 2018 *)
let orient eqs = function
  | (F _ as t, V x)
  (*  -> Some ((V x, t) :: List.map (apply [x,t]) eqs) *)
  -> Some ((V x, t) :: eqs)
  | _ -> None

let cleanup eqs =
  Ls.uniq
    (List.filter (fun (s,t) -> s<>t) eqs)

(* assuming: an input equation is flatten *)
let collapse ac = function
    F(f,ss), F(g,tt) when List.mem f ac && f = g ->
     not (List.length ss = List.length tt)
  | F(f,ss), F(g,tt) ->
     not (f = g)
  | _, _ -> false

let occur_twice x eqs = 1 < List.length [s,t | s,t <- eqs; s = x || t = x]
                      
let divide_non_linears org_vars eqs =
  let mem x = List.mem x org_vars in
  let p (s,t) = match s,t with
      V x, V y when mem x && mem y -> occur_twice s eqs || occur_twice t eqs
    | V x, _   when mem x          -> occur_twice s eqs
    | _  , V y when mem y          -> occur_twice t eqs
    | _        -> false in
  split_by p eqs

let add_substitution s1 s2 =
  s1 @ List.map (fun (v,t) -> v, substitute s1 t) s2


(**************************************************************)
(* "Syntactic" AC-Unification, A.Boudet and E.Contejean, 1994 *)
(**************************************************************)

let unify_syn_aux_1_do ac org_vars used0 ess =
  (* setup *)
  let excepts, targets = divide_non_linears org_vars ess in
  let used = used0 @ excepts in
  let all = used @ targets in
  (* check failure *)
  if not (prune org_vars ess && check ess) then
     Some []
  else
  (* apply rules *)
  match targets with
    [] ->
      None
  | eq :: eqs -> 
     (* attempt to apply Orient *)
     match orient (Ls.del all [eq]) eq with
       Some eqs1 ->
         Some [[], eqs1]
     | None ->
     (* attempt to apply Mutation *)
     match mutate ac eq with
       Some [] ->
        (* failed to mutate *)
        Some []
     | Some eqss ->
         Some [[], eqs1@used@eqs | eqs1 <- eqss]
     (* attempt to apply Merge *)
     | None ->
     match merge eqs eq with
       Some (results,others) ->
        Some [[], used@results@others]
     (* attempt to apply Variable Replacement *)
     | None ->
     match var_rep org_vars (Ls.del all [eq]) eq with
       Some (results,others) ->
        Some [[], results@others]
     | None ->
     (* attempt to apply Replacement *)
     match rep org_vars (Ls.del all [eq]) eq with
       Some (results,others) ->
        Some [[], results@others]
     (* attempt to apply EQE *)
     | None ->
     match eqe org_vars (used@eqs) eq with
       Some eqs1 ->
        Some [[], eqs1]
     (* otherwise go the next equation *)
     | None ->
        Some [eq::used, eqs]

(* input is a equation list; output is a list of a equation list *)
let rec unify_syn_aux_1 ac org_vars ess =
  unify_syn_aux_1_loop ac org_vars [] ess
(* decompose equations except ones which introduced from non-linear terms
   i.e. ignore {x = s, x = t} introduced from x+x = s+t because this might
   involves non-terminating unification processes *)
and unify_syn_aux_1_loop ac org_vars used ess =
  let ess = cleanup ess in
  match unify_syn_aux_1_do ac org_vars used ess with
    Some es ->
     Ls.concat_map
       (fun (used',ess') -> unify_syn_aux_1_loop ac org_vars used' ess')
       es
  | None -> [used@ess]


let unify_syn_aux_2_do ac org_vars solver used ess =
  (* setup *)
  let all = used @ ess in
  (* apply rules *)
  if not (prune org_vars ess && check ess) then
    (* give-up to find solutions, or no solutions *)
    (* failure detected *)
    Some []
  else
    begin
      match ess with
        [] ->
         (* succesfuly finished *)
         None
      | eq :: eqs -> 
         (* attempt to apply Orient *)
         match orient (Ls.del all [eq]) eq with
           Some eqs1 ->
            Some [[], eqs1]
        (* attempt to apply Variable Replacement *)
         | None ->
        match var_rep org_vars (Ls.del all [eq]) eq with
          Some (results,others) -> Some [[], results@others]
        (* attempt to apply EQE (this is an original modification; not in the paper) *)
        | None ->
        match eqe org_vars (used@eqs) eq with
        Some eqs1 -> Some [[], used@eqs1]
        (* attempt to apply Merge *)
        | None ->
           match eq with
             V x, t when List.mem x org_vars -> 
             begin
               match merge eqs (V x, t) with
                 Some (results, others) ->
                  (* recursively solve new equation *)
                  let vs es = vars_in_equations [es] in
                  Some [[], eqs2@used@others | es <- results;
                                               eqs2 <- solver (vs es) [es]]
               (* otherwise go the next equation *)
               | _ -> Some [eq::used, eqs]
             end
           (* otherwise go the next equation *)
           | _ -> Some [eq::used, eqs]
    end

let rec unify_syn_aux_2 ac org_vars eqs =
  if List.exists (collapse ac) eqs then
    []
  else
    unify_syn_aux_2_loop ac org_vars [] eqs 
(* merge equations obtained by unify_syn_aux_1 *)
and unify_syn_aux_2_loop ac org_vars used ess0 =
  let ess = cleanup ess0 in
  let solver = unify_syn_aux_1 ac in
  match unify_syn_aux_2_do ac org_vars solver used ess with
    Some es ->
     Ls.concat_map
       (fun (used',ess') -> unify_syn_aux_2_loop ac org_vars used' ess')
       es
  | None -> [used@ess]


(* reduce x1,x2,x3,.. in { x1 = s, t = x2, w = x3, ... } *)
let rec sources org_vars = function
  | []
    -> []
  | (V x,_) :: eqs when List.mem x org_vars
    -> x :: sources org_vars eqs 
  | (_,V x) :: eqs when List.mem x org_vars
    -> x :: sources org_vars eqs
  | _ :: eqs
    -> sources org_vars eqs

(* E is a solved form if E = { x_1 = t_1, ... , x_n = t_n }
   where each variables x_i are distinct and not introduced by Mutation *)
let rec solved org_vars eqs =
  let src = sources org_vars eqs in
  let vars = vars_in_equations eqs in
  List.length src = List.length eqs
  && List.for_all (fun c -> c = 1)
       [i | x,i <- Listx.count vars; List.mem x src]


let rec build_substitution ac org_vars = function
    [] -> []
  | (V x, t) :: eqs when List.mem x org_vars ->
     (x, reconstruct ac t) :: build_substitution ac org_vars eqs
  | (t, V x) :: eqs when List.mem x org_vars ->
     (x, reconstruct ac t) :: build_substitution ac org_vars eqs
  | (s, t) :: _ ->
      let msg = Term.sprint s ^ " = " ^ Term.sprint t in
      failwith ("Ac_subst.build_substitution: unsolved equation " ^ msg)


let unify_syn_aux_34_do ac org_vars solver used eqs =
  (* apply rules *)
  if not (prune org_vars eqs && check eqs) then
    say "prune!" (fun () -> Some [])
  else
    match eqs with
    | [] ->
       None
    | eq :: eqs' ->
       (* attempt to apply Orient *)
       match orient (Ls.del (used@eqs) [eq]) eq with
         Some eqs1 ->
          Some [[], eqs1]
       | None ->
      (* attempt to apply EQE (this is an original modification; not in the paper) *)
      match eqe org_vars (used@eqs') eq with
        Some eqs1 ->
         say "qeq" (fun() -> Some [[], used@eqs1])
      | None ->
      (* attempt to apply Merge *)
      match merge (used@eqs') eq with
        Some (results, others) ->
         say "merged!" (fun () ->
             (* recursively solve new equation *)
             let vs es = vars_in_equations [es] in
             Some [[], eqs2@others | es <- results;
                                          eqs2 <- solver (vs es) [es]])
      | None -> 
       (* attempt to apply Replacement *)
      match rep org_vars (Ls.del (used@eqs) [eq]) eq with
        Some (results,others) ->
         say "REPed" (fun() -> Some [[],results@others])
      | None ->
         (* otherwise go to the next *)
         say "go next" (fun() -> Some [eq::used, eqs'])


let rec unify_syn_aux_34 ac org_vars eqs =
    unify_syn_aux_34_loop ac org_vars [] eqs
(* check failure and build solved forms *)
and unify_syn_aux_34_loop ac org_vars used eqs0 =
  (* setup *)
  let solver = unify_syn_aux_1 ac  in
  let eqs = cleanup eqs0 in
  if solved org_vars (used@eqs) then
    say "solved!" (fun () -> [eqs@used])
  else
    match unify_syn_aux_34_do ac org_vars solver used eqs with
      Some es ->
       Ls.concat_map
         (fun (used',eqs') -> unify_syn_aux_34_loop ac org_vars used' eqs')
         es
    | None -> unify_syn_aux_34_end org_vars (eqs@used)
and unify_syn_aux_34_end org_vars eqs =
  (* attempt to apply Special Variable Replacement *)
  match var_rep_direction org_vars eqs with
    Some (results) ->
     say "Special REPed" (fun() -> [results])
  | None ->
     failwith "no applicative rules"


let cleanup_substitution ac org_vars s =
  List.map (fun (x,t) -> x, reconstruct ac t)
    (List.filter (fun (x,_) -> List.mem x org_vars) s)

(* Syntactic AC Unification *)
let rec unify_syn ac eqs0 =
  let vs = Ls.uniq (vars_in_equations eqs0) in
  let eqs = [flatten ac s, flatten ac t | s,t <- eqs0] in
  Ls.uniq 
    (List.map (build_substitution ac vs)
       (unify_syn_aux ac vs eqs))
(* combine unify_aux_{1,2,34} *)
and unify_syn_aux ac org_vars eqs =
  unify_syn_aux_1 ac org_vars eqs
  |> Ls.concat_map (unify_syn_aux_2 ac org_vars)
  |> Ls.concat_map (unify_syn_aux_34 ac org_vars)
  


(* ----------------------------------------------------------- *)
(*                       Main Functions                        *)
(* ----------------------------------------------------------- *)

(* return a complete set U of unifies s.t.
   sσ =_{AC} tσ for σ in U *)
let unify ac (s,t) =
  unify_syn ac [s,t]


let replace_var_by_consts t = 
 let vars = Term.variables t in
 let sigma = [v, Term.new_const () | v <- vars] in
  sigma, Term.substitute sigma t

let rec replace_consts_by_var sigma = function
  | V x as t -> t
  | F(f,[]) as t ->
      begin
        try
          let v,_ = List.find (fun (w,u) -> u = t) sigma in V v
        with Not_found -> t
      end
  | F(f,ss) ->  F(f,[ replace_consts_by_var sigma si | si <- ss ])

(* return a (complete) set S of substitutions s.t.
   pσ =_{AC} t0 for σ in S *)
let matcher ac (p,t0) = 
  (* p matches t0 if a ground instance of p and t0 are unifiable *)
  let sigma,t = replace_var_by_consts t0 in
  [[x, replace_consts_by_var sigma t | x,t <- phi]
  | phi <- unify ac (p,t)]

let unifiable ac (t1,t2) =
  [] <> unify ac (t1,t2)

