open Term
open Match
open Ls

(* all rewrite step satisfy 'step' or 'stepS' type *)
(* abbr. var name
   rw    : rewrite
   rws   : rewrite list
   steps : rewrite list list
*)  
type step    = t -> t list   (* normal step *)
type rewrite = t * (rule * pos) option * t
type stepS   = t -> rewrite list list (* stepS t <=> 't -> t1 -> ... -> tn' *)

(** Utilities *)
let show_rewrite = function
  | s,None,t            -> Format.printf "%a --()--> %a@." print s print t
  | s,Some ((l,r),pos),t ->
      Format.printf "%a" print s;
      Format.printf " --(%a,%a,%s)--> " print l print r (Util.sprint_int_list pos);
      Format.printf "%a@." print t
let show_stepS steps =
  List.iter
    (fun rw -> List.iter show_rewrite rw;  print_endline "-----")
    steps

(* get last rewrite s.t. s1 -> ... -> this *)
let last : rewrite list -> rewrite = function
  | [] -> invalid_arg "Step.last: there are no rewrite sequences"
  | ss -> List.hd (List.rev ss)

let last_term ss =
  match last ss with (_,_,t) -> t

let nontrivials rws = [ rw | (_,Some _,_) as rw <- rws ]

(* translator for stepS to step *)
let shrink (arr : stepS) = fun s ->
  uniq [ t | ss <- arr s; _,_,t <- ss ]

(* ================================================= *)
(** all rewrite functions return list **)

(* rewrite one stap *)
(* return  (term, (rule, position), term) list list
   i.e. term ->_l,r,p term
 *)
let stepS trs t =
  match
    [l,r,s,p | (l,r) <- trs;
               (s,p) <- subterm_with_pos t;
               instance_of s l]
  with
    | []    -> []
    | terms -> [
        [ t, Some ((l,r),p), replace t (subst (matcher l s) r) p ]
        | (l,r,s,p) <- terms ]

let rewrite trs = shrink (stepS trs)

(* if need *)
let rewrite_a trs t =
  try
    List.hd (rewrite trs t)
  with _ ->
    failwith "rewrite_a: there are no reducible terms"

module Table = Map.Make (struct
  (* co-reduced term and number of applied rewrite steps *)
  type t = Term.t * int
  let compare (s,i) (t,j) =
    if s = t then
      Pervasives.compare i j
    else
      Pervasives.compare s t
    (*
    match Match.compare s t with
    | 0 -> Pervasives.compare i j
    | c -> c
    *)
end)

let rw_table : (rewrite list list) Table.t ref
    = ref Table.empty

let reset_table () = rw_table := Table.empty
let _cache         = ref false
let cache_on  ()   = _cache := true
let cache_off ()   = _cache := false

let show_table table = Table.iter
    (fun (t,k) res ->
      Format.printf "k = %d and %a@." k print t;
      show_stepS res)
    table

let show () = show_table !rw_table

(* disable
(* rename to m(t) if there is key (k,i) and mgu m such that m(k) = m(t) *)
let get_renaming : t -> t -> t = function t ->
  let matching = Table.fold
      (fun (k,i) _ v -> if Match.variant k v then k else v)
      !rw_table
      t
  in subst (Match.matcher matching t)

let renamings renm xs = List.map
    (fun (s,x,t) -> renm s, x, renm t)
    xs
*)
let rec k_stepS' k r s = match k with
  | 0 -> [[s,None,s]] (* better than [[]] ? *)
  | 1 -> stepS r s
  | _ -> uniq [ (s,arrow,t) :: ts
                | reducts   <- stepS r s;
                  _,arrow,t <- reducts;
                  ts        <- k_stepS (k-1) r t ]
and k_stepS k r s =
  if not !_cache then
    k_stepS' k r s
  else
    try
      let steps = Table.find (s,k) !rw_table
      and member = function
        | _,None,_       -> true
        | _,Some (x,_),_ -> List.exists (variant_rule x) r in
      (* let renm = get_renaming s in *)
      (* match [ renamings renm rws | rws <- steps; List.for_all member rws] with *)
      match [ rws | rws <- steps; List.for_all member rws] with
      | [] -> raise Not_found
      | ss -> ss
    with Not_found ->
      let ts = k_stepS' k r s in
      rw_table := Table.add (s,k) ts !rw_table;
      ts

(* under k steps (include 0 step) *)
let k_stepSs k r s =
  uniq [ t | l <- range 0 k; t <- k_stepS l r s ]

