open Term
open Ls
open Form

(* cf. Albert Rubio, A fully syntactic AC-RPO, Information and Computation, 2011 *)
(* NOTE: we suppose
  - arity of AC- symbols are exactly 2
  - AC and non-AC symbols are orthogonal
therefore following terms are not valid
  FF ("f",[ AC ("f",[VV "x"; VV "x"]) ]) (* mean f(f(x,x)) *)
  AC ("F",[]) (* mean F is AC and constant *)
 *)


(* term for AC system *)
type act =
  | VV of string
  | FF of string * act list
  | AC of string * act list


let rec subterms s = s :: match s with
  | VV x      -> []
  | FF (_,ts) -> [ u | t <- ts; u <- subterms t ]
  | AC (_,ts) -> [ u | t <- ts; u <- subterms t ]

(* ========================================== *)
(** Translation and Flatting **)

(* for translation of type between t to act *)
let a_class a_sym =
  let x,y,z = Term.V "x", Term.V "y", Term.V "z"
  and a ts = F (a_sym, ts) in
  [ a [x; a [y;z]], a [a [x;y]; z];
    a [a [x;y]; z], a [x; a [y;z]]]

let c_class c_sym =
  let x,y = Term.V "x", Term.V "y"
  and c ts = F (c_sym, ts) in
  [ c [x;y], c [y;x] ]

let rule_A lr  = match lr with
  | F (f,_),_ -> List.exists (Match.variant_rule lr) (a_class f)
  | _       -> false

let rule_C lr  = match lr with
  | F (f,_),_ -> List.exists (Match.variant_rule lr) (c_class f)
  | _       -> false

(* get AC-symbols *)
let a_sym rs = List.fold_left
    (fun fs (l,r) ->
      if rule_A (l,r) then (root l) :: fs else fs)
    [] rs

let c_sym rs = List.fold_left
    (fun fs (l,r) ->
      if rule_C (l,r) then (root l) :: fs else fs)
    [] rs

let ac_sym rs =
  let a_syms = a_sym rs
  and c_syms = c_sym rs in
  a_syms @^ c_syms

let is_AC_sym f rs =
  List.mem f (ac_sym rs)

let is_AC_rule ac_sym lr = match lr with
  | F (f,_), _ -> List.mem f ac_sym && (rule_A lr || rule_C lr)
  | _          -> false

(* standard style TRS ===> AC style TRS *)
let to_AC ac_sym rs =
  let rec trans = function
    | Term.V x -> VV x
    | F (f,ts) ->
        if List.mem f ac_sym then
          AC (f, List.map trans ts)
        else
          FF (f, List.map trans ts)
  in [ trans l, trans r | l,r <- rs;
                          not (is_AC_rule ac_sym (l,r))]

(* AC style ===> TRS standard style TRS dropped AC rules *)
(* from_AC ~s:"_AC" attach '_AC' suffix to AC rules *)
let from_AC ?s:(s="") rs =
  let rec trans = function
    | VV x      -> Term.V x
    | FF (f,ts) -> F (f  , List.map trans ts)
    | AC (f,ts) -> F (f^s, List.map trans ts)
  in [ trans l, trans r | l,r <- rs ]

let is_vv = function
  | VV _ -> true
  | _    -> false

let is_fun t = not (is_vv t)

let is_AC = function
  | AC _ -> true
  | _    -> false

let root_AC = function
  | FF (f,_) -> f
  | AC (f,_) -> f
  | VV x     -> invalid_arg "Acrpo.root_AC: argument is variable"

let args = function
  | VV _      -> invalid_arg "Acrpo.args: variable has no arguments"
  | FF (_,ts) -> ts
  | AC (_,ts) -> ts


let rec flatten = function
  | VV x -> VV x
  | FF (f,ts) -> FF (f, List.map flatten ts)
  | AC (f,ts) ->
      let ss = List.map flatten ts
      and flat si = if is_AC si && root_AC si = f then args si else [si]
      in AC (f,[ s | si <- ss; s <- flat si ])

let trs_flatten rs =
  [ flatten l, flatten r | l,r <- rs ]


  (** AC Equality and Ordering **)
