open Term
open Match
open Ls

(* all rewrite step satisfy 'step' or 'stepS' type *)
(* abbr. var name
   rw    : rewrite
   rws   : rewrite list
*)  
type step    = t -> t list   (* ordinal one step(s) *)
type rewrite = t * (Rule.t * position) list * t
type stepS   = t -> rewrite list (* one step with rules and positions *)
type stepSs  = t -> rewrite list list (* sequences of steps with rules and positions *)

(** Utilities *)
let show_rewrite_aux hd tl rev rw =
  let head s t = if not rev then s else t in
  let last s t = if not rev then t else s in
  match rw with
  | s,[],t ->
      Format.printf "%a %s--()--%s %a@."
        print (head s t) hd tl print (last s t)
  | s,xs,t ->
      Format.printf "%a" print (head s t);
      List.iter (fun ((l,r),pos) ->
        Format.printf " %s-(%a,%a,%s)--%s "
          hd print l print r (Util.sprint_int_list pos) tl)
        xs;
      Format.printf "%a@." print (last s t)

let show_rewrite rw = show_rewrite_aux "-" ">" false rw
let show_rewrite_rev rw = show_rewrite_aux "<" "-" true rw

let show_stepS steps =
  List.iter
    (fun rw -> List.iter show_rewrite rw;  print_endline "-----")
    steps

(** utility **)
let source ((s,_,_): rewrite) = s
let reduct ((_,_,t): rewrite) = t
let rule_and_position ((_,xs,_): rewrite) = xs

let sources rws = List.map source rws
let reducts rws = List.map reduct rws
let rules_and_positions rws = Listx.concat_map rule_and_position rws
let rules rws     = List.map fst (rules_and_positions rws)
let positions rws = List.map snd (rules_and_positions rws)

(* get last rewrite s.t. s1 -> ... -> this *)
let last : rewrite list -> rewrite = function
  | [] -> invalid_arg "Step.last: there are no rewrite sequences"
  | ss -> List.hd (List.rev ss)

let last_term ss = match last ss with (_,_,t) -> t

(* combine stepS relations (like >>=) *)
let bindS (arr : stepS) rws = 
  Listx.unique (
    Listx.concat_map (fun (s,xs,t) -> 
      List.map (fun (_,ys,u) -> (s,xs@ys,u)) (arr t))
    rws)

let (>>>) arr1 arr2 = fun t -> bindS arr2 (arr1 t)

(* combine step relations (like >>=) *)
let bind (arr : step) rws = Listx.concat_map arr rws

let (>>) arr1 arr2 = fun t -> bind arr2 (arr1 t)

(* filter for binds *)
let such_that prop (rws: rewrite list) =
  List.filter prop rws

let position_is (prop : position list -> bool) =
  such_that (fun step -> prop (positions [step]))

(* filter rewrite steps by variable condition on reducts *)
let including_variablesS vs (arr: stepS) =
  fun t ->
    such_that
      (fun step ->
        Listset.subset (variables_at (positions [step]) (reduct step)) vs)
      (arr t)

let including_variables vs (arr: stepS) = fun t ->
    reducts (including_variablesS vs arr t)

(* this variant checks variable condition of source terms *)
let rev_including_variablesS vs (arr: stepS) =
  fun t ->
    such_that
      (fun step ->
        Listset.subset (variables_at (positions [step]) (source step)) vs)
      (arr t)

let rev_including_variables vs (arr: stepS) = fun t ->
    reducts (rev_including_variablesS vs arr t)
(* remove reflexive step *)
let non_trivials rws =
  List.filter
    (fun (_,xs,_) -> xs <> [])
    rws

(* translators *)
let stepS2stepSs (arr : stepS) = fun t -> [arr t]
let stepS2step (arr : stepS) = fun t -> reducts (arr t)


(* ================================================= *)
(** all rewrite functions return list **)

(* rewrite one stap *)
(* return  (term, (rule, position), term) list list
   i.e. term ->_l,r,p term
 *)