(* simple k-step *)
let k_step k r  = shrink (k_stepS k r)
let k_steps k r = shrink (k_stepSs k r)


(** MULTI STEP **)  
let cntx p s t = replace s t p
let add_cntx c (s,rule,t) = c s, rule, c t
let add_pos  p (s,rule,t) = match rule with
  | Some (lr,pos) -> s, Some (lr,p @ pos), t
  | None          -> s, rule             , t

(* add position and context to rewrite sequence *)
let rec update p t = function
  | []      -> []
  | (s,rule,u)::rws ->
      let c = cntx p t in
      add_pos p (add_cntx c (s,rule,u))
        :: update p (c u) rws

(* integrate : rewrite list list list -> rewrite list list *)
(* let integrate stepss = List.map List.concat (product stepss) *)
let integrate (s : t) (xss : (rewrite list * pos) list list)
    =
  let rec fix_pos t = function
    | []               -> []
    | (rws1,p) :: rwss ->
        let rws2 = update p t rws1 in
        rws2 :: fix_pos (last_term rws2) rwss
  in List.map
    (fun rwss -> List.concat (fix_pos s rwss))
    (product xss)

(** wrong result **)
let rec multi1 trs s =
(* 1st rule *)
  match s with
  | V x      -> [[s, None, s]]
  | F (f,ts) -> multi2 trs s @ multi3 trs s
(* 2nd rule *)
and multi2 trs s =
  match s with
  | F (f,[]) -> [[s, None, s]]
  | F (f,ts) ->
      (* add position and context to rewrite sequence *)
      (* let update_list p t = [ update p s ss | ss <- multi1 trs t ] in *)
      integrate s [[steps,[i] | steps <- multi1 trs t] | i,t <- indexing ts]
  | _ -> failwith "Step.multi2: incorrectly apply"
(* 3rd rule *)
and multi3 trs s =
  match s with
  | F _ ->
      (* recursive application *)
      let taus t = multi1 trs t in
      (* add context and position *)
      let update_list context p t = [ update p context ss | ss <- taus t] in
      (* sigma is matcher of s and l *)
      let rewrite_sigma c sigma r =
      List.map List.concat @@
        product [ update_list c p (subst sigma v)
                  | x <- vars r;
                    v <- [V x];
                    p <- pos_in v r ]
      in
      [ (s, Some ((l,r),[]), t) :: steps
        | (l,r),sigma <- matchers trs s;
          t           <- [subst sigma r];
          steps       <- rewrite_sigma t sigma r ]
  | _ -> failwith "Step.multi3: incorrectly applied"

let multi_stepS trs t = uniq @@ multi1 trs t
let multi_step trs t  = shrink (multi_stepS trs) t



(** elimination **)
(* let cut rs s t = List.mem t (k_steps !Ref.k rs s) *)
let steps arr ss = uniq [s2 | s1 <- ss; s2 <- arr s1 ]
let rec cutting k arr ss t =
  if k < 1 then
    false
  else
    List.mem t ss || cutting (k-1) arr (steps arr ss) t
let cut k arr s t = cutting k arr [s] t

(* remove unessential rules
   this function is useful for termination (eliminate identical rule) *)
let reduce0 k step r =
  let rec loop front =
    function
      | []       -> front
      | (l,r) :: rs when cut k (step (front @ rs)) l r -> loop front rs
      | lr :: rs -> loop (lr::front) rs in
  List.rev (loop [] r)

let reduce_by k step rs =
  (* elimination is meta operation, caching is meta caching *)
  let table  = !rw_table in
  reset_table ();
  let result = Console.watcher "reducing" (reduce0 k step) rs in
  rw_table := table;
  result

let reduce k r = reduce_by k rewrite r


let if_reduced r r' =
  if r === r' then
    ()
  else
    begin
      print_endline "NOTE: input TRS is reduced\n";
      Format.printf "original is\n%s@." (Trs.sprint_rules r);
      Format.printf "reduced to\n%s@."  (Trs.sprint_rules r');
    end


(******* Property ********)

let is_NF trs t = [] = rewrite trs t

(* don't use normalforms because of its low performance *)
let rec normalforms trs t =
  match rewrite trs t with
  | []    -> [t]
  | terms -> [n | u <- terms; n <- normalforms trs u]

(* assume trs is terminate *)
let rec normalform trs t = 
  if is_NF trs t
  then t
  else normalform trs (rewrite_a trs t)
