open Term
open Ls

open Step
open Form

(**********************************************)
(*             RULE LABELING MODULE           *)
(*                                            *)
(* using extended rule labeling from 
  Takahito Aoto.
  Automated Confluence Proof by Decreasing Diagrams based on Rule-Labelling.
  RTA 2010.
  pages 8-16.
*)
(**********************************************)


(**************** INDEX MODULE ****************)

(* index of rule *)
module E = struct
  type t = Rule.t
  let compare = Pervasives.compare
end
module I = Index.Make (E)

let var_prefix = "rl"
let var s  = Format.sprintf "%s%d" var_prefix (I.index s)
(* using at main function *)
let labeled trs = [ rule, var rule | rule <- trs ]

(* index of function symbol *)
module F = Index.Make (String)
(* using at contextP *)
let weight s =
  Form.var
    (Format.sprintf "rlw%d" (F.index s))
    NAT


(*************** EXTENDED PART ****************)

(* lexicographic order w.r.t (rewrite relation weight, rule weight) *)
let lex_gt ((p,i),(q,j)) =
  D [ p >% q; C [ p =% q; i >% j ] ]

let lex_eq ((p,i),(q,j)) =
  C [ p =% q; i =% j ]

let lex_geq (x,y) = D [ lex_gt (x,y); lex_eq (x,y) ]

(* polynomial of context C s.t. C[]|p = [] *)
let contextP c  = function
  | [] -> atom "0"
  | p  -> OP
	("+",
   [ weight f | p' <- prefix p \\ [p];
                F (f,_) <- [subterm_at p' c]
   ])

(**************** CONSTRAINTS *****************)

(* encode (polynomial, string represented as rule *)
let enc (f,cs) = f, Form.var cs NAT

(* part of -->^=b . -V{a,b}->^* *)
let psi a b = function
  | []      -> True
  | c :: ds ->
      C (D [lex_geq (enc b, enc c)] :: [ D [lex_gt (enc a, enc d); lex_gt (enc b, enc d)] | d <- ds])

(* part of -V{a}->^* . psi *)
let phi_ab a b cs =
  D [ C (psi a b cs2 :: [ lex_gt (enc a, enc c) | c <- cs1])
    | cs1, cs2 <- List.combine (prefix cs) (suffix cs)]


(* lookup index of rule from (rule,index) *)
let rec lookup a = function
  | []      -> invalid_arg "Rl.lookup"
  | (x,y) :: _ when Match.variant_rule a x -> y
  | _ :: xs -> lookup a xs

let jj_rl k r s cp : ((t * Rule.t * position) list * (t * Rule.t * position) list) list =
  let drop_data rws : (t * Rule.t * position) list =
    try
    [ t,x,p | _,rw,t <- rws; (x,p) <- rw ]
  with Failure _ ->
      failwith "Rl.drop_data: rewrite sequence contains None?"
  in [ drop_data xs, drop_data ys | xs,ys <- Join.k_join k r s cp ]

(* TRS -> encoded + indexed TRS *)
let rec label1 ll t = function
  | [] -> []
  | (u,rl,p) :: cs ->
      (contextP t p, lookup rl ll)
        :: label1 ll u cs

(* let label cs ll = [ lookup c ll | c <- cs ] *)
let label t cs ll = label1 ll t cs
let unlabel     = List.map fst (* reverse of label *)

(* join sequences s.t. (encoded and indexed rules, term) list *)
let wrap_jj k rl sl (t,u) =
  let r,s = unlabel rl, unlabel sl in
  [ label t cs rl, label u ds sl
    | cs,ds <- jj_rl k r s (t,u) ]

let rec show_seq arr = function
  | [] -> ()
  | (t,rl,_) :: seqs ->
      Format.printf "%s %s %a@."
	(Trs.sprint_rules [rl])
	arr
	print t
	;
      show_seq arr seqs

let show_jj (cs,ds) =
  print_endline "=======>";
  show_seq "=>" cs;
  print_endline "<=======";
  show_seq "<=" ds

(* (s,t) is obtained by CP(rl,sl) *)
let phi k rl sl ((p,s),(q,t)) (rule1,rule2) =
  let l1,l2 = (p,lookup rule1 rl), (q,lookup rule2 sl )in
  D [ C [ phi_ab l1 l2 cs;
          phi_ab l2 l1 ds]
    | cs, ds <- wrap_jj k sl rl (s,t)]

(*
let phii r n =
  let ((_,s),(_,t)),_ = List.nth (cp2 r r) n in
  wrap_jj 3 (labeled r) (labeled r) (s,t)
*)

(* critical pair with weight of functions *)
let cp2 r s =
  [ ((contextP l2 p, t),(atom "0", u)),(r1,r2)
    | r1,p,((l2,_) as r2) <- Match.overlaps r s;
      t,_,_,u <- Match.cp_overlap [r1,p,r2] ]

let cp2_all r s = uniq @@
  cp2 r s @
  [ (y,x),(b,a) | (x,y),(a,b) <- cp2 s r ]

(* joinability constraint *)
let rlv k rl sl =
  let r,s = unlabel rl, unlabel sl in
  C [ phi k rl sl cp rr | cp, rr <- cp2_all r s ]


(* wrapper for Commute_smt.thm_RL *)
let wrap_phi k r s (rs1,rs2) =
  let rl,sl = labeled r, labeled s in
  C [ phi k rl sl cp rr | cp, rr <- cp2_all rs1 rs2]


(** [[CDN]]: see Aoto, RTA2010 **)
(*
  w(l[x]) >= w(r[x]) if x is linear     in r[] forall (l,r) in TRS
  w(l[x]) >  w(r[x]) if x is non-linear in r[] forall (l,r) in TRS
*)
let rule_condition rs =
  C
    ([ geq (contextP l p, contextP r q)
       | l,r <- rs;
         x <- Term.variables l;
         not @@ List.mem x @@ remove x (variables_list r);
         p <- positions_of (V x) l;
         q <- positions_of (V x) r
     ] @
     [ GT (contextP l p, contextP r q)
       | l,r <- rs;
         x <- Term.variables l;
         List.mem x @@ remove x (variables_list r);
         p <- positions_of (V x) l;
         q <- positions_of (V x) r
     ]
    )


(******************* RLV -> Formula *******************)

let to_form k r s =
  (* I impose max(labeled(R)) <= min(labeled(S)) *)
  let sl, rl = labeled s, labeled r in
  C [rlv k rl sl; rule_condition (r @ s) ]

let rm_prefix s =
  let pre_len = String.length var_prefix in
  String.sub s
    pre_len
    (String.length s - pre_len)

(* [var_prefix ^ rule index, weight] -> [rule index]*)
let rec decode = function
  | []          -> []
  | (x,_) :: xs ->
      try
	int_of_string (rm_prefix x) :: decode xs
      with Failure _ ->
	decode xs

let solve_rlv k r s =
  Trs.left_linear r && Trs.left_linear s && Smt2.sat [to_form k r s]

let solve_rlv_proof k r s =
  if Trs.left_linear r && Trs.left_linear s then
    begin
      I.refresh ();
      match Smt2.sat_proof [to_form k r s] with
      | Some model ->
          let strip xs = [ y, a | I (y,a) <- xs ] in
          Some
            ([ rule, I.index rule | rule <- cup r s],
            decode @@
               List.sort
                 (fun (_,w1) (_,w2) -> compare w2 w1)
                 (strip model))
      | None   -> None
    end
  else
    None