let rec eq_list_by eq xs ys = match xs,ys with
    | [],[]        -> true
    | x::xs, y::ys -> eq x y && eq_list_by eq xs ys
    | _            -> false

let eq_ac s t =
  let rec eq_in s t =
    match s, t with
      | VV x, VV y           -> x = y
      | FF (f,xs), FF (g,ys) -> f = g && eq_list_by eq_in xs ys
      | AC (f,xs), AC (g,ys) -> f = g && eq_list_by eq_in
            (List.sort compare xs) (List.sort compare ys)
      | _                    -> false
      in eq_in (flatten s) (flatten t)

let rec diff_by eq xs ys = match xs,ys with
  | x::s1, y::s2 -> if eq x y then diff_by eq s1 s2 else xs, ys
  | _            -> xs,ys

let split_two xs =
  let rec sp k xs =
    if k > List.length xs then
      []
  else
    (take k xs, drop k xs) :: sp (k+1) xs
  in sp 0 xs


(* =========================================== *)
module F = Format
(* fresh variable *)
let count = ref 0
let fresh_var t =
  count := !count + 1;
  var (F.sprintf "fresh_%d" !count) t

(* non-negative variables for variable symbols *)
module VIndex = Index.Make(String)
let vindex x = F.sprintf "v_%s%d" x (VIndex.index x)

(* string representations for function symbols *)
module FIndex = Index.Make(String)
(* ':' is reserved character on Yices (many TRS uses ':') *)
(* let index f = F.sprintf "f_%s%d" f (FIndex.index f) *)
let index f = F.sprintf "f_%d" (FIndex.index f)

(* wrapper for following functions (f_*_of) *)
let f_prot_of h = function
  | FF (f,_) -> h f
  | AC (f,_) -> h f
  | _        -> invalid_arg "Acrpo.f_var_of: argument is not functional term"

(* generate variable representation for each function symbol *)
let f_var f  = Form.var (index f) NAT
let f_var_of = f_prot_of f_var

(* preceding for function symbols *)
let (>+) f g = f_var f >% f_var g

(* for distinguish of function symbols Lex or Mul *)
let f_lex f  = Form.var ("lex_" ^ index f) BOOL
let f_lex_of = f_prot_of f_lex

(** AC-eq **)
let f_eq_ac s t =
  eq_ac s t =? true


(** Multiset **)
let var_eps s = var ("eps" ^ s) BOOL
let var_gam s = var ("gam" ^ s) BOOL

(* wrapper for restricted multiset *)
let forall xs f = C [ p ==>% f i | (i,(_,p)) <- xs ]
let exists xs f = D [ C [p; f i] | (i,(_,p)) <- xs ]

