From e2f4fc23076032935daa85e305d0453dc1470d73 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 28 Aug 2022 17:44:33 -0700 Subject: [PATCH] overhaul of proof format for new solver This commit overhauls the proof format (in development) for the new core. NOTE: this functionality is work in progress with a long way to go. It is shielded by the sat.euf option, which is off by default and in pre-release state. It is too early to fuzz or use it. It is pushed into master to shed light on road-map for certifying inferences of sat.euf. It retires the ad-hoc extension of DRUP used by the SAT solver. Instead it relies on SMT with ad-hoc extensions for proof terms. It adds the following commands (consumed by proof_cmds.cpp): - assume - for input clauses - learn - when a clause is learned (or redundant clause is added) - del - when a clause is deleted. The commands take a list of expressions of type Bool and the last argument can optionally be of type Proof. When the last argument is of type Proof it is provided as a hint to justify the learned clause. Proof hints can be checked using a self-contained proof checker. The sat/smt/euf_proof_checker.h class provides a plugin dispatcher for checkers. It is instantiated with a checker for arithmetic lemmas, so far for Farkas proofs. Use example: ``` (set-option :sat.euf true) (set-option :tactic.default_tactic smt) (set-option :sat.smt.proof f.proof) (declare-const x Int) (declare-const y Int) (declare-const z Int) (declare-const u Int) (assert (< x y)) (assert (< y z)) (assert (< z x)) (check-sat) ``` Run z3 on a file with above content. Then run z3 on f.proof ``` (verified-smt) (verified-smt) (verified-smt) (verified-farkas) (verified-smt) ``` --- src/ast/ast.cpp | 4 +- src/ast/ast_pp_util.cpp | 6 +- src/ast/scoped_proof.h | 4 +- src/cmd_context/CMakeLists.txt | 1 + src/cmd_context/cmd_context.cpp | 1 + src/cmd_context/cmd_context.h | 5 + src/cmd_context/proof_cmds.cpp | 207 ++++++++++++++ src/cmd_context/proof_cmds.h | 49 ++++ src/sat/dimacs.cpp | 119 +------- src/sat/dimacs.h | 7 - src/sat/sat_config.cpp | 8 +- src/sat/sat_config.h | 2 +- src/sat/sat_drat.cpp | 246 ++-------------- src/sat/sat_drat.h | 18 +- src/sat/sat_params.pyg | 2 +- src/sat/sat_solver.cpp | 4 +- src/sat/sat_types.h | 19 +- src/sat/smt/CMakeLists.txt | 1 + src/sat/smt/arith_axioms.cpp | 9 +- src/sat/smt/arith_diagnostics.cpp | 68 ++++- src/sat/smt/arith_proof_checker.h | 58 +++- src/sat/smt/arith_solver.cpp | 21 +- src/sat/smt/arith_solver.h | 63 ++++- src/sat/smt/bv_solver.cpp | 4 +- src/sat/smt/euf_proof.cpp | 263 ++++++++--------- src/sat/smt/euf_proof_checker.cpp | 48 ++++ src/sat/smt/euf_proof_checker.h | 46 +++ src/sat/smt/euf_solver.cpp | 3 - src/sat/smt/euf_solver.h | 51 ++-- src/sat/smt/q_solver.cpp | 14 +- src/sat/smt/sat_th.cpp | 38 +-- src/sat/smt/sat_th.h | 34 ++- src/sat/tactic/goal2sat.cpp | 8 - src/shell/drat_frontend.cpp | 450 ++++-------------------------- src/shell/drat_frontend.h | 2 +- src/shell/main.cpp | 2 +- src/shell/smtlib_frontend.cpp | 2 + 37 files changed, 809 insertions(+), 1078 deletions(-) create mode 100644 src/cmd_context/proof_cmds.cpp create mode 100644 src/cmd_context/proof_cmds.h create mode 100644 src/sat/smt/euf_proof_checker.cpp create mode 100644 src/sat/smt/euf_proof_checker.h diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index c51ea4e32..f60512e63 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -889,8 +889,10 @@ void basic_decl_plugin::set_manager(ast_manager * m, family_id id) { } void basic_decl_plugin::get_sort_names(svector & sort_names, symbol const & logic) { - if (logic == symbol::null) + if (logic == symbol::null) { sort_names.push_back(builtin_name("bool", BOOL_SORT)); + sort_names.push_back(builtin_name("Proof", PROOF_SORT)); // reserved name? + } sort_names.push_back(builtin_name("Bool", BOOL_SORT)); } diff --git a/src/ast/ast_pp_util.cpp b/src/ast/ast_pp_util.cpp index 76d50cbe6..a74566199 100644 --- a/src/ast/ast_pp_util.cpp +++ b/src/ast/ast_pp_util.cpp @@ -175,14 +175,14 @@ std::ostream& ast_pp_util::define_expr(std::ostream& out, expr* n) { visit.pop_back(); if (to_app(n)->get_num_args() > 0) { out << "(define-const $" << n->get_id() << " " << mk_pp(n->get_sort(), m) << " ("; - out << to_app(n)->get_name(); // fixme + out << mk_ismt2_func(to_app(n)->get_decl(), m); for (auto* e : *to_app(n)) display_expr_def(out << " ", e); - out << ")\n"; + out << "))\n"; } continue; } - out << "(define-const $" << n->get_id() << " " << mk_pp(n->get_sort(), m) << " " << mk_pp(n, m) << "\n"; + out << "(define-const $" << n->get_id() << " " << mk_pp(n->get_sort(), m) << " " << mk_pp(n, m) << ")\n"; m_defined.push_back(n); m_is_defined.mark(n, true); visit.pop_back(); diff --git a/src/ast/scoped_proof.h b/src/ast/scoped_proof.h index c8071031c..7943c6eb9 100644 --- a/src/ast/scoped_proof.h +++ b/src/ast/scoped_proof.h @@ -29,8 +29,8 @@ public: m.toggle_proof_mode(mode); } ~scoped_proof_mode() { - m.toggle_proof_mode(m_mode); - } + m.toggle_proof_mode(m_mode); + } }; diff --git a/src/cmd_context/CMakeLists.txt b/src/cmd_context/CMakeLists.txt index f8c1aa38f..2c65fd374 100644 --- a/src/cmd_context/CMakeLists.txt +++ b/src/cmd_context/CMakeLists.txt @@ -11,6 +11,7 @@ z3_add_component(cmd_context simplify_cmd.cpp tactic_cmds.cpp tactic_manager.cpp + proof_cmds.cpp COMPONENT_DEPENDENCIES rewriter solver diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index 2b19041d0..97ae0fbc1 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -561,6 +561,7 @@ cmd_context::~cmd_context() { finalize_cmds(); finalize_tactic_cmds(); finalize_probes(); + m_proof_cmds = nullptr; reset(true); m_mcs.reset(); m_solver = nullptr; diff --git a/src/cmd_context/cmd_context.h b/src/cmd_context/cmd_context.h index a51820ac1..f3f593bd6 100644 --- a/src/cmd_context/cmd_context.h +++ b/src/cmd_context/cmd_context.h @@ -39,6 +39,7 @@ Notes: #include "solver/progress_callback.h" #include "cmd_context/pdecl.h" #include "cmd_context/tactic_manager.h" +#include "cmd_context/proof_cmds.h" #include "params/context_params.h" @@ -172,6 +173,7 @@ public: bool owns_manager() const { return m_manager != nullptr; } }; + class cmd_context : public progress_callback, public tactic_manager, public ast_printer_context { public: enum status { @@ -225,6 +227,7 @@ protected: bool m_ignore_check = false; // used by the API to disable check-sat() commands when parsing SMT 2.0 files. bool m_exit_on_error = false; bool m_allow_duplicate_declarations = false; + scoped_ptr m_proof_cmds; static std::ostringstream g_error_stream; @@ -397,6 +400,8 @@ public: pdecl_manager & pm() const { if (!m_pmanager) const_cast(this)->init_manager(); return *m_pmanager; } sexpr_manager & sm() const { if (!m_sexpr_manager) const_cast(this)->m_sexpr_manager = alloc(sexpr_manager); return *m_sexpr_manager; } + proof_cmds& get_proof_cmds() { if (!m_proof_cmds) m_proof_cmds = proof_cmds::mk(m()); return *m_proof_cmds; } + void set_solver_factory(solver_factory * s); void set_check_sat_result(check_sat_result * r) { m_check_sat_result = r; } check_sat_result * get_check_sat_result() const { return m_check_sat_result.get(); } diff --git a/src/cmd_context/proof_cmds.cpp b/src/cmd_context/proof_cmds.cpp new file mode 100644 index 000000000..ae64ef701 --- /dev/null +++ b/src/cmd_context/proof_cmds.cpp @@ -0,0 +1,207 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + proof_cmds.cpp + +Abstract: + + Commands for reading and checking proofs. + +Author: + + Nikolaj Bjorner (nbjorner) 2022-8-26 + +Notes: + +- add theory hint bypass using proof checker plugins of SMT + - arith_proof_checker.h is currently +- could use m_drat for drup premises. + +--*/ + +#include "util/small_object_allocator.h" +#include "ast/ast_util.h" +#include "cmd_context/cmd_context.h" +#include "smt/smt_solver.h" +#include "sat/sat_solver.h" +#include "sat/sat_drat.h" +#include "sat/smt/euf_proof_checker.h" +#include + +class smt_checker { + ast_manager& m; + params_ref m_params; + euf::proof_checker m_checker; + + scoped_ptr m_solver; + +#if 0 + sat::solver sat_solver; + sat::drat m_drat; + sat::literal_vector m_units; + sat::literal_vector m_drup_units; + + void add_units() { + auto const& units = m_drat.units(); + for (unsigned i = m_units.size(); i < units.size(); ++i) + m_units.push_back(units[i].first); + } +#endif + +public: + smt_checker(ast_manager& m): + m(m), + m_checker(m) + // sat_solver(m_params, m.limit()), + // m_drat(sat_solver) + { + m_solver = mk_smt_solver(m, m_params, symbol()); + } + + void check(expr_ref_vector const& clause, expr* proof_hint) { + + if (m_checker.check(clause, proof_hint)) { + if (is_app(proof_hint)) + std::cout << "(verified-" << to_app(proof_hint)->get_name() << ")\n"; + else + std::cout << "(verified-checker)\n"; + return; + } + + m_solver->push(); + for (expr* lit : clause) + m_solver->assert_expr(m.mk_not(lit)); + lbool is_sat = m_solver->check_sat(); + if (is_sat != l_false) { + std::cout << "did not verify: " << is_sat << " " << clause << "\n\n"; + m_solver->display(std::cout); + if (is_sat == l_true) { + model_ref mdl; + m_solver->get_model(mdl); + std::cout << *mdl << "\n"; + } + exit(0); + } + m_solver->pop(1); + std::cout << "(verified-smt)\n"; + // assume(clause); + } + + void assume(expr_ref_vector const& clause) { + m_solver->assert_expr(mk_or(clause)); + } +}; + +class proof_cmds::imp { + ast_manager& m; + expr_ref_vector m_lits; + expr_ref m_proof_hint; + smt_checker m_checker; +public: + imp(ast_manager& m): m(m), m_lits(m), m_proof_hint(m), m_checker(m) {} + + void add_literal(expr* e) { + if (m.is_proof(e)) + m_proof_hint = e; + else + m_lits.push_back(e); + } + + void end_assumption() { + m_checker.assume(m_lits); + m_lits.reset(); + m_proof_hint.reset(); + } + + void end_learned() { + m_checker.check(m_lits, m_proof_hint); + m_lits.reset(); + m_proof_hint.reset(); + } + + void end_deleted() { + m_lits.reset(); + m_proof_hint.reset(); + } +}; + +proof_cmds* proof_cmds::mk(ast_manager& m) { + return alloc(proof_cmds, m); +} + +proof_cmds::proof_cmds(ast_manager& m) { + m_imp = alloc(imp, m); +} + +proof_cmds::~proof_cmds() { + dealloc(m_imp); +} + +void proof_cmds::add_literal(expr* e) { + m_imp->add_literal(e); +} + +void proof_cmds::end_assumption() { + m_imp->end_assumption(); +} + +void proof_cmds::end_learned() { + m_imp->end_learned(); +} + +void proof_cmds::end_deleted() { + m_imp->end_deleted(); +} + +// assumption +class assume_cmd : public cmd { +public: + assume_cmd():cmd("assume") {} + char const* get_usage() const override { return "+"; } + char const * get_descr(cmd_context& ctx) const override { return "proof command for adding assumption (input assertion)"; } + unsigned get_arity() const override { return VAR_ARITY; } + void prepare(cmd_context & ctx) override {} + void finalize(cmd_context & ctx) override {} + void failure_cleanup(cmd_context & ctx) override {} + cmd_arg_kind next_arg_kind(cmd_context & ctx) const override { return CPK_EXPR; } + void set_next_arg(cmd_context & ctx, expr * arg) override { ctx.get_proof_cmds().add_literal(arg); } + void execute(cmd_context& ctx) override { ctx.get_proof_cmds().end_assumption(); } +}; + +// deleted clause +class del_cmd : public cmd { +public: + del_cmd():cmd("del") {} + char const* get_usage() const override { return "+"; } + char const * get_descr(cmd_context& ctx) const override { return "proof command for clause deletion"; } + unsigned get_arity() const override { return VAR_ARITY; } + void prepare(cmd_context & ctx) override {} + void finalize(cmd_context & ctx) override {} + void failure_cleanup(cmd_context & ctx) override {} + cmd_arg_kind next_arg_kind(cmd_context & ctx) const override { return CPK_EXPR; } + void set_next_arg(cmd_context & ctx, expr * arg) override { ctx.get_proof_cmds().add_literal(arg); } + void execute(cmd_context& ctx) override { ctx.get_proof_cmds().end_deleted(); } +}; + +// learned/redundant clause +class learn_cmd : public cmd { +public: + learn_cmd():cmd("learn") {} + char const* get_usage() const override { return "+"; } + char const * get_descr(cmd_context& ctx) const override { return "proof command for learned (redundant) clauses"; } + unsigned get_arity() const override { return VAR_ARITY; } + void prepare(cmd_context & ctx) override {} + void finalize(cmd_context & ctx) override {} + void failure_cleanup(cmd_context & ctx) override {} + cmd_arg_kind next_arg_kind(cmd_context & ctx) const override { return CPK_EXPR; } + void set_next_arg(cmd_context & ctx, expr * arg) override { ctx.get_proof_cmds().add_literal(arg); } + void execute(cmd_context& ctx) override { ctx.get_proof_cmds().end_learned(); } +}; + +void install_proof_cmds(cmd_context & ctx) { + ctx.insert(alloc(del_cmd)); + ctx.insert(alloc(learn_cmd)); + ctx.insert(alloc(assume_cmd)); +} diff --git a/src/cmd_context/proof_cmds.h b/src/cmd_context/proof_cmds.h new file mode 100644 index 000000000..71589b371 --- /dev/null +++ b/src/cmd_context/proof_cmds.h @@ -0,0 +1,49 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + proof_cmds.h + +Abstract: + Commands for reading proofs. + +Author: + + Nikolaj Bjorner (nbjorner) 2022-8-26 + +Notes: + +--*/ +#pragma once + +/** + proof_cmds is a structure that tracks an evidence trail. + + The main interface is to: + add literals one by one, + add proof hints + until receiving end-command: assumption, learned, deleted. + Evidence can be checked: + - By DRUP + - Theory lemmas + +*/ + + +class proof_cmds { + class imp; + imp* m_imp; +public: + static proof_cmds* mk(ast_manager& m); + proof_cmds(ast_manager& m); + ~proof_cmds(); + void add_literal(expr* e); + void end_assumption(); + void end_learned(); + void end_deleted(); +}; + +class cmd_context; +void install_proof_cmds(cmd_context & ctx); + diff --git a/src/sat/dimacs.cpp b/src/sat/dimacs.cpp index 629fefb7b..1d19a60c7 100644 --- a/src/sat/dimacs.cpp +++ b/src/sat/dimacs.cpp @@ -112,27 +112,6 @@ static void read_clause(Buffer & in, std::ostream& err, sat::literal_vector & li } } -template -static void read_pragma(Buffer & in, std::ostream& err, std::string& p, sat::proof_hint& h) { - skip_whitespace(in); - if (*in != 'p') - return; - ++in; - while (*in == ' ') - ++in; - while (true) { - if (*in == EOF) - break; - if (*in == '\n') { - ++in; - break; - } - p.push_back(*in); - ++in; - } - if (!p.empty()) - h.from_string(p); -} template @@ -177,25 +156,7 @@ namespace dimacs { std::ostream& operator<<(std::ostream& out, drat_pp const& p) { auto const& r = p.r; sat::status_pp pp(r.m_status, p.th); - switch (r.m_tag) { - case drat_record::tag_t::is_clause: - if (!r.m_pragma.empty()) - return out << pp << " " << r.m_lits << " 0 p " << r.m_pragma << "\n"; - return out << pp << " " << r.m_lits << " 0\n"; - case drat_record::tag_t::is_node: - return out << "e " << r.m_node_id << " " << r.m_name << " " << r.m_args << "0\n"; - case drat_record::tag_t::is_sort: - return out << "s " << r.m_node_id << " " << r.m_name << " " << r.m_args << "0\n"; - case drat_record::tag_t::is_decl: - return out << "f " << r.m_node_id << " " << r.m_name << " " << r.m_args << "0\n"; - case drat_record::tag_t::is_bool_def: - return out << "b " << r.m_node_id << " " << r.m_args << "0\n"; - case drat_record::tag_t::is_var: - return out << "v " << r.m_node_id << " " << r.m_name << " " << r.m_args << "0\n"; - case drat_record::tag_t::is_quantifier: - return out << "q " << r.m_node_id << " " << r.m_name << " " << r.m_args << "0\n"; - } - return out; + return out << pp << " " << r.m_lits << " 0\n"; } char const* drat_parser::parse_identifier() { @@ -266,47 +227,10 @@ namespace dimacs { } bool drat_parser::next() { - int n, b, e, theory_id; - auto parse_ast = [&](drat_record::tag_t tag) { - ++in; - skip_whitespace(in); - n = parse_int(in, err); - skip_whitespace(in); - m_record.m_name = parse_sexpr(); - m_record.m_tag = tag; - m_record.m_node_id = n; - m_record.m_args.reset(); - while (true) { - n = parse_int(in, err); - if (n == 0) - break; - if (n < 0) - throw lex_error(); - m_record.m_args.push_back(n); - } - }; - auto parse_var = [&]() { - ++in; - skip_whitespace(in); - n = parse_int(in, err); - skip_whitespace(in); - m_record.m_name = parse_sexpr(); - m_record.m_tag = drat_record::tag_t::is_var; - m_record.m_node_id = n; - m_record.m_args.reset(); - n = parse_int(in, err); - if (n < 0) - throw lex_error(); - m_record.m_args.push_back(n); - n = parse_int(in, err); - if (n != 0) - throw lex_error(); - }; + int theory_id; try { loop: skip_whitespace(in); - m_record.m_pragma.clear(); - m_record.m_hint.reset(); switch (*in) { case EOF: return false; @@ -321,7 +245,6 @@ namespace dimacs { ++in; skip_whitespace(in); read_clause(in, err, m_record.m_lits); - m_record.m_tag = drat_record::tag_t::is_clause; m_record.m_status = sat::status::input(); break; case 'a': @@ -331,49 +254,13 @@ namespace dimacs { theory_id = read_theory_id(); skip_whitespace(in); read_clause(in, err, m_record.m_lits); - read_pragma(in, err, m_record.m_pragma, m_record.m_hint); - m_record.m_tag = drat_record::tag_t::is_clause; m_record.m_status = sat::status::th(false, theory_id); break; - case 'e': - // parse expression definition - parse_ast(drat_record::tag_t::is_node); - break; - case 'v': - parse_var(); - break; - case 'q': - parse_ast(drat_record::tag_t::is_quantifier); - break; - case 'f': - // parse function declaration - parse_ast(drat_record::tag_t::is_decl); - break; - case 's': - // parse sort declaration (not used) - parse_ast(drat_record::tag_t::is_sort); - break; - case 'b': - // parse bridge between Boolean variable identifier b - // and expression identifier e, which is of type Bool - ++in; - skip_whitespace(in); - b = parse_int(in, err); - n = parse_int(in, err); - e = parse_int(in, err); - if (e != 0) - throw lex_error(); - m_record.m_tag = drat_record::tag_t::is_bool_def; - m_record.m_node_id = b; - m_record.m_args.reset(); - m_record.m_args.push_back(n); - break; case 'd': // parse clause deletion ++in; skip_whitespace(in); read_clause(in, err, m_record.m_lits); - m_record.m_tag = drat_record::tag_t::is_clause; m_record.m_status = sat::status::deleted(); break; case 'r': @@ -383,13 +270,11 @@ namespace dimacs { skip_whitespace(in); theory_id = read_theory_id(); read_clause(in, err, m_record.m_lits); - m_record.m_tag = drat_record::tag_t::is_clause; m_record.m_status = sat::status::th(true, theory_id); break; default: // parse clause redundant modulo DRAT (or mostly just DRUP) read_clause(in, err, m_record.m_lits); - m_record.m_tag = drat_record::tag_t::is_clause; m_record.m_status = sat::status::redundant(); break; } diff --git a/src/sat/dimacs.h b/src/sat/dimacs.h index 7a5a66283..ca6aae07b 100644 --- a/src/sat/dimacs.h +++ b/src/sat/dimacs.h @@ -53,18 +53,11 @@ namespace dimacs { }; struct drat_record { - enum class tag_t { is_clause, is_node, is_decl, is_sort, is_bool_def, is_var, is_quantifier }; - tag_t m_tag{ tag_t::is_clause }; // a clause populates m_lits and m_status // a node populates m_node_id, m_name, m_args // a bool def populates m_node_id and one element in m_args sat::literal_vector m_lits; sat::status m_status = sat::status::redundant(); - unsigned m_node_id = 0; - std::string m_name; - unsigned_vector m_args; - std::string m_pragma; - sat::proof_hint m_hint; }; struct drat_pp { diff --git a/src/sat/sat_config.cpp b/src/sat/sat_config.cpp index b911ee971..d21ec5b93 100644 --- a/src/sat/sat_config.cpp +++ b/src/sat/sat_config.cpp @@ -197,10 +197,10 @@ namespace sat { m_drat_check_unsat = p.drat_check_unsat(); m_drat_check_sat = p.drat_check_sat(); m_drat_file = p.drat_file(); - m_drat = !p.drat_disable() && (sp.lemmas2console() || m_drat_check_unsat || m_drat_file.is_non_empty_string() || m_drat_check_sat) && p.threads() == 1; + m_smt_proof = p.smt_proof(); + m_drat = !p.drat_disable() && (sp.lemmas2console() || m_drat_check_unsat || m_drat_file.is_non_empty_string() || m_smt_proof.is_non_empty_string() || m_drat_check_sat) && p.threads() == 1; m_drat_binary = p.drat_binary(); m_drat_activity = p.drat_activity(); - m_drup_trim = p.drup_trim(); m_dyn_sub_res = p.dyn_sub_res(); // Parameters used in Liang, Ganesh, Poupart, Czarnecki AAAI 2016. @@ -254,10 +254,6 @@ namespace sat { sat_simplifier_params ssp(_p); m_elim_vars = ssp.elim_vars(); -#if 0 - if (m_drat && (m_xor_solver || m_card_solver)) - throw sat_param_exception("DRAT checking only works for pure CNF"); -#endif } void config::collect_param_descrs(param_descrs & r) { diff --git a/src/sat/sat_config.h b/src/sat/sat_config.h index 34ffeed5c..7d98b092c 100644 --- a/src/sat/sat_config.h +++ b/src/sat/sat_config.h @@ -177,9 +177,9 @@ namespace sat { bool m_drat; bool m_drat_binary; symbol m_drat_file; + symbol m_smt_proof; bool m_drat_check_unsat; bool m_drat_check_sat; - bool m_drup_trim; bool m_drat_activity; bool m_card_solver; diff --git a/src/sat/sat_drat.cpp b/src/sat/sat_drat.cpp index 8e63b4a19..93f4c2e91 100644 --- a/src/sat/sat_drat.cpp +++ b/src/sat/sat_drat.cpp @@ -52,8 +52,7 @@ namespace sat { void drat::updt_config() { m_check_unsat = s.get_config().m_drat_check_unsat; m_check_sat = s.get_config().m_drat_check_sat; - m_trim = s.get_config().m_drup_trim; - m_check = m_check_unsat || m_check_sat || m_trim; + m_check = m_check_unsat || m_check_sat; m_activity = s.get_config().m_drat_activity; } @@ -130,14 +129,6 @@ namespace sat { } } buffer[len++] = '0'; - if (st.get_hint()) { - buffer[len++] = ' '; - buffer[len++] = 'p'; - buffer[len++] = ' '; - auto* ps = st.get_hint(); - for (auto ch : ps->to_string()) - buffer[len++] = ch; - } buffer[len++] = '\n'; m_out->write(buffer, len); } @@ -210,8 +201,6 @@ namespace sat { if (st.is_redundant() && st.is_sat()) verify(1, &l); - if (m_trim) - m_proof.push_back({mk_clause(1, &l, st.is_redundant()), st}); if (st.is_deleted()) return; @@ -230,8 +219,7 @@ namespace sat { IF_VERBOSE(20, trace(verbose_stream(), 2, lits, st);); if (st.is_deleted()) { - if (m_trim) - m_proof.push_back({mk_clause(2, lits, true), st}); + ; } else { if (st.is_redundant() && st.is_sat()) @@ -658,7 +646,7 @@ namespace sat { verify(0, nullptr); SASSERT(m_inconsistent); } - if (m_print_clause) m_print_clause->on_clause(0, nullptr, status::redundant()); + if (m_clause_eh) m_clause_eh->on_clause(0, nullptr, status::redundant()); } void drat::add(literal l, bool learned) { ++m_stats.m_num_add; @@ -666,7 +654,8 @@ namespace sat { if (m_out) dump(1, &l, st); if (m_bout) bdump(1, &l, st); if (m_check) append(l, st); - if (m_print_clause) m_print_clause->on_clause(1, &l, st); + TRACE("sat", tout << "add " << m_clause_eh << "\n"); + if (m_clause_eh) m_clause_eh->on_clause(1, &l, st); } void drat::add(literal l1, literal l2, status st) { if (st.is_deleted()) @@ -677,7 +666,7 @@ namespace sat { if (m_out) dump(2, ls, st); if (m_bout) bdump(2, ls, st); if (m_check) append(l1, l2, st); - if (m_print_clause) m_print_clause->on_clause(2, ls, st); + if (m_clause_eh) m_clause_eh->on_clause(2, ls, st); } void drat::add(clause& c, status st) { if (st.is_deleted()) @@ -687,7 +676,7 @@ namespace sat { if (m_out) dump(c.size(), c.begin(), st); if (m_bout) bdump(c.size(), c.begin(), st); if (m_check) append(mk_clause(c), st); - if (m_print_clause) m_print_clause->on_clause(c.size(), c.begin(), st); + if (m_clause_eh) m_clause_eh->on_clause(c.size(), c.begin(), st); } void drat::add(literal_vector const& lits, status st) { @@ -709,8 +698,8 @@ namespace sat { if (m_out) dump(sz, lits, st); - if (m_print_clause) - m_print_clause->on_clause(sz, lits, st); + if (m_clause_eh) + m_clause_eh->on_clause(sz, lits, st); } void drat::add(literal_vector const& c) { @@ -730,8 +719,8 @@ namespace sat { } } } - if (m_print_clause) - m_print_clause->on_clause(c.size(), c.data(), status::redundant()); + if (m_clause_eh) + m_clause_eh->on_clause(c.size(), c.data(), status::redundant()); } void drat::del(literal l) { @@ -739,6 +728,7 @@ namespace sat { if (m_out) dump(1, &l, status::deleted()); if (m_bout) bdump(1, &l, status::deleted()); if (m_check) append(l, status::deleted()); + if (m_clause_eh) m_clause_eh->on_clause(1, &l, status::deleted()); } void drat::del(literal l1, literal l2) { @@ -747,6 +737,7 @@ namespace sat { if (m_out) dump(2, ls, status::deleted()); if (m_bout) bdump(2, ls, status::deleted()); if (m_check) append(l1, l2, status::deleted()); + if (m_clause_eh) m_clause_eh->on_clause(2, ls, status::deleted()); } void drat::del(clause& c) { @@ -764,7 +755,8 @@ namespace sat { ++m_stats.m_num_del; if (m_out) dump(c.size(), c.begin(), status::deleted()); if (m_bout) bdump(c.size(), c.begin(), status::deleted()); - if (m_check) append(mk_clause(c), status::deleted()); + if (m_check) append(mk_clause(c), status::deleted()); + if (m_clause_eh) m_clause_eh->on_clause(c.size(), c.begin(), status::deleted()); } clause& drat::mk_clause(clause& c) { @@ -780,23 +772,9 @@ namespace sat { if (m_out) dump(c.size(), c.begin(), status::deleted()); if (m_bout) bdump(c.size(), c.begin(), status::deleted()); if (m_check) append(mk_clause(c.size(), c.begin(), true), status::deleted()); + if (m_clause_eh) m_clause_eh->on_clause(c.size(), c.begin(), status::deleted()); } - // - // placeholder for trim function. - // 1. trail contains justification for the empty clause. - // 2. backward pass to prune. - // - svector> drat::trim() { - SASSERT(m_units.empty()); - svector> proof; - for (auto const& [c, st] : m_proof) - if (!st.is_deleted()) - proof.push_back({c,st}); - return proof; - } - - void drat::check_model(model const& m) { } @@ -828,196 +806,4 @@ namespace sat { return out; } - - std::string proof_hint::to_string() const { - std::ostringstream ous; - switch (m_ty) { - case hint_type::null_h: - return std::string(); - case hint_type::farkas_h: - ous << "farkas "; - break; - case hint_type::bound_h: - ous << "bound "; - break; - case hint_type::implied_eq_h: - ous << "implied_eq "; - break; - default: - UNREACHABLE(); - break; - } - for (auto const& [q, l] : m_literals) - ous << rational(q) << " * " << l << " "; - for (auto const& [a, b] : m_eqs) - ous << " = " << a << " " << b << " "; - for (auto const& [a, b] : m_diseqs) - ous << " != " << a << " " << b << " "; - return ous.str(); - } - - void proof_hint::from_string(char const* s) { - proof_hint& h = *this; - h.reset(); - h.m_ty = hint_type::null_h; - if (!s) - return; - auto ws = [&]() { - while (*s == ' ' || *s == '\n' || *s == '\t') - ++s; - }; - - auto parse_type = [&]() { - if (0 == strncmp(s, "farkas", 6)) { - h.m_ty = hint_type::farkas_h; - s += 6; - return true; - } - if (0 == strncmp(s, "bound", 5)) { - h.m_ty = hint_type::bound_h; - s += 5; - return true; - } - if (0 == strncmp(s, "implied_eq", 10)) { - h.m_ty = hint_type::implied_eq_h; - s += 10; - return true; - } - return false; - }; - - sbuffer buff; - auto parse_coeff = [&]() { - buff.reset(); - while (*s && *s != ' ') { - buff.push_back(*s); - ++s; - } - buff.push_back(0); - return rational(buff.data()); - }; - - auto parse_literal = [&]() { - rational r = parse_coeff(); - if (!r.is_int()) - return sat::null_literal; - if (r < 0) - return sat::literal((-r).get_unsigned(), true); - return sat::literal(r.get_unsigned(), false); - }; - auto parse_coeff_literal = [&]() { - if (*s == '=') { - ++s; - ws(); - unsigned a = parse_coeff().get_unsigned(); - ws(); - unsigned b = parse_coeff().get_unsigned(); - h.m_eqs.push_back(std::make_pair(a, b)); - return true; - } - if (*s == '!' && *(s + 1) == '=') { - s += 2; - ws(); - unsigned a = parse_coeff().get_unsigned(); - ws(); - unsigned b = parse_coeff().get_unsigned(); - h.m_diseqs.push_back(std::make_pair(a, b)); - return true; - } - rational coeff = parse_coeff(); - ws(); - if (*s == '*') { - ++s; - ws(); - sat::literal lit = parse_literal(); - h.m_literals.push_back(std::make_pair(coeff, lit)); - return true; - } - return false; - }; - - ws(); - if (!parse_type()) - return; - ws(); - while (*s) { - if (!parse_coeff_literal()) - return; - ws(); - } - } - -#if 0 - // debugging code - bool drat::is_clause(clause& c, literal l1, literal l2, literal l3, drat::status st1, drat::status st2) { - //if (st1 != st2) return false; - if (c.size() != 3) return false; - if (l1 == c[0]) { - if (l2 == c[1] && l3 == c[2]) return true; - if (l2 == c[2] && l3 == c[1]) return true; - } - if (l2 == c[0]) { - if (l1 == c[1] && l3 == c[2]) return true; - if (l1 == c[2] && l3 == c[1]) return true; - } - if (l3 == c[0]) { - if (l1 == c[1] && l2 == c[2]) return true; - if (l1 == c[2] && l2 == c[1]) return true; - } - return false; - } -#endif - - -#if 0 - if (!m_inconsistent) { - literal_vector lits(n, c); - IF_VERBOSE(0, verbose_stream() << "not drup " << lits << "\n"); - for (unsigned v = 0; v < m_assignment.size(); ++v) { - lbool val = m_assignment[v]; - if (val != l_undef) { - IF_VERBOSE(0, verbose_stream() << literal(v, false) << " |-> " << val << "\n"); - } - } - for (clause* cp : s.m_clauses) { - clause& cl = *cp; - bool found = false; - for (literal l : cl) { - if (m_assignment[l.var()] != (l.sign() ? l_true : l_false)) { - found = true; - break; - } - } - if (!found) { - IF_VERBOSE(0, verbose_stream() << "Clause is false under assignment: " << cl << "\n"); - } - } - for (clause* cp : s.m_learned) { - clause& cl = *cp; - bool found = false; - for (literal l : cl) { - if (m_assignment[l.var()] != (l.sign() ? l_true : l_false)) { - found = true; - break; - } - } - if (!found) { - IF_VERBOSE(0, verbose_stream() << "Clause is false under assignment: " << cl << "\n"); - } - } - svector bin; - s.collect_bin_clauses(bin, true); - for (auto& b : bin) { - bool found = false; - if (m_assignment[b.first.var()] != (b.first.sign() ? l_true : l_false)) found = true; - if (m_assignment[b.second.var()] != (b.second.sign() ? l_true : l_false)) found = true; - if (!found) { - IF_VERBOSE(0, verbose_stream() << "Bin clause is false under assignment: " << b.first << " " << b.second << "\n"); - } - } - IF_VERBOSE(0, s.display(verbose_stream())); - exit(0); - } -#endif - } diff --git a/src/sat/sat_drat.h b/src/sat/sat_drat.h index 1ba86d724..452b69701 100644 --- a/src/sat/sat_drat.h +++ b/src/sat/sat_drat.h @@ -60,8 +60,8 @@ namespace sat { class justification; class clause; - struct print_clause { - virtual ~print_clause() {} + struct clause_eh { + virtual ~clause_eh() {} virtual void on_clause(unsigned, literal const*, status) = 0; }; @@ -78,7 +78,7 @@ namespace sat { watched_clause(clause* c, literal l1, literal l2): m_clause(c), m_l1(l1), m_l2(l2) {} }; - print_clause* m_print_clause = nullptr; + clause_eh* m_clause_eh = nullptr; svector m_watched_clauses; typedef svector watch; solver& s; @@ -95,7 +95,6 @@ namespace sat { bool m_check_sat = false; bool m_check = false; bool m_activity = false; - bool m_trim = false; stats m_stats; @@ -145,17 +144,10 @@ namespace sat { void add(literal_vector const& c); // add learned clause void add(unsigned sz, literal const* lits, status st); - void set_print_clause(print_clause& print_clause) { - m_print_clause = &print_clause; - } + void set_clause_eh(clause_eh& clause_eh) { m_clause_eh = &clause_eh; } - // support for SMT - connect Boolean variables with AST nodes - // associate AST node id with Boolean variable v - - // declare AST node n with 'name' and arguments arg std::ostream* out() { return m_out; } - bool is_cleaned(clause& c) const; void del(literal l); void del(literal l1, literal l2); @@ -181,8 +173,6 @@ namespace sat { svector> const& units() { return m_units; } bool is_drup(unsigned n, literal const* c, literal_vector& units); solver& get_solver() { return s; } - - svector> trim(); }; diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index f322d98f8..41bfa7afa 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -47,11 +47,11 @@ def_module_params('sat', ('threads', UINT, 1, 'number of parallel threads to use'), ('dimacs.core', BOOL, False, 'extract core from DIMACS benchmarks'), ('drat.disable', BOOL, False, 'override anything that enables DRAT'), + ('smt.proof', SYMBOL, '', 'add SMT proof to file'), ('drat.file', SYMBOL, '', 'file to dump DRAT proofs'), ('drat.binary', BOOL, False, 'use Binary DRAT output format'), ('drat.check_unsat', BOOL, False, 'build up internal proof and check'), ('drat.check_sat', BOOL, False, 'build up internal trace, check satisfying model'), - ('drup.trim', BOOL, False, 'build and trim drup proof'), ('drat.activity', BOOL, False, 'dump variable activities'), ('cardinality.solver', BOOL, True, 'use cardinality solver'), ('pb.solver', SYMBOL, 'solver', 'method for handling Pseudo-Boolean constraints: circuit (arithmetical circuit), sorting (sorting circuit), totalizer (use totalizer encoding), binary_merge, segmented, solver (use native solver)'), diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index d393cc94e..6a5f15567 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -402,6 +402,7 @@ namespace sat { extension::scoped_drating _sd(*m_ext.get()); if (j.get_kind() == justification::EXT_JUSTIFICATION) fill_ext_antecedents(lit, j, false); + TRACE("sat", tout << "drat-unit\n"); m_drat.add(lit, m_searching); } @@ -948,8 +949,7 @@ namespace sat { if (j.level() == 0) { if (m_config.m_drat) drat_log_unit(l, j); - if (!m_config.m_drup_trim) - j = justification(0); // erase justification for level 0 + j = justification(0); // erase justification for level 0 } else { VERIFY(!at_base_lvl()); diff --git a/src/sat/sat_types.h b/src/sat/sat_types.h index db13db054..4e119a2ae 100644 --- a/src/sat/sat_types.h +++ b/src/sat/sat_types.h @@ -94,22 +94,9 @@ namespace sat { }; - enum class hint_type { - null_h, - farkas_h, - bound_h, - implied_eq_h, - }; - - struct proof_hint { - hint_type m_ty = hint_type::null_h; - vector> m_literals; - vector> m_eqs; - vector> m_diseqs; - void reset() { m_ty = hint_type::null_h; m_literals.reset(); m_eqs.reset(); m_diseqs.reset(); } - std::string to_string() const; - void from_string(char const* s); - void from_string(std::string const& s) { from_string(s.c_str()); } + class proof_hint { + public: + virtual ~proof_hint() {} }; class status { diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index a75c0022d..51bd8bdd7 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -21,6 +21,7 @@ z3_add_component(sat_smt euf_invariant.cpp euf_model.cpp euf_proof.cpp + euf_proof_checker.cpp euf_relevancy.cpp euf_solver.cpp fpa_solver.cpp diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 517efcfe3..4d1afb4cc 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -264,11 +264,12 @@ namespace arith { SASSERT(k1 != k2 || kind1 != kind2); auto bin_clause = [&](sat::literal l1, sat::literal l2) { - sat::proof_hint* bound_params = nullptr; + arith_proof_hint* bound_params = nullptr; if (ctx.use_drat()) { - bound_params = &m_farkas2; - m_farkas2.m_literals[0] = std::make_pair(rational(1), ~l1); - m_farkas2.m_literals[1] = std::make_pair(rational(1), ~l2); + m_arith_hint.set_type(ctx, hint_type::farkas_h); + m_arith_hint.add_lit(rational(1), ~l1); + m_arith_hint.add_lit(rational(1), ~l2); + bound_params = m_arith_hint.mk(ctx); } add_clause(l1, l2, bound_params); }; diff --git a/src/sat/smt/arith_diagnostics.cpp b/src/sat/smt/arith_diagnostics.cpp index 4f016746c..bfeeaff4d 100644 --- a/src/sat/smt/arith_diagnostics.cpp +++ b/src/sat/smt/arith_diagnostics.cpp @@ -15,6 +15,8 @@ Author: --*/ +#include "ast/ast_util.h" +#include "ast/scoped_proof.h" #include "sat/smt/euf_solver.h" #include "sat/smt/arith_solver.h" @@ -81,7 +83,6 @@ namespace arith { } void solver::explain_assumptions() { - m_arith_hint.reset(); unsigned i = 0; for (auto const & ev : m_explanation) { ++i; @@ -91,14 +92,12 @@ namespace arith { switch (m_constraint_sources[idx]) { case inequality_source: { literal lit = m_inequalities[idx]; - m_arith_hint.m_literals.push_back({ev.coeff(), lit}); + m_arith_hint.add_lit(ev.coeff(), lit); break; } case equality_source: { auto [u, v] = m_equalities[idx]; - ctx.drat_log_expr(u->get_expr()); - ctx.drat_log_expr(v->get_expr()); - m_arith_hint.m_eqs.push_back({u->get_expr_id(), v->get_expr_id()}); + m_arith_hint.add_eq(u, v); break; } default: @@ -115,22 +114,65 @@ namespace arith { * such that there is a r >= 1 * (r1*a1+..+r_k*a_k) = r*a, (r1*b1+..+r_k*b_k) <= r*b */ - sat::proof_hint const* solver::explain(sat::hint_type ty, sat::literal lit) { + arith_proof_hint const* solver::explain(hint_type ty, sat::literal lit) { if (!ctx.use_drat()) return nullptr; - m_arith_hint.m_ty = ty; + m_arith_hint.set_type(ctx, ty); explain_assumptions(); if (lit != sat::null_literal) - m_arith_hint.m_literals.push_back({rational(1), ~lit}); - return &m_arith_hint; + m_arith_hint.add_lit(rational(1), ~lit); + return m_arith_hint.mk(ctx); } - sat::proof_hint const* solver::explain_implied_eq(euf::enode* a, euf::enode* b) { + arith_proof_hint const* solver::explain_implied_eq(euf::enode* a, euf::enode* b) { if (!ctx.use_drat()) return nullptr; - m_arith_hint.m_ty = sat::hint_type::implied_eq_h; + m_arith_hint.set_type(ctx, hint_type::implied_eq_h); explain_assumptions(); - m_arith_hint.m_diseqs.push_back({a->get_expr_id(), b->get_expr_id()}); - return &m_arith_hint; + m_arith_hint.add_diseq(a, b); + return m_arith_hint.mk(ctx); + } + + expr* arith_proof_hint::get_hint(euf::solver& s) const { + ast_manager& m = s.get_manager(); + family_id fid = m.get_family_id("arith"); + arith_util arith(m); + solver& a = dynamic_cast(*s.fid2solver(fid)); + char const* name; + switch (m_ty) { + case hint_type::farkas_h: + name = "farkas"; + break; + case hint_type::bound_h: + name = "bound"; + break; + case hint_type::implied_eq_h: + name = "implied-eq"; + break; + } + rational lc(1); + for (unsigned i = m_lit_head; i < m_lit_tail; ++i) + lc = lcm(lc, denominator(a.m_arith_hint.lit(i).first)); + + expr_ref_vector args(m); + sort_ref_vector sorts(m); + for (unsigned i = m_lit_head; i < m_lit_tail; ++i) { + auto const& [coeff, lit] = a.m_arith_hint.lit(i); + args.push_back(arith.mk_int(coeff*lc)); + args.push_back(s.literal2expr(lit)); + } + for (unsigned i = m_eq_head; i < m_eq_tail; ++i) { + auto const& [a, b, is_eq] = a.m_arith_hint.eq(i); + expr_ref eq(m.mk_eq(a->get_expr(), b->get_expr()), m); + if (!is_eq) eq = m.mk_not(eq); + args.push_back(arith.mk_int(lc)); + args.push_back(eq); + } + for (expr* a : args) + sorts.push_back(a->get_sort()); + sort* range = m.mk_proof_sort(); + func_decl* d = m.mk_func_decl(symbol(name), args.size(), sorts.data(), range); + expr* r = m.mk_app(d, args); + return r; } } diff --git a/src/sat/smt/arith_proof_checker.h b/src/sat/smt/arith_proof_checker.h index c37a4f7c2..2333d1e3b 100644 --- a/src/sat/smt/arith_proof_checker.h +++ b/src/sat/smt/arith_proof_checker.h @@ -1,5 +1,5 @@ /*++ -Copyright (c) 2020 Microsoft Corporation +Copyright (c) 2022 Microsoft Corporation Module Name: @@ -11,7 +11,15 @@ Abstract: Author: - Nikolaj Bjorner (nbjorner) 2020-09-08 + Nikolaj Bjorner (nbjorner) 2022-08-28 + +Notes: + +The module assumes a limited repertoire of arithmetic proof rules. + +- farkas - inequalities, equalities and disequalities with coefficients +- implied-eq - last literal is a disequality. The literals before imply the corresponding equality. +- bound - last literal is a bound. It is implied by prior literals. --*/ #pragma once @@ -19,11 +27,12 @@ Author: #include "util/obj_pair_set.h" #include "ast/ast_trail.h" #include "ast/arith_decl_plugin.h" +#include "sat/smt/euf_proof_checker.h" namespace arith { - class proof_checker { + class proof_checker : public euf::proof_checker_plugin { struct row { obj_map m_coeffs; rational m_coeff; @@ -300,6 +309,8 @@ namespace arith { public: proof_checker(ast_manager& m): m(m), a(m) {} + + ~proof_checker() override {} void reset() { m_ineq.reset(); @@ -350,6 +361,47 @@ namespace arith { return out; } + bool check(expr_ref_vector const& clause, app* jst) override { + reset(); + + if (jst->get_name() == symbol("farkas")) { + bool even = true; + rational coeff; + expr* x, *y; + for (expr* arg : *jst) { + if (even) { + VERIFY(a.is_numeral(arg, coeff)); + } + else { + bool sign = m.is_not(arg, arg); + if (a.is_le(arg) || a.is_lt(arg) || a.is_ge(arg) || a.is_gt(arg)) + add_ineq(coeff, arg, sign); + else if (m.is_eq(arg, x, y)) { + if (sign) + add_diseq(x, y); + else + add_eq(x, y); + } + else + return false; + } + even = !even; + } + // display(verbose_stream()); + // todo: correlate with literals in clause, literals that are not in clause should have RUP property. + return check_farkas(); + } + + // todo: rules for bounds and implied-by + + return false; + } + + void register_plugins(euf::proof_checker& pc) { + pc.register_plugin(symbol("farkas"), this); + pc.register_plugin(symbol("bound"), this); + pc.register_plugin(symbol("implied-eq"), this); + } }; diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index fc3677cb9..27de68f3b 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -39,8 +39,6 @@ namespace arith { lp().settings().set_random_seed(get_config().m_random_seed); m_lia = alloc(lp::int_solver, *m_solver.get()); - m_farkas2.m_ty = sat::hint_type::farkas_h; - m_farkas2.m_literals.resize(2); } solver::~solver() { @@ -197,11 +195,12 @@ namespace arith { reset_evidence(); m_core.push_back(lit1); TRACE("arith", tout << lit2 << " <- " << m_core << "\n";); - sat::proof_hint* ph = nullptr; + arith_proof_hint* ph = nullptr; if (ctx.use_drat()) { - ph = &m_farkas2; - m_farkas2.m_literals[0] = std::make_pair(rational(1), lit1); - m_farkas2.m_literals[1] = std::make_pair(rational(1), ~lit2); + m_arith_hint.set_type(ctx, hint_type::farkas_h); + m_arith_hint.add_lit(rational(1), lit1); + m_arith_hint.add_lit(rational(1), ~lit2); + ph = m_arith_hint.mk(ctx); } assign(lit2, m_core, m_eqs, ph); ++m_stats.m_bounds_propagations; @@ -262,7 +261,7 @@ namespace arith { TRACE("arith", for (auto lit : m_core) tout << lit << ": " << s().value(lit) << "\n";); DEBUG_CODE(for (auto lit : m_core) { VERIFY(s().value(lit) == l_true); }); ++m_stats.m_bound_propagations1; - assign(lit, m_core, m_eqs, explain(sat::hint_type::bound_h, lit)); + assign(lit, m_core, m_eqs, explain(hint_type::bound_h, lit)); } if (should_refine_bounds() && first) @@ -378,7 +377,7 @@ namespace arith { reset_evidence(); m_explanation.clear(); lp().explain_implied_bound(be, m_bp); - assign(bound, m_core, m_eqs, explain(sat::hint_type::farkas_h, bound)); + assign(bound, m_core, m_eqs, explain(hint_type::farkas_h, bound)); } @@ -1178,7 +1177,7 @@ namespace arith { app_ref b = mk_bound(m_lia->get_term(), m_lia->get_offset(), !m_lia->is_upper()); IF_VERBOSE(4, verbose_stream() << "cut " << b << "\n"); literal lit = expr2literal(b); - assign(lit, m_core, m_eqs, explain(sat::hint_type::bound_h, lit)); + assign(lit, m_core, m_eqs, explain(hint_type::bound_h, lit)); lia_check = l_false; break; } @@ -1200,7 +1199,7 @@ namespace arith { return lia_check; } - void solver::assign(literal lit, literal_vector const& core, svector const& eqs, sat::proof_hint const* pma) { + void solver::assign(literal lit, literal_vector const& core, svector const& eqs, euf::th_proof_hint const* pma) { if (core.size() < small_lemma_size() && eqs.empty()) { m_core2.reset(); for (auto const& c : core) @@ -1247,7 +1246,7 @@ namespace arith { for (literal& c : m_core) c.neg(); - add_clause(m_core, explain(sat::hint_type::farkas_h)); + add_clause(m_core, explain(hint_type::farkas_h)); } bool solver::is_infeasible() const { diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 76bdeca9d..92d7cb120 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -48,8 +48,61 @@ namespace arith { typedef sat::literal_vector literal_vector; typedef lp_api::bound api_bound; + enum class hint_type { + farkas_h, + bound_h, + implied_eq_h + }; + + struct arith_proof_hint : public euf::th_proof_hint { + hint_type m_ty; + unsigned m_lit_head, m_lit_tail, m_eq_head, m_eq_tail; + arith_proof_hint(hint_type t, unsigned lh, unsigned lt, unsigned eh, unsigned et): + m_ty(t), m_lit_head(lh), m_lit_tail(lt), m_eq_head(eh), m_eq_tail(et) {} + expr* get_hint(euf::solver& s) const override; + }; + + class arith_proof_hint_builder { + vector> m_literals; + svector> m_eqs; + hint_type m_ty; + unsigned m_lit_head = 0, m_lit_tail = 0, m_eq_head = 0, m_eq_tail; + void reset() { m_lit_head = m_lit_tail; m_eq_head = m_eq_tail; } + void add(euf::enode* a, euf::enode* b, bool is_eq) { + if (m_eq_tail < m_eqs.size()) + m_eqs[m_eq_tail] = std::tuple(a, b, is_eq); + else + m_eqs.push_back(std::tuple(a, b, is_eq)); + m_eq_tail++; + } + public: + void set_type(euf::solver& ctx, hint_type ty) { + ctx.push(value_trail(m_eq_tail)); + ctx.push(value_trail(m_lit_tail)); + m_ty = ty; + reset(); + } + void add_eq(euf::enode* a, euf::enode* b) { add(a, b, true); } + void add_diseq(euf::enode* a, euf::enode* b) { add(a, b, false); } + void add_lit(rational const& coeff, literal lit) { + if (m_lit_tail < m_literals.size()) + m_literals[m_lit_tail] = {coeff, lit}; + else + m_literals.push_back({coeff, lit}); + m_lit_tail++; + } + std::pair const& lit(unsigned i) const { return m_literals[i]; } + std::tuple const& eq(unsigned i) const { return m_eqs[i]; } + arith_proof_hint* mk(euf::solver& s) { + return new (s.get_region()) arith_proof_hint(m_ty, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail); + } + }; + + class solver : public euf::th_euf_solver { + friend struct arith_proof_hint; + struct scope { unsigned m_bounds_lim; unsigned m_idiv_lim; @@ -414,15 +467,15 @@ namespace arith { void set_conflict(); void set_conflict_or_lemma(literal_vector const& core, bool is_conflict); void set_evidence(lp::constraint_index idx); - void assign(literal lit, literal_vector const& core, svector const& eqs, sat::proof_hint const* pma); + void assign(literal lit, literal_vector const& core, svector const& eqs, euf::th_proof_hint const* pma); void false_case_of_check_nla(const nla::lemma& l); void dbg_finalize_model(model& mdl); - sat::proof_hint m_arith_hint; - sat::proof_hint m_farkas2; - sat::proof_hint const* explain(sat::hint_type ty, sat::literal lit = sat::null_literal); - sat::proof_hint const* explain_implied_eq(euf::enode* a, euf::enode* b); + arith_proof_hint_builder m_arith_hint; + + arith_proof_hint const* explain(hint_type ty, sat::literal lit = sat::null_literal); + arith_proof_hint const* explain_implied_eq(euf::enode* a, euf::enode* b); void explain_assumptions(); diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index 1e3cb0171..33ff829d6 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -391,8 +391,8 @@ namespace bv { if (c.m_kind != bv_justification::kind_t::bit2ne) { expr* e1 = var2expr(c.m_v1); expr* e2 = var2expr(c.m_v2); - eq = m.mk_eq(e1, e2); - ctx.drat_eq_def(leq, eq); + eq = m.mk_eq(e1, e2); + ctx.set_tmp_bool_var(leq.var(), eq); } sat::literal_vector lits; diff --git a/src/sat/smt/euf_proof.cpp b/src/sat/smt/euf_proof.cpp index 567f6fc1c..4ec90166a 100644 --- a/src/sat/smt/euf_proof.cpp +++ b/src/sat/smt/euf_proof.cpp @@ -21,130 +21,21 @@ Author: namespace euf { - void solver::init_drat() { - if (!m_drat_initialized) { + void solver::init_proof() { + if (!m_proof_initialized) { get_drat().add_theory(get_id(), symbol("euf")); get_drat().add_theory(m.get_basic_family_id(), symbol("bool")); } - m_drat_initialized = true; - } - - void solver::def_add_arg(unsigned arg) { - auto* out = get_drat().out(); - if (out) - (*out) << " " << arg; - } - - void solver::def_end() { - auto* out = get_drat().out(); - if (out) - (*out) << " 0\n"; - } - - void solver::def_begin(char id, unsigned n, std::string const& name) { - auto* out = get_drat().out(); - if (out) - (*out) << id << " " << n << " " << name; - } - - void solver::bool_def(bool_var v, unsigned n) { - auto* out = get_drat().out(); - if (out) - (*out) << "b " << v << " " << n << " 0\n"; - } - - - void solver::drat_log_params(func_decl* f) { - for (unsigned i = f->get_num_parameters(); i-- > 0; ) { - auto const& p = f->get_parameter(i); - if (!p.is_ast()) - continue; - ast* a = p.get_ast(); - if (is_func_decl(a)) - drat_log_decl(to_func_decl(a)); + if (!m_proof_out && s().get_config().m_drat && + (get_config().m_lemmas2console || s().get_config().m_smt_proof.is_non_empty_string())) { + TRACE("euf", tout << "init-proof\n"); + m_proof_out = alloc(std::ofstream, s().get_config().m_smt_proof.str(), std::ios_base::out); + if (get_config().m_lemmas2console) + get_drat().set_clause_eh(*this); + if (s().get_config().m_smt_proof.is_non_empty_string()) + get_drat().set_clause_eh(*this); } - } - - void solver::drat_log_expr1(expr* e) { - if (is_app(e)) { - app* a = to_app(e); - drat_log_params(a->get_decl()); - drat_log_decl(a->get_decl()); - std::stringstream strm; - strm << mk_ismt2_func(a->get_decl(), m); - def_begin('e', e->get_id(), strm.str()); - for (expr* arg : *a) - def_add_arg(arg->get_id()); - def_end(); - } - else if (is_var(e)) { - var* v = to_var(e); - def_begin('v', v->get_id(), "" + mk_pp(e->get_sort(), m)); - def_add_arg(v->get_idx()); - def_end(); - } - else if (is_quantifier(e)) { - quantifier* q = to_quantifier(e); - std::stringstream strm; - strm << "(" << (is_forall(q) ? "forall" : (is_exists(q) ? "exists" : "lambda")); - for (unsigned i = 0; i < q->get_num_decls(); ++i) - strm << " (" << q->get_decl_name(i) << " " << mk_pp(q->get_decl_sort(i), m) << ")"; - strm << ")"; - def_begin('q', q->get_id(), strm.str()); - def_add_arg(q->get_expr()->get_id()); - def_end(); - } - else - UNREACHABLE(); - m_drat_asts.insert(e); - push(insert_obj_trail(m_drat_asts, e)); - } - - void solver::drat_log_expr(expr* e) { - if (m_drat_asts.contains(e)) - return; - ptr_vector::scoped_stack _sc(m_drat_todo); - m_drat_todo.push_back(e); - while (!m_drat_todo.empty()) { - e = m_drat_todo.back(); - unsigned sz = m_drat_todo.size(); - if (is_app(e)) - for (expr* arg : *to_app(e)) - if (!m_drat_asts.contains(arg)) - m_drat_todo.push_back(arg); - if (is_quantifier(e)) { - expr* arg = to_quantifier(e)->get_expr(); - if (!m_drat_asts.contains(arg)) - m_drat_todo.push_back(arg); - } - if (m_drat_todo.size() != sz) - continue; - if (!m_drat_asts.contains(e)) - drat_log_expr1(e); - m_drat_todo.pop_back(); - } - } - - void solver::drat_bool_def(sat::bool_var v, expr* e) { - if (!use_drat()) - return; - drat_log_expr(e); - bool_def(v, e->get_id()); - } - - - void solver::drat_log_decl(func_decl* f) { - if (f->get_family_id() != null_family_id) - return; - if (m_drat_asts.contains(f)) - return; - m_drat_asts.insert(f); - push(insert_obj_trail< ast>(m_drat_asts, f)); - std::ostringstream strm; - smt2_pp_environment_dbg env(m); - ast_smt2_pp(strm, f, env); - def_begin('f', f->get_small_id(), strm.str()); - def_end(); + m_proof_initialized = true; } /** @@ -183,16 +74,19 @@ namespace euf { } } + void solver::set_tmp_bool_var(bool_var b, expr* e) { + m_bool_var2expr.setx(b, e, nullptr); + } + void solver::log_justification(literal l, th_explain const& jst) { literal_vector lits; - unsigned nv = s().num_vars(); expr_ref_vector eqs(m); + unsigned nv = s().num_vars(); auto add_lit = [&](enode_pair const& eq) { ++nv; - literal lit(nv, false); eqs.push_back(m.mk_eq(eq.first->get_expr(), eq.second->get_expr())); - drat_eq_def(lit, eqs.back()); - return lit; + set_tmp_bool_var(nv, eqs.back()); + return literal(nv, false); }; for (auto lit : euf::th_explain::lits(jst)) @@ -208,68 +102,133 @@ namespace euf { get_drat().add(lits, sat::status::th(m_is_redundant, jst.ext().get_id(), jst.get_pragma())); } - void solver::drat_eq_def(literal lit, expr* eq) { - expr *a = nullptr, *b = nullptr; - VERIFY(m.is_eq(eq, a, b)); - drat_log_expr(a); - drat_log_expr(b); - def_begin('e', eq->get_id(), std::string("=")); - def_add_arg(a->get_id()); - def_add_arg(b->get_id()); - def_end(); - bool_def(lit.var(), eq->get_id()); + void solver::on_clause(unsigned n, literal const* lits, sat::status st) { + TRACE("euf", tout << "on-clause " << n << "\n"); + on_lemma(n, lits, st); + on_proof(n, lits, st); } - void solver::on_clause(unsigned n, literal const* lits, sat::status st) { + void solver::on_proof(unsigned n, literal const* lits, sat::status st) { + if (!m_proof_out) + return; + flet _display_all_decls(m_display_all_decls, true); + std::ostream& out = *m_proof_out; + if (!visit_clause(out, n, lits)) + return; + if (st.is_asserted()) + display_redundant(out, n, lits, status2proof_hint(st)); + else if (st.is_deleted()) + display_deleted(out, n, lits); + else if (st.is_redundant()) + display_redundant(out, n, lits, status2proof_hint(st)); + else if (st.is_input()) + display_assume(out, n, lits); + else + UNREACHABLE(); + out.flush(); + } + + void solver::on_lemma(unsigned n, literal const* lits, sat::status st) { if (!get_config().m_lemmas2console) return; if (!st.is_redundant() && !st.is_asserted()) return; - - if (!visit_clause(n, lits)) + std::ostream& out = std::cout; + if (!visit_clause(out, n, lits)) return; - std::function ppth = [&](int th) { return m.get_family_name(th); }; if (!st.is_sat()) - std::cout << "; " << sat::status_pp(st, ppth) << "\n"; + out << "; " << sat::status_pp(st, ppth) << "\n"; - display_clause(n, lits); + display_assert(out, n, lits); } + void solver::on_instantiation(unsigned n, sat::literal const* lits, unsigned k, euf::enode* const* bindings) { + std::ostream& out = std::cout; + for (unsigned i = 0; i < k; ++i) + visit_expr(out, bindings[i]->get_expr()); + VERIFY(visit_clause(out, n, lits)); + out << "(instantiate"; + display_literals(out, n, lits); + for (unsigned i = 0; i < k; ++i) + display_expr(out << " :binding ", bindings[i]->get_expr()); + out << ")\n"; + } - bool solver::visit_clause(unsigned n, literal const* lits) { + bool solver::visit_clause(std::ostream& out, unsigned n, literal const* lits) { for (unsigned i = 0; i < n; ++i) { expr* e = bool_var2expr(lits[i].var()); if (!e) return false; - visit_expr(e); + visit_expr(out, e); } return true; } - void solver::display_clause(unsigned n, literal const* lits) { - std::cout << "(assert (or"; + void solver::display_assert(std::ostream& out, unsigned n, literal const* lits) { + display_literals(out << "(assert (or", n, lits) << "))\n"; + } + + void solver::display_assume(std::ostream& out, unsigned n, literal const* lits) { + display_literals(out << "(assume", n, lits) << ")\n"; + } + + void solver::display_redundant(std::ostream& out, unsigned n, literal const* lits, expr_ref& proof_hint) { + if (proof_hint) + visit_expr(out, proof_hint); + display_hint(display_literals(out << "(learn", n, lits), proof_hint) << ")\n"; + } + + void solver::display_deleted(std::ostream& out, unsigned n, literal const* lits) { + display_literals(out << "(del", n, lits) << ")\n"; + } + + std::ostream& solver::display_hint(std::ostream& out, expr* proof_hint) { + if (proof_hint) + return display_expr(out << " ", proof_hint); + else + return out; + } + + expr_ref solver::status2proof_hint(sat::status st) { + if (st.is_sat()) + return expr_ref(m.mk_const("rup", m.mk_proof_sort()), m); // provable by reverse unit propagation + auto* h = reinterpret_cast(st.get_hint()); + if (!h) + return expr_ref(m); + + expr* e = h->get_hint(*this); + if (e) + return expr_ref(e, m); + + return expr_ref(m); + } + + std::ostream& solver::display_literals(std::ostream& out, unsigned n, literal const* lits) { for (unsigned i = 0; i < n; ++i) { expr* e = bool_var2expr(lits[i].var()); if (lits[i].sign()) - m_clause_visitor.display_expr_def(std::cout << " (not ", e) << ")"; + display_expr(out << " (not ", e) << ")"; else - m_clause_visitor.display_expr_def(std::cout << " ", e); + display_expr(out << " ", e); } - std::cout << "))\n"; + return out; } - void solver::visit_expr(expr* e) { + void solver::visit_expr(std::ostream& out, expr* e) { m_clause_visitor.collect(e); - m_clause_visitor.display_skolem_decls(std::cout); - m_clause_visitor.define_expr(std::cout, e); + if (m_display_all_decls) + m_clause_visitor.display_decls(out); + else + m_clause_visitor.display_skolem_decls(out); + m_clause_visitor.define_expr(out, e); } - void solver::display_expr(expr* e) { - m_clause_visitor.display_expr_def(std::cout, e); + std::ostream& solver::display_expr(std::ostream& out, expr* e) { + return m_clause_visitor.display_expr_def(out, e); } } diff --git a/src/sat/smt/euf_proof_checker.cpp b/src/sat/smt/euf_proof_checker.cpp new file mode 100644 index 000000000..9943c5729 --- /dev/null +++ b/src/sat/smt/euf_proof_checker.cpp @@ -0,0 +1,48 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + euf_proof_checker.cpp + +Abstract: + + Plugin manager for checking EUF proofs + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-25 + +--*/ + +#include "ast/ast_pp.h" +#include "sat/smt/euf_proof_checker.h" +#include "sat/smt/arith_proof_checker.h" + +namespace euf { + + proof_checker::proof_checker(ast_manager& m): + m(m) { + arith::proof_checker* apc = alloc(arith::proof_checker, m); + m_plugins.push_back(apc); + apc->register_plugins(*this); + } + + proof_checker::~proof_checker() {} + + void proof_checker::register_plugin(symbol const& rule, proof_checker_plugin* p) { + m_map.insert(rule, p); + } + + bool proof_checker::check(expr_ref_vector const& clause, expr* e) { + if (!e || !is_app(e)) + return false; + app* a = to_app(e); + proof_checker_plugin* p = nullptr; + if (m_map.find(a->get_decl()->get_name(), p)) + return p->check(clause, a); + return false; + } + +} + diff --git a/src/sat/smt/euf_proof_checker.h b/src/sat/smt/euf_proof_checker.h new file mode 100644 index 000000000..be1e8fa70 --- /dev/null +++ b/src/sat/smt/euf_proof_checker.h @@ -0,0 +1,46 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + euf_proof_checker.h + +Abstract: + + Plugin manager for checking EUF proofs + +Author: + + Nikolaj Bjorner (nbjorner) 2022-08-25 + +--*/ +#pragma once + +#include "util/map.h" +#include "util/scoped_ptr_vector.h" +#include "ast/ast.h" + +namespace euf { + + class proof_checker; + + class proof_checker_plugin { + public: + virtual ~proof_checker_plugin() {} + virtual bool check(expr_ref_vector const& clause, app* jst) = 0; + virtual void register_plugins(proof_checker& pc) = 0; + }; + + class proof_checker { + ast_manager& m; + scoped_ptr_vector m_plugins; + map m_map; + public: + proof_checker(ast_manager& m); + ~proof_checker(); + void register_plugin(symbol const& rule, proof_checker_plugin*); + bool check(expr_ref_vector const& clause, expr* e); + }; + +} + diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index c094c0ac4..dcb919d65 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -170,9 +170,6 @@ namespace euf { TRACE("before_search", s().display(tout);); for (auto* s : m_solvers) s->init_search(); - - if (get_config().m_lemmas2console) - get_drat().set_print_clause(*this); } bool solver::is_external(bool_var v) { diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index ed1ad9b92..71569676a 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -60,11 +60,10 @@ namespace euf { std::ostream& display(std::ostream& out) const; }; - class solver : public sat::extension, public th_internalizer, public th_decompile, public sat::print_clause { + class solver : public sat::extension, public th_internalizer, public th_decompile, public sat::clause_eh { typedef top_sort deps_t; friend class ackerman; class user_sort; - // friend class sat::ba_solver; struct stats { unsigned m_ackerman; unsigned m_final_checks; @@ -175,26 +174,22 @@ namespace euf { void log_antecedents(std::ostream& out, literal l, literal_vector const& r); void log_antecedents(literal l, literal_vector const& r); void log_justification(literal l, th_explain const& jst); - void drat_log_decl(func_decl* f); - void drat_log_params(func_decl* f); - void drat_log_expr1(expr* n); - ptr_vector m_drat_todo; - obj_hashtable m_drat_asts; - bool m_drat_initialized{ false }; - void init_drat(); - ast_pp_util m_clause_visitor; - void on_clause(unsigned n, literal const* lits, sat::status st) override; - void def_add_arg(unsigned arg); - void def_end(); - void def_begin(char id, unsigned n, std::string const& name); + bool m_proof_initialized = false; + void init_proof(); + ast_pp_util m_clause_visitor; + bool m_display_all_decls = false; + void on_clause(unsigned n, literal const* lits, sat::status st) override; + void on_lemma(unsigned n, literal const* lits, sat::status st); + void on_proof(unsigned n, literal const* lits, sat::status st); + std::ostream& display_literals(std::ostream& out, unsigned n, sat::literal const* lits); + void display_assume(std::ostream& out, unsigned n, literal const* lits); + void display_redundant(std::ostream& out, unsigned n, literal const* lits, expr_ref& proof_hint); + void display_deleted(std::ostream& out, unsigned n, literal const* lits); + std::ostream& display_hint(std::ostream& out, expr* proof_hint); + expr_ref status2proof_hint(sat::status st); // relevancy - //bool_vector m_relevant_expr_ids; - //bool_vector m_relevant_visited; - //ptr_vector m_relevant_todo; - //void init_relevant_expr_ids(); - //void push_relevant(sat::bool_var v); bool is_propagated(sat::literal lit); // invariant void check_eqc_bool_assignment() const; @@ -347,16 +342,16 @@ namespace euf { // proof - bool use_drat() { return s().get_config().m_drat && (init_drat(), true); } + bool use_drat() { return s().get_config().m_drat && (init_proof(), true); } sat::drat& get_drat() { return s().get_drat(); } - void drat_bool_def(sat::bool_var v, expr* n); - void drat_eq_def(sat::literal lit, expr* eq); - void drat_log_expr(expr* n); - void bool_def(bool_var v, unsigned n); - bool visit_clause(unsigned n, literal const* lits); - void display_clause(unsigned n, literal const* lits); - void visit_expr(expr* e); - void display_expr(expr* e); + + void set_tmp_bool_var(sat::bool_var b, expr* e); + bool visit_clause(std::ostream& out, unsigned n, literal const* lits); + void display_assert(std::ostream& out, unsigned n, literal const* lits); + void visit_expr(std::ostream& out, expr* e); + std::ostream& display_expr(std::ostream& out, expr* e); + void on_instantiation(unsigned n, sat::literal const* lits, unsigned k, euf::enode* const* bindings); + scoped_ptr m_proof_out; // decompile bool extract_pb(std::function& card, diff --git a/src/sat/smt/q_solver.cpp b/src/sat/smt/q_solver.cpp index 3c6ff97d9..f40aa76c8 100644 --- a/src/sat/smt/q_solver.cpp +++ b/src/sat/smt/q_solver.cpp @@ -360,19 +360,7 @@ namespace q { void solver::log_instantiation(unsigned n, sat::literal const* lits, justification* j) { TRACE("q", for (unsigned i = 0; i < n; ++i) tout << literal2expr(lits[i]) << "\n";); if (get_config().m_instantiations2console) { - - ctx.visit_clause(n, lits); - if (j) { - for (unsigned i = 0; i < j->m_clause.num_decls(); ++i) - ctx.visit_expr(j->m_binding[i]->get_expr()); - std::cout << "; (instantiation"; - for (unsigned i = 0; i < j->m_clause.num_decls(); ++i) { - std::cout << " "; - ctx.display_expr(j->m_binding[i]->get_expr()); - } - std::cout << ")\n"; - } - ctx.display_clause(n, lits); + ctx.on_instantiation(n, lits, j ? j->m_clause.num_decls() : 0, j ? j->m_binding : nullptr); } } } diff --git a/src/sat/smt/sat_th.cpp b/src/sat/smt/sat_th.cpp index 7f4ff2f4d..3267f0940 100644 --- a/src/sat/smt/sat_th.cpp +++ b/src/sat/smt/sat_th.cpp @@ -125,7 +125,7 @@ namespace euf { pop_core(n); } - sat::status th_euf_solver::mk_status(sat::proof_hint const* ps) { + sat::status th_euf_solver::mk_status(th_proof_hint const* ps) { return sat::status::th(m_is_redundant, get_id(), ps); } @@ -149,7 +149,7 @@ namespace euf { return add_clause(2, lits); } - bool th_euf_solver::add_clause(sat::literal a, sat::literal b, sat::proof_hint const* ps) { + bool th_euf_solver::add_clause(sat::literal a, sat::literal b, th_proof_hint const* ps) { sat::literal lits[2] = { a, b }; return add_clause(2, lits, ps); } @@ -164,7 +164,7 @@ namespace euf { return add_clause(4, lits); } - bool th_euf_solver::add_clause(unsigned n, sat::literal* lits, sat::proof_hint const* ps) { + bool th_euf_solver::add_clause(unsigned n, sat::literal* lits, th_proof_hint const* ps) { bool was_true = false; for (unsigned i = 0; i < n; ++i) was_true |= is_true(lits[i]); @@ -226,13 +226,14 @@ namespace euf { return ctx.s().rand()(); } - size_t th_explain::get_obj_size(unsigned num_lits, unsigned num_eqs, sat::proof_hint const* pma) { - return sat::constraint_base::obj_size(sizeof(th_explain) + sizeof(sat::literal) * num_lits + sizeof(enode_pair) * num_eqs + (pma?pma->to_string().length()+1:1)); + size_t th_explain::get_obj_size(unsigned num_lits, unsigned num_eqs) { + return sat::constraint_base::obj_size(sizeof(th_explain) + sizeof(sat::literal) * num_lits + sizeof(enode_pair) * num_eqs); } - th_explain::th_explain(unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode_pair const& p, sat::proof_hint const* pma) { + th_explain::th_explain(unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode_pair const& p, th_proof_hint const* pma) { m_consequent = c; m_eq = p; + m_proof_hint = pma; m_num_literals = n_lits; m_num_eqs = n_eqs; char * base_ptr = reinterpret_cast(this) + sizeof(th_explain); @@ -244,33 +245,24 @@ namespace euf { m_eqs = reinterpret_cast(base_ptr); for (i = 0; i < n_eqs; ++i) m_eqs[i] = eqs[i]; - base_ptr += sizeof(enode_pair) * n_eqs; - m_pragma = reinterpret_cast(base_ptr); - i = 0; - if (pma) { - std::string s = pma->to_string(); - for (i = 0; s[i]; ++i) - m_pragma[i] = s[i]; - } - m_pragma[i] = 0; } - th_explain* th_explain::mk(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode* x, enode* y, sat::proof_hint const* pma) { + th_explain* th_explain::mk(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode* x, enode* y, th_proof_hint const* pma) { region& r = th.ctx.get_region(); - void* mem = r.allocate(get_obj_size(n_lits, n_eqs, pma)); + void* mem = r.allocate(get_obj_size(n_lits, n_eqs)); sat::constraint_base::initialize(mem, &th); - return new (sat::constraint_base::ptr2mem(mem)) th_explain(n_lits, lits, n_eqs, eqs, c, enode_pair(x, y)); + return new (sat::constraint_base::ptr2mem(mem)) th_explain(n_lits, lits, n_eqs, eqs, c, enode_pair(x, y), pma); } - th_explain* th_explain::propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, sat::literal consequent, sat::proof_hint const* pma) { + th_explain* th_explain::propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, sat::literal consequent, th_proof_hint const* pma) { return mk(th, lits.size(), lits.data(), eqs.size(), eqs.data(), consequent, nullptr, nullptr, pma); } - th_explain* th_explain::propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y, sat::proof_hint const* pma) { + th_explain* th_explain::propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y, th_proof_hint const* pma) { return mk(th, lits.size(), lits.data(), eqs.size(), eqs.data(), sat::null_literal, x, y, pma); } - th_explain* th_explain::propagate(th_euf_solver& th, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y, sat::proof_hint const* pma) { + th_explain* th_explain::propagate(th_euf_solver& th, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y, th_proof_hint const* pma) { return mk(th, 0, nullptr, eqs.size(), eqs.data(), sat::null_literal, x, y, pma); } @@ -313,8 +305,8 @@ namespace euf { out << "--> " << m_consequent; if (m_eq.first != nullptr) out << "--> " << m_eq.first->get_expr_id() << " == " << m_eq.second->get_expr_id(); - if (m_pragma != nullptr) - out << " p " << m_pragma; + if (m_proof_hint != nullptr) + out << " p "; return out; } diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index 833b6c05e..8289418fc 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -58,6 +58,7 @@ namespace euf { }; + class th_decompile { public: virtual ~th_decompile() = default; @@ -138,6 +139,11 @@ namespace euf { }; + class th_proof_hint : public sat::proof_hint { + public: + virtual expr* get_hint(euf::solver& s) const = 0; + }; + class th_euf_solver : public th_solver { protected: solver& ctx; @@ -150,16 +156,16 @@ namespace euf { region& get_region(); - sat::status mk_status(sat::proof_hint const* ps = nullptr); + sat::status mk_status(th_proof_hint const* ps = nullptr); bool add_unit(sat::literal lit); bool add_units(sat::literal_vector const& lits); bool add_clause(sat::literal lit) { return add_unit(lit); } bool add_clause(sat::literal a, sat::literal b); - bool add_clause(sat::literal a, sat::literal b, sat::proof_hint const* ps); + bool add_clause(sat::literal a, sat::literal b, th_proof_hint const* ps); bool add_clause(sat::literal a, sat::literal b, sat::literal c); bool add_clause(sat::literal a, sat::literal b, sat::literal c, sat::literal d); - bool add_clause(sat::literal_vector const& lits, sat::proof_hint const* ps = nullptr) { return add_clause(lits.size(), lits.data(), ps); } - bool add_clause(unsigned n, sat::literal* lits, sat::proof_hint const* ps = nullptr); + bool add_clause(sat::literal_vector const& lits, th_proof_hint const* ps = nullptr) { return add_clause(lits.size(), lits.data(), ps); } + bool add_clause(unsigned n, sat::literal* lits, th_proof_hint const* ps = nullptr); void add_equiv(sat::literal a, sat::literal b); void add_equiv_and(sat::literal a, sat::literal_vector const& bs); @@ -220,16 +226,16 @@ namespace euf { * that retrieve literals on demand. */ class th_explain { - sat::literal m_consequent = sat::null_literal; // literal consequent for propagations - enode_pair m_eq = enode_pair(); // equality consequent for propagations + sat::literal m_consequent = sat::null_literal; // literal consequent for propagations + enode_pair m_eq = enode_pair(); // equality consequent for propagations + th_proof_hint const* m_proof_hint; unsigned m_num_literals; unsigned m_num_eqs; sat::literal* m_literals; enode_pair* m_eqs; - char* m_pragma = nullptr; - static size_t get_obj_size(unsigned num_lits, unsigned num_eqs, sat::proof_hint const* pma); - th_explain(unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode_pair const& eq, sat::proof_hint const* pma = nullptr); - static th_explain* mk(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode* x, enode* y, sat::proof_hint const* pma = nullptr); + static size_t get_obj_size(unsigned num_lits, unsigned num_eqs); + th_explain(unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode_pair const& eq, th_proof_hint const* pma = nullptr); + static th_explain* mk(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode* x, enode* y, th_proof_hint const* pma = nullptr); public: static th_explain* conflict(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs); @@ -240,9 +246,9 @@ namespace euf { static th_explain* conflict(th_euf_solver& th, sat::literal lit, euf::enode* x, euf::enode* y); static th_explain* conflict(th_euf_solver& th, euf::enode* x, euf::enode* y); static th_explain* propagate(th_euf_solver& th, sat::literal lit, euf::enode* x, euf::enode* y); - static th_explain* propagate(th_euf_solver& th, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y, sat::proof_hint const* pma = nullptr); - static th_explain* propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, sat::literal consequent, sat::proof_hint const* pma = nullptr); - static th_explain* propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y, sat::proof_hint const* pma = nullptr); + static th_explain* propagate(th_euf_solver& th, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y, th_proof_hint const* pma = nullptr); + static th_explain* propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, sat::literal consequent, th_proof_hint const* pma = nullptr); + static th_explain* propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y, th_proof_hint const* pma = nullptr); sat::ext_constraint_idx to_index() const { return sat::constraint_base::mem2base(this); @@ -277,7 +283,7 @@ namespace euf { enode_pair eq_consequent() const { return m_eq; } - sat::proof_hint const* get_pragma() const { return nullptr; } //*m_pragma ? m_pragma : nullptr; } + th_proof_hint const* get_pragma() const { return m_proof_hint; } }; diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 1e44c6c21..2d390b4f2 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -75,7 +75,6 @@ struct goal2sat::imp : public sat::sat_internalizer { func_decl_ref_vector m_unhandled_funs; bool m_default_external; bool m_euf { false }; - bool m_drat { false }; bool m_is_redundant { false }; bool m_top_level { false }; sat::literal_vector aig_lits; @@ -102,7 +101,6 @@ struct goal2sat::imp : public sat::sat_internalizer { m_ite_extra = p.get_bool("ite_extra", true); m_max_memory = megabytes_to_bytes(p.get_uint("max_memory", UINT_MAX)); m_euf = sp.euf(); - m_drat = sp.drat_file().is_non_empty_string(); } void throw_op_not_handled(std::string const& s) { @@ -169,15 +167,9 @@ struct goal2sat::imp : public sat::sat_internalizer { if (m_expr2var_replay && m_expr2var_replay->find(n, v)) return v; v = m_solver.add_var(is_ext); - log_def(v, n); return v; } - void log_def(sat::bool_var v, expr* n) { - if (m_drat && m_euf) - ensure_euf()->drat_bool_def(v, n); - } - sat::bool_var to_bool_var(expr* e) override { sat::literal l; sat::bool_var v = m_map.to_bool_var(e); diff --git a/src/shell/drat_frontend.cpp b/src/shell/drat_frontend.cpp index 19bf4c3e8..091d8cd4d 100644 --- a/src/shell/drat_frontend.cpp +++ b/src/shell/drat_frontend.cpp @@ -17,90 +17,22 @@ Copyright (c) 2020 Microsoft Corporation #include "cmd_context/cmd_context.h" #include "ast/proofs/proof_checker.h" #include "ast/rewriter/th_rewriter.h" +#include "ast/reg_decl_plugins.h" #include "sat/smt/arith_proof_checker.h" -class smt_checker { - ast_manager& m; +class drup_checker { sat::drat& m_drat; - expr_ref_vector const& m_b2e; - expr_ref_vector m_fresh_exprs; - expr_ref_vector m_core; - expr_ref_vector m_inputs; - params_ref m_params; - scoped_ptr m_lemma_solver, m_input_solver; sat::literal_vector m_units; - bool m_check_inputs { false }; - - expr* fresh(expr* e) { - unsigned i = e->get_id(); - m_fresh_exprs.reserve(i + 1); - expr* r = m_fresh_exprs.get(i); - if (!r) { - r = m.mk_fresh_const("sk", e->get_sort()); - m_fresh_exprs[i] = r; - } - return r; - } - - expr_ref define(expr* e, unsigned depth) { - expr_ref r(fresh(e), m); - m_core.push_back(m.mk_eq(r, e)); - if (depth == 0) - return r; - r = e; - if (is_app(e)) { - expr_ref_vector args(m); - for (expr* arg : *to_app(e)) - args.push_back(define(arg, depth - 1)); - r = m.mk_app(to_app(e)->get_decl(), args.size(), args.data()); - } - return r; - } - - void unfold1(sat::literal_vector const& lits) { - m_core.reset(); - for (sat::literal lit : lits) { - expr* e = m_b2e[lit.var()]; - expr_ref fml = define(e, 2); - if (!lit.sign()) - fml = m.mk_not(fml); - m_core.push_back(fml); - } - } - - expr_ref lit2expr(sat::literal lit) { - return expr_ref(lit.sign() ? m.mk_not(m_b2e[lit.var()]) : m_b2e[lit.var()], m); - } + bool m_check_inputs = false; void add_units() { auto const& units = m_drat.units(); -#if 0 - for (unsigned i = m_units.size(); i < units.size(); ++i) { - sat::literal lit = units[i].first; - m_lemma_solver->assert_expr(lit2expr(lit)); - } -#endif for (unsigned i = m_units.size(); i < units.size(); ++i) m_units.push_back(units[i].first); } void check_assertion_redundant(sat::literal_vector const& input) { - expr_ref_vector args(m); - for (auto lit : input) - args.push_back(lit2expr(lit)); - m_inputs.push_back(args.size() == 1 ? args.back() : m.mk_or(args)); - - m_input_solver->push(); - for (auto lit : input) { - m_input_solver->assert_expr(lit2expr(~lit)); - } - lbool is_sat = m_input_solver->check_sat(); - if (is_sat != l_false) { - std::cout << "Failed to verify input\n"; - exit(0); - } - m_input_solver->pop(1); } @@ -112,59 +44,71 @@ class smt_checker { */ sat::literal_vector drup_units; - void check_clause(sat::literal_vector const& lits) { - + void check_clause(sat::literal_vector const& lits) { + } + + void check_drup(sat::literal_vector const& lits) { add_units(); drup_units.reset(); if (m_drat.is_drup(lits.size(), lits.data(), drup_units)) { std::cout << "drup\n"; return; } - m_input_solver->push(); -// for (auto lit : drup_units) -// m_input_solver->assert_expr(lit2expr(lit)); - for (auto lit : lits) - m_input_solver->assert_expr(lit2expr(~lit)); - lbool is_sat = m_input_solver->check_sat(); - if (is_sat != l_false) { - std::cout << "did not verify: " << is_sat << " " << lits << "\n"; - for (sat::literal lit : lits) - std::cout << lit2expr(lit) << "\n"; - std::cout << "\n"; - m_input_solver->display(std::cout); - if (is_sat == l_true) { - model_ref mdl; - m_input_solver->get_model(mdl); - std::cout << *mdl << "\n"; - } - - exit(0); - } - m_input_solver->pop(1); - std::cout << "smt\n"; - // check_assertion_redundant(lits); + std::cout << "did not verify " << lits << "\n"; + exit(0); } public: - smt_checker(sat::drat& drat, expr_ref_vector const& b2e): - m(b2e.m()), m_drat(drat), m_b2e(b2e), m_fresh_exprs(m), m_core(m), m_inputs(m) { - m_lemma_solver = mk_smt_solver(m, m_params, symbol()); - m_input_solver = mk_smt_solver(m, m_params, symbol()); - } + drup_checker(sat::drat& drat): m_drat(drat) {} - void add(sat::literal_vector const& lits, sat::status const& st, bool validated) { + void add(sat::literal_vector const& lits, sat::status const& st) { for (sat::literal lit : lits) while (lit.var() >= m_drat.get_solver().num_vars()) m_drat.get_solver().mk_var(true); - if (st.is_input() && m_check_inputs) - check_assertion_redundant(lits); - else if (!st.is_sat() && !st.is_deleted() && !validated) - check_clause(lits); - // m_drat.add(lits, st); + if (st.is_sat()) + check_drup(lits); + m_drat.add(lits, st); } +}; + +unsigned read_drat(char const* drat_file) { + ast_manager m; + reg_decl_plugins(m); + std::ifstream ins(drat_file); + dimacs::drat_parser drat(ins, std::cerr); + + std::function read_theory = [&](char const* r) { + return m.mk_family_id(symbol(r)); + }; + std::function write_theory = [&](int th) { + return m.get_family_name(th); + }; + drat.set_read_theory(read_theory); + params_ref p; + reslimit lim; + sat::solver solver(p, lim); + sat::drat drat_checker(solver); + drup_checker checker(drat_checker); + + for (auto const& r : drat) { + std::cout << dimacs::drat_pp(r, write_theory); + std::cout.flush(); + checker.add(r.m_lits, r.m_status); + if (drat_checker.inconsistent()) { + std::cout << "inconsistent\n"; + return 0; + } + statistics st; + drat_checker.collect_statistics(st); + std::cout << st << "\n"; + } + return 0; +} + + +#if 0 bool validate_hint(expr_ref_vector const& exprs, sat::literal_vector const& lits, sat::proof_hint const& hint) { - // return; // remove when testing this arith_util autil(m); arith::proof_checker achecker(m); proof_checker pc(m); @@ -247,7 +191,6 @@ public: std::cout << "p hint not verified\n"; return false; } - std::cout << "p hint verified\n"; return true; @@ -260,291 +203,4 @@ public: return false; } - /** - * Add an assertion from the source file - */ - void add_assertion(expr* a) { - m_input_solver->assert_expr(a); - } - - void display_input() { - scoped_ptr s = mk_smt_solver(m, m_params, symbol()); - for (auto* e : m_inputs) - s->assert_expr(e); - s->display(std::cout); - } - - symbol name; - unsigned_vector params; - ptr_vector sorts; - - void parse_quantifier(sexpr_ref const& sexpr, cmd_context& ctx, quantifier_kind& k, sort_ref_vector& domain, svector& names) { - k = quantifier_kind::forall_k; - symbol q; - unsigned sz; - if (sexpr->get_kind() != sexpr::kind_t::COMPOSITE) - goto bail; - sz = sexpr->get_num_children(); - if (sz == 0) - goto bail; - q = sexpr->get_child(0)->get_symbol(); - if (q == "forall") - k = quantifier_kind::forall_k; - else if (q == "exists") - k = quantifier_kind::exists_k; - else if (q == "lambda") - k = quantifier_kind::lambda_k; - else - goto bail; - for (unsigned i = 1; i < sz; ++i) { - auto* e = sexpr->get_child(i); - if (e->get_kind() != sexpr::kind_t::COMPOSITE) - goto bail; - if (2 != e->get_num_children()) - goto bail; - symbol name = e->get_child(0)->get_symbol(); - std::ostringstream ostrm; - e->get_child(1)->display(ostrm); - std::istringstream istrm(ostrm.str()); - params_ref p; - auto srt = parse_smt2_sort(ctx, istrm, false, p, "quantifier"); - if (!srt) - goto bail; - names.push_back(name); - domain.push_back(srt); - } - return; - bail: - std::cout << "Could not parse expression\n"; - sexpr->display(std::cout); - std::cout << "\n"; - exit(0); - } - - void parse_sexpr(sexpr_ref const& sexpr, cmd_context& ctx, expr_ref_vector const& args, expr_ref& result) { - params.reset(); - sorts.reset(); - for (expr* arg : args) - sorts.push_back(arg->get_sort()); - sort_ref rng(m); - func_decl* f = nullptr; - switch (sexpr->get_kind()) { - case sexpr::kind_t::COMPOSITE: { - unsigned sz = sexpr->get_num_children(); - if (sz == 0) - goto bail; - if (sexpr->get_child(0)->get_symbol() == symbol("_")) { - name = sexpr->get_child(1)->get_symbol(); - if (name == "bv" && sz == 4) { - bv_util bvu(m); - auto val = sexpr->get_child(2)->get_numeral(); - auto n = sexpr->get_child(3)->get_numeral().get_unsigned(); - result = bvu.mk_numeral(val, n); - return; - } - if (name == "is" && sz == 3) { - name = sexpr->get_child(2)->get_child(0)->get_symbol(); - f = ctx.find_func_decl(name, params.size(), params.data(), args.size(), sorts.data(), rng.get()); - if (!f) - goto bail; - datatype_util dtu(m); - result = dtu.mk_is(f, args[0]); - return; - } - if (name == "Real" && sz == 4) { - arith_util au(m); - rational r = sexpr->get_child(2)->get_numeral(); - // rational den = sexpr->get_child(3)->get_numeral(); - result = au.mk_numeral(r, false); - return; - } - if (name == "Int" && sz == 4) { - arith_util au(m); - rational num = sexpr->get_child(2)->get_numeral(); - result = au.mk_numeral(num, true); - return; - } - if (name == "as-array" && sz == 3) { - array_util au(m); - auto const* ch2 = sexpr->get_child(2); - switch (ch2->get_kind()) { - case sexpr::kind_t::COMPOSITE: - break; - default: - name = sexpr->get_child(2)->get_symbol(); - f = ctx.find_func_decl(name); - if (f) { - result = au.mk_as_array(f); - return; - } - } - } - for (unsigned i = 2; i < sz; ++i) { - auto* child = sexpr->get_child(i); - if (child->is_numeral() && child->get_numeral().is_unsigned()) - params.push_back(child->get_numeral().get_unsigned()); - else - goto bail; - } - break; - } - goto bail; - } - case sexpr::kind_t::SYMBOL: - name = sexpr->get_symbol(); - break; - case sexpr::kind_t::BV_NUMERAL: { - goto bail; - } - case sexpr::kind_t::STRING: - case sexpr::kind_t::KEYWORD: - case sexpr::kind_t::NUMERAL: - default: - goto bail; - } - f = ctx.find_func_decl(name, params.size(), params.data(), args.size(), sorts.data(), rng.get()); - if (!f) - goto bail; - result = ctx.m().mk_app(f, args); - return; - bail: - std::cout << "Could not parse expression\n"; - sexpr->display(std::cout); - std::cout << "\n"; - exit(0); - } -}; - -static void verify_smt(char const* drat_file, char const* smt_file) { - cmd_context ctx; - ctx.set_ignore_check(true); - ctx.set_regular_stream(std::cerr); - ctx.set_solver_factory(mk_smt_strategic_solver_factory()); - if (smt_file) { - std::ifstream smt_in(smt_file); - if (!parse_smt2_commands(ctx, smt_in)) { - std::cerr << "could not read file " << smt_file << "\n"; - return; - } - } - - std::ifstream ins(drat_file); - dimacs::drat_parser drat(ins, std::cerr); - ast_manager& m = ctx.m(); - std::function read_theory = [&](char const* r) { - return m.mk_family_id(symbol(r)); - }; - std::function write_theory = [&](int th) { - return m.get_family_name(th); - }; - drat.set_read_theory(read_theory); - params_ref p; - reslimit lim; - p.set_bool("drat.check_unsat", true); - sat::solver solver(p, lim); - sat::drat drat_checker(solver); - drat_checker.updt_config(); - - expr_ref_vector bool_var2expr(m); - expr_ref_vector exprs(m), args(m), inputs(m); - sort_ref_vector sargs(m), sorts(m); - func_decl_ref_vector decls(m); - - smt_checker checker(drat_checker, bool_var2expr); - - for (expr* a : ctx.assertions()) - checker.add_assertion(a); - - for (auto const& r : drat) { - std::cout << dimacs::drat_pp(r, write_theory); - std::cout.flush(); - switch (r.m_tag) { - case dimacs::drat_record::tag_t::is_clause: { - bool validated = checker.validate_hint(exprs, r.m_lits, r.m_hint); - checker.add(r.m_lits, r.m_status, validated); - if (drat_checker.inconsistent()) { - std::cout << "inconsistent\n"; - return; - } - break; - } - case dimacs::drat_record::tag_t::is_node: { - expr_ref e(m); - args.reset(); - for (auto n : r.m_args) - args.push_back(exprs.get(n)); - std::istringstream strm(r.m_name); - auto sexpr = parse_sexpr(ctx, strm, p, drat_file); - checker.parse_sexpr(sexpr, ctx, args, e); - exprs.reserve(r.m_node_id + 1); - exprs.set(r.m_node_id, e); - break; - } - case dimacs::drat_record::tag_t::is_var: { - var_ref e(m); - SASSERT(r.m_args.size() == 1); - std::istringstream strm(r.m_name); - auto srt = parse_smt2_sort(ctx, strm, false, p, drat_file); - e = m.mk_var(r.m_args[0], srt); - exprs.reserve(r.m_node_id + 1); - exprs.set(r.m_node_id, e); - break; - } - case dimacs::drat_record::tag_t::is_decl: { - std::istringstream strm(r.m_name); - ctx.set_allow_duplicate_declarations(); - parse_smt2_commands(ctx, strm); - break; - } - case dimacs::drat_record::tag_t::is_sort: { - sort_ref srt(m); - symbol name = symbol(r.m_name.c_str()); - sargs.reset(); - for (auto n : r.m_args) - sargs.push_back(sorts.get(n)); - psort_decl* pd = ctx.find_psort_decl(name); - if (pd) - srt = pd->instantiate(ctx.pm(), sargs.size(), sargs.data()); - else - srt = m.mk_uninterpreted_sort(name); - sorts.reserve(r.m_node_id + 1); - sorts.set(r.m_node_id, srt); - break; - } - case dimacs::drat_record::tag_t::is_quantifier: { - VERIFY(r.m_args.size() == 1); - quantifier_ref q(m); - std::istringstream strm(r.m_name); - auto sexpr = parse_sexpr(ctx, strm, p, drat_file); - sort_ref_vector domain(m); - svector names; - quantifier_kind k; - checker.parse_quantifier(sexpr, ctx, k, domain, names); - q = m.mk_quantifier(k, domain.size(), domain.data(), names.data(), exprs.get(r.m_args[0])); - exprs.reserve(r.m_node_id + 1); - exprs.set(r.m_node_id, q); - break; - } - case dimacs::drat_record::tag_t::is_bool_def: - bool_var2expr.reserve(r.m_node_id + 1); - bool_var2expr.set(r.m_node_id, exprs.get(r.m_args[0])); - break; - default: - UNREACHABLE(); - break; - } - } - statistics st; - drat_checker.collect_statistics(st); - std::cout << st << "\n"; -} - - -unsigned read_drat(char const* drat_file, char const* problem_file) { - if (!problem_file) { - std::cerr << "No smt2 file provided to checker\n"; - return -1; - } - verify_smt(drat_file, problem_file); - return 0; -} +#endif diff --git a/src/shell/drat_frontend.h b/src/shell/drat_frontend.h index ef37fcfaa..89760360e 100644 --- a/src/shell/drat_frontend.h +++ b/src/shell/drat_frontend.h @@ -3,6 +3,6 @@ Copyright (c) 2011 Microsoft Corporation --*/ #pragma once -unsigned read_drat(char const * drat_file, char const* problem_file); +unsigned read_drat(char const * drat_file); diff --git a/src/shell/main.cpp b/src/shell/main.cpp index b1da07db8..af3b22db0 100644 --- a/src/shell/main.cpp +++ b/src/shell/main.cpp @@ -402,7 +402,7 @@ int STD_CALL main(int argc, char ** argv) { return_value = read_mps_file(g_input_file); break; case IN_DRAT: - return_value = read_drat(g_drat_input_file, g_input_file); + return_value = read_drat(g_drat_input_file); break; default: UNREACHABLE(); diff --git a/src/shell/smtlib_frontend.cpp b/src/shell/smtlib_frontend.cpp index c86d4b5b9..7c81d2211 100644 --- a/src/shell/smtlib_frontend.cpp +++ b/src/shell/smtlib_frontend.cpp @@ -26,6 +26,7 @@ Revision History: #include "parsers/smt2/smt2parser.h" #include "muz/fp/dl_cmds.h" #include "cmd_context/extra_cmds/dbg_cmds.h" +#include "cmd_context/proof_cmds.h" #include "opt/opt_cmds.h" #include "cmd_context/extra_cmds/polynomial_cmds.h" #include "cmd_context/extra_cmds/subpaving_cmds.h" @@ -128,6 +129,7 @@ unsigned read_smtlib2_commands(char const * file_name) { install_subpaving_cmds(ctx); install_opt_cmds(ctx); install_smt2_extra_cmds(ctx); + install_proof_cmds(ctx); g_cmd_context = &ctx; signal(SIGINT, on_ctrl_c);