diff --git a/scripts/mk_project.py b/scripts/mk_project.py index 0f5dc26ae..99051fe93 100644 --- a/scripts/mk_project.py +++ b/scripts/mk_project.py @@ -58,7 +58,7 @@ def init_project_def(): add_lib('proto_model', ['model', 'rewriter', 'smt_params'], 'smt/proto_model') add_lib('smt', ['bit_blaster', 'macros', 'normal_forms', 'cmd_context', 'proto_model', 'solver_assertions', 'substitution', 'grobner', 'simplex', 'proofs', 'pattern', 'parser_util', 'fpa', 'lp']) - add_lib('polysat', ['util', 'dd'], 'sat/smt/polysat'), + add_lib('polysat', ['util', 'dd', 'sat'], 'sat/smt/polysat'), add_lib('sat_smt', ['sat', 'euf', 'smt', 'tactic', 'solver', 'smt_params', 'bit_blaster', 'fpa', 'mbp', 'polysat', 'normal_forms', 'lp', 'pattern', 'qe_lite'], 'sat/smt') add_lib('sat_tactic', ['tactic', 'sat', 'solver', 'sat_smt'], 'sat/tactic') add_lib('nlsat_tactic', ['nlsat', 'sat_tactic', 'arith_tactics'], 'nlsat/tactic') diff --git a/src/sat/smt/polysat/polysat_core.cpp b/src/sat/smt/polysat/polysat_core.cpp index ff35f20ee..be25c9af8 100644 --- a/src/sat/smt/polysat/polysat_core.cpp +++ b/src/sat/smt/polysat/polysat_core.cpp @@ -107,6 +107,7 @@ namespace polysat { m_justification.push_back(null_dependency); m_watch.push_back({}); m_var_queue.mk_var_eh(v); + m_viable.ensure_var(v); s.trail().push(mk_add_var(*this)); return v; } @@ -147,8 +148,8 @@ namespace polysat { s.trail().push(mk_dqueue_var(m_var, *this)); switch (m_viable.find_viable(m_var, m_value)) { case find_t::empty: - m_unsat_core = m_viable.explain(); - propagate_unsat_core(); + s.set_lemma(m_viable.get_core(), 0, m_viable.explain()); + // propagate_unsat_core(); return sat::check_result::CR_CONTINUE; case find_t::singleton: s.propagate(m_constraints.eq(var2pdd(m_var), m_value), m_viable.explain()); @@ -294,7 +295,7 @@ namespace polysat { // default is to use unsat core: // if core is based on viable, use s.set_lemma(); - s.set_conflict(m_unsat_core); + s.set_conflict(m_unsat_core); } void core::assign_eh(unsigned index, bool sign, dependency const& dep) { diff --git a/src/sat/smt/polysat/polysat_core.h b/src/sat/smt/polysat/polysat_core.h index b4f35e776..144b1256b 100644 --- a/src/sat/smt/polysat/polysat_core.h +++ b/src/sat/smt/polysat/polysat_core.h @@ -17,9 +17,10 @@ Author: --*/ #pragma once +#include "util/var_queue.h" #include "util/dependency.h" #include "math/dd/dd_pdd.h" -#include "sat/smt/sat_th.h" +#include "sat/sat_extension.h" #include "sat/smt/polysat/polysat_types.h" #include "sat/smt/polysat/polysat_constraints.h" #include "sat/smt/polysat/polysat_viable.h" diff --git a/src/sat/smt/polysat/polysat_types.h b/src/sat/smt/polysat/polysat_types.h index 92896a7f7..1df9912bc 100644 --- a/src/sat/smt/polysat/polysat_types.h +++ b/src/sat/smt/polysat/polysat_types.h @@ -26,7 +26,7 @@ namespace polysat { using pvar_vector = unsigned_vector; inline const pvar null_var = UINT_MAX; - + class signed_constraint; class dependency { std::variant> m_data; @@ -87,7 +87,9 @@ namespace polysat { using dependency_vector = vector; - class signed_constraint; + using core_vector = vector>; + + // // The interface that PolySAT uses to the SAT/SMT solver. @@ -97,7 +99,7 @@ namespace polysat { public: virtual void add_eq_literal(pvar v, rational const& val) = 0; virtual void set_conflict(dependency_vector const& core) = 0; - virtual void set_lemma(vector const& lemma, unsigned level, dependency_vector const& core) = 0; + virtual void set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) = 0; virtual dependency propagate(signed_constraint sc, dependency_vector const& deps) = 0; virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0; virtual trail_stack& trail() = 0; diff --git a/src/sat/smt/polysat/polysat_viable.cpp b/src/sat/smt/polysat/polysat_viable.cpp index 9af70d716..4152956de 100644 --- a/src/sat/smt/polysat/polysat_viable.cpp +++ b/src/sat/smt/polysat/polysat_viable.cpp @@ -79,7 +79,6 @@ namespace polysat { find_t viable::find_viable(pvar v, rational& lo) { rational hi; - ensure_var(v); switch (find_viable(v, lo, hi)) { case l_true: return (lo == hi) ? find_t::singleton : find_t::multiple; @@ -106,9 +105,6 @@ namespace polysat { // max size should always be present, regardless of whether we have intervals there (to make sure all fixed bits are considered) widths_set.insert(c.size(v)); - for (pvar v : overlaps) - ensure_var(v); - for (pvar v : overlaps) for (layer const& l : m_units[v].get_layers()) widths_set.insert(l.bit_width); @@ -121,6 +117,7 @@ namespace polysat { rational const& max_value = c.var2pdd(v).max_value(); + m_explain.reset(); lbool result_lo = find_on_layers(v, widths, overlaps, fbi, rational::zero(), max_value, lo); if (result_lo != l_true) return result_lo; @@ -129,18 +126,13 @@ namespace polysat { hi = lo; return l_true; } - + lbool result_hi = find_on_layers(v, widths, overlaps, fbi, lo + 1, max_value, hi); - switch (result_hi) { - case l_false: - hi = lo; - return l_true; - case l_undef: - return l_undef; - default: - return l_true; - } + if (result_hi != l_false) + return result_hi; + hi = lo; + return l_true; } // l_true ... found viable value @@ -153,18 +145,17 @@ namespace polysat { fixed_bits_info const& fbi, rational const& to_cover_lo, rational const& to_cover_hi, - rational& val - ) { - ptr_vector refine_todo; - ptr_vector relevant_entries; + rational& val) { + ptr_vector refine_todo; // max number of interval refinements before falling back to the univariate solver unsigned const refinement_budget = 100; unsigned refinements = refinement_budget; + unsigned explain_size = m_explain.size(); while (refinements--) { - relevant_entries.clear(); - lbool result = find_on_layer(v, widths.size() - 1, widths, overlaps, fbi, to_cover_lo, to_cover_hi, val, refine_todo, relevant_entries); + m_explain.shrink(explain_size); + lbool result = find_on_layer(v, widths.size() - 1, widths, overlaps, fbi, to_cover_lo, to_cover_hi, val, refine_todo); // store bit-intervals we have used for (entry* e : refine_todo) @@ -191,8 +182,6 @@ namespace polysat { if (!refined) return l_true; } - - LOG("Refinement budget exhausted! Fall back to univariate solver."); return l_undef; } @@ -211,12 +200,10 @@ namespace polysat { rational const& to_cover_lo, rational const& to_cover_hi, rational& val, - ptr_vector& refine_todo, - ptr_vector& relevant_entries - ) { + ptr_vector& refine_todo) { unsigned const w = widths[w_idx]; rational const& mod_value = rational::power_of_two(w); - unsigned const first_relevant_for_conflict = relevant_entries.size(); + unsigned const first_relevant_for_conflict = m_explain.size(); LOG("layer " << w << " bits, to_cover: [" << to_cover_lo << "; " << to_cover_hi << "["); SASSERT(0 <= to_cover_lo); @@ -295,12 +282,12 @@ namespace polysat { if (!e) break; - relevant_entries.push_back(e); + m_explain.push_back(e); if (e->interval.is_full()) { - relevant_entries.clear(); - relevant_entries.push_back(e); // full interval e -> all other intervals are subsumed/irrelevant - set_conflict_by_interval(v, w, relevant_entries, 0); + m_explain.reset(); + m_explain.push_back(e); // full interval e -> all other intervals are subsumed/irrelevant + set_conflict_by_interval(v, w, m_explain, 0); return l_false; } @@ -314,7 +301,7 @@ namespace polysat { if (progress >= mod_value) { // covered the whole domain => conflict - set_conflict_by_interval(v, w, relevant_entries, first_relevant_for_conflict); + set_conflict_by_interval(v, w, m_explain, first_relevant_for_conflict); return l_false; } if (progress >= to_cover_len) { @@ -365,7 +352,7 @@ namespace polysat { lower_cover_lo = 0; lower_cover_hi = lower_mod_value; rational a; - lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo, relevant_entries); + lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo); VERIFY(result != l_undef); if (result == l_false) { SASSERT(c.inconsistent()); @@ -387,7 +374,7 @@ namespace polysat { lower_cover_hi = mod(next_val, lower_mod_value); rational a; - lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo, relevant_entries); + lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo); if (result == l_false) { SASSERT(c.inconsistent()); return l_false; // conflict @@ -408,7 +395,7 @@ namespace polysat { if (progress >= mod_value) { // covered the whole domain => conflict - set_conflict_by_interval(v, w, relevant_entries, first_relevant_for_conflict); + set_conflict_by_interval(v, w, m_explain, first_relevant_for_conflict); return l_false; } @@ -421,6 +408,197 @@ namespace polysat { return l_undef; } + void viable::set_conflict_by_interval(pvar v, unsigned w, ptr_vector& intervals, unsigned first_interval) { + SASSERT(!intervals.empty()); + SASSERT(first_interval < intervals.size()); + +#if 0 + bool create_lemma = true; + uint_set vars_to_explain; + char const* lemma_name = nullptr; + + // if there is a full interval, it should be the only one in the given vector + if (intervals.size() == 1 && intervals[0]->interval.is_full()) { + lemma_name = "viable (full interval)"; + entry const* e = intervals[0]; + for (auto sc : e->side_cond) + lemma.insert_eval(~sc); + for (const auto& src : e->src) { + lemma.insert(~src); + core.insert(src); + core.insert_vars(src); + } + if (v != e->var) + vars_to_explain.insert(e->var); + } + else { + SASSERT(all_of(intervals, [](entry* e) { return e->interval.is_proper(); })); + lemma_name = "viable (proper intervals)"; + + // allocate one dummy space in intervals storage to simplify recursive calls + intervals.push_back(nullptr); + entry** intervals_begin = intervals.data() + first_interval; + unsigned num_intervals = intervals.size() - first_interval - 1; + + if (!set_conflict_by_interval_rec(v, w, intervals_begin, num_intervals, core, create_lemma, lemma, vars_to_explain)) + return false; + } + + for (pvar x : vars_to_explain) { + s.m_slicing.explain_simple_overlap(v, x, [this, &core, &lemma](sat::literal l) { + lemma.insert(~l); + core.insert(s.lit2cnstr(l)); + }); + } + + if (create_lemma) + core.add_lemma(lemma_name, lemma.build()); + + //core.logger().log(inf_fi(*this, v)); + core.init_viable_end(v); + return true; +#endif + } + + bool viable::set_conflict_by_interval_rec(pvar v, unsigned w, entry** intervals, unsigned num_intervals, bool& create_lemma, uint_set& vars_to_explain) { + SASSERT(std::all_of(intervals, intervals + num_intervals, [w](entry const* e) { return e->bit_width <= w; })); + // TODO: assert invariants on intervals list + rational const& mod_value = rational::power_of_two(w); + + // heuristic: find longest interval as starting point + unsigned longest_idx = UINT_MAX; + rational longest_len; + for (unsigned i = 0; i < num_intervals; ++i) { + entry* e = intervals[i]; + if (e->bit_width != w) + continue; + rational len = e->interval.current_len(); + if (len > longest_len) { + longest_idx = i; + longest_len = len; + } + } + SASSERT(longest_idx < UINT_MAX); + entry* longest = intervals[longest_idx]; + + if (!longest->valid_for_lemma) + create_lemma = false; + + unsigned i = longest_idx; + entry* e = longest; // e is the current baseline + + entry* tmp = nullptr; + on_scope_exit dont_leak_entry = [this, &tmp]() { + if (tmp) + m_alloc.push_back(tmp); + }; + +#if 0 + while (!longest->interval.currently_contains(e->interval.hi_val())) { + unsigned j = (i + 1) % num_intervals; + entry* n = intervals[j]; + + if (n->bit_width != w) { + // we have a hole starting with 'n', to be filled with intervals at lower bit-width. + // note that the next lower bit-width is not necessarily n->bit_width (e.g., the next layer may start with a hole itself) + unsigned w2 = n->bit_width; + // first, find the next constraint after the hole + unsigned last_idx = j; + unsigned k = j; + while (intervals[k]->bit_width != w) { + if (intervals[k]->bit_width > w2) + w2 = intervals[k]->bit_width; + last_idx = k; + k = (k + 1) % num_intervals; + } + entry* after = intervals[k]; + SASSERT(j < last_idx); // the hole cannot wrap around (but k may be 0) + + rational const& lower_mod_value = rational::power_of_two(w2); + SASSERT(distance(e->interval.hi_val(), after->interval.lo_val(), mod_value) < lower_mod_value); // otherwise we would have started the conflict at w2 already + + // create temporary entry that covers the hole-complement on the lower level + if (!tmp) + tmp = alloc_entry(v); + pdd dummy = s.var2pdd(v).mk_val(123); // we could create extract-terms for boundaries but let's skip that for now and just disable the lemma + create_lemma = false; + tmp->valid_for_lemma = false; + tmp->bit_width = w2; + tmp->interval = eval_interval::proper(dummy, mod(after->interval.lo_val(), lower_mod_value), dummy, mod(e->interval.hi_val(), lower_mod_value)); + + // the index "last_idx+1" is always valid because we allocate an extra dummy space at the end before starting the recursion. + // we only need a single dummy space because the temporary entry is always at bit-width w2. + entry* old = intervals[last_idx + 1]; + intervals[last_idx + 1] = tmp; + + bool result = set_conflict_by_interval_rec(v, w2, &intervals[j], last_idx - j + 2, create_lemma, vars_to_explain); + + intervals[last_idx + 1] = old; + + if (!result) + return false; + + if (create_lemma) { + // hole_length < 2^w2 + signed_constraint c = s.ult(after->interval.lo() - e->interval.hi(), lower_mod_value); + VERIFY(c.is_currently_true(s)); + // this constraint may already exist on the stack with opposite bool value, + // in that case we have a different, earlier conflict + if (c.bvalue(s) == l_false) { + core.reset(); + core.init(~c); + return false; + } + lemma.insert(~c); + } + + tmp->bit_width = w; + tmp->interval = eval_interval::proper(dummy, e->interval.hi_val(), dummy, after->interval.lo_val()); + e = tmp; + j = k; + n = after; + } + + // We may have a bunch of intervals that contain the current value. + // Choose the one making the most progress. + // TODO: it should be the last one, otherwise we wouldn't have added it to relevant_intervals? then we can skip the progress computation. + // (TODO: should note the relevant invariants of intervals list and assert them above.) + SASSERT(n->interval.currently_contains(e->interval.hi_val())); + unsigned best_j = j; + rational best_progress = distance(e->interval.hi_val(), n->interval.hi_val(), mod_value); + while (true) { + unsigned j1 = (j + 1) % num_intervals; + entry* n1 = intervals[j1]; + if (n1->bit_width != w) + break; + if (!n1->interval.currently_contains(e->interval.hi_val())) + break; + j = j1; + n = n1; + SASSERT(n != longest); // because of loop condition on outer while loop + rational progress = distance(e->interval.hi_val(), n->interval.hi_val(), mod_value); + if (progress > best_progress) { + best_j = j; + best_progress = progress; + } + } + j = best_j; + n = intervals[best_j]; + + if (!update_interval_conflict(v, e->interval.hi(), n, core, create_lemma, lemma, vars_to_explain)) + return false; + + i = j; + e = n; + } + + if (!update_interval_conflict(v, e->interval.hi(), longest, core, create_lemma, lemma, vars_to_explain)) + return false; +#endif + + return true; + } + // returns true iff no conflict was encountered bool viable::collect_bit_information(pvar v, bool add_conflict, fixed_bits_info& out_fbi) { @@ -429,7 +607,7 @@ namespace polysat { out_fbi.reset(v_sz); auto& [fixed, just_src, just_side_cond, just_slice] = out_fbi; - svector fbs; + svector fbs; c.get_fixed_bits(v, fbs); for (auto const& fb : fbs) { @@ -625,15 +803,23 @@ namespace polysat { /* * Explain why the current variable is not viable or signleton. */ - dependency_vector viable::explain() { throw default_exception("nyi"); } + dependency_vector viable::explain() { + dependency_vector result; + for (auto e : m_explain) { + auto index = e->constraint_index; + auto const& [sc, d, value] = c.m_constraint_index[index]; + result.push_back(d); + result.append(c.explain_eval(sc)); + } + // TODO: explaination for fixed bits + return result; + } /* * Register constraint at index 'idx' as unitary in v. */ void viable::add_unitary(pvar v, unsigned idx) { - ensure_var(v); - if (c.is_assigned(v)) return; auto [sc, d, value] = c.m_constraint_index[idx]; diff --git a/src/sat/smt/polysat/polysat_viable.h b/src/sat/smt/polysat/polysat_viable.h index 1812ec40c..f426dc326 100644 --- a/src/sat/smt/polysat/polysat_viable.h +++ b/src/sat/smt/polysat/polysat_viable.h @@ -138,6 +138,9 @@ namespace polysat { vector m_units; // set of viable values based on unit multipliers, layered by bit-width in descending order ptr_vector m_equal_lin; // entries that have non-unit multipliers, but are equal ptr_vector m_diseq_lin; // entries that have distinct non-zero multipliers + ptr_vector m_explain; // entries that explain the current propagation or conflict + core_vector m_core; // forbidden interval core + bool m_has_core = false; bool well_formed(entry* e); bool well_formed(layers const& ls); @@ -158,8 +161,6 @@ namespace polysat { bool intersect(pvar v, entry* e); - void ensure_var(pvar v); - lbool find_viable(pvar v, rational& lo, rational& hi); lbool find_on_layers( @@ -180,8 +181,7 @@ namespace polysat { rational const& to_cover_lo, rational const& to_cover_hi, rational& out_val, - ptr_vector& refine_todo, - ptr_vector& relevant_entries); + ptr_vector& refine_todo); template @@ -211,9 +211,8 @@ namespace polysat { throw default_exception("nyi"); } - bool set_conflict_by_interval(pvar v, unsigned w, ptr_vector& intervals, unsigned first_interval) { - throw default_exception("nyi"); - } + void set_conflict_by_interval(pvar v, unsigned w, ptr_vector& intervals, unsigned first_interval); + bool set_conflict_by_interval_rec(pvar v, unsigned w, entry** intervals, unsigned num_intervals, bool& create_lemma, uint_set& vars_to_explain); std::pair find_value(rational const& val, entry* entries) { throw default_exception("nyi"); @@ -236,11 +235,26 @@ namespace polysat { */ dependency_vector explain(); + /* + * flag whether there is a forbidden interval core + */ + bool has_core() const { return m_has_core; } + + /* + * Retrieve lemma corresponding to forbidden interval constraints + */ + core_vector const& get_core() { SASSERT(m_has_core); return m_core; } + /* * Register constraint at index 'idx' as unitary in v. */ void add_unitary(pvar v, unsigned idx); + /* + * Ensure data-structures tracking variable v. + */ + void ensure_var(pvar v); + }; } diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 7e867af2f..460501cd0 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -104,15 +104,26 @@ namespace polysat { return { core, eqs }; } - void solver::set_lemma(vector const& lemma, unsigned level, dependency_vector const& core) { + void solver::set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) { auto [lits, eqs] = explain_deps(core); auto ex = euf::th_explain::conflict(*this, lits, eqs, nullptr); ctx.push(value_trail(m_has_lemma)); m_has_lemma = true; m_lemma_level = level; m_lemma.reset(); - for (auto sc : lemma) - m_lemma.push_back(constraint2expr(sc)); + for (auto sc : aux_core) { + if (std::holds_alternative(sc)) { + auto d = *std::get_if(&sc); + if (d.is_literal()) + m_lemma.push_back(ctx.literal2expr(d.literal())); + else { + auto [v1, v2] = d.eq(); + m_lemma.push_back(ctx.mk_eq(var2enode(v1), var2enode(v2))); + } + } + else if (std::holds_alternative(sc)) + m_lemma.push_back(constraint2expr(*std::get_if(&sc))); + } ctx.set_conflict(ex); } @@ -129,7 +140,7 @@ namespace polysat { sat::literal_vector lits; for (auto* e : m_lemma) - lits.push_back(ctx.mk_literal(e)); + lits.push_back(~ctx.mk_literal(e)); s().add_clause(lits.size(), lits.data(), sat::status::th(true, get_id(), nullptr)); return l_false; } diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 6c7260f95..b5e69c36a 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -128,7 +128,7 @@ namespace polysat { // callbacks from core void add_eq_literal(pvar v, rational const& val) override; void set_conflict(dependency_vector const& core) override; - void set_lemma(vector const& lemma, unsigned level, dependency_vector const& core) override; + void set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) override; dependency propagate(signed_constraint sc, dependency_vector const& deps) override; void propagate(dependency const& d, bool sign, dependency_vector const& deps) override; trail_stack& trail() override;