diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 30dd1d720..8f04b1b00 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -571,6 +571,7 @@ namespace euf { m_updates.push_back(update_record(false, update_record::inconsistent())); m_n1 = n1; m_n2 = n2; + TRACE("euf", tout << "conflict " << bpp(n1) << " " << bpp(n2) << " " << j << "\n"); m_justification = j; } @@ -723,7 +724,7 @@ namespace euf { else if (j.is_congruence()) push_congruence(a, b, j.is_commutative()); if (cc && j.is_congruence()) - cc->push_back(std::tuple(a, b, j.timestamp(), j.is_commutative())); + cc->push_back(std::tuple(a->get_app(), b->get_app(), j.timestamp(), j.is_commutative())); } diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index c0d7f03d8..53aaf481a 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -77,7 +77,7 @@ namespace euf { // It is the only information collected from justifications in order to // reconstruct EUF proofs. Transitivity, Symmetry of equality are not // tracked. - typedef std::tuple cc_justification_record; + typedef std::tuple cc_justification_record; typedef svector cc_justification; class egraph { diff --git a/src/cmd_context/extra_cmds/proof_cmds.cpp b/src/cmd_context/extra_cmds/proof_cmds.cpp index b90daf29a..c3269253f 100644 --- a/src/cmd_context/extra_cmds/proof_cmds.cpp +++ b/src/cmd_context/extra_cmds/proof_cmds.cpp @@ -52,153 +52,6 @@ Proof checker for clauses created during search. #include "params/solver_params.hpp" #include -class smt_checker { - ast_manager& m; - params_ref m_params; - - // for checking proof rules (hints) - euf::proof_checker m_checker; - - // for fallback SMT checker - scoped_ptr m_solver; - - // for RUP - symbol m_rup; - sat::solver m_sat_solver; - sat::drat m_drat; - sat::literal_vector m_units; - sat::literal_vector m_clause; - - void add_units() { - auto const& units = m_drat.units(); - for (unsigned i = m_units.size(); i < units.size(); ++i) - m_units.push_back(units[i].first); - } - -public: - smt_checker(ast_manager& m): - m(m), - m_checker(m), - m_sat_solver(m_params, m.limit()), - m_drat(m_sat_solver) - { - m_params.set_bool("drat.check_unsat", true); - m_sat_solver.updt_params(m_params); - m_drat.updt_config(); - m_solver = mk_smt_solver(m, m_params, symbol()); - m_rup = symbol("rup"); - } - - bool is_rup(app* proof_hint) { - return - proof_hint && - proof_hint->get_name() == m_rup; - } - - void mk_clause(expr_ref_vector const& clause) { - m_clause.reset(); - for (expr* e : clause) { - bool sign = false; - while (m.is_not(e, e)) - sign = !sign; - m_clause.push_back(sat::literal(e->get_id(), sign)); - } - } - - void mk_clause(expr* e) { - m_clause.reset(); - bool sign = false; - while (m.is_not(e, e)) - sign = !sign; - m_clause.push_back(sat::literal(e->get_id(), sign)); - } - - bool check_rup(expr_ref_vector const& clause) { - add_units(); - mk_clause(clause); - return m_drat.is_drup(m_clause.size(), m_clause.data(), m_units); - } - - bool check_rup(expr* u) { - add_units(); - mk_clause(u); - return m_drat.is_drup(m_clause.size(), m_clause.data(), m_units); - } - - void add_clause(expr_ref_vector const& clause) { - mk_clause(clause); - m_drat.add(m_clause, sat::status::input()); - } - - void check(expr_ref_vector& clause, app* proof_hint) { - - if (is_rup(proof_hint) && check_rup(clause)) { - std::cout << "(verified-rup)\n"; - return; - } - - expr_ref_vector units(m); - if (m_checker.check(clause, proof_hint, units)) { - bool units_are_rup = true; - for (expr* u : units) { - if (!check_rup(u)) { - std::cout << "unit " << mk_pp(u, m) << " is not rup\n"; - units_are_rup = false; - } - } - if (units_are_rup) { - std::cout << "(verified-" << proof_hint->get_name() << ")\n"; - add_clause(clause); - return; - } - } - - // extract a simplified verification condition in case proof validation does not work. - // quantifier instantiation can be validated as follows: - // If quantifier instantiation claims that (forall x . phi(x)) => psi using instantiation x -> t - // then check the simplified VC: phi(t) => psi. - // in case psi is the literal instantiation, then the clause is a propositional tautology. - // The VC function is a no-op if the proof hint does not have an associated vc generator. - expr_ref_vector vc(clause); - if (m_checker.vc(proof_hint, clause, vc)) { - std::cout << "(verified-" << proof_hint->get_name() << ")\n"; - add_clause(clause); - return; - } - - m_solver->push(); - for (expr* lit : vc) - m_solver->assert_expr(m.mk_not(lit)); - lbool is_sat = m_solver->check_sat(); - if (is_sat != l_false) { - std::cout << "did not verify: " << is_sat << " " << clause << "\n\n"; - m_solver->display(std::cout); - if (is_sat == l_true) { - model_ref mdl; - m_solver->get_model(mdl); - std::cout << *mdl << "\n"; - } - exit(0); - } - m_solver->pop(1); - std::cout << "(verified-smt"; - if (proof_hint) std::cout << "\n" << mk_bounded_pp(proof_hint, m, 4); - for (expr* arg : clause) - std::cout << "\n " << mk_bounded_pp(arg, m); - std::cout << ")\n"; - add_clause(clause); - } - - void assume(expr_ref_vector const& clause) { - add_clause(clause); - m_solver->assert_expr(mk_or(clause)); - } - - void del(expr_ref_vector const& clause) { - - } - -}; /** * Replay proof entierly, then walk backwards extracting reduced proof. @@ -207,7 +60,7 @@ class proof_trim { cmd_context& ctx; ast_manager& m; sat::proof_trim trim; - euf::proof_checker m_checker; + euf::theory_checker m_checker; vector m_clauses; bool_vector m_is_infer; symbol m_rup; @@ -371,11 +224,11 @@ class proof_cmds_imp : public proof_cmds { bool m_check = true; bool m_save = false; bool m_trim = false; - scoped_ptr m_checker; + scoped_ptr m_checker; scoped_ptr m_saver; scoped_ptr m_trimmer; - smt_checker& checker() { if (!m_checker) m_checker = alloc(smt_checker, m); return *m_checker; } + euf::smt_proof_checker& checker() { params_ref p; if (!m_checker) m_checker = alloc(euf::smt_proof_checker, m, p); return *m_checker; } proof_saver& saver() { if (!m_saver) m_saver = alloc(proof_saver, ctx); return *m_saver; } proof_trim& trim() { if (!m_trimmer) m_trimmer = alloc(proof_trim, ctx); return *m_trimmer; } @@ -404,7 +257,7 @@ public: void end_infer() override { if (m_check) - checker().check(m_lits, m_proof_hint); + checker().infer(m_lits, m_proof_hint); if (m_save) saver().infer(m_lits, m_proof_hint); if (m_trim) diff --git a/src/model/model.cpp b/src/model/model.cpp index dfa76db68..c89ff59a2 100644 --- a/src/model/model.cpp +++ b/src/model/model.cpp @@ -225,6 +225,20 @@ struct model::top_sort : public ::top_sort { } }; +void model::evaluate_constants() { + for (auto& [k, p] : m_interp) { + auto & [i, e] = p; + if (m.is_value(e)) + continue; + expr_ref val(m); + val = (*this)(e); + m.dec_ref(e); + m.inc_ref(val); + p.second = val; + } +} + + void model::compress(bool force_inline) { if (m_cleaned) return; diff --git a/src/model/model.h b/src/model/model.h index 07049a522..a93fc1b4f 100644 --- a/src/model/model.h +++ b/src/model/model.h @@ -94,6 +94,8 @@ public: void compress(bool force_inline = false); + void evaluate_constants(); + void set_model_completion(bool f) { m_mev.set_model_completion(f); } void updt_params(params_ref const & p); diff --git a/src/sat/sat_config.cpp b/src/sat/sat_config.cpp index d21ec5b93..2330fe401 100644 --- a/src/sat/sat_config.cpp +++ b/src/sat/sat_config.cpp @@ -198,7 +198,16 @@ namespace sat { m_drat_check_sat = p.drat_check_sat(); m_drat_file = p.drat_file(); m_smt_proof = p.smt_proof(); - m_drat = !p.drat_disable() && (sp.lemmas2console() || m_drat_check_unsat || m_drat_file.is_non_empty_string() || m_smt_proof.is_non_empty_string() || m_drat_check_sat) && p.threads() == 1; + m_smt_proof_check = p.smt_proof_check(); + m_smt_proof_check_rup = p.smt_proof_check_rup(); + m_drat = + !p.drat_disable() && p.threads() == 1 && + (sp.lemmas2console() || + m_drat_check_unsat || + m_drat_file.is_non_empty_string() || + m_smt_proof.is_non_empty_string() || + m_smt_proof_check || + m_drat_check_sat); m_drat_binary = p.drat_binary(); m_drat_activity = p.drat_activity(); m_dyn_sub_res = p.dyn_sub_res(); diff --git a/src/sat/sat_config.h b/src/sat/sat_config.h index 7d98b092c..2d609b1bc 100644 --- a/src/sat/sat_config.h +++ b/src/sat/sat_config.h @@ -178,6 +178,8 @@ namespace sat { bool m_drat_binary; symbol m_drat_file; symbol m_smt_proof; + bool m_smt_proof_check; + bool m_smt_proof_check_rup; bool m_drat_check_unsat; bool m_drat_check_sat; bool m_drat_activity; diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index d6a956a32..8f15f5f68 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -89,7 +89,7 @@ namespace sat { virtual bool unit_propagate() = 0; virtual bool is_external(bool_var v) { return false; } virtual double get_reward(literal l, ext_constraint_idx idx, literal_occs_fun& occs) const { return 0; } - virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r, bool probing) = 0; + virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r, bool probing, proof_hint*& ph) = 0; virtual bool is_extended_binary(ext_justification_idx idx, literal_vector & r) { return false; } virtual bool decide(bool_var& var, lbool& phase) { return false; } virtual bool get_case_split(bool_var& var, lbool& phase) { return false; } diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index 41bfa7afa..c35e97b92 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -48,6 +48,8 @@ def_module_params('sat', ('dimacs.core', BOOL, False, 'extract core from DIMACS benchmarks'), ('drat.disable', BOOL, False, 'override anything that enables DRAT'), ('smt.proof', SYMBOL, '', 'add SMT proof to file'), + ('smt.proof.check', BOOL, False, 'check SMT proof while it is created'), + ('smt.proof.check_rup', BOOL, True, 'apply forward RUP proof checking'), ('drat.file', SYMBOL, '', 'file to dump DRAT proofs'), ('drat.binary', BOOL, False, 'use Binary DRAT output format'), ('drat.check_unsat', BOOL, False, 'build up internal proof and check'), diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index d10e1124d..7f7b34ea8 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -402,8 +402,8 @@ namespace sat { extension::scoped_drating _sd(*m_ext.get()); if (j.get_kind() == justification::EXT_JUSTIFICATION) fill_ext_antecedents(lit, j, false); - TRACE("sat", tout << "drat-unit\n"); - m_drat.add(lit, m_searching); + else + m_drat.add(lit, m_searching); } void solver::drat_log_clause(unsigned num_lits, literal const* lits, sat::status st) { @@ -2890,7 +2890,8 @@ namespace sat { SASSERT(m_ext); auto idx = js.get_ext_justification_idx(); m_ext_antecedents.reset(); - m_ext->get_antecedents(consequent, idx, m_ext_antecedents, probing); + proof_hint* ph = nullptr; + m_ext->get_antecedents(consequent, idx, m_ext_antecedents, probing, ph); } bool solver::is_two_phase() const { diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 6d75c0f1f..0b1e68e96 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -36,7 +36,7 @@ z3_add_component(sat_smt q_mam.cpp q_mbi.cpp q_model_fixer.cpp - q_proof_checker.cpp + q_theory_checker.cpp q_queue.cpp q_solver.cpp recfun_solver.cpp diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index f7e3f6293..5db5688b3 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -1476,7 +1476,7 @@ namespace arith { return r; } - void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) { + void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing, sat::proof_hint*& ph) { auto& jst = euf::th_explain::from_index(idx); ctx.get_antecedents(l, jst, r, probing); } diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 3d3f1ddd7..672477df2 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -485,7 +485,7 @@ namespace arith { solver(euf::solver& ctx, theory_id id); ~solver() override; bool is_external(bool_var v) override { return false; } - void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) override; + void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing, sat::proof_hint*& ph) override; void asserted(literal l) override; sat::check_result check() override; void simplify() override {} diff --git a/src/sat/smt/arith_proof_checker.h b/src/sat/smt/arith_theory_checker.h similarity index 98% rename from src/sat/smt/arith_proof_checker.h rename to src/sat/smt/arith_theory_checker.h index 0e5c891d5..f11e7ae29 100644 --- a/src/sat/smt/arith_proof_checker.h +++ b/src/sat/smt/arith_theory_checker.h @@ -34,7 +34,7 @@ The module assumes a limited repertoire of arithmetic proof rules. namespace arith { - class proof_checker : public euf::proof_checker_plugin { + class theory_checker : public euf::theory_checker_plugin { struct row { obj_map m_coeffs; rational m_coeff; @@ -304,7 +304,7 @@ namespace arith { } public: - proof_checker(ast_manager& m): + theory_checker(ast_manager& m): m(m), a(m), m_farkas("farkas"), @@ -468,7 +468,7 @@ namespace arith { return check(); } - void register_plugins(euf::proof_checker& pc) override { + void register_plugins(euf::theory_checker& pc) override { pc.register_plugin(m_farkas, this); pc.register_plugin(m_bound, this); pc.register_plugin(m_implied_eq, this); diff --git a/src/sat/smt/array_solver.h b/src/sat/smt/array_solver.h index 5c2708842..f161af01a 100644 --- a/src/sat/smt/array_solver.h +++ b/src/sat/smt/array_solver.h @@ -278,7 +278,7 @@ namespace array { solver(euf::solver& ctx, theory_id id); ~solver() override; bool is_external(bool_var v) override { return false; } - void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) override {} + void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing, sat::proof_hint*& ph) override {} void asserted(literal l) override {} sat::check_result check() override; diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index 8ef2c5cd5..631f954b1 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -306,7 +306,7 @@ namespace bv { bool solver::is_extended_binary(sat::ext_justification_idx idx, literal_vector& r) { return false; } bool solver::is_external(bool_var v) { return true; } - void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) { + void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing, sat::proof_hint*& ph) { auto& c = bv_justification::from_index(idx); TRACE("bv", display_constraint(tout, idx) << "\n";); switch (c.m_kind) { @@ -395,6 +395,7 @@ namespace bv { sat::literal leq1(s().num_vars() + 1, false); sat::literal leq2(s().num_vars() + 2, false); expr_ref eq1(m), eq2(m); + sat::proof_hint* ph = nullptr; if (c.m_kind == bv_justification::kind_t::bv2int) { eq1 = m.mk_eq(c.a->get_expr(), c.b->get_expr()); eq2 = m.mk_eq(c.a->get_expr(), c.c->get_expr()); @@ -416,24 +417,24 @@ namespace bv { lits.push_back(c.m_consequent); break; case bv_justification::kind_t::ne2bit: - get_antecedents(c.m_consequent, c.to_index(), lits, true); + get_antecedents(c.m_consequent, c.to_index(), lits, true, ph); lits.push_back(c.m_consequent); break; case bv_justification::kind_t::bit2eq: - get_antecedents(leq1, c.to_index(), lits, true); + get_antecedents(leq1, c.to_index(), lits, true, ph); for (auto& lit : lits) lit.neg(); lits.push_back(leq1); break; case bv_justification::kind_t::bit2ne: - get_antecedents(c.m_consequent, c.to_index(), lits, true); + get_antecedents(c.m_consequent, c.to_index(), lits, true, ph); for (auto& lit : lits) lit.neg(); lits.push_back(c.m_consequent); break; case bv_justification::kind_t::bv2int: - get_antecedents(leq1, c.to_index(), lits, true); - get_antecedents(leq2, c.to_index(), lits, true); + get_antecedents(leq1, c.to_index(), lits, true, ph); + get_antecedents(leq2, c.to_index(), lits, true, ph); for (auto& lit : lits) lit.neg(); lits.push_back(leq1); diff --git a/src/sat/smt/bv_solver.h b/src/sat/smt/bv_solver.h index 2c8fb4ae9..313679443 100644 --- a/src/sat/smt/bv_solver.h +++ b/src/sat/smt/bv_solver.h @@ -335,7 +335,7 @@ namespace bv { double get_reward(literal l, sat::ext_constraint_idx idx, sat::literal_occs_fun& occs) const override; bool is_extended_binary(sat::ext_justification_idx idx, literal_vector& r) override; bool is_external(bool_var v) override; - void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector & r, bool probing) override; + void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector & r, bool probing, sat::proof_hint*& ph) override; void asserted(literal l) override; sat::check_result check() override; void push_core() override; diff --git a/src/sat/smt/dt_solver.cpp b/src/sat/smt/dt_solver.cpp index a87f8770b..80846ec58 100644 --- a/src/sat/smt/dt_solver.cpp +++ b/src/sat/smt/dt_solver.cpp @@ -755,7 +755,7 @@ namespace dt { SASSERT(m_find.get_num_vars() == get_num_vars()); } - void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) { + void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing, sat::proof_hint*& ph) { auto& jst = euf::th_explain::from_index(idx); ctx.get_antecedents(l, jst, r, probing); } diff --git a/src/sat/smt/dt_solver.h b/src/sat/smt/dt_solver.h index 4e2524f6b..136e40eac 100644 --- a/src/sat/smt/dt_solver.h +++ b/src/sat/smt/dt_solver.h @@ -140,7 +140,7 @@ namespace dt { ~solver() override; bool is_external(bool_var v) override { return false; } - void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) override; + void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing, sat::proof_hint*& ph) override; void asserted(literal l) override; sat::check_result check() override; diff --git a/src/sat/smt/euf_ackerman.cpp b/src/sat/smt/euf_ackerman.cpp index c8639e302..2026120e4 100644 --- a/src/sat/smt/euf_ackerman.cpp +++ b/src/sat/smt/euf_ackerman.cpp @@ -20,7 +20,7 @@ Author: namespace euf { - ackerman::ackerman(solver& s, ast_manager& m): s(s), m(m) { + ackerman::ackerman(solver& ctx, ast_manager& m): ctx(ctx), m(m) { new_tmp(); } @@ -100,31 +100,31 @@ namespace euf { } bool ackerman::enable_cc(app* a, app* b) { - if (!s.enable_ackerman_axioms(a)) + if (!ctx.enable_ackerman_axioms(a)) return false; - if (!s.enable_ackerman_axioms(b)) + if (!ctx.enable_ackerman_axioms(b)) return false; for (expr* arg : *a) - if (!s.enable_ackerman_axioms(arg)) + if (!ctx.enable_ackerman_axioms(arg)) return false; for (expr* arg : *b) - if (!s.enable_ackerman_axioms(arg)) + if (!ctx.enable_ackerman_axioms(arg)) return false; return true; } bool ackerman::enable_eq(expr* a, expr* b, expr* c) { - return s.enable_ackerman_axioms(a) && - s.enable_ackerman_axioms(b) && - s.enable_ackerman_axioms(c); + return ctx.enable_ackerman_axioms(a) && + ctx.enable_ackerman_axioms(b) && + ctx.enable_ackerman_axioms(c); } void ackerman::cg_conflict_eh(expr * n1, expr * n2) { if (!is_app(n1) || !is_app(n2)) return; - if (!s.enable_ackerman_axioms(n1)) + if (!ctx.enable_ackerman_axioms(n1)) return; - SASSERT(!s.m_drating); + SASSERT(!ctx.m_drating); app* a = to_app(n1); app* b = to_app(n2); if (a->get_decl() != b->get_decl() || a->get_num_args() != b->get_num_args()) @@ -139,7 +139,7 @@ namespace euf { void ackerman::used_eq_eh(expr* a, expr* b, expr* c) { if (a == b || a == c || b == c) return; - if (s.m_drating) + if (ctx.m_drating) return; if (!enable_eq(a, b, c)) return; @@ -149,7 +149,7 @@ namespace euf { } void ackerman::used_cc_eh(app* a, app* b) { - if (s.m_drating) + if (ctx.m_drating) return; TRACE("ack", tout << "used cc: " << mk_pp(a, m) << " == " << mk_pp(b, m) << "\n";); SASSERT(a->get_decl() == b->get_decl()); @@ -162,7 +162,7 @@ namespace euf { void ackerman::gc() { m_num_propagations_since_last_gc++; - if (m_num_propagations_since_last_gc <= s.m_config.m_dack_gc) + if (m_num_propagations_since_last_gc <= ctx.m_config.m_dack_gc) return; m_num_propagations_since_last_gc = 0; @@ -175,14 +175,14 @@ namespace euf { } void ackerman::propagate() { - SASSERT(s.s().at_base_lvl()); + SASSERT(ctx.s().at_base_lvl()); auto* n = m_queue; inference* k = nullptr; - unsigned num_prop = static_cast(s.s().get_stats().m_conflict * s.m_config.m_dack_factor); + unsigned num_prop = static_cast(ctx.s().get_stats().m_conflict * ctx.m_config.m_dack_factor); num_prop = std::min(num_prop, m_table.size()); for (unsigned i = 0; i < num_prop; ++i, n = k) { k = n->next(); - if (n->m_count < s.m_config.m_dack_threshold) + if (n->m_count < ctx.m_config.m_dack_threshold) continue; if (n->m_count >= m_high_watermark && num_prop < m_table.size()) ++num_prop; @@ -190,13 +190,13 @@ namespace euf { add_cc(n->a, n->b); else add_eq(n->a, n->b, n->c); - ++s.m_stats.m_ackerman; + ++ctx.m_stats.m_ackerman; remove(n); } } void ackerman::add_cc(expr* _a, expr* _b) { - flet _is_redundant(s.m_is_redundant, true); + flet _is_redundant(ctx.m_is_redundant, true); app* a = to_app(_a); app* b = to_app(_b); TRACE("ack", tout << mk_pp(a, m) << " " << mk_pp(b, m) << "\n";); @@ -204,24 +204,33 @@ namespace euf { unsigned sz = a->get_num_args(); for (unsigned i = 0; i < sz; ++i) { - expr_ref eq = s.mk_eq(a->get_arg(i), b->get_arg(i)); - lits.push_back(~s.mk_literal(eq)); + expr* ai = a->get_arg(i); + expr* bi = b->get_arg(i); + if (ai != bi) { + expr_ref eq = ctx.mk_eq(ai, bi); + lits.push_back(~ctx.mk_literal(eq)); + } } - expr_ref eq = s.mk_eq(a, b); - lits.push_back(s.mk_literal(eq)); - s.s().mk_clause(lits, sat::status::th(true, m.get_basic_family_id())); + expr_ref eq = ctx.mk_eq(a, b); + lits.push_back(ctx.mk_literal(eq)); + th_proof_hint* ph = ctx.mk_cc_proof_hint(lits, a, b); + ctx.s().mk_clause(lits, sat::status::th(true, m.get_basic_family_id(), ph)); } void ackerman::add_eq(expr* a, expr* b, expr* c) { - flet _is_redundant(s.m_is_redundant, true); + if (a == c || b == c) + return; + flet _is_redundant(ctx.m_is_redundant, true); sat::literal lits[3]; - expr_ref eq1(s.mk_eq(a, c), m); - expr_ref eq2(s.mk_eq(b, c), m); - expr_ref eq3(s.mk_eq(a, b), m); + expr_ref eq1(ctx.mk_eq(a, c), m); + expr_ref eq2(ctx.mk_eq(b, c), m); + expr_ref eq3(ctx.mk_eq(a, b), m); TRACE("ack", tout << mk_pp(a, m) << " " << mk_pp(b, m) << " " << mk_pp(c, m) << "\n";); - lits[0] = ~s.mk_literal(eq1); - lits[1] = ~s.mk_literal(eq2); - lits[2] = s.mk_literal(eq3); - s.s().mk_clause(3, lits, sat::status::th(true, m.get_basic_family_id())); + lits[0] = ~ctx.mk_literal(eq1); + lits[1] = ~ctx.mk_literal(eq2); + lits[2] = ctx.mk_literal(eq3); + th_proof_hint* ph = ctx.mk_tc_proof_hint(lits); + ctx.s().mk_clause(3, lits, sat::status::th(true, m.get_basic_family_id(), ph)); } + } diff --git a/src/sat/smt/euf_ackerman.h b/src/sat/smt/euf_ackerman.h index 17c7404f2..b5af2f689 100644 --- a/src/sat/smt/euf_ackerman.h +++ b/src/sat/smt/euf_ackerman.h @@ -52,14 +52,14 @@ namespace euf { typedef hashtable table_t; - solver& s; + solver& ctx; ast_manager& m; table_t m_table; - inference* m_queue { nullptr }; - inference* m_tmp_inference { nullptr }; - unsigned m_gc_threshold { 100 }; - unsigned m_high_watermark { 1000 }; - unsigned m_num_propagations_since_last_gc { 0 }; + inference* m_queue = nullptr; + inference* m_tmp_inference = nullptr; + unsigned m_gc_threshold = 100; + unsigned m_high_watermark = 1000 ; + unsigned m_num_propagations_since_last_gc = 0; void reset(); void new_tmp(); @@ -75,7 +75,7 @@ namespace euf { public: - ackerman(solver& s, ast_manager& m); + ackerman(solver& ctx, ast_manager& m); ~ackerman(); void cg_conflict_eh(expr * n1, expr * n2); diff --git a/src/sat/smt/euf_proof.cpp b/src/sat/smt/euf_proof.cpp index 8ee1a1d28..dbabb1a7b 100644 --- a/src/sat/smt/euf_proof.cpp +++ b/src/sat/smt/euf_proof.cpp @@ -27,13 +27,12 @@ namespace euf { get_drat().add_theory(m.get_basic_family_id(), symbol("bool")); } if (!m_proof_out && s().get_config().m_drat && - (get_config().m_lemmas2console || s().get_config().m_smt_proof.is_non_empty_string())) { + (get_config().m_lemmas2console || + s().get_config().m_smt_proof_check || + s().get_config().m_smt_proof.is_non_empty_string())) { TRACE("euf", tout << "init-proof " << s().get_config().m_smt_proof << "\n"); m_proof_out = alloc(std::ofstream, s().get_config().m_smt_proof.str(), std::ios_base::out); - if (get_config().m_lemmas2console) - get_drat().set_clause_eh(*this); - if (s().get_config().m_smt_proof.is_non_empty_string()) - get_drat().set_clause_eh(*this); + get_drat().set_clause_eh(*this); } m_proof_initialized = true; } @@ -90,6 +89,46 @@ namespace euf { return new (get_region()) eq_proof_hint(m_lit_head, m_lit_tail, m_cc_head, m_cc_tail); } + th_proof_hint* solver::mk_cc_proof_hint(sat::literal_vector const& ante, app* a, app* b) { + if (!use_drat()) + return nullptr; + SASSERT(a->get_decl() == b->get_decl()); + push(value_trail(m_lit_tail)); + push(value_trail(m_cc_tail)); + push(restore_size_trail(m_proof_literals)); + push(restore_size_trail(m_explain_cc, m_explain_cc.size())); + + for (auto lit : ante) + m_proof_literals.push_back(~lit); + + m_explain_cc.push_back({a, b, 0, false}); + + m_lit_head = m_lit_tail; + m_cc_head = m_cc_tail; + m_lit_tail = m_proof_literals.size(); + m_cc_tail = m_explain_cc.size(); + return new (get_region()) eq_proof_hint(m_lit_head, m_lit_tail, m_cc_head, m_cc_tail); + } + + th_proof_hint* solver::mk_tc_proof_hint(sat::literal const* clause) { + if (!use_drat()) + return nullptr; + push(value_trail(m_lit_tail)); + push(value_trail(m_cc_tail)); + push(restore_size_trail(m_proof_literals)); + + for (unsigned i = 0; i < 3; ++i) + m_proof_literals.push_back(~clause[i]); + + + m_lit_head = m_lit_tail; + m_cc_head = m_cc_tail; + m_lit_tail = m_proof_literals.size(); + m_cc_tail = m_explain_cc.size(); + return new (get_region()) eq_proof_hint(m_lit_head, m_lit_tail, m_cc_head, m_cc_tail); + } + + expr* eq_proof_hint::get_hint(euf::solver& s) const { ast_manager& m = s.get_manager(); func_decl_ref cc(m), cc_comm(m); @@ -118,7 +157,7 @@ namespace euf { std::sort(s.m_explain_cc.data() + m_cc_head, s.m_explain_cc.data() + m_cc_tail, compare_ts); for (unsigned i = m_cc_head; i < m_cc_tail; ++i) { auto const& [a, b, ts, comm] = s.m_explain_cc[i]; - args.push_back(cc_proof(comm, m.mk_eq(a->get_expr(), b->get_expr()))); + args.push_back(cc_proof(comm, m.mk_eq(a, b))); } for (auto * arg : args) sorts.push_back(arg->get_sort()); @@ -126,6 +165,8 @@ namespace euf { func_decl* f = m.mk_func_decl(symbol("euf"), sorts.size(), sorts.data(), proof); return m.mk_app(f, args); } + + smt_proof_hint* solver::mk_smt_hint(symbol const& n, unsigned nl, literal const* lits, unsigned ne, expr_pair const* eqs, unsigned nd, expr_pair const* deqs) { if (!use_drat()) @@ -134,8 +175,14 @@ namespace euf { push(restore_size_trail(m_proof_literals)); for (unsigned i = 0; i < nl; ++i) - if (sat::null_literal != lits[i]) + if (sat::null_literal != lits[i]) { + if (!literal2expr(lits[i])) + IF_VERBOSE(0, verbose_stream() << lits[i] << "\n"; display(verbose_stream())); + + + SASSERT(literal2expr(lits[i])); m_proof_literals.push_back(lits[i]); + } push(value_trail(m_eq_tail)); push(restore_size_trail(m_proof_eqs)); @@ -231,6 +278,7 @@ namespace euf { TRACE("euf", tout << "on-clause " << n << "\n"); on_lemma(n, lits, st); on_proof(n, lits, st); + on_check(n, lits, st); } void solver::on_proof(unsigned n, literal const* lits, sat::status st) { @@ -252,6 +300,21 @@ namespace euf { UNREACHABLE(); out.flush(); } + + void solver::on_check(unsigned n, literal const* lits, sat::status st) { + if (!s().get_config().m_smt_proof_check) + return; + expr_ref_vector clause(m); + for (unsigned i = 0; i < n; ++i) + clause.push_back(literal2expr(lits[i])); + auto hint = status2proof_hint(st); + if (st.is_asserted() || st.is_redundant()) + m_smt_proof_checker.infer(clause, hint); + else if (st.is_deleted()) + m_smt_proof_checker.del(clause); + else if (st.is_input()) + m_smt_proof_checker.assume(clause); + } void solver::on_lemma(unsigned n, literal const* lits, sat::status st) { if (!get_config().m_lemmas2console) @@ -320,21 +383,21 @@ namespace euf { if (proof_hint) return display_expr(out << " ", proof_hint); else - return out; + return out; } - expr_ref solver::status2proof_hint(sat::status st) { + app_ref solver::status2proof_hint(sat::status st) { if (st.is_sat()) - return expr_ref(m.mk_const("rup", m.mk_proof_sort()), m); // provable by reverse unit propagation + return app_ref(m.mk_const("rup", m.mk_proof_sort()), m); // provable by reverse unit propagation auto* h = reinterpret_cast(st.get_hint()); if (!h) - return expr_ref(m); + return app_ref(m); expr* e = h->get_hint(*this); if (e) - return expr_ref(e, m); + return app_ref(to_app(e), m); - return expr_ref(m); + return app_ref(m); } std::ostream& solver::display_literals(std::ostream& out, unsigned n, literal const* lits) { @@ -345,6 +408,7 @@ namespace euf { k = m.mk_const(symbol(lits[i].var()), m.mk_bool_sort()); e = k; } + SASSERT(e); if (lits[i].sign()) display_expr(out << " (not ", e) << ")"; else diff --git a/src/sat/smt/euf_proof_checker.cpp b/src/sat/smt/euf_proof_checker.cpp index e7a77df5c..98032a86b 100644 --- a/src/sat/smt/euf_proof_checker.cpp +++ b/src/sat/smt/euf_proof_checker.cpp @@ -20,11 +20,14 @@ Author: #include "ast/ast_util.h" #include "ast/ast_ll_pp.h" #include "ast/arith_decl_plugin.h" +#include "smt/smt_solver.h" +#include "sat/sat_params.hpp" #include "sat/smt/euf_proof_checker.h" -#include "sat/smt/arith_proof_checker.h" -#include "sat/smt/q_proof_checker.h" +#include "sat/smt/arith_theory_checker.h" +#include "sat/smt/q_theory_checker.h" #include "sat/smt/tseitin_proof_checker.h" + namespace euf { /** @@ -57,7 +60,7 @@ namespace euf { * union-find checker. */ - class eq_proof_checker : public proof_checker_plugin { + class eq_theory_checker : public theory_checker_plugin { ast_manager& m; arith_util m_arith; expr_ref_vector m_trail; @@ -133,7 +136,7 @@ namespace euf { } public: - eq_proof_checker(ast_manager& m): m(m), m_arith(m), m_trail(m) {} + eq_theory_checker(ast_manager& m): m(m), m_arith(m), m_trail(m) {} expr_ref_vector clause(app* jst) override { expr_ref_vector result(m); @@ -208,7 +211,7 @@ namespace euf { return false; } - void register_plugins(proof_checker& pc) override { + void register_plugins(theory_checker& pc) override { pc.register_plugin(symbol("euf"), this); } }; @@ -219,12 +222,12 @@ namespace euf { The pivot occurs with opposite signs in proof1 and proof2 */ - class res_proof_checker : public proof_checker_plugin { + class res_checker : public theory_checker_plugin { ast_manager& m; - proof_checker& pc; + theory_checker& pc; public: - res_proof_checker(ast_manager& m, proof_checker& pc): m(m), pc(pc) {} + res_checker(ast_manager& m, theory_checker& pc): m(m), pc(pc) {} bool check(app* jst) override { if (jst->get_num_args() != 3) @@ -273,46 +276,46 @@ namespace euf { return result; } - void register_plugins(proof_checker& pc) override { + void register_plugins(theory_checker& pc) override { pc.register_plugin(symbol("res"), this); } }; - proof_checker::proof_checker(ast_manager& m): + theory_checker::theory_checker(ast_manager& m): m(m) { - add_plugin(alloc(arith::proof_checker, m)); - add_plugin(alloc(eq_proof_checker, m)); - add_plugin(alloc(res_proof_checker, m, *this)); - add_plugin(alloc(q::proof_checker, m)); - add_plugin(alloc(smt_proof_checker_plugin, m, symbol("datatype"))); // no-op datatype proof checker - add_plugin(alloc(tseitin::proof_checker, m)); + add_plugin(alloc(arith::theory_checker, m)); + add_plugin(alloc(eq_theory_checker, m)); + add_plugin(alloc(res_checker, m, *this)); + add_plugin(alloc(q::theory_checker, m)); + add_plugin(alloc(smt_theory_checker_plugin, m, symbol("datatype"))); // no-op datatype proof checker + add_plugin(alloc(tseitin::theory_checker, m)); } - proof_checker::~proof_checker() { + theory_checker::~theory_checker() { for (auto& [k, v] : m_checked_clauses) dealloc(v); } - void proof_checker::add_plugin(proof_checker_plugin* p) { + void theory_checker::add_plugin(theory_checker_plugin* p) { m_plugins.push_back(p); p->register_plugins(*this); } - void proof_checker::register_plugin(symbol const& rule, proof_checker_plugin* p) { + void theory_checker::register_plugin(symbol const& rule, theory_checker_plugin* p) { m_map.insert(rule, p); } - bool proof_checker::check(expr* e) { + bool theory_checker::check(expr* e) { if (!e || !is_app(e)) return false; if (m_checked_clauses.contains(e)) return true; app* a = to_app(e); - proof_checker_plugin* p = nullptr; + theory_checker_plugin* p = nullptr; return m_map.find(a->get_decl()->get_name(), p) && p->check(a); } - expr_ref_vector proof_checker::clause(expr* e) { + expr_ref_vector theory_checker::clause(expr* e) { expr_ref_vector* rr; if (m_checked_clauses.find(e, rr)) return *rr; @@ -322,17 +325,17 @@ namespace euf { return r; } - bool proof_checker::vc(expr* e, expr_ref_vector const& clause, expr_ref_vector& v) { + bool theory_checker::vc(expr* e, expr_ref_vector const& clause, expr_ref_vector& v) { SASSERT(is_app(e)); app* a = to_app(e); - proof_checker_plugin* p = nullptr; + theory_checker_plugin* p = nullptr; if (m_map.find(a->get_name(), p)) return p->vc(a, clause, v); IF_VERBOSE(10, verbose_stream() << "there is no proof plugin for " << mk_pp(e, m) << "\n"); return false; } - bool proof_checker::check(expr_ref_vector const& clause1, expr* e, expr_ref_vector & units) { + bool theory_checker::check(expr_ref_vector const& clause1, expr* e, expr_ref_vector & units) { if (!check(e)) return false; units.reset(); @@ -358,13 +361,166 @@ namespace euf { return true; } - expr_ref_vector smt_proof_checker_plugin::clause(app* jst) { + expr_ref_vector smt_theory_checker_plugin::clause(app* jst) { expr_ref_vector result(m); SASSERT(jst->get_name() == m_rule); for (expr* arg : *jst) result.push_back(mk_not(m, arg)); return result; } + + + smt_proof_checker::smt_proof_checker(ast_manager& m, params_ref const& p): + m(m), + m_params(p), + m_checker(m), + m_sat_solver(m_params, m.limit()), + m_drat(m_sat_solver) + { + m_params.set_bool("drat.check_unsat", true); + m_params.set_bool("euf", false); + m_sat_solver.updt_params(m_params); + m_drat.updt_config(); + m_rup = symbol("rup"); + sat_params sp(m_params); + m_check_rup = sp.smt_proof_check_rup(); + } + + void smt_proof_checker::ensure_solver() { + if (!m_solver) + m_solver = mk_smt_solver(m, m_params, symbol()); + } + + + void smt_proof_checker::log_verified(app* proof_hint) { + symbol n = proof_hint->get_name(); + if (n == m_last_rule) { + ++m_num_last_rules; + return; + } + if (m_num_last_rules > 0) + std::cout << "(verified-" << m_last_rule << "+" << m_num_last_rules << ")\n"; + + std::cout << "(verified-" << n << ")\n"; + m_last_rule = n; + m_num_last_rules = 0; + + } + + bool smt_proof_checker::check_rup(expr_ref_vector const& clause) { + if (!m_check_rup) + return true; + add_units(); + mk_clause(clause); + return m_drat.is_drup(m_clause.size(), m_clause.data(), m_units); + } + + bool smt_proof_checker::check_rup(expr* u) { + if (!m_check_rup) + return true; + add_units(); + mk_clause(u); + return m_drat.is_drup(m_clause.size(), m_clause.data(), m_units); + } + + void smt_proof_checker::infer(expr_ref_vector& clause, app* proof_hint) { + + if (is_rup(proof_hint) && check_rup(clause)) { + if (m_check_rup) { + log_verified(proof_hint); + add_clause(clause); + } + return; + } + + expr_ref_vector units(m); + if (m_checker.check(clause, proof_hint, units)) { + bool units_are_rup = true; + for (expr* u : units) { + if (!check_rup(u)) { + std::cout << "unit " << mk_bounded_pp(u, m) << " is not rup\n"; + units_are_rup = false; + } + } + if (units_are_rup) { + log_verified(proof_hint); + add_clause(clause); + return; + } + } + + // extract a simplified verification condition in case proof validation does not work. + // quantifier instantiation can be validated as follows: + // If quantifier instantiation claims that (forall x . phi(x)) => psi using instantiation x -> t + // then check the simplified VC: phi(t) => psi. + // in case psi is the literal instantiation, then the clause is a propositional tautology. + // The VC function is a no-op if the proof hint does not have an associated vc generator. + expr_ref_vector vc(clause); + if (m_checker.vc(proof_hint, clause, vc)) { + log_verified(proof_hint); + add_clause(clause); + return; + } + + ensure_solver(); + m_solver->push(); + for (expr* lit : vc) + m_solver->assert_expr(m.mk_not(lit)); + lbool is_sat = m_solver->check_sat(); + if (is_sat != l_false) { + std::cout << "did not verify: " << is_sat << " " << clause << "\n"; + if (proof_hint) + std::cout << "hint: " << mk_bounded_pp(proof_hint, m, 4) << "\n"; + m_solver->display(std::cout); + if (is_sat == l_true) { + model_ref mdl; + m_solver->get_model(mdl); + mdl->evaluate_constants(); + std::cout << *mdl << "\n"; + } + exit(0); + } + m_solver->pop(1); + std::cout << "(verified-smt"; + if (proof_hint) std::cout << "\n" << mk_bounded_pp(proof_hint, m, 4); + for (expr* arg : clause) + std::cout << "\n " << mk_bounded_pp(arg, m); + std::cout << ")\n"; + if (is_rup(proof_hint)) + diagnose_rup_failure(clause); + + add_clause(clause); + } + + void smt_proof_checker::diagnose_rup_failure(expr_ref_vector const& clause) { + expr_ref_vector fmls(m), assumptions(m), core(m); + m_solver->get_assertions(fmls); + for (unsigned i = 0; i < fmls.size(); ++i) { + assumptions.push_back(m.mk_fresh_const("a", m.mk_bool_sort())); + fmls[i] = m.mk_implies(assumptions.back(), fmls.get(i)); + } + + ref<::solver> core_solver = mk_smt_solver(m, m_params, symbol()); + // core_solver->assert_expr(fmls); + core_solver->assert_expr(m.mk_not(mk_or(clause))); + lbool ch = core_solver->check_sat(assumptions); + std::cout << "failed to verify\n" << clause << "\n"; + if (ch == l_false) { + core_solver->get_unsat_core(core); + std::cout << "core\n"; + for (expr* f : core) + std::cout << mk_pp(f, m) << "\n"; + } + SASSERT(false); + + exit(0); + } + + void smt_proof_checker::collect_statistics(statistics& st) const { + if (m_solver) + m_solver->collect_statistics(st); + + } } diff --git a/src/sat/smt/euf_proof_checker.h b/src/sat/smt/euf_proof_checker.h index 8b1b5d671..6d5cb3290 100644 --- a/src/sat/smt/euf_proof_checker.h +++ b/src/sat/smt/euf_proof_checker.h @@ -19,30 +19,35 @@ Author: #include "util/map.h" #include "util/scoped_ptr_vector.h" #include "ast/ast.h" +#include "ast/ast_util.h" +#include "solver/solver.h" +#include "sat/sat_solver.h" +#include "sat/sat_drat.h" + namespace euf { - class proof_checker; + class theory_checker; - class proof_checker_plugin { + class theory_checker_plugin { public: - virtual ~proof_checker_plugin() {} + virtual ~theory_checker_plugin() {} virtual bool check(app* jst) = 0; virtual expr_ref_vector clause(app* jst) = 0; - virtual void register_plugins(proof_checker& pc) = 0; - virtual bool vc(app* jst, expr_ref_vector const& clause, expr_ref_vector& v) { return false; } + virtual void register_plugins(theory_checker& pc) = 0; + virtual bool vc(app* jst, expr_ref_vector const& clause, expr_ref_vector& v) { v.reset(); v.append(this->clause(jst)); return false; } }; - class proof_checker { + class theory_checker { ast_manager& m; - scoped_ptr_vector m_plugins; // plugins of proof checkers - map m_map; // symbol table of proof checkers + scoped_ptr_vector m_plugins; // plugins of proof checkers + map m_map; // symbol table of proof checkers obj_map m_checked_clauses; // cache of previously checked proofs and their clauses. - void add_plugin(proof_checker_plugin* p); + void add_plugin(theory_checker_plugin* p); public: - proof_checker(ast_manager& m); - ~proof_checker(); - void register_plugin(symbol const& rule, proof_checker_plugin*); + theory_checker(ast_manager& m); + ~theory_checker(); + void register_plugin(symbol const& rule, theory_checker_plugin*); bool check(expr* jst); expr_ref_vector clause(expr* jst); bool vc(expr* jst, expr_ref_vector const& clause, expr_ref_vector& v); @@ -55,15 +60,107 @@ namespace euf { It provides shared implementations for clause and register_plugin. It overrides check to always fail. */ - class smt_proof_checker_plugin : public proof_checker_plugin { + class smt_theory_checker_plugin : public theory_checker_plugin { ast_manager& m; symbol m_rule; public: - smt_proof_checker_plugin(ast_manager& m, symbol const& n): m(m), m_rule(n) {} + smt_theory_checker_plugin(ast_manager& m, symbol const& n): m(m), m_rule(n) {} bool check(app* jst) override { return false; } expr_ref_vector clause(app* jst) override; - void register_plugins(proof_checker& pc) override { pc.register_plugin(m_rule, this); } + void register_plugins(theory_checker& pc) override { pc.register_plugin(m_rule, this); } }; + + class smt_proof_checker { + ast_manager& m; + params_ref m_params; + + // for checking proof rules (hints) + euf::theory_checker m_checker; + + // for fallback SMT checker + scoped_ptr<::solver> m_solver; + + // for RUP + symbol m_rup; + sat::solver m_sat_solver; + sat::drat m_drat; + sat::literal_vector m_units; + sat::literal_vector m_clause; + bool m_check_rup = false; + + // for logging + symbol m_last_rule; + unsigned m_num_last_rules = 0; + + void add_units() { + auto const& units = m_drat.units(); + for (unsigned i = m_units.size(); i < units.size(); ++i) + m_units.push_back(units[i].first); + } + + void log_verified(app* proof_hint); + + void diagnose_rup_failure(expr_ref_vector const& clause); + + void ensure_solver(); + + public: + smt_proof_checker(ast_manager& m, params_ref const& p); + + bool is_rup(app* proof_hint) { + return + proof_hint && + proof_hint->get_name() == m_rup; + } + + void mk_clause(expr_ref_vector const& clause) { + m_clause.reset(); + for (expr* e : clause) { + bool sign = false; + while (m.is_not(e, e)) + sign = !sign; + m_clause.push_back(sat::literal(e->get_id(), sign)); + } + } + + void mk_clause(expr* e) { + m_clause.reset(); + bool sign = false; + while (m.is_not(e, e)) + sign = !sign; + m_clause.push_back(sat::literal(e->get_id(), sign)); + } + + bool check_rup(expr_ref_vector const& clause); + + bool check_rup(expr* u); + + void add_clause(expr_ref_vector const& clause) { + if (!m_check_rup) + return; + mk_clause(clause); + m_drat.add(m_clause, sat::status::input()); + } + + void assume(expr_ref_vector const& clause) { + add_clause(clause); + if (!m_check_rup) + return; + ensure_solver(); + m_solver->assert_expr(mk_or(clause)); + } + + void del(expr_ref_vector const& clause) { + } + + + void infer(expr_ref_vector& clause, app* proof_hint); + + void collect_statistics(statistics& st) const; + + }; + + } diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index c29652803..cde3ef2e0 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -50,7 +50,8 @@ namespace euf { m_to_m(&m), m_to_si(&si), m_values(m), - m_clause_visitor(m) + m_clause_visitor(m), + m_smt_proof_checker(m, p) { updt_params(p); m_relevancy.set_enabled(get_config().m_relevancy_lvl > 2); @@ -72,7 +73,7 @@ namespace euf { void solver::updt_params(params_ref const& p) { m_config.updt_params(p); - use_drat(); + use_drat(); } /** @@ -215,19 +216,25 @@ namespace euf { x - 3 = 0 => x = 3 by arithmetic x = 3 => f(x) = f(3) by EUF resolve to produce clause x - 3 = 0 => f(x) = f(3) + + The last argument to get_assumptions is a place-holder to retrieve a justification of a propagation. + Theory solver would have to populate this hint and the combined hint would have to be composed from the + sub-hints. */ - void solver::get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) { + void solver::get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing, sat::proof_hint*& ph) { m_egraph.begin_explain(); m_explain.reset(); if (use_drat() && !probing) push(restore_size_trail(m_explain_cc, m_explain_cc.size())); auto* ext = sat::constraint_base::to_extension(idx); + th_proof_hint* hint = nullptr; + sat::proof_hint* shint = nullptr; bool has_theory = false; if (ext == this) get_antecedents(l, constraint::from_idx(idx), r, probing); else { - ext->get_antecedents(l, idx, r, probing); + ext->get_antecedents(l, idx, r, probing, shint); has_theory = true; } for (unsigned qhead = 0; qhead < m_explain.size(); ++qhead) { @@ -239,20 +246,19 @@ namespace euf { auto* ext = sat::constraint_base::to_extension(idx); SASSERT(ext != this); sat::literal lit = sat::null_literal; - ext->get_antecedents(lit, idx, r, probing); + ext->get_antecedents(lit, idx, r, probing, shint); has_theory = true; } } m_egraph.end_explain(); - th_proof_hint* hint = nullptr; - if (use_drat() && !probing) { - if (has_theory) { - r.push_back(~l); - hint = mk_smt_hint(symbol("smt"), r); - r.pop_back(); - } - else + if (use_drat() && !probing) { + if (!has_theory) hint = mk_hint(l, r); + else { + if (l != sat::null_literal) r.push_back(~l); + hint = mk_smt_hint(symbol("smt"), r); + if (l != sat::null_literal) r.pop_back(); + } } unsigned j = 0; for (sat::literal lit : r) @@ -957,6 +963,7 @@ namespace euf { m_egraph.collect_statistics(st); for (auto* e : m_solvers) e->collect_statistics(st); + m_smt_proof_checker.collect_statistics(st); st.update("euf ackerman", m_stats.m_ackerman); st.update("euf final check", m_stats.m_final_checks); } diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index bf18ca4bc..cedfcf2da 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -28,6 +28,7 @@ Author: #include "sat/smt/euf_ackerman.h" #include "sat/smt/user_solver.h" #include "sat/smt/euf_relevancy.h" +#include "sat/smt/euf_proof_checker.h" #include "smt/params/smt_params.h" @@ -203,19 +204,22 @@ namespace euf { unsigned m_eq_head = 0, m_eq_tail = 0, m_deq_head = 0, m_deq_tail = 0; eq_proof_hint* mk_hint(literal lit, literal_vector const& r); + bool m_proof_initialized = false; void init_proof(); ast_pp_util m_clause_visitor; bool m_display_all_decls = false; + smt_proof_checker m_smt_proof_checker; void on_clause(unsigned n, literal const* lits, sat::status st) override; void on_lemma(unsigned n, literal const* lits, sat::status st); void on_proof(unsigned n, literal const* lits, sat::status st); + void on_check(unsigned n, literal const* lits, sat::status st); std::ostream& display_literals(std::ostream& out, unsigned n, sat::literal const* lits); void display_assume(std::ostream& out, unsigned n, literal const* lits); void display_inferred(std::ostream& out, unsigned n, literal const* lits, expr* proof_hint); void display_deleted(std::ostream& out, unsigned n, literal const* lits); std::ostream& display_hint(std::ostream& out, expr* proof_hint); - expr_ref status2proof_hint(sat::status st); + app_ref status2proof_hint(sat::status st); // relevancy bool is_propagated(sat::literal lit); @@ -334,7 +338,7 @@ namespace euf { bool set_root(literal l, literal r) override; void flush_roots() override; - void get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) override; + void get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing, sat::proof_hint*& ph) override; void get_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing); void add_antecedent(bool probing, enode* a, enode* b); void add_diseq_antecedent(ptr_vector& ex, cc_justification* cc, enode* a, enode* b); @@ -400,6 +404,8 @@ namespace euf { smt_proof_hint* mk_smt_prop_hint(symbol const& n, literal lit, expr* a, expr* b) { expr_pair e(a, b); return mk_smt_hint(n, 1, &lit, 0, nullptr, 1, &e); } smt_proof_hint* mk_smt_prop_hint(symbol const& n, literal lit, enode* a, enode* b) { return mk_smt_prop_hint(n, lit, a->get_expr(), b->get_expr()); } smt_proof_hint* mk_smt_hint(symbol const& n, enode* a, enode* b) { expr_pair e(a->get_expr(), b->get_expr()); return mk_smt_hint(n, 0, nullptr, 1, &e); } + th_proof_hint* mk_cc_proof_hint(sat::literal_vector const& ante, app* a, app* b); + th_proof_hint* mk_tc_proof_hint(sat::literal const* ternary_clause); sat::status mk_tseitin_status(sat::literal a, sat::literal b); sat::status mk_tseitin_status(unsigned n, sat::literal const* lits); diff --git a/src/sat/smt/fpa_solver.h b/src/sat/smt/fpa_solver.h index 38abb399d..537ae0895 100644 --- a/src/sat/smt/fpa_solver.h +++ b/src/sat/smt/fpa_solver.h @@ -71,7 +71,7 @@ namespace fpa { void finalize_model(model& mdl) override; bool unit_propagate() override; - void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override { UNREACHABLE(); } + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing, sat::proof_hint*& ph) override { UNREACHABLE(); } sat::check_result check() override; euf::th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); } diff --git a/src/sat/smt/pb_solver.cpp b/src/sat/smt/pb_solver.cpp index 1c762cb3f..2a29f07d5 100644 --- a/src/sat/smt/pb_solver.cpp +++ b/src/sat/smt/pb_solver.cpp @@ -722,7 +722,8 @@ namespace pb { auto* ext = sat::constraint_base::to_extension(cindex); if (ext != this) { m_lemma.reset(); - ext->get_antecedents(consequent, idx, m_lemma, false); + sat::proof_hint* ph = nullptr; + ext->get_antecedents(consequent, idx, m_lemma, false, ph); for (literal l : m_lemma) process_antecedent(~l, offset); break; } @@ -1052,7 +1053,8 @@ namespace pb { auto* ext = sat::constraint_base::to_extension(index); if (ext != this) { m_lemma.reset(); - ext->get_antecedents(consequent, index, m_lemma, false); + sat::proof_hint* ph = nullptr; + ext->get_antecedents(consequent, index, m_lemma, false, ph); for (literal l : m_lemma) process_antecedent(~l, 1); break; @@ -1688,7 +1690,7 @@ namespace pb { // ---------------------------- // constraint generic methods - void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector & r, bool probing) { + void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector & r, bool probing, sat::proof_hint*& ph) { get_antecedents(l, index2constraint(idx), r, probing); } diff --git a/src/sat/smt/pb_solver.h b/src/sat/smt/pb_solver.h index 09c0e47e0..9f9e6835c 100644 --- a/src/sat/smt/pb_solver.h +++ b/src/sat/smt/pb_solver.h @@ -377,7 +377,7 @@ namespace pb { bool propagated(literal l, sat::ext_constraint_idx idx) override; bool unit_propagate() override { return false; } lbool resolve_conflict() override; - void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector & r, bool probing) override; + void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector & r, bool probing, sat::proof_hint*& ph) override; void asserted(literal l) override; sat::check_result check() override; void push() override; diff --git a/src/sat/smt/q_solver.cpp b/src/sat/smt/q_solver.cpp index 7fd7be97e..7bddcbd9f 100644 --- a/src/sat/smt/q_solver.cpp +++ b/src/sat/smt/q_solver.cpp @@ -353,7 +353,7 @@ namespace q { return !m.is_and(arg) && !m.is_or(arg) && !m.is_iff(arg) && !m.is_implies(arg); } - void solver::get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) { + void solver::get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing, sat::proof_hint*& ph) { m_ematch.get_antecedents(l, idx, r, probing); } diff --git a/src/sat/smt/q_solver.h b/src/sat/smt/q_solver.h index 3a95a00be..f199db610 100644 --- a/src/sat/smt/q_solver.h +++ b/src/sat/smt/q_solver.h @@ -83,7 +83,7 @@ namespace q { solver(euf::solver& ctx, family_id fid); bool is_external(sat::bool_var v) override { return false; } - void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override; + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing, sat::proof_hint*& ph) override; void asserted(sat::literal l) override; sat::check_result check() override; diff --git a/src/sat/smt/q_proof_checker.cpp b/src/sat/smt/q_theory_checker.cpp similarity index 79% rename from src/sat/smt/q_proof_checker.cpp rename to src/sat/smt/q_theory_checker.cpp index 8ddd3d75a..be246dd3c 100644 --- a/src/sat/smt/q_proof_checker.cpp +++ b/src/sat/smt/q_theory_checker.cpp @@ -3,7 +3,7 @@ Copyright (c) 2022 Microsoft Corporation Module Name: - q_proof_checker.cpp + q_theory_checker.cpp Abstract: @@ -16,12 +16,12 @@ Author: --*/ #include "ast/rewriter/var_subst.h" -#include "sat/smt/q_proof_checker.h" +#include "sat/smt/q_theory_checker.h" #include "sat/smt/q_solver.h" namespace q { - expr_ref_vector proof_checker::clause(app* jst) { + expr_ref_vector theory_checker::clause(app* jst) { expr_ref_vector result(m); for (expr* arg : *jst) if (!is_bind(arg)) @@ -29,7 +29,7 @@ namespace q { return result; } - expr_ref_vector proof_checker::binding(app* jst) { + expr_ref_vector theory_checker::binding(app* jst) { expr_ref_vector result(m); for (expr* arg : *jst) if (is_bind(arg)) { @@ -39,7 +39,7 @@ namespace q { return result; } - bool proof_checker::vc(app* jst, expr_ref_vector const& clause0, expr_ref_vector& v) { + bool theory_checker::vc(app* jst, expr_ref_vector const& clause0, expr_ref_vector& v) { expr* q = nullptr; if (!is_inst(jst)) return false; @@ -54,11 +54,11 @@ namespace q { return qi == clause1.get(1); } - bool proof_checker::is_inst(expr* jst) { + bool theory_checker::is_inst(expr* jst) { return is_app(jst) && to_app(jst)->get_name() == m_inst && m.mk_proof_sort() == jst->get_sort(); } - bool proof_checker::is_bind(expr* e) { + bool theory_checker::is_bind(expr* e) { return is_app(e) && to_app(e)->get_name() == m_bind && m.mk_proof_sort() == e->get_sort(); } diff --git a/src/sat/smt/q_proof_checker.h b/src/sat/smt/q_theory_checker.h similarity index 83% rename from src/sat/smt/q_proof_checker.h rename to src/sat/smt/q_theory_checker.h index 4072739c7..70c9938fe 100644 --- a/src/sat/smt/q_proof_checker.h +++ b/src/sat/smt/q_theory_checker.h @@ -3,7 +3,7 @@ Copyright (c) 2022 Microsoft Corporation Module Name: - q_proof_checker.h + q_theory_checker.h Abstract: @@ -25,7 +25,7 @@ Author: namespace q { - class proof_checker : public euf::proof_checker_plugin { + class theory_checker : public euf::theory_checker_plugin { ast_manager& m; symbol m_inst; symbol m_bind; @@ -37,7 +37,7 @@ namespace q { bool is_bind(expr* e); public: - proof_checker(ast_manager& m): + theory_checker(ast_manager& m): m(m), m_inst("inst"), m_bind("bind") { @@ -47,7 +47,7 @@ namespace q { bool check(app* jst) override { return false; } - void register_plugins(euf::proof_checker& pc) override { + void register_plugins(euf::theory_checker& pc) override { pc.register_plugin(symbol("inst"), this); } diff --git a/src/sat/smt/recfun_solver.cpp b/src/sat/smt/recfun_solver.cpp index c88138d3f..ebb1ec7c1 100644 --- a/src/sat/smt/recfun_solver.cpp +++ b/src/sat/smt/recfun_solver.cpp @@ -180,7 +180,7 @@ namespace recfun { add_clause(clause); } - void solver::get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) { + void solver::get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing, sat::proof_hint*& ph) { UNREACHABLE(); } diff --git a/src/sat/smt/recfun_solver.h b/src/sat/smt/recfun_solver.h index 4e41a35a9..463aee6d6 100644 --- a/src/sat/smt/recfun_solver.h +++ b/src/sat/smt/recfun_solver.h @@ -92,7 +92,7 @@ namespace recfun { solver(euf::solver& ctx); ~solver() override; bool is_external(sat::bool_var v) override { return false; } - void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override; + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing, sat::proof_hint*& ph) override; void asserted(sat::literal l) override; sat::check_result check() override; std::ostream& display(std::ostream& out) const override; diff --git a/src/sat/smt/tseitin_proof_checker.cpp b/src/sat/smt/tseitin_proof_checker.cpp index df0d60435..19f6e8660 100644 --- a/src/sat/smt/tseitin_proof_checker.cpp +++ b/src/sat/smt/tseitin_proof_checker.cpp @@ -32,13 +32,13 @@ TODOs: namespace tseitin { - expr_ref_vector proof_checker::clause(app* jst) { + expr_ref_vector theory_checker::clause(app* jst) { expr_ref_vector result(m); result.append(jst->get_num_args(), jst->get_args()); return result; } - bool proof_checker::check(app* jst) { + bool theory_checker::check(app* jst) { expr* main_expr = nullptr; unsigned max_depth = 0; for (expr* arg : *jst) { @@ -231,7 +231,7 @@ namespace tseitin { return false; } - bool proof_checker::equiv(expr* a, expr* b) { + bool theory_checker::equiv(expr* a, expr* b) { if (a == b) return true; if (!is_app(a) || !is_app(b)) diff --git a/src/sat/smt/tseitin_proof_checker.h b/src/sat/smt/tseitin_proof_checker.h index 86a109ed8..8bbacd53b 100644 --- a/src/sat/smt/tseitin_proof_checker.h +++ b/src/sat/smt/tseitin_proof_checker.h @@ -25,7 +25,7 @@ Author: namespace tseitin { - class proof_checker : public euf::proof_checker_plugin { + class theory_checker : public euf::theory_checker_plugin { ast_manager& m; expr_fast_mark1 m_mark; @@ -52,12 +52,12 @@ namespace tseitin { } struct scoped_mark { - proof_checker& pc; - scoped_mark(proof_checker& pc): pc(pc) {} + theory_checker& pc; + scoped_mark(theory_checker& pc): pc(pc) {} ~scoped_mark() { pc.m_mark.reset(); pc.m_nmark.reset(); } }; public: - proof_checker(ast_manager& m): + theory_checker(ast_manager& m): m(m) { } @@ -65,7 +65,7 @@ namespace tseitin { bool check(app* jst) override; - void register_plugins(euf::proof_checker& pc) override { + void register_plugins(euf::theory_checker& pc) override { pc.register_plugin(symbol("tseitin"), this); } diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index 5c98a6fac..34fd26c77 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -201,7 +201,7 @@ namespace user_solver { return sat::justification::mk_ext_justification(s().scope_lvl(), constraint->to_index()); } - void solver::get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) { + void solver::get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing, sat::proof_hint*& ph) { auto& j = justification::from_index(idx); auto const& prop = m_prop[j.m_propagation_index]; for (unsigned id : prop.m_ids) diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index 28528b9a1..c996d1878 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -152,7 +152,7 @@ namespace user_solver { void push_core() override; void pop_core(unsigned n) override; bool unit_propagate() override; - void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override; + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing, sat::proof_hint*& ph) override; void collect_statistics(statistics& st) const override; sat::literal internalize(expr* e, bool sign, bool root, bool learned) override; void internalize(expr* e, bool redundant) override; diff --git a/src/shell/drat_frontend.cpp b/src/shell/drat_frontend.cpp index 091d8cd4d..b0b711ef0 100644 --- a/src/shell/drat_frontend.cpp +++ b/src/shell/drat_frontend.cpp @@ -5,20 +5,14 @@ Copyright (c) 2020 Microsoft Corporation #include #include -#include "ast/bv_decl_plugin.h" #include "util/memory_manager.h" #include "util/statistics.h" +#include "ast/proofs/proof_checker.h" +#include "ast/reg_decl_plugins.h" #include "sat/dimacs.h" #include "sat/sat_solver.h" #include "sat/sat_drat.h" -#include "smt/smt_solver.h" #include "shell/drat_frontend.h" -#include "parsers/smt2/smt2parser.h" -#include "cmd_context/cmd_context.h" -#include "ast/proofs/proof_checker.h" -#include "ast/rewriter/th_rewriter.h" -#include "ast/reg_decl_plugins.h" -#include "sat/smt/arith_proof_checker.h" class drup_checker { @@ -104,103 +98,3 @@ unsigned read_drat(char const* drat_file) { } return 0; } - - -#if 0 - - bool validate_hint(expr_ref_vector const& exprs, sat::literal_vector const& lits, sat::proof_hint const& hint) { - arith_util autil(m); - arith::proof_checker achecker(m); - proof_checker pc(m); - switch (hint.m_ty) { - case sat::hint_type::null_h: - break; - case sat::hint_type::bound_h: - case sat::hint_type::farkas_h: - case sat::hint_type::implied_eq_h: { - achecker.reset(); - for (auto const& [a, b]: hint.m_eqs) { - expr* x = exprs[a]; - expr* y = exprs[b]; - achecker.add_eq(x, y); - } - for (auto const& [a, b]: hint.m_diseqs) { - expr* x = exprs[a]; - expr* y = exprs[b]; - achecker.add_diseq(x, y); - } - - unsigned sz = hint.m_literals.size(); - for (unsigned i = 0; i < sz; ++i) { - auto const& [coeff, lit] = hint.m_literals[i]; - app_ref e(to_app(m_b2e[lit.var()]), m); - if (i + 1 == sz && sat::hint_type::bound_h == hint.m_ty) { - if (!achecker.add_conseq(coeff, e, lit.sign())) { - std::cout << "p failed checking hint " << e << "\n"; - return false; - } - - } - else if (!achecker.add_ineq(coeff, e, lit.sign())) { - std::cout << "p failed checking hint " << e << "\n"; - return false; - } - } - - // achecker.display(std::cout << "checking\n"); - bool ok = achecker.check(); - - if (!ok) { - rational lc(1); - for (auto const& [coeff, lit] : hint.m_literals) - lc = lcm(lc, denominator(coeff)); - bool is_strict = false; - expr_ref sum(m); - for (auto const& [coeff, lit] : hint.m_literals) { - app_ref e(to_app(m_b2e[lit.var()]), m); - VERIFY(pc.check_arith_literal(!lit.sign(), e, coeff*lc, sum, is_strict)); - std::cout << "sum: " << sum << "\n"; - } - sort* s = sum->get_sort(); - if (is_strict) - sum = autil.mk_lt(sum, autil.mk_numeral(rational(0), s)); - else - sum = autil.mk_le(sum, autil.mk_numeral(rational(0), s)); - th_rewriter rw(m); - rw(sum); - std::cout << "sum: " << sum << "\n"; - - for (auto const& [a, b]: hint.m_eqs) { - expr* x = exprs[a]; - expr* y = exprs[b]; - app_ref e(m.mk_eq(x, y), m); - std::cout << e << "\n"; - } - for (auto const& [a, b]: hint.m_diseqs) { - expr* x = exprs[a]; - expr* y = exprs[b]; - app_ref e(m.mk_not(m.mk_eq(x, y)), m); - std::cout << e << "\n"; - } - for (auto const& [coeff, lit] : hint.m_literals) { - app_ref e(to_app(m_b2e[lit.var()]), m); - if (lit.sign()) e = m.mk_not(e); - std::cout << e << "\n"; - } - achecker.display(std::cout); - std::cout << "p hint not verified\n"; - return false; - } - - std::cout << "p hint verified\n"; - return true; - break; - } - default: - UNREACHABLE(); - break; - } - return false; - } - -#endif