open Ls
open Format

type t = 
  | V of string
  | F of string * t list

type rule = (t * t)
type trs  = rule list


(***** PRINT FUNCTIONS *****)
let ppf_stdout () = formatter_of_out_channel stdout

let print_list pr sep ppf = function
  | [] -> ()
  | x :: xs ->
      pr ppf x;
      List.iter (fprintf ppf (sep ^^ "%a") pr) xs 

let rec print ppf = function
  | V x       -> fprintf ppf "%s" x
  (* | F (f, []) -> fprintf ppf "%s" f *)
  | F (f, ts) -> fprintf ppf "%s(%a)" f (print_list print ",") ts

(* print to string *)
let sprint_pp pp t =
  pp str_formatter t;
  flush_str_formatter ()

let sprint = sprint_pp print

let sprint_ () t = sprint t

(********************************************************
|___|___|___|___|___|___|___|___|___|___|___|___|___|___|
__|___|___|___|___|___| WALL _|___|___|___|___|___|___|__
|___|___|___|___|___|___|___|___|___|___|___|___|___|___|
********************************************************)

(** generator **)
module VI = Index.Make (String)
module CI = Index.Make (String)

let fresh_var_s   x = sprintf "x%d" (VI.index x)
let fresh_const_s x = sprintf "c%d" (CI.index x)

let fresh_var   x = V (fresh_var_s x)
let fresh_const x = F (fresh_const_s x, [])

let v_count   = ref 0
let v_incr () = v_count := !v_count + 1; !v_count
let c_count   = ref 0
let c_incr () = c_count := !c_count + 1; !c_count

let new_var   () = V (fresh_var_s   (sprintf "nx%d" (v_incr ())))
let new_const () = F (fresh_const_s (sprintf "nc%d" (c_incr ()) ),[])

(** predicate1 **)
let forall = List.for_all (fun x -> x)

let is_var = function
  | V _ -> true
  | _   -> false

let is_const = function
  | F (_,[]) -> true
  | _        -> false

(** basica functions **)
let rec vars_raw = function
  | V x       -> [x]
  | F (f, ts) -> [ x | t <- ts; x <- vars_raw t]
let vars t = uniq (vars_raw t)

let rec funs_raw = function
  | V x       -> []
  | F (f, ts) -> f :: [g | t <- ts; g <- funs_raw t; g <> f]
let funs t = uniq (funs_raw t)

let symbols t = funs t @ vars t

let root = function
  | V x       -> x
  | F (f, ts) -> f

let root_f = function
  | V _     -> None
  | F (f,_) -> Some f

let top_is f = function
  | V _     -> false
  | F (g,_) -> f = g

(*********************************************)
(************* about substitution ************)

let rec subst s = function
  | V x when List.mem x [v | v,_ <- s]
             -> List.hd [t | y,t <- s; x = y]
  | F (f,ts) -> F (f, List.map (subst s) ts)
  | t        -> t

let replace_sym f (sigm : (string * string) list) =
  if List.mem f [g | g,_ <- sigm] then
    List.assoc f sigm
  else
    f
let rec replace_symbol sigm = function
  | V x      -> V x
  | F (f,ts) -> F (replace_sym f sigm,
                   List.map (replace_symbol sigm) ts)


(* return variant of t *)
let rename t = subst
                 [x, new_var () | x <- uniq (vars t)]
                 t

(* rename terms by one scope *)
let renames ts = match rename (F ("",ts)) with
  | F ("",ts) -> ts
  | _         -> failwith "Term.rename: impossible"

(** SUBTERM **)
let rec p_subterms t = match t with
  | V x      -> []
  | F (f,ts) -> [s | ti <- ts; s <- subterms ti]
and subterms t = t :: p_subterms t

let subterm_with_pos t =
  let rec inner t p = match t with
    | V _      -> [t,p]
    | F (_,[]) -> [t,p]
    | F (_,ts) -> (t,p) ::
      [x | (ti,n) <- List.combine
                       ts
                       (range_of 1 ts);
           x <- inner ti (p @ [n])]
  in inner t []

(* t is subterm of s *)
let rec subterm_of s t = List.mem t (subterms s)

(** POSITION **)
(* suppose position is begining at 1 *)
type pos = int list

let sprint_pos = function
  | [] -> "0"
  | p  -> List.fold_left
        (fun s q -> sprintf "%s%d" s q) "" p

let print_pos p = print_endline (sprint_pos p)

let poss t = [p | (_,p) <- subterm_with_pos t]
let poss_fun t = [p | (t,p) <- subterm_with_pos t; not (is_var t)]
(* p is a position at s in t *)
let pos_in s t =
  [p | (x,p) <- subterm_with_pos t;
       x = s ]

let rec subterm_at t p =
  let safe_nth xs n = try List.nth xs n
    with Failure _ -> failwith @@
      sprintf "Term.subterm_at: invalid position ==> term %s and position %s"
        (sprint t) (Util.sprint_int_list p)
  in
  match p,t with
  | [], _           -> t
  | n::ns, F (_,ts) -> subterm_at (safe_nth ts (n-1)) ns
  | _               -> failwith "Term.subterm_at: index too large"

(* t[s]p *)
let rec replace t s p = match t,p with
  | _, []           -> s
  | F (f,ts), n::ns -> F (f, (* take n xs @ drop n xs = xs *)
                         take (n-1) ts @
                         (replace (List.nth ts (n-1)) s ns
                           :: drop n ts))
  | _               -> failwith "Term.replace: index too large"


(** predicate2 **)
let linear t =
  let rec elemn = function
    | []    -> true
    | x::xs -> not (List.mem x xs) && elemn xs
  in elemn (vars_raw t)

let rec var_clash (x,t) = function
    | []          -> false
    | (y,s) :: ys when x = y
                  -> t <> s || var_clash (x,t) ys
    | _ :: ys     -> var_clash (x,t) ys

let consist sigma = forall [not (var_clash s sigma) | s <- sigma]
