open Ls
open Term
open Step

(************ JOINABILITY ************)

(** We assume all rewrite rules in Table are included in a TRS s.t
  - TRS R the tool argument:
    > tool R
  - TRS R1 u R2 u ... u Rn the tool argument:
    > tool R1 ... Rn

In other words this module is loaded only when running tool
**)

module Table = Map.Make (
  struct
    type t = (Term.t * Term.t)
    let compare (x,y) (z,w) =
      Pervasives.compare (x,y) (z,w)
  end)

(* table propery
   1: element is tuple ((s,t),[->_A*, ->_B* | s ->_A*.*B_<- t])
   2: if ((s,t), seqs) in table then ((t,s), rev seqs) in table
   3: every element never removed
*)
let join_table : (rewrite list * rewrite list) list Table.t ref
    = ref Table.empty

let _cache = ref false
let cache_on  () = _cache := true
let cache_off () = _cache := false

(************* Preliminary **************)

(** table operatoer **)
let show_pair (s,t) =
  Format.printf "(%a,%a)@." print s print t
let show_seq direct seq =
  List.iter
    (fun rw ->
      Format.printf "%s " direct;
      show_rewrite rw)
    seq
let show_joins xs =
  List.iter
    (fun (seq1,seq2) ->
      Format.printf "---------------------------@.";
      show_seq "==>" seq1;
      show_seq "<==" seq2)
    xs
let show_table table =
  Table.iter
    (fun pair seqs ->
      show_pair pair;
      show_joins seqs;
      Format.printf "@."
    )
    table
      
let reset_table () = join_table := Table.empty
let show        () = show_table !join_table

let is_over xs seq =
  List.for_all
    (fun (_,rule,_) ->
      match rule with
      | Some (rule,_) -> List.mem rule xs
      | None          -> true
    )
    seq

let lookup ((r : trs),(s : trs)) pair table =
  try
    let seqs = Table.find pair table in
    match [ s1,s2 | s1,s2 <- seqs; is_over r s1 && is_over s s2 ] with
    | [] -> None
    | xs -> Some xs
  with Not_found ->
    None

let lookup_joins r_s pair =
  lookup r_s pair !join_table


let add_join pair seq =
  try
    let seqs = Table.find pair !join_table in
    if List.mem seq seqs then
      (* pair and seq is already registered *)
      ()
    else
      (* pair is registered but seq is not so *)
      join_table := Table.add pair (seq::seqs) !join_table
  with Not_found ->
    (* pair is not registered *)
    join_table := Table.add pair [seq] !join_table


(**************** Calculator *****************)
(** MEMO: argument name 'arr' is abbreviation of arrow **)

(* returned list respect *<- ( not ->* ) *)
let jjj1 (arr : stepS) ss =
  uniq
  [ many_steps @ steps, last_term steps
    | many_steps, s <- ss;
      steps <- arr s;
      steps <> []
  ]

(* Prop: seq in ss ==> ss in jjj1 arr ss *)
let rec jjj2 (arrs : stepS list) ss =
  match arrs with
  | []          -> ss
  | arr :: arrs -> jjj2 arrs (jjj1 arr ss)

(* register si -> sm .. ->.<- tn <- .. <- tj
    for all i <= m and j <= n in xs and ys *)
let add_all = function
  | [],[] -> ()
  | xs,ys ->
    let trivial =
      let bottom =
        try last_term xs with Invalid_argument _ ->
        try last_term ys with Invalid_argument _ ->
            invalid_arg "Join.add_all: incorrect application. xs = [] and ys = []"
      in [bottom, None, bottom]
    in let tails xs = [ ys | ys <- Ls.suffix (xs @ trivial); ys <> [] ] in
    List.iter
      (fun (pair,seqs) -> add_join pair seqs)
      [ (si, ti), (nontrivials (x::ss1), nontrivials (y::ss2))
        | (si,_,_) as x :: ss1 <- tails xs;
          (ti,_,_) as y :: ss2 <- tails ys;
          si <> ti]
 
(* sequence order *)
let rec embed xs ys =
  match xs,ys with
  | _, [] -> true
  | [],_  -> false
  | x :: xs, y :: ys -> (x = y && embed xs ys) || embed xs (y :: ys)

let pair_embed (xs,ys) (xs',ys') =
  embed xs xs' && embed ys ys'

(* minimality is needed? *)
let jj0 r_arrs s_arrs (t,u) =
  if t = u then
    [[],[]]
  else
    let l1 = jjj2 r_arrs [[], t]
    and l2 = jjj2 s_arrs [[], u] in
    (* reform join sequence *<- to ->* *)
    let joins = uniq [nontrivials xs, nontrivials ys | xs, t' <- l1; ys, u' <- l2; t' = u'] in
    (* caching *)
    let mini = minimal
        (fun x y -> pair_embed x y)
        joins in
    (if !_cache then List.iter add_all mini else ());
    mini