(* cf. SAT Solving for Termination Proofs with Recursive
Path Orders and Dependency Pairs*)
(* suppose m and n is set of pair of term and that constraint *)
let f_geq_M_by
    (tag : int)
    (f_eq : 'a -> 'a -> 'b form)
    (f_gt : 'a -> 'a -> 'b form)
    (m : ('a * 'b form) list)
    (n : ('a * 'b form) list)
  =
  (* convert to multiset (list) to index list *)
  let rep_m = Listx.ix m
  and rep_n = Listx.ix n
  (* reference function *)
  and atM i = fst (List.nth m i)
  and atN j = fst (List.nth n j) in
  (* functions *)
  let eps_of i     = var_eps (F.sprintf "%d%d" tag i)
  and gam_of (i,j) = var_gam (F.sprintf "%d%d%d" tag i j) in
  C [ (** definition of function gamma **)
    (* totality *)
    forall rep_n (fun n_ ->
      exists rep_m (fun m_ -> gam_of (n_,m_)));
    (* uniqueness *)
    forall rep_n (fun n_ -> forall rep_m (fun m1_ ->
      forall rep_m (fun m2_ ->
        C [ gam_of (n_,m1_); gam_of (n_,m2_) ] ==>% (m1_ =? m2_))));
    (** restriction of epsilon and gamma **)
    forall rep_m (fun m_ ->
      eps_of m_ ==>%
        exists rep_n (fun n_ ->
          C [ gam_of (n_,m_);
              forall rep_n (fun ni_ ->
                if n_ = ni_ then True else  N (gam_of (ni_,m_)))
          ]));
    (* condition *)
    forall rep_m (fun m_ -> forall rep_n (fun n_ ->
      gam_of (n_,m_) ==>%
        C [ eps_of m_     ==>% f_eq (atM m_) (atN n_);
            N (eps_of m_) ==>% f_gt (atM m_) (atN n_)
          ]))
  ]

let f_gt_M_by (tag : int) f_eq f_gt m n =
  C [ (* there exists term that not refered in equality function *)
    exists (Listx.ix m)
      (fun i -> N (var_eps (F.sprintf "%d%d" tag i)));
    (* and M is greater equal N *)
    f_geq_M_by tag f_eq f_gt m n ]

(* identifier for multiset ordering *)
let order_counter = ref 0
let get_counter () =
  order_counter := !order_counter + 1;
  !order_counter


(** multiset ordering **)
let pure m = [ x, True | x <- m ]

let f_geq_M_with f_gt m' n' =
  f_geq_M_by (get_counter ()) f_eq_ac f_gt m' n'
and f_gt_M_with  f_gt m' n' =
  f_gt_M_by  (get_counter ()) f_eq_ac f_gt m' n'

let f_geq_M f_gt m n = f_geq_M_with f_gt (pure m) (pure n)
and f_gt_M  f_gt m n = f_gt_M_with  f_gt (pure m) (pure n)

(** lexicographic order **)
let f_gt_lex_by eq f_gt xs ys = match diff_by eq xs ys with
  | x::_, []   -> True
  | x::_, y::_ -> f_gt x y
  | _          -> False

let f_gt_lex f_gt xs ys = f_gt_lex_by eq_ac f_gt xs ys

(** function >_f with order > and precedence >=F **)
let f_gt_with f f_gt s t = match s,t with
  | VV _, VV _ -> False
  | _          ->
      let a,b = f_var (root_AC s), f_var (root_AC t) in
      C [ f_gt s t;
          N (a >% (f_var f))
            ==>% (a >=% b)
        ]

let f_big_head f = function
  | VV _ -> []
  | t    -> [ ti, g >+ f | ti <- args t; is_fun ti;
                           g <- [root_AC ti]; f <> g]

let f_var_occur t = match t with
  | VV _ -> failwith "Acrpo.f_var_occur: applied variable"
  | _    -> let ind = function
      | VV x -> var ("_" ^ vindex x) NAT
      | _    -> atom "1"
  in [ ind ti | ti <- args t ]

let rest_of_var_occur vs us =
  let vars_of xs = [v,x | Term (V (v,x)) <- xs] in
  C [
    (* positive *)
    C [ v >% atom "0" | v <- vs @ us ];
    (* no free variable *)
    if Ls.subseteq (vars_of vs) (vars_of us) then True else False;
  ]

let f_no_small_head f = function
  | VV _ -> []
  | t    -> [ ti, N (f >+ root_AC ti) | ti <- args t; is_fun ti ]

let tf f = function
  | VV x -> [VV x]
  | t    ->
      if root_AC t <> f then
        [t]
      else if List.length (args t) < 2 then
        invalid_arg "Acrpo.tf: argument term is not AC-term"
      else
        args t

let f_emb_no_big t =
  let set f ts =
    [ xs @ (u :: ys), N (root_AC t >+ f)
      | xs,t::ys <- split_two ts;
        is_fun t;
        v <- args t; u <- tf f v ]
  in match t with
  | FF (f,ts) -> [ FF (f,xs),p | xs,p <- set f ts ]
  | AC (f,ts) -> [ AC (f,xs),p | xs,p <- set f ts ]
  | _         -> invalid_arg "Acrpo.emb_no_big: applied variable"


let forall xs f = C [ p ==>% f x | x,p <- xs ]
let exists xs f = D [ C [p; f x] | x,p <- xs ]

let rec f_acrpo_eq s t =
  D [ f_eq_ac s t; f_acrpo s t ]
and f_acrpo s t = match s with
  | VV _ -> False
  | _    -> D [ rule s t | rule <- [
      f_acrpo_1; f_acrpo_2; f_acrpo_34;
      f_acrpo_5; f_acrpo_6 ]]
(* rule 1 *)
and f_acrpo_1 s t = match s with
  | VV _ -> False
  | _    -> D [ f_acrpo_eq si t | si <- args s ]
(* rule 2 *)
and f_acrpo_2 s t = 
  if is_vv s || is_vv t then
    False
  else
    C [ f_var_of s >% f_var_of t;
        C [ f_acrpo s ti | ti <- args t]
      ]
(* rule 3,4 *)
and f_acrpo_34 s t = match s,t with
  | FF (f,ss), FF (g,ts) when f = g ->
      C [ (* rule 3 *)
        f_lex f ==>%
          C [ f_gt_lex f_acrpo ss ts;
              C [ f_acrpo s ti | ti <- ts ]];
        (* rule 4 *)
        N (f_lex f) ==>%
            f_gt_M f_acrpo ss ts]
  | _ -> False
(* rule 5 *)
and f_acrpo_5 s t = match s,t with
  | AC (f,_), AC (g,_) when f = g ->
      exists (f_emb_no_big s)
        (fun s' -> f_acrpo_eq s' t)
  | _ -> False
(* rule 6 *)
and f_acrpo_6 s t = match s,t with
  | AC (f,ts), AC (g,us) when f = g ->
      let vs,vt = f_var_occur s, f_var_occur t in
      let ws, wt = OP ("+",vs), OP ("+",vt) in
      C [
        forall (f_emb_no_big t)
        (fun t' -> f_acrpo s t');
        f_geq_M_with (f_gt_with f f_acrpo)
        (f_no_small_head f s)
        (f_no_small_head f t);
        D [
          (* rule 6 (a) *)
          f_gt_M_with f_acrpo (f_big_head f s) (f_big_head f t);
          (* rule 6 (b) *)
          C [ rest_of_var_occur vs vt; ws >% wt ];
          (* rule 6 (c) *)
          C [ rest_of_var_occur vs vt; ws >=% wt; f_gt_M f_acrpo ts us ]]
        ]
  | _ -> False

let f_acrpo_trs rs =
  C [f_acrpo l r | l,r <- rs ]

let acrpo_statement ac_sym rs =
  f_acrpo_trs (to_AC ac_sym rs)

let acrpo ac_sym rs = Smt2.sat [acrpo_statement ac_sym rs]

let terminate ac_sym rs = acrpo ac_sym rs

(* =========================================== *)
(*** TEST ***)
module Test = struct

  let x,y,z = Term.V "x", Term.V "y", Term.V "z"
  let p x y = Term.F ("+",[x; y])
  let zero  = Term.F ("0",[])
  let s x   = Term.F ("s",[x])
  let r = [
    p zero y   , y;
    p (s x) y  , s (p x y);
    (* p (p x y) z, p x (p y z);
    p x y      , p y x *)
  ]

  let x,y,z = VV "x", VV "y", VV "z"
  let aczero = FF ("0",[])
  let acs x  = FF ("s",[x])
  let acp ts = AC ("+",ts)
  let (++) x y = acp [x;y]

  let ac_f f ts   = FF (f,ts)
  let acf ts      = ac_f "f" ts
  let acg ts      = ac_f "g" ts
  let ac_F f ts   = AC (f,ts)
  let acF ts      = ac_F "F" ts
  let acG ts      = ac_F "G" ts

  let t1 = acF [ acF [x;y]; z]
  let t2 = acg [t1]


  open Format
  let alert s b = printf "status %s: %b@." s b

  let to_AC () =
    alert "to_AC"
      begin
        to_AC ["+"] r = [
        acp [aczero; y]    ,y;
        acp [acs x; y]     ,acs (acp [x;y]);
        (* acp [acp [x;y]; z] ,acp [x; acp [y;z]];
           acp [x;y]          ,acp [y;x] *)
      ]
      end

  let flatten () =
    alert "flatten"
      begin
        flatten ((x ++ y) ++ z) =
        acp [x;y;z]
      end


  (* see example 6 on the paper *)
  let acL x = ac_f "L" [x]
  and acT x = ac_f "T" [x]
  let example =
    [
     aczero ++ x       , x;
     x ++ x            , x;
     acL (acT x)       , acL x;
     acL (acT y ++ x)  , acL (x ++ y) ++ acL y;
     acT (acT x)       , acT x;
     acT x ++ x        , acT x;
     acT (x ++ y) ++ x , acT (x ++ y);
     acT (acT y ++ x)  , acT (x ++ y) ++ acT y
   ]

  (* TEST for formulea *)
  let is_sat t = Smt2.sat t
  let is_not_sat t = not (is_sat t )

  let f_geq_M_by () = alert "f_geq_M_by" @@
    begin
      is_sat
        [ f_geq_M_by 0 (=?) (>?)
            (pure [2;1]) (pure [1;2])
        ]
      && is_sat
        [
         f_geq_M_by 0 (=?) (>?)
           (pure [2;1]) (pure [1;1])
       ]
      && is_not_sat
        [ f_geq_M_by 0 (=?) (>?)
            (pure [2]) (pure [1;2])
        ]
    end

  let f_gt_M_by () = alert "f_gt_M_by" @@
    is_sat
      [ f_gt_M_by 0 (=?) (>?) (pure [2;1]) (pure [1;1;1])]

  let shallow_gtM m n =
    f_gt_M_with (fun _ _ -> True) m n

  let f_gt_lex () = alert "f_gt_lex" @@
    begin
      f_gt_lex_by (=) (>?) [1;2;1] [1;1;3;3]
        = (2 >? 1)
    end

  let f_gt_with () = alert "f_gt_with" @@
    let test =
      f_gt_with "F" (fun _ _ -> True) (acF []) (acG [])
    in begin
      is_sat
        [ test ]
      && is_not_sat
        [ f_var "G" >% f_var "F"; test ]
    end

  let f_big_head () = alert "f_big_head" @@
    let test =
      shallow_gtM
        (f_big_head "f" @@ acf [ acF []; acf [] ])
        (f_big_head "f" @@ acf [ acG []; acf [] ])
    in begin
      is_sat
        [ test ]
      && is_not_sat
        [ f_var "f" >% f_var "F"; f_var "G" >% f_var "f"; test ]
    end

  let f_var_occur () = alert "f_var_occur" @@
    begin
      is_sat
        [ OP ("+", f_var_occur (acF [x;y])) >%
          OP ("+", f_var_occur (acG [x])) ]
    && is_not_sat
        [ OP ("+", f_var_occur (acF [x;y])) >%
          OP ("+", f_var_occur (acG [x;y])) ]
    end

  let f_no_small_head () = alert "f_no_small_head" @@
    let test =
      f_gt_M_with (fun _ _ -> True)
        (f_no_small_head "f" @@ acf [ acF [] ])
        (f_no_small_head "f" @@ acf [ acG [] ])
    in begin
      is_sat
        [ test ]
      && is_not_sat
        [ f_var "G" >% f_var "f"; f_var "f" >% f_var "F"; test ]
    end

  let f_emb_no_big () = alert "f_emb_no_big" @@
    let test = shallow_gtM
        (f_emb_no_big (acF [ acf [aczero] ]))
        (f_emb_no_big (acG [ acf [aczero] ]))
    in begin
      is_sat
        [ test ]
      && is_not_sat
        [ f_var "f" >% f_var "F"; f_var "G" >% f_var "f"; test ]
    end

  let f_acrpo_1  () = alert "f_acrpo_1" @@
    false
  let f_acrpo_2  () = alert "f_acrpo_2" @@
    false
  let f_acrpo_34 () = alert "f_acrpo_34" @@
    false
  let f_acrpo_5  () = alert "f_acrpo_5" @@
    false
  let f_acrpo_6  () = alert "f_acrpo_6" @@
    false

  let f_acrpo_trs () = alert "f_acrpo_trs" @@
      is_sat [f_acrpo_trs example]

  let f_run () = List.iter (fun test -> test ())
      [ f_geq_M_by; f_gt_M_by; f_gt_lex; f_gt_with;
        f_big_head; f_var_occur; f_no_small_head; f_emb_no_big;
        f_acrpo_1; f_acrpo_2; f_acrpo_34; f_acrpo_5; f_acrpo_6;
        f_acrpo_trs]

end
