open Ls
open Coll_result

(* Redundant Rules Method *)
(* 時間がないので、仮のモジュールに定義する *)

(* least fixed point for concat map *)
let rec fix f xs = List.rev (fix0 f xs (Ls.concat_map f xs))
and fix0 f xs = function
  | [] -> xs
  | ys when Ls.subseteq ys xs -> xs
  | ys -> fix0 f (ys @+ xs) (Ls.concat_map f ys)

module I = Index.Make (
  struct
    type t = Rule.t
    let compare = Pervasives.compare
  end
)

let rec chain1 t = function
  | [] -> []
  | lr::rs ->
      let fs = Term.functions t in
      let ds = Rules.defined_symbols [lr] in
      if Ls.subseteq ds fs then
        I.index lr :: chain1 t rs
      else
        chain1 t rs

(* TRS R -> [(t, R')] where
  - all rules in R' are possible rewritable t
  - t is a term of a rule (l,r) in R
*)
let rec chain rs =
  let chain1L = [ I.index (l,r), chain1 l ss
                | l,r <- rs;
                  ss <- [setminus rs [l,r]]] in
  let chain1R = [ I.index (l,r), chain1 r ss
                | l,r <- rs;
                  ss <- [setminus rs [l,r]]] in
  (* search grouped rules (possible rewritable rules) *)
  let f i j = if i = j then [] else
    Ls.del (List.hd [cs | k,cs <- chain1R; j = k]) [i] in
  (* merge them  *)
  [(fst (I.element i), i), Ls.uniq (fix (f i) cs) | i, cs <- chain1L] @
  [(snd (I.element i), i), Ls.uniq (fix (f i) cs) | i, cs <- chain1R]

(* remove emulatable rules recursibly *)
let rec remove_rules rs map =
  match [ i | i,cs <- map; Ls.subseteq cs rs ] with
  | [] -> map
  | removable -> remove_rules
    removable
    [ j, (Ls.del ds rs) | j,ds <- map; not (List.mem j removable) ]

let rec merge_rules map = Ls.uniq (merge_rules0 [] map)
and merge_rules0 rs map =
  (* if a rule is necessary, it cannot joinable by other rules *)
  let necessary = [ i | i,[] <- map ] @ rs in
  let map2 = remove_rules necessary map in
  match [ j,cs | j,cs <- map2; not (Ls.subseteq cs necessary) ] with
  (* no removable rules *)
  | []   -> necessary
  (* there are several removable rules *)
  | need -> 
      match Ls.minimal
        (fun (_,cs1) (_,cs2) -> List.length cs1 > List.length cs2)
        [ j,cs | j,cs <- need] with
      | []       -> rs
      | (j,_)::_ -> let necessary_plus = j :: necessary in
        merge_rules0 necessary_plus (remove_rules necessary_plus need)


let rec remove_redundants b_l2r b_conv k rs =
  begin
    (* compute dependency of lhs and rhs terms in rules *)
    let chainSet = chain rs in
    let minimal = function
      | [] -> []
      | zs -> Ls.minimum (fun xs ys -> List.length xs > List.length ys) zs in
    (* compute minimal rules (without l->r it self) to l rewrite to r for all l->r *)
    let rs1 = if not b_l2r then rs else
      List.map I.element (
        merge_rules
          [ i, minimal 
              [seql | seql,_ <- k_join_idx k [ I.element i_rl | i_rl <- cs ] [] (l,r)]
          | l,r <- rs; (t,i),cs <- chainSet; i = I.index (l,r) && l = t ]
      ) in
    (* compute minimal rules (without l->r) to join l and r for all l->r *)
    let rs2 = if not b_conv then rs1 else
      List.map I.element (
        merge_rules
          [ i, minimal 
              [seql @+ seqr | seql,seqr <- k_join_idx
                        k
                        [ I.element i_rl | i_rl <- Ls.del cs [j] ]
                        [ I.element i_rl | i_rl <- Ls.del ds [i] ]
                        (l,r)]
          | l,r <- rs; (t,i),cs <- chainSet; i = I.index (l,r) && l = t;
                       (u,j),ds <- chainSet; j = I.index (l,r) && r = u ]
      )
    in rs2
  end
and k_join_idx k rs ss e =
  [ Ls.uniq [ I.index rl | _, Some (rl,p), _ <- seql ],
    Ls.uniq [ I.index rl | _, Some (rl,p), _ <- seqr ]
  | seql, seqr <- Join.k_join k rs ss e ]


let if_reduced r r' =
  if r === r' then
    ()
  else
    begin
      print_endline "NOTE: input TRS is reduced\n";
      Format.printf "original is\n%s@." (Trs.sprint_rules r);
      Format.printf "reduced to\n%s@."  (Trs.sprint_rules r');
    end
