From 0d80e47350e61abdc9d91b1df117d7006ceece11 Mon Sep 17 00:00:00 2001 From: Jakob Rath Date: Tue, 18 Jul 2023 11:22:02 +0200 Subject: [PATCH] update deps handling (need to support pvars as well) --- src/math/polysat/slicing.cpp | 91 +++++++++++++++++++++++++++++------- src/math/polysat/slicing.h | 44 +++++++++++++---- src/test/slicing.cpp | 21 +++++---- 3 files changed, 122 insertions(+), 34 deletions(-) diff --git a/src/math/polysat/slicing.cpp b/src/math/polysat/slicing.cpp index 8e16d7175..2b72547ea 100644 --- a/src/math/polysat/slicing.cpp +++ b/src/math/polysat/slicing.cpp @@ -74,9 +74,49 @@ Recycle the z3 egraph? #include "math/polysat/log.h" #include "util/tptr.h" +namespace { + + template + [[maybe_unused]] + inline constexpr bool always_false_v = false; + +} namespace polysat { + unsigned slicing::dep_t::to_uint() const { + return std::visit([](auto arg) -> unsigned { + using T = std::decay_t; + if constexpr (std::is_same_v) + return UINT_MAX; + else if constexpr (std::is_same_v) + return (arg.to_uint() << 1); + else if constexpr (std::is_same_v) + return (arg << 1) + 1; + else + static_assert(always_false_v, "non-exhaustive visitor!"); + }, m_data); + } + + slicing::dep_t slicing::dep_t::from_uint(unsigned x) { + if (x == UINT_MAX) + return dep_t(); + else if ((x & 1) == 0) + return dep_t(sat::to_literal(x >> 1)); + else + return dep_t(static_cast(x >> 1)); + } + + std::ostream& slicing::dep_t::display(std::ostream& out) { + if (is_null()) + out << "null"; + else if (is_var()) + out << "v" << var(); + else if (is_lit()) + out << "lit(" << lit() << ")"; + return out; + } + void* slicing::encode_dep(dep_t d) { void* p = box(d.to_uint()); SASSERT_EQ(d, decode_dep(p)); @@ -84,7 +124,7 @@ namespace polysat { } slicing::dep_t slicing::decode_dep(void* p) { - return sat::to_literal(unbox(p)); + return dep_t::from_uint(unbox(p)); } void slicing::display_dep(std::ostream& out, void* d) { @@ -290,36 +330,48 @@ namespace polysat { } void slicing::begin_explain() { - SASSERT(m_marked_deps.empty()); + SASSERT(m_marked_lits.empty()); + SASSERT(m_marked_vars.empty()); } void slicing::end_explain() { - m_marked_deps.reset(); + m_marked_lits.reset(); + m_marked_vars.reset(); } - void slicing::push_dep(void* dp, dep_vector& out_deps) { + void slicing::push_dep(void* dp, sat::literal_vector& out_lits, unsigned_vector& out_vars) { dep_t d = decode_dep(dp); - if (d == sat::null_literal) - return; - if (m_marked_deps.contains(d)) - return; - m_marked_deps.insert(d); - out_deps.push_back(d); + if (d.is_var()) { + pvar v = d.var(); + if (m_marked_vars.contains(v)) + return; + m_marked_vars.insert(v); + out_vars.push_back(v); + } + else if (d.is_lit()) { + sat::literal lit = d.lit(); + if (m_marked_lits.contains(lit)) + return; + m_marked_lits.insert(lit); + out_lits.push_back(lit); + } + else { + SASSERT(d.is_null()); + } } - void slicing::explain_class(enode* x, enode* y, dep_vector& out_deps) { + void slicing::explain_class(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars) { SASSERT_EQ(x->get_root(), y->get_root()); SASSERT(m_tmp_justifications.empty()); m_egraph.begin_explain(); m_egraph.explain_eq(m_tmp_justifications, nullptr, x, y); m_egraph.end_explain(); for (void* dp : m_tmp_justifications) - push_dep(dp, out_deps); + push_dep(dp, out_lits, out_vars); m_tmp_justifications.reset(); } - void slicing::explain_equal(enode* x, enode* y, dep_vector& out_deps) { - begin_explain(); + void slicing::explain_equal(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars) { SASSERT(is_equal(x, y)); enode_vector& xs = m_tmp2; enode_vector& ys = m_tmp3; @@ -337,7 +389,7 @@ namespace polysat { enode* const rx = x->get_root(); enode* const ry = y->get_root(); if (rx == ry) - explain_class(x, y, out_deps); + explain_class(x, y, out_lits, out_vars); else { xs.push_back(sub_hi(rx)); xs.push_back(sub_lo(rx)); @@ -360,6 +412,11 @@ namespace polysat { } } SASSERT(ys.empty()); + } + + void slicing::explain_equal(pvar x, pvar y, sat::literal_vector& out_lits, unsigned_vector& out_vars) { + begin_explain(); + explain_equal(var2slice(x), var2slice(y), out_lits, out_vars); end_explain(); } @@ -548,7 +605,7 @@ namespace polysat { } pvar v = m_solver.add_var(hi - lo + 1); // TODO: can we use 'compressed' slice trees again if we store the source slice here as dependency? - VERIFY(merge(slices, var2slice(v), null_dep)); + VERIFY(merge(slices, var2slice(v), dep_t())); return v; } @@ -598,7 +655,7 @@ namespace polysat { enode_vector tmp; tmp.push_back(pdd2slice(p)); tmp.push_back(pdd2slice(q)); - VERIFY(merge(tmp, var2slice(v), null_dep)); + VERIFY(merge(tmp, var2slice(v), dep_t())); return m_solver.var(v); } diff --git a/src/math/polysat/slicing.h b/src/math/polysat/slicing.h index ff4880026..7147f6d34 100644 --- a/src/math/polysat/slicing.h +++ b/src/math/polysat/slicing.h @@ -28,7 +28,7 @@ Notation: #include "ast/bv_decl_plugin.h" #include "math/polysat/types.h" #include "math/polysat/constraint.h" -#include +#include namespace polysat { @@ -38,9 +38,25 @@ namespace polysat { friend class test_slicing; - using dep_t = sat::literal; - using dep_vector = sat::literal_vector; - static constexpr sat::literal null_dep = sat::null_literal; + class dep_t { + std::variant m_data; + public: + dep_t() { SASSERT(is_null()); } + dep_t(sat::literal l): m_data(l) { SASSERT(l != sat::null_literal); SASSERT_EQ(l, lit()); } + dep_t(pvar v): m_data(v) { SASSERT(v != null_var); SASSERT_EQ(v, var()); } + bool is_null() const { return std::holds_alternative(m_data); } + bool is_lit() const { return std::holds_alternative(m_data); } + bool is_var() const { return std::holds_alternative(m_data); } + sat::literal lit() const { SASSERT(is_lit()); return *std::get_if(&m_data); } + pvar var() const { SASSERT(is_var()); return *std::get_if(&m_data); } + bool operator==(dep_t other) const { return m_data == other.m_data; } + bool operator!=(dep_t other) const { return !operator==(other); } + std::ostream& display(std::ostream& out); + unsigned to_uint() const; + static dep_t from_uint(unsigned x); + }; + + friend std::ostream& operator<<(std::ostream&, slicing::dep_t); using enode = euf::enode; using enode_vector = euf::enode_vector; @@ -140,14 +156,14 @@ namespace polysat { void begin_explain(); void end_explain(); - void push_dep(void* dp, dep_vector& out_deps); + void push_dep(void* dp, sat::literal_vector& out_lits, unsigned_vector& out_vars); // Extract reason why slices x and y are in the same equivalence class - void explain_class(enode* x, enode* y, dep_vector& out_deps); + void explain_class(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars); // Extract reason why slices x and y are equal // (i.e., x and y have the same base, but are not necessarily in the same equivalence class) - void explain_equal(enode* x, enode* y, dep_vector& out_deps); + void explain_equal(enode* x, enode* y, sat::literal_vector& out_lits, unsigned_vector& out_vars); // Merge equivalence classes of two base slices. // Returns true if merge succeeded without conflict. @@ -184,7 +200,8 @@ namespace polysat { mutable enode_vector m_tmp2; mutable enode_vector m_tmp3; ptr_vector m_tmp_justifications; - sat::literal_set m_marked_deps; + sat::literal_set m_marked_lits; + uint_set m_marked_vars; // get a slice that is equivalent to the given pdd (may introduce new variable) enode* pdd2slice(pdd const& p); @@ -222,6 +239,16 @@ namespace polysat { void add_value(pvar v, rational const& value); void add_constraint(signed_constraint c); + // update congruences, egraph + void propagate(); + + bool is_conflict() const { return m_egraph.inconsistent(); } + + /** Extract reason for conflict */ + void explain(sat::literal_vector& out_lits, unsigned_vector& out_vars); + /** Extract reason for x == y */ + void explain_equal(pvar x, pvar y, sat::literal_vector& out_lits, unsigned_vector& out_vars); + // TODO: // Query for a given variable v: // - set of variables that share at least one slice with v (need variable, offset/width relative to v) @@ -232,4 +259,5 @@ namespace polysat { inline std::ostream& operator<<(std::ostream& out, slicing const& s) { return s.display(out); } + inline std::ostream& operator<<(std::ostream& out, slicing::dep_t d) { return d.display(out); } } diff --git a/src/test/slicing.cpp b/src/test/slicing.cpp index b25b5760f..67324ed2e 100644 --- a/src/test/slicing.cpp +++ b/src/test/slicing.cpp @@ -114,17 +114,19 @@ namespace polysat { << " root(v" << b << ") = " << sl.var2slice(b)->get_root_id() << " root(v" << c << ") = " << sl.var2slice(c)->get_root_id() << "\n"; - sat::literal_vector reason; - sl.explain_equal(sl.var2slice(b), sl.var2slice(c), reason); - std::cout << " Reason: " << reason << "\n\n"; + sat::literal_vector reason_lits; + unsigned_vector reason_vars; + sl.explain_equal(sl.var2slice(b), sl.var2slice(c), reason_lits, reason_vars); + std::cout << " Reason: " << reason_lits << " vars " << reason_vars << "\n"; std::cout << "v" << b << " = " << d << "? " << sl.is_equal(sl.var2slice(b), sl.pdd2slice(d)) << " root(v" << b << ") = " << sl.var2slice(b)->get_root_id() << " root(" << d << ") = " << sl.pdd2slice(d)->get_root_id() << "\n"; - reason.reset(); - sl.explain_equal(sl.var2slice(b), sl.pdd2slice(d), reason); - std::cout << " Reason: " << reason << "\n\n"; + reason_lits.reset(); + reason_vars.reset(); + sl.explain_equal(sl.var2slice(b), sl.pdd2slice(d), reason_lits, reason_vars); + std::cout << " Reason: " << reason_lits << " vars " << reason_vars << "\n"; VERIFY(sl.invariant()); } @@ -156,9 +158,10 @@ namespace polysat { << " slice(v" << d << ") = " << sl.var2slice(d)->get_id() << " slice(v" << e << ") = " << sl.var2slice(e)->get_id() << "\n"; - sat::literal_vector reason; - sl.explain_equal(sl.var2slice(d), sl.var2slice(e), reason); - std::cout << " Reason: " << reason << "\n"; + sat::literal_vector reason_lits; + unsigned_vector reason_vars; + sl.explain_equal(sl.var2slice(d), sl.var2slice(e), reason_lits, reason_vars); + std::cout << " Reason: " << reason_lits << " vars " << reason_vars << "\n"; sl.display_tree(std::cout); VERIFY(sl.invariant());