let stepS trs t =
  let ts = [l,r,s,p | (l,r) <- trs;
                      (s,p) <- subterms_with_positions t;
                      instance_of s l] in
  [(t, [(l,r),p], replace t (substitute (matcher l s) r) p)
  | (l,r,s,p) <- ts ]

let rewrite trs t = reducts (stepS trs t)

let rewrite_at trs p =
  stepS2step (fun t ->
    position_is (List.for_all ((=)p)) (stepS trs t))

(* if need *)
let rewrite_a trs t =
  try
    List.hd (rewrite trs t)
  with _ ->
    failwith "rewrite_a: there are no reducible terms"

module Table = Map.Make (struct
  (* co-reduced term and number of applied rewrite steps *)
  type t = Term.t * int
  let compare (s,i) (t,j) =
    if s = t then
      Pervasives.compare i j
    else
      Pervasives.compare s t
    (*
    match Match.compare s t with
    | 0 -> Pervasives.compare i j
    | c -> c
    *)
end)

let rw_table : (rewrite list list) Table.t ref
    = ref Table.empty

let reset_table () = rw_table := Table.empty
let _cache         = ref false
let cache_on  ()   = _cache := true
let cache_off ()   = _cache := false

let show_table table = Table.iter
    (fun (t,k) res ->
      Format.printf "k = %d and %a@." k print t;
      show_stepS res)
    table

let show () = show_table !rw_table

(* disable
(* rename to m(t) if there is key (k,i) and mgu m such that m(k) = m(t) *)
let get_renaming : t -> t -> t = function t ->
  let matching = Table.fold
      (fun (k,i) _ v -> if Match.variant k v then k else v)
      !rw_table
      t
  in substitute (Match.matcher matching t)

let renamings renm xs = List.map
    (fun (s,x,t) -> renm s, x, renm t)
    xs
*)
let rec k_stepSs0 k r s = match k with
  | n when n < 0 -> invalid_arg "Step.k_stepS: k < 0"
  | 0 -> [[s,[],s]]
  | 1 -> List.map Listset.singleton (stepS r s)
  | _ -> Listset.unique
           (Listx.concat_map (fun (_,xs,t) ->
              List.map (fun rws -> (s,xs,t) :: rws)
                       (k_stepSs (k-1) r t))
           (stepS r s))
and k_stepSs k r s =
  if not !_cache then
    k_stepSs0 k r s
  else
    try
      let steps = Table.find (s,k) !rw_table
      and member = function
        | _,[],_    -> true
        | _,[x,_],_ -> List.exists (variant_rule x) r
        | _ -> failwith "Join.k_stepSs: bug hits you!" in
      (* let renm = get_renaming s in *)
      (* match [ renamings renm rws | rws <- steps; List.for_all member rws] with *)
      match [ rws | rws <- steps; List.for_all member rws] with
      | [] -> raise Not_found
      | ss -> ss
    with Not_found ->
      let ts = k_stepSs0 k r s in
      rw_table := Table.add (s,k) ts !rw_table;
      ts

let k_stepS k r t =
  k_stepSs k r t
  |> List.map (fun steps ->
      let last xs = List.hd (List.rev xs) in
      (t, rules_and_positions steps, reduct (last steps)))

(* under k steps (include 0 step) *)
let less_k_stepSs k r s =
  Listx.unique [ t | l <- range 0 k; t <- k_stepSs l r s ]

(* simple k-step *)
let k_step      k r t = reducts (k_stepS k r t)
let less_k_step k r t = Listx.concat_map reducts (less_k_stepSs k r t)


(*****************************************************)
(*****************************************************)

let cntx p s t = replace s t p
let add_cntx c (s,rules,t) = c s, rules, c t
let add_pos  p (s,rules,t) =
  let xs = List.map (fun (lr,pos) -> (lr,p @ pos)) rules in
  s, xs, t

(*****************************************************)

let closed f steps =
  (F (f, sources steps), rules_and_positions steps, F (f,reducts steps))