(********** MAIN FUNCTIONS **********)

(** no lookup caches in this function, caching is available only standard rewriting relation (e.g. jj) **)
(* pair is joinable by ->r and s<- *)
let jj_ex r_arrs s_arrs pair =
  jj0 r_arrs s_arrs pair

(* standard join function
 try to find join sequence s.t.
   (fst pair) -rs1-> ... -rsm> . <-ssn- ... <-ss1- (snd pair) *)
let jj rss sss pair =
  let r_s = union rss, union sss in
  match lookup_joins r_s pair with
  | Some xs when xs <> [] -> xs
  | _ ->
      jj_ex
        (List.map (k_stepSs 1) rss)
        (List.map (k_stepSs 1) sss)
        pair
  (* | Some [] -> failwith "jj: lookup failure" *)

let rec k_join k r s pair =
  jj (rep k r) (rep k s) pair

let is_k_join k r s pair =
  [] <> k_join k r s pair


(* unjoin *)
let unjoin rs ss (s,t) =
  let sub = subst [ v, new_const () | v <- vars s @+ vars t ] in
  not @@ Match.unifiable
    [ Match.tcap rs (sub s)
    , Match.tcap ss (sub t) ]


(** joinability **)

let joinable rss sss pair =
  if unjoin (union rss) (union sss) pair then
    Result.NO pair
  else if jj rss sss pair <> [] then
    Result.YES ()
  else
    Result.MAYBE

let k_joinable k r s pair =
  joinable (rep k r) (rep k s) pair

(* predicate *)
let join_by r_arrs s_arrs pair =
  jj_ex r_arrs s_arrs pair <> []
let join k r s pair = match k_joinable k r s pair with
  | Result.YES _ -> true
  | _            -> false

(* joinability via direction *)
let steps step ss = uniq [s2 | s1 <- ss; s2 <- step s1 ]
let rec cutting k step ss t =
  if k < 1 then
    false
  else
    List.mem t ss || cutting (k-1) step (steps step ss) t
let goto k step s t = cutting k step [s] t


let light_cache  = ref []
let _light_cache = ref false
let light_on  () = _light_cache := true
let light_off () = _light_cache := false

let rec l_lookup x p table =
  if !_light_cache then
    l_lookup0 x p table
  else
    None
and l_lookup0 (rr,ss) p1 = function
  | [] -> None
  | (((rs1,rs2),p2),result)::xs ->
      if rs1 === rr && rs2 == ss && p1 = p2 then
        Some result
      else
        l_lookup0 (rr,ss) p1 xs

let rec jj_light k rs ss pair =
  match l_lookup (rs,ss) pair !light_cache with
  | Some result -> result
  | None ->
      let result = jj_light0 k rs ss pair in
      light_cache := (((rs,ss),pair),result) :: !light_cache;
      result
and jj_light0 k rs ss (t,u) =
    if unjoin rs ss (t,u) then
      Result.NO (t,u)
    else if k <= 0 then
      Result.MAYBE
    else
      let ts, us = uniq (t :: rewrite rs t), uniq (u :: rewrite ss u) in
      if ts @^ us <> [] then
        Result.YES ()
      else
        match Result.exists
          (jj_light0 (k-1) rs ss)
          [ t1,u1 | t1 <- ts; u1 <- us ] with
        | Result.YES _ -> Result.YES ()
        | Result.NO nf -> Result.NO nf
        | Result.MAYBE -> Result.MAYBE


(*************************************************)
(***** utility *****)
let light rws = [ rl,t | _,Some (rl,_),t <- rws ]
let lights xs = uniq @@ List.map (fun (x,y) -> light x, light y) xs

let show_lights dis ls = List.iter
    (fun ((l,r),t) -> Format.printf "%s -- %a,%a --> %a@."
        dis print l print r print t)
    ls
let show_ljoins xs = List.iter
    (fun (seq1,seq2) ->
      Format.printf "---------------------------@.";
      show_lights "==>" seq1;
      show_lights "<==" seq2)
    xs

let rules_from rws     = [ x | _,Some (x,_),_ <- rws ]
let rules_pos_from rws = [ x | _,Some x,_ <- rws ]

let rules_From (rws1,rws2)     = rules_from rws1, rules_from rws2
let rules_pos_From (rws1,rws2) = rules_pos_from rws1, rules_pos_from rws2
