From 32825a26cb84111a7758304cf8dc919f33cbd3f9 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 30 Dec 2023 17:29:36 -0800 Subject: [PATCH] Update hints to carry premises --- src/sat/smt/polysat/core.cpp | 14 ++++-- src/sat/smt/polysat/core.h | 2 + src/sat/smt/polysat/viable.cpp | 12 ++++- src/sat/smt/polysat_solver.cpp | 85 +++++++++++++++++++++++++--------- src/sat/smt/polysat_solver.h | 51 +++++++++++++++----- 5 files changed, 127 insertions(+), 37 deletions(-) diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index 0fb1e1635..c5f1b336f 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -245,7 +245,7 @@ namespace polysat { } // If no saturation propagation was possible, explain the conflict using the variable assignment. - m_unsat_core = explain_eval(get_constraint(conflict_idx)); + m_unsat_core = explain_eval_unfold(get_constraint(conflict_idx)); m_unsat_core.push_back(get_dependency(conflict_idx)); s.set_conflict(m_unsat_core, "polysat-bail-out-conflict"); decay_activity(); @@ -456,9 +456,9 @@ namespace polysat { s.trail().push(unassign(*this, index.id)); } - dependency_vector core::explain_eval(signed_constraint const& sc) { + dependency_vector core::explain_eval(unsigned_vector const& vars) { dependency_vector deps; - for (auto v : sc.vars()) { + for (auto v : vars) { if (is_assigned(v)) { inc_activity(v); deps.push_back(m_justification[v]); @@ -467,6 +467,14 @@ namespace polysat { return deps; } + dependency_vector core::explain_eval(signed_constraint const& sc) { + return explain_eval(sc.vars()); + } + + dependency_vector core::explain_eval_unfold(signed_constraint const& sc) { + return explain_eval(sc.unfold_vars()); + } + lbool core::eval(signed_constraint const& sc) { return sc.eval(m_assignment); } diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index 8f58aa118..7a76e0fd9 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -86,6 +86,7 @@ namespace polysat { void propagate_assignment(pvar v, rational const& value, dependency dep); void propagate_activation(constraint_id idx, signed_constraint& sc, dependency dep); void propagate(constraint_id id, signed_constraint& sc, lbool value, dependency const& d); + dependency_vector explain_eval(unsigned_vector const& vars); void add_watch(unsigned idx, unsigned var); @@ -173,6 +174,7 @@ namespace polysat { lbool eval(signed_constraint const& sc); lbool eval_unfold(signed_constraint const& sc); dependency_vector explain_eval(signed_constraint const& sc); + dependency_vector explain_eval_unfold(signed_constraint const& sc); bool inconsistent() const; /* diff --git a/src/sat/smt/polysat/viable.cpp b/src/sat/smt/polysat/viable.cpp index bc2f11b22..4f51529ce 100644 --- a/src/sat/smt/polysat/viable.cpp +++ b/src/sat/smt/polysat/viable.cpp @@ -107,8 +107,7 @@ namespace polysat { m_num_bits = c.size(v); m_fixed_bits.reset(v); init_overlaps(v); - bool start_at0 = val1 == 0; - + bool start_at0 = val1 == 0; lbool r = next_viable(val1); TRACE("bv", display_state(tout); display(tout << "next viable v" << v << " " << val1 << " " << r << "\n")); @@ -132,6 +131,15 @@ namespace polysat { r = next_viable(val2); + if (r != l_false) + return r; + + if (!start_at0 && val1 == c.var2pdd(v).max_value()) + return l_false; + + val2 = 0; + r = next_viable(val2); + if (r != l_false) return r; val2 = val1; diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 9b7184c98..c32699411 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -119,9 +119,9 @@ namespace polysat { void solver::set_conflict(dependency_vector const& deps, char const* hint_info) { auto [lits, eqs] = explain_deps(deps); - polysat_proof* hint = nullptr; + proof_hint* hint = nullptr; if (ctx.use_drat() && hint_info) - hint = mk_proof_hint(hint_info); + hint = mk_proof_hint(hint_info, lits, eqs); auto ex = euf::th_explain::conflict(*this, lits, eqs, hint); TRACE("bv", ex->display(tout << "conflict: ") << "\n"; s().display(tout)); validate_conflict(lits, eqs); @@ -246,9 +246,12 @@ namespace polysat { if (s().value(lit) == l_true) return dependency(lit.var()); auto [core, eqs] = explain_deps(deps); - polysat_proof* hint = nullptr; - if (ctx.use_drat() && hint_info) - hint = mk_proof_hint(hint_info); + proof_hint* hint = nullptr; + if (ctx.use_drat() && hint_info) { + core.push_back(~lit); + hint = mk_proof_hint(hint_info, core, eqs); + core.pop_back(); + } auto ex = euf::th_explain::propagate(*this, core, eqs, lit, hint); validate_propagate(lit, core, eqs); ctx.propagate(lit, ex); @@ -273,14 +276,18 @@ namespace polysat { TRACE("bv", tout << "propagate " << d << " " << sign << "\n"); auto [core, eqs] = explain_deps(deps); SASSERT(d.is_bool_var() || d.is_eq()); - polysat_proof* hint = nullptr; - if (ctx.use_drat() && hint_info) - hint = mk_proof_hint(hint_info); + proof_hint* hint = nullptr; + if (d.is_bool_var()) { auto bv = d.bool_var(); auto lit = sat::literal(bv, sign); if (s().value(lit) == l_true) return; + if (ctx.use_drat() && hint_info) { + core.push_back(~lit); + hint = mk_proof_hint(hint_info, core, eqs); + core.pop_back(); + } auto ex = euf::th_explain::propagate(*this, core, eqs, lit, hint); validate_propagate(lit, core, eqs); ctx.propagate(lit, ex); @@ -291,6 +298,8 @@ namespace polysat { auto n1 = var2enode(v1); auto n2 = var2enode(v2); eqs.push_back({ n1, n2 }); + if (ctx.use_drat() && hint_info) + hint = mk_proof_hint(hint_info, core, eqs); auto ex = euf::th_explain::conflict(*this, core, eqs, hint); validate_conflict(core, eqs); ctx.set_conflict(ex); @@ -319,26 +328,36 @@ namespace polysat { lits.push_back(~ctx.mk_literal(constraint2expr(*std::get_if(&e)))); } for (auto [n1, n2] : eqs) - ctx.get_eq_antecedents(n1, n2, lits); + ctx.get_eq_antecedents(n1, n2, lits); + proof_hint* hint = nullptr; + if (ctx.use_drat()) + hint = mk_proof_hint(name, lits, {}); for (auto& lit : lits) lit.neg(); for (auto lit : lits) if (s().value(lit) == l_true) return false; validate_axiom(lits); - s().add_clause(lits.size(), lits.data(), sat::status::th(is_redundant, get_id(), mk_proof_hint(name))); + s().add_clause(lits.size(), lits.data(), sat::status::th(is_redundant, get_id(), hint)); return true; } void solver::add_axiom(char const* name, std::initializer_list const& clause) { bool is_redundant = false; sat::literal_vector lits; - for (auto lit : clause) - lits.push_back(lit); - validate_axiom(lits); - polysat_proof* hint = nullptr; - if (ctx.use_drat()) - hint = mk_proof_hint(name); + proof_hint* hint = nullptr; + if (ctx.use_drat()) { + for (auto lit : clause) + lits.push_back(~lit); + hint = mk_proof_hint(name, lits, {}); + for (auto& lit : lits) + lit.neg(); + } + else { + for (auto lit : clause) + lits.push_back(lit); + } + validate_axiom(lits); s().add_clause(lits.size(), lits.data(), sat::status::th(is_redundant, get_id(), hint)); } @@ -400,12 +419,36 @@ namespace polysat { return expr_ref(r, m); } - expr* solver::polysat_proof::get_hint(euf::solver& s) const { - auto& m = s.get_manager(); - return m.mk_app(symbol(name), 0, nullptr, m.mk_proof_sort()); + expr* solver::proof_hint::get_hint(euf::solver& s) const { + ast_manager& m = s.get_manager(); + family_id fid = m.get_family_id("bv"); + solver& p = dynamic_cast(*s.fid2solver(fid)); + expr_ref_vector args(m); + for (unsigned i = m_lit_head; i < m_lit_tail; ++i) + args.push_back(s.literal2expr(p.m_mk_hint.lit(i))); + for (unsigned i = m_eq_head; i < m_eq_tail; ++i) + args.push_back(s.mk_eq(p.m_mk_hint.eq(i).first, p.m_mk_hint.eq(i).second)); + expr* pr = m.mk_app(symbol(name), args.size(), args.data(), m.mk_proof_sort()); + return m.mk_app(symbol("bv"), 1, &pr, m.mk_proof_sort()); } - solver::polysat_proof* solver::mk_proof_hint(char const* name) { - return new (get_region()) polysat_proof(name); + void solver::proof_hint_builder::init(euf::solver& ctx, char const* name) { + ctx.push(value_trail(m_eq_tail)); + ctx.push(value_trail(m_lit_tail)); + m_name = name; + reset(); + } + + solver::proof_hint* solver::proof_hint_builder::mk(euf::solver& ctx) { + return new (ctx.get_region()) proof_hint(m_name, m_lit_head, m_lit_tail, m_eq_head, m_eq_tail); + } + + solver::proof_hint* solver::mk_proof_hint(char const* name, sat::literal_vector const& lits, euf::enode_pair_vector const& eqs) { + m_mk_hint.init(ctx, name); + for (auto lit : lits) + m_mk_hint.add_lit(lit); + for (auto [a,b] : eqs) + m_mk_hint.add_eq(a,b); + return m_mk_hint.mk(ctx); } } diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index e8f85e41c..49e51de12 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -29,6 +29,8 @@ namespace euf { namespace polysat { + + class solver : public euf::th_euf_solver, public solver_interface { typedef euf::theory_var theory_var; typedef euf::theory_id theory_id; @@ -37,6 +39,42 @@ namespace polysat { typedef sat::literal_vector literal_vector; using pdd = dd::pdd; + struct proof_hint : public euf::th_proof_hint { + char const* name; + unsigned m_lit_head, m_lit_tail, m_eq_head, m_eq_tail; + proof_hint(char const* name, unsigned lh, unsigned lt, unsigned eh, unsigned et) : + name(name), m_lit_head(lh), m_lit_tail(lt), m_eq_head(eh), m_eq_tail(et) {} + expr* get_hint(euf::solver& s) const override; + }; + + class proof_hint_builder { + sat::literal_vector m_literals; + euf::enode_pair_vector m_eqs; + char const* m_name = nullptr; + unsigned m_lit_head = 0, m_lit_tail = 0, m_eq_head = 0, m_eq_tail = 0; + void reset() { m_lit_head = m_lit_tail; m_eq_head = m_eq_tail; } + void add(euf::enode* a, euf::enode* b) { + if (m_eq_tail < m_eqs.size()) + m_eqs[m_eq_tail] = { a, b }; + else + m_eqs.push_back({ a, b }); + m_eq_tail++; + } + public: + void init(euf::solver& ctx, char const* name); + void add_eq(euf::enode* a, euf::enode* b) { add(a, b); } + void add_lit(sat::literal lit) { + if (m_lit_tail < m_literals.size()) + m_literals[m_lit_tail] = lit; + else + m_literals.push_back(lit); + m_lit_tail++; + } + sat::literal const& lit(unsigned i) const { return m_literals[i]; } + euf::enode_pair const& eq(unsigned i) const { return m_eqs[i]; } + proof_hint* mk(euf::solver& s); + }; + struct stats { void reset() { memset(this, 0, sizeof(stats)); } stats() { reset(); } @@ -49,17 +87,8 @@ namespace polysat { ~atom() { } }; - - class polysat_proof : public euf::th_proof_hint { - // assume name is statically allocated - char const* name; - public: - polysat_proof(char const* name) : name(name) {} - ~polysat_proof() override {} - expr* get_hint(euf::solver& s) const override; - }; - - polysat_proof* mk_proof_hint(char const* name); + proof_hint_builder m_mk_hint; + proof_hint* mk_proof_hint(char const* name, sat::literal_vector const& lits, euf::enode_pair_vector const& eqs); bv_util bv; arith_util m_autil;