open Ls
open Format
open Term

(** Printitng **)
module E = struct
  type t = rule
  let compare = compare
end
module I = Index.Make (E)

(* OCaml evaluation is started at rhs,
so (index x) > (index (y) for fresh x, y ==> false *)
let sort trs =
    (ignore (List.map I.index trs);
    List.sort (fun a b -> compare (I.index a) (I.index b)) trs)

let sprint_rule (l,r) =
  sprintf "  %2d: %s -> %s@." (I.index (l,r)) (sprint l) (sprint r)

let sprint_rules trs = 
  List.fold_left
    (fun s rule -> sprintf "%s%s" s (sprint_rule rule))
    ""
    (sort trs)

let print_rules trs = print_endline (sprint_rules trs)

let sprint_overlap ppf ((a,_),p,(b,_)) =
  sprintf "%s, %s, %s@." (sprint a) (sprint_pos p) (sprint b)

let sprint_overlaps laps = List.fold_left
  (fun s lap -> s ^ sprint_overlap "%s%s" lap)
  ""
  laps

let print_overlaps laps = print_endline (sprint_overlaps laps)

(* a <-^p t -> b *)
let sprint_cp (a,p,t,b) =
  sprintf "%s <- %s, %s -> %s@."
    (sprint a) (sprint_pos p) (sprint t) (sprint b)

let print_cp cp = print_endline (sprint_cp cp)


(********************************************************)
(** Symbols **)

let vars trs =
  uniq [v | l,r <- trs; v <- Term.vars l @ Term.vars r]

let funs trs =
  uniq [ f | l,r <- trs; f <- Term.funs l @ Term.funs r]

let def (l,_) = root l
let defs trs  = uniq @@ List.map def trs

let cons rule = uniq @@ funs [rule] -\\ [def rule]
let conss trs = funs trs -\\ defs trs


(** Subterms **)

let subterms trs =
  uniq [ t | l,r <- trs; t <- subterms l @ subterms r ]


(** Substitution **)

let subst s trs =
  [ subst s l, subst s r | l,r <- trs ]

let rename trs =
  [ Term.rename l, Term.rename r | l,r <- trs ]

let replace_symbol sigm trs =
  let re = Term.replace_symbol sigm in
  [ re l, re r | l,r <- trs ]


(** Predicate **)
let trs_cond1 trs =
  List.for_all (fun (l,_) -> not (is_var l)) trs

let trs_cond2 trs =
  List.for_all (fun (l,r) -> subseteq (Term.vars r) (Term.vars l)) trs

let is_trs trs = trs_cond1 trs && trs_cond2 trs

let linear trs =
  List.for_all
    (fun (l,r) -> Term.linear l && Term.linear r)
    trs

let left_linear trs =
  List.for_all
    (fun (l,_) -> Term.linear l)
    trs

let right_linear trs =
  List.for_all
    (fun (_,r) -> Term.linear r)
    trs

let non_erase trs =
  List.for_all
    (fun (l,r) -> subseteq (Term.vars l) (Term.vars r))
    trs

(** Initialize **)
let trs_cond1' trs =
    [ lr | lr <- trs; not (trs_cond1 [lr]) ]
let trs_cond2' trs =
    [ lr | lr <- trs; not (trs_cond2 [lr]) ]

let init trs =
  let rr = rename trs in
  (* add rule index *)
  ignore (List.map I.index rr);
  (* check validity of TRS *)
  let error1 rs = Format.eprintf
      "ERROR\nleft-hand-side is variable\n%s@."
      (sprint_rules rs)
  and error2 rs = Format.eprintf
      "ERROR\nfree variable exists in right-hand side\n%s@."
      (sprint_rules rs)
  in match trs_cond1' rr, trs_cond2' rr with
  | [],[] -> rr
  | xs,[] -> (error1 xs; exit 1)
  | [],ys -> (error2 ys; exit 1)
  | xs,ys -> (error1 xs; error2 ys; exit 1)