(** Parallel step with rewrite sequences (copy from rewriting.ml) **)
let rec parallel_stepS rs = function
  | V _ as s -> [s,[],s]
  | F (f, ss) as s ->
      Listx.unique
      ((s,[],s) :: root_parallel_step rs s
      @
      [closed f steps
      | steps <- Listx.pi [ List.map (add_pos [i]) (parallel_stepS rs s)
                          | i,s <- Listx.index ss ] ])
and root_parallel_step rs s =
  Listx.unique
    [ s, [(l,r),[]], substitute (matcher l s) r
    | (l, r) <- rs;
      instance_of s l ]

let parallel_step trs t = Listx.unique (reducts (parallel_stepS trs t))


(*****************************************************)
(** MULTI STEP **)  

let rec multi1 trs s =
(* 1st rule *)
  match s with
  | V x      -> [s, [], s]
  | F (f,ts) -> multi2 trs s @ multi3 trs s
(* 2nd rule *)
and multi2 trs s =
  match s with
  | F (f,[]) -> [s, [], s]
  | F (f,ts) ->
      (* add position and context to rewrite sequence *)
      List.map
        (fun (i,t) -> List.map (add_pos [i]) (multi1 trs t))
        (Listx.ix ts)
      |> Listx.pi
      |> List.map (closed f)
  | _ -> failwith "Step.multi2: incorrect appling"
(* 3rd rule *)
and multi3 trs s =
  match s with
  | F _ ->
      (* recursive application *)
      (* list of pairs: ([rules and positions], new substitution) *)
      let rewrite_sigma sigma t =
        (* pairs of variable and its position on t *)
        let xp =
          List.map (fun x -> (x, positions_of (V x) t)) (variables t)
          |> Listx.concat_map (fun (x,ps) -> List.map (fun p -> x,p) ps)
        in
        xp
        |> List.map (fun (x,p) ->
               List.map (* compute inner reducts with attribution *)
                 (fun step ->
                   let step = add_pos p step in
                   (rule_and_position step, [x, reduct step]))
                 (multi1 trs (substitute sigma (V x))))
        |> Listx.pi (* combine all results wrt variables *)
        |> List.map (fun xss -> Listset.big_union (List.map fst xss),
                                Listset.big_union (List.map snd xss))
      in
      (* rewrite term at root and subterms under variables of rhs *)
      matchers trs s
      |> Listx.concat_map (fun ((l,r),sigma) ->
          let t = substitute sigma r in
          List.map
            (fun (xs,tau) -> (s, ((l,r),[])::xs, substitute tau t))
            (rewrite_sigma sigma r))
  | _ -> failwith "Step.multi3: incorrectly applied"

let multi_stepS trs t = Listx.unique (multi1 trs t)
let multi_step  trs t = Listx.unique (reducts (multi_stepS trs t))

(* xs2 is not correct reducts
let _ = multi_stepS [f[f[x]],h[x]] (f[f[g[f[f[a]];f[f[a]]]]])
        |> Listx.drop 4;;
let _ = multi_step [f[f[x]],h[x]] (f[f[g[f[f[a]];f[f[a]]]]])
let xs1 = multi_step r (F ("f",
  [F ("f",
    [F ("g", [F ("f", [F ("f", [V "$0"])]); F ("f", [F ("f", [V "$0"])])])])]))
let xs2 = Rewriting.multistep r (F ("f",
  [F ("f",
    [F ("g", [F ("f", [F ("f", [V "$0"])]); F ("f", [F ("f", [V "$0"])])])])]))
*)


(******* Property ********)

let is_NF trs t = [] = rewrite trs t

(* don't use normalforms because of its low performance *)
let rec normalforms trs t =
  match rewrite trs t with
  | []    -> [t]
  | terms -> [n | u <- terms; n <- normalforms trs u]

(* assume trs is terminate *)
let rec normalform trs t = 
  if is_NF trs t
  then t
  else normalform trs (rewrite_a trs t)
