open Form
open Format
open Console

(** logical formula to SMT2 formula converter **)
type t = LIA | LRA

(* util functions *)
let command c list = sprintf "(%s %s)" c (String.concat " " list)

(* convert to SMT2 format *)
let rec term_to_string = function
  | A s       -> s
  | V (v,t)   -> v
  | BV bs     -> "0b" ^ String.concat "" [ if b then "1" else "0" | b <- bs ]
  | Tuple fs  -> command "mk-tuple" (List.map form_to_string1 fs)
and form_to_string1 = function
  | True         -> "true"
  | False        -> "false"
  | Term t       -> term_to_string t
  | N f          -> command "not" [form_to_string1 f]
  | C []         -> "true"
  | C cnf        -> command "and" (List.map form_to_string1 cnf)
  | D []         -> "false"
  | D dnf        -> command "or"  (List.map form_to_string1 dnf)
  | OP (bin,fs)  -> command bin   (List.map form_to_string1 fs)
  | IMP (a,b)    -> command "=>"[form_to_string1 a; form_to_string1 b]
  | EQ (a,b)     -> command "=" [form_to_string1 a; form_to_string1 b]
  | GT (a,b)     -> command ">" [form_to_string1 a; form_to_string1 b]
  | LT (a,b)     -> command "<" [form_to_string1 a; form_to_string1 b]

let form_to_string f = form_to_string1 (Form.simply f)

(* generate definition string *)
let type_to_string = function
  | BOOL        -> "Bool"
  | INT         -> "Int"
  | REAL        -> "Real"
  | BITVECTOR n -> sprintf "(bitvector %d)" n
  | _           -> invalid_arg "Smt2.type_to_string: bad definition"
let definition_to_string = function
  | V (v,NAT) -> sprintf "(declare-const %s %s)(assert (>= %s 0))" v "Int" v
  | V (v,t)   -> sprintf "(declare-const %s %s)" v (type_to_string t)
  | _         -> invalid_arg "Smt2.definitioon_to_string: bad definition"

(****************** GENERATOR *****************)

(** generate typed variable **)
let def_nat version xs =
  Ls.uniq
    begin
    if version = 1 then
      [ Term (V (x,NAT)) | x <- xs ]
    else
      [ Term (V (x,INT)) | x <- xs ] @
      [ GT (Term (V (x,INT)), Term (A "-1")) | x <- xs ]
    end

let def_bool css = Ls.uniq
  [ Term (V (cs, BOOL)) | cs <- css ]

let def_int css = Ls.uniq
  [ Term (V (cs, INT)) | cs <- css ]

(* definition generator *)
let gen_definition fs =
  [ definition_to_string v | v <- Ls.uniq [ v | f <- fs; v <- variables f ]]

(** assertion **)
let gen_assert fs = Ls.uniq
  [command "assert" [form_to_string f] | f <- fs ]


(** default logic is linear integer logic **)
let gen_logic = function
    LIA -> ["(set-logic QF_LIA)"]
  | LRA -> ["(set-logic QF_LRA)"]

(** checking and modeling command **)
let gen_instruction = [ "(check-sat)";"(get-model)" ]


(****************** TRANSLATE TOOL *****************)
(* all in one *)
let to_string t fs =
  gen_logic t
  @ gen_definition fs
  @ gen_assert fs
  @ gen_instruction

(* abbreviations *)
let smt2 fs = to_string LIA fs
let smt2_r fs = to_string LRA fs


(** SAT FUNCTIONS **)
(* --------------------------------------------- *)
open Lexing
let syntax_error path p =
  eprintf "SMT2 file %S at line %d, character %d:@.Syntax error.@."
    path p.pos_lnum (p.pos_cnum - p.pos_bol)
  ; exit 1

let run_and_parse parser0 solver run_func (path,ch) lines =
  let parse x =
    (* DEBUG *)
    if !Ref.verbose then
      (eprintf "-- SMT file is\n%s@." (run ("cat " ^ path));
       eprintf "-- output of %s is\n%s@." !Ref.smt_solver x);
    let lexbuf = Lexing.from_string x in
    try
      parser0 Smt2_lexer.evidence lexbuf
    with Parsing.Parse_error ->
      syntax_error path lexbuf.lex_curr_p
  in (* RUN! *)
  run_func
    (fun tmp ->
      parse (run (String.concat " " [solver; tmp])))
    (path,ch)
    lines

(* return evidence if sat (supported only Yices format) *)
let sat_with_proof solver run_func (path,ch) lines =
  run_and_parse Smt2_parser.sat_proof solver run_func (path,ch) lines

(* return boolean *)
let sat_with_no_proof solver run_func (path,ch) lines =
  run_and_parse Smt2_parser.sat solver run_func (path,ch) lines

(* save as temporary file *)
let sat_save solver fs =
  sat_with_no_proof
    solver
    run_with_save
    (Filename.open_temp_file "coll" "smt2")
    (smt2 fs)

(* create and open temp file with tmp dir *)
let open_temp_here tmp pre suf =
  (try Unix.mkdir tmp 0o775 with Unix.Unix_error (_,_,_) -> ());
  Filename.open_temp_file ~temp_dir:tmp pre suf

(* use non removing version if output path set *)
let sat ?(solver=(!Ref.smt_solver)) fs =
  let lines = smt2 fs in
  match !Ref.dump_file with
  | Some temp_dir ->
      sat_with_no_proof solver run_with_save (open_temp_here temp_dir "coll" ".smt2") lines
  | None      ->
      sat_with_no_proof solver run_with      (Filename.open_temp_file "coll" ".smt2") lines


(** CHECK **)
let sat_check () = sat []


(* use non removing version if output path set *)
let sat_proof ?(solver=(!Ref.smt_solver)) fs =
  let lines = smt2 fs in
  match !Ref.dump_file with
  | Some temp_dir ->
      sat_with_proof solver run_with_save (open_temp_here temp_dir "coll" "smt2") lines
  | None      ->
      sat_with_proof solver run_with      (Filename.open_temp_file "coll" "smt2") lines


(** for debug **)
let rec simple' = function
    | C ts  -> [ f | t <- ts; f <- simple' t ]
    | other -> [other]
let simple fs = [ f' | f <- fs; f' <- simple' f ]

let show fs = List.iter print_endline @@ smt2 fs
