From ab1a2e27a7c887df540f713847771ebf23f16595 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 7 Dec 2023 15:53:07 -0800 Subject: [PATCH 01/89] v2 of polysat --- src/sat/sat_solver.cpp | 5 +- src/sat/smt/CMakeLists.txt | 6 + src/sat/smt/euf_solver.cpp | 9 + src/sat/smt/euf_solver.h | 2 + src/sat/smt/polysat_assignment.cpp | 119 ++++++++++ src/sat/smt/polysat_assignment.h | 120 ++++++++++ src/sat/smt/polysat_constraints.cpp | 25 ++ src/sat/smt/polysat_constraints.h | 128 +++++++++++ src/sat/smt/polysat_core.cpp | 276 ++++++++++++++++++++++ src/sat/smt/polysat_core.h | 128 +++++++++++ src/sat/smt/polysat_internalize.cpp | 343 ++++++++++++++++++++++++++++ src/sat/smt/polysat_model.cpp | 58 +++++ src/sat/smt/polysat_solver.cpp | 191 ++++++++++++++++ src/sat/smt/polysat_solver.h | 187 +++++++++++++++ src/sat/smt/polysat_substitution.h | 212 +++++++++++++++++ src/sat/smt/polysat_types.h | 45 ++++ src/sat/smt/polysat_viable.h | 55 +++++ src/util/var_queue.h | 2 + 18 files changed, 1908 insertions(+), 3 deletions(-) create mode 100644 src/sat/smt/polysat_assignment.cpp create mode 100644 src/sat/smt/polysat_assignment.h create mode 100644 src/sat/smt/polysat_constraints.cpp create mode 100644 src/sat/smt/polysat_constraints.h create mode 100644 src/sat/smt/polysat_core.cpp create mode 100644 src/sat/smt/polysat_core.h create mode 100644 src/sat/smt/polysat_internalize.cpp create mode 100644 src/sat/smt/polysat_model.cpp create mode 100644 src/sat/smt/polysat_solver.cpp create mode 100644 src/sat/smt/polysat_solver.h create mode 100644 src/sat/smt/polysat_substitution.h create mode 100644 src/sat/smt/polysat_types.h create mode 100644 src/sat/smt/polysat_viable.h diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 96b3c13c4..716a8effe 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -2429,9 +2429,8 @@ namespace sat { m_conflicts_since_restart++; m_conflicts_since_gc++; m_stats.m_conflict++; - if (m_step_size > m_config.m_step_size_min) { - m_step_size -= m_config.m_step_size_dec; - } + if (m_step_size > m_config.m_step_size_min) + m_step_size -= m_config.m_step_size_dec; bool unique_max; m_conflict_lvl = get_max_lvl(m_not_l, m_conflict, unique_max); diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 7caccded6..2a6fb9e66 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -33,6 +33,12 @@ z3_add_component(sat_smt pb_internalize.cpp pb_pb.cpp pb_solver.cpp + polysat_assignment.cpp + polysat_constraints.cpp + polysat_core.cpp + polysat_internalize.cpp + polysat_model.cpp + polysat_solver.cpp q_clause.cpp q_ematch.cpp q_eval.cpp diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 3ae4425fc..51c0518e5 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -209,6 +209,15 @@ namespace euf { s().assign(lit, sat::justification::mk_ext_justification(s().scope_lvl(), idx)); } + lbool solver::resolve_conflict() { + for (auto* s : m_solvers) { + lbool r = s->resolve_conflict(); + if (r != l_undef) + return r; + } + return l_undef; + } + /** Retrieve set of literals r that imply r. Since the set of literals are retrieved modulo multiple theories in a single implication diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index db99ec512..9cac6e02a 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -363,6 +363,7 @@ namespace euf { bool propagate(enode* a, enode* b, th_explain* p) { return propagate(a, b, p->to_index()); } size_t* to_justification(sat::literal l) { return to_ptr(l); } void set_conflict(th_explain* p) { set_conflict(p->to_index()); } + bool inconsistent() const { return s().inconsistent() || m_egraph.inconsistent(); } bool set_root(literal l, literal r) override; void flush_roots() override; @@ -378,6 +379,7 @@ namespace euf { bool get_case_split(bool_var& var, lbool& phase) override; void asserted(literal l) override; sat::check_result check() override; + lbool resolve_conflict() override; void push() override; void pop(unsigned n) override; void user_push() override; diff --git a/src/sat/smt/polysat_assignment.cpp b/src/sat/smt/polysat_assignment.cpp new file mode 100644 index 000000000..a985188fa --- /dev/null +++ b/src/sat/smt/polysat_assignment.cpp @@ -0,0 +1,119 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat substitution and assignment + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ + +#include "sat/smt/polysat_assignment.h" +#include "sat/smt/polysat_core.h" + +namespace polysat { + + substitution::substitution(pdd p) + : m_subst(std::move(p)) { } + + substitution::substitution(dd::pdd_manager& m) + : m_subst(m.one()) { } + + substitution substitution::add(pvar var, rational const& value) const { + return {m_subst.subst_add(var, value)}; + } + + pdd substitution::apply_to(pdd const& p) const { + return p.subst_val(m_subst); + } + + bool substitution::contains(pvar var) const { + rational out_value; + return value(var, out_value); + } + + bool substitution::value(pvar var, rational& out_value) const { + return m_subst.subst_get(var, out_value); + } + + assignment::assignment(core& s) + : m_core(s) { } + + + assignment assignment::clone() const { + assignment a(s()); + a.m_pairs = m_pairs; + a.m_subst.reserve(m_subst.size()); + for (unsigned i = m_subst.size(); i-- > 0; ) + if (m_subst[i]) + a.m_subst.set(i, alloc(substitution, *m_subst[i])); + a.m_subst_trail = m_subst_trail; + return a; + } + + bool assignment::contains(pvar var) const { + return subst(s().size(var)).contains(var); + } + + bool assignment::value(pvar var, rational& out_value) const { + return subst(s().size(var)).value(var, out_value); + } + + substitution& assignment::subst(unsigned sz) { + return const_cast(std::as_const(*this).subst(sz)); + } + + substitution const& assignment::subst(unsigned sz) const { + m_subst.reserve(sz + 1); + if (!m_subst[sz]) + m_subst.set(sz, alloc(substitution, s().sz2pdd(sz))); + return *m_subst[sz]; + } + + void assignment::push(pvar var, rational const& value) { + SASSERT(all_of(m_pairs, [var](assignment_item_t const& item) { return item.first != var; })); + m_pairs.push_back({var, value}); + unsigned const sz = s().size(var); + substitution& sub = subst(sz); + m_subst_trail.push_back(sub); + sub = sub.add(var, value); + SASSERT_EQ(sub, *m_subst[sz]); + } + + void assignment::pop() { + substitution& sub = m_subst_trail.back(); + unsigned sz = sub.bit_width(); + SASSERT_EQ(sz, s().size(m_pairs.back().first)); + *m_subst[sz] = sub; + m_subst_trail.pop_back(); + m_pairs.pop_back(); + } + + pdd assignment::apply_to(pdd const& p) const { + unsigned const sz = p.power_of_2(); + return subst(sz).apply_to(p); + } + + std::ostream& substitution::display(std::ostream& out) const { + char const* delim = ""; + pdd p = m_subst; + while (!p.is_val()) { + SASSERT(p.lo().is_val()); + out << delim << "v" << p.var() << " := " << p.lo(); + delim = " "; + p = p.hi(); + } + return out; + } + + std::ostream& assignment::display(std::ostream& out) const { + char const* delim = ""; + for (auto const& [var, value] : m_pairs) + out << delim << var << " == " << value, delim = " "; + return out; + } +} diff --git a/src/sat/smt/polysat_assignment.h b/src/sat/smt/polysat_assignment.h new file mode 100644 index 000000000..daff03dd5 --- /dev/null +++ b/src/sat/smt/polysat_assignment.h @@ -0,0 +1,120 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat substitution and assignment + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once +#include "util/scoped_ptr_vector.h" +#include "sat/smt/polysat_types.h" + +namespace polysat { + + class core; + + using assignment_item_t = std::pair; + + class substitution_iterator { + pdd m_current; + substitution_iterator(pdd current) : m_current(std::move(current)) {} + friend class substitution; + + public: + using value_type = assignment_item_t; + using difference_type = std::ptrdiff_t; + using pointer = value_type const*; + using reference = value_type const&; + using iterator_category = std::input_iterator_tag; + + substitution_iterator& operator++() { + SASSERT(!m_current.is_val()); + m_current = m_current.hi(); + return *this; + } + + value_type operator*() const { + SASSERT(!m_current.is_val()); + return { m_current.var(), m_current.lo().val() }; + } + + bool operator==(substitution_iterator const& other) const { return m_current == other.m_current; } + bool operator!=(substitution_iterator const& other) const { return !operator==(other); } + }; + + /** Substitution for a single bit width. */ + class substitution { + pdd m_subst; + + substitution(pdd p); + + public: + substitution(dd::pdd_manager& m); + [[nodiscard]] substitution add(pvar var, rational const& value) const; + [[nodiscard]] pdd apply_to(pdd const& p) const; + + [[nodiscard]] bool contains(pvar var) const; + [[nodiscard]] bool value(pvar var, rational& out_value) const; + + [[nodiscard]] bool empty() const { return m_subst.is_one(); } + + pdd const& to_pdd() const { return m_subst; } + unsigned bit_width() const { return to_pdd().power_of_2(); } + + bool operator==(substitution const& other) const { return &m_subst.manager() == &other.m_subst.manager() && m_subst == other.m_subst; } + bool operator!=(substitution const& other) const { return !operator==(other); } + + std::ostream& display(std::ostream& out) const; + + using const_iterator = substitution_iterator; + const_iterator begin() const { return {m_subst}; } + const_iterator end() const { return {m_subst.manager().one()}; } + }; + + /** Full variable assignment, may include variables of varying bit widths. */ + class assignment { + core& m_core; + vector m_pairs; + mutable scoped_ptr_vector m_subst; + vector m_subst_trail; + + substitution& subst(unsigned sz); + core& s() const { return m_core; } + public: + assignment(core& s); + // prevent implicit copy, use clone() if you do need a copy + assignment(assignment const&) = delete; + assignment& operator=(assignment const&) = delete; + assignment(assignment&&) = default; + assignment& operator=(assignment&&) = default; + assignment clone() const; + + void push(pvar var, rational const& value); + void pop(); + + pdd apply_to(pdd const& p) const; + + bool contains(pvar var) const; + bool value(pvar var, rational& out_value) const; + rational value(pvar var) const { rational val; VERIFY(value(var, val)); return val; } + bool empty() const { return pairs().empty(); } + substitution const& subst(unsigned sz) const; + vector const& pairs() const { return m_pairs; } + using const_iterator = decltype(m_pairs)::const_iterator; + const_iterator begin() const { return pairs().begin(); } + const_iterator end() const { return pairs().end(); } + + std::ostream& display(std::ostream& out) const; + }; + + inline std::ostream& operator<<(std::ostream& out, substitution const& sub) { return sub.display(out); } + + inline std::ostream& operator<<(std::ostream& out, assignment const& a) { return a.display(out); } +} + diff --git a/src/sat/smt/polysat_constraints.cpp b/src/sat/smt/polysat_constraints.cpp new file mode 100644 index 000000000..1c9de327c --- /dev/null +++ b/src/sat/smt/polysat_constraints.cpp @@ -0,0 +1,25 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat constraints + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#include "sat/smt/polysat_core.h" +#include "sat/smt/polysat_solver.h" +#include "sat/smt/polysat_constraints.h" + +namespace polysat { + + signed_constraint constraints::ule(pdd const& p, pdd const& q) { + auto* c = alloc(ule_constraint, p, q); + m_trail.push(new_obj_trail(c)); + return signed_constraint(ckind_t::ule_t, c); + } +} diff --git a/src/sat/smt/polysat_constraints.h b/src/sat/smt/polysat_constraints.h new file mode 100644 index 000000000..24c7f9a11 --- /dev/null +++ b/src/sat/smt/polysat_constraints.h @@ -0,0 +1,128 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat constraints + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ + + +#pragma once +#include "sat/smt/polysat_types.h" + +namespace polysat { + + class core; + + using pdd = dd::pdd; + using pvar = unsigned; + + enum ckind_t { ule_t, umul_ovfl_t, smul_fl_t, op_t }; + + class constraint { + unsigned_vector m_vars; + public: + virtual ~constraint() {} + unsigned_vector& vars() { return m_vars; } + unsigned_vector const& vars() const { return m_vars; } + unsigned var(unsigned idx) const { return m_vars[idx]; } + bool contains_var(pvar v) const { return m_vars.contains(v); } + }; + + class ule_constraint : public constraint { + pdd m_lhs, m_rhs; + public: + ule_constraint(pdd const& lhs, pdd const& rhs) : m_lhs(lhs), m_rhs(rhs) {} + }; + + class signed_constraint { + bool m_sign = false; + ckind_t m_op = ule_t; + constraint* m_constraint = nullptr; + public: + signed_constraint() {} + signed_constraint(ckind_t c, constraint* p) : m_op(c), m_constraint(p) {} + signed_constraint operator~() const { signed_constraint r(*this); r.m_sign = !r.m_sign; return r; } + bool sign() const { return m_sign; } + unsigned_vector& vars() { return m_constraint->vars(); } + unsigned_vector const& vars() const { return m_constraint->vars(); } + unsigned var(unsigned idx) const { return m_constraint->var(idx); } + bool contains_var(pvar v) const { return m_constraint->contains_var(v); } + bool is_ule() const { return m_op == ule_t; } + ule_constraint& to_ule() { return *reinterpret_cast(m_constraint); } + bool is_eq(pvar& v, rational& val) { throw default_exception("nyi"); } + }; + + using dependent_constraint = std::pair; + + class constraints { + trail_stack& m_trail; + public: + constraints(trail_stack& c) : m_trail(c) {} + + signed_constraint eq(pdd const& p) { return ule(p, p.manager().mk_val(0)); } + signed_constraint eq(pdd const& p, rational const& v) { return eq(p - p.manager().mk_val(v)); } + signed_constraint ule(pdd const& p, pdd const& q); + signed_constraint sle(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint ult(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint slt(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint umul_ovfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint smul_ovfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint smul_udfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint bit(pdd const& p, unsigned i) { throw default_exception("nyi"); } + + signed_constraint diseq(pdd const& p) { return ~eq(p); } + signed_constraint diseq(pdd const& p, pdd const& q) { return diseq(p - q); } + signed_constraint diseq(pdd const& p, rational const& q) { return diseq(p - q); } + signed_constraint diseq(pdd const& p, int q) { return diseq(p, rational(q)); } + signed_constraint diseq(pdd const& p, unsigned q) { return diseq(p, rational(q)); } + + signed_constraint ule(pdd const& p, rational const& q) { return ule(p, p.manager().mk_val(q)); } + signed_constraint ule(rational const& p, pdd const& q) { return ule(q.manager().mk_val(p), q); } + signed_constraint ule(pdd const& p, int q) { return ule(p, rational(q)); } + signed_constraint ule(pdd const& p, unsigned q) { return ule(p, rational(q)); } + signed_constraint ule(int p, pdd const& q) { return ule(rational(p), q); } + signed_constraint ule(unsigned p, pdd const& q) { return ule(rational(p), q); } + + signed_constraint uge(pdd const& p, pdd const& q) { return ule(q, p); } + signed_constraint uge(pdd const& p, rational const& q) { return ule(q, p); } + + signed_constraint ult(pdd const& p, rational const& q) { return ult(p, p.manager().mk_val(q)); } + signed_constraint ult(rational const& p, pdd const& q) { return ult(q.manager().mk_val(p), q); } + signed_constraint ult(int p, pdd const& q) { return ult(rational(p), q); } + signed_constraint ult(unsigned p, pdd const& q) { return ult(rational(p), q); } + signed_constraint ult(pdd const& p, int q) { return ult(p, rational(q)); } + signed_constraint ult(pdd const& p, unsigned q) { return ult(p, rational(q)); } + + signed_constraint slt(pdd const& p, rational const& q) { return slt(p, p.manager().mk_val(q)); } + signed_constraint slt(rational const& p, pdd const& q) { return slt(q.manager().mk_val(p), q); } + signed_constraint slt(pdd const& p, int q) { return slt(p, rational(q)); } + signed_constraint slt(pdd const& p, unsigned q) { return slt(p, rational(q)); } + signed_constraint slt(int p, pdd const& q) { return slt(rational(p), q); } + signed_constraint slt(unsigned p, pdd const& q) { return slt(rational(p), q); } + + + signed_constraint sgt(pdd const& p, pdd const& q) { return slt(q, p); } + signed_constraint sgt(pdd const& p, int q) { return slt(q, p); } + signed_constraint sgt(pdd const& p, unsigned q) { return slt(q, p); } + signed_constraint sgt(int p, pdd const& q) { return slt(q, p); } + signed_constraint sgt(unsigned p, pdd const& q) { return slt(q, p); } + + signed_constraint umul_ovfl(pdd const& p, rational const& q) { return umul_ovfl(p, p.manager().mk_val(q)); } + signed_constraint umul_ovfl(rational const& p, pdd const& q) { return umul_ovfl(q.manager().mk_val(p), q); } + signed_constraint umul_ovfl(pdd const& p, int q) { return umul_ovfl(p, rational(q)); } + signed_constraint umul_ovfl(pdd const& p, unsigned q) { return umul_ovfl(p, rational(q)); } + signed_constraint umul_ovfl(int p, pdd const& q) { return umul_ovfl(rational(p), q); } + signed_constraint umul_ovfl(unsigned p, pdd const& q) { return umul_ovfl(rational(p), q); } + + + //signed_constraint even(pdd const& p) { return parity_at_least(p, 1); } + //signed_constraint odd(pdd const& p) { return ~even(p); } + }; +} \ No newline at end of file diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp new file mode 100644 index 000000000..27d6ee731 --- /dev/null +++ b/src/sat/smt/polysat_core.cpp @@ -0,0 +1,276 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + polysat_core.cpp + +Abstract: + + PolySAT core functionality + +Author: + + Nikolaj Bjorner (nbjorner) 2022-01-26 + Jakob Rath 2021-04-06 + +Notes: + +polysat::solver +- adds assignments +- calls propagation and check + +polysat::core +- propagates literals +- crates case splits by value assignment (equalities) +- detects conflicts based on Literal assignmets +- adds lemmas based on projections + +--*/ + +#include "params/bv_rewriter_params.hpp" +#include "sat/smt/polysat_solver.h" +#include "sat/smt/euf_solver.h" + +namespace polysat { + + class core::mk_assign_var : public trail { + pvar m_var; + core& c; + public: + mk_assign_var(pvar v, core& c) : m_var(v), c(c) {} + void undo() { + c.m_justification[m_var] = nullptr; + c.m_assignment.pop(); + } + }; + + class core::mk_dqueue_var : public trail { + pvar m_var; + core& c; + public: + mk_dqueue_var(pvar v, core& c) : m_var(v), c(c) {} + void undo() { + c.m_var_queue.unassign_var_eh(m_var); + } + }; + + class core::mk_add_var : public trail { + core& c; + public: + mk_add_var(core& c) : c(c) {} + void undo() override { + c.del_var(); + } + }; + + class core::mk_add_watch : public trail { + core& c; + unsigned m_idx; + public: + mk_add_watch(core& c, unsigned idx) : c(c), m_idx(idx) {} + void undo() override { + auto& sc = c.m_prop_queue[m_idx].first; + auto& vars = sc.vars(); + if (vars.size() > 0) + c.m_watch[vars[0]].pop_back(); + if (vars.size() > 1) + c.m_watch[vars[1]].pop_back(); + } + }; + + core::core(solver& s) : + s(s), + m_viable(*this), + m_constraints(s.get_trail_stack()), + m_assignment(*this), + m_dep(s.get_region()), + m_var_queue(m_activity) + {} + + pdd core::value(rational const& v, unsigned sz) { + return sz2pdd(sz).mk_val(v); + } + + dd::pdd_manager& core::sz2pdd(unsigned sz) const { + m_pdd.reserve(sz + 1); + if (!m_pdd[sz]) + m_pdd.set(sz, alloc(dd::pdd_manager, 1000, dd::pdd_manager::semantics::mod2N_e, sz)); + return *m_pdd[sz]; + } + + dd::pdd_manager& core::var2pdd(pvar v) const { + return sz2pdd(size(v)); + } + + pvar core::add_var(unsigned sz) { + unsigned v = m_vars.size(); + m_vars.push_back(sz2pdd(sz).mk_var(v)); + m_activity.push_back({ sz, 0 }); + m_justification.push_back(nullptr); + m_watch.push_back({}); + m_var_queue.mk_var_eh(v); + s.ctx.push(mk_add_var(*this)); + return v; + } + + void core::del_var() { + unsigned v = m_vars.size() - 1; + m_vars.pop_back(); + m_activity.pop_back(); + m_justification.pop_back(); + m_watch.pop_back(); + m_var_queue.del_var_eh(v); + } + + // case split on unassigned variables until all are assigned values. + // create equality literal for unassigned variable. + // return new_eq if there is a new literal. + + sat::check_result core::check() { + if (m_var_queue.empty()) + return sat::check_result::CR_DONE; + m_var = m_var_queue.next_var(); + s.ctx.push(mk_dqueue_var(m_var, *this)); + switch (m_viable.find_viable(m_var, m_value)) { + case find_t::empty: + m_unsat_core = m_viable.explain(); + propagate_unsat_core(); + return sat::check_result::CR_CONTINUE; + case find_t::singleton: + s.propagate(m_constraints.eq(var2pdd(m_var), m_value), m_viable.explain()); + return sat::check_result::CR_CONTINUE; + case find_t::multiple: + return sat::check_result::CR_CONTINUE; + case find_t::resource_out: + return sat::check_result::CR_GIVEUP; + } + UNREACHABLE(); + return sat::check_result::CR_GIVEUP; + } + + // First propagate Boolean assignment, then propagate value assignment + bool core::propagate() { + if (m_qhead == m_prop_queue.size() && m_vqhead == m_prop_queue.size()) + return false; + s.ctx.push(value_trail(m_qhead)); + for (; m_qhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_qhead) + propagate_constraint(m_qhead, m_prop_queue[m_qhead]); + s.ctx.push(value_trail(m_vqhead)); + for (; m_vqhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_vqhead) + propagate_value(m_vqhead, m_prop_queue[m_vqhead]); + return true; + } + + void core::propagate_constraint(unsigned idx, dependent_constraint& dc) { + auto [sc, dep] = dc; + if (sc.is_eq(m_var, m_value)) { + propagate_assignment(m_var, m_value, dep); + return; + } + add_watch(idx, sc); + } + + void core::add_watch(unsigned idx, signed_constraint& sc) { + auto& vars = sc.vars(); + unsigned i = 0, j = 0, sz = vars.size(); + for (; i < sz && j < 2; ++i) + if (!is_assigned(vars[i])) + std::swap(vars[i], vars[j++]); + if (vars.size() > 0) + add_watch(idx, vars[0]); + if (vars.size() > 1) + add_watch(idx, vars[1]); + s.ctx.push(mk_add_watch(*this, idx)); + } + + void core::add_watch(unsigned idx, unsigned var) { + m_watch[var].push_back(idx); + } + + void core::propagate_assignment(pvar v, rational const& value, stacked_dependency* dep) { + if (is_assigned(v)) + return; + if (m_var_queue.contains(v)) { + m_var_queue.del_var_eh(v); + s.ctx.push(mk_dqueue_var(v, *this)); + } + m_values[v] = value; + m_justification[v] = dep; + m_assignment.push(v , value); + s.ctx.push(mk_assign_var(v, *this)); + + // update the watch lists for pvars + // remove constraints from m_watch[v] that have more than 2 free variables. + // for entries where there is only one free variable left add to viable set + unsigned j = 0; + for (auto idx : m_watch[v]) { + auto [sc, dep] = m_prop_queue[idx]; + auto& vars = sc.vars(); + if (vars[0] != v) + std::swap(vars[0], vars[1]); + SASSERT(vars[0] == v); + bool swapped = false; + for (unsigned i = vars.size(); i-- > 2; ) { + if (!is_assigned(vars[i])) { + add_watch(idx, vars[i]); + std::swap(vars[i], vars[0]); + swapped = true; + break; + } + } + if (!swapped) { + m_watch[v][j++] = idx; + } + + // constraint is unitary, add to viable set + if (vars.size() >= 2 && is_assigned(vars[0]) && !is_assigned(vars[1])) { + // m_viable.add_unitary(vars[1], idx); + } + } + m_watch[v].shrink(j); + } + + void core::propagate_value(unsigned idx, dependent_constraint const& dc) { + auto [sc, dep] = dc; + // check if sc evaluates to false + switch (eval(sc)) { + case l_true: + return; + case l_false: + m_unsat_core = explain_eval(dc); + propagate_unsat_core(); + return; + default: + break; + } + // if sc is v == value, then check the watch list for v if they evaluate to false. + if (sc.is_eq(m_var, m_value)) { + for (auto idx : m_watch[m_var]) { + auto [sc, dep] = m_prop_queue[idx]; + if (eval(sc) == l_false) { + m_unsat_core = explain_eval({ sc, dep }); + propagate_unsat_core(); + return; + } + } + } + + throw default_exception("nyi"); + } + + bool core::propagate_unsat_core() { + // default is to use unsat core: + s.set_conflict(m_unsat_core); + // if core is based on viable, use s.set_lemma(); + throw default_exception("nyi"); + } + + void core::assign_eh(signed_constraint const& sc, dependency const& dep) { + m_prop_queue.push_back({ sc, m_dep.mk_leaf(dep) }); + s.ctx.push(push_back_vector(m_prop_queue)); + } + + + +} diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat_core.h new file mode 100644 index 000000000..7fdf8c88c --- /dev/null +++ b/src/sat/smt/polysat_core.h @@ -0,0 +1,128 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + polysat_core.h + +Abstract: + + Core solver for polysat + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-30 + Jakob Rath 2021-04-06 + +--*/ +#pragma once + +#include "util/dependency.h" +#include "math/dd/dd_pdd.h" +#include "sat/smt/sat_th.h" +#include "sat/smt/polysat_types.h" +#include "sat/smt/polysat_constraints.h" +#include "sat/smt/polysat_viable.h" +#include "sat/smt/polysat_assignment.h" + +namespace polysat { + + class core; + class solver; + + class core { + class mk_add_var; + class mk_dqueue_var; + class mk_assign_var; + class mk_add_watch; + typedef svector> activity; + friend class viable; + friend class constraints; + friend class assignment; + + solver& s; + viable m_viable; + constraints m_constraints; + assignment m_assignment; + unsigned m_qhead = 0, m_vqhead = 0; + svector m_prop_queue; + stacked_dependency_manager m_dep; + mutable scoped_ptr_vector m_pdd; + dependency_vector m_unsat_core; + + + // attributes associated with variables + vector m_vars; // for each variable a pdd + vector m_values; // current value of assigned variable + ptr_vector m_justification; // justification for assignment + activity m_activity; // activity of variables + var_queue m_var_queue; // priority queue of variables to assign + vector m_watch; // watch lists for variables for constraints on m_prop_queue where they occur + + vector m_subst; // substitution, one for each size. + + // values to split on + rational m_value; + pvar m_var = 0; + + dd::pdd_manager& sz2pdd(unsigned sz) const; + dd::pdd_manager& var2pdd(pvar v) const; + unsigned size(pvar v) const { return var2pdd(v).power_of_2(); } + void del_var(); + + bool is_assigned(pvar v) { return nullptr != m_justification[v]; } + void propagate_constraint(unsigned idx, dependent_constraint& dc); + void propagate_value(unsigned idx, dependent_constraint const& dc); + void propagate_assignment(pvar v, rational const& value, stacked_dependency* dep); + bool propagate_unsat_core(); + + void add_watch(unsigned idx, signed_constraint& sc); + void add_watch(unsigned idx, unsigned var); + + lbool eval(signed_constraint sc) { throw default_exception("nyi"); } + dependency_vector explain_eval(dependent_constraint const& dc) { throw default_exception("nyi"); } + + public: + core(solver& s); + + sat::check_result check(); + + bool propagate(); + void assign_eh(signed_constraint const& sc, dependency const& dep); + + expr_ref constraint2expr(signed_constraint const& sc) const { throw default_exception("nyi"); } + + pdd value(rational const& v, unsigned sz); + + signed_constraint eq(pdd const& p) { return m_constraints.eq(p); } + signed_constraint eq(pdd const& p, pdd const& q) { return m_constraints.eq(p - q); } + signed_constraint ule(pdd const& p, pdd const& q) { return m_constraints.ule(p, q); } + signed_constraint sle(pdd const& p, pdd const& q) { return m_constraints.sle(p, q); } + signed_constraint umul_ovfl(pdd const& p, pdd const& q) { return m_constraints.umul_ovfl(p, q); } + signed_constraint smul_ovfl(pdd const& p, pdd const& q) { return m_constraints.smul_ovfl(p, q); } + signed_constraint smul_udfl(pdd const& p, pdd const& q) { return m_constraints.smul_udfl(p, q); } + signed_constraint bit(pdd const& p, unsigned i) { return m_constraints.bit(p, i); } + + + pdd lshr(pdd a, pdd b) { throw default_exception("nyi"); } + pdd ashr(pdd a, pdd b) { throw default_exception("nyi"); } + pdd shl(pdd a, pdd b) { throw default_exception("nyi"); } + pdd band(pdd a, pdd b) { throw default_exception("nyi"); } + pdd bxor(pdd a, pdd b) { throw default_exception("nyi"); } + pdd bor(pdd a, pdd b) { throw default_exception("nyi"); } + pdd bnand(pdd a, pdd b) { throw default_exception("nyi"); } + pdd bxnor(pdd a, pdd b) { throw default_exception("nyi"); } + pdd bnor(pdd a, pdd b) { throw default_exception("nyi"); } + pdd bnot(pdd a) { throw default_exception("nyi"); } + std::pair quot_rem(pdd const& n, pdd const& d) { throw default_exception("nyi"); } + pdd zero_ext(pdd a, unsigned sz) { throw default_exception("nyi"); } + pdd sign_ext(pdd a, unsigned sz) { throw default_exception("nyi"); } + pdd extract(pdd src, unsigned hi, unsigned lo) { throw default_exception("nyi"); } + pdd concat(unsigned n, pdd const* args) { throw default_exception("nyi"); } + pvar add_var(unsigned sz); + pdd var(pvar p) { return m_vars[p]; } + + std::ostream& display(std::ostream& out) const { throw default_exception("nyi"); } + }; + +} diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp new file mode 100644 index 000000000..80c2fb19b --- /dev/null +++ b/src/sat/smt/polysat_internalize.cpp @@ -0,0 +1,343 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + polysat_model.cpp + +Abstract: + + PolySAT model generation + +Author: + + Nikolaj Bjorner (nbjorner) 2022-01-26 + +--*/ + +#include "params/bv_rewriter_params.hpp" +#include "sat/smt/polysat_solver.h" +#include "sat/smt/euf_solver.h" + +namespace polysat { + + euf::theory_var solver::mk_var(euf::enode* n) { + return euf::th_euf_solver::mk_var(n); + } + + sat::literal solver::internalize(expr* e, bool sign, bool root) { + force_push(); + SASSERT(m.is_bool(e)); + if (!visit_rec(m, e, sign, root)) + return sat::null_literal; + sat::literal lit = expr2literal(e); + if (sign) + lit.neg(); + return lit; + } + + void solver::internalize(expr* e) { + force_push(); + visit_rec(m, e, false, false); + } + + bool solver::visit(expr* e) { + force_push(); + if (!is_app(e) || to_app(e)->get_family_id() != get_id()) { + ctx.internalize(e); + return true; + } + m_stack.push_back(sat::eframe(e)); + return false; + } + + bool solver::visited(expr* e) { + euf::enode* n = expr2enode(e); + return n && n->is_attached_to(get_id()); + } + + bool solver::post_visit(expr* e, bool sign, bool root) { + euf::enode* n = expr2enode(e); + app* a = to_app(e); + + if (visited(e)) + return true; + + SASSERT(!n || !n->is_attached_to(get_id())); + if (!n) + n = mk_enode(e, false); + + SASSERT(!n->is_attached_to(get_id())); + mk_var(n); + SASSERT(n->is_attached_to(get_id())); + internalize_polysat(a); + return true; + } + + void solver::internalize_polysat(app* a) { + +#define if_unary(F) if (a->get_num_args() == 1) { internalize_unary(a, [&](pdd const& p) { return F(p); }); break; } + + switch (a->get_decl_kind()) { + case OP_BMUL: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p * q; }); break; + case OP_BADD: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p + q; }); break; + case OP_BSUB: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p - q; }); break; + case OP_BLSHR: internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.lshr(p, q); }); break; + case OP_BSHL: internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.shl(p, q); }); break; + case OP_BAND: internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.band(p, q); }); break; + case OP_BOR: internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.bor(p, q); }); break; + case OP_BXOR: internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.bxor(p, q); }); break; + case OP_BNAND: if_unary(m_core.bnot); internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.bnand(p, q); }); break; + case OP_BNOR: if_unary(m_core.bnot); internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.bnor(p, q); }); break; + case OP_BXNOR: if_unary(m_core.bnot); internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.bxnor(p, q); }); break; + case OP_BNOT: internalize_unary(a, [&](pdd const& p) { return m_core.bnot(p); }); break; + case OP_BNEG: internalize_unary(a, [&](pdd const& p) { return -p; }); break; + case OP_MKBV: internalize_mkbv(a); break; + case OP_BV_NUM: internalize_num(a); break; + case OP_ULEQ: internalize_le(a); break; + case OP_SLEQ: internalize_le(a); break; + case OP_UGEQ: internalize_le(a); break; + case OP_SGEQ: internalize_le(a); break; + case OP_ULT: internalize_le(a); break; + case OP_SLT: internalize_le(a); break; + case OP_UGT: internalize_le(a); break; + case OP_SGT: internalize_le(a); break; + + case OP_BUMUL_NO_OVFL: internalize_binaryc(a, [&](pdd const& p, pdd const& q) { return m_core.umul_ovfl(p, q); }); break; + case OP_BSMUL_NO_OVFL: internalize_binaryc(a, [&](pdd const& p, pdd const& q) { return m_core.smul_ovfl(p, q); }); break; + case OP_BSMUL_NO_UDFL: internalize_binaryc(a, [&](pdd const& p, pdd const& q) { return m_core.smul_udfl(p, q); }); break; + + case OP_BUMUL_OVFL: + case OP_BSMUL_OVFL: + case OP_BSDIV_OVFL: + case OP_BNEG_OVFL: + case OP_BUADD_OVFL: + case OP_BSADD_OVFL: + case OP_BUSUB_OVFL: + case OP_BSSUB_OVFL: + // handled by bv_rewriter for now + UNREACHABLE(); + break; + + case OP_BUDIV_I: internalize_div_rem_i(a, true); break; + case OP_BUREM_I: internalize_div_rem_i(a, false); break; + + case OP_BUDIV: internalize_div_rem(a, true); break; + case OP_BUREM: internalize_div_rem(a, false); break; + case OP_BSDIV0: UNREACHABLE(); break; + case OP_BUDIV0: UNREACHABLE(); break; + case OP_BSREM0: UNREACHABLE(); break; + case OP_BUREM0: UNREACHABLE(); break; + case OP_BSMOD0: UNREACHABLE(); break; + + case OP_EXTRACT: internalize_extract(a); break; + case OP_CONCAT: internalize_concat(a); break; + case OP_ZERO_EXT: internalize_par_unary(a, [&](pdd const& p, unsigned sz) { return m_core.zero_ext(p, sz); }); break; + case OP_SIGN_EXT: internalize_par_unary(a, [&](pdd const& p, unsigned sz) { return m_core.sign_ext(p, sz); }); break; + + // polysat::solver should also support at least: + case OP_BREDAND: // x == 2^K - 1 unary, return single bit, 1 if all input bits are set. + case OP_BREDOR: // x > 0 unary, return single bit, 1 if at least one input bit is set. + case OP_BCOMP: // x == y binary, return single bit, 1 if the arguments are equal. + case OP_BSDIV: + case OP_BSREM: + case OP_BSMOD: + case OP_BSDIV_I: + case OP_BSREM_I: + case OP_BSMOD_I: + case OP_BASHR: + IF_VERBOSE(0, verbose_stream() << mk_pp(a, m) << "\n"); + NOT_IMPLEMENTED_YET(); + return; + default: + IF_VERBOSE(0, verbose_stream() << mk_pp(a, m) << "\n"); + NOT_IMPLEMENTED_YET(); + return; + } +#undef if_unary + } + + class solver::mk_atom_trail : public trail { + solver& th; + sat::bool_var m_var; + public: + mk_atom_trail(sat::bool_var v, solver& th) : th(th), m_var(v) {} + void undo() override { + solver::atom* a = th.get_bv2a(m_var); + a->~atom(); + th.erase_bv2a(m_var); + } + }; + + solver::atom* solver::mk_atom(sat::bool_var bv) { + atom* a = get_bv2a(bv); + if (a) + return a; + a = new (get_region()) atom(bv); + insert_bv2a(bv, a); + ctx.push(mk_atom_trail(bv, *this)); + return a; + } + + void solver::internalize_binaryc(app* e, std::function const& fn) { + auto p = expr2pdd(e->get_arg(0)); + auto q = expr2pdd(e->get_arg(1)); + auto sc = ~fn(p, q); + sat::literal lit = expr2literal(e); + mk_atom(lit.var())->m_sc = sc; + } + + void solver::internalize_div_rem_i(app* e, bool is_div) { + auto p = expr2pdd(e->get_arg(0)); + auto q = expr2pdd(e->get_arg(1)); + auto [quot, rem] = m_core.quot_rem(p, q); + internalize_set(e, is_div ? quot : rem); + } + + void solver::internalize_div_rem(app* e, bool is_div) { + bv_rewriter_params p(s().params()); + if (p.hi_div0()) { + internalize_div_rem_i(e, is_div); + return; + } + expr* arg1 = e->get_arg(0); + expr* arg2 = e->get_arg(1); + unsigned sz = bv.get_bv_size(e); + expr_ref zero(bv.mk_numeral(0, sz), m); + sat::literal eqZ = eq_internalize(arg2, zero); + sat::literal eqU = eq_internalize(e, is_div ? bv.mk_bv_udiv0(arg1) : bv.mk_bv_urem0(arg1)); + sat::literal eqI = eq_internalize(e, is_div ? bv.mk_bv_udiv_i(arg1, arg2) : bv.mk_bv_urem_i(arg1, arg2)); + add_clause(~eqZ, eqU); + add_clause(eqZ, eqI); + ctx.add_aux(~eqZ, eqU); + ctx.add_aux(eqZ, eqI); + } + + void solver::internalize_num(app* a) { + rational val; + unsigned sz = 0; + VERIFY(bv.is_numeral(a, val, sz)); + auto p = m_core.value(val, sz); + internalize_set(a, p); + } + + // TODO - test that internalize works with recursive call on bit2bool + void solver::internalize_mkbv(app* a) { + unsigned i = 0; + for (expr* arg : *a) { + expr_ref b2b(m); + b2b = bv.mk_bit2bool(a, i); + sat::literal bit_i = ctx.internalize(b2b, false, false); + sat::literal lit = expr2literal(arg); + add_equiv(lit, bit_i); +#if 0 + ctx.add_aux_equiv(lit, bit_i); +#endif + ++i; + } + } + + void solver::internalize_extract(app* e) { + unsigned const hi = bv.get_extract_high(e); + unsigned const lo = bv.get_extract_low(e); + auto const src = expr2pdd(e->get_arg(0)); + auto const p = m_core.extract(src, hi, lo); + SASSERT_EQ(p.power_of_2(), hi - lo + 1); + internalize_set(e, p); + } + + void solver::internalize_concat(app* e) { + SASSERT(bv.is_concat(e)); + vector args; + for (expr* arg : *e) + args.push_back(expr2pdd(arg)); + auto const p = m_core.concat(args.size(), args.data()); + internalize_set(e, p); + } + + void solver::internalize_par_unary(app* e, std::function const& fn) { + pdd const p = expr2pdd(e->get_arg(0)); + unsigned const par = e->get_parameter(0).get_int(); + internalize_set(e, fn(p, par)); + } + + void solver::internalize_binary(app* e, std::function const& fn) { + SASSERT(e->get_num_args() >= 1); + auto p = expr2pdd(e->get_arg(0)); + for (unsigned i = 1; i < e->get_num_args(); ++i) + p = fn(p, expr2pdd(e->get_arg(i))); + internalize_set(e, p); + } + + void solver::internalize_unary(app* e, std::function const& fn) { + SASSERT(e->get_num_args() == 1); + auto p = expr2pdd(e->get_arg(0)); + internalize_set(e, fn(p)); + } + + template + void solver::internalize_le(app* e) { + auto p = expr2pdd(e->get_arg(0)); + auto q = expr2pdd(e->get_arg(1)); + if (Rev) + std::swap(p, q); + auto sc = Signed ? m_core.sle(p, q) : m_core.ule(p, q); + if (Negated) + sc = ~sc; + + sat::literal lit = expr2literal(e); + atom* a = mk_atom(lit.var()); + a->m_sc = sc; + } + + void solver::internalize_bit2bool(atom* a, expr* e, unsigned idx) { + pdd p = expr2pdd(e); + a->m_sc = m_core.bit(p, idx); + } + + dd::pdd solver::expr2pdd(expr* e) { + return var2pdd(get_th_var(e)); + } + + dd::pdd solver::var2pdd(euf::theory_var v) { + if (!m_var2pdd_valid.get(v, false)) { + unsigned bv_size = get_bv_size(v); + pvar pv = m_core.add_var(bv_size); + m_pddvar2var.setx(pv, v, UINT_MAX); + pdd p = m_core.var(pv); + internalize_set(v, p); + return p; + } + return m_var2pdd[v]; + } + + void solver::apply_sort_cnstr(euf::enode* n, sort* s) { + if (!bv.is_bv(n->get_expr())) + return; + theory_var v = n->get_th_var(get_id()); + if (v == euf::null_theory_var) + v = mk_var(n); + var2pdd(v); + } + + void solver::internalize_set(expr* e, pdd const& p) { + internalize_set(get_th_var(e), p); + } + + void solver::internalize_set(euf::theory_var v, pdd const& p) { + SASSERT_EQ(get_bv_size(v), p.power_of_2()); + m_var2pdd.reserve(get_num_vars(), p); + m_var2pdd_valid.reserve(get_num_vars(), false); + ctx.push(set_bitvector_trail(m_var2pdd_valid, v)); +#if 0 + m_var2pdd[v].reset(p.manager()); +#endif + m_var2pdd[v] = p; + } + + void solver::eq_internalized(euf::enode* n) { + SASSERT(m.is_eq(n->get_expr())); + } + + +} diff --git a/src/sat/smt/polysat_model.cpp b/src/sat/smt/polysat_model.cpp new file mode 100644 index 000000000..383a3f692 --- /dev/null +++ b/src/sat/smt/polysat_model.cpp @@ -0,0 +1,58 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + polysat_model.cpp + +Abstract: + + PolySAT model generation + +Author: + + Nikolaj Bjorner (nbjorner) 2022-01-26 + +--*/ + +#include "params/bv_rewriter_params.hpp" +#include "sat/smt/polysat_solver.h" +#include "sat/smt/euf_solver.h" + +namespace polysat { + + + void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { +#if 0 + auto p = expr2pdd(n->get_expr()); + rational val; + VERIFY(m_polysat.try_eval(p, val)); + values[n->get_root_id()] = bv.mk_numeral(val, get_bv_size(n)); +#endif + } + + + bool solver::check_model(sat::model const& m) const { + return false; + } + + void solver::finalize_model(model& mdl) { + + } + + std::ostream& solver::display_justification(std::ostream& out, sat::ext_justification_idx idx) const { + return out; + } + + std::ostream& solver::display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const { + return out; + } + + std::ostream& solver::display(std::ostream& out) const { + m_core.display(out); + for (unsigned v = 0; v < get_num_vars(); ++v) + if (m_var2pdd_valid.get(v, false)) + out << ctx.bpp(var2enode(v)) << " := " << m_var2pdd[v] << "\n"; + return out; + } +} diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp new file mode 100644 index 000000000..cee129327 --- /dev/null +++ b/src/sat/smt/polysat_solver.cpp @@ -0,0 +1,191 @@ +/*--- +Copyright (c 2022 Microsoft Corporation + +Module Name: + + polysat_internalize.cpp + +Abstract: + + PolySAT interface to bit-vector + +Author: + + Nikolaj Bjorner (nbjorner) 2022-01-26 + +Notes: +The solver adds literals to polysat::core, calls propagation and check +The result of polysat::core::check is one of: +- is_sat: the model is complete +- is_unsat: there is a Boolean conflict. The SAT solver backtracks and resolves the conflict. +- new_eq: the solver adds a new equality literal to the SAT solver. +- new_lemma: there is a conflict, but it is resolved by backjumping and adding a lemma to the SAT solver. +- giveup: Polysat was unable to determine satisfiability. + +--*/ + +#include "ast/euf/euf_bv_plugin.h" +#include "sat/smt/polysat_solver.h" +#include "sat/smt/euf_solver.h" + + +namespace polysat { + + solver::solver(euf::solver& ctx, theory_id id): + euf::th_euf_solver(ctx, symbol("bv"), id), + bv(ctx.get_manager()), + m_autil(ctx.get_manager()), + m_core(*this), + m_lemma(ctx.get_manager()) + { + ctx.get_egraph().add_plugin(alloc(euf::bv_plugin, ctx.get_egraph())); + } + + unsigned solver::get_bv_size(euf::enode* n) { + return bv.get_bv_size(n->get_expr()); + } + + unsigned solver::get_bv_size(theory_var v) { + return bv.get_bv_size(var2expr(v)); + } + + bool solver::unit_propagate() { + return m_core.propagate(); + } + + sat::check_result solver::check() { + return m_core.check(); + } + + void solver::asserted(literal l) { + atom* a = get_bv2a(l.var()); + TRACE("bv", tout << "asserted: " << l << "\n";); + if (!a) + return; + force_push(); + auto sc = a->m_sc; + if (l.sign()) + sc = ~sc; + m_core.assign_eh(sc, dependency(l, s().lvl(l))); + } + + void solver::set_conflict(dependency_vector const& core) { + auto [lits, eqs] = explain_deps(core); + auto ex = euf::th_explain::conflict(*this, lits, eqs, nullptr); + ctx.set_conflict(ex); + } + + std::pair solver::explain_deps(dependency_vector const& deps) { + sat::literal_vector core; + euf::enode_pair_vector eqs; + for (auto d : deps) { + if (d.is_literal()) { + core.push_back(d.literal()); + } + else { + auto const [v1, v2] = m_var_eqs[d.index()]; + euf::enode* const n1 = var2enode(v1); + euf::enode* const n2 = var2enode(v2); + VERIFY(n1->get_root() == n2->get_root()); + eqs.push_back(euf::enode_pair(n1, n2)); + } + } + DEBUG_CODE({ + for (auto lit : core) + VERIFY(s().value(lit) == l_true); + for (auto const& [n1, n2] : eqs) + VERIFY(n1->get_root() == n2->get_root()); + }); + IF_VERBOSE(10, + for (auto lit : core) + verbose_stream() << " " << lit << ": " << mk_ismt2_pp(literal2expr(lit), m) << "\n"; + for (auto const& [n1, n2] : eqs) + verbose_stream() << " " << ctx.bpp(n1) << " == " << ctx.bpp(n2) << "\n";); + + return { core, eqs }; + } + + void solver::set_lemma(vector const& lemma, unsigned level, dependency_vector const& core) { + auto [lits, eqs] = explain_deps(core); + auto ex = euf::th_explain::conflict(*this, lits, eqs, nullptr); + ctx.push(value_trail(m_has_lemma)); + m_has_lemma = true; + m_lemma_level = level; + m_lemma.reset(); + for (auto sc : lemma) + m_lemma.push_back(m_core.constraint2expr(sc)); + ctx.set_conflict(ex); + } + + // If an MCSat lemma is added, then backjump based on the lemma level and + // add the lemma to the solver with potentially fresh literals. + // return l_false to signal sat::solver that backjumping has been taken care of internally. + lbool solver::resolve_conflict() { + if (!m_has_lemma) + return l_undef; + + unsigned num_scopes = s().scope_lvl() - m_lemma_level; + + // s().pop_reinit(num_scopes); + + sat::literal_vector lits; + for (auto* e : m_lemma) + lits.push_back(ctx.mk_literal(e)); + s().add_clause(lits.size(), lits.data(), sat::status::th(true, get_id(), nullptr)); + return l_false; + } + + // Create an equality literal that represents the value assignment + // Prefer case split to true. + // The equality gets added in a callback using asserted(). + void solver::add_eq_literal(pvar pvar, rational const& val) { + auto v = m_pddvar2var[pvar]; + auto n = var2enode(v); + auto eq = eq_internalize(n->get_expr(), bv.mk_numeral(val, get_bv_size(v))); + s().set_phase(eq); + } + + void solver::new_eq_eh(euf::th_eq const& eq) { + auto v1 = eq.v1(), v2 = eq.v2(); + pdd p = var2pdd(v1); + pdd q = var2pdd(v2); + auto sc = m_core.eq(p, q); + m_var_eqs.setx(m_var_eqs_head, std::make_pair(v1, v2), std::make_pair(v1, v2)); + ctx.push(value_trail(m_var_eqs_head)); + m_core.assign_eh(sc, dependency(m_var_eqs_head, s().scope_lvl())); + m_var_eqs_head++; + } + + void solver::new_diseq_eh(euf::th_eq const& ne) { + euf::theory_var v1 = ne.v1(), v2 = ne.v2(); + pdd p = var2pdd(v1); + pdd q = var2pdd(v2); + auto sc = ~m_core.eq(p, q); + sat::literal neq = ~expr2literal(ne.eq()); + TRACE("bv", tout << neq << " := " << s().value(neq) << " @" << s().scope_lvl() << "\n"); + m_core.assign_eh(sc, dependency(neq, s().lvl(neq))); + } + + // Core uses the propagate callback to add unit propagations to the trail. + // The polysat::solver takes care of translating signed constraints into expressions, which translate into literals. + // Everything goes over expressions/literals. polysat::core is not responsible for replaying expressions. + void solver::propagate(signed_constraint sc, dependency_vector const& deps) { + sat::literal lit = ctx.mk_literal(m_core.constraint2expr(sc)); + auto [core, eqs] = explain_deps(deps); + auto ex = euf::th_explain::propagate(*this, core, eqs, lit, nullptr); + ctx.propagate(lit, ex); + } + + void solver::add_lemma(vector const& lemma) { + sat::literal_vector lits; + for (auto sc : lemma) + lits.push_back(ctx.mk_literal(m_core.constraint2expr(sc))); + s().add_clause(lits.size(), lits.data(), sat::status::th(true, get_id(), nullptr)); + } + + void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) { + auto& jst = euf::th_explain::from_index(idx); + ctx.get_th_antecedents(l, jst, r, probing); + } + +} diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h new file mode 100644 index 000000000..5d9cd19a3 --- /dev/null +++ b/src/sat/smt/polysat_solver.h @@ -0,0 +1,187 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + polysat_solver.h + +Abstract: + + Theory plugin for bit-vectors + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-30 + +--*/ +#pragma once + +#include "sat/smt/sat_th.h" +#include "math/dd/dd_pdd.h" +#include "sat/smt/polysat_core.h" + +namespace euf { + class solver; +} + +namespace polysat { + + + class solver : public euf::th_euf_solver { + typedef euf::theory_var theory_var; + typedef euf::theory_id theory_id; + typedef sat::literal literal; + typedef sat::bool_var bool_var; + typedef sat::literal_vector literal_vector; + using pdd = dd::pdd; + + struct stats { + void reset() { memset(this, 0, sizeof(stats)); } + stats() { reset(); } + }; + + struct atom { + bool_var m_bv; + signed_constraint m_sc; + atom(bool_var b) :m_bv(b) {} + ~atom() { } + }; + + class polysat_proof : public euf::th_proof_hint { + public: + ~polysat_proof() override {} + expr* get_hint(euf::solver& s) const override { return nullptr; } + }; + + friend class core; + + bv_util bv; + arith_util m_autil; + stats m_stats; + core m_core; + polysat_proof m_proof; + + vector m_var2pdd; // theory_var 2 pdd + bool_vector m_var2pdd_valid; // valid flag + unsigned_vector m_pddvar2var; // pvar -> theory_var + ptr_vector m_bool_var2atom; // bool_var -> atom + + svector> m_var_eqs; + unsigned m_var_eqs_head = 0; + + bool m_has_lemma = false; + unsigned m_lemma_level = 0; + expr_ref_vector m_lemma; + + // internalize + bool visit(expr* e) override; + bool visited(expr* e) override; + bool post_visit(expr* e, bool sign, bool root) override; + unsigned get_bv_size(euf::enode* n); + unsigned get_bv_size(theory_var v); + theory_var get_var(euf::enode* n); + void fixed_var_eh(theory_var v); + bool is_fixed(euf::theory_var v, expr_ref& val, sat::literal_vector& lits) override { return false; } + bool is_bv(theory_var v) const { return bv.is_bv(var2expr(v)); } + void register_true_false_bit(theory_var v, unsigned i); + void add_bit(theory_var v, sat::literal lit); + void eq_internalized(sat::bool_var b1, sat::bool_var b2, unsigned idx, theory_var v1, theory_var v2, sat::literal eq, euf::enode* n); + + void insert_bv2a(bool_var bv, atom* a) { m_bool_var2atom.setx(bv, a, nullptr); } + void erase_bv2a(bool_var bv) { m_bool_var2atom[bv] = nullptr; } + atom* get_bv2a(bool_var bv) const { return m_bool_var2atom.get(bv, nullptr); } + class mk_atom_trail; + atom* mk_atom(sat::bool_var bv); + void set_bit_eh(theory_var v, literal l, unsigned idx); + void init_bits(expr* e, expr_ref_vector const & bits); + void mk_bits(theory_var v); + void add_def(sat::literal def, sat::literal l); + void internalize_unary(app* e, std::function const& fn); + void internalize_binary(app* e, std::function const& fn); + void internalize_binaryc(app* e, std::function const& fn); + void internalize_par_unary(app* e, std::function const& fn); + void internalize_novfl(app* n, std::function& fn); + void internalize_interp(app* n, std::function& ibin, std::function& un); + void internalize_num(app * n); + void internalize_concat(app * n); + void internalize_bv2int(app* n); + void internalize_int2bv(app* n); + void internalize_mkbv(app* n); + void internalize_xor3(app* n); + void internalize_carry(app* n); + void internalize_sub(app* n); + void internalize_extract(app* n); + void internalize_repeat(app* n); + void internalize_bit2bool(app* n); + void internalize_udiv_i(app* n); + template + void internalize_le(app* n); + void internalize_div_rem_i(app* e, bool is_div); + void internalize_div_rem(app* e, bool is_div); + void internalize_polysat(app* a); + void assert_bv2int_axiom(app * n); + void assert_int2bv_axiom(app* n); + void internalize_bit2bool(atom* a, expr* e, unsigned idx); + + pdd expr2pdd(expr* e); + pdd var2pdd(euf::theory_var v); + void internalize_set(expr* e, pdd const& p); + void internalize_set(euf::theory_var v, pdd const& p); + + // callbacks from core + void add_eq_literal(pvar v, rational const& val); + void set_conflict(dependency_vector const& core); + void set_lemma(vector const& lemma, unsigned level, dependency_vector const& core); + void propagate(signed_constraint sc, dependency_vector const& deps); + void add_lemma(vector const& lemma); + + std::pair explain_deps(dependency_vector const& deps); + + public: + solver(euf::solver& ctx, theory_id id); + void set_lookahead(sat::lookahead* s) override { } + void init_search() override {} + double get_reward(literal l, sat::ext_constraint_idx idx, sat::literal_occs_fun& occs) const override { return 0; } + bool is_extended_binary(sat::ext_justification_idx idx, literal_vector& r) override { return false; } + bool is_external(bool_var v) override { return false; } + void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector & r, bool probing) override; + void asserted(literal l) override; + sat::check_result check() override; + void simplify() override {} + void clauses_modifed() override {} + lbool get_phase(bool_var v) override { return l_undef; } + std::ostream& display(std::ostream& out) const override; + std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override; + std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override; + void collect_statistics(statistics& st) const override {} + euf::th_solver* clone(euf::solver& ctx) override { throw default_exception("nyi"); } + extension* copy(sat::solver* s) override { throw default_exception("nyi"); } + void find_mutexes(literal_vector& lits, vector & mutexes) override {} + void gc() override {} + void pop_reinit() override {} + lbool resolve_conflict() override; + bool validate() override { return true; } + void init_use_list(sat::ext_use_list& ul) override {} + bool is_blocked(literal l, sat::ext_constraint_idx) override { return false; } + bool check_model(sat::model const& m) const override; + void finalize_model(model& mdl) override; + + void new_eq_eh(euf::th_eq const& eq) override; + void new_diseq_eh(euf::th_eq const& ne) override; + bool use_diseqs() const override { return true; } + bool unit_propagate() override; + + void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; + + bool extract_pb(std::function& card, + std::function& pb) override { return false; } + + bool to_formulas(std::function& l2e, expr_ref_vector& fmls) override { return false; } + sat::literal internalize(expr* e, bool sign, bool root) override; + void internalize(expr* e) override; + void eq_internalized(euf::enode* n) override; + euf::theory_var mk_var(euf::enode* n) override; + void apply_sort_cnstr(euf::enode * n, sort * s) override; + }; + +} diff --git a/src/sat/smt/polysat_substitution.h b/src/sat/smt/polysat_substitution.h new file mode 100644 index 000000000..a30c6b710 --- /dev/null +++ b/src/sat/smt/polysat_substitution.h @@ -0,0 +1,212 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat substitution + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once +#include "sat/smt/polysat_types.h" + +namespace polysat { + + using assignment_item_t = std::pair; + + class substitution_iterator { + pdd m_current; + substitution_iterator(pdd current) : m_current(std::move(current)) {} + friend class substitution; + + public: + using value_type = assignment_item_t; + using difference_type = std::ptrdiff_t; + using pointer = value_type const*; + using reference = value_type const&; + using iterator_category = std::input_iterator_tag; + + substitution_iterator& operator++() { + SASSERT(!m_current.is_val()); + m_current = m_current.hi(); + return *this; + } + + value_type operator*() const { + SASSERT(!m_current.is_val()); + return { m_current.var(), m_current.lo().val() }; + } + + bool operator==(substitution_iterator const& other) const { return m_current == other.m_current; } + bool operator!=(substitution_iterator const& other) const { return !operator==(other); } + }; + + /** Substitution for a single bit width. */ + class substitution { + pdd m_subst; + + substitution(pdd p); + + public: + substitution(dd::pdd_manager& m); + [[nodiscard]] substitution add(pvar var, rational const& value) const; + [[nodiscard]] pdd apply_to(pdd const& p) const; + + [[nodiscard]] bool contains(pvar var) const; + [[nodiscard]] bool value(pvar var, rational& out_value) const; + + [[nodiscard]] bool empty() const { return m_subst.is_one(); } + + pdd const& to_pdd() const { return m_subst; } + unsigned bit_width() const { return to_pdd().power_of_2(); } + + bool operator==(substitution const& other) const { return &m_subst.manager() == &other.m_subst.manager() && m_subst == other.m_subst; } + bool operator!=(substitution const& other) const { return !operator==(other); } + + std::ostream& display(std::ostream& out) const; + + using const_iterator = substitution_iterator; + const_iterator begin() const { return {m_subst}; } + const_iterator end() const { return {m_subst.manager().one()}; } + }; + + /** Full variable assignment, may include variables of varying bit widths. */ + class assignment { + vector m_pairs; + mutable scoped_ptr_vector m_subst; + vector m_subst_trail; + + substitution& subst(unsigned sz); + solver& s() const { return *m_solver; } + public: + assignment(solver& s); + // prevent implicit copy, use clone() if you do need a copy + assignment(assignment const&) = delete; + assignment& operator=(assignment const&) = delete; + assignment(assignment&&) = default; + assignment& operator=(assignment&&) = default; + assignment clone() const; + + void push(pvar var, rational const& value); + void pop(); + + pdd apply_to(pdd const& p) const; + + bool contains(pvar var) const; + bool value(pvar var, rational& out_value) const; + rational value(pvar var) const { rational val; VERIFY(value(var, val)); return val; } + bool empty() const { return pairs().empty(); } + substitution const& subst(unsigned sz) const; + vector const& pairs() const { return m_pairs; } + using const_iterator = decltype(m_pairs)::const_iterator; + const_iterator begin() const { return pairs().begin(); } + const_iterator end() const { return pairs().end(); } + + std::ostream& display(std::ostream& out) const; + }; + + inline std::ostream& operator<<(std::ostream& out, substitution const& sub) { return sub.display(out); } + + inline std::ostream& operator<<(std::ostream& out, assignment const& a) { return a.display(out); } +} + +namespace polysat { + + enum class search_item_k + { + assignment, + boolean, + }; + + class search_item { + search_item_k m_kind; + union { + pvar m_var; + sat::literal m_lit; + }; + bool m_resolved = false; // when marked as resolved it is no longer valid to reduce the conflict state + + search_item(pvar var): m_kind(search_item_k::assignment), m_var(var) {} + search_item(sat::literal lit): m_kind(search_item_k::boolean), m_lit(lit) {} + public: + static search_item assignment(pvar var) { return search_item(var); } + static search_item boolean(sat::literal lit) { return search_item(lit); } + bool is_assignment() const { return m_kind == search_item_k::assignment; } + bool is_boolean() const { return m_kind == search_item_k::boolean; } + bool is_resolved() const { return m_resolved; } + search_item_k kind() const { return m_kind; } + pvar var() const { SASSERT(is_assignment()); return m_var; } + sat::literal lit() const { SASSERT(is_boolean()); return m_lit; } + void set_resolved() { m_resolved = true; } + }; + + class search_state { + solver& s; + + vector m_items; + assignment m_assignment; + + // store index into m_items + unsigned_vector m_pvar_to_idx; + unsigned_vector m_bool_to_idx; + + bool value(pvar v, rational& r) const; + + public: + search_state(solver& s): s(s), m_assignment(s) {} + unsigned size() const { return m_items.size(); } + search_item const& back() const { return m_items.back(); } + search_item const& operator[](unsigned i) const { return m_items[i]; } + + assignment const& get_assignment() const { return m_assignment; } + substitution const& subst(unsigned sz) const { return m_assignment.subst(sz); } + + // TODO: implement the following method if we actually need the assignments without resolved items already during conflict resolution + // (no separate trail needed, just a second m_subst and an index into the trail, I think) + // (update on set_resolved? might be one iteration too early, looking at the old solver::resolve_conflict loop) + substitution const& unresolved_assignment(unsigned sz) const; + + void push_assignment(pvar v, rational const& r); + void push_boolean(sat::literal lit); + void pop(); + + unsigned get_pvar_index(pvar v) const; + unsigned get_bool_index(sat::bool_var var) const; + unsigned get_bool_index(sat::literal lit) const { return get_bool_index(lit.var()); } + + void set_resolved(unsigned i) { m_items[i].set_resolved(); } + + using const_iterator = decltype(m_items)::const_iterator; + const_iterator begin() const { return m_items.begin(); } + const_iterator end() const { return m_items.end(); } + + std::ostream& display(std::ostream& out) const; + std::ostream& display(search_item const& item, std::ostream& out) const; + std::ostream& display_verbose(std::ostream& out) const; + std::ostream& display_verbose(search_item const& item, std::ostream& out) const; + }; + + struct search_state_pp { + search_state const& s; + bool verbose; + search_state_pp(search_state const& s, bool verbose = false) : s(s), verbose(verbose) {} + }; + + struct search_item_pp { + search_state const& s; + search_item const& i; + bool verbose; + search_item_pp(search_state const& s, search_item const& i, bool verbose = false) : s(s), i(i), verbose(verbose) {} + }; + + inline std::ostream& operator<<(std::ostream& out, search_state const& s) { return s.display(out); } + + inline std::ostream& operator<<(std::ostream& out, search_state_pp const& p) { return p.verbose ? p.s.display_verbose(out) : p.s.display(out); } + + inline std::ostream& operator<<(std::ostream& out, search_item_pp const& p) { return p.verbose ? p.s.display_verbose(p.i, out) : p.s.display(p.i, out); } + +} diff --git a/src/sat/smt/polysat_types.h b/src/sat/smt/polysat_types.h new file mode 100644 index 000000000..4296a8247 --- /dev/null +++ b/src/sat/smt/polysat_types.h @@ -0,0 +1,45 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once + +#include "math/dd/dd_pdd.h" +#include "util/sat_literal.h" +#include "util/dependency.h" + +namespace polysat { + + using pdd = dd::pdd; + using pvar = unsigned; + + + class dependency { + unsigned m_index; + unsigned m_level; + public: + dependency(sat::literal lit, unsigned level) : m_index(2 * lit.index()), m_level(level) {} + dependency(unsigned var_idx, unsigned level) : m_index(1 + 2 * var_idx), m_level(level) {} + bool is_literal() const { return m_index % 2 == 0; } + sat::literal literal() const { SASSERT(is_literal()); return sat::to_literal(m_index / 2); } + unsigned index() const { SASSERT(!is_literal()); return (m_index - 1) / 2; } + unsigned level() const { return m_level; } + }; + + using stacked_dependency = stacked_dependency_manager::dependency; + + inline std::ostream& operator<<(std::ostream& out, dependency d) { + if (d.is_literal()) + return out << d.literal() << "@" << d.level(); + else + return out << "v" << d.index() << "@" << d.level(); + } + + using dependency_vector = vector; + +} diff --git a/src/sat/smt/polysat_viable.h b/src/sat/smt/polysat_viable.h new file mode 100644 index 000000000..def069652 --- /dev/null +++ b/src/sat/smt/polysat_viable.h @@ -0,0 +1,55 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + maintain viable domains + It uses the interval extraction functions from forbidden intervals. + An empty viable set corresponds directly to a conflict that does not rely on + the non-viable variable. + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once + +#include "util/rational.h" +#include "sat/smt/polysat_types.h" + +namespace polysat { + + enum class find_t { + empty, + singleton, + multiple, + resource_out, + }; + + class core; + + class viable { + core& c; + public: + viable(core& c) : c(c) {} + + /** + * Find a next viable value for variable. + */ + find_t find_viable(pvar v, rational& out_val) { throw default_exception("nyi"); } + + /* + * Explain why the current variable is not viable or signleton. + */ + dependency_vector explain() { throw default_exception("nyi"); } + + /* + * Register constraint at index 'idx' as unitary in v. + */ + void add_unitary(pvar v, unsigned idx) { throw default_exception("nyi"); } + + }; + +} diff --git a/src/util/var_queue.h b/src/util/var_queue.h index 9807e5ac2..0af4de3b8 100644 --- a/src/util/var_queue.h +++ b/src/util/var_queue.h @@ -68,6 +68,8 @@ public: void reset() { m_queue.reset(); } + + bool contains(var v) const { return m_queue.contains(v); } bool empty() const { return m_queue.empty(); } From 9df89e1640ea7a8bd8af218554246761c7955f16 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 7 Dec 2023 16:03:30 -0800 Subject: [PATCH 02/89] tidy --- src/sat/smt/polysat_core.cpp | 2 +- src/sat/smt/polysat_core.h | 2 -- src/sat/smt/polysat_internalize.cpp | 4 ++-- src/sat/smt/polysat_solver.cpp | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp index 27d6ee731..931a92992 100644 --- a/src/sat/smt/polysat_core.cpp +++ b/src/sat/smt/polysat_core.cpp @@ -225,7 +225,7 @@ namespace polysat { // constraint is unitary, add to viable set if (vars.size() >= 2 && is_assigned(vars[0]) && !is_assigned(vars[1])) { - // m_viable.add_unitary(vars[1], idx); + m_viable.add_unitary(vars[1], idx); } } m_watch[v].shrink(j); diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat_core.h index 7fdf8c88c..075840a48 100644 --- a/src/sat/smt/polysat_core.h +++ b/src/sat/smt/polysat_core.h @@ -59,8 +59,6 @@ namespace polysat { var_queue m_var_queue; // priority queue of variables to assign vector m_watch; // watch lists for variables for constraints on m_prop_queue where they occur - vector m_subst; // substitution, one for each size. - // values to split on rational m_value; pvar m_var = 0; diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 80c2fb19b..dd761fdc3 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -3,11 +3,11 @@ Copyright (c) 2022 Microsoft Corporation Module Name: - polysat_model.cpp + polysat_internalize.cpp Abstract: - PolySAT model generation + PolySAT internalize Author: diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index cee129327..58b2c97d1 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -3,7 +3,7 @@ Copyright (c 2022 Microsoft Corporation Module Name: - polysat_internalize.cpp + polysat_solver.cpp Abstract: From bb03f1f1ec424f36478d7793c6de3bb816166265 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 7 Dec 2023 19:43:08 -0800 Subject: [PATCH 03/89] allow propagation on equalities and literals that are not assigned. --- src/sat/smt/polysat_core.cpp | 102 ++++++++++++++++------------ src/sat/smt/polysat_core.h | 26 +++++-- src/sat/smt/polysat_internalize.cpp | 11 +-- src/sat/smt/polysat_solver.cpp | 30 ++++++-- src/sat/smt/polysat_solver.h | 5 +- 5 files changed, 114 insertions(+), 60 deletions(-) diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp index 931a92992..b97a223f7 100644 --- a/src/sat/smt/polysat_core.cpp +++ b/src/sat/smt/polysat_core.cpp @@ -66,16 +66,16 @@ namespace polysat { class core::mk_add_watch : public trail { core& c; - unsigned m_idx; public: - mk_add_watch(core& c, unsigned idx) : c(c), m_idx(idx) {} + mk_add_watch(core& c) : c(c) {} void undo() override { - auto& sc = c.m_prop_queue[m_idx].first; + auto& [sc, lit] = c.m_constraint_trail.back(); auto& vars = sc.vars(); if (vars.size() > 0) c.m_watch[vars[0]].pop_back(); if (vars.size() > 1) c.m_watch[vars[1]].pop_back(); + c.m_constraint_trail.pop_back(); } }; @@ -123,6 +123,22 @@ namespace polysat { m_var_queue.del_var_eh(v); } + unsigned core::register_constraint(signed_constraint& sc, solver_assertion as) { + unsigned idx = m_constraint_trail.size(); + m_constraint_trail.push_back({ sc, as }); + auto& vars = sc.vars(); + unsigned i = 0, j = 0, sz = vars.size(); + for (; i < sz && j < 2; ++i) + if (!is_assigned(vars[i])) + std::swap(vars[i], vars[j++]); + if (vars.size() > 0) + add_watch(idx, vars[0]); + if (vars.size() > 1) + add_watch(idx, vars[1]); + s.ctx.push(mk_add_watch(*this)); + return idx; + } + // case split on unassigned variables until all are assigned values. // create equality literal for unassigned variable. // return new_eq if there is a new literal. @@ -141,6 +157,7 @@ namespace polysat { s.propagate(m_constraints.eq(var2pdd(m_var), m_value), m_viable.explain()); return sat::check_result::CR_CONTINUE; case find_t::multiple: + s.add_eq_literal(m_var, m_value); return sat::check_result::CR_CONTINUE; case find_t::resource_out: return sat::check_result::CR_GIVEUP; @@ -155,33 +172,17 @@ namespace polysat { return false; s.ctx.push(value_trail(m_qhead)); for (; m_qhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_qhead) - propagate_constraint(m_qhead, m_prop_queue[m_qhead]); + propagate_constraint(m_prop_queue[m_qhead]); s.ctx.push(value_trail(m_vqhead)); for (; m_vqhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_vqhead) - propagate_value(m_vqhead, m_prop_queue[m_vqhead]); + propagate_value(m_prop_queue[m_vqhead]); return true; } - void core::propagate_constraint(unsigned idx, dependent_constraint& dc) { - auto [sc, dep] = dc; - if (sc.is_eq(m_var, m_value)) { + void core::propagate_constraint(prop_item& dc) { + auto [idx, sc, dep] = dc; + if (sc.is_eq(m_var, m_value)) propagate_assignment(m_var, m_value, dep); - return; - } - add_watch(idx, sc); - } - - void core::add_watch(unsigned idx, signed_constraint& sc) { - auto& vars = sc.vars(); - unsigned i = 0, j = 0, sz = vars.size(); - for (; i < sz && j < 2; ++i) - if (!is_assigned(vars[i])) - std::swap(vars[i], vars[j++]); - if (vars.size() > 0) - add_watch(idx, vars[0]); - if (vars.size() > 1) - add_watch(idx, vars[1]); - s.ctx.push(mk_add_watch(*this, idx)); } void core::add_watch(unsigned idx, unsigned var) { @@ -205,7 +206,7 @@ namespace polysat { // for entries where there is only one free variable left add to viable set unsigned j = 0; for (auto idx : m_watch[v]) { - auto [sc, dep] = m_prop_queue[idx]; + auto [sc, as] = m_constraint_trail[idx]; auto& vars = sc.vars(); if (vars[0] != v) std::swap(vars[0], vars[1]); @@ -219,39 +220,51 @@ namespace polysat { break; } } - if (!swapped) { - m_watch[v][j++] = idx; - } - // constraint is unitary, add to viable set - if (vars.size() >= 2 && is_assigned(vars[0]) && !is_assigned(vars[1])) { - m_viable.add_unitary(vars[1], idx); - } + SASSERT(!swapped || vars.size() <= 1 || (!is_assigned(vars[0]) && !is_assigned(vars[1]))); + if (swapped) + continue; + m_watch[v][j++] = idx; + if (vars.size() <= 1) + continue; + auto v1 = vars[1]; + if (is_assigned(v1)) + continue; + SASSERT(is_assigned(vars[0]) && vars.size() >= 2); + // detect unitary, add to viable, detect conflict? + m_viable.add_unitary(v1, idx); } m_watch[v].shrink(j); } - void core::propagate_value(unsigned idx, dependent_constraint const& dc) { - auto [sc, dep] = dc; + void core::propagate_value(prop_item const& dc) { + auto [idx, sc, dep] = dc; // check if sc evaluates to false switch (eval(sc)) { case l_true: return; case l_false: - m_unsat_core = explain_eval(dc); + m_unsat_core = explain_eval({ sc, dep }); propagate_unsat_core(); return; default: break; } - // if sc is v == value, then check the watch list for v if they evaluate to false. + // if sc is v == value, then check the watch list for v to propagate truth assignments if (sc.is_eq(m_var, m_value)) { for (auto idx : m_watch[m_var]) { - auto [sc, dep] = m_prop_queue[idx]; - if (eval(sc) == l_false) { - m_unsat_core = explain_eval({ sc, dep }); - propagate_unsat_core(); - return; + auto [sc, as] = m_constraint_trail[idx]; + switch (eval(sc)) { + case l_false: + m_unsat_core = explain_eval({ sc, nullptr }); + s.propagate(as, true, m_unsat_core); + break; + case l_true: + m_unsat_core = explain_eval({ sc, nullptr }); + s.propagate(as, false, m_unsat_core); + break; + default: + break; } } } @@ -259,15 +272,14 @@ namespace polysat { throw default_exception("nyi"); } - bool core::propagate_unsat_core() { + void core::propagate_unsat_core() { // default is to use unsat core: s.set_conflict(m_unsat_core); // if core is based on viable, use s.set_lemma(); - throw default_exception("nyi"); } - void core::assign_eh(signed_constraint const& sc, dependency const& dep) { - m_prop_queue.push_back({ sc, m_dep.mk_leaf(dep) }); + void core::assign_eh(unsigned index, signed_constraint const& sc, dependency const& dep) { + m_prop_queue.push_back({ index, sc, m_dep.mk_leaf(dep) }); s.ctx.push(push_back_vector(m_prop_queue)); } diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat_core.h index 075840a48..0c173da77 100644 --- a/src/sat/smt/polysat_core.h +++ b/src/sat/smt/polysat_core.h @@ -30,12 +30,25 @@ namespace polysat { class core; class solver; + struct solver_assertion { + unsigned m_var1; + unsigned m_var2 = 0; + public: + solver_assertion(sat::literal lit) : m_var1(2*lit.index()) {} + solver_assertion(unsigned v1, unsigned v2) : m_var1(1 + 2*v1), m_var2(v2) {} + bool is_literal() const { return m_var1 % 2 == 0; } + sat::literal get_literal() const { SASSERT(is_literal()); return sat::to_literal(m_var1 / 2); } + unsigned var1() const { SASSERT(!is_literal()); return (m_var1 - 1) / 2; } + unsigned var2() const { SASSERT(!is_literal()); return m_var2; } + }; + class core { class mk_add_var; class mk_dqueue_var; class mk_assign_var; class mk_add_watch; typedef svector> activity; + typedef std::tuple prop_item; friend class viable; friend class constraints; friend class assignment; @@ -45,7 +58,8 @@ namespace polysat { constraints m_constraints; assignment m_assignment; unsigned m_qhead = 0, m_vqhead = 0; - svector m_prop_queue; + svector m_prop_queue; + svector> m_constraint_trail; // stacked_dependency_manager m_dep; mutable scoped_ptr_vector m_pdd; dependency_vector m_unsat_core; @@ -69,12 +83,11 @@ namespace polysat { void del_var(); bool is_assigned(pvar v) { return nullptr != m_justification[v]; } - void propagate_constraint(unsigned idx, dependent_constraint& dc); - void propagate_value(unsigned idx, dependent_constraint const& dc); + void propagate_constraint(prop_item& dc); + void propagate_value(prop_item const& dc); void propagate_assignment(pvar v, rational const& value, stacked_dependency* dep); - bool propagate_unsat_core(); + void propagate_unsat_core(); - void add_watch(unsigned idx, signed_constraint& sc); void add_watch(unsigned idx, unsigned var); lbool eval(signed_constraint sc) { throw default_exception("nyi"); } @@ -85,8 +98,9 @@ namespace polysat { sat::check_result check(); + unsigned register_constraint(signed_constraint& sc, solver_assertion sa); bool propagate(); - void assign_eh(signed_constraint const& sc, dependency const& dep); + void assign_eh(unsigned idx, signed_constraint const& sc, dependency const& dep); expr_ref constraint2expr(signed_constraint const& sc) const { throw default_exception("nyi"); } diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index dd761fdc3..af53750fc 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -169,12 +169,15 @@ namespace polysat { } }; - solver::atom* solver::mk_atom(sat::bool_var bv) { + solver::atom* solver::mk_atom(sat::literal lit, signed_constraint& sc) { + auto bv = lit.var(); atom* a = get_bv2a(bv); if (a) return a; a = new (get_region()) atom(bv); insert_bv2a(bv, a); + a->m_sc = sc; + a->m_index = m_core.register_constraint(sc, lit); ctx.push(mk_atom_trail(bv, *this)); return a; } @@ -184,7 +187,7 @@ namespace polysat { auto q = expr2pdd(e->get_arg(1)); auto sc = ~fn(p, q); sat::literal lit = expr2literal(e); - mk_atom(lit.var())->m_sc = sc; + auto* a = mk_atom(lit, sc); } void solver::internalize_div_rem_i(app* e, bool is_div) { @@ -277,6 +280,7 @@ namespace polysat { template void solver::internalize_le(app* e) { + SASSERT(e->get_num_args() == 2); auto p = expr2pdd(e->get_arg(0)); auto q = expr2pdd(e->get_arg(1)); if (Rev) @@ -286,8 +290,7 @@ namespace polysat { sc = ~sc; sat::literal lit = expr2literal(e); - atom* a = mk_atom(lit.var()); - a->m_sc = sc; + atom* a = mk_atom(lit, sc); } void solver::internalize_bit2bool(atom* a, expr* e, unsigned idx) { diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 58b2c97d1..e3aff6e5f 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -66,7 +66,7 @@ namespace polysat { auto sc = a->m_sc; if (l.sign()) sc = ~sc; - m_core.assign_eh(sc, dependency(l, s().lvl(l))); + m_core.assign_eh(a->m_index, sc, dependency(l, s().lvl(l))); } void solver::set_conflict(dependency_vector const& core) { @@ -151,8 +151,10 @@ namespace polysat { pdd q = var2pdd(v2); auto sc = m_core.eq(p, q); m_var_eqs.setx(m_var_eqs_head, std::make_pair(v1, v2), std::make_pair(v1, v2)); - ctx.push(value_trail(m_var_eqs_head)); - m_core.assign_eh(sc, dependency(m_var_eqs_head, s().scope_lvl())); + ctx.push(value_trail(m_var_eqs_head)); + unsigned index = 0; +// unsigned index = m_core.register_constraint(sc); + m_core.assign_eh(index, sc, dependency(m_var_eqs_head, s().scope_lvl())); m_var_eqs_head++; } @@ -162,8 +164,9 @@ namespace polysat { pdd q = var2pdd(v2); auto sc = ~m_core.eq(p, q); sat::literal neq = ~expr2literal(ne.eq()); + auto index = m_core.register_constraint(sc, neq); TRACE("bv", tout << neq << " := " << s().value(neq) << " @" << s().scope_lvl() << "\n"); - m_core.assign_eh(sc, dependency(neq, s().lvl(neq))); + m_core.assign_eh(index, sc, dependency(neq, s().lvl(neq))); } // Core uses the propagate callback to add unit propagations to the trail. @@ -176,6 +179,25 @@ namespace polysat { ctx.propagate(lit, ex); } + void solver::propagate(solver_assertion as, bool sign, dependency_vector const& deps) { + auto [core, eqs] = explain_deps(deps); + if (as.is_literal()) { + auto lit = as.get_literal(); + if (sign) + lit.neg(); + auto ex = euf::th_explain::propagate(*this, core, eqs, lit, nullptr); + ctx.propagate(lit, ex); + } + else if (sign) { + // equalities are always asserted so a negative propagation is a conflict. + auto n1 = var2enode(as.var1()); + auto n2 = var2enode(as.var2()); + eqs.push_back({ n1, n2 }); + auto ex = euf::th_explain::conflict(*this, core, eqs, nullptr); + ctx.set_conflict(ex); + } + } + void solver::add_lemma(vector const& lemma) { sat::literal_vector lits; for (auto sc : lemma) diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 5d9cd19a3..7940a7223 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -42,6 +42,7 @@ namespace polysat { struct atom { bool_var m_bv; + unsigned m_index = 0; signed_constraint m_sc; atom(bool_var b) :m_bv(b) {} ~atom() { } @@ -91,7 +92,7 @@ namespace polysat { void erase_bv2a(bool_var bv) { m_bool_var2atom[bv] = nullptr; } atom* get_bv2a(bool_var bv) const { return m_bool_var2atom.get(bv, nullptr); } class mk_atom_trail; - atom* mk_atom(sat::bool_var bv); + atom* mk_atom(sat::literal lit, signed_constraint& sc); void set_bit_eh(theory_var v, literal l, unsigned idx); void init_bits(expr* e, expr_ref_vector const & bits); void mk_bits(theory_var v); @@ -133,6 +134,8 @@ namespace polysat { void set_conflict(dependency_vector const& core); void set_lemma(vector const& lemma, unsigned level, dependency_vector const& core); void propagate(signed_constraint sc, dependency_vector const& deps); + void propagate(solver_assertion as, bool sign, dependency_vector const& deps); + void add_lemma(vector const& lemma); std::pair explain_deps(dependency_vector const& deps); From f3fa6fdb84fd4b999b2e20eadf668b59065ae5c4 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 7 Dec 2023 19:47:55 -0800 Subject: [PATCH 04/89] n/a --- src/sat/smt/polysat_core.cpp | 4 +--- src/sat/smt/polysat_solver.cpp | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp index b97a223f7..b7ff2e740 100644 --- a/src/sat/smt/polysat_core.cpp +++ b/src/sat/smt/polysat_core.cpp @@ -267,9 +267,7 @@ namespace polysat { break; } } - } - - throw default_exception("nyi"); + } } void core::propagate_unsat_core() { diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index e3aff6e5f..35b177754 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -152,8 +152,7 @@ namespace polysat { auto sc = m_core.eq(p, q); m_var_eqs.setx(m_var_eqs_head, std::make_pair(v1, v2), std::make_pair(v1, v2)); ctx.push(value_trail(m_var_eqs_head)); - unsigned index = 0; -// unsigned index = m_core.register_constraint(sc); + unsigned index = m_core.register_constraint(sc, solver_assertion(v1, v2)); m_core.assign_eh(index, sc, dependency(m_var_eqs_head, s().scope_lvl())); m_var_eqs_head++; } From 8207732d27ab44eb13abe721f499fd86e5511aaa Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 7 Dec 2023 20:53:04 -0800 Subject: [PATCH 05/89] n/a --- src/sat/smt/polysat_constraints.h | 16 ++++++++++++- src/sat/smt/polysat_core.h | 2 -- src/sat/smt/polysat_solver.cpp | 37 ++++++++++++++++++++++++++++--- src/sat/smt/polysat_solver.h | 3 +++ 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/src/sat/smt/polysat_constraints.h b/src/sat/smt/polysat_constraints.h index 24c7f9a11..156776517 100644 --- a/src/sat/smt/polysat_constraints.h +++ b/src/sat/smt/polysat_constraints.h @@ -39,6 +39,16 @@ namespace polysat { pdd m_lhs, m_rhs; public: ule_constraint(pdd const& lhs, pdd const& rhs) : m_lhs(lhs), m_rhs(rhs) {} + pdd const& lhs() const { return m_lhs; } + pdd const& rhs() const { return m_rhs; } + }; + + class umul_ovfl_constraint : public constraint { + pdd m_lhs, m_rhs; + public: + umul_ovfl_constraint(pdd const& lhs, pdd const& rhs) : m_lhs(lhs), m_rhs(rhs) {} + pdd const& lhs() const { return m_lhs; } + pdd const& rhs() const { return m_rhs; } }; class signed_constraint { @@ -54,8 +64,12 @@ namespace polysat { unsigned_vector const& vars() const { return m_constraint->vars(); } unsigned var(unsigned idx) const { return m_constraint->var(idx); } bool contains_var(pvar v) const { return m_constraint->contains_var(v); } + ckind_t op() const { return m_op; } bool is_ule() const { return m_op == ule_t; } - ule_constraint& to_ule() { return *reinterpret_cast(m_constraint); } + bool is_umul_ovfl() const { return m_op == umul_ovfl_t; } + bool is_smul_fl() const { return m_op == smul_fl_t; } + ule_constraint const& to_ule() const { return *reinterpret_cast(m_constraint); } + umul_ovfl_constraint const& to_umul_ovfl() const { return *reinterpret_cast(m_constraint); } bool is_eq(pvar& v, rational& val) { throw default_exception("nyi"); } }; diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat_core.h index 0c173da77..b3e7dead5 100644 --- a/src/sat/smt/polysat_core.h +++ b/src/sat/smt/polysat_core.h @@ -102,8 +102,6 @@ namespace polysat { bool propagate(); void assign_eh(unsigned idx, signed_constraint const& sc, dependency const& dep); - expr_ref constraint2expr(signed_constraint const& sc) const { throw default_exception("nyi"); } - pdd value(rational const& v, unsigned sz); signed_constraint eq(pdd const& p) { return m_constraints.eq(p); } diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 35b177754..84e06636e 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -113,7 +113,7 @@ namespace polysat { m_lemma_level = level; m_lemma.reset(); for (auto sc : lemma) - m_lemma.push_back(m_core.constraint2expr(sc)); + m_lemma.push_back(constraint2expr(sc)); ctx.set_conflict(ex); } @@ -172,7 +172,7 @@ namespace polysat { // The polysat::solver takes care of translating signed constraints into expressions, which translate into literals. // Everything goes over expressions/literals. polysat::core is not responsible for replaying expressions. void solver::propagate(signed_constraint sc, dependency_vector const& deps) { - sat::literal lit = ctx.mk_literal(m_core.constraint2expr(sc)); + sat::literal lit = ctx.mk_literal(constraint2expr(sc)); auto [core, eqs] = explain_deps(deps); auto ex = euf::th_explain::propagate(*this, core, eqs, lit, nullptr); ctx.propagate(lit, ex); @@ -200,7 +200,7 @@ namespace polysat { void solver::add_lemma(vector const& lemma) { sat::literal_vector lits; for (auto sc : lemma) - lits.push_back(ctx.mk_literal(m_core.constraint2expr(sc))); + lits.push_back(ctx.mk_literal(constraint2expr(sc))); s().add_clause(lits.size(), lits.data(), sat::status::th(true, get_id(), nullptr)); } @@ -209,4 +209,35 @@ namespace polysat { ctx.get_th_antecedents(l, jst, r, probing); } + expr_ref solver::constraint2expr(signed_constraint const& sc) { + switch (sc.op()) { + case ckind_t::ule_t: { + auto l = pdd2expr(sc.to_ule().lhs()); + auto h = pdd2expr(sc.to_ule().rhs()); + return expr_ref(bv.mk_ule(l, h), m); + } + case ckind_t::umul_ovfl_t: { + auto l = pdd2expr(sc.to_umul_ovfl().lhs()); + auto r = pdd2expr(sc.to_umul_ovfl().rhs()); + return expr_ref(bv.mk_bvumul_ovfl(l, r), m); + } + case ckind_t::smul_fl_t: + case ckind_t::op_t: + break; + } + throw default_exception("nyi"); + } + + expr_ref solver::pdd2expr(pdd const& p) { + if (p.is_val()) { + expr* n = bv.mk_numeral(p.val(), p.power_of_2()); + return expr_ref(n, m); + } + auto lo = pdd2expr(p.lo()); + auto hi = pdd2expr(p.hi()); + auto v = var2enode(m_pddvar2var[p.var()]); + hi = bv.mk_bv_mul(v->get_expr(), hi); + return expr_ref(bv.mk_bv_add(lo, hi), m); + } + } diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 7940a7223..6f8e7a640 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -139,6 +139,9 @@ namespace polysat { void add_lemma(vector const& lemma); std::pair explain_deps(dependency_vector const& deps); + + expr_ref constraint2expr(signed_constraint const& sc); + expr_ref pdd2expr(pdd const& p); public: solver(euf::solver& ctx, theory_id id); From fda5f29e70f24359d5df030bff8933f5bdbbb661 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 04:49:38 -0800 Subject: [PATCH 06/89] tidy' Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat_constraints.h | 2 -- src/sat/smt/polysat_core.cpp | 49 +++++++++++++++++------------ src/sat/smt/polysat_core.h | 41 ++++++++++-------------- src/sat/smt/polysat_internalize.cpp | 30 +++++++----------- src/sat/smt/polysat_solver.cpp | 35 +++++++++++---------- src/sat/smt/polysat_solver.h | 10 +++--- src/sat/smt/polysat_types.h | 6 ++-- 7 files changed, 83 insertions(+), 90 deletions(-) diff --git a/src/sat/smt/polysat_constraints.h b/src/sat/smt/polysat_constraints.h index 156776517..b8068c68a 100644 --- a/src/sat/smt/polysat_constraints.h +++ b/src/sat/smt/polysat_constraints.h @@ -73,8 +73,6 @@ namespace polysat { bool is_eq(pvar& v, rational& val) { throw default_exception("nyi"); } }; - using dependent_constraint = std::pair; - class constraints { trail_stack& m_trail; public: diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp index b7ff2e740..f3950adf2 100644 --- a/src/sat/smt/polysat_core.cpp +++ b/src/sat/smt/polysat_core.cpp @@ -40,7 +40,7 @@ namespace polysat { public: mk_assign_var(pvar v, core& c) : m_var(v), c(c) {} void undo() { - c.m_justification[m_var] = nullptr; + c.m_justification[m_var] = dependency::null_dependency(); c.m_assignment.pop(); } }; @@ -84,7 +84,6 @@ namespace polysat { m_viable(*this), m_constraints(s.get_trail_stack()), m_assignment(*this), - m_dep(s.get_region()), m_var_queue(m_activity) {} @@ -107,7 +106,7 @@ namespace polysat { unsigned v = m_vars.size(); m_vars.push_back(sz2pdd(sz).mk_var(v)); m_activity.push_back({ sz, 0 }); - m_justification.push_back(nullptr); + m_justification.push_back(dependency::null_dependency()); m_watch.push_back({}); m_var_queue.mk_var_eh(v); s.ctx.push(mk_add_var(*this)); @@ -123,9 +122,9 @@ namespace polysat { m_var_queue.del_var_eh(v); } - unsigned core::register_constraint(signed_constraint& sc, solver_assertion as) { + unsigned core::register_constraint(signed_constraint& sc, dependency d) { unsigned idx = m_constraint_trail.size(); - m_constraint_trail.push_back({ sc, as }); + m_constraint_trail.push_back({ sc, d }); auto& vars = sc.vars(); unsigned i = 0, j = 0, sz = vars.size(); for (; i < sz && j < 2; ++i) @@ -172,15 +171,24 @@ namespace polysat { return false; s.ctx.push(value_trail(m_qhead)); for (; m_qhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_qhead) - propagate_constraint(m_prop_queue[m_qhead]); + propagate_assignment(m_prop_queue[m_qhead]); s.ctx.push(value_trail(m_vqhead)); for (; m_vqhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_vqhead) propagate_value(m_prop_queue[m_vqhead]); return true; } - void core::propagate_constraint(prop_item& dc) { - auto [idx, sc, dep] = dc; + signed_constraint core::get_constraint(unsigned idx, bool sign) { + auto sc = m_constraint_trail[idx].sc; + if (sign) + sc = ~sc; + return sc; + } + + + void core::propagate_assignment(prop_item& dc) { + auto [idx, sign, dep] = dc; + auto sc = get_constraint(idx, sign); if (sc.is_eq(m_var, m_value)) propagate_assignment(m_var, m_value, dep); } @@ -189,7 +197,7 @@ namespace polysat { m_watch[var].push_back(idx); } - void core::propagate_assignment(pvar v, rational const& value, stacked_dependency* dep) { + void core::propagate_assignment(pvar v, rational const& value, dependency dep) { if (is_assigned(v)) return; if (m_var_queue.contains(v)) { @@ -238,13 +246,15 @@ namespace polysat { } void core::propagate_value(prop_item const& dc) { - auto [idx, sc, dep] = dc; + auto [idx, sign, dep] = dc; + auto sc = get_constraint(idx, sign); // check if sc evaluates to false switch (eval(sc)) { case l_true: - return; + break; case l_false: - m_unsat_core = explain_eval({ sc, dep }); + m_unsat_core = explain_eval(sc); + m_unsat_core.push_back(dep); propagate_unsat_core(); return; default: @@ -253,15 +263,13 @@ namespace polysat { // if sc is v == value, then check the watch list for v to propagate truth assignments if (sc.is_eq(m_var, m_value)) { for (auto idx : m_watch[m_var]) { - auto [sc, as] = m_constraint_trail[idx]; + auto [sc, d] = m_constraint_trail[idx]; switch (eval(sc)) { case l_false: - m_unsat_core = explain_eval({ sc, nullptr }); - s.propagate(as, true, m_unsat_core); + s.propagate(d, true, explain_eval(sc)); break; case l_true: - m_unsat_core = explain_eval({ sc, nullptr }); - s.propagate(as, false, m_unsat_core); + s.propagate(d, false, explain_eval(sc)); break; default: break; @@ -272,12 +280,13 @@ namespace polysat { void core::propagate_unsat_core() { // default is to use unsat core: - s.set_conflict(m_unsat_core); // if core is based on viable, use s.set_lemma(); + + s.set_conflict(m_unsat_core); } - void core::assign_eh(unsigned index, signed_constraint const& sc, dependency const& dep) { - m_prop_queue.push_back({ index, sc, m_dep.mk_leaf(dep) }); + void core::assign_eh(unsigned index, bool sign, dependency const& dep) { + m_prop_queue.push_back({ index, sign, dep }); s.ctx.push(push_back_vector(m_prop_queue)); } diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat_core.h index b3e7dead5..802108ffc 100644 --- a/src/sat/smt/polysat_core.h +++ b/src/sat/smt/polysat_core.h @@ -30,45 +30,36 @@ namespace polysat { class core; class solver; - struct solver_assertion { - unsigned m_var1; - unsigned m_var2 = 0; - public: - solver_assertion(sat::literal lit) : m_var1(2*lit.index()) {} - solver_assertion(unsigned v1, unsigned v2) : m_var1(1 + 2*v1), m_var2(v2) {} - bool is_literal() const { return m_var1 % 2 == 0; } - sat::literal get_literal() const { SASSERT(is_literal()); return sat::to_literal(m_var1 / 2); } - unsigned var1() const { SASSERT(!is_literal()); return (m_var1 - 1) / 2; } - unsigned var2() const { SASSERT(!is_literal()); return m_var2; } - }; - class core { class mk_add_var; class mk_dqueue_var; class mk_assign_var; class mk_add_watch; typedef svector> activity; - typedef std::tuple prop_item; + typedef std::tuple prop_item; friend class viable; friend class constraints; friend class assignment; + struct constraint_info { + signed_constraint sc; + dependency d; + }; solver& s; viable m_viable; constraints m_constraints; assignment m_assignment; unsigned m_qhead = 0, m_vqhead = 0; svector m_prop_queue; - svector> m_constraint_trail; // - stacked_dependency_manager m_dep; + svector m_constraint_trail; // index of constraints mutable scoped_ptr_vector m_pdd; dependency_vector m_unsat_core; // attributes associated with variables - vector m_vars; // for each variable a pdd - vector m_values; // current value of assigned variable - ptr_vector m_justification; // justification for assignment + vector m_vars; // for each variable a pdd + vector m_values; // current value of assigned variable + svector m_justification; // justification for assignment activity m_activity; // activity of variables var_queue m_var_queue; // priority queue of variables to assign vector m_watch; // watch lists for variables for constraints on m_prop_queue where they occur @@ -82,25 +73,27 @@ namespace polysat { unsigned size(pvar v) const { return var2pdd(v).power_of_2(); } void del_var(); - bool is_assigned(pvar v) { return nullptr != m_justification[v]; } - void propagate_constraint(prop_item& dc); + bool is_assigned(pvar v) { return !m_justification[v].is_null(); } void propagate_value(prop_item const& dc); - void propagate_assignment(pvar v, rational const& value, stacked_dependency* dep); + void propagate_assignment(prop_item& dc); + void propagate_assignment(pvar v, rational const& value, dependency dep); void propagate_unsat_core(); void add_watch(unsigned idx, unsigned var); + signed_constraint get_constraint(unsigned idx, bool sign); + lbool eval(signed_constraint sc) { throw default_exception("nyi"); } - dependency_vector explain_eval(dependent_constraint const& dc) { throw default_exception("nyi"); } + dependency_vector explain_eval(signed_constraint const& dc) { throw default_exception("nyi"); } public: core(solver& s); sat::check_result check(); - unsigned register_constraint(signed_constraint& sc, solver_assertion sa); + unsigned register_constraint(signed_constraint& sc, dependency d); bool propagate(); - void assign_eh(unsigned idx, signed_constraint const& sc, dependency const& dep); + void assign_eh(unsigned idx, bool sign, dependency const& d); pdd value(rational const& v, unsigned sz); diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index af53750fc..95979348d 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -163,23 +163,18 @@ namespace polysat { public: mk_atom_trail(sat::bool_var v, solver& th) : th(th), m_var(v) {} void undo() override { - solver::atom* a = th.get_bv2a(m_var); - a->~atom(); th.erase_bv2a(m_var); } }; - solver::atom* solver::mk_atom(sat::literal lit, signed_constraint& sc) { - auto bv = lit.var(); - atom* a = get_bv2a(bv); - if (a) - return a; - a = new (get_region()) atom(bv); + void solver::mk_atom(sat::bool_var bv, signed_constraint& sc) { + if (get_bv2a(bv)) + return; + sat::literal lit(bv, false); + auto index = m_core.register_constraint(sc, dependency(lit, 0)); + auto a = new (get_region()) atom(bv, index); insert_bv2a(bv, a); - a->m_sc = sc; - a->m_index = m_core.register_constraint(sc, lit); ctx.push(mk_atom_trail(bv, *this)); - return a; } void solver::internalize_binaryc(app* e, std::function const& fn) { @@ -187,7 +182,9 @@ namespace polysat { auto q = expr2pdd(e->get_arg(1)); auto sc = ~fn(p, q); sat::literal lit = expr2literal(e); - auto* a = mk_atom(lit, sc); + if (lit.sign()) + sc = ~sc; + mk_atom(lit.var(), sc); } void solver::internalize_div_rem_i(app* e, bool is_div) { @@ -290,12 +287,9 @@ namespace polysat { sc = ~sc; sat::literal lit = expr2literal(e); - atom* a = mk_atom(lit, sc); - } - - void solver::internalize_bit2bool(atom* a, expr* e, unsigned idx) { - pdd p = expr2pdd(e); - a->m_sc = m_core.bit(p, idx); + if (lit.sign()) + sc = ~sc; + mk_atom(lit.var(), sc); } dd::pdd solver::expr2pdd(expr* e) { diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 84e06636e..65183578d 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -58,15 +58,12 @@ namespace polysat { } void solver::asserted(literal l) { - atom* a = get_bv2a(l.var()); TRACE("bv", tout << "asserted: " << l << "\n";); + atom* a = get_bv2a(l.var()); if (!a) return; force_push(); - auto sc = a->m_sc; - if (l.sign()) - sc = ~sc; - m_core.assign_eh(a->m_index, sc, dependency(l, s().lvl(l))); + m_core.assign_eh(a->m_index, l.sign(), dependency(l, s().lvl(l))); } void solver::set_conflict(dependency_vector const& core) { @@ -150,10 +147,11 @@ namespace polysat { pdd p = var2pdd(v1); pdd q = var2pdd(v2); auto sc = m_core.eq(p, q); - m_var_eqs.setx(m_var_eqs_head, std::make_pair(v1, v2), std::make_pair(v1, v2)); + m_var_eqs.setx(m_var_eqs_head, {v1, v2}, {v1, v2}); ctx.push(value_trail(m_var_eqs_head)); - unsigned index = m_core.register_constraint(sc, solver_assertion(v1, v2)); - m_core.assign_eh(index, sc, dependency(m_var_eqs_head, s().scope_lvl())); + auto d = dependency(m_var_eqs_head, s().scope_lvl()); + unsigned index = m_core.register_constraint(sc, d); + m_core.assign_eh(index, false, d); m_var_eqs_head++; } @@ -163,9 +161,10 @@ namespace polysat { pdd q = var2pdd(v2); auto sc = ~m_core.eq(p, q); sat::literal neq = ~expr2literal(ne.eq()); - auto index = m_core.register_constraint(sc, neq); + auto d = dependency(neq, s().lvl(neq)); + auto index = m_core.register_constraint(sc, d); TRACE("bv", tout << neq << " := " << s().value(neq) << " @" << s().scope_lvl() << "\n"); - m_core.assign_eh(index, sc, dependency(neq, s().lvl(neq))); + m_core.assign_eh(index, false, d); } // Core uses the propagate callback to add unit propagations to the trail. @@ -178,19 +177,22 @@ namespace polysat { ctx.propagate(lit, ex); } - void solver::propagate(solver_assertion as, bool sign, dependency_vector const& deps) { + void solver::propagate(dependency const& d, bool sign, dependency_vector const& deps) { auto [core, eqs] = explain_deps(deps); - if (as.is_literal()) { - auto lit = as.get_literal(); + if (d.is_literal()) { + auto lit = d.literal(); if (sign) lit.neg(); + if (s().value(lit) == l_true) + return; auto ex = euf::th_explain::propagate(*this, core, eqs, lit, nullptr); ctx.propagate(lit, ex); } - else if (sign) { + else if (sign) { + auto const [v1, v2] = m_var_eqs[d.index()]; // equalities are always asserted so a negative propagation is a conflict. - auto n1 = var2enode(as.var1()); - auto n2 = var2enode(as.var2()); + auto n1 = var2enode(v1); + auto n2 = var2enode(v2); eqs.push_back({ n1, n2 }); auto ex = euf::th_explain::conflict(*this, core, eqs, nullptr); ctx.set_conflict(ex); @@ -223,6 +225,7 @@ namespace polysat { } case ckind_t::smul_fl_t: case ckind_t::op_t: + NOT_IMPLEMENTED_YET(); break; } throw default_exception("nyi"); diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 6f8e7a640..408d0756c 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -42,9 +42,8 @@ namespace polysat { struct atom { bool_var m_bv; - unsigned m_index = 0; - signed_constraint m_sc; - atom(bool_var b) :m_bv(b) {} + unsigned m_index; + atom(bool_var b, unsigned index) :m_bv(b), m_index(index) {} ~atom() { } }; @@ -92,7 +91,7 @@ namespace polysat { void erase_bv2a(bool_var bv) { m_bool_var2atom[bv] = nullptr; } atom* get_bv2a(bool_var bv) const { return m_bool_var2atom.get(bv, nullptr); } class mk_atom_trail; - atom* mk_atom(sat::literal lit, signed_constraint& sc); + void mk_atom(sat::bool_var bv, signed_constraint& sc); void set_bit_eh(theory_var v, literal l, unsigned idx); void init_bits(expr* e, expr_ref_vector const & bits); void mk_bits(theory_var v); @@ -122,7 +121,6 @@ namespace polysat { void internalize_polysat(app* a); void assert_bv2int_axiom(app * n); void assert_int2bv_axiom(app* n); - void internalize_bit2bool(atom* a, expr* e, unsigned idx); pdd expr2pdd(expr* e); pdd var2pdd(euf::theory_var v); @@ -134,7 +132,7 @@ namespace polysat { void set_conflict(dependency_vector const& core); void set_lemma(vector const& lemma, unsigned level, dependency_vector const& core); void propagate(signed_constraint sc, dependency_vector const& deps); - void propagate(solver_assertion as, bool sign, dependency_vector const& deps); + void propagate(dependency const& d, bool sign, dependency_vector const& deps); void add_lemma(vector const& lemma); diff --git a/src/sat/smt/polysat_types.h b/src/sat/smt/polysat_types.h index 4296a8247..42f283cc7 100644 --- a/src/sat/smt/polysat_types.h +++ b/src/sat/smt/polysat_types.h @@ -11,28 +11,26 @@ Author: #include "math/dd/dd_pdd.h" #include "util/sat_literal.h" -#include "util/dependency.h" namespace polysat { using pdd = dd::pdd; using pvar = unsigned; - class dependency { unsigned m_index; unsigned m_level; public: dependency(sat::literal lit, unsigned level) : m_index(2 * lit.index()), m_level(level) {} dependency(unsigned var_idx, unsigned level) : m_index(1 + 2 * var_idx), m_level(level) {} + static dependency null_dependency() { return dependency(0, UINT_MAX); } + bool is_null() const { return m_level == UINT_MAX; } bool is_literal() const { return m_index % 2 == 0; } sat::literal literal() const { SASSERT(is_literal()); return sat::to_literal(m_index / 2); } unsigned index() const { SASSERT(!is_literal()); return (m_index - 1) / 2; } unsigned level() const { return m_level; } }; - using stacked_dependency = stacked_dependency_manager::dependency; - inline std::ostream& operator<<(std::ostream& out, dependency d) { if (d.is_literal()) return out << d.literal() << "@" << d.level(); From 237ee6b083accd5db2206dc89189f897934cfcb5 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 04:54:08 -0800 Subject: [PATCH 07/89] tidy' Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat_core.cpp | 10 ++++++++++ src/sat/smt/polysat_core.h | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp index f3950adf2..73e9c5a6b 100644 --- a/src/sat/smt/polysat_core.cpp +++ b/src/sat/smt/polysat_core.cpp @@ -290,6 +290,16 @@ namespace polysat { s.ctx.push(push_back_vector(m_prop_queue)); } + dependency_vector core::explain_eval(signed_constraint const& sc) { + dependency_vector deps; + for (auto v : sc.vars()) + if (is_assigned(v)) + deps.push_back(m_justification[v]); + return deps; + } + lbool core::eval(signed_constraint const& sc) { + throw default_exception("nyi"); + } } diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat_core.h index 802108ffc..6944c39d8 100644 --- a/src/sat/smt/polysat_core.h +++ b/src/sat/smt/polysat_core.h @@ -83,8 +83,8 @@ namespace polysat { signed_constraint get_constraint(unsigned idx, bool sign); - lbool eval(signed_constraint sc) { throw default_exception("nyi"); } - dependency_vector explain_eval(signed_constraint const& dc) { throw default_exception("nyi"); } + lbool eval(signed_constraint const& sc); + dependency_vector explain_eval(signed_constraint const& sc); public: core(solver& s); From 642f1ea1f60b4f6d0212e976bce879d59b0bdd25 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 10:53:28 -0800 Subject: [PATCH 08/89] port over ule_constraint --- src/sat/smt/CMakeLists.txt | 1 + src/sat/smt/polysat_constraints.cpp | 8 +- src/sat/smt/polysat_constraints.h | 14 +- src/sat/smt/polysat_solver.cpp | 4 +- src/sat/smt/polysat_solver.h | 2 +- src/sat/smt/polysat_ule.cpp | 346 ++++++++++++++++++++++++++++ src/sat/smt/polysat_ule.h | 56 +++++ 7 files changed, 421 insertions(+), 10 deletions(-) create mode 100644 src/sat/smt/polysat_ule.cpp create mode 100644 src/sat/smt/polysat_ule.h diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 2a6fb9e66..e27849391 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -39,6 +39,7 @@ z3_add_component(sat_smt polysat_internalize.cpp polysat_model.cpp polysat_solver.cpp + polysat_ule.cpp q_clause.cpp q_ematch.cpp q_eval.cpp diff --git a/src/sat/smt/polysat_constraints.cpp b/src/sat/smt/polysat_constraints.cpp index 1c9de327c..101d60e62 100644 --- a/src/sat/smt/polysat_constraints.cpp +++ b/src/sat/smt/polysat_constraints.cpp @@ -11,15 +11,21 @@ Author: Jakob Rath 2021-04-06 --*/ + #include "sat/smt/polysat_core.h" #include "sat/smt/polysat_solver.h" #include "sat/smt/polysat_constraints.h" +#include "sat/smt/polysat_ule.h" namespace polysat { signed_constraint constraints::ule(pdd const& p, pdd const& q) { + pdd lhs = p, rhs = q; + bool is_positive = true; + ule_constraint::simplify(is_positive, lhs, rhs); auto* c = alloc(ule_constraint, p, q); m_trail.push(new_obj_trail(c)); - return signed_constraint(ckind_t::ule_t, c); + auto sc = signed_constraint(ckind_t::ule_t, c); + return is_positive ? sc : ~sc; } } diff --git a/src/sat/smt/polysat_constraints.h b/src/sat/smt/polysat_constraints.h index b8068c68a..da82431c4 100644 --- a/src/sat/smt/polysat_constraints.h +++ b/src/sat/smt/polysat_constraints.h @@ -14,11 +14,14 @@ Author: #pragma once +#include "util/trail.h" #include "sat/smt/polysat_types.h" namespace polysat { class core; + class ule_constraint; + class assignment; using pdd = dd::pdd; using pvar = unsigned; @@ -33,15 +36,12 @@ namespace polysat { unsigned_vector const& vars() const { return m_vars; } unsigned var(unsigned idx) const { return m_vars[idx]; } bool contains_var(pvar v) const { return m_vars.contains(v); } + virtual std::ostream& display(std::ostream& out, lbool status) const = 0; + virtual std::ostream& display(std::ostream& out) const = 0; + virtual lbool eval() const = 0; + virtual lbool eval(assignment const& a) const = 0; }; - class ule_constraint : public constraint { - pdd m_lhs, m_rhs; - public: - ule_constraint(pdd const& lhs, pdd const& rhs) : m_lhs(lhs), m_rhs(rhs) {} - pdd const& lhs() const { return m_lhs; } - pdd const& rhs() const { return m_rhs; } - }; class umul_ovfl_constraint : public constraint { pdd m_lhs, m_rhs; diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 65183578d..57098c447 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -27,6 +27,7 @@ The result of polysat::core::check is one of: #include "ast/euf/euf_bv_plugin.h" #include "sat/smt/polysat_solver.h" #include "sat/smt/euf_solver.h" +#include "sat/smt/polysat_ule.h" namespace polysat { @@ -170,11 +171,12 @@ namespace polysat { // Core uses the propagate callback to add unit propagations to the trail. // The polysat::solver takes care of translating signed constraints into expressions, which translate into literals. // Everything goes over expressions/literals. polysat::core is not responsible for replaying expressions. - void solver::propagate(signed_constraint sc, dependency_vector const& deps) { + dependency solver::propagate(signed_constraint sc, dependency_vector const& deps) { sat::literal lit = ctx.mk_literal(constraint2expr(sc)); auto [core, eqs] = explain_deps(deps); auto ex = euf::th_explain::propagate(*this, core, eqs, lit, nullptr); ctx.propagate(lit, ex); + return dependency(lit, s().lvl(lit)); } void solver::propagate(dependency const& d, bool sign, dependency_vector const& deps) { diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 408d0756c..76923f88f 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -131,7 +131,7 @@ namespace polysat { void add_eq_literal(pvar v, rational const& val); void set_conflict(dependency_vector const& core); void set_lemma(vector const& lemma, unsigned level, dependency_vector const& core); - void propagate(signed_constraint sc, dependency_vector const& deps); + dependency propagate(signed_constraint sc, dependency_vector const& deps); void propagate(dependency const& d, bool sign, dependency_vector const& deps); void add_lemma(vector const& lemma); diff --git a/src/sat/smt/polysat_ule.cpp b/src/sat/smt/polysat_ule.cpp new file mode 100644 index 000000000..08448b34d --- /dev/null +++ b/src/sat/smt/polysat_ule.cpp @@ -0,0 +1,346 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat unsigned <= constraints + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +Notes: + + Canonical representation of equation p == 0 is the constraint p <= 0. + The alternatives p < 1, -1 <= q, q > -2 are eliminated. + + Rewrite rules to simplify expressions. + In the following let k, k1, k2 be values. + + - k1 <= k2 ==> 0 <= 0 if k1 <= k2 + - k1 <= k2 ==> 1 <= 0 if k1 > k2 + - 0 <= p ==> 0 <= 0 + - p <= 0 ==> 1 <= 0 if p is never zero due to parity + - p <= -1 ==> 0 <= 0 + - k <= p ==> p - k <= - k - 1 + - k*2^n*p <= 0 ==> 2^n*p <= 0 if k is odd, leading coeffient is always a power of 2. + + Note: the rules will rewrite alternative formulations of equations: + - -1 <= p ==> p + 1 <= 0 + - 1 <= p ==> p - 1 <= -2 + + Rewrite rules on signed constraints: + + - p > -2 ==> p + 1 <= 0 + - p <= -2 ==> p + 1 > 0 + + At this point, all equations are in canonical form. + +TODO: clause simplifications: + + - p + k <= p ==> p + k <= k & p != 0 for k != 0 + - p*q = 0 ==> p = 0 or q = 0 applies to any factoring + - 2*p <= 2*q ==> (p >= 2^n-1 & q < 2^n-1) or (p >= 2^n-1 = q >= 2^n-1 & p <= q) + ==> (p >= 2^n-1 => q < 2^n-1 or p <= q) & + (p < 2^n-1 => p <= q) & + (p < 2^n-1 => q < 2^n-1) + + - 3*p <= 3*q ==> ? + +Note: + case p <= p + k is already covered because we test (lhs - rhs).is_val() + + It can be seen as an instance of lemma 5.2 of Supratik and John. + +The following forms are equivalent: + + p <= q + p <= p - q - 1 + q - p <= q + q - p <= -p - 1 + -q - 1 <= -p - 1 + -q - 1 <= p - q - 1 + +Useful lemmas: + + p <= q && q+1 != 0 ==> p+1 <= q+1 + + p <= q && p != 0 ==> -q <= -p + +--*/ + +#include "sat/smt/polysat_constraints.h" +#include "sat/smt/polysat_ule.h" + +#define LOG(_msg_) verbose_stream() << _msg_ << "\n" + +namespace polysat { + + // Simplify lhs <= rhs. + // + // NOTE: the result should not depend on the initial value of is_positive; + // the purpose of is_positive is to allow flipping the sign as part of a rewrite rule. + static void simplify_impl(bool& is_positive, pdd& lhs, pdd& rhs) { + + SASSERT_EQ(lhs.power_of_2(), rhs.power_of_2()); + unsigned const N = lhs.power_of_2(); + + // 0 <= p --> 0 <= 0 + if (lhs.is_zero()) { + rhs = 0; + return; + } + // p <= -1 --> 0 <= 0 + if (rhs.is_max()) { + lhs = 0, rhs = 0; + return; + } + // p <= p --> 0 <= 0 + if (lhs == rhs) { + lhs = 0, rhs = 0; + return; + } + // Evaluate constants + if (lhs.is_val() && rhs.is_val()) { + if (lhs.val() <= rhs.val()) + lhs = 0, rhs = 0; + else + lhs = 0, rhs = 0, is_positive = !is_positive; + return; + } + + // Try to reduce the number of variables on one side using one of these rules: + // + // p <= q --> p <= p - q - 1 + // p <= q --> q - p <= q + // + // Possible alternative to 1: + // p <= q --> q - p <= -p - 1 + // Possible alternative to 2: + // p <= q --> -q-1 <= p - q - 1 + // + // Example: + // + // x <= x + y --> x <= - y - 1 + // x + y <= x --> -y <= x + if (!lhs.is_val() && !rhs.is_val()) { + unsigned const lhs_vars = lhs.free_vars().size(); + unsigned const rhs_vars = rhs.free_vars().size(); + unsigned const diff_vars = (lhs - rhs).free_vars().size(); + if (diff_vars < lhs_vars || diff_vars < rhs_vars) { + LOG("reduce number of varables"); + // verbose_stream() << "IN: " << ule_pp(to_lbool(is_positive), lhs, rhs) << "\n"; + if (lhs_vars <= rhs_vars) + rhs = lhs - rhs - 1; + else + lhs = rhs - lhs; + // verbose_stream() << "OUT: " << ule_pp(to_lbool(is_positive), lhs, rhs) << "\n"; + } + } + + // -p + k <= k --> p <= k + if (rhs.is_val() && !rhs.is_zero() && lhs.offset() == rhs.val()) { + LOG("-p + k <= k --> p <= k"); + lhs = rhs - lhs; + } + + // k <= p + k --> p <= -k-1 + if (lhs.is_val() && !lhs.is_zero() && lhs.val() == rhs.offset()) { + LOG("k <= p + k --> p <= -k-1"); + pdd k = lhs; + lhs = rhs - lhs; + rhs = -k - 1; + } + + // k <= -p --> p-1 <= -k-1 + if (lhs.is_val() && rhs.leading_coefficient().get_bit(N - 1) && !rhs.offset().is_zero()) { + LOG("k <= -p --> p-1 <= -k-1"); + pdd k = lhs; + lhs = -(rhs + 1); + rhs = -k - 1; + } + + // -p <= k --> -k-1 <= p-1 + // if (rhs.is_val() && lhs.leading_coefficient() > rational::power_of_two(N - 1) && !lhs.offset().is_zero()) { + if (rhs.is_val() && lhs.leading_coefficient().get_bit(N - 1) && !lhs.offset().is_zero()) { + LOG("-p <= k --> -k-1 <= p-1"); + pdd k = rhs; + rhs = -(lhs + 1); + lhs = -k - 1; + } + + // NOTE: do not use pdd operations in conditions when comparing pdd values. + // e.g.: "lhs.offset() == (rhs + 1).val()" is problematic with the following evaluation: + // 1. return reference into pdd_manager::m_values from lhs.offset() + // 2. compute rhs+1, which may reallocate pdd_manager::m_values + // 3. now the reference returned from lhs.offset() may be invalid + pdd const rhs_plus_one = rhs + 1; + + // p - k <= -k - 1 --> k <= p + // TODO: potential bug here: first call offset(), then rhs+1 has to reallocate pdd_manager::m_values, then the reference to offset is broken. + if (rhs.is_val() && !rhs.is_zero() && lhs.offset() == rhs_plus_one.val()) { + LOG("p - k <= -k - 1 --> k <= p"); + pdd k = -(rhs + 1); + rhs = lhs + k; + lhs = k; + } + + pdd const lhs_minus_one = lhs - 1; + + // k <= 2^(N-1)*p + q + k-1 --> k <= 2^(N-1)*p - q + if (lhs.is_val() && rhs.leading_coefficient() == rational::power_of_two(N-1) && rhs.offset() == lhs_minus_one.val()) { + LOG("k <= 2^(N-1)*p + q + k-1 --> k <= 2^(N-1)*p - q"); + rhs = (lhs - 1) - rhs; + } + + // -1 <= p --> p + 1 <= 0 + if (lhs.is_max()) { + lhs = rhs + 1; + rhs = 0; + } + + // 1 <= p --> p > 0 + if (lhs.is_one()) { + lhs = rhs; + rhs = 0; + is_positive = !is_positive; + } + + // p > -2 --> p + 1 <= 0 + // p <= -2 --> p + 1 > 0 + if (rhs.is_val() && !rhs.is_zero() && (rhs + 2).is_zero()) { + // Note: rhs.is_zero() iff rhs.manager().power_of_2() == 1 (the rewrite is not wrong for M=2, but useless) + lhs = lhs + 1; + rhs = 0; + is_positive = !is_positive; + } + // 2p + 1 <= 0 --> 0 < 0 + if (rhs.is_zero() && lhs.is_never_zero()) { + lhs = 0; + is_positive = !is_positive; + return; + } + // a*p + q <= 0 --> p + a^-1*q <= 0 for a odd + if (rhs.is_zero() && !lhs.leading_coefficient().is_power_of_two()) { + rational lc = lhs.leading_coefficient(); + rational x, y; + gcd(lc, lhs.manager().two_to_N(), x, y); + if (x.is_neg()) + x = mod(x, lhs.manager().two_to_N()); + lhs *= x; + SASSERT(lhs.leading_coefficient().is_power_of_two()); + } + } // simplify_impl +} + + +namespace polysat { + + ule_constraint::ule_constraint(pdd const& l, pdd const& r) : + m_lhs(l), m_rhs(r) { + + SASSERT_EQ(m_lhs.power_of_2(), m_rhs.power_of_2()); + + vars().append(m_lhs.free_vars()); + for (auto v : m_rhs.free_vars()) + if (!vars().contains(v)) + vars().push_back(v); + } + + void ule_constraint::simplify(bool& is_positive, pdd& lhs, pdd& rhs) { + SASSERT_EQ(lhs.power_of_2(), rhs.power_of_2()); +#ifndef NDEBUG + bool const old_is_positive = is_positive; + pdd const old_lhs = lhs; + pdd const old_rhs = rhs; +#endif + simplify_impl(is_positive, lhs, rhs); +#ifndef NDEBUG + if (old_is_positive != is_positive || old_lhs != lhs || old_rhs != rhs) { + ule_pp const old_ule(to_lbool(old_is_positive), old_lhs, old_rhs); + ule_pp const new_ule(to_lbool(is_positive), lhs, rhs); + // always-false and always-true constraints should be rewritten to 0 != 0 and 0 == 0, respectively. + if (is_always_false(old_is_positive, old_lhs, old_rhs)) { + SASSERT(!is_positive); + SASSERT(lhs.is_zero()); + SASSERT(rhs.is_zero()); + } + if (is_always_true(old_is_positive, old_lhs, old_rhs)) { + SASSERT(is_positive); + SASSERT(lhs.is_zero()); + SASSERT(rhs.is_zero()); + } + } + SASSERT(is_simplified(lhs, rhs)); // rewriting should be idempotent +#endif + } + + bool ule_constraint::is_simplified(pdd const& lhs0, pdd const& rhs0) { + bool const pos0 = true; + bool pos1 = pos0; + pdd lhs1 = lhs0; + pdd rhs1 = rhs0; + simplify_impl(pos1, lhs1, rhs1); + bool const is_simplified = (pos1 == pos0 && lhs1 == lhs0 && rhs1 == rhs0); + DEBUG_CODE({ + // check that simplification doesn't depend on initial sign + bool pos2 = !pos0; + pdd lhs2 = lhs0; + pdd rhs2 = rhs0; + simplify_impl(pos2, lhs2, rhs2); + SASSERT_EQ(pos2, !pos1); + SASSERT_EQ(lhs2, lhs1); + SASSERT_EQ(rhs2, rhs1); + }); + return is_simplified; + } + + std::ostream& ule_constraint::display(std::ostream& out, lbool status, pdd const& lhs, pdd const& rhs) { + out << lhs; + if (rhs.is_zero() && status == l_true) out << " == "; + else if (rhs.is_zero() && status == l_false) out << " != "; + else if (status == l_true) out << " <= "; + else if (status == l_false) out << " > "; + else out << " <=/> "; + return out << rhs; + } + + std::ostream& ule_constraint::display(std::ostream& out, lbool status) const { + return display(out, status, m_lhs, m_rhs); + } + + std::ostream& ule_constraint::display(std::ostream& out) const { + return display(out, l_true, m_lhs, m_rhs); + } + + + + // Evaluate lhs <= rhs + lbool ule_constraint::eval(pdd const& lhs, pdd const& rhs) { + // NOTE: don't assume simplifications here because we also call this on partially substituted constraints + if (lhs.is_zero()) + return l_true; // 0 <= p + if (lhs == rhs) + return l_true; // p <= p + if (rhs.is_max()) + return l_true; // p <= -1 + if (rhs.is_zero() && lhs.is_never_zero()) + return l_false; // p <= 0 implies p == 0 + if (lhs.is_one() && rhs.is_never_zero()) + return l_true; // 1 <= p implies p != 0 + if (lhs.is_val() && rhs.is_val()) + return to_lbool(lhs.val() <= rhs.val()); + return l_undef; + + } + + lbool ule_constraint::eval() const { + return eval(lhs(), rhs()); + } + + lbool ule_constraint::eval(assignment const& a) const { + return eval(a.apply_to(lhs()), a.apply_to(rhs())); + } + +} diff --git a/src/sat/smt/polysat_ule.h b/src/sat/smt/polysat_ule.h new file mode 100644 index 000000000..12efe506a --- /dev/null +++ b/src/sat/smt/polysat_ule.h @@ -0,0 +1,56 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat unsigned <= constraint + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once +#include "sat/smt/polysat_ule.h" +#include "sat/smt/polysat_assignment.h" +#include "sat/smt/polysat_constraints.h" + + +namespace polysat { + + class ule_constraint final : public constraint { + pdd m_lhs; + pdd m_rhs; + static bool is_always_true(bool is_positive, pdd const& lhs, pdd const& rhs) { return eval(lhs, rhs) == to_lbool(is_positive); } + static bool is_always_false(bool is_positive, pdd const& lhs, pdd const& rhs) { return is_always_true(!is_positive, lhs, rhs); } + static lbool eval(pdd const& lhs, pdd const& rhs); + + public: + ule_constraint(pdd const& l, pdd const& r); + ~ule_constraint() override {} + pdd const& lhs() const { return m_lhs; } + pdd const& rhs() const { return m_rhs; } + static std::ostream& display(std::ostream& out, lbool status, pdd const& lhs, pdd const& rhs); + std::ostream& display(std::ostream& out, lbool status) const override; + std::ostream& display(std::ostream& out) const override; + lbool eval() const override; + lbool eval(assignment const& a) const override; + bool is_eq() const { return m_rhs.is_zero(); } + unsigned power_of_2() const { return m_lhs.power_of_2(); } + + static void simplify(bool& is_positive, pdd& lhs, pdd& rhs); + static bool is_simplified(pdd const& lhs, pdd const& rhs); // return true if lhs <= rhs is not simplified further. this is meant to be used in assertions. + }; + + struct ule_pp { + lbool status; + pdd lhs; + pdd rhs; + ule_pp(lbool status, pdd const& lhs, pdd const& rhs): status(status), lhs(lhs), rhs(rhs) {} + ule_pp(lbool status, ule_constraint const& ule): status(status), lhs(ule.lhs()), rhs(ule.rhs()) {} + }; + + inline std::ostream& operator<<(std::ostream& out, ule_pp const& u) { return ule_constraint::display(out, u.status, u.lhs, u.rhs); } + +} From 8546b275ef7a5eaa708cd32ad9ede5a4b54cef51 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 12:04:19 -0800 Subject: [PATCH 09/89] port forbidden intervals --- src/sat/smt/CMakeLists.txt | 3 + src/sat/smt/polysat_constraints.cpp | 10 + src/sat/smt/polysat_constraints.h | 15 +- src/sat/smt/polysat_core.cpp | 12 +- src/sat/smt/polysat_core.h | 6 +- src/sat/smt/polysat_fi.cpp | 588 ++++++++++++++++++++++++++++ src/sat/smt/polysat_fi.h | 122 ++++++ src/sat/smt/polysat_interval.h | 224 +++++++++++ src/sat/smt/polysat_solver.cpp | 5 +- src/sat/smt/polysat_umul_ovfl.cpp | 73 ++++ src/sat/smt/polysat_umul_ovfl.h | 39 ++ src/sat/smt/polysat_viable.cpp | 36 ++ src/sat/smt/polysat_viable.h | 2 + 13 files changed, 1122 insertions(+), 13 deletions(-) create mode 100644 src/sat/smt/polysat_fi.cpp create mode 100644 src/sat/smt/polysat_fi.h create mode 100644 src/sat/smt/polysat_interval.h create mode 100644 src/sat/smt/polysat_umul_ovfl.cpp create mode 100644 src/sat/smt/polysat_umul_ovfl.h create mode 100644 src/sat/smt/polysat_viable.cpp diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index e27849391..bdc602da9 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -37,9 +37,12 @@ z3_add_component(sat_smt polysat_constraints.cpp polysat_core.cpp polysat_internalize.cpp + polysat_fi.cpp polysat_model.cpp polysat_solver.cpp polysat_ule.cpp + polysat_umul_ovfl.cpp + polysat_viable.cpp q_clause.cpp q_ematch.cpp q_eval.cpp diff --git a/src/sat/smt/polysat_constraints.cpp b/src/sat/smt/polysat_constraints.cpp index 101d60e62..a03b4f5f5 100644 --- a/src/sat/smt/polysat_constraints.cpp +++ b/src/sat/smt/polysat_constraints.cpp @@ -28,4 +28,14 @@ namespace polysat { auto sc = signed_constraint(ckind_t::ule_t, c); return is_positive ? sc : ~sc; } + + lbool signed_constraint::eval(assignment& a) const { + lbool r = m_constraint->eval(a); + return m_sign ? ~r : r; + } + + std::ostream& signed_constraint::display(std::ostream& out) const { + if (m_sign) out << "~"; + return out << *m_constraint; + } } diff --git a/src/sat/smt/polysat_constraints.h b/src/sat/smt/polysat_constraints.h index da82431c4..121fc2da6 100644 --- a/src/sat/smt/polysat_constraints.h +++ b/src/sat/smt/polysat_constraints.h @@ -21,6 +21,7 @@ namespace polysat { class core; class ule_constraint; + class umul_ovfl_constraint; class assignment; using pdd = dd::pdd; @@ -42,14 +43,8 @@ namespace polysat { virtual lbool eval(assignment const& a) const = 0; }; + inline std::ostream& operator<<(std::ostream& out, constraint const& c) { return c.display(out); } - class umul_ovfl_constraint : public constraint { - pdd m_lhs, m_rhs; - public: - umul_ovfl_constraint(pdd const& lhs, pdd const& rhs) : m_lhs(lhs), m_rhs(rhs) {} - pdd const& lhs() const { return m_lhs; } - pdd const& rhs() const { return m_rhs; } - }; class signed_constraint { bool m_sign = false; @@ -60,10 +55,13 @@ namespace polysat { signed_constraint(ckind_t c, constraint* p) : m_op(c), m_constraint(p) {} signed_constraint operator~() const { signed_constraint r(*this); r.m_sign = !r.m_sign; return r; } bool sign() const { return m_sign; } + bool is_positive() const { return !m_sign; } + bool is_negative() const { return m_sign; } unsigned_vector& vars() { return m_constraint->vars(); } unsigned_vector const& vars() const { return m_constraint->vars(); } unsigned var(unsigned idx) const { return m_constraint->var(idx); } bool contains_var(pvar v) const { return m_constraint->contains_var(v); } + lbool eval(assignment& a) const; ckind_t op() const { return m_op; } bool is_ule() const { return m_op == ule_t; } bool is_umul_ovfl() const { return m_op == umul_ovfl_t; } @@ -71,8 +69,11 @@ namespace polysat { ule_constraint const& to_ule() const { return *reinterpret_cast(m_constraint); } umul_ovfl_constraint const& to_umul_ovfl() const { return *reinterpret_cast(m_constraint); } bool is_eq(pvar& v, rational& val) { throw default_exception("nyi"); } + std::ostream& display(std::ostream& out) const; }; + inline std::ostream& operator<<(std::ostream& out, signed_constraint const& c) { return c.display(out); } + class constraints { trail_stack& m_trail; public: diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp index 73e9c5a6b..3de88d93b 100644 --- a/src/sat/smt/polysat_core.cpp +++ b/src/sat/smt/polysat_core.cpp @@ -262,8 +262,10 @@ namespace polysat { } // if sc is v == value, then check the watch list for v to propagate truth assignments if (sc.is_eq(m_var, m_value)) { - for (auto idx : m_watch[m_var]) { - auto [sc, d] = m_constraint_trail[idx]; + for (auto idx1 : m_watch[m_var]) { + if (idx == idx1) + continue; + auto [sc, d] = m_constraint_trail[idx1]; switch (eval(sc)) { case l_false: s.propagate(d, true, explain_eval(sc)); @@ -299,7 +301,11 @@ namespace polysat { } lbool core::eval(signed_constraint const& sc) { - throw default_exception("nyi"); + return sc.eval(m_assignment); + } + + pdd core::subst(pdd const& p) { + return m_assignment.apply_to(p); } } diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat_core.h index 6944c39d8..3c8a79bd6 100644 --- a/src/sat/smt/polysat_core.h +++ b/src/sat/smt/polysat_core.h @@ -70,7 +70,7 @@ namespace polysat { dd::pdd_manager& sz2pdd(unsigned sz) const; dd::pdd_manager& var2pdd(pvar v) const; - unsigned size(pvar v) const { return var2pdd(v).power_of_2(); } + void del_var(); bool is_assigned(pvar v) { return !m_justification[v].is_null(); } @@ -96,6 +96,7 @@ namespace polysat { void assign_eh(unsigned idx, bool sign, dependency const& d); pdd value(rational const& v, unsigned sz); + pdd subst(pdd const&); signed_constraint eq(pdd const& p) { return m_constraints.eq(p); } signed_constraint eq(pdd const& p, pdd const& q) { return m_constraints.eq(p - q); } @@ -124,6 +125,9 @@ namespace polysat { pdd concat(unsigned n, pdd const* args) { throw default_exception("nyi"); } pvar add_var(unsigned sz); pdd var(pvar p) { return m_vars[p]; } + unsigned size(pvar v) const { return var2pdd(v).power_of_2(); } + + constraints& cs() { return m_constraints; } std::ostream& display(std::ostream& out) const { throw default_exception("nyi"); } }; diff --git a/src/sat/smt/polysat_fi.cpp b/src/sat/smt/polysat_fi.cpp new file mode 100644 index 000000000..349243ed8 --- /dev/null +++ b/src/sat/smt/polysat_fi.cpp @@ -0,0 +1,588 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + Conflict explanation using forbidden intervals as described in + "Solving bitvectors with MCSAT: explanations from bits and pieces" + by S. Graham-Lengrand, D. Jovanovic, B. Dutertre. + +Author: + + Jakob Rath 2021-04-06 + Nikolaj Bjorner (nbjorner) 2021-03-19 + +--*/ +#include "sat/smt/polysat_fi.h" +#include "sat/smt/polysat_interval.h" +#include "sat/smt/polysat_umul_ovfl.h" +#include "sat/smt/polysat_ule.h" +#include "sat/smt/polysat_core.h" + +namespace polysat { + + /** + * + * \param[in] c Original constraint + * \param[in] v Variable that is bounded by constraint + * \param[out] fi "forbidden interval" record that captures values not allowed for v + * \returns True iff a forbidden interval exists and the output parameters were set. + */ + bool forbidden_intervals::get_interval(signed_constraint const& c, pvar v, fi_record& fi) { + // verbose_stream() << "get_interval for v" << v << " " << c << "\n"; + SASSERT(fi.side_cond.empty()); + SASSERT(fi.src.empty()); + fi.bit_width = s.size(v); // TODO: preliminary + if (c.is_ule()) + return get_interval_ule(c, v, fi); + if (c.is_umul_ovfl()) + return get_interval_umul_ovfl(c, v, fi); + return false; + } + + bool forbidden_intervals::get_interval_umul_ovfl(signed_constraint const& c, pvar v, fi_record& fi) { + using std::swap; + + backtrack _backtrack(fi.side_cond); + + fi.coeff = 1; + fi.src.push_back(c); + + // eval(lhs) = a1*v + eval(e1) = a1*v + b1 + // eval(rhs) = a2*v + eval(e2) = a2*v + b2 + // We keep the e1, e2 around in case we need side conditions such as e1=b1, e2=b2. + auto [ok1, a1, e1, b1] = linear_decompose(v, c.to_umul_ovfl().p(), fi.side_cond); + auto [ok2, a2, e2, b2] = linear_decompose(v, c.to_umul_ovfl().q(), fi.side_cond); + + auto& m = e1.manager(); + rational bound = m.max_value(); + + if (ok2 && !ok1) { + swap(a1, a2); + swap(e1, e2); + swap(b1, b2); + swap(ok1, ok2); + } + if (ok1 && !ok2 && a1.is_one() && b1.is_zero()) { + if (c.is_positive()) { + _backtrack.released = true; + rational lo_val(0); + rational hi_val(2); + pdd lo = m.mk_val(lo_val); + pdd hi = m.mk_val(hi_val); + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + return true; + } + } + + if (!ok1 || !ok2) + return false; + + + if (a2.is_one() && a1.is_zero()) { + swap(a1, a2); + swap(e1, e2); + swap(b1, b2); + } + + if (!a1.is_one() || !a2.is_zero()) + return false; + + if (!b1.is_zero()) + return false; + + _backtrack.released = true; + + // Ovfl(v, e2) + + + if (c.is_positive()) { + if (b2.val() <= 1) { + fi.interval = eval_interval::full(); + fi.side_cond.push_back(s.cs().ule(e2, 1)); + } + else { + // [0, div(bound, b2.val()) + 1[ + rational lo_val(0); + rational hi_val(div(bound, b2.val()) + 1); + pdd lo = m.mk_val(lo_val); + pdd hi = m.mk_val(hi_val); + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + fi.side_cond.push_back(s.cs().ule(e2, b2.val())); + } + + } + else { + if (b2.val() <= 1) { + _backtrack.released = false; + return false; + } + else { + // [div(bound, b2.val()) + 1, 0[ + rational lo_val(div(bound, b2.val()) + 1); + rational hi_val(0); + pdd lo = m.mk_val(lo_val); + pdd hi = m.mk_val(hi_val); + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + fi.side_cond.push_back(s.cs().ule(b2.val(), e2)); + } + } + + // LOG("overflow interval " << fi.interval); + + return true; + } + + static char const* _last_function = ""; + + bool forbidden_intervals::get_interval_ule(signed_constraint const& c, pvar v, fi_record& fi) { + + backtrack _backtrack(fi.side_cond); + + fi.coeff = 1; + fi.src.push_back(c); + + struct show { + forbidden_intervals& f; + signed_constraint const& c; + pvar v; + fi_record& fi; + backtrack& _backtrack; + show(forbidden_intervals& f, + signed_constraint const& c, + pvar v, + fi_record& fi, + backtrack& _backtrack):f(f), c(c), v(v), fi(fi), _backtrack(_backtrack) {} + ~show() { + if (!_backtrack.released) + return; + IF_VERBOSE(0, verbose_stream() << _last_function << " " << v << " " << c << " " << fi.interval << " " << fi.side_cond << "\n"); + } + }; + // uncomment to trace intervals + // show _show(*this, c, v, fi, _backtrack); + + // eval(lhs) = a1*v + eval(e1) = a1*v + b1 + // eval(rhs) = a2*v + eval(e2) = a2*v + b2 + // We keep the e1, e2 around in case we need side conditions such as e1=b1, e2=b2. + auto [ok1, a1, e1, b1] = linear_decompose(v, c.to_ule().lhs(), fi.side_cond); + auto [ok2, a2, e2, b2] = linear_decompose(v, c.to_ule().rhs(), fi.side_cond); + + _backtrack.released = true; + + // v > q + if (false && ok1 && !ok2 && match_non_zero(c, a1, b1, e1, c.to_ule().rhs(), fi)) + return true; + + // p > v + if (false && !ok1 && ok2 && match_non_max(c, c.to_ule().lhs(), a2, b2, e2, fi)) + return true; + + if (!ok1 || !ok2 || (a1.is_zero() && a2.is_zero())) { + _backtrack.released = false; + return false; + } + SASSERT(b1.is_val()); + SASSERT(b2.is_val()); + + // a*v + b <= 0, a odd + // a*v + b > 0, a odd + if (match_zero(c, a1, b1, e1, a2, b2, e2, fi)) + return true; + + // -1 <= a*v + b, a odd + // -1 > a*v + b, a odd + if (match_max(c, a1, b1, e1, a2, b2, e2, fi)) + return true; + + if (match_linear1(c, a1, b1, e1, a2, b2, e2, fi)) + return true; + if (match_linear2(c, a1, b1, e1, a2, b2, e2, fi)) + return true; + if (match_linear3(c, a1, b1, e1, a2, b2, e2, fi)) + return true; + if (match_linear4(c, a1, b1, e1, a2, b2, e2, fi)) + return true; + + _backtrack.released = false; + return false; + } + + void forbidden_intervals::push_eq(bool is_zero, pdd const& p, vector& side_cond) { + SASSERT(!p.is_val() || (is_zero == p.is_zero())); + if (p.is_val()) + return; + else if (is_zero) + side_cond.push_back(s.eq(p)); + else + side_cond.push_back(~s.eq(p)); + } + + std::tuple forbidden_intervals::linear_decompose(pvar v, pdd const& p, vector& out_side_cond) { + auto& m = p.manager(); + pdd q = m.zero(); + pdd e = m.zero(); + unsigned const deg = p.degree(v); + if (deg == 0) + // p = 0*v + e + e = p; + else if (deg == 1) + // p = q*v + e + p.factor(v, 1, q, e); + else + return std::tuple(false, rational(0), q, e); + + // r := eval(q) + // Add side constraint q = r. + if (!q.is_val()) { + pdd r = s.subst(q); + + + if (!r.is_val()) + return std::tuple(false, rational(0), q, e); + out_side_cond.push_back(s.eq(q, r)); + q = r; + } + auto b = s.subst(e); + return std::tuple(b.is_val(), q.val(), e, b); + }; + + eval_interval forbidden_intervals::to_interval( + signed_constraint const& c, bool is_trivial, rational & coeff, + rational & lo_val, pdd & lo, + rational & hi_val, pdd & hi) { + + dd::pdd_manager& m = lo.manager(); + + if (is_trivial) { + if (c.is_positive()) + // TODO: we cannot use empty intervals for interpolation. So we + // can remove the empty case (make it represent 'full' instead), + // and return 'false' here. Then we do not need the proper/full + // tag on intervals. + return eval_interval::empty(m); + else + return eval_interval::full(); + } + + rational pow2 = m.two_to_N(); + + if (coeff > pow2/2) { + // TODO: if coeff != pow2 - 1, isn't this counterproductive now? considering the gap condition on refine-equal-lin acceleration. + + coeff = pow2 - coeff; + SASSERT(coeff > 0); + // Transform according to: y \in [l;u[ <=> -y \in [1-u;1-l[ + // -y \in [1-u;1-l[ + // <=> -y - (1 - u) < (1 - l) - (1 - u) { by: y \in [l;u[ <=> y - l < u - l } + // <=> u - y - 1 < u - l { simplified } + // <=> (u-l) - (u-y-1) - 1 < u-l { by: a < b <=> b - a - 1 < b } + // <=> y - l < u - l { simplified } + // <=> y \in [l;u[. + lo = 1 - lo; + hi = 1 - hi; + swap(lo, hi); + lo_val = mod(1 - lo_val, pow2); + hi_val = mod(1 - hi_val, pow2); + lo_val.swap(hi_val); + } + + if (c.is_positive()) + return eval_interval::proper(lo, lo_val, hi, hi_val); + else + return eval_interval::proper(hi, hi_val, lo, lo_val); + } + + /** + * Match e1 + t <= e2, with t = a1*y + * condition for empty/full: e2 == -1 + */ + bool forbidden_intervals::match_linear1(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + if (a2.is_zero() && !a1.is_zero()) { + SASSERT(!a1.is_zero()); + bool is_trivial = (b2 + 1).is_zero(); + push_eq(is_trivial, e2 + 1, fi.side_cond); + auto lo = e2 - e1 + 1; + rational lo_val = (b2 - b1 + 1).val(); + auto hi = -e1; + rational hi_val = (-b1).val(); + fi.coeff = a1; + fi.interval = to_interval(c, is_trivial, fi.coeff, lo_val, lo, hi_val, hi); + add_non_unit_side_conds(fi, b1, e1, b2, e2); + return true; + } + return false; + } + + /** + * e1 <= e2 + t, with t = a2*y + * condition for empty/full: e1 == 0 + */ + bool forbidden_intervals::match_linear2(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + if (a1.is_zero() && !a2.is_zero()) { + SASSERT(!a2.is_zero()); + bool is_trivial = b1.is_zero(); + push_eq(is_trivial, e1, fi.side_cond); + auto lo = -e2; + rational lo_val = (-b2).val(); + auto hi = e1 - e2; + rational hi_val = (b1 - b2).val(); + fi.coeff = a2; + fi.interval = to_interval(c, is_trivial, fi.coeff, lo_val, lo, hi_val, hi); + add_non_unit_side_conds(fi, b1, e1, b2, e2); + return true; + } + return false; + } + + /** + * e1 + t <= e2 + t, with t = a1*y = a2*y + * condition for empty/full: e1 == e2 + */ + bool forbidden_intervals::match_linear3(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + if (a1 == a2 && !a1.is_zero()) { + bool is_trivial = b1.val() == b2.val(); + push_eq(is_trivial, e1 - e2, fi.side_cond); + auto lo = -e2; + rational lo_val = (-b2).val(); + auto hi = -e1; + rational hi_val = (-b1).val(); + fi.coeff = a1; + fi.interval = to_interval(c, is_trivial, fi.coeff, lo_val, lo, hi_val, hi); + add_non_unit_side_conds(fi, b1, e1, b2, e2); + return true; + } + return false; + } + + /** + * e1 + t <= e2 + t', with t = a1*y, t' = a2*y, a1 != a2, a1, a2 non-zero + */ + bool forbidden_intervals::match_linear4(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + if (a1 != a2 && !a1.is_zero() && !a2.is_zero()) { + // NOTE: we don't have an interval here in the same sense as in the other cases. + // We use the interval to smuggle out the values a1,b1,a2,b2 without adding additional fields. + // to_interval flips a1,b1 with a2,b2 for negative constraints, which we also need for this case. + auto lo = b1; + rational lo_val = a1; + auto hi = b2; + rational hi_val = a2; + // We use fi.coeff = -1 to tell the caller to treat it as a diseq_lin. + fi.coeff = -1; + fi.interval = to_interval(c, false, fi.coeff, lo_val, lo, hi_val, hi); + add_non_unit_side_conds(fi, b1, e1, b2, e2); + SASSERT(!fi.interval.is_currently_empty()); + return true; + } + return false; + } + + /** + * a*v <= 0, a odd + * forbidden interval for v is [1;0[ + * + * a*v + b <= 0, a odd + * forbidden interval for v is [n+1;n[ where n = -b * a^-1 + * + * TODO: extend to + * 2^k*a*v <= 0, a odd + * (using intervals for the lower bits of v) + */ + bool forbidden_intervals::match_zero( + signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + if (a1.is_odd() && a2.is_zero() && b2.is_zero()) { + auto& m = e1.manager(); + rational const& mod_value = m.two_to_N(); + rational a1_inv; + VERIFY(a1.mult_inverse(m.power_of_2(), a1_inv)); + + // interval for a*v + b > 0 is [n;n+1[ where n = -b * a^-1 + rational lo_val = mod(-b1.val() * a1_inv, mod_value); + pdd lo = -e1 * a1_inv; + rational hi_val = mod(lo_val + 1, mod_value); + pdd hi = lo + 1; + + // interval for a*v + b <= 0 is the complement + if (c.is_positive()) { + std::swap(lo_val, hi_val); + std::swap(lo, hi); + } + + fi.coeff = 1; + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + // RHS == 0 is a precondition because we can only multiply with a^-1 in equations, not inequalities + if (b2 != e2) + fi.side_cond.push_back(s.eq(b2, e2)); + return true; + } + return false; + } + + /** + * -1 <= a*v + b, a odd + * forbidden interval for v is [n+1;n[ where n = (-b-1) * a^-1 + */ + bool forbidden_intervals::match_max( + signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + if (a1.is_zero() && b1.is_max() && a2.is_odd()) { + auto& m = e2.manager(); + rational const& mod_value = m.two_to_N(); + rational a2_inv; + VERIFY(a2.mult_inverse(m.power_of_2(), a2_inv)); + + // interval for -1 > a*v + b is [n;n+1[ where n = (-b-1) * a^-1 + rational lo_val = mod((-1 - b2.val()) * a2_inv, mod_value); + pdd lo = (-1 - e2) * a2_inv; + rational hi_val = mod(lo_val + 1, mod_value); + pdd hi = lo + 1; + + // interval for -1 <= a*v + b is the complement + if (c.is_positive()) { + std::swap(lo_val, hi_val); + std::swap(lo, hi); + } + + fi.coeff = 1; + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + // LHS == -1 is a precondition because we can only multiply with a^-1 in equations, not inequalities + if (b1 != e1) + fi.side_cond.push_back(s.eq(b1, e1)); + return true; + } + return false; + } + + /** + * v > q + * forbidden interval for v is [0,1[ + * + * v - k > q + * forbidden interval for v is [k,k+1[ + * + * v > q + * forbidden interval for v is [0;q+1[ but at least [0;1[ + * + * The following cases are implemented, and subsume the simple ones above. + * + * v - k > q + * forbidden interval for v is [k;k+q+1[ but at least [k;k+1[ + * + * a*v - k > q, a odd + * forbidden interval for v is [a^-1*k, a^-1*k + 1[ + */ + bool forbidden_intervals::match_non_zero( + signed_constraint const& c, + rational const& a1, pdd const& b1, pdd const& e1, + pdd const& q, + fi_record& fi) { + _last_function = __func__; + SASSERT(b1.is_val()); + if (a1.is_one() && c.is_negative()) { + // v - k > q + auto& m = e1.manager(); + rational const& mod_value = m.two_to_N(); + rational lo_val = (-b1).val(); + pdd lo = -e1; + rational hi_val = mod(lo_val + 1, mod_value); + pdd hi = lo + q + 1; + fi.coeff = 1; + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + return true; + } + if (a1.is_odd() && c.is_negative()) { + // a*v - k > q, a odd + auto& m = e1.manager(); + rational const& mod_value = m.two_to_N(); + rational a1_inv; + VERIFY(a1.mult_inverse(m.power_of_2(), a1_inv)); + rational lo_val(mod(-b1.val() * a1_inv, mod_value)); + auto lo = -e1 * a1_inv; + rational hi_val(mod(lo_val + 1, mod_value)); + auto hi = lo + 1; + fi.coeff = 1; + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + return true; + } + return false; + } + + /** + * p > v + * forbidden interval for v is [p;0[ but at least [-1,0[ + * + * p > v + k + * forbidden interval for v is [p-k;-k[ but at least [-1-k,-k[ + * + * p > a*v + k, a odd + * forbidden interval for v is [ a^-1*(-1-k) ; a^-1*(-1-k) + 1 [ + */ + bool forbidden_intervals::match_non_max( + signed_constraint const& c, + pdd const& p, + rational const& a2, pdd const& b2, pdd const& e2, + fi_record& fi) { + _last_function = __func__; + SASSERT(b2.is_val()); + if (a2.is_one() && c.is_negative()) { + // p > v + k + auto& m = e2.manager(); + rational const& mod_value = m.two_to_N(); + rational hi_val = (-b2).val(); + pdd hi = -e2; + rational lo_val = mod(hi_val - 1, mod_value); + pdd lo = p - e2; + fi.coeff = 1; + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + return true; + } + if (a2.is_odd() && c.is_negative()) { + // p > a*v + k, a odd + auto& m = e2.manager(); + rational const& mod_value = m.two_to_N(); + rational a2_inv; + VERIFY(a2.mult_inverse(m.power_of_2(), a2_inv)); + rational lo_val = mod(a2_inv * (-1 - b2.val()), mod_value); + pdd lo = a2_inv * (-1 - e2); + rational hi_val = mod(lo_val + 1, mod_value); + pdd hi = lo + 1; + fi.coeff = 1; + fi.interval = eval_interval::proper(lo, lo_val, hi, hi_val); + return true; + } + return false; + } + + + void forbidden_intervals::add_non_unit_side_conds(fi_record& fi, pdd const& b1, pdd const& e1, pdd const& b2, pdd const& e2) { + if (fi.coeff == 1) + return; + if (b1 != e1) + fi.side_cond.push_back(s.eq(b1, e1)); + if (b2 != e2) + fi.side_cond.push_back(s.eq(b2, e2)); + } +} diff --git a/src/sat/smt/polysat_fi.h b/src/sat/smt/polysat_fi.h new file mode 100644 index 000000000..7782deb4a --- /dev/null +++ b/src/sat/smt/polysat_fi.h @@ -0,0 +1,122 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + Conflict explanation using forbidden intervals as described in + "Solving bitvectors with MCSAT: explanations from bits and pieces" + by S. Graham-Lengrand, D. Jovanovic, B. Dutertre. + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once +#include "sat/smt/polysat_types.h" +#include "sat/smt/polysat_interval.h" +#include "sat/smt/polysat_constraints.h" + +namespace polysat { + + class core; + + struct fi_record { + eval_interval interval; + vector side_cond; + vector src; // only units may have multiple src (as they can consist of contracted bit constraints) + rational coeff; + unsigned bit_width = 0; // number of lower bits; TODO: should move this to viable::entry; where the coeff/bit-width is adapted accordingly + + /** Create invalid fi_record */ + fi_record(): interval(eval_interval::full()) {} + + void reset() { + interval = eval_interval::full(); + side_cond.reset(); + src.reset(); + coeff.reset(); + bit_width = 0; + } + + struct less { + bool operator()(fi_record const& a, fi_record const& b) const { + return a.interval.lo_val() < b.interval.lo_val(); + } + }; + }; + + class forbidden_intervals { + + void push_eq(bool is_trivial, pdd const& p, vector& side_cond); + eval_interval to_interval(signed_constraint const& c, bool is_trivial, rational& coeff, + rational & lo_val, pdd & lo, rational & hi_val, pdd & hi); + + + std::tuple linear_decompose(pvar v, pdd const& p, vector& out_side_cond); + + bool match_linear1(signed_constraint const& c, + rational const& a1, pdd const& b1, pdd const& e1, + rational const& a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + bool match_linear2(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + bool match_linear3(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + bool match_linear4(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + void add_non_unit_side_conds(fi_record& fi, pdd const& b1, pdd const& e1, pdd const& b2, pdd const& e2); + + bool match_zero(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + bool match_max(signed_constraint const& c, + rational const & a1, pdd const& b1, pdd const& e1, + rational const & a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + bool match_non_zero(signed_constraint const& c, + rational const& a1, pdd const& b1, pdd const& e1, + pdd const& q, + fi_record& fi); + + bool match_non_max(signed_constraint const& c, + pdd const& p, + rational const& a2, pdd const& b2, pdd const& e2, + fi_record& fi); + + bool get_interval_ule(signed_constraint const& c, pvar v, fi_record& fi); + + bool get_interval_umul_ovfl(signed_constraint const& c, pvar v, fi_record& fi); + + struct backtrack { + bool released = false; + vector& side_cond; + unsigned sz; + backtrack(vector& s):side_cond(s), sz(s.size()) {} + ~backtrack() { + if (!released) + side_cond.shrink(sz); + } + }; + + core& s; + + public: + forbidden_intervals(core& s): s(s) {} + bool get_interval(signed_constraint const& c, pvar v, fi_record& fi); + }; +} diff --git a/src/sat/smt/polysat_interval.h b/src/sat/smt/polysat_interval.h new file mode 100644 index 000000000..9965dbab1 --- /dev/null +++ b/src/sat/smt/polysat_interval.h @@ -0,0 +1,224 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat intervals + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-6 + +--*/ +#pragma once +#include "sat/smt/polysat_types.h" +#include + +namespace polysat { + + struct pdd_bounds { + pdd lo; ///< lower bound, inclusive + pdd hi; ///< upper bound, exclusive + }; + + /** + * An interval is either [lo; hi[ (excl. upper bound) or the full domain Z_{2^w}. + * If lo > hi, the interval wraps around, i.e., represents the union of [lo; 2^w[ and [0; hi[. + * Membership test t \in [lo; hi[ is equivalent to t - lo < hi - lo. + */ + class interval { + std::optional m_bounds = std::nullopt; + + interval() = default; + interval(pdd const& lo, pdd const& hi): m_bounds({lo, hi}) {} + public: + static interval empty(dd::pdd_manager& m) { return proper(m.zero(), m.zero()); } + static interval full() { return {}; } + static interval proper(pdd const& lo, pdd const& hi) { return {lo, hi}; } + + interval(interval const&) = default; + interval(interval&&) = default; + interval& operator=(interval const& other) { + m_bounds = std::nullopt; // clear pdds first to allow changing pdd_manager (probably should change the PDD assignment operator; but for now I want to be able to detect manager confusions) + m_bounds = other.m_bounds; + return *this; + } + interval& operator=(interval&& other) { + m_bounds = std::nullopt; // clear pdds first to allow changing pdd_manager + m_bounds = std::move(other.m_bounds); + return *this; + } + ~interval() = default; + + bool is_full() const { return !m_bounds; } + bool is_proper() const { return !!m_bounds; } + bool is_always_empty() const { return is_proper() && lo() == hi(); } + pdd const& lo() const { SASSERT(is_proper()); return m_bounds->lo; } + pdd const& hi() const { SASSERT(is_proper()); return m_bounds->hi; } + }; + + inline std::ostream& operator<<(std::ostream& os, interval const& i) { + if (i.is_full()) + return os << "full"; + else + return os << "[" << i.lo() << " ; " << i.hi() << "["; + } + + // distance from a to b, wrapping around at mod_value. + // basically mod(b - a, mod_value), but distance(0, mod_value, mod_value) = mod_value. + inline rational distance(rational const& a, rational const& b, rational const& mod_value) { + SASSERT(mod_value.is_power_of_two()); + SASSERT(0 <= a && a < mod_value); + SASSERT(0 <= b && b <= mod_value); + rational x = b - a; + if (x.is_neg()) + x += mod_value; + return x; + } + + class r_interval { + rational m_lo; + rational m_hi; + + r_interval(rational lo, rational hi) + : m_lo(std::move(lo)), m_hi(std::move(hi)) + {} + + public: + + static r_interval empty() { + return {rational::zero(), rational::zero()}; + } + + static r_interval full() { + return {rational(-1), rational::zero()}; + } + + static r_interval proper(rational lo, rational hi) { + SASSERT(0 <= lo); + SASSERT(0 <= hi); + return {std::move(lo), std::move(hi)}; + } + + bool is_full() const { return m_lo.is_neg(); } + bool is_proper() const { return !is_full(); } + bool is_empty() const { return is_proper() && lo() == hi(); } + rational const& lo() const { SASSERT(is_proper()); return m_lo; } + rational const& hi() const { SASSERT(is_proper()); return m_hi; } + + // this one also supports representing full intervals as [lo;mod_value[ + static rational len(rational const& lo, rational const& hi, rational const& mod_value) { + SASSERT(mod_value.is_power_of_two()); + SASSERT(0 <= lo && lo < mod_value); + SASSERT(0 <= hi && hi <= mod_value); + SASSERT(hi != mod_value || lo == 0); // hi == mod_value only allowed when lo == 0 + rational len = hi - lo; + if (len.is_neg()) + len += mod_value; + return len; + } + + rational len(rational const& mod_value) const { + SASSERT(is_proper()); + return len(lo(), hi(), mod_value); + } + + // deals only with proper intervals + // but works with full intervals represented as [0;mod_value[ -- maybe we should just change representation of full intervals to this always + static bool contains(rational const& lo, rational const& hi, rational const& val) { + if (lo <= hi) + return lo <= val && val < hi; + else + return val < hi || val >= lo; + } + + bool contains(rational const& val) const { + if (is_full()) + return true; + else + return contains(lo(), hi(), val); + } + + }; + + class eval_interval { + interval m_symbolic; + rational m_concrete_lo; + rational m_concrete_hi; + + eval_interval(interval&& i, rational const& lo_val, rational const& hi_val): + m_symbolic(std::move(i)), m_concrete_lo(lo_val), m_concrete_hi(hi_val) {} + public: + static eval_interval empty(dd::pdd_manager& m) { + return {interval::empty(m), rational::zero(), rational::zero()}; + } + + static eval_interval full() { + return {interval::full(), rational::zero(), rational::zero()}; + } + + static eval_interval proper(pdd const& lo, rational const& lo_val, pdd const& hi, rational const& hi_val) { + SASSERT(0 <= lo_val && lo_val <= lo.manager().max_value()); + SASSERT(0 <= hi_val && hi_val <= hi.manager().max_value()); + return {interval::proper(lo, hi), lo_val, hi_val}; + } + + bool is_full() const { return m_symbolic.is_full(); } + bool is_proper() const { return m_symbolic.is_proper(); } + bool is_always_empty() const { return m_symbolic.is_always_empty(); } + bool is_currently_empty() const { return is_proper() && lo_val() == hi_val(); } + interval const& symbolic() const { return m_symbolic; } + pdd const& lo() const { return m_symbolic.lo(); } + pdd const& hi() const { return m_symbolic.hi(); } + rational const& lo_val() const { SASSERT(is_proper()); return m_concrete_lo; } + rational const& hi_val() const { SASSERT(is_proper()); return m_concrete_hi; } + + rational current_len() const { + SASSERT(is_proper()); + return mod(hi_val() - lo_val(), lo().manager().two_to_N()); + } + + bool currently_contains(rational const& val) const { + if (is_full()) + return true; + else if (lo_val() <= hi_val()) + return lo_val() <= val && val < hi_val(); + else + return val < hi_val() || val >= lo_val(); + } + + bool currently_contains(eval_interval const& other) const { + if (is_full()) + return true; + if (other.is_full()) + return false; + // lo <= lo' <= hi' <= hi' + if (lo_val() <= other.lo_val() && other.lo_val() <= other.hi_val() && other.hi_val() <= hi_val()) + return true; + if (lo_val() <= hi_val()) + return false; + // hi < lo <= lo' <= hi' + if (lo_val() <= other.lo_val() && other.lo_val() <= other.hi_val()) + return true; + // lo' <= hi' <= hi < lo + if (other.lo_val() <= other.hi_val() && other.hi_val() <= hi_val()) + return true; + // hi' <= hi < lo <= lo' + if (other.hi_val() <= hi_val() && lo_val() <= other.lo_val()) + return true; + return false; + } + + }; // class eval_interval + + inline std::ostream& operator<<(std::ostream& os, eval_interval const& i) { + if (i.is_full()) + return os << "full"; + else { + auto& m = i.hi().manager(); + return os << i.symbolic() << " := [" << m.normalize(i.lo_val()) << ";" << m.normalize(i.hi_val()) << "["; + } + } + +} diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 57098c447..3e01ff391 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -28,6 +28,7 @@ The result of polysat::core::check is one of: #include "sat/smt/polysat_solver.h" #include "sat/smt/euf_solver.h" #include "sat/smt/polysat_ule.h" +#include "sat/smt/polysat_umul_ovfl.h" namespace polysat { @@ -221,8 +222,8 @@ namespace polysat { return expr_ref(bv.mk_ule(l, h), m); } case ckind_t::umul_ovfl_t: { - auto l = pdd2expr(sc.to_umul_ovfl().lhs()); - auto r = pdd2expr(sc.to_umul_ovfl().rhs()); + auto l = pdd2expr(sc.to_umul_ovfl().p()); + auto r = pdd2expr(sc.to_umul_ovfl().q()); return expr_ref(bv.mk_bvumul_ovfl(l, r), m); } case ckind_t::smul_fl_t: diff --git a/src/sat/smt/polysat_umul_ovfl.cpp b/src/sat/smt/polysat_umul_ovfl.cpp new file mode 100644 index 000000000..5c448bc0a --- /dev/null +++ b/src/sat/smt/polysat_umul_ovfl.cpp @@ -0,0 +1,73 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat multiplication overflow constraint + +Author: + + Jakob Rath, Nikolaj Bjorner (nbjorner) 2021-12-09 + +--*/ +#include "sat/smt/polysat_constraints.h" +#include "sat/smt/polysat_assignment.h" +#include "sat/smt/polysat_umul_ovfl.h" + + +namespace polysat { + + umul_ovfl_constraint::umul_ovfl_constraint(pdd const& p, pdd const& q): + m_p(p), m_q(q) { + simplify(); + vars().append(m_p.free_vars()); + for (auto v : m_q.free_vars()) + if (!vars().contains(v)) + vars().push_back(v); + + } + void umul_ovfl_constraint::simplify() { + if (m_p.is_zero() || m_q.is_zero() || m_p.is_one() || m_q.is_one()) { + m_q = 0; + m_p = 0; + return; + } + if (m_p.index() > m_q.index()) + swap(m_p, m_q); + } + + std::ostream& umul_ovfl_constraint::display(std::ostream& out, lbool status) const { + switch (status) { + case l_true: return display(out); + case l_false: return display(out << "~"); + case l_undef: return display(out << "?"); + } + return out; + } + + std::ostream& umul_ovfl_constraint::display(std::ostream& out) const { + return out << "ovfl*(" << m_p << ", " << m_q << ")"; + } + + lbool umul_ovfl_constraint::eval(pdd const& p, pdd const& q) { + if (p.is_zero() || q.is_zero() || p.is_one() || q.is_one()) + return l_false; + + if (p.is_val() && q.is_val()) { + if (p.val() * q.val() > p.manager().max_value()) + return l_true; + else + return l_false; + } + return l_undef; + } + + lbool umul_ovfl_constraint::eval() const { + return eval(p(), q()); + } + + lbool umul_ovfl_constraint::eval(assignment const& a) const { + return eval(a.apply_to(p()), a.apply_to(q())); + } + +} diff --git a/src/sat/smt/polysat_umul_ovfl.h b/src/sat/smt/polysat_umul_ovfl.h new file mode 100644 index 000000000..502ed4bbf --- /dev/null +++ b/src/sat/smt/polysat_umul_ovfl.h @@ -0,0 +1,39 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat multiplication overflow constraint + +Author: + + Jakob Rath, Nikolaj Bjorner (nbjorner) 2021-12-09 + +--*/ +#pragma once +#include "sat/smt/polysat_constraints.h" + +namespace polysat { + + class umul_ovfl_constraint final : public constraint { + + pdd m_p; + pdd m_q; + + void simplify(); + static bool is_always_true(bool is_positive, pdd const& p, pdd const& q) { return eval(p, q) == to_lbool(is_positive); } + static bool is_always_false(bool is_positive, pdd const& p, pdd const& q) { return is_always_true(!is_positive, p, q); } + static lbool eval(pdd const& p, pdd const& q); + + public: + umul_ovfl_constraint(pdd const& p, pdd const& q); + ~umul_ovfl_constraint() override {} + pdd const& p() const { return m_p; } + pdd const& q() const { return m_q; } + std::ostream& display(std::ostream& out, lbool status) const override; + std::ostream& display(std::ostream& out) const override; + lbool eval() const override; + lbool eval(assignment const& a) const override; + }; + +} diff --git a/src/sat/smt/polysat_viable.cpp b/src/sat/smt/polysat_viable.cpp new file mode 100644 index 000000000..79689d01f --- /dev/null +++ b/src/sat/smt/polysat_viable.cpp @@ -0,0 +1,36 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + maintain viable domains + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +Notes: + + +--*/ + + +#include "util/debug.h" +#include "sat/smt/polysat_viable.h" +#include "sat/smt/polysat_core.h" + +namespace polysat { + + std::ostream& operator<<(std::ostream& out, find_t f) { + switch (f) { + case find_t::empty: return out << "empty"; + case find_t::singleton: return out << "singleton"; + case find_t::multiple: return out << "multiple"; + case find_t::resource_out: return out << "resource-out"; + default: return out << ""; + } + } + + +} diff --git a/src/sat/smt/polysat_viable.h b/src/sat/smt/polysat_viable.h index def069652..2f87e79cc 100644 --- a/src/sat/smt/polysat_viable.h +++ b/src/sat/smt/polysat_viable.h @@ -30,6 +30,8 @@ namespace polysat { class core; + std::ostream& operator<<(std::ostream& out, find_t x); + class viable { core& c; public: From aa82ca3017dd0e82f504da6faec1fdbc310a9313 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 13:25:50 -0800 Subject: [PATCH 10/89] add log helper to util --- src/sat/smt/polysat_constraints.cpp | 7 ++ src/sat/smt/polysat_constraints.h | 2 +- src/util/CMakeLists.txt | 1 + src/util/log.cpp | 125 +++++++++++++++++++++++ src/util/log.h | 148 ++++++++++++++++++++++++++++ src/util/log_helper.h | 105 ++++++++++++++++++++ 6 files changed, 387 insertions(+), 1 deletion(-) create mode 100644 src/util/log.cpp create mode 100644 src/util/log.h create mode 100644 src/util/log_helper.h diff --git a/src/sat/smt/polysat_constraints.cpp b/src/sat/smt/polysat_constraints.cpp index a03b4f5f5..b588019dc 100644 --- a/src/sat/smt/polysat_constraints.cpp +++ b/src/sat/smt/polysat_constraints.cpp @@ -16,6 +16,7 @@ Author: #include "sat/smt/polysat_solver.h" #include "sat/smt/polysat_constraints.h" #include "sat/smt/polysat_ule.h" +#include "sat/smt/polysat_umul_ovfl.h" namespace polysat { @@ -29,6 +30,12 @@ namespace polysat { return is_positive ? sc : ~sc; } + signed_constraint constraints::umul_ovfl(pdd const& p, pdd const& q) { + auto* c = alloc(umul_ovfl_constraint, p, q); + m_trail.push(new_obj_trail(c)); + return signed_constraint(ckind_t::umul_ovfl_t, c); + } + lbool signed_constraint::eval(assignment& a) const { lbool r = m_constraint->eval(a); return m_sign ? ~r : r; diff --git a/src/sat/smt/polysat_constraints.h b/src/sat/smt/polysat_constraints.h index 121fc2da6..b62fbba21 100644 --- a/src/sat/smt/polysat_constraints.h +++ b/src/sat/smt/polysat_constraints.h @@ -85,7 +85,7 @@ namespace polysat { signed_constraint sle(pdd const& p, pdd const& q) { throw default_exception("nyi"); } signed_constraint ult(pdd const& p, pdd const& q) { throw default_exception("nyi"); } signed_constraint slt(pdd const& p, pdd const& q) { throw default_exception("nyi"); } - signed_constraint umul_ovfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint umul_ovfl(pdd const& p, pdd const& q); signed_constraint smul_ovfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } signed_constraint smul_udfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } signed_constraint bit(pdd const& p, unsigned i) { throw default_exception("nyi"); } diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index c2a7e8296..eead1069e 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -26,6 +26,7 @@ z3_add_component(util inf_rational.cpp inf_s_integer.cpp lbool.cpp + log.cpp luby.cpp memory_manager.cpp min_cut.cpp diff --git a/src/util/log.cpp b/src/util/log.cpp new file mode 100644 index 000000000..f3101bfc4 --- /dev/null +++ b/src/util/log.cpp @@ -0,0 +1,125 @@ +#ifndef _MSC_VER +#include // for isatty +#else +#include +#include +#undef min +#endif +#include +#include + +#include "util/util.h" +#include "util/log.h" + +/** +For windows: +https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/isatty?view=msvc-160 + +So include and create platform wrapper for _isatty / isatty. + +Other: +- add option to configure z3 trace feature to point to std::err +- roll this functionality into trace.cpp/trace.h in util +- Generally, generic functionality should not reside in specific directories. +- code diverges on coding conventions. +*/ + +/* + TODO: add deferred logs, i.e., the messages are held back and only printed when a non-conditional message is logged. + Purpose: reduce noise, e.g., when printing prerequisites for transformations that do not always apply. +*/ + +char const* color_red() { return "\x1B[31m"; } +char const* color_yellow() { return "\x1B[33m"; } +char const* color_blue() { return "\x1B[34m"; } +char const* color_reset() { return "\x1B[0m"; } + +#if POLYSAT_LOGGING_ENABLED + +std::atomic g_log_enabled(true); + +void set_log_enabled(bool log_enabled) { + g_log_enabled = log_enabled; +} + +bool get_log_enabled() { + return g_log_enabled; +} + +static LogLevel get_max_log_level(std::string const& fn, std::string const& pretty_fn) { + (void)fn; + (void)pretty_fn; + + // if (fn == "pop_levels") + // return LogLevel::Default; + + // also covers 'reset_marks' and 'set_marks' + if (fn.find("set_mark") != std::string::npos) + return LogLevel::Default; + + // return LogLevel::Verbose; + return LogLevel::Default; +} + +/// Filter log messages +bool polysat_should_log(unsigned verbose_lvl, LogLevel msg_level, std::string fn, std::string pretty_fn) { + if (!g_log_enabled) + return false; + if (get_verbosity_level() < verbose_lvl) + return false; + LogLevel max_log_level = get_max_log_level(fn, pretty_fn); + return msg_level <= max_log_level; +} + +static char const* level_color(LogLevel msg_level) { + switch (msg_level) { + case LogLevel::Heading1: return color_red(); + case LogLevel::Heading2: return color_yellow(); + case LogLevel::Heading3: return color_blue(); + default: return nullptr; + } +} + +int polysat_log_indent_level = 0; + +std::pair polysat_log(LogLevel msg_level, std::string fn, std::string /* pretty_fn */) { + std::ostream& os = std::cerr; + + size_t width = 20; + size_t padding = 0; + if (width >= fn.size()) + padding = width - fn.size(); + else + fn = fn.substr(0, width - 3) + "..."; + char const* color = nullptr; + color = level_color(msg_level); +#ifdef _MSC_VER + HANDLE hOut = GetStdHandle(STD_ERROR_HANDLE); + bool ok = hOut != INVALID_HANDLE_VALUE; + DWORD dwMode = 0; + ok = ok && GetConsoleMode(hOut, &dwMode); + dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; + ok = ok && SetConsoleMode(hOut, dwMode); +#else + int const fd = fileno(stderr); + if (color && !isatty(fd)) { color = nullptr; } +#endif + + if (color) + os << color; + os << "[" << fn << "] " << std::string(padding, ' '); + os << std::string(polysat_log_indent_level, ' '); + return {os, (bool)color}; + +} + +polysat_log_indent::polysat_log_indent(int amount): m_amount{amount} { + polysat_log_indent_level += m_amount; +} + +polysat_log_indent::~polysat_log_indent() { + polysat_log_indent_level -= m_amount; +} + + +#endif // POLYSAT_LOGGING_ENABLED diff --git a/src/util/log.h b/src/util/log.h new file mode 100644 index 000000000..fe89f91cf --- /dev/null +++ b/src/util/log.h @@ -0,0 +1,148 @@ +#ifndef POLYSAT_LOG_HPP +#define POLYSAT_LOG_HPP + + +#include +#include +#include +#include "util/log_helper.h" + + +// By default, enable logging only in debug mode +#ifndef POLYSAT_LOGGING_ENABLED +# ifndef NDEBUG +# define POLYSAT_LOGGING_ENABLED 1 +# else +# define POLYSAT_LOGGING_ENABLED 0 +# endif +#endif + + +char const* color_blue(); +char const* color_yellow(); +char const* color_red(); +char const* color_reset(); + + +#if POLYSAT_LOGGING_ENABLED + +void set_log_enabled(bool log_enabled); +bool get_log_enabled(); + +class scoped_set_log_enabled { + bool m_prev; +public: + scoped_set_log_enabled(bool enabled) { + m_prev = get_log_enabled(); + set_log_enabled(enabled); + } + ~scoped_set_log_enabled() { + set_log_enabled(m_prev); + } +}; + +class polysat_log_indent +{ + int m_amount; +public: + polysat_log_indent(int amount); + ~polysat_log_indent(); +}; + +/// Lower log level means more important +enum class LogLevel : int { + None = 0, + Heading1 = 1, + Heading2 = 2, + Heading3 = 3, + Default = 4, +}; + +/// Filter log messages +bool +polysat_should_log(unsigned verbose_lvl, LogLevel msg_level, std::string fn, std::string pretty_fn); + +std::pair +polysat_log(LogLevel msg_level, std::string fn, std::string pretty_fn); + +#ifdef _MSC_VER +#define __PRETTY_FUNCTION__ __FUNCSIG__ +#endif + +#define LOG_(verbose_lvl, log_lvl, x) \ + do { \ + if (polysat_should_log(verbose_lvl, log_lvl, __func__, __PRETTY_FUNCTION__)) { \ + auto pair = polysat_log(log_lvl, __func__, __PRETTY_FUNCTION__); \ + std::ostream& os = pair.first; \ + bool should_reset = pair.second; \ + os << x; \ + if (should_reset) \ + os << color_reset(); \ + os << std::endl; \ + } \ + } while (false) + +#define LOG_CONCAT_HELPER(a,b) a ## b +#define LOG_CONCAT(a,b) LOG_CONCAT_HELPER(a,b) + +#define LOG_INDENT(verbose_lvl, log_lvl, x) \ + LOG_(verbose_lvl, log_lvl, x); \ + polysat_log_indent LOG_CONCAT(polysat_log_indent_obj_, __LINE__) (4); + +#define LOG_H1(x) LOG_INDENT(0, LogLevel::Heading1, x) +#define LOG_H2(x) LOG_INDENT(0, LogLevel::Heading2, x) +#define LOG_H3(x) LOG_INDENT(0, LogLevel::Heading3, x) +#define LOG(x) LOG_(0, LogLevel::Default , x) + +#define LOG_H1_V(verbose_lvl, x) LOG_INDENT(verbose_lvl, LogLevel::Heading1, x) +#define LOG_H2_V(verbose_lvl, x) LOG_INDENT(verbose_lvl, LogLevel::Heading2, x) +#define LOG_H3_V(verbose_lvl, x) LOG_INDENT(verbose_lvl, LogLevel::Heading3, x) +#define LOG_V(verbose_lvl, x) LOG_(verbose_lvl, LogLevel::Default , x) + +#define COND_LOG(c, x) if (c) LOG(x) +#define LOGE(x) LOG(#x << " = " << (x)) + +#define IF_LOGGING(x) \ + do { \ + if (get_log_enabled()) { \ + x; \ + } \ + } while (false) + + +#else // POLYSAT_LOGGING_ENABLED + +inline void set_log_enabled(bool) {} +inline bool get_log_enabled() { return false; } +class scoped_set_log_enabled { +public: + scoped_set_log_enabled(bool) {} +}; + +#define LOG_(vlvl, lvl, x) \ + do { \ + /* do nothing */ \ + } while (false) + +#define LOG_H1(x) LOG_(0, 0, x) +#define LOG_H2(x) LOG_(0, 0, x) +#define LOG_H3(x) LOG_(0, 0, x) +#define LOG(x) LOG_(0, 0, x) + +#define LOG_H1_V(v, x) LOG_(v, 0, x) +#define LOG_H2_V(v, x) LOG_(v, 0, x) +#define LOG_H3_V(v, x) LOG_(v, 0, x) +#define LOG_V(v, x) LOG_(v, 0, x) + +#define COND_LOG(c, x) LOG_(0, c, x) +#define LOGE(x) LOG_(0, 0, x) + +#define IF_LOGGING(x) \ + do { \ + /* do nothing */ \ + } while (false) + +#endif // POLYSAT_LOGGING_ENABLED + + +#endif // POLYSAT_LOG_HPP diff --git a/src/util/log_helper.h b/src/util/log_helper.h new file mode 100644 index 000000000..7be2693b4 --- /dev/null +++ b/src/util/log_helper.h @@ -0,0 +1,105 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + Logging support + +Abstract: + + Utilities for logging. + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-6 + +--*/ +#pragma once + +#include +#include +#include +#include + +template +struct show_deref_t { + T const* ptr; +}; + +template +std::ostream& operator<<(std::ostream& out, show_deref_t s) { + if (s.ptr) + return out << *s.ptr; + else + return out << ""; +} + +template +show_deref_t show_deref(T* ptr) { + return show_deref_t{ptr}; +} + +template ().get())>::type> +show_deref_t show_deref(Ptr const& ptr) { + return show_deref_t{ptr.get()}; +} + + +template +struct repeat { + size_t count; + T const& obj; + repeat(size_t count, T const& obj): count(count), obj(obj) {} +}; + +template +std::ostream& operator<<(std::ostream& out, repeat const& r) { + for (size_t i = r.count; i-- > 0; ) + out << r.obj; + return out; +} + +enum class pad_direction { + left, + right, +}; + +template +struct pad { + pad_direction dir; + unsigned width; + T const& obj; + pad(pad_direction dir, unsigned width, T const& obj): dir(dir), width(width), obj(obj) {} +}; + +template +std::ostream& operator<<(std::ostream& out, pad const& p) { + std::stringstream tmp; + tmp << p.obj; + std::string s = tmp.str(); + size_t n = (s.length() < p.width) ? (p.width - s.length()) : 0; + switch (p.dir) { + case pad_direction::left: + out << repeat(n, ' ') << s; + break; + case pad_direction::right: + out << s << repeat(n, ' '); + break; + } + return out; +} + +/// Fill with spaces to the right: +/// out << rpad(8, "hello") +/// writes "hello ". +template +pad rpad(unsigned width, T const& obj) { + return pad(pad_direction::right, width, obj); +} + +/// Fill with spaces to the left. +template +pad lpad(unsigned width, T const& obj) { + return pad(pad_direction::left, width, obj); +} From ddb55cc3dcc371ec5e73b48acf931269f1921218 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 14:50:33 -0800 Subject: [PATCH 11/89] porting viable --- src/sat/smt/polysat_core.cpp | 4 +- src/sat/smt/polysat_types.h | 7 +- src/sat/smt/polysat_viable.cpp | 196 +++++++++++++++++++++++++++++++++ src/sat/smt/polysat_viable.h | 77 ++++++++++++- src/util/rational.h | 10 ++ 5 files changed, 285 insertions(+), 9 deletions(-) diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp index 3de88d93b..d4b62e8b6 100644 --- a/src/sat/smt/polysat_core.cpp +++ b/src/sat/smt/polysat_core.cpp @@ -40,7 +40,7 @@ namespace polysat { public: mk_assign_var(pvar v, core& c) : m_var(v), c(c) {} void undo() { - c.m_justification[m_var] = dependency::null_dependency(); + c.m_justification[m_var] = null_dependency; c.m_assignment.pop(); } }; @@ -106,7 +106,7 @@ namespace polysat { unsigned v = m_vars.size(); m_vars.push_back(sz2pdd(sz).mk_var(v)); m_activity.push_back({ sz, 0 }); - m_justification.push_back(dependency::null_dependency()); + m_justification.push_back(null_dependency); m_watch.push_back({}); m_var_queue.mk_var_eh(v); s.ctx.push(mk_add_var(*this)); diff --git a/src/sat/smt/polysat_types.h b/src/sat/smt/polysat_types.h index 42f283cc7..c8c8324d7 100644 --- a/src/sat/smt/polysat_types.h +++ b/src/sat/smt/polysat_types.h @@ -17,13 +17,16 @@ namespace polysat { using pdd = dd::pdd; using pvar = unsigned; + using pvar_vector = unsigned_vector; + inline const pvar null_var = UINT_MAX; + + class dependency { unsigned m_index; unsigned m_level; public: dependency(sat::literal lit, unsigned level) : m_index(2 * lit.index()), m_level(level) {} dependency(unsigned var_idx, unsigned level) : m_index(1 + 2 * var_idx), m_level(level) {} - static dependency null_dependency() { return dependency(0, UINT_MAX); } bool is_null() const { return m_level == UINT_MAX; } bool is_literal() const { return m_index % 2 == 0; } sat::literal literal() const { SASSERT(is_literal()); return sat::to_literal(m_index / 2); } @@ -31,6 +34,8 @@ namespace polysat { unsigned level() const { return m_level; } }; + inline const dependency null_dependency = dependency(0, UINT_MAX); + inline std::ostream& operator<<(std::ostream& out, dependency d) { if (d.is_literal()) return out << d.literal() << "@" << d.level(); diff --git a/src/sat/smt/polysat_viable.cpp b/src/sat/smt/polysat_viable.cpp index 79689d01f..deb6d415c 100644 --- a/src/sat/smt/polysat_viable.cpp +++ b/src/sat/smt/polysat_viable.cpp @@ -17,11 +17,21 @@ Notes: #include "util/debug.h" +#include "util/log.h" #include "sat/smt/polysat_viable.h" #include "sat/smt/polysat_core.h" namespace polysat { + using dd::val_pp; + + viable::viable(core& c) : c(c), cs(c.cs()), m_forbidden_intervals(c) {} + + viable::~viable() { + for (auto* e : m_alloc) + dealloc(e); + } + std::ostream& operator<<(std::ostream& out, find_t f) { switch (f) { case find_t::empty: return out << "empty"; @@ -32,5 +42,191 @@ namespace polysat { } } + viable::entry* viable::alloc_entry(pvar var) { + if (m_alloc.empty()) + return alloc(entry); + auto* e = m_alloc.back(); + e->reset(); + e->var = var; + m_alloc.pop_back(); + return e; + } + + find_t viable::find_viable(pvar v, rational& out_val) { throw default_exception("nyi"); } + + /* + * Explain why the current variable is not viable or signleton. + */ + dependency_vector viable::explain() { throw default_exception("nyi"); } + + /* + * Register constraint at index 'idx' as unitary in v. + */ + void viable::add_unitary(pvar v, unsigned idx) { + if (c.is_assigned(v)) + return; + auto [sc, d] = c.m_constraint_trail[idx]; + + entry* ne = alloc_entry(v); + if (!m_forbidden_intervals.get_interval(sc, v, *ne)) { + m_alloc.push_back(ne); + return; + } + + if (ne->interval.is_currently_empty()) { + m_alloc.push_back(ne); + return; + } + + if (ne->coeff == 1) { + intersect(v, ne); + return; + } + else if (ne->coeff == -1) { + insert(ne, v, m_diseq_lin, entry_kind::diseq_e); + return; + } + else { + unsigned const w = c.size(v); + unsigned const k = ne->coeff.parity(w); + // unsigned const lo_parity = ne->interval.lo_val().parity(w); + // unsigned const hi_parity = ne->interval.hi_val().parity(w); + + display_one(std::cerr << "try to reduce entry: ", v, ne) << "\n"; + + if (k > 0 && ne->coeff.is_power_of_two()) { + // reduction of coeff gives us a unit entry + // + // 2^k a x \not\in [ lo ; hi [ + // + // new_lo = lo[w-1:k] if lo[k-1:0] = 0 + // lo[w-1:k] + 1 otherwise + // + // new_hi = hi[w-1:k] if hi[k-1:0] = 0 + // hi[w-1:k] + 1 otherwise + // + // Reference: Fig. 1 (dtrim) in BitvectorsMCSAT + // + pdd const& pdd_lo = ne->interval.lo(); + pdd const& pdd_hi = ne->interval.hi(); + rational const& lo = ne->interval.lo_val(); + rational const& hi = ne->interval.hi_val(); + + rational new_lo = machine_div2k(lo, k); + if (mod2k(lo, k).is_zero()) + ne->side_cond.push_back(cs.eq(pdd_lo * rational::power_of_two(w - k))); + else { + new_lo += 1; + ne->side_cond.push_back(~cs.eq(pdd_lo * rational::power_of_two(w - k))); + } + + rational new_hi = machine_div2k(hi, k); + if (mod2k(hi, k).is_zero()) + ne->side_cond.push_back(cs.eq(pdd_hi * rational::power_of_two(w - k))); + else { + new_hi += 1; + ne->side_cond.push_back(~cs.eq(pdd_hi * rational::power_of_two(w - k))); + } + + // we have to update also the pdd bounds accordingly, but it seems not worth introducing new variables for this eagerly + // new_lo = lo[:k] etc. + // TODO: for now just disable the FI-lemma if this case occurs + ne->valid_for_lemma = false; + + if (new_lo == new_hi) { + // empty or full + // if (ne->interval.currently_contains(rational::zero())) + NOT_IMPLEMENTED_YET(); + } + + ne->coeff = machine_div2k(ne->coeff, k); + ne->interval = eval_interval::proper(pdd_lo, new_lo, pdd_hi, new_hi); + ne->bit_width -= k; + display_one(std::cerr << "reduced entry: ", v, ne) << "\n"; + LOG("reduced entry to unit in bitwidth " << ne->bit_width); + return intersect(v, ne); + } + + // TODO: later, can reduce according to shared_parity + // unsigned const shared_parity = std::min(coeff_parity, std::min(lo_parity, hi_parity)); + + insert(ne, v, m_equal_lin, entry_kind::equal_e); + return; + } + + } + + void viable::intersect(pvar v, entry* e) { + + throw default_exception("nyi"); + } + + void viable::log() { + for (pvar v = 0; v < m_units.size(); ++v) + log(v); + } + + void viable::log(pvar v) { + throw default_exception("nyi"); + } + + void viable::insert(entry* e, pvar v, ptr_vector& entries, entry_kind k) { + throw default_exception("nyi"); + } + + std::ostream& viable::display_one(std::ostream& out, pvar v, entry const* e) const { + auto& m = c.var2pdd(v); + if (e->coeff == -1) { + // p*val + q > r*val + s if e->src.is_positive() + // p*val + q >= r*val + s if e->src.is_negative() + // Note that e->interval is meaningless in this case, + // we just use it to transport the values p,q,r,s + rational const& p = e->interval.lo_val(); + rational const& q_ = e->interval.lo().val(); + rational const& r = e->interval.hi_val(); + rational const& s_ = e->interval.hi().val(); + out << "[ "; + out << val_pp(m, p, true) << "*v" << v << " + " << val_pp(m, q_); + out << (e->src[0].is_positive() ? " > " : " >= "); + out << val_pp(m, r, true) << "*v" << v << " + " << val_pp(m, s_); + out << " ] "; + } + else if (e->coeff != 1) + out << e->coeff << " * v" << v << " " << e->interval << " "; + else + out << e->interval << " "; + if (e->side_cond.size() <= 5) + out << e->side_cond << " "; + else + out << e->side_cond.size() << " side-conditions "; + unsigned count = 0; + for (const auto& src : e->src) { + ++count; + out << src << "; "; + if (count > 10) { + out << " ..."; + break; + } + } + return out; + } + + std::ostream& viable::display_all(std::ostream& out, pvar v, entry const* e, char const* delimiter) const { + if (!e) + return out; + entry const* first = e; + unsigned count = 0; + do { + display_one(out, v, e) << delimiter; + e = e->next(); + ++count; + if (count > 10) { + out << " ..."; + break; + } + } + while (e != first); + return out; + } } diff --git a/src/sat/smt/polysat_viable.h b/src/sat/smt/polysat_viable.h index 2f87e79cc..322b30551 100644 --- a/src/sat/smt/polysat_viable.h +++ b/src/sat/smt/polysat_viable.h @@ -17,7 +17,12 @@ Author: #pragma once #include "util/rational.h" +#include "util/dlist.h" +#include "util/map.h" +#include "util/small_object_allocator.h" + #include "sat/smt/polysat_types.h" +#include "sat/smt/polysat_fi.h" namespace polysat { @@ -29,28 +34,88 @@ namespace polysat { }; class core; + class constraints; std::ostream& operator<<(std::ostream& out, find_t x); class viable { core& c; - public: - viable(core& c) : c(c) {} + constraints& cs; + forbidden_intervals m_forbidden_intervals; - /** + struct entry final : public dll_base, public fi_record { + /// whether the entry has been created by refinement (from constraints in 'fi_record::src') + bool refined = false; + /// whether the entry is part of the current set of intervals, or stashed away for backtracking + bool active = true; + bool valid_for_lemma = true; + pvar var = null_var; + + void reset() { + // dll_base::init(this); // we never did this in alloc_entry either + fi_record::reset(); + refined = false; + active = true; + valid_for_lemma = true; + var = null_var; + } + }; + + enum class entry_kind { unit_e, equal_e, diseq_e }; + + struct layer final { + entry* entries = nullptr; + unsigned bit_width = 0; + layer(unsigned bw) : bit_width(bw) {} + }; + + class layers final { + svector m_layers; + public: + svector const& get_layers() const { return m_layers; } + layer& ensure_layer(unsigned bit_width); + layer* get_layer(unsigned bit_width); + layer* get_layer(entry* e) { return get_layer(e->bit_width); } + layer const* get_layer(unsigned bit_width) const; + layer const* get_layer(entry* e) const { return get_layer(e->bit_width); } + entry* get_entries(unsigned bit_width) const { layer const* l = get_layer(bit_width); return l ? l->entries : nullptr; } + }; + + ptr_vector m_alloc; + vector m_units; // set of viable values based on unit multipliers, layered by bit-width in descending order + ptr_vector m_equal_lin; // entries that have non-unit multipliers, but are equal + ptr_vector m_diseq_lin; // entries that have distinct non-zero multipliers + + entry* alloc_entry(pvar v); + + std::ostream& display_one(std::ostream& out, pvar v, entry const* e) const; + std::ostream& display_all(std::ostream& out, pvar v, entry const* e, char const* delimiter = "") const; + void log(); + void log(pvar v); + + void insert(entry* e, pvar v, ptr_vector& entries, entry_kind k); + + void intersect(pvar v, entry* e); + + + public: + viable(core& c); + ~viable(); + + /** * Find a next viable value for variable. */ - find_t find_viable(pvar v, rational& out_val) { throw default_exception("nyi"); } + find_t find_viable(pvar v, rational& out_val); /* * Explain why the current variable is not viable or signleton. */ - dependency_vector explain() { throw default_exception("nyi"); } + dependency_vector explain(); /* * Register constraint at index 'idx' as unitary in v. */ - void add_unitary(pvar v, unsigned idx) { throw default_exception("nyi"); } + void add_unitary(pvar v, unsigned idx); }; diff --git a/src/util/rational.h b/src/util/rational.h index f47fddefe..4253bd4a7 100644 --- a/src/util/rational.h +++ b/src/util/rational.h @@ -501,6 +501,16 @@ public: return k; } + /** Number of trailing zeros in an N-bit representation */ + unsigned parity(unsigned num_bits) const { + SASSERT(!is_neg()); + SASSERT(*this < rational::power_of_two(num_bits)); + if (is_zero()) + return num_bits; + return trailing_zeros(); + } + + static bool limit_denominator(rational &num, rational const& limit); }; From 45f3aab5ff7facb1ef7b2cbf6a0332df5e417e4f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 15:28:05 -0800 Subject: [PATCH 12/89] porting viable --- src/sat/smt/polysat_core.cpp | 4 + src/sat/smt/polysat_core.h | 1 + src/sat/smt/polysat_viable.cpp | 259 ++++++++++++++++++++++++++++++++- src/sat/smt/polysat_viable.h | 12 +- src/util/dlist.h | 10 ++ 5 files changed, 277 insertions(+), 9 deletions(-) diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp index d4b62e8b6..be4bb0cf0 100644 --- a/src/sat/smt/polysat_core.cpp +++ b/src/sat/smt/polysat_core.cpp @@ -308,4 +308,8 @@ namespace polysat { return m_assignment.apply_to(p); } + trail_stack& core::trail() { + return s.get_trail_stack(); + } + } diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat_core.h index 3c8a79bd6..b7c9f9eb8 100644 --- a/src/sat/smt/polysat_core.h +++ b/src/sat/smt/polysat_core.h @@ -128,6 +128,7 @@ namespace polysat { unsigned size(pvar v) const { return var2pdd(v).power_of_2(); } constraints& cs() { return m_constraints; } + trail_stack& trail(); std::ostream& display(std::ostream& out) const { throw default_exception("nyi"); } }; diff --git a/src/sat/smt/polysat_viable.cpp b/src/sat/smt/polysat_viable.cpp index deb6d415c..d68822563 100644 --- a/src/sat/smt/polysat_viable.cpp +++ b/src/sat/smt/polysat_viable.cpp @@ -42,6 +42,32 @@ namespace polysat { } } + struct viable::pop_viable_trail : public trail { + viable& m_s; + entry* e; + pvar v; + entry_kind k; + public: + pop_viable_trail(viable& s, entry* e, pvar v, entry_kind k) + : m_s(s), e(e), v(v), k(k) {} + void undo() override { + m_s.pop_viable(e, v, k); + } + }; + + struct viable::push_viable_trail : public trail { + viable& m_s; + entry* e; + pvar v; + entry_kind k; + public: + push_viable_trail(viable& s, entry* e, pvar v, entry_kind k) + : m_s(s), e(e), v(v), k(k) {} + void undo() override { + m_s.push_viable(e, v, k); + } + }; + viable::entry* viable::alloc_entry(pvar var) { if (m_alloc.empty()) return alloc(entry); @@ -52,7 +78,10 @@ namespace polysat { return e; } - find_t viable::find_viable(pvar v, rational& out_val) { throw default_exception("nyi"); } + find_t viable::find_viable(pvar v, rational& out_val) { + ensure_var(v); + throw default_exception("nyi"); + } /* * Explain why the current variable is not viable or signleton. @@ -63,6 +92,9 @@ namespace polysat { * Register constraint at index 'idx' as unitary in v. */ void viable::add_unitary(pvar v, unsigned idx) { + + ensure_var(v); + if (c.is_assigned(v)) return; auto [sc, d] = c.m_constraint_trail[idx]; @@ -76,7 +108,7 @@ namespace polysat { if (ne->interval.is_currently_empty()) { m_alloc.push_back(ne); return; - } + } if (ne->coeff == 1) { intersect(v, ne); @@ -144,7 +176,7 @@ namespace polysat { ne->bit_width -= k; display_one(std::cerr << "reduced entry: ", v, ne) << "\n"; LOG("reduced entry to unit in bitwidth " << ne->bit_width); - return intersect(v, ne); + intersect(v, ne); } // TODO: later, can reduce according to shared_parity @@ -153,12 +185,92 @@ namespace polysat { insert(ne, v, m_equal_lin, entry_kind::equal_e); return; } - } - void viable::intersect(pvar v, entry* e) { + void viable::ensure_var(pvar v) { + while (v >= m_units.size()) { + m_units.push_back(layers()); + m_equal_lin.push_back(nullptr); + m_diseq_lin.push_back(nullptr); + } + } - throw default_exception("nyi"); + bool viable::intersect(pvar v, entry* ne) { + SASSERT(!c.is_assigned(v)); + SASSERT(!ne->src.empty()); + entry*& entries = m_units[v].ensure_layer(ne->bit_width).entries; + entry* e = entries; + if (e && e->interval.is_full()) { + m_alloc.push_back(ne); + return false; + } + + if (ne->interval.is_currently_empty()) { + m_alloc.push_back(ne); + return false; + } + + auto create_entry = [&]() { + c.trail().push(pop_viable_trail(*this, ne, v, entry_kind::unit_e)); + ne->init(ne); + return ne; + }; + + auto remove_entry = [&](entry* e) { + c.trail().push(push_viable_trail(*this, e, v, entry_kind::unit_e)); + e->remove_from(entries, e); + e->active = false; + }; + + if (ne->interval.is_full()) { + // for (auto const& l : m_units[v].get_layers()) + // while (l.entries) + // remove_entry(l.entries); + while (entries) + remove_entry(entries); + entries = create_entry(); + return true; + } + + if (!e) + entries = create_entry(); + else { + entry* first = e; + do { + if (e->interval.currently_contains(ne->interval)) { + m_alloc.push_back(ne); + return false; + } + while (ne->interval.currently_contains(e->interval)) { + entry* n = e->next(); + remove_entry(e); + if (!entries) { + entries = create_entry(); + return true; + } + if (e == first) + first = n; + e = n; + } + SASSERT(e->interval.lo_val() != ne->interval.lo_val()); + if (e->interval.lo_val() > ne->interval.lo_val()) { + if (first->prev()->interval.currently_contains(ne->interval)) { + m_alloc.push_back(ne); + return false; + } + e->insert_before(create_entry()); + if (e == first) + entries = e->prev(); + SASSERT(well_formed(m_units[v])); + return true; + } + e = e->next(); + } while (e != first); + // otherwise, append to end of list + first->insert_before(create_entry()); + } + SASSERT(well_formed(m_units[v])); + return true; } void viable::log() { @@ -170,10 +282,94 @@ namespace polysat { throw default_exception("nyi"); } - void viable::insert(entry* e, pvar v, ptr_vector& entries, entry_kind k) { - throw default_exception("nyi"); + + viable::layer& viable::layers::ensure_layer(unsigned bit_width) { + for (unsigned i = 0; i < m_layers.size(); ++i) { + layer& l = m_layers[i]; + if (l.bit_width == bit_width) + return l; + else if (l.bit_width < bit_width) { + m_layers.push_back(layer(0)); + for (unsigned j = m_layers.size(); --j > i; ) + m_layers[j] = m_layers[j - 1]; + m_layers[i] = layer(bit_width); + return m_layers[i]; + } + } + m_layers.push_back(layer(bit_width)); + return m_layers.back(); } + viable::layer* viable::layers::get_layer(unsigned bit_width) { + return const_cast(std::as_const(*this).get_layer(bit_width)); + } + + viable::layer const* viable::layers::get_layer(unsigned bit_width) const { + for (layer const& l : m_layers) + if (l.bit_width == bit_width) + return &l; + return nullptr; + } + + void viable::pop_viable(entry* e, pvar v, entry_kind k) { + SASSERT(well_formed(m_units[v])); + SASSERT(e->active); + e->active = false; + switch (k) { + case entry_kind::unit_e: + entry::remove_from(m_units[v].get_layer(e)->entries, e); + SASSERT(well_formed(m_units[v])); + break; + case entry_kind::equal_e: + entry::remove_from(m_equal_lin[v], e); + break; + case entry_kind::diseq_e: + entry::remove_from(m_diseq_lin[v], e); + break; + default: + UNREACHABLE(); + break; + } + m_alloc.push_back(e); + } + + void viable::push_viable(entry* e, pvar v, entry_kind k) { + // display_one(verbose_stream() << "Push entry: ", v, e) << "\n"; + entry*& entries = m_units[v].get_layer(e)->entries; + SASSERT(e->prev() != e || !entries); + SASSERT(e->prev() != e || e->next() == e); + SASSERT(k == entry_kind::unit_e); + SASSERT(!e->active); + e->active = true; + (void)k; + SASSERT(well_formed(m_units[v])); + if (e->prev() != e) { + entry* pos = e->prev(); + e->init(e); + pos->insert_after(e); + if (e->interval.lo_val() < entries->interval.lo_val()) + entries = e; + } + else + entries = e; + SASSERT(well_formed(m_units[v])); + } + + void viable::insert(entry* e, pvar v, ptr_vector& entries, entry_kind k) { + SASSERT(well_formed(m_units[v])); + + c.trail().push(pop_viable_trail(*this, e, v, k)); + + e->init(e); + if (!entries[v]) + entries[v] = e; + else + e->insert_after(entries[v]); + SASSERT(entries[v]->invariant()); + SASSERT(well_formed(m_units[v])); + } + + std::ostream& viable::display_one(std::ostream& out, pvar v, entry const* e) const { auto& m = c.var2pdd(v); if (e->coeff == -1) { @@ -229,4 +425,51 @@ namespace polysat { return out; } + /* + * Lower bounds are strictly ascending. + * Intervals don't contain each-other (since lower bounds are ascending, it suffices to check containment in one direction). + */ + bool viable::well_formed(entry* e) { + if (!e) + return true; + entry* first = e; + while (true) { + if (!e->active) + return false; + + if (e->interval.is_full()) + return e->next() == e; + if (e->interval.is_currently_empty()) + return false; + + auto* n = e->next(); + if (n != e && e->interval.currently_contains(n->interval)) + return false; + + if (n == first) + break; + if (e->interval.lo_val() >= n->interval.lo_val()) + return false; + e = n; + } + return true; + } + + /* + * Layers are ordered in strictly descending bit-width. + * Entries in each layer are well-formed. + */ + bool viable::well_formed(layers const& ls) { + unsigned prev_width = std::numeric_limits::max(); + for (layer const& l : ls.get_layers()) { + if (!well_formed(l.entries)) + return false; + if (!all_of(dll_elements(l.entries), [&l](entry const& e) { return e.bit_width == l.bit_width; })) + return false; + if (prev_width <= l.bit_width) + return false; + prev_width = l.bit_width; + } + return true; + } } diff --git a/src/sat/smt/polysat_viable.h b/src/sat/smt/polysat_viable.h index 322b30551..f1c826d10 100644 --- a/src/sat/smt/polysat_viable.h +++ b/src/sat/smt/polysat_viable.h @@ -86,6 +86,9 @@ namespace polysat { ptr_vector m_equal_lin; // entries that have non-unit multipliers, but are equal ptr_vector m_diseq_lin; // entries that have distinct non-zero multipliers + bool well_formed(entry* e); + bool well_formed(layers const& ls); + entry* alloc_entry(pvar v); std::ostream& display_one(std::ostream& out, pvar v, entry const* e) const; @@ -93,13 +96,20 @@ namespace polysat { void log(); void log(pvar v); + struct pop_viable_trail; + void pop_viable(entry* e, pvar v, entry_kind k); + struct push_viable_trail; + void push_viable(entry* e, pvar v, entry_kind k); + void insert(entry* e, pvar v, ptr_vector& entries, entry_kind k); - void intersect(pvar v, entry* e); + bool intersect(pvar v, entry* e); + void ensure_var(pvar v); public: viable(core& c); + ~viable(); /** diff --git a/src/util/dlist.h b/src/util/dlist.h index e5c95b8cf..07aefa97e 100644 --- a/src/util/dlist.h +++ b/src/util/dlist.h @@ -223,6 +223,16 @@ public: } }; +template +class dll_elements { + T const* m_list; +public: + dll_elements(T const* list) : m_list(list) {} + dll_iterator begin() const { return dll_iterator::mk_begin(m_list); } + dll_iterator end() const { return dll_iterator::mk_end(m_list); } +}; + + template < typename T , typename U = std::enable_if_t, T>> // should only match if T actually inherits from dll_base > From 9bfecead73c67fea0e00cea44a3da7c6b8c7eca0 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 9 Dec 2023 09:38:18 -0800 Subject: [PATCH 13/89] reorganize polysat functionality to use abstract solver interface make dependency be self-contained --- scripts/mk_project.py | 3 +- src/CMakeLists.txt | 1 + src/sat/smt/CMakeLists.txt | 7 - src/sat/smt/polysat/CMakeLists.txt | 15 ++ .../smt/{ => polysat}/polysat_assignment.cpp | 4 +- .../smt/{ => polysat}/polysat_assignment.h | 2 +- .../smt/{ => polysat}/polysat_constraints.cpp | 9 +- .../smt/{ => polysat}/polysat_constraints.h | 2 +- src/sat/smt/{ => polysat}/polysat_core.cpp | 29 ++- src/sat/smt/{ => polysat}/polysat_core.h | 14 +- src/sat/smt/{ => polysat}/polysat_fi.cpp | 10 +- src/sat/smt/{ => polysat}/polysat_fi.h | 6 +- src/sat/smt/polysat/polysat_interval.h | 224 ++++++++++++++++++ src/sat/smt/polysat/polysat_types.h | 67 ++++++ src/sat/smt/{ => polysat}/polysat_ule.cpp | 4 +- src/sat/smt/{ => polysat}/polysat_ule.h | 5 +- .../smt/{ => polysat}/polysat_umul_ovfl.cpp | 6 +- src/sat/smt/{ => polysat}/polysat_umul_ovfl.h | 2 +- src/sat/smt/{ => polysat}/polysat_viable.cpp | 4 +- src/sat/smt/{ => polysat}/polysat_viable.h | 4 +- src/sat/smt/polysat_solver.cpp | 18 +- src/sat/smt/polysat_solver.h | 20 +- src/sat/smt/polysat_types.h | 48 ---- 23 files changed, 381 insertions(+), 123 deletions(-) create mode 100644 src/sat/smt/polysat/CMakeLists.txt rename src/sat/smt/{ => polysat}/polysat_assignment.cpp (97%) rename src/sat/smt/{ => polysat}/polysat_assignment.h (98%) rename src/sat/smt/{ => polysat}/polysat_constraints.cpp (85%) rename src/sat/smt/{ => polysat}/polysat_constraints.h (99%) rename src/sat/smt/{ => polysat}/polysat_core.cpp (92%) rename src/sat/smt/{ => polysat}/polysat_core.h (94%) rename src/sat/smt/{ => polysat}/polysat_fi.cpp (98%) rename src/sat/smt/{ => polysat}/polysat_fi.h (96%) create mode 100644 src/sat/smt/polysat/polysat_interval.h create mode 100644 src/sat/smt/polysat/polysat_types.h rename src/sat/smt/{ => polysat}/polysat_ule.cpp (99%) rename src/sat/smt/{ => polysat}/polysat_ule.h (94%) rename src/sat/smt/{ => polysat}/polysat_umul_ovfl.cpp (92%) rename src/sat/smt/{ => polysat}/polysat_umul_ovfl.h (95%) rename src/sat/smt/{ => polysat}/polysat_viable.cpp (99%) rename src/sat/smt/{ => polysat}/polysat_viable.h (97%) delete mode 100644 src/sat/smt/polysat_types.h diff --git a/scripts/mk_project.py b/scripts/mk_project.py index 77bb1b680..0f5dc26ae 100644 --- a/scripts/mk_project.py +++ b/scripts/mk_project.py @@ -58,7 +58,8 @@ def init_project_def(): add_lib('proto_model', ['model', 'rewriter', 'smt_params'], 'smt/proto_model') add_lib('smt', ['bit_blaster', 'macros', 'normal_forms', 'cmd_context', 'proto_model', 'solver_assertions', 'substitution', 'grobner', 'simplex', 'proofs', 'pattern', 'parser_util', 'fpa', 'lp']) - add_lib('sat_smt', ['sat', 'euf', 'smt', 'tactic', 'solver', 'smt_params', 'bit_blaster', 'fpa', 'mbp', 'normal_forms', 'lp', 'pattern', 'qe_lite'], 'sat/smt') + add_lib('polysat', ['util', 'dd'], 'sat/smt/polysat'), + add_lib('sat_smt', ['sat', 'euf', 'smt', 'tactic', 'solver', 'smt_params', 'bit_blaster', 'fpa', 'mbp', 'polysat', 'normal_forms', 'lp', 'pattern', 'qe_lite'], 'sat/smt') add_lib('sat_tactic', ['tactic', 'sat', 'solver', 'sat_smt'], 'sat/tactic') add_lib('nlsat_tactic', ['nlsat', 'sat_tactic', 'arith_tactics'], 'nlsat/tactic') add_lib('bv_tactics', ['tactic', 'bit_blaster', 'core_tactics'], 'tactic/bv') diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a12571f35..13fd58db8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -73,6 +73,7 @@ add_subdirectory(parsers/smt2) add_subdirectory(solver/assertions) add_subdirectory(ast/pattern) add_subdirectory(math/lp) +add_subdirectory(sat/smt/polysat) add_subdirectory(sat/smt) add_subdirectory(sat/tactic) add_subdirectory(nlsat/tactic) diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index bdc602da9..95d0a5324 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -33,16 +33,9 @@ z3_add_component(sat_smt pb_internalize.cpp pb_pb.cpp pb_solver.cpp - polysat_assignment.cpp - polysat_constraints.cpp - polysat_core.cpp polysat_internalize.cpp - polysat_fi.cpp polysat_model.cpp polysat_solver.cpp - polysat_ule.cpp - polysat_umul_ovfl.cpp - polysat_viable.cpp q_clause.cpp q_ematch.cpp q_eval.cpp diff --git a/src/sat/smt/polysat/CMakeLists.txt b/src/sat/smt/polysat/CMakeLists.txt new file mode 100644 index 000000000..0011e0ee5 --- /dev/null +++ b/src/sat/smt/polysat/CMakeLists.txt @@ -0,0 +1,15 @@ +z3_add_component(polysat + SOURCES + polysat_assignment.cpp + polysat_constraints.cpp + polysat_core.cpp + polysat_fi.cpp + polysat_ule.cpp + polysat_umul_ovfl.cpp + polysat_viable.cpp + COMPONENT_DEPENDENCIES + util + dd + smt_params +) + diff --git a/src/sat/smt/polysat_assignment.cpp b/src/sat/smt/polysat/polysat_assignment.cpp similarity index 97% rename from src/sat/smt/polysat_assignment.cpp rename to src/sat/smt/polysat/polysat_assignment.cpp index a985188fa..aedf6d409 100644 --- a/src/sat/smt/polysat_assignment.cpp +++ b/src/sat/smt/polysat/polysat_assignment.cpp @@ -12,8 +12,8 @@ Author: --*/ -#include "sat/smt/polysat_assignment.h" -#include "sat/smt/polysat_core.h" +#include "sat/smt/polysat/polysat_assignment.h" +#include "sat/smt/polysat/polysat_core.h" namespace polysat { diff --git a/src/sat/smt/polysat_assignment.h b/src/sat/smt/polysat/polysat_assignment.h similarity index 98% rename from src/sat/smt/polysat_assignment.h rename to src/sat/smt/polysat/polysat_assignment.h index daff03dd5..befaad0b7 100644 --- a/src/sat/smt/polysat_assignment.h +++ b/src/sat/smt/polysat/polysat_assignment.h @@ -13,7 +13,7 @@ Author: --*/ #pragma once #include "util/scoped_ptr_vector.h" -#include "sat/smt/polysat_types.h" +#include "sat/smt/polysat/polysat_types.h" namespace polysat { diff --git a/src/sat/smt/polysat_constraints.cpp b/src/sat/smt/polysat/polysat_constraints.cpp similarity index 85% rename from src/sat/smt/polysat_constraints.cpp rename to src/sat/smt/polysat/polysat_constraints.cpp index b588019dc..99da7b0db 100644 --- a/src/sat/smt/polysat_constraints.cpp +++ b/src/sat/smt/polysat/polysat_constraints.cpp @@ -12,11 +12,10 @@ Author: --*/ -#include "sat/smt/polysat_core.h" -#include "sat/smt/polysat_solver.h" -#include "sat/smt/polysat_constraints.h" -#include "sat/smt/polysat_ule.h" -#include "sat/smt/polysat_umul_ovfl.h" +#include "sat/smt/polysat/polysat_core.h" +#include "sat/smt/polysat/polysat_constraints.h" +#include "sat/smt/polysat/polysat_ule.h" +#include "sat/smt/polysat/polysat_umul_ovfl.h" namespace polysat { diff --git a/src/sat/smt/polysat_constraints.h b/src/sat/smt/polysat/polysat_constraints.h similarity index 99% rename from src/sat/smt/polysat_constraints.h rename to src/sat/smt/polysat/polysat_constraints.h index b62fbba21..687b3d91a 100644 --- a/src/sat/smt/polysat_constraints.h +++ b/src/sat/smt/polysat/polysat_constraints.h @@ -15,7 +15,7 @@ Author: #pragma once #include "util/trail.h" -#include "sat/smt/polysat_types.h" +#include "sat/smt/polysat/polysat_types.h" namespace polysat { diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat/polysat_core.cpp similarity index 92% rename from src/sat/smt/polysat_core.cpp rename to src/sat/smt/polysat/polysat_core.cpp index be4bb0cf0..07eeaa0c1 100644 --- a/src/sat/smt/polysat_core.cpp +++ b/src/sat/smt/polysat/polysat_core.cpp @@ -29,8 +29,7 @@ polysat::core --*/ #include "params/bv_rewriter_params.hpp" -#include "sat/smt/polysat_solver.h" -#include "sat/smt/euf_solver.h" +#include "sat/smt/polysat/polysat_core.h" namespace polysat { @@ -79,10 +78,10 @@ namespace polysat { } }; - core::core(solver& s) : + core::core(solver_interface& s) : s(s), m_viable(*this), - m_constraints(s.get_trail_stack()), + m_constraints(s.trail()), m_assignment(*this), m_var_queue(m_activity) {} @@ -109,7 +108,7 @@ namespace polysat { m_justification.push_back(null_dependency); m_watch.push_back({}); m_var_queue.mk_var_eh(v); - s.ctx.push(mk_add_var(*this)); + s.trail().push(mk_add_var(*this)); return v; } @@ -134,7 +133,7 @@ namespace polysat { add_watch(idx, vars[0]); if (vars.size() > 1) add_watch(idx, vars[1]); - s.ctx.push(mk_add_watch(*this)); + s.trail().push(mk_add_watch(*this)); return idx; } @@ -146,7 +145,7 @@ namespace polysat { if (m_var_queue.empty()) return sat::check_result::CR_DONE; m_var = m_var_queue.next_var(); - s.ctx.push(mk_dqueue_var(m_var, *this)); + s.trail().push(mk_dqueue_var(m_var, *this)); switch (m_viable.find_viable(m_var, m_value)) { case find_t::empty: m_unsat_core = m_viable.explain(); @@ -169,11 +168,11 @@ namespace polysat { bool core::propagate() { if (m_qhead == m_prop_queue.size() && m_vqhead == m_prop_queue.size()) return false; - s.ctx.push(value_trail(m_qhead)); - for (; m_qhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_qhead) + s.trail().push(value_trail(m_qhead)); + for (; m_qhead < m_prop_queue.size() && !s.inconsistent(); ++m_qhead) propagate_assignment(m_prop_queue[m_qhead]); - s.ctx.push(value_trail(m_vqhead)); - for (; m_vqhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_vqhead) + s.trail().push(value_trail(m_vqhead)); + for (; m_vqhead < m_prop_queue.size() && !s.inconsistent(); ++m_vqhead) propagate_value(m_prop_queue[m_vqhead]); return true; } @@ -202,12 +201,12 @@ namespace polysat { return; if (m_var_queue.contains(v)) { m_var_queue.del_var_eh(v); - s.ctx.push(mk_dqueue_var(v, *this)); + s.trail().push(mk_dqueue_var(v, *this)); } m_values[v] = value; m_justification[v] = dep; m_assignment.push(v , value); - s.ctx.push(mk_assign_var(v, *this)); + s.trail().push(mk_assign_var(v, *this)); // update the watch lists for pvars // remove constraints from m_watch[v] that have more than 2 free variables. @@ -289,7 +288,7 @@ namespace polysat { void core::assign_eh(unsigned index, bool sign, dependency const& dep) { m_prop_queue.push_back({ index, sign, dep }); - s.ctx.push(push_back_vector(m_prop_queue)); + s.trail().push(push_back_vector(m_prop_queue)); } dependency_vector core::explain_eval(signed_constraint const& sc) { @@ -309,7 +308,7 @@ namespace polysat { } trail_stack& core::trail() { - return s.get_trail_stack(); + return s.trail(); } } diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat/polysat_core.h similarity index 94% rename from src/sat/smt/polysat_core.h rename to src/sat/smt/polysat/polysat_core.h index b7c9f9eb8..bb21ee641 100644 --- a/src/sat/smt/polysat_core.h +++ b/src/sat/smt/polysat/polysat_core.h @@ -20,15 +20,15 @@ Author: #include "util/dependency.h" #include "math/dd/dd_pdd.h" #include "sat/smt/sat_th.h" -#include "sat/smt/polysat_types.h" -#include "sat/smt/polysat_constraints.h" -#include "sat/smt/polysat_viable.h" -#include "sat/smt/polysat_assignment.h" +#include "sat/smt/polysat/polysat_types.h" +#include "sat/smt/polysat/polysat_constraints.h" +#include "sat/smt/polysat/polysat_viable.h" +#include "sat/smt/polysat/polysat_assignment.h" namespace polysat { class core; - class solver; + class solver_interface; class core { class mk_add_var; @@ -45,7 +45,7 @@ namespace polysat { signed_constraint sc; dependency d; }; - solver& s; + solver_interface& s; viable m_viable; constraints m_constraints; assignment m_assignment; @@ -87,7 +87,7 @@ namespace polysat { dependency_vector explain_eval(signed_constraint const& sc); public: - core(solver& s); + core(solver_interface& s); sat::check_result check(); diff --git a/src/sat/smt/polysat_fi.cpp b/src/sat/smt/polysat/polysat_fi.cpp similarity index 98% rename from src/sat/smt/polysat_fi.cpp rename to src/sat/smt/polysat/polysat_fi.cpp index 349243ed8..e54fb5cea 100644 --- a/src/sat/smt/polysat_fi.cpp +++ b/src/sat/smt/polysat/polysat_fi.cpp @@ -13,11 +13,11 @@ Author: Nikolaj Bjorner (nbjorner) 2021-03-19 --*/ -#include "sat/smt/polysat_fi.h" -#include "sat/smt/polysat_interval.h" -#include "sat/smt/polysat_umul_ovfl.h" -#include "sat/smt/polysat_ule.h" -#include "sat/smt/polysat_core.h" +#include "sat/smt/polysat/polysat_fi.h" +#include "sat/smt/polysat/polysat_interval.h" +#include "sat/smt/polysat/polysat_umul_ovfl.h" +#include "sat/smt/polysat/polysat_ule.h" +#include "sat/smt/polysat/polysat_core.h" namespace polysat { diff --git a/src/sat/smt/polysat_fi.h b/src/sat/smt/polysat/polysat_fi.h similarity index 96% rename from src/sat/smt/polysat_fi.h rename to src/sat/smt/polysat/polysat_fi.h index 7782deb4a..e1f876c3c 100644 --- a/src/sat/smt/polysat_fi.h +++ b/src/sat/smt/polysat/polysat_fi.h @@ -14,9 +14,9 @@ Author: --*/ #pragma once -#include "sat/smt/polysat_types.h" -#include "sat/smt/polysat_interval.h" -#include "sat/smt/polysat_constraints.h" +#include "sat/smt/polysat/polysat_types.h" +#include "sat/smt/polysat/polysat_interval.h" +#include "sat/smt/polysat/polysat_constraints.h" namespace polysat { diff --git a/src/sat/smt/polysat/polysat_interval.h b/src/sat/smt/polysat/polysat_interval.h new file mode 100644 index 000000000..0299f83b3 --- /dev/null +++ b/src/sat/smt/polysat/polysat_interval.h @@ -0,0 +1,224 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat intervals + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-6 + +--*/ +#pragma once +#include "sat/smt/polysat/polysat_types.h" +#include + +namespace polysat { + + struct pdd_bounds { + pdd lo; ///< lower bound, inclusive + pdd hi; ///< upper bound, exclusive + }; + + /** + * An interval is either [lo; hi[ (excl. upper bound) or the full domain Z_{2^w}. + * If lo > hi, the interval wraps around, i.e., represents the union of [lo; 2^w[ and [0; hi[. + * Membership test t \in [lo; hi[ is equivalent to t - lo < hi - lo. + */ + class interval { + std::optional m_bounds = std::nullopt; + + interval() = default; + interval(pdd const& lo, pdd const& hi): m_bounds({lo, hi}) {} + public: + static interval empty(dd::pdd_manager& m) { return proper(m.zero(), m.zero()); } + static interval full() { return {}; } + static interval proper(pdd const& lo, pdd const& hi) { return {lo, hi}; } + + interval(interval const&) = default; + interval(interval&&) = default; + interval& operator=(interval const& other) { + m_bounds = std::nullopt; // clear pdds first to allow changing pdd_manager (probably should change the PDD assignment operator; but for now I want to be able to detect manager confusions) + m_bounds = other.m_bounds; + return *this; + } + interval& operator=(interval&& other) { + m_bounds = std::nullopt; // clear pdds first to allow changing pdd_manager + m_bounds = std::move(other.m_bounds); + return *this; + } + ~interval() = default; + + bool is_full() const { return !m_bounds; } + bool is_proper() const { return !!m_bounds; } + bool is_always_empty() const { return is_proper() && lo() == hi(); } + pdd const& lo() const { SASSERT(is_proper()); return m_bounds->lo; } + pdd const& hi() const { SASSERT(is_proper()); return m_bounds->hi; } + }; + + inline std::ostream& operator<<(std::ostream& os, interval const& i) { + if (i.is_full()) + return os << "full"; + else + return os << "[" << i.lo() << " ; " << i.hi() << "["; + } + + // distance from a to b, wrapping around at mod_value. + // basically mod(b - a, mod_value), but distance(0, mod_value, mod_value) = mod_value. + inline rational distance(rational const& a, rational const& b, rational const& mod_value) { + SASSERT(mod_value.is_power_of_two()); + SASSERT(0 <= a && a < mod_value); + SASSERT(0 <= b && b <= mod_value); + rational x = b - a; + if (x.is_neg()) + x += mod_value; + return x; + } + + class r_interval { + rational m_lo; + rational m_hi; + + r_interval(rational lo, rational hi) + : m_lo(std::move(lo)), m_hi(std::move(hi)) + {} + + public: + + static r_interval empty() { + return {rational::zero(), rational::zero()}; + } + + static r_interval full() { + return {rational(-1), rational::zero()}; + } + + static r_interval proper(rational lo, rational hi) { + SASSERT(0 <= lo); + SASSERT(0 <= hi); + return {std::move(lo), std::move(hi)}; + } + + bool is_full() const { return m_lo.is_neg(); } + bool is_proper() const { return !is_full(); } + bool is_empty() const { return is_proper() && lo() == hi(); } + rational const& lo() const { SASSERT(is_proper()); return m_lo; } + rational const& hi() const { SASSERT(is_proper()); return m_hi; } + + // this one also supports representing full intervals as [lo;mod_value[ + static rational len(rational const& lo, rational const& hi, rational const& mod_value) { + SASSERT(mod_value.is_power_of_two()); + SASSERT(0 <= lo && lo < mod_value); + SASSERT(0 <= hi && hi <= mod_value); + SASSERT(hi != mod_value || lo == 0); // hi == mod_value only allowed when lo == 0 + rational len = hi - lo; + if (len.is_neg()) + len += mod_value; + return len; + } + + rational len(rational const& mod_value) const { + SASSERT(is_proper()); + return len(lo(), hi(), mod_value); + } + + // deals only with proper intervals + // but works with full intervals represented as [0;mod_value[ -- maybe we should just change representation of full intervals to this always + static bool contains(rational const& lo, rational const& hi, rational const& val) { + if (lo <= hi) + return lo <= val && val < hi; + else + return val < hi || val >= lo; + } + + bool contains(rational const& val) const { + if (is_full()) + return true; + else + return contains(lo(), hi(), val); + } + + }; + + class eval_interval { + interval m_symbolic; + rational m_concrete_lo; + rational m_concrete_hi; + + eval_interval(interval&& i, rational const& lo_val, rational const& hi_val): + m_symbolic(std::move(i)), m_concrete_lo(lo_val), m_concrete_hi(hi_val) {} + public: + static eval_interval empty(dd::pdd_manager& m) { + return {interval::empty(m), rational::zero(), rational::zero()}; + } + + static eval_interval full() { + return {interval::full(), rational::zero(), rational::zero()}; + } + + static eval_interval proper(pdd const& lo, rational const& lo_val, pdd const& hi, rational const& hi_val) { + SASSERT(0 <= lo_val && lo_val <= lo.manager().max_value()); + SASSERT(0 <= hi_val && hi_val <= hi.manager().max_value()); + return {interval::proper(lo, hi), lo_val, hi_val}; + } + + bool is_full() const { return m_symbolic.is_full(); } + bool is_proper() const { return m_symbolic.is_proper(); } + bool is_always_empty() const { return m_symbolic.is_always_empty(); } + bool is_currently_empty() const { return is_proper() && lo_val() == hi_val(); } + interval const& symbolic() const { return m_symbolic; } + pdd const& lo() const { return m_symbolic.lo(); } + pdd const& hi() const { return m_symbolic.hi(); } + rational const& lo_val() const { SASSERT(is_proper()); return m_concrete_lo; } + rational const& hi_val() const { SASSERT(is_proper()); return m_concrete_hi; } + + rational current_len() const { + SASSERT(is_proper()); + return mod(hi_val() - lo_val(), lo().manager().two_to_N()); + } + + bool currently_contains(rational const& val) const { + if (is_full()) + return true; + else if (lo_val() <= hi_val()) + return lo_val() <= val && val < hi_val(); + else + return val < hi_val() || val >= lo_val(); + } + + bool currently_contains(eval_interval const& other) const { + if (is_full()) + return true; + if (other.is_full()) + return false; + // lo <= lo' <= hi' <= hi' + if (lo_val() <= other.lo_val() && other.lo_val() <= other.hi_val() && other.hi_val() <= hi_val()) + return true; + if (lo_val() <= hi_val()) + return false; + // hi < lo <= lo' <= hi' + if (lo_val() <= other.lo_val() && other.lo_val() <= other.hi_val()) + return true; + // lo' <= hi' <= hi < lo + if (other.lo_val() <= other.hi_val() && other.hi_val() <= hi_val()) + return true; + // hi' <= hi < lo <= lo' + if (other.hi_val() <= hi_val() && lo_val() <= other.lo_val()) + return true; + return false; + } + + }; // class eval_interval + + inline std::ostream& operator<<(std::ostream& os, eval_interval const& i) { + if (i.is_full()) + return os << "full"; + else { + auto& m = i.hi().manager(); + return os << i.symbolic() << " := [" << m.normalize(i.lo_val()) << ";" << m.normalize(i.hi_val()) << "["; + } + } + +} diff --git a/src/sat/smt/polysat/polysat_types.h b/src/sat/smt/polysat/polysat_types.h new file mode 100644 index 000000000..e77e755bf --- /dev/null +++ b/src/sat/smt/polysat/polysat_types.h @@ -0,0 +1,67 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once + +#include +#include "math/dd/dd_pdd.h" +#include "util/trail.h" +#include "util/sat_literal.h" + +namespace polysat { + + using pdd = dd::pdd; + using pvar = unsigned; + using theory_var = unsigned; + + using pvar_vector = unsigned_vector; + inline const pvar null_var = UINT_MAX; + + + + class dependency { + std::variant> m_data; + unsigned m_level; + public: + dependency(sat::literal lit, unsigned level) : m_data(lit), m_level(level) {} + dependency(theory_var v1, theory_var v2, unsigned level) : m_data(std::make_pair(v1, v2)), m_level(level) {} + bool is_null() const { return is_literal() && *std::get_if(&m_data) == sat::null_literal; } + bool is_literal() const { return std::holds_alternative(m_data); } + sat::literal literal() const { SASSERT(is_literal()); return *std::get_if(&m_data); } + std::pair eq() const { SASSERT(!is_literal()); return *std::get_if>(&m_data); } + unsigned level() const { return m_level; } + }; + + inline const dependency null_dependency = dependency(sat::null_literal, UINT_MAX); + + inline std::ostream& operator<<(std::ostream& out, dependency d) { + if (d.is_null()) + return out << "null"; + else if (d.is_literal()) + return out << d.literal() << "@" << d.level(); + else + return out << "v" << d.eq().first << " == v" << d.eq().second << "@" << d.level(); + } + + using dependency_vector = vector; + + class signed_constraint; + + class solver_interface { + public: + virtual void add_eq_literal(pvar v, rational const& val) = 0; + virtual void set_conflict(dependency_vector const& core) = 0; + virtual void set_lemma(vector const& lemma, unsigned level, dependency_vector const& core) = 0; + virtual dependency propagate(signed_constraint sc, dependency_vector const& deps) = 0; + virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0; + virtual trail_stack& trail() = 0; + virtual bool inconsistent() const = 0; + }; + +} diff --git a/src/sat/smt/polysat_ule.cpp b/src/sat/smt/polysat/polysat_ule.cpp similarity index 99% rename from src/sat/smt/polysat_ule.cpp rename to src/sat/smt/polysat/polysat_ule.cpp index 08448b34d..0fb01bcae 100644 --- a/src/sat/smt/polysat_ule.cpp +++ b/src/sat/smt/polysat/polysat_ule.cpp @@ -70,8 +70,8 @@ Useful lemmas: --*/ -#include "sat/smt/polysat_constraints.h" -#include "sat/smt/polysat_ule.h" +#include "sat/smt/polysat/polysat_constraints.h" +#include "sat/smt/polysat/polysat_ule.h" #define LOG(_msg_) verbose_stream() << _msg_ << "\n" diff --git a/src/sat/smt/polysat_ule.h b/src/sat/smt/polysat/polysat_ule.h similarity index 94% rename from src/sat/smt/polysat_ule.h rename to src/sat/smt/polysat/polysat_ule.h index 12efe506a..e21ed1029 100644 --- a/src/sat/smt/polysat_ule.h +++ b/src/sat/smt/polysat/polysat_ule.h @@ -12,9 +12,8 @@ Author: --*/ #pragma once -#include "sat/smt/polysat_ule.h" -#include "sat/smt/polysat_assignment.h" -#include "sat/smt/polysat_constraints.h" +#include "sat/smt/polysat/polysat_assignment.h" +#include "sat/smt/polysat/polysat_constraints.h" namespace polysat { diff --git a/src/sat/smt/polysat_umul_ovfl.cpp b/src/sat/smt/polysat/polysat_umul_ovfl.cpp similarity index 92% rename from src/sat/smt/polysat_umul_ovfl.cpp rename to src/sat/smt/polysat/polysat_umul_ovfl.cpp index 5c448bc0a..dfe400603 100644 --- a/src/sat/smt/polysat_umul_ovfl.cpp +++ b/src/sat/smt/polysat/polysat_umul_ovfl.cpp @@ -10,9 +10,9 @@ Author: Jakob Rath, Nikolaj Bjorner (nbjorner) 2021-12-09 --*/ -#include "sat/smt/polysat_constraints.h" -#include "sat/smt/polysat_assignment.h" -#include "sat/smt/polysat_umul_ovfl.h" +#include "sat/smt/polysat/polysat_constraints.h" +#include "sat/smt/polysat/polysat_assignment.h" +#include "sat/smt/polysat/polysat_umul_ovfl.h" namespace polysat { diff --git a/src/sat/smt/polysat_umul_ovfl.h b/src/sat/smt/polysat/polysat_umul_ovfl.h similarity index 95% rename from src/sat/smt/polysat_umul_ovfl.h rename to src/sat/smt/polysat/polysat_umul_ovfl.h index 502ed4bbf..41972ef59 100644 --- a/src/sat/smt/polysat_umul_ovfl.h +++ b/src/sat/smt/polysat/polysat_umul_ovfl.h @@ -11,7 +11,7 @@ Author: --*/ #pragma once -#include "sat/smt/polysat_constraints.h" +#include "sat/smt/polysat/polysat_constraints.h" namespace polysat { diff --git a/src/sat/smt/polysat_viable.cpp b/src/sat/smt/polysat/polysat_viable.cpp similarity index 99% rename from src/sat/smt/polysat_viable.cpp rename to src/sat/smt/polysat/polysat_viable.cpp index d68822563..a11a02b91 100644 --- a/src/sat/smt/polysat_viable.cpp +++ b/src/sat/smt/polysat/polysat_viable.cpp @@ -18,8 +18,8 @@ Notes: #include "util/debug.h" #include "util/log.h" -#include "sat/smt/polysat_viable.h" -#include "sat/smt/polysat_core.h" +#include "sat/smt/polysat/polysat_viable.h" +#include "sat/smt/polysat/polysat_core.h" namespace polysat { diff --git a/src/sat/smt/polysat_viable.h b/src/sat/smt/polysat/polysat_viable.h similarity index 97% rename from src/sat/smt/polysat_viable.h rename to src/sat/smt/polysat/polysat_viable.h index f1c826d10..79fcfa76e 100644 --- a/src/sat/smt/polysat_viable.h +++ b/src/sat/smt/polysat/polysat_viable.h @@ -21,8 +21,8 @@ Author: #include "util/map.h" #include "util/small_object_allocator.h" -#include "sat/smt/polysat_types.h" -#include "sat/smt/polysat_fi.h" +#include "sat/smt/polysat/polysat_types.h" +#include "sat/smt/polysat/polysat_fi.h" namespace polysat { diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 3e01ff391..9e57b8cae 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -27,8 +27,8 @@ The result of polysat::core::check is one of: #include "ast/euf/euf_bv_plugin.h" #include "sat/smt/polysat_solver.h" #include "sat/smt/euf_solver.h" -#include "sat/smt/polysat_ule.h" -#include "sat/smt/polysat_umul_ovfl.h" +#include "sat/smt/polysat/polysat_ule.h" +#include "sat/smt/polysat/polysat_umul_ovfl.h" namespace polysat { @@ -82,7 +82,7 @@ namespace polysat { core.push_back(d.literal()); } else { - auto const [v1, v2] = m_var_eqs[d.index()]; + auto const [v1, v2] = d.eq(); euf::enode* const n1 = var2enode(v1); euf::enode* const n2 = var2enode(v2); VERIFY(n1->get_root() == n2->get_root()); @@ -151,7 +151,7 @@ namespace polysat { auto sc = m_core.eq(p, q); m_var_eqs.setx(m_var_eqs_head, {v1, v2}, {v1, v2}); ctx.push(value_trail(m_var_eqs_head)); - auto d = dependency(m_var_eqs_head, s().scope_lvl()); + auto d = dependency(v1, v2, s().scope_lvl()); unsigned index = m_core.register_constraint(sc, d); m_core.assign_eh(index, false, d); m_var_eqs_head++; @@ -192,7 +192,7 @@ namespace polysat { ctx.propagate(lit, ex); } else if (sign) { - auto const [v1, v2] = m_var_eqs[d.index()]; + auto const [v1, v2] = d.eq(); // equalities are always asserted so a negative propagation is a conflict. auto n1 = var2enode(v1); auto n2 = var2enode(v2); @@ -202,6 +202,14 @@ namespace polysat { } } + bool solver::inconsistent() const { + return s().inconsistent(); + } + + trail_stack& solver::trail() { + return ctx.get_trail_stack(); + } + void solver::add_lemma(vector const& lemma) { sat::literal_vector lits; for (auto sc : lemma) diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 76923f88f..8489619da 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -18,7 +18,7 @@ Author: #include "sat/smt/sat_th.h" #include "math/dd/dd_pdd.h" -#include "sat/smt/polysat_core.h" +#include "sat/smt/polysat/polysat_core.h" namespace euf { class solver; @@ -27,7 +27,7 @@ namespace euf { namespace polysat { - class solver : public euf::th_euf_solver { + class solver : public euf::th_euf_solver, public solver_interface { typedef euf::theory_var theory_var; typedef euf::theory_id theory_id; typedef sat::literal literal; @@ -53,8 +53,6 @@ namespace polysat { expr* get_hint(euf::solver& s) const override { return nullptr; } }; - friend class core; - bv_util bv; arith_util m_autil; stats m_stats; @@ -128,12 +126,14 @@ namespace polysat { void internalize_set(euf::theory_var v, pdd const& p); // callbacks from core - void add_eq_literal(pvar v, rational const& val); - void set_conflict(dependency_vector const& core); - void set_lemma(vector const& lemma, unsigned level, dependency_vector const& core); - dependency propagate(signed_constraint sc, dependency_vector const& deps); - void propagate(dependency const& d, bool sign, dependency_vector const& deps); - + void add_eq_literal(pvar v, rational const& val) override; + void set_conflict(dependency_vector const& core) override; + void set_lemma(vector const& lemma, unsigned level, dependency_vector const& core) override; + dependency propagate(signed_constraint sc, dependency_vector const& deps) override; + void propagate(dependency const& d, bool sign, dependency_vector const& deps) override; + trail_stack& trail() override; + bool inconsistent() const override; + void add_lemma(vector const& lemma); std::pair explain_deps(dependency_vector const& deps); diff --git a/src/sat/smt/polysat_types.h b/src/sat/smt/polysat_types.h deleted file mode 100644 index c8c8324d7..000000000 --- a/src/sat/smt/polysat_types.h +++ /dev/null @@ -1,48 +0,0 @@ -/*++ -Copyright (c) 2021 Microsoft Corporation - -Author: - - Nikolaj Bjorner (nbjorner) 2021-03-19 - Jakob Rath 2021-04-06 - ---*/ -#pragma once - -#include "math/dd/dd_pdd.h" -#include "util/sat_literal.h" - -namespace polysat { - - using pdd = dd::pdd; - using pvar = unsigned; - - using pvar_vector = unsigned_vector; - inline const pvar null_var = UINT_MAX; - - - class dependency { - unsigned m_index; - unsigned m_level; - public: - dependency(sat::literal lit, unsigned level) : m_index(2 * lit.index()), m_level(level) {} - dependency(unsigned var_idx, unsigned level) : m_index(1 + 2 * var_idx), m_level(level) {} - bool is_null() const { return m_level == UINT_MAX; } - bool is_literal() const { return m_index % 2 == 0; } - sat::literal literal() const { SASSERT(is_literal()); return sat::to_literal(m_index / 2); } - unsigned index() const { SASSERT(!is_literal()); return (m_index - 1) / 2; } - unsigned level() const { return m_level; } - }; - - inline const dependency null_dependency = dependency(0, UINT_MAX); - - inline std::ostream& operator<<(std::ostream& out, dependency d) { - if (d.is_literal()) - return out << d.literal() << "@" << d.level(); - else - return out << "v" << d.index() << "@" << d.level(); - } - - using dependency_vector = vector; - -} From bff51b699de5ac0656555c0a741391bfc9889651 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 9 Dec 2023 09:39:59 -0800 Subject: [PATCH 14/89] remove stale files --- src/sat/smt/polysat_interval.h | 224 ----------------------------- src/sat/smt/polysat_substitution.h | 212 --------------------------- 2 files changed, 436 deletions(-) delete mode 100644 src/sat/smt/polysat_interval.h delete mode 100644 src/sat/smt/polysat_substitution.h diff --git a/src/sat/smt/polysat_interval.h b/src/sat/smt/polysat_interval.h deleted file mode 100644 index 9965dbab1..000000000 --- a/src/sat/smt/polysat_interval.h +++ /dev/null @@ -1,224 +0,0 @@ -/*++ -Copyright (c) 2021 Microsoft Corporation - -Module Name: - - polysat intervals - -Author: - - Nikolaj Bjorner (nbjorner) 2021-03-19 - Jakob Rath 2021-04-6 - ---*/ -#pragma once -#include "sat/smt/polysat_types.h" -#include - -namespace polysat { - - struct pdd_bounds { - pdd lo; ///< lower bound, inclusive - pdd hi; ///< upper bound, exclusive - }; - - /** - * An interval is either [lo; hi[ (excl. upper bound) or the full domain Z_{2^w}. - * If lo > hi, the interval wraps around, i.e., represents the union of [lo; 2^w[ and [0; hi[. - * Membership test t \in [lo; hi[ is equivalent to t - lo < hi - lo. - */ - class interval { - std::optional m_bounds = std::nullopt; - - interval() = default; - interval(pdd const& lo, pdd const& hi): m_bounds({lo, hi}) {} - public: - static interval empty(dd::pdd_manager& m) { return proper(m.zero(), m.zero()); } - static interval full() { return {}; } - static interval proper(pdd const& lo, pdd const& hi) { return {lo, hi}; } - - interval(interval const&) = default; - interval(interval&&) = default; - interval& operator=(interval const& other) { - m_bounds = std::nullopt; // clear pdds first to allow changing pdd_manager (probably should change the PDD assignment operator; but for now I want to be able to detect manager confusions) - m_bounds = other.m_bounds; - return *this; - } - interval& operator=(interval&& other) { - m_bounds = std::nullopt; // clear pdds first to allow changing pdd_manager - m_bounds = std::move(other.m_bounds); - return *this; - } - ~interval() = default; - - bool is_full() const { return !m_bounds; } - bool is_proper() const { return !!m_bounds; } - bool is_always_empty() const { return is_proper() && lo() == hi(); } - pdd const& lo() const { SASSERT(is_proper()); return m_bounds->lo; } - pdd const& hi() const { SASSERT(is_proper()); return m_bounds->hi; } - }; - - inline std::ostream& operator<<(std::ostream& os, interval const& i) { - if (i.is_full()) - return os << "full"; - else - return os << "[" << i.lo() << " ; " << i.hi() << "["; - } - - // distance from a to b, wrapping around at mod_value. - // basically mod(b - a, mod_value), but distance(0, mod_value, mod_value) = mod_value. - inline rational distance(rational const& a, rational const& b, rational const& mod_value) { - SASSERT(mod_value.is_power_of_two()); - SASSERT(0 <= a && a < mod_value); - SASSERT(0 <= b && b <= mod_value); - rational x = b - a; - if (x.is_neg()) - x += mod_value; - return x; - } - - class r_interval { - rational m_lo; - rational m_hi; - - r_interval(rational lo, rational hi) - : m_lo(std::move(lo)), m_hi(std::move(hi)) - {} - - public: - - static r_interval empty() { - return {rational::zero(), rational::zero()}; - } - - static r_interval full() { - return {rational(-1), rational::zero()}; - } - - static r_interval proper(rational lo, rational hi) { - SASSERT(0 <= lo); - SASSERT(0 <= hi); - return {std::move(lo), std::move(hi)}; - } - - bool is_full() const { return m_lo.is_neg(); } - bool is_proper() const { return !is_full(); } - bool is_empty() const { return is_proper() && lo() == hi(); } - rational const& lo() const { SASSERT(is_proper()); return m_lo; } - rational const& hi() const { SASSERT(is_proper()); return m_hi; } - - // this one also supports representing full intervals as [lo;mod_value[ - static rational len(rational const& lo, rational const& hi, rational const& mod_value) { - SASSERT(mod_value.is_power_of_two()); - SASSERT(0 <= lo && lo < mod_value); - SASSERT(0 <= hi && hi <= mod_value); - SASSERT(hi != mod_value || lo == 0); // hi == mod_value only allowed when lo == 0 - rational len = hi - lo; - if (len.is_neg()) - len += mod_value; - return len; - } - - rational len(rational const& mod_value) const { - SASSERT(is_proper()); - return len(lo(), hi(), mod_value); - } - - // deals only with proper intervals - // but works with full intervals represented as [0;mod_value[ -- maybe we should just change representation of full intervals to this always - static bool contains(rational const& lo, rational const& hi, rational const& val) { - if (lo <= hi) - return lo <= val && val < hi; - else - return val < hi || val >= lo; - } - - bool contains(rational const& val) const { - if (is_full()) - return true; - else - return contains(lo(), hi(), val); - } - - }; - - class eval_interval { - interval m_symbolic; - rational m_concrete_lo; - rational m_concrete_hi; - - eval_interval(interval&& i, rational const& lo_val, rational const& hi_val): - m_symbolic(std::move(i)), m_concrete_lo(lo_val), m_concrete_hi(hi_val) {} - public: - static eval_interval empty(dd::pdd_manager& m) { - return {interval::empty(m), rational::zero(), rational::zero()}; - } - - static eval_interval full() { - return {interval::full(), rational::zero(), rational::zero()}; - } - - static eval_interval proper(pdd const& lo, rational const& lo_val, pdd const& hi, rational const& hi_val) { - SASSERT(0 <= lo_val && lo_val <= lo.manager().max_value()); - SASSERT(0 <= hi_val && hi_val <= hi.manager().max_value()); - return {interval::proper(lo, hi), lo_val, hi_val}; - } - - bool is_full() const { return m_symbolic.is_full(); } - bool is_proper() const { return m_symbolic.is_proper(); } - bool is_always_empty() const { return m_symbolic.is_always_empty(); } - bool is_currently_empty() const { return is_proper() && lo_val() == hi_val(); } - interval const& symbolic() const { return m_symbolic; } - pdd const& lo() const { return m_symbolic.lo(); } - pdd const& hi() const { return m_symbolic.hi(); } - rational const& lo_val() const { SASSERT(is_proper()); return m_concrete_lo; } - rational const& hi_val() const { SASSERT(is_proper()); return m_concrete_hi; } - - rational current_len() const { - SASSERT(is_proper()); - return mod(hi_val() - lo_val(), lo().manager().two_to_N()); - } - - bool currently_contains(rational const& val) const { - if (is_full()) - return true; - else if (lo_val() <= hi_val()) - return lo_val() <= val && val < hi_val(); - else - return val < hi_val() || val >= lo_val(); - } - - bool currently_contains(eval_interval const& other) const { - if (is_full()) - return true; - if (other.is_full()) - return false; - // lo <= lo' <= hi' <= hi' - if (lo_val() <= other.lo_val() && other.lo_val() <= other.hi_val() && other.hi_val() <= hi_val()) - return true; - if (lo_val() <= hi_val()) - return false; - // hi < lo <= lo' <= hi' - if (lo_val() <= other.lo_val() && other.lo_val() <= other.hi_val()) - return true; - // lo' <= hi' <= hi < lo - if (other.lo_val() <= other.hi_val() && other.hi_val() <= hi_val()) - return true; - // hi' <= hi < lo <= lo' - if (other.hi_val() <= hi_val() && lo_val() <= other.lo_val()) - return true; - return false; - } - - }; // class eval_interval - - inline std::ostream& operator<<(std::ostream& os, eval_interval const& i) { - if (i.is_full()) - return os << "full"; - else { - auto& m = i.hi().manager(); - return os << i.symbolic() << " := [" << m.normalize(i.lo_val()) << ";" << m.normalize(i.hi_val()) << "["; - } - } - -} diff --git a/src/sat/smt/polysat_substitution.h b/src/sat/smt/polysat_substitution.h deleted file mode 100644 index a30c6b710..000000000 --- a/src/sat/smt/polysat_substitution.h +++ /dev/null @@ -1,212 +0,0 @@ -/*++ -Copyright (c) 2021 Microsoft Corporation - -Module Name: - - polysat substitution - -Author: - - Nikolaj Bjorner (nbjorner) 2021-03-19 - Jakob Rath 2021-04-06 - ---*/ -#pragma once -#include "sat/smt/polysat_types.h" - -namespace polysat { - - using assignment_item_t = std::pair; - - class substitution_iterator { - pdd m_current; - substitution_iterator(pdd current) : m_current(std::move(current)) {} - friend class substitution; - - public: - using value_type = assignment_item_t; - using difference_type = std::ptrdiff_t; - using pointer = value_type const*; - using reference = value_type const&; - using iterator_category = std::input_iterator_tag; - - substitution_iterator& operator++() { - SASSERT(!m_current.is_val()); - m_current = m_current.hi(); - return *this; - } - - value_type operator*() const { - SASSERT(!m_current.is_val()); - return { m_current.var(), m_current.lo().val() }; - } - - bool operator==(substitution_iterator const& other) const { return m_current == other.m_current; } - bool operator!=(substitution_iterator const& other) const { return !operator==(other); } - }; - - /** Substitution for a single bit width. */ - class substitution { - pdd m_subst; - - substitution(pdd p); - - public: - substitution(dd::pdd_manager& m); - [[nodiscard]] substitution add(pvar var, rational const& value) const; - [[nodiscard]] pdd apply_to(pdd const& p) const; - - [[nodiscard]] bool contains(pvar var) const; - [[nodiscard]] bool value(pvar var, rational& out_value) const; - - [[nodiscard]] bool empty() const { return m_subst.is_one(); } - - pdd const& to_pdd() const { return m_subst; } - unsigned bit_width() const { return to_pdd().power_of_2(); } - - bool operator==(substitution const& other) const { return &m_subst.manager() == &other.m_subst.manager() && m_subst == other.m_subst; } - bool operator!=(substitution const& other) const { return !operator==(other); } - - std::ostream& display(std::ostream& out) const; - - using const_iterator = substitution_iterator; - const_iterator begin() const { return {m_subst}; } - const_iterator end() const { return {m_subst.manager().one()}; } - }; - - /** Full variable assignment, may include variables of varying bit widths. */ - class assignment { - vector m_pairs; - mutable scoped_ptr_vector m_subst; - vector m_subst_trail; - - substitution& subst(unsigned sz); - solver& s() const { return *m_solver; } - public: - assignment(solver& s); - // prevent implicit copy, use clone() if you do need a copy - assignment(assignment const&) = delete; - assignment& operator=(assignment const&) = delete; - assignment(assignment&&) = default; - assignment& operator=(assignment&&) = default; - assignment clone() const; - - void push(pvar var, rational const& value); - void pop(); - - pdd apply_to(pdd const& p) const; - - bool contains(pvar var) const; - bool value(pvar var, rational& out_value) const; - rational value(pvar var) const { rational val; VERIFY(value(var, val)); return val; } - bool empty() const { return pairs().empty(); } - substitution const& subst(unsigned sz) const; - vector const& pairs() const { return m_pairs; } - using const_iterator = decltype(m_pairs)::const_iterator; - const_iterator begin() const { return pairs().begin(); } - const_iterator end() const { return pairs().end(); } - - std::ostream& display(std::ostream& out) const; - }; - - inline std::ostream& operator<<(std::ostream& out, substitution const& sub) { return sub.display(out); } - - inline std::ostream& operator<<(std::ostream& out, assignment const& a) { return a.display(out); } -} - -namespace polysat { - - enum class search_item_k - { - assignment, - boolean, - }; - - class search_item { - search_item_k m_kind; - union { - pvar m_var; - sat::literal m_lit; - }; - bool m_resolved = false; // when marked as resolved it is no longer valid to reduce the conflict state - - search_item(pvar var): m_kind(search_item_k::assignment), m_var(var) {} - search_item(sat::literal lit): m_kind(search_item_k::boolean), m_lit(lit) {} - public: - static search_item assignment(pvar var) { return search_item(var); } - static search_item boolean(sat::literal lit) { return search_item(lit); } - bool is_assignment() const { return m_kind == search_item_k::assignment; } - bool is_boolean() const { return m_kind == search_item_k::boolean; } - bool is_resolved() const { return m_resolved; } - search_item_k kind() const { return m_kind; } - pvar var() const { SASSERT(is_assignment()); return m_var; } - sat::literal lit() const { SASSERT(is_boolean()); return m_lit; } - void set_resolved() { m_resolved = true; } - }; - - class search_state { - solver& s; - - vector m_items; - assignment m_assignment; - - // store index into m_items - unsigned_vector m_pvar_to_idx; - unsigned_vector m_bool_to_idx; - - bool value(pvar v, rational& r) const; - - public: - search_state(solver& s): s(s), m_assignment(s) {} - unsigned size() const { return m_items.size(); } - search_item const& back() const { return m_items.back(); } - search_item const& operator[](unsigned i) const { return m_items[i]; } - - assignment const& get_assignment() const { return m_assignment; } - substitution const& subst(unsigned sz) const { return m_assignment.subst(sz); } - - // TODO: implement the following method if we actually need the assignments without resolved items already during conflict resolution - // (no separate trail needed, just a second m_subst and an index into the trail, I think) - // (update on set_resolved? might be one iteration too early, looking at the old solver::resolve_conflict loop) - substitution const& unresolved_assignment(unsigned sz) const; - - void push_assignment(pvar v, rational const& r); - void push_boolean(sat::literal lit); - void pop(); - - unsigned get_pvar_index(pvar v) const; - unsigned get_bool_index(sat::bool_var var) const; - unsigned get_bool_index(sat::literal lit) const { return get_bool_index(lit.var()); } - - void set_resolved(unsigned i) { m_items[i].set_resolved(); } - - using const_iterator = decltype(m_items)::const_iterator; - const_iterator begin() const { return m_items.begin(); } - const_iterator end() const { return m_items.end(); } - - std::ostream& display(std::ostream& out) const; - std::ostream& display(search_item const& item, std::ostream& out) const; - std::ostream& display_verbose(std::ostream& out) const; - std::ostream& display_verbose(search_item const& item, std::ostream& out) const; - }; - - struct search_state_pp { - search_state const& s; - bool verbose; - search_state_pp(search_state const& s, bool verbose = false) : s(s), verbose(verbose) {} - }; - - struct search_item_pp { - search_state const& s; - search_item const& i; - bool verbose; - search_item_pp(search_state const& s, search_item const& i, bool verbose = false) : s(s), i(i), verbose(verbose) {} - }; - - inline std::ostream& operator<<(std::ostream& out, search_state const& s) { return s.display(out); } - - inline std::ostream& operator<<(std::ostream& out, search_state_pp const& p) { return p.verbose ? p.s.display_verbose(out) : p.s.display(out); } - - inline std::ostream& operator<<(std::ostream& out, search_item_pp const& p) { return p.verbose ? p.s.display_verbose(p.i, out) : p.s.display(p.i, out); } - -} From 70bddb35beae0e8eee214b73d9030d6d946b3e2e Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 9 Dec 2023 09:44:05 -0800 Subject: [PATCH 15/89] update viable Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat/polysat_viable.cpp | 108 ++++++++++++++++++++----- src/sat/smt/polysat/polysat_viable.h | 63 ++++++++++++++- 2 files changed, 147 insertions(+), 24 deletions(-) diff --git a/src/sat/smt/polysat/polysat_viable.cpp b/src/sat/smt/polysat/polysat_viable.cpp index a11a02b91..27b36c582 100644 --- a/src/sat/smt/polysat/polysat_viable.cpp +++ b/src/sat/smt/polysat/polysat_viable.cpp @@ -45,44 +45,111 @@ namespace polysat { struct viable::pop_viable_trail : public trail { viable& m_s; entry* e; - pvar v; entry_kind k; public: - pop_viable_trail(viable& s, entry* e, pvar v, entry_kind k) - : m_s(s), e(e), v(v), k(k) {} + pop_viable_trail(viable& s, entry* e, entry_kind k) + : m_s(s), e(e), k(k) {} void undo() override { - m_s.pop_viable(e, v, k); + m_s.pop_viable(e, k); } }; struct viable::push_viable_trail : public trail { viable& m_s; entry* e; - pvar v; - entry_kind k; public: - push_viable_trail(viable& s, entry* e, pvar v, entry_kind k) - : m_s(s), e(e), v(v), k(k) {} + push_viable_trail(viable& s, entry* e) + : m_s(s), e(e) {} void undo() override { - m_s.push_viable(e, v, k); + m_s.push_viable(e); } }; - viable::entry* viable::alloc_entry(pvar var) { + viable::entry* viable::alloc_entry(pvar var, unsigned constraint_index) { if (m_alloc.empty()) return alloc(entry); auto* e = m_alloc.back(); e->reset(); e->var = var; + e->constraint_index = constraint_index; m_alloc.pop_back(); return e; } - find_t viable::find_viable(pvar v, rational& out_val) { + find_t viable::find_viable(pvar v, rational& lo) { + rational hi; ensure_var(v); - throw default_exception("nyi"); + switch (find_viable(v, lo, hi)) { + case l_true: + return (lo == hi) ? find_t::singleton : find_t::multiple; + case l_false: + return find_t::empty; + default: + return find_t::resource_out; + } } + lbool viable::find_viable(pvar v, rational& lo, rational& hi) { + fixed_bits_info fbi; + +#if 0 + if (!collect_bit_information(v, true, fbi)) + return l_false; // conflict already added +#endif + + pvar_vector overlaps; +#if 0 + // TODO s.m_slicing.collect_simple_overlaps(v, overlaps); +#endif + std::sort(overlaps.begin(), overlaps.end(), [&](pvar x, pvar y) { return c.size(x) > c.size(y); }); + + uint_set widths_set; + // max size should always be present, regardless of whether we have intervals there (to make sure all fixed bits are considered) + widths_set.insert(c.size(v)); + +#if 0 + LOG("Overlaps with v" << v << ":"); + for (pvar x : overlaps) { + unsigned hi, lo; + if (s.m_slicing.is_extract(x, v, hi, lo)) + LOG(" v" << x << " = v" << v << "[" << hi << ":" << lo << "]"); + else + LOG(" v" << x << " not extracted from v" << v << "; size " << s.size(x)); + for (layer const& l : m_units[x].get_layers()) { + widths_set.insert(l.bit_width); + } + } +#endif + + unsigned_vector widths; + for (unsigned w : widths_set) + widths.push_back(w); + LOG("widths: " << widths); + + rational const& max_value = c.var2pdd(v).max_value(); + +#if 0 + lbool result_lo = find_on_layers(v, widths, overlaps, fbi, rational::zero(), max_value, lo); + if (result_lo == l_false) + return l_false; // conflict + if (result_lo == l_undef) + return find_viable_fallback(v, overlaps, lo, hi); + + if (lo == max_value) { + hi = lo; + return l_true; + } + + lbool result_hi = find_on_layers(v, widths, overlaps, fbi, lo + 1, max_value, hi); + if (result_hi == l_false) + hi = lo; // no other viable value + if (result_hi == l_undef) + return find_viable_fallback(v, overlaps, lo, hi); +#endif + return l_true; + } + + /* * Explain why the current variable is not viable or signleton. */ @@ -99,7 +166,7 @@ namespace polysat { return; auto [sc, d] = c.m_constraint_trail[idx]; - entry* ne = alloc_entry(v); + entry* ne = alloc_entry(v, idx); if (!m_forbidden_intervals.get_interval(sc, v, *ne)) { m_alloc.push_back(ne); return; @@ -211,13 +278,13 @@ namespace polysat { } auto create_entry = [&]() { - c.trail().push(pop_viable_trail(*this, ne, v, entry_kind::unit_e)); + c.trail().push(pop_viable_trail(*this, ne, entry_kind::unit_e)); ne->init(ne); return ne; }; auto remove_entry = [&](entry* e) { - c.trail().push(push_viable_trail(*this, e, v, entry_kind::unit_e)); + c.trail().push(push_viable_trail(*this, e)); e->remove_from(entries, e); e->active = false; }; @@ -311,7 +378,8 @@ namespace polysat { return nullptr; } - void viable::pop_viable(entry* e, pvar v, entry_kind k) { + void viable::pop_viable(entry* e, entry_kind k) { + unsigned v = e->var; SASSERT(well_formed(m_units[v])); SASSERT(e->active); e->active = false; @@ -333,15 +401,15 @@ namespace polysat { m_alloc.push_back(e); } - void viable::push_viable(entry* e, pvar v, entry_kind k) { + void viable::push_viable(entry* e) { // display_one(verbose_stream() << "Push entry: ", v, e) << "\n"; + auto v = e->var; entry*& entries = m_units[v].get_layer(e)->entries; SASSERT(e->prev() != e || !entries); SASSERT(e->prev() != e || e->next() == e); - SASSERT(k == entry_kind::unit_e); SASSERT(!e->active); e->active = true; - (void)k; + SASSERT(well_formed(m_units[v])); if (e->prev() != e) { entry* pos = e->prev(); @@ -358,7 +426,7 @@ namespace polysat { void viable::insert(entry* e, pvar v, ptr_vector& entries, entry_kind k) { SASSERT(well_formed(m_units[v])); - c.trail().push(pop_viable_trail(*this, e, v, k)); + c.trail().push(pop_viable_trail(*this, e, k)); e->init(e); if (!entries[v]) diff --git a/src/sat/smt/polysat/polysat_viable.h b/src/sat/smt/polysat/polysat_viable.h index 79fcfa76e..7c9c019dc 100644 --- a/src/sat/smt/polysat/polysat_viable.h +++ b/src/sat/smt/polysat/polysat_viable.h @@ -50,6 +50,7 @@ namespace polysat { bool active = true; bool valid_for_lemma = true; pvar var = null_var; + unsigned constraint_index = UINT_MAX; void reset() { // dll_base::init(this); // we never did this in alloc_entry either @@ -58,6 +59,7 @@ namespace polysat { active = true; valid_for_lemma = true; var = null_var; + constraint_index = UINT_MAX; } }; @@ -70,7 +72,7 @@ namespace polysat { }; class layers final { - svector m_layers; + svector m_layers; public: svector const& get_layers() const { return m_layers; } layer& ensure_layer(unsigned bit_width); @@ -81,6 +83,57 @@ namespace polysat { entry* get_entries(unsigned bit_width) const { layer const* l = get_layer(bit_width); return l ? l->entries : nullptr; } }; + struct fixed_bits_info { + svector fixed; + vector> just_src; + vector> just_side_cond; + vector> just_slicing; + + bool is_empty() const { + SASSERT_EQ(fixed.empty(), just_src.empty()); + SASSERT_EQ(fixed.empty(), just_side_cond.empty()); + return fixed.empty(); + } + + bool is_empty_at(unsigned i) const { + return fixed[i] == l_undef && just_src[i].empty() && just_side_cond[i].empty(); + } + + void reset(unsigned num_bits) { + fixed.reset(); + fixed.resize(num_bits, l_undef); + just_src.reset(); + just_src.resize(num_bits); + just_side_cond.reset(); + just_side_cond.resize(num_bits); + just_slicing.reset(); + just_slicing.resize(num_bits); + } + + void reset_just(unsigned i) { + just_src[i].reset(); + just_side_cond[i].reset(); + just_slicing[i].reset(); + } + + void set_just(unsigned i, entry* e) { + reset_just(i); + push_just(i, e); + } + + void push_just(unsigned i, entry* e) { + just_src[i].append(e->src); + just_side_cond[i].append(e->side_cond); + } + + void push_from_bit(unsigned i, unsigned src) { + just_src[i].append(just_src[src]); + just_side_cond[i].append(just_side_cond[src]); + just_slicing[i].append(just_slicing[src]); + } + }; + + ptr_vector m_alloc; vector m_units; // set of viable values based on unit multipliers, layered by bit-width in descending order ptr_vector m_equal_lin; // entries that have non-unit multipliers, but are equal @@ -89,7 +142,7 @@ namespace polysat { bool well_formed(entry* e); bool well_formed(layers const& ls); - entry* alloc_entry(pvar v); + entry* alloc_entry(pvar v, unsigned constraint_index); std::ostream& display_one(std::ostream& out, pvar v, entry const* e) const; std::ostream& display_all(std::ostream& out, pvar v, entry const* e, char const* delimiter = "") const; @@ -97,9 +150,9 @@ namespace polysat { void log(pvar v); struct pop_viable_trail; - void pop_viable(entry* e, pvar v, entry_kind k); + void pop_viable(entry* e, entry_kind k); struct push_viable_trail; - void push_viable(entry* e, pvar v, entry_kind k); + void push_viable(entry* e); void insert(entry* e, pvar v, ptr_vector& entries, entry_kind k); @@ -107,6 +160,8 @@ namespace polysat { void ensure_var(pvar v); + lbool find_viable(pvar v, rational& lo, rational& hi); + public: viable(core& c); From 683a5dda377c0372f32a7aca74b4c545d74630d2 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 9 Dec 2023 10:58:54 -0800 Subject: [PATCH 16/89] remove include to bv-params Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat/polysat_core.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sat/smt/polysat/polysat_core.cpp b/src/sat/smt/polysat/polysat_core.cpp index 07eeaa0c1..a6ac25eda 100644 --- a/src/sat/smt/polysat/polysat_core.cpp +++ b/src/sat/smt/polysat/polysat_core.cpp @@ -28,7 +28,6 @@ polysat::core --*/ -#include "params/bv_rewriter_params.hpp" #include "sat/smt/polysat/polysat_core.h" namespace polysat { From 94ba85bb12530156c00928e9bbc37ee836891c50 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 9 Dec 2023 12:08:02 -0800 Subject: [PATCH 17/89] updates to viable --- src/sat/smt/polysat/polysat_core.cpp | 8 + src/sat/smt/polysat/polysat_core.h | 6 +- src/sat/smt/polysat/polysat_types.h | 4 + src/sat/smt/polysat/polysat_viable.cpp | 327 ++++- src/sat/smt/polysat/polysat_viable.h | 57 + src/sat/smt/polysat/slicing.cpp | 1727 ++++++++++++++++++++++++ src/sat/smt/polysat/slicing.h | 397 ++++++ src/sat/smt/polysat_solver.cpp | 42 + src/sat/smt/polysat_solver.h | 1 + 9 files changed, 2539 insertions(+), 30 deletions(-) create mode 100644 src/sat/smt/polysat/slicing.cpp create mode 100644 src/sat/smt/polysat/slicing.h diff --git a/src/sat/smt/polysat/polysat_core.cpp b/src/sat/smt/polysat/polysat_core.cpp index a6ac25eda..c41938a1a 100644 --- a/src/sat/smt/polysat/polysat_core.cpp +++ b/src/sat/smt/polysat/polysat_core.cpp @@ -278,6 +278,14 @@ namespace polysat { } } + void core::get_bitvector_prefixes(pvar v, pvar_vector& out) { + s.get_bitvector_prefixes(v, out); + } + + bool core::inconsistent() const { + return s.inconsistent(); + } + void core::propagate_unsat_core() { // default is to use unsat core: // if core is based on viable, use s.set_lemma(); diff --git a/src/sat/smt/polysat/polysat_core.h b/src/sat/smt/polysat/polysat_core.h index bb21ee641..766c3a9bc 100644 --- a/src/sat/smt/polysat/polysat_core.h +++ b/src/sat/smt/polysat/polysat_core.h @@ -79,6 +79,9 @@ namespace polysat { void propagate_assignment(pvar v, rational const& value, dependency dep); void propagate_unsat_core(); + void get_bitvector_prefixes(pvar v, pvar_vector& out); + bool inconsistent() const; + void add_watch(unsigned idx, unsigned var); signed_constraint get_constraint(unsigned idx, bool sign); @@ -89,8 +92,7 @@ namespace polysat { public: core(solver_interface& s); - sat::check_result check(); - + sat::check_result check(); unsigned register_constraint(signed_constraint& sc, dependency d); bool propagate(); void assign_eh(unsigned idx, bool sign, dependency const& d); diff --git a/src/sat/smt/polysat/polysat_types.h b/src/sat/smt/polysat/polysat_types.h index e77e755bf..5b63b5ee5 100644 --- a/src/sat/smt/polysat/polysat_types.h +++ b/src/sat/smt/polysat/polysat_types.h @@ -10,6 +10,9 @@ Author: #pragma once #include + + + #include "math/dd/dd_pdd.h" #include "util/trail.h" #include "util/sat_literal.h" @@ -62,6 +65,7 @@ namespace polysat { virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0; virtual trail_stack& trail() = 0; virtual bool inconsistent() const = 0; + virtual void get_bitvector_prefixes(pvar v, pvar_vector& out) = 0; }; } diff --git a/src/sat/smt/polysat/polysat_viable.cpp b/src/sat/smt/polysat/polysat_viable.cpp index 27b36c582..d69f40180 100644 --- a/src/sat/smt/polysat/polysat_viable.cpp +++ b/src/sat/smt/polysat/polysat_viable.cpp @@ -98,42 +98,31 @@ namespace polysat { #endif pvar_vector overlaps; -#if 0 - // TODO s.m_slicing.collect_simple_overlaps(v, overlaps); -#endif + c.get_bitvector_prefixes(v, overlaps); std::sort(overlaps.begin(), overlaps.end(), [&](pvar x, pvar y) { return c.size(x) > c.size(y); }); uint_set widths_set; // max size should always be present, regardless of whether we have intervals there (to make sure all fixed bits are considered) widths_set.insert(c.size(v)); -#if 0 - LOG("Overlaps with v" << v << ":"); - for (pvar x : overlaps) { - unsigned hi, lo; - if (s.m_slicing.is_extract(x, v, hi, lo)) - LOG(" v" << x << " = v" << v << "[" << hi << ":" << lo << "]"); - else - LOG(" v" << x << " not extracted from v" << v << "; size " << s.size(x)); - for (layer const& l : m_units[x].get_layers()) { - widths_set.insert(l.bit_width); - } - } -#endif + for (pvar v : overlaps) + ensure_var(v); + for (pvar v : overlaps) + for (layer const& l : m_units[v].get_layers()) + widths_set.insert(l.bit_width); + unsigned_vector widths; for (unsigned w : widths_set) - widths.push_back(w); + widths.push_back(w); + LOG("Overlaps with v" << v << ":" << overlaps); LOG("widths: " << widths); rational const& max_value = c.var2pdd(v).max_value(); -#if 0 lbool result_lo = find_on_layers(v, widths, overlaps, fbi, rational::zero(), max_value, lo); - if (result_lo == l_false) - return l_false; // conflict - if (result_lo == l_undef) - return find_viable_fallback(v, overlaps, lo, hi); + if (result_lo != l_true) + return result_lo; if (lo == max_value) { hi = lo; @@ -141,12 +130,294 @@ namespace polysat { } lbool result_hi = find_on_layers(v, widths, overlaps, fbi, lo + 1, max_value, hi); - if (result_hi == l_false) - hi = lo; // no other viable value - if (result_hi == l_undef) - return find_viable_fallback(v, overlaps, lo, hi); -#endif - return l_true; + + switch (result_hi) { + case l_false: + hi = lo; + return l_true; + case l_undef: + return l_undef; + default: + return l_true; + } + } + + // l_true ... found viable value + // l_false ... no viable value in [to_cover_lo;to_cover_hi[ + // l_undef ... out of resources + lbool viable::find_on_layers( + pvar const v, + unsigned_vector const& widths, + pvar_vector const& overlaps, + fixed_bits_info const& fbi, + rational const& to_cover_lo, + rational const& to_cover_hi, + rational& val + ) { + ptr_vector refine_todo; + ptr_vector relevant_entries; + + // max number of interval refinements before falling back to the univariate solver + unsigned const refinement_budget = 100; + unsigned refinements = refinement_budget; + + while (refinements--) { + relevant_entries.clear(); + lbool result = find_on_layer(v, widths.size() - 1, widths, overlaps, fbi, to_cover_lo, to_cover_hi, val, refine_todo, relevant_entries); + + // store bit-intervals we have used + for (entry* e : refine_todo) + intersect(v, e); + refine_todo.clear(); + + if (result != l_true) + return l_false; + + // overlaps are sorted by variable size in descending order + // start refinement on smallest variable + // however, we probably should rotate to avoid getting stuck in refinement loop on a 'bad' constraint + bool refined = false; + for (unsigned i = overlaps.size(); i-- > 0; ) { + pvar x = overlaps[i]; + rational const& mod_value = c.var2pdd(x).two_to_N(); + rational x_val = mod(val, mod_value); + if (!refine_viable(x, x_val)) { + refined = true; + break; + } + } + + if (!refined) + return l_true; + } + + LOG("Refinement budget exhausted! Fall back to univariate solver."); + return l_undef; + } + + // find viable values in half-open interval [to_cover_lo;to_cover_hi[ w.r.t. unit intervals on the given layer + // + // Returns: + // - l_true ... found value that is viable w.r.t. units and fixed bits + // - l_false ... found conflict + // - l_undef ... found no viable value in target interval [to_cover_lo;to_cover_hi[ + lbool viable::find_on_layer( + pvar const v, + unsigned const w_idx, + unsigned_vector const& widths, + pvar_vector const& overlaps, + fixed_bits_info const& fbi, + rational const& to_cover_lo, + rational const& to_cover_hi, + rational& val, + ptr_vector& refine_todo, + ptr_vector& relevant_entries + ) { + unsigned const w = widths[w_idx]; + rational const& mod_value = rational::power_of_two(w); + unsigned const first_relevant_for_conflict = relevant_entries.size(); + + LOG("layer " << w << " bits, to_cover: [" << to_cover_lo << "; " << to_cover_hi << "["); + SASSERT(0 <= to_cover_lo); + SASSERT(0 <= to_cover_hi); + SASSERT(to_cover_lo < mod_value); + SASSERT(to_cover_hi <= mod_value); // full interval if to_cover_hi == mod_value + SASSERT(to_cover_lo != to_cover_hi); // non-empty search domain (but it may wrap) + + // TODO: refinement of eq/diseq should happen only on the correct layer: where (one of) the coefficient(s) are odd + // for refinement, we have to choose an entry that is violated, but if there are multiple, we can choose the one on smallest domain (lowest bit-width). + // (by maintaining descending order by bit-width; and refine first that fails). + // but fixed-bit-refinement is cheap and could be done during the search. + + // when we arrive at a hole the possibilities are: + // 1) go to lower bitwidth + // 2) refinement of some eq/diseq constraint (if one is violated at that point) -- defer this until point is viable for all layers and fixed bits. + // 3) refinement by using bit constraints? -- TODO: do this during search (another interval check after/before the entry_cursors) + // 4) (point is actually feasible) + + // a complication is that we have to iterate over multiple lists of intervals. + // might be useful to merge them upfront to simplify the remainder of the algorithm? + // (non-trivial since prev/next pointers are embedded into entries and lists are updated by refinement) + struct entry_cursor { + entry* cur; + // entry* first; + // entry* last; + }; + + // find relevant interval lists + svector ecs; + for (pvar x : overlaps) { + if (c.size(x) < w) // note that overlaps are sorted by variable size descending + break; + if (entry* e = m_units[x].get_entries(w)) { + display_all(std::cerr << "units for width " << w << ":\n", 0, e, "\n"); + entry_cursor ec; + ec.cur = e; // TODO: e->prev() probably makes it faster when querying 0 (can often save going around the interval list once) + // ec.first = nullptr; + // ec.last = nullptr; + ecs.push_back(ec); + } + } + + rational const to_cover_len = r_interval::len(to_cover_lo, to_cover_hi, mod_value); + val = to_cover_lo; + + rational progress; // = 0 + SASSERT(progress.is_zero()); + while (true) { + while (true) { + entry* e = nullptr; + + // try to make progress using any of the relevant interval lists + for (entry_cursor& ec : ecs) { + // advance until current value 'val' + auto const [n, n_contains_val] = find_value(val, ec.cur); + // display_one(std::cerr << "found entry n: ", 0, n) << "\n"; + // LOG("n_contains_val: " << n_contains_val << " val = " << val); + ec.cur = n; + if (n_contains_val) { + e = n; + break; + } + } + + // when we cannot make progress by existing intervals any more, try interval from fixed bits + if (!e) { + e = refine_bits(v, val, w, fbi); + if (e) { + refine_todo.push_back(e); + display_one(std::cerr << "found entry by bits: ", 0, e) << "\n"; + } + } + + // no more progress on current layer + if (!e) + break; + + relevant_entries.push_back(e); + + if (e->interval.is_full()) { + relevant_entries.clear(); + relevant_entries.push_back(e); // full interval e -> all other intervals are subsumed/irrelevant + set_conflict_by_interval(v, w, relevant_entries, 0); + return l_false; + } + + SASSERT(e->interval.currently_contains(val)); + rational const& new_val = e->interval.hi_val(); + rational const dist = distance(val, new_val, mod_value); + SASSERT(dist > 0); + val = new_val; + progress += dist; + LOG("val: " << val << " progress: " << progress); + + if (progress >= mod_value) { + // covered the whole domain => conflict + set_conflict_by_interval(v, w, relevant_entries, first_relevant_for_conflict); + return l_false; + } + if (progress >= to_cover_len) { + // we covered the hole left at larger bit-width + // TODO: maybe we want to keep trying a bit longer to see if we can cover the whole domain. or maybe only if we enter this layer multiple times. + return l_undef; + } + + // (another way to compute 'progress') + SASSERT_EQ(progress, distance(to_cover_lo, val, mod_value)); + } + + // no more progress + // => 'val' is viable w.r.t. unit intervals until current layer + + if (!w_idx) { + // we are at the lowest layer + // => found viable value w.r.t. unit intervals and fixed bits + return l_true; + } + + // find next covered value + rational next_val = to_cover_hi; + for (entry_cursor& ec : ecs) { + // each ec.cur is now the next interval after 'lo' + rational const& n = ec.cur->interval.lo_val(); + SASSERT(r_interval::contains(ec.cur->prev()->interval.hi_val(), n, val)); + if (distance(val, n, mod_value) < distance(val, next_val, mod_value)) + next_val = n; + } + if (entry* e = refine_bits(v, next_val, w, fbi)) { + refine_todo.push_back(e); + rational const& n = e->interval.lo_val(); + SASSERT(distance(val, n, mod_value) < distance(val, next_val, mod_value)); + next_val = n; + } + SASSERT(!refine_bits(v, val, w, fbi)); + SASSERT(val != next_val); + + unsigned const lower_w = widths[w_idx - 1]; + rational const lower_mod_value = rational::power_of_two(lower_w); + + rational lower_cover_lo, lower_cover_hi; + if (distance(val, next_val, mod_value) >= lower_mod_value) { + // NOTE: in this case we do not get the first viable value, but the one with smallest value in the lower bits. + // this is because we start the search in the recursive case at 0. + // if this is a problem, adapt to lower_cover_lo = mod(val, lower_mod_value), lower_cover_hi = ... + lower_cover_lo = 0; + lower_cover_hi = lower_mod_value; + rational a; + lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo, relevant_entries); + VERIFY(result != l_undef); + if (result == l_false) { + SASSERT(c.inconsistent()); + return l_false; // conflict + } + SASSERT(result == l_true); + // replace lower bits of 'val' by 'a' + rational const val_lower = mod(val, lower_mod_value); + val = val - val_lower + a; + if (a < val_lower) + a += lower_mod_value; + LOG("distance(val, cover_hi) = " << distance(val, to_cover_hi, mod_value)); + LOG("distance(next_val, cover_hi) = " << distance(next_val, to_cover_hi, mod_value)); + SASSERT(distance(val, to_cover_hi, mod_value) >= distance(next_val, to_cover_hi, mod_value)); + return l_true; + } + + lower_cover_lo = mod(val, lower_mod_value); + lower_cover_hi = mod(next_val, lower_mod_value); + + rational a; + lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo, relevant_entries); + if (result == l_false) { + SASSERT(c.inconsistent()); + return l_false; // conflict + } + + // replace lower bits of 'val' by 'a' + rational const dist = distance(lower_cover_lo, a, lower_mod_value); + val += dist; + progress += dist; + LOG("distance(val, cover_hi) = " << distance(val, to_cover_hi, mod_value)); + LOG("distance(next_val, cover_hi) = " << distance(next_val, to_cover_hi, mod_value)); + SASSERT(distance(val, to_cover_hi, mod_value) >= distance(next_val, to_cover_hi, mod_value)); + + if (result == l_true) + return l_true; // done + + SASSERT(result == l_undef); + + if (progress >= mod_value) { + // covered the whole domain => conflict + set_conflict_by_interval(v, w, relevant_entries, first_relevant_for_conflict); + return l_false; + } + + if (progress >= to_cover_len) { + // we covered the hole left at larger bit-width + return l_undef; + } + } + UNREACHABLE(); + return l_undef; } diff --git a/src/sat/smt/polysat/polysat_viable.h b/src/sat/smt/polysat/polysat_viable.h index 7c9c019dc..37b1d7b0c 100644 --- a/src/sat/smt/polysat/polysat_viable.h +++ b/src/sat/smt/polysat/polysat_viable.h @@ -162,6 +162,63 @@ namespace polysat { lbool find_viable(pvar v, rational& lo, rational& hi); + lbool find_on_layers( + pvar v, + unsigned_vector const& widths, + pvar_vector const& overlaps, + fixed_bits_info const& fbi, + rational const& to_cover_lo, + rational const& to_cover_hi, + rational& out_val); + + lbool find_on_layer( + pvar v, + unsigned w_idx, + unsigned_vector const& widths, + pvar_vector const& overlaps, + fixed_bits_info const& fbi, + rational const& to_cover_lo, + rational const& to_cover_hi, + rational& out_val, + ptr_vector& refine_todo, + ptr_vector& relevant_entries); + + + template + bool refine_viable(pvar v, rational const& val, fixed_bits_info const& fbi) { + throw default_exception("nyi"); + } + + bool refine_viable(pvar v, rational const& val) { + throw default_exception("nyi"); + } + + template + bool refine_bits(pvar v, rational const& val, fixed_bits_info const& fbi) { + throw default_exception("nyi"); + } + + template + entry* refine_bits(pvar v, rational const& val, unsigned num_bits, fixed_bits_info const& fbi) { + throw default_exception("nyi"); + } + + bool refine_equal_lin(pvar v, rational const& val) { + throw default_exception("nyi"); + } + + bool refine_disequal_lin(pvar v, rational const& val) { + throw default_exception("nyi"); + } + + bool set_conflict_by_interval(pvar v, unsigned w, ptr_vector& intervals, unsigned first_interval) { + throw default_exception("nyi"); + } + + std::pair find_value(rational const& val, entry* entries) { + throw default_exception("nyi"); + } + public: viable(core& c); diff --git a/src/sat/smt/polysat/slicing.cpp b/src/sat/smt/polysat/slicing.cpp new file mode 100644 index 000000000..04fe8c4fc --- /dev/null +++ b/src/sat/smt/polysat/slicing.cpp @@ -0,0 +1,1727 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + polysat slicing + +Author: + + Jakob Rath 2023-06-01 + +--*/ + + + + +/* + +Example: +(1) x = y +(2) z = y[3:0] +(3) explain(x[3:0] == z)? should be { (1), (2) } + + (1) + x ========================> y + / \ / \ (2) + x[7:4] x[3:0] y[7:4] y[3:0] ===========> z + + +TODO: +- About the sub-slice sharing among equivalent nodes: + - When extracting a variable y := x[h:l], we always need to create a new slice for y. + - Merge slices for x[h:l] with y; store as dependency 'x[h:l]' (rather than 'null_dep' as we do now). + - when merging, we must avoid that the new variable becomes the root of the equivalence class, + because when finding dependencies for 'y := x[h:l]', such extraction-dependencies would be false/unnecessary. + (alternatively, just ignore them. but we never *have* to have them as root, so just don't do it. but add assertions for 1. new var is not root, 2. no extraction-dependency when walking from 'x' to 'x[h:l]'.) + - When encountering this dependency, need to start at slice for 'x' and walk towards 'x[h:l]', + collecting dependencies whenever we move to a representative. +- when solver assigns value of a variable v, add equations with v substituted by its value? + - since we only track equations over variables/names, this might not lead to many further additions + - a question is how to track/handle the dependency on the assignment +- check Algorithm 2 of "Solving Bitvectors with MCSAT"; what is the difference to what we are doing now? +- track equalities such as x = -y ? +- on_merge could propagate values upwards: + if slice has value but parent has no value, then check if sub_other(parent(s)) [sibling(s)?] has a value. + if yes, can propagate value upwards. (add a congruence term to track deps properly?). + we have to check the whole equivalence class, because the parents may be in different classes. + it is enough to propagate values to variables. we could count (in the variable slice) the number of its base slices that are still unassigned. + +*/ + + +#include "ast/reg_decl_plugins.h" +#include "math/polysat/slicing.h" +#include "math/polysat/solver.h" +#include "math/polysat/log.h" +#include "util/tptr.h" + + +namespace { + + template + [[maybe_unused]] + inline constexpr bool always_false_v = false; + +} + +namespace polysat { + + void* slicing::dep_t::encode() const { + void* p = std::visit([](auto arg) -> void* { + using T = std::decay_t; + if constexpr (std::is_same_v) + return nullptr; + else if constexpr (std::is_same_v) + return box(arg.to_uint(), 1); + else if constexpr (std::is_same_v) + return box(arg, 2); + else + static_assert(always_false_v, "non-exhaustive visitor!"); + }, m_data); + SASSERT(*this == decode(p)); + return p; + } + + slicing::dep_t slicing::dep_t::decode(void* p) { + if (!p) + return {}; + unsigned tag = get_tag(p); + SASSERT(tag == 1 || tag == 2); + if (tag == 1) + return dep_t(sat::to_literal(unbox(p))); + else + return dep_t(unbox(p)); + } + + std::ostream& slicing::display(std::ostream& out, dep_t d) const { + if (d.is_null()) + out << "null"; + else if (d.is_value()) { + pvar x = get_dep_var(d); + enode* n = get_dep_slice(d); + sat::literal lit = get_dep_lit(d); + out << "value(v" << x << " on slice "; + if (n) + out << n->get_id(); + else + out << ""; + if (lit != sat::null_literal) + out << " by literal " << lit; + out << ")"; + } + else if (d.is_lit()) + out << "lit(" << d.lit() << ")"; + return out; + } + + slicing::dep_t slicing::mk_var_dep(pvar v, enode* s, sat::literal lit) { + SASSERT_EQ(m_dep_var.size(), m_dep_slice.size()); + SASSERT_EQ(m_dep_var.size(), m_dep_lit.size()); + unsigned const idx = m_dep_var.size(); + m_dep_var.push_back(v); + m_dep_lit.push_back(lit); + m_dep_slice.push_back(s); + return dep_t(idx); + } + + slicing::slicing(solver& s): + m_solver(s), + m_egraph(m_ast) + { + reg_decl_plugins(m_ast); + m_bv = alloc(bv_util, m_ast); + m_egraph.set_display_justification([&](std::ostream& out, void* dp) { display(out, dep_t::decode(dp)); }); + m_egraph.set_on_merge([&](enode* root, enode* other) { egraph_on_merge(root, other); }); + m_egraph.set_on_propagate([&](enode* lit, enode* ante) { egraph_on_propagate(lit, ante); }); + // m_egraph.set_on_make([&](enode* n) { egraph_on_make(n); }); + } + + slicing::slice_info& slicing::info(enode* n) { + return const_cast(std::as_const(*this).info(n)); + } + + slicing::slice_info const& slicing::info(enode* n) const { + SASSERT(n); + SASSERT(!n->is_equality()); + SASSERT(m_bv->is_bv_sort(n->get_sort())); + slice_info const& i = m_info[n->get_id()]; + return i.slice ? info(i.slice) : i; + } + + bool slicing::is_slice(enode* n) const { + if (n->is_equality()) + return false; + SASSERT(m_bv->is_bv_sort(n->get_sort())); + slice_info const& i = m_info[n->get_id()]; + return !i.slice; + } + + bool slicing::is_concat(enode* n) const { + if (n->is_equality()) + return false; + return !is_slice(n); + } + + unsigned slicing::width(enode* s) const { + SASSERT(!s->is_equality()); + return m_bv->get_bv_size(s->get_expr()); + } + + slicing::enode* slicing::sibling(enode* s) const { + enode* p = parent(s); + SASSERT(p); + SASSERT(sub_lo(p) == s || sub_hi(p) == s); + if (s != sub_hi(p)) + return sub_hi(p); + else + return sub_lo(p); + } + + func_decl* slicing::mk_concat_decl(ptr_vector const& args) { + SASSERT(args.size() >= 2); + ptr_vector domain; + unsigned sz = 0; + for (expr* e : args) { + domain.push_back(e->get_sort()); + sz += m_bv->get_bv_size(e); + } + sort* range = m_bv->mk_sort(sz); + return m_ast.mk_func_decl(symbol("slice-concat"), domain.size(), domain.data(), range); + } + + void slicing::push_scope() { + LOG("push_scope"); + if (can_propagate()) + propagate(); + m_scopes.push_back(m_trail.size()); + m_egraph.push(); + m_dep_size_trail.push_back(m_dep_var.size()); + SASSERT(!use_var_congruences() || m_needs_congruence.empty()); + } + + void slicing::pop_scope(unsigned num_scopes) { + LOG("pop_scope(" << num_scopes << ")"); + if (num_scopes == 0) + return; + unsigned const lvl = m_scopes.size(); + SASSERT(num_scopes <= lvl); + unsigned const target_lvl = lvl - num_scopes; + unsigned const target_size = m_scopes[target_lvl]; + m_scopes.shrink(target_lvl); + svector replay_trail; + unsigned_vector replay_add_var_trail; + svector> replay_extract_trail; + svector replay_concat_trail; + unsigned num_replay_concat = 0; + for (unsigned i = m_trail.size(); i-- > target_size; ) { + switch (m_trail[i]) { + case trail_item::add_var: + replay_trail.push_back(trail_item::add_var); + replay_add_var_trail.push_back(width(m_var2slice.back())); + undo_add_var(); + break; + case trail_item::split_core: + undo_split_core(); + break; + case trail_item::mk_extract: { + replay_trail.push_back(trail_item::mk_extract); + extract_args const& args = m_extract_trail.back(); + replay_extract_trail.push_back({args, m_extract_dedup[args]}); + undo_mk_extract(); + break; + } + case trail_item::mk_concat: + replay_trail.push_back(trail_item::mk_concat); + num_replay_concat++; + break; + case trail_item::set_value_node: + undo_set_value_node(); + break; + default: + UNREACHABLE(); + } + } + m_egraph.pop(num_scopes); + m_needs_congruence.reset(); + m_disequality_conflict = nullptr; + m_dep_var.shrink(m_dep_size_trail[target_lvl]); + m_dep_lit.shrink(m_dep_size_trail[target_lvl]); + m_dep_slice.shrink(m_dep_size_trail[target_lvl]); + m_dep_size_trail.shrink(target_lvl); + m_trail.shrink(target_size); + // replay add_var/mk_extract/mk_concat in the same order + // (only until polysat::solver supports proper garbage collection of variables) + unsigned add_var_idx = replay_add_var_trail.size(); + unsigned extract_idx = replay_extract_trail.size(); + unsigned concat_idx = m_concat_trail.size() - num_replay_concat; + for (auto it = replay_trail.rbegin(); it != replay_trail.rend(); ++it) { + switch (*it) { + case trail_item::add_var: { + unsigned const sz = replay_add_var_trail[--add_var_idx]; + add_var(sz); + break; + } + case trail_item::mk_extract: { + auto const [args, v] = replay_extract_trail[--extract_idx]; + replay_extract(args, v); + break; + } + case trail_item::mk_concat: { + NOT_IMPLEMENTED_YET(); + auto const ci = m_concat_trail[concat_idx++]; + num_replay_concat++; + replay_concat(ci.num_args, &m_concat_args[ci.args_idx], ci.v); + break; + } + default: + UNREACHABLE(); + } + } + SASSERT(invariant()); + } + + void slicing::add_var(unsigned bit_width) { + pvar const v = m_var2slice.size(); + enode* s = alloc_slice(bit_width, v); + m_var2slice.push_back(s); + m_trail.push_back(trail_item::add_var); + LOG_V(10, "add_var: v" << v << " -> " << slice_pp(*this, s)); + } + + void slicing::undo_add_var() { + m_var2slice.pop_back(); + } + + slicing::enode* slicing::find_or_alloc_disequality(enode* x, enode* y, sat::literal lit) { + expr_ref eq(m_ast.mk_eq(x->get_expr(), y->get_expr()), m_ast); + enode* eqn = m_egraph.find(eq); + if (eqn) + return eqn; + auto args = {x, y}; + eqn = m_egraph.mk(eq, 0, args.size(), args.begin()); + auto j = euf::justification::external(dep_t(lit).encode()); + m_egraph.set_value(eqn, l_false, j); + SASSERT(eqn->is_equality()); + SASSERT_EQ(eqn->value(), l_false); + return eqn; + } + + void slicing::egraph_on_make(enode* n) { + LOG("on_make: " << e_pp(n)); + } + + slicing::enode* slicing::alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var) { + SASSERT(!m_egraph.find(e)); + // NOTE: sometimes egraph::mk already triggers a merge due to congruence. + // in this case we have to make sure to allocate m_info early enough. + unsigned const id = e->get_id(); + m_info.reserve(id + 1); + slice_info& i = m_info[id]; + i.reset(); + i.var = var; + enode* n = m_egraph.mk(e, 0, num_args, args); // NOTE: the egraph keeps a strong reference to 'e' + LOG_V(10, "alloc_enode: " << slice_pp(*this, n) << " " << e_pp(n)); + return n; + } + + slicing::enode* slicing::find_or_alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var) { + enode* n = m_egraph.find(e); + if (n) { + SASSERT_EQ(info(n).var, var); + return n; + } + return alloc_enode(e, num_args, args, var); + } + + slicing::enode* slicing::alloc_slice(unsigned width, pvar var) { + SASSERT(width > 0); + app_ref a(m_ast.mk_fresh_const("s", m_bv->mk_sort(width), false), m_ast); + return alloc_enode(a, 0, nullptr, var); + } + + slicing::enode* slicing::mk_concat_node(enode_vector const& slices) { + return mk_concat_node(slices.size(), slices.data()); + } + + slicing::enode* slicing::mk_concat_node(unsigned num_slices, enode* const* slices) { + ptr_vector args; + for (unsigned i = 0; i < num_slices; ++i) + args.push_back(slices[i]->get_expr()); + app_ref a(m_ast.mk_app(mk_concat_decl(args), args), m_ast); + return find_or_alloc_enode(a, num_slices, slices, null_var); + } + + void slicing::add_concat_node(enode* s, enode* concat) { + SASSERT(slice2var(s) != null_var); // all concat nodes should point to a variable node + SASSERT(is_app(concat->get_expr())); + slice_info& concat_info = m_info[concat->get_id()]; + if (s->get_root() == concat->get_root()) { + SASSERT_EQ(s, concat_info.slice); + return; + } + SASSERT(!concat_info.slice); // not yet set + concat_info.slice = s; + egraph_merge(s, concat, dep_t()); + } + + void slicing::add_var_congruence(pvar v) { + enode_vector& base = m_tmp2; + SASSERT(base.empty()); + enode* sv = var2slice(v); + get_base(sv, base); + // Add equation v == concat(s1, ..., sn) + add_concat_node(sv, mk_concat_node(base)); + base.clear(); + } + + void slicing::add_var_congruence_if_needed(pvar v) { + if (!m_needs_congruence.contains(v)) + return; + m_needs_congruence.remove(v); + add_var_congruence(v); + } + + void slicing::update_var_congruences() { + if (!use_var_congruences()) + return; + // TODO: this is only needed once per equivalence class + // (mark root of var2slice to detect duplicates?) + for (pvar v : m_needs_congruence) + add_var_congruence(v); + m_needs_congruence.reset(); + } + + bool slicing::use_var_congruences() const { + return m_solver.config().m_slicing_congruence; + } + + // split a single slice without updating any equivalences + void slicing::split_core(enode* s, unsigned cut) { + SASSERT(is_slice(s)); // this action only makes sense for slices + SASSERT(!has_sub(s)); + SASSERT(info(s).sub_hi == nullptr && info(s).sub_lo == nullptr); + SASSERT(width(s) > cut + 1); + unsigned const width_hi = width(s) - cut - 1; + unsigned const width_lo = cut + 1; + enode* sub_hi; + enode* sub_lo; + if (is_value(s)) { + rational const val = get_value(s); + sub_hi = mk_value_slice(machine_div2k(val, width_lo), width_hi); + sub_lo = mk_value_slice(mod2k(val, width_lo), width_lo); + } + else { + sub_hi = alloc_slice(width_hi); + sub_lo = alloc_slice(width_lo); + } + SASSERT(!parent(sub_hi)); + SASSERT(!parent(sub_lo)); + info(sub_hi).parent = s; + info(sub_lo).parent = s; + info(s).set_cut(cut, sub_hi, sub_lo); + m_trail.push_back(trail_item::split_core); + m_enode_trail.push_back(s); + for (enode* n = s; n != nullptr; n = parent(n)) { + pvar const v = slice2var(n); + if (v == null_var) + continue; + if (m_needs_congruence.contains(v)) { + SASSERT(invariant_needs_congruence()); + break; // added parents already previously + } + m_needs_congruence.insert(v); + } + } + + bool slicing::invariant_needs_congruence() const { + for (pvar v : m_needs_congruence) + for (enode* s = var2slice(v); s != nullptr; s = parent(s)) + if (slice2var(s) != null_var) { + VERIFY(m_needs_congruence.contains(slice2var(s))); + } + return true; + } + + void slicing::undo_split_core() { + enode* s = m_enode_trail.back(); + m_enode_trail.pop_back(); + info(s).set_cut(null_cut, nullptr, nullptr); + } + + void slicing::split(enode* s, unsigned cut) { + // this action only makes sense for base slices. + // a base slice is never equivalent to a congruence node. + SASSERT(is_slice(s)); + SASSERT(!has_sub(s)); + SASSERT(cut != null_cut); + // split all slices in the equivalence class + for (enode* n : euf::enode_class(s)) + split_core(n, cut); + // propagate equivalences to subslices + for (enode* n : euf::enode_class(s)) { + enode* target = n->get_target(); + if (!target) + continue; + euf::justification const j = n->get_justification(); + SASSERT(j.is_external()); // cannot be a congruence since the slice wasn't split before. + dep_t const d = dep_t::decode(j.ext()); + egraph_merge(sub_hi(n), sub_hi(target), d); + egraph_merge(sub_lo(n), sub_lo(target), d); + } + } + + void slicing::mk_slice(enode* src, unsigned const hi, unsigned const lo, enode_vector& out, bool output_full_src, bool output_base) { + SASSERT(hi >= lo); + SASSERT(width(src) > hi); // extracted range must be fully contained inside the src slice + auto output_slice = [this, output_base, &out](enode* s) { + if (output_base) + get_base(s, out); + else + out.push_back(s); + }; + if (lo == 0 && width(src) - 1 == hi) { + output_slice(src); + return; + } + if (has_sub(src)) { + // src is split into [src.width-1, cut+1] and [cut, 0] + unsigned const cut = info(src).cut; + if (lo >= cut + 1) { + // target slice falls into upper subslice + mk_slice(sub_hi(src), hi - cut - 1, lo - cut - 1, out, output_full_src, output_base); + if (output_full_src) + output_slice(sub_lo(src)); + return; + } + else if (cut >= hi) { + // target slice falls into lower subslice + if (output_full_src) + output_slice(sub_hi(src)); + mk_slice(sub_lo(src), hi, lo, out, output_full_src, output_base); + return; + } + else { + SASSERT(hi > cut && cut >= lo); + // desired range spans over the cutpoint, so we get multiple slices in the result + mk_slice(sub_hi(src), hi - cut - 1, 0, out, output_full_src, output_base); + mk_slice(sub_lo(src), cut, lo, out, output_full_src, output_base); + return; + } + } + else { + // [src.width-1, 0] has no subdivision yet + if (width(src) - 1 > hi) { + split(src, hi); + SASSERT(!has_sub(sub_hi(src))); + if (output_full_src) + out.push_back(sub_hi(src)); + mk_slice(sub_lo(src), hi, lo, out, output_full_src, output_base); // recursive call to take care of case lo > 0 + return; + } + else { + SASSERT(lo > 0); + split(src, lo - 1); + out.push_back(sub_hi(src)); + SASSERT(!has_sub(sub_lo(src))); + if (output_full_src) + out.push_back(sub_lo(src)); + return; + } + } + UNREACHABLE(); + } + + slicing::enode* slicing::mk_value_slice(rational const& val, unsigned bit_width) { + SASSERT(bit_width > 0); + SASSERT(0 <= val && val < rational::power_of_two(bit_width)); + sort* bv_sort = m_bv->mk_sort(bit_width); + func_decl_ref f(m_ast.mk_fresh_func_decl("val", nullptr, 1, &bv_sort, bv_sort, false), m_ast); + app_ref a(m_ast.mk_app(f, m_bv->mk_numeral(val, bit_width)), m_ast); + enode* s = alloc_enode(a, 0, nullptr, null_var); + set_value_node(s, s); + SASSERT_EQ(get_value(s), val); + return s; + } + + slicing::enode* slicing::mk_interpreted_value_node(enode* s) { + SASSERT(is_value(s)); + // NOTE: how this is used now, the node will not yet be contained in the egraph. + enode* n = alloc_enode(s->get_app()->get_arg(0), 0, nullptr, null_var); + info(n).value_node = s; + n->mark_interpreted(); + SASSERT(n->interpreted()); + SASSERT_EQ(get_value_node(n), s); + return n; + } + + bool slicing::is_value(enode* n) const { + SASSERT(n); + SASSERT(is_app(n->get_expr())); // we only create app nodes at the moment; if this fails just return false here. + app* a = n->get_app(); + return a->get_num_args() == 1 && m_bv->is_numeral(a->get_arg(0)); + } + + rational slicing::get_value(enode* s) const { + SASSERT(is_value(s)); + rational val; + VERIFY(try_get_value(s, val)); + return val; + } + + bool slicing::try_get_value(enode* s, rational& val) const { + if (!s) + return false; + app* a = s->get_app(); + if (a->get_num_args() != 1) + return false; + bool const ok = m_bv->is_numeral(a->get_arg(0), val); + SASSERT_EQ(ok, is_value(s)); + return ok; + } + + void slicing::explain_class(enode* x, enode* y, ptr_vector& out_deps) { + SASSERT_EQ(x->get_root(), y->get_root()); + m_egraph.begin_explain(); + m_egraph.explain_eq(out_deps, nullptr, x, y); + m_egraph.end_explain(); + } + + void slicing::explain_equal(enode* x, enode* y, ptr_vector& out_deps) { + SASSERT(is_equal(x, y)); + SASSERT_EQ(width(x), width(y)); + enode_vector& xs = m_tmp2; + enode_vector& ys = m_tmp3; + SASSERT(xs.empty()); + SASSERT(ys.empty()); + xs.push_back(x); + ys.push_back(y); + while (!xs.empty()) { + SASSERT(!ys.empty()); + enode* const x = xs.back(); xs.pop_back(); + enode* const y = ys.back(); ys.pop_back(); + if (x == y) + continue; + if (width(x) == width(y)) { + enode* const rx = x->get_root(); + enode* const ry = y->get_root(); + if (rx == ry) + explain_class(x, y, out_deps); + else { + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + ys.push_back(sub_hi(y)); + ys.push_back(sub_lo(y)); + } + } + else if (width(x) > width(y)) { + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + ys.push_back(y); + } + else { + SASSERT(width(x) < width(y)); + xs.push_back(x); + ys.push_back(sub_hi(y)); + ys.push_back(sub_lo(y)); + } + } + SASSERT(ys.empty()); + } + + void slicing::explain_equal(pvar x, pvar y, ptr_vector& out_deps) { + explain_equal(var2slice(x), var2slice(y), out_deps); + } + + void slicing::explain_equal(pvar x, pvar y, std::function const& on_lit) { + SASSERT(m_marked_lits.empty()); + SASSERT(m_tmp_deps.empty()); + explain_equal(x, y, m_tmp_deps); + for (void* dp : m_tmp_deps) { + dep_t const d = dep_t::decode(dp); + if (d.is_null()) + continue; + if (d.is_lit()) { + sat::literal lit = d.lit(); + if (m_marked_lits.contains(lit)) + continue; + m_marked_lits.insert(lit); + on_lit(d.lit()); + } + else { + // equivalence between to variables cannot be due to value assignment + UNREACHABLE(); + } + } + m_marked_lits.reset(); + m_tmp_deps.reset(); + } + + void slicing::explain(ptr_vector& out_deps) { + SASSERT(is_conflict()); + m_egraph.begin_explain(); + if (m_disequality_conflict) { + LOG("Disequality conflict: " << m_disequality_conflict); + enode* eqn = m_disequality_conflict; + SASSERT(eqn->is_equality()); + SASSERT_EQ(eqn->value(), l_false); + SASSERT(eqn->get_lit_justification().is_external()); + SASSERT(m_ast.is_eq(eqn->get_expr())); + SASSERT_EQ(eqn->get_arg(0)->get_root(), eqn->get_arg(1)->get_root()); + m_egraph.explain_eq(out_deps, nullptr, eqn->get_arg(0), eqn->get_arg(1)); + out_deps.push_back(eqn->get_lit_justification().ext()); + } + else { + SASSERT(m_egraph.inconsistent()); + m_egraph.explain(out_deps, nullptr); + } + m_egraph.end_explain(); + } + + clause_ref slicing::build_conflict_clause() { + LOG_H1("slicing: build_conflict_clause"); + // display_tree(verbose_stream()); + + SASSERT(invariant()); + SASSERT(is_conflict()); + SASSERT(m_marked_lits.empty()); + SASSERT(m_tmp_deps.empty()); + explain(m_tmp_deps); + clause_builder cb(m_solver, "slicing"); + + auto add_premise = [this, &cb](sat::literal lit) { + LOG("Premise: " << lit_pp(m_solver, lit)); + cb.insert(~lit); + }; + + auto add_conclusion = [this, &cb](signed_constraint c) { + LOG("Conclusion: " << lit_pp(m_solver, c)); + cb.insert_eval(c); + }; + + pvar x = null_var; enode* sx = nullptr; sat::literal xlit = sat::null_literal; + pvar y = null_var; enode* sy = nullptr; sat::literal ylit = sat::null_literal; + for (void* dp : m_tmp_deps) { + dep_t const d = dep_t::decode(dp); + // LOG("dep: " << dep_pp(*this, d)); + if (d.is_null()) + continue; + if (d.is_lit()) { + sat::literal const lit = d.lit(); + if (m_marked_lits.contains(lit)) + continue; + m_marked_lits.insert(lit); + add_premise(lit); + } + else { + SASSERT(d.is_value()); + pvar const v = get_dep_var(d); + enode* const sv = get_dep_slice(d); + sat::literal const lit = get_dep_lit(d); + if (x == null_var) + x = v, sx = sv, xlit = lit; + else if (y == null_var) + y = v, sy = sv, ylit = lit; + else { + // pvar justifications are only introduced by add_value, i.e., when a variable is assigned in the solver. + // thus there can be at most two pvar justifications in a single conflict. + UNREACHABLE(); + } + } + } + m_marked_lits.reset(); + m_tmp_deps.reset(); + + if (x != null_var && y != null_var && xlit == sat::null_literal && ylit != sat::null_literal) { + using std::swap; + swap(x, y); + swap(sx, sy); + swap(xlit, ylit); + } + + if (x != null_var) { + LOG("Variable v" << x << " with slice " << slice_pp(*this, sx) << " by literal " << lit_pp(m_solver, xlit)); + if (m_solver.is_assigned(x)) + LOG("solver-value " << assignment_pp(m_solver, x, m_solver.get_value(x))); + } + if (y != null_var) { + LOG("Variable v" << y << " with slice " << slice_pp(*this, sy) << " by literal " << lit_pp(m_solver, ylit)); + if (m_solver.is_assigned(y)) + LOG("solver-value " << assignment_pp(m_solver, y, m_solver.get_value(y))); + } + + // conflict has either 0 or 2 vars + VERIFY(x != null_var || y == null_var); + VERIFY(y != null_var || x == null_var); + + if (xlit != sat::null_literal && ylit != sat::null_literal) { + verbose_stream() << "build_conflict_clause (2)" << std::endl; + add_premise(xlit); + add_premise(ylit); + } + else if (xlit != sat::null_literal && ylit == sat::null_literal) { + verbose_stream() << "build_conflict_clause (1)" << std::endl; + add_premise(xlit); + + // rational const x_slice_value = get_value(get_value_node(var2slice(x))); + // LOG("v" << x << " slice_value: " << x_slice_value); + rational const y_slice_value = get_value(get_value_node(var2slice(y))); + LOG("v" << y << " slice_value: " << y_slice_value); + // SASSERT(x_slice_value != y_slice_value); + add_conclusion(~m_solver.eq(m_solver.var(y), y_slice_value)); + +/* + unsigned x_hi, x_lo; + VERIFY(find_range_in_ancestor(sx, var2slice(x), x_hi, x_lo)); + pvar const xx = mk_extract(x, x_hi, x_lo); + LOG("find_range_in_ancestor: v" << x << "[" << x_hi << ":" << x_lo << "] = " << slice_pp(*this, sx) << " --> represented by variable v" << xx); + unsigned y_hi, y_lo; + VERIFY(find_range_in_ancestor(sy, var2slice(y), y_hi, y_lo)); + pvar const yy = mk_extract(y, y_hi, y_lo); + LOG("find_range_in_ancestor: v" << y << "[" << y_hi << ":" << y_lo << "] = " << slice_pp(*this, sy) << " --> represented by variable v" << yy); + // LOG("v" << x << " has solver-value? " << m_solver.is_assigned(x)); + if (m_solver.is_assigned(x)) LOG("v" << x << " has solver-value " << m_solver.get_value(x)); + // LOG("v" << y << " has solver-value? " << m_solver.is_assigned(y)); + if (m_solver.is_assigned(y)) LOG("v" << y << " has solver-value " << m_solver.get_value(y)); + LOG("v" << x << " is slice " << slice_pp(*this, var2slice(x))); + LOG("v" << y << " is slice " << slice_pp(*this, var2slice(y))); + SASSERT_EQ(sy->get_root(), var2slice(yy)->get_root()); + rational const sy_slice_value = get_value(get_value_node(sy)); + // rational const sy_solver_value = mod2k(machine_div2k(m_solver.get_value(y), lo), hi - lo + 1); + // c = m_solver.eq(m_solver.var(yy), sy_slice_value); +*/ + } + else { + verbose_stream() << "build_conflict_clause (0)" << std::endl; + SASSERT(xlit == sat::null_literal); + SASSERT(ylit == sat::null_literal); + + // unsigned x_hi, x_lo, y_hi, y_lo; + // VERIFY(find_range_in_ancestor(sx, var2slice(x), x_hi, x_lo)); + // VERIFY(find_range_in_ancestor(sy, var2slice(y), y_hi, y_lo)); + // pvar const xx = mk_extract(x, x_hi, x_lo); + // pvar const yy = mk_extract(y, y_hi, y_lo); + // SASSERT_EQ(sx->get_root(), var2slice(xx)->get_root()); + // SASSERT_EQ(sy->get_root(), var2slice(yy)->get_root()); + // rational sval = mod2k(machine_div2k(m_solver.get_value(x), x_lo), x_hi - x_lo + 1); + // LOG("find_range_in_ancestor: v" << x << "[" << x_hi << ":" << x_lo << "] = " << slice_pp(*this, sx) << " --> represented by variable v" << xx); + // LOG("find_range_in_ancestor: v" << y << "[" << y_hi << ":" << y_lo << "] = " << slice_pp(*this, sy) << " --> represented by variable v" << yy); + LOG("v" << x << " is slice " << slice_pp(*this, var2slice(x))); + LOG("v" << y << " is slice " << slice_pp(*this, var2slice(y))); + if (m_solver.is_assigned(x)) LOG("v" << x << " has solver-value " << m_solver.get_value(x)); + if (m_solver.is_assigned(y)) LOG("v" << y << " has solver-value " << m_solver.get_value(y)); + // SASSERT(xx != yy); + // c = m_solver.eq(m_solver.var(xx), m_solver.var(yy)); // similar to what Algorithm 1 in BitvectorsMCSAT is doing + // LOG("c: " << lit_pp(m_solver, c)); + + rational const x_slice_value = get_value(get_value_node(var2slice(x))); + LOG("v" << x << " slice-value: " << x_slice_value); + add_conclusion(~m_solver.eq(m_solver.var(x), x_slice_value)); + + rational const y_slice_value = get_value(get_value_node(var2slice(y))); + LOG("v" << y << " slice-value: " << y_slice_value); + add_conclusion(~m_solver.eq(m_solver.var(y), y_slice_value)); + } + + // TODO: we don't need clauses like this ... rather set up the conflict core from it + + return cb.build(); + } + + void slicing::explain_value(enode* s, std::function const& on_lit, std::function const& on_var) { + SASSERT(invariant()); + SASSERT(m_marked_lits.empty()); + + enode* n = get_value_node(s); + SASSERT(is_value(n)); + + SASSERT(m_tmp_deps.empty()); + explain_equal(s, n, m_tmp_deps); + + for (void* dp : m_tmp_deps) { + dep_t const d = dep_t::decode(dp); + if (d.is_null()) + continue; + if (d.is_lit()) { + sat::literal const lit = d.lit(); + if (!m_marked_lits.contains(lit)) { + on_lit(lit); + m_marked_lits.insert(lit); + } + } + else { + SASSERT(d.is_value()); + sat::literal const lit = get_dep_lit(d); + if (lit == sat::null_literal) + on_var(get_dep_var(d)); + else if (!m_marked_lits.contains(lit)) { + on_lit(lit); + m_marked_lits.insert(lit); + } + } + } + m_tmp_deps.reset(); + m_marked_lits.reset(); + } + + void slicing::explain_value(pvar v, std::function const& on_lit, std::function const& on_var) { + explain_value(var2slice(v), on_lit, on_var); + } + + bool slicing::find_range_in_ancestor(enode* s, enode* a, unsigned& out_hi, unsigned& out_lo) { + out_hi = width(s) - 1; + out_lo = 0; + while (true) { + if (s == a) + return true; + enode* p = parent(s); + if (!p) + return false; + if (s == sub_hi(p)) { + unsigned offset = 1 + info(p).cut; + out_hi += offset; + out_lo += offset; + } + else { + SASSERT_EQ(s, sub_lo(p)); + /* range stays unchanged */ + } + s = p; + } + } + + bool slicing::is_extract(pvar x, pvar src, unsigned& out_hi, unsigned& out_lo) { + return find_range_in_ancestor(var2slice(x), var2slice(src), out_hi, out_lo); + } + + void slicing::egraph_on_merge(enode* root, enode* other) { + LOG("on_merge: root " << slice_pp(*this, root) << " other " << slice_pp(*this, other)); + if (root->interpreted()) + return; + if (root->is_equality()) { + SASSERT(other->is_equality()); + return; + } + SASSERT(!other->interpreted()); // by convention, interpreted nodes are always chosen as root + SASSERT(root != other); + SASSERT_EQ(root, root->get_root()); + SASSERT_EQ(root, other->get_root()); + + enode* const v1 = info(root).value_node; // root is the root + enode* const v2 = info(other).value_node; // 'other' was its own root before the merge + if (v1 && v2 && get_value(v1) != get_value(v2)) { + // we have a conflict. add interpreted enodes to make the egraph realize it. + enode* const i1 = mk_interpreted_value_node(v1); + enode* const i2 = mk_interpreted_value_node(v2); + m_egraph.merge(i1, v1, dep_t().encode()); + m_egraph.merge(i2, v2, dep_t().encode()); + SASSERT(is_conflict()); + return; + } + + enode* const v = v1 ? v1 : v2; + if (v && !(v1 && v2)) { + // exactly one node has a value + rational const val = get_value(v); + for (enode* n : euf::enode_class(other)) { + enode* const vn = get_value_node(n); + if (!vn) + set_value_node(n, v); + + pvar const var = slice2var(n); + if (var == null_var) + continue; + if (m_solver.is_assigned(var)) + continue; + LOG("on_merge: v" << var << " := " << val); + m_solver.assign_propagate_by_slicing(var, val); + } + } + } + + void slicing::set_value_node(enode* s, enode* value_node) { + SASSERT(!info(s).value_node); + SASSERT(is_value(value_node)); + info(s).value_node = value_node; + if (s != value_node) { + m_trail.push_back(trail_item::set_value_node); + m_enode_trail.push_back(s); + } + } + + void slicing::undo_set_value_node() { + enode* s = m_enode_trail.back(); + m_enode_trail.pop_back(); + info(s).value_node = nullptr; + } + + void slicing::egraph_on_propagate(enode* lit, enode* ante) { + // ante may be set when symmetric equality is added by congruence + if (ante) + return; + // on_propagate may be called before set_value + if (lit->value() == l_undef) + return; + SASSERT(lit->is_equality()); + SASSERT_EQ(lit->value(), l_false); + SASSERT(lit->get_lit_justification().is_external()); + m_disequality_conflict = lit; + } + + bool slicing::can_propagate() const { + if (use_var_congruences() && !m_needs_congruence.empty()) + return true; + return m_egraph.can_propagate(); + } + + void slicing::propagate() { + // m_egraph.propagate(); + if (is_conflict()) + return; + update_var_congruences(); + m_egraph.propagate(); + } + + bool slicing::egraph_merge(enode* s1, enode* s2, dep_t dep) { + LOG("egraph_merge: " << slice_pp(*this, s1) << " and " << slice_pp(*this, s2) << " by " << dep_pp(*this, dep)); + SASSERT_EQ(width(s1), width(s2)); + if (dep.is_value()) { + if (is_value(s1)) + std::swap(s1, s2); + SASSERT(is_value(s2)); + SASSERT(!is_value(s1)); // we never merge two value slices directly + if (get_dep_slice(dep) != s1) + dep = mk_var_dep(get_dep_var(dep), s1, get_dep_lit(dep)); + } + m_egraph.merge(s1, s2, dep.encode()); + return !is_conflict(); + } + + bool slicing::merge_base(enode* s1, enode* s2, dep_t dep) { + SASSERT(!has_sub(s1)); + SASSERT(!has_sub(s2)); + return egraph_merge(s1, s2, dep); + } + + bool slicing::merge(enode_vector& xs, enode_vector& ys, dep_t dep) { + while (!xs.empty()) { + SASSERT(!ys.empty()); + enode* const x = xs.back(); + enode* const y = ys.back(); + xs.pop_back(); + ys.pop_back(); + if (x == y) + continue; + if (x->get_root() == y->get_root()) { + DEBUG_CODE({ + // invariant: parents merged => base slices merged + enode_vector const x_base = get_base(x); + enode_vector const y_base = get_base(y); + SASSERT_EQ(x_base.size(), y_base.size()); + for (unsigned i = x_base.size(); i-- > 0; ) { + SASSERT_EQ(x_base[i]->get_root(), y_base[i]->get_root()); + } + }); + continue; + } +#if 0 + if (has_sub(x)) { + get_base(x, xs); + x = xs.back(); + xs.pop_back(); + } + if (has_sub(y)) { + get_base(y, ys); + y = ys.back(); + ys.pop_back(); + } + SASSERT(!has_sub(x)); + SASSERT(!has_sub(y)); + if (width(x) == width(y)) { + if (!merge_base(x, y, dep)) { + xs.clear(); + ys.clear(); + return false; + } + } + else if (width(x) > width(y)) { + // need to split x according to y + mk_slice(x, width(y) - 1, 0, xs, true); + ys.push_back(y); + } + else { + SASSERT(width(y) > width(x)); + // need to split y according to x + mk_slice(y, width(x) - 1, 0, ys, true); + xs.push_back(x); + } +#else + if (width(x) == width(y)) { + // We may merge slices if at least one of them doesn't have a subslice yet, + // because in that case all intermediate cut points will be aligned. + // NOTE: it is necessary to merge intermediate slices for value nodes, to ensure downward propagation of assignments. + bool const should_merge = (!has_sub(x) || !has_sub(y)); + // If either slice has a subdivision, we have to cut the other and advance to subslices + if (has_sub(x) || has_sub(y)) { + if (!has_sub(x)) + split(x, get_cut(y)); + if (!has_sub(y)) + split(y, get_cut(x)); + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + ys.push_back(sub_hi(y)); + ys.push_back(sub_lo(y)); + } + // We may only merge intermediate nodes after we're done with splitting (since we currently split the whole equivalence class at once) + if (should_merge) { + if (!egraph_merge(x, y, dep)) { + xs.clear(); + ys.clear(); + return false; + } + } + } + else if (width(x) > width(y)) { + if (!has_sub(x)) + split(x, width(y) - 1); + // split(x, has_sub(y) ? get_cut(y) : (width(y) - 1)); + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + ys.push_back(y); + } + else { + SASSERT(width(y) > width(x)); + if (!has_sub(y)) + split(y, width(x) - 1); + // split(y, has_sub(x) ? get_cut(x) : (width(x) - 1)); + ys.push_back(sub_hi(y)); + ys.push_back(sub_lo(y)); + xs.push_back(x); + } +#endif + } + SASSERT(ys.empty()); + return true; + } + + bool slicing::merge(enode_vector& xs, enode* y, dep_t dep) { + enode_vector& ys = m_tmp2; + SASSERT(ys.empty()); + ys.push_back(y); + return merge(xs, ys, dep); // will clear xs and ys + } + + bool slicing::merge(enode* x, enode* y, dep_t dep) { + LOG("merge: " << slice_pp(*this, x) << " and " << slice_pp(*this, y)); + SASSERT_EQ(width(x), width(y)); + if (!has_sub(x) && !has_sub(y)) + return merge_base(x, y, dep); + enode_vector& xs = m_tmp2; + enode_vector& ys = m_tmp3; + SASSERT(xs.empty()); + SASSERT(ys.empty()); + xs.push_back(x); + ys.push_back(y); + return merge(xs, ys, dep); // will clear xs and ys + } + + bool slicing::is_equal(enode* x, enode* y) { + SASSERT_EQ(width(x), width(y)); + x = x->get_root(); + y = y->get_root(); + if (x == y) + return true; + enode_vector& xs = m_tmp2; + enode_vector& ys = m_tmp3; + SASSERT(xs.empty()); + SASSERT(ys.empty()); + on_scope_exit clear_vectors([&xs, &ys](){ + xs.clear(); + ys.clear(); + }); + // TODO: we don't always have to collect the full base if intermediate nodes are already equal + get_base(x, xs); + get_base(y, ys); + if (xs.size() != ys.size()) + return false; + for (unsigned i = xs.size(); i-- > 0; ) + if (xs[i]->get_root() != ys[i]->get_root()) + return false; + return true; + } + + void slicing::get_base(enode* src, enode_vector& out_base) const { + enode_vector& todo = m_tmp1; + SASSERT(todo.empty()); + todo.push_back(src); + while (!todo.empty()) { + enode* s = todo.back(); + todo.pop_back(); + if (!has_sub(s)) + out_base.push_back(s); + else { + todo.push_back(sub_lo(s)); + todo.push_back(sub_hi(s)); + } + } + SASSERT(todo.empty()); + } + + slicing::enode_vector slicing::get_base(enode* src) const { + enode_vector out; + get_base(src, out); + return out; + } + + pvar slicing::mk_extract(enode* src, unsigned hi, unsigned lo, pvar replay_var) { + LOG("mk_extract: src=" << slice_pp(*this, src) << " hi=" << hi << " lo=" << lo); + enode_vector& slices = m_tmp3; + SASSERT(slices.empty()); + mk_slice(src, hi, lo, slices, false, false); + pvar v = null_var; + // try to re-use variable of an existing slice + if (slices.size() == 1) + v = slice2var(slices[0]); + if (replay_var != null_var && v != replay_var) { + // replayed variable should be 'fresh', unless it was a re-used variable + enode* s = var2slice(replay_var); + SASSERT(s->is_root()); + SASSERT_EQ(s->class_size(), 1); + SASSERT(!has_sub(s)); + SASSERT_EQ(width(s), hi - lo + 1); + v = replay_var; + } + // allocate new variable if we cannot reuse it + if (v == null_var) { + v = m_solver.add_var(hi - lo + 1, pvar_kind::internal); +#if 1 + // slice didn't have a variable yet; so we can re-use it for the new variable + // (we end up with a "phantom" enode that was first created for the variable) + if (slices.size() == 1) { + enode* s = slices[0]; + LOG("re-using slice " << slice_pp(*this, s) << " for new variable v" << v); + // display_tree(std::cerr, s, 0, hi, lo); + SASSERT_EQ(info(s).var, null_var); + info(m_var2slice[v]).var = null_var; // disconnect the "phantom" enode + info(s).var = v; + m_var2slice[v] = s; + } +#endif + } + // connect new variable + VERIFY(merge(slices, var2slice(v), dep_t())); + slices.reset(); + return v; + } + + void slicing::replay_extract(extract_args const& args, pvar r) { + LOG("replay_extract"); + SASSERT(r != null_var); + SASSERT(!m_extract_dedup.contains(args)); + VERIFY_EQ(mk_extract(var2slice(args.src), args.hi, args.lo, r), r); + m_extract_dedup.insert(args, r); + m_extract_trail.push_back(args); + m_trail.push_back(trail_item::mk_extract); + } + + pvar slicing::mk_extract(pvar src, unsigned hi, unsigned lo) { + LOG_H3("mk_extract: v" << src << "[" << hi << ":" << lo << "] size(v" << src << ") = " << m_solver.size(src)); + if (m_solver.size(src) == hi - lo + 1) + return src; + extract_args args{src, hi, lo}; + auto it = m_extract_dedup.find_iterator(args); + if (it != m_extract_dedup.end()) + return it->m_value; + pvar const v = mk_extract(var2slice(src), hi, lo); + m_extract_dedup.insert(args, v); + m_extract_trail.push_back(args); + m_trail.push_back(trail_item::mk_extract); + LOG("mk_extract: v" << src << "[" << hi << ":" << lo << "] = v" << v); + return v; + } + + void slicing::undo_mk_extract() { + extract_args args = m_extract_trail.back(); + m_extract_trail.pop_back(); + m_extract_dedup.remove(args); + } + + pvar slicing::mk_concat(unsigned num_args, pvar const* args, pvar replay_var) { + enode_vector& slices = m_tmp3; + SASSERT(slices.empty()); + unsigned total_width = 0; + for (unsigned i = 0; i < num_args; ++i) { + enode* s = var2slice(args[i]); + slices.push_back(s); + total_width += width(s); + } + // NOTE: we use concat nodes to deduplicate (syntactically equal) concat expressions. + // we might end up reusing variables that are not introduced by mk_concat (if we enable the variable re-use optimization in mk_extract), + // but because such congruence nodes are only added over direct descendants, we do not get unwanted dependencies from this re-use. + // (but note that the nodes from mk_concat are not only over direct descendants) + enode* concat = mk_concat_node(slices); + pvar v = slice2var(concat); + if (v != null_var) + return v; + if (replay_var != null_var) { + // replayed variable should be 'fresh' + enode* s = var2slice(replay_var); + SASSERT(s->is_root()); + SASSERT_EQ(s->class_size(), 1); + SASSERT(!has_sub(s)); + SASSERT_EQ(width(s), total_width); + v = replay_var; + } + else + v = m_solver.add_var(total_width, pvar_kind::internal); + enode* sv = var2slice(v); + VERIFY(merge(slices, sv, dep_t())); + // NOTE: add_concat_node must be done after merge to preserve the invariant: "a base slice is never equivalent to a congruence node". + add_concat_node(sv, concat); + slices.reset(); + + // don't mess with the concat_trail during replay + if (replay_var == null_var) { + concat_info ci; + ci.v = v; + ci.num_args = num_args; + ci.args_idx = m_concat_args.size(); + m_concat_trail.push_back(ci); + for (unsigned i = 0; i < num_args; ++i) + m_concat_args.push_back(args[i]); + } + m_trail.push_back(trail_item::mk_concat); + + return v; + } + + void slicing::replay_concat(unsigned num_args, pvar const* args, pvar r) { + SASSERT(r != null_var); + VERIFY_EQ(mk_concat(num_args, args, r), r); + } + + pvar slicing::mk_concat(std::initializer_list args) { + return mk_concat(args.size(), args.begin()); + } + + void slicing::add_constraint(signed_constraint c) { + LOG(c); + SASSERT(!is_conflict()); + if (!add_fixed_bits(c)) + return; + if (c->is_eq()) + add_constraint_eq(c->to_eq(), c.blit()); + } + + bool slicing::add_fixed_bits(signed_constraint c) { + // TODO: what is missing here: + // - we don't prioritize constraints that set larger bit ranges + // e.g., c1 sets 3 lower bits, and c2 sets 5 lower bits. + // slicing may have both {c1,c2} in justifications while previously we always prefer c2. + // - instead of prioritizing constraints (which is annoying to do incrementally), let subsumption take care of this issue. + // if constraint C subsumes constraint D, then we might even want to completely deactivate D in the solver? (not easy if D is on higher level than C). + // - (we could wait until propagate() to add fixed bits to the egraph. but that would only work on a single decision level.) + if (c->vars().size() != 1) + return true; + fixed_bits fb; + if (!get_fixed_bits(c, fb)) + return true; + pvar const x = c->vars()[0]; + return add_fixed_bits(x, fb.hi, fb.lo, fb.value, c.blit()); + } + + bool slicing::add_fixed_bits(pvar x, unsigned hi, unsigned lo, rational const& value, sat::literal lit) { + LOG("add_fixed_bits: v" << x << "[" << hi << ":" << lo << "] = " << value << " by " << lit_pp(m_solver, lit)); + enode_vector& xs = m_tmp3; + SASSERT(xs.empty()); + mk_slice(var2slice(x), hi, lo, xs, false, false); + enode* const sval = mk_value_slice(value, hi - lo + 1); + // 'xs' will be cleared by 'merge'. + // NOTE: the 'nullptr' argument will be fixed by 'egraph_merge' + return merge(xs, sval, mk_var_dep(x, nullptr, lit)); + } + + bool slicing::add_constraint_eq(pdd const& p, sat::literal lit) { + auto& m = p.manager(); + for (auto& [a, x] : p.linear_monomials()) { + if (a != 1 && a != m.max_value()) + continue; + pdd const body = a.is_one() ? (m.mk_var(x) - p) : (m.mk_var(x) + p); + // c is either x = body or x != body, depending on polarity + if (!add_equation(x, body, lit)) { + SASSERT(is_conflict()); + return false; + } + // without this check, when p = x - y we would handle both x = y and y = x separately + if (body.is_unary()) + break; + } + return true; + } + + // TODO: handle equations 2^k x = 2^k y? (lower n-k bits of x and y are equal) + bool slicing::add_equation(pvar x, pdd const& body, sat::literal lit) { + LOG("Equation from lit(" << lit << "): v" << x << (lit.sign() ? " != " : " = ") << body); + if (!lit.sign() && body.is_val()) { + LOG(" simple assignment"); + // Simple assignment x = value + return add_value(x, body.val(), lit); + } + enode* const sx = var2slice(x); + pvar const y = m_solver.m_names.get_name(body); + if (y == null_var) { + if (!body.is_val()) { + // TODO: register name trigger (if a name for value 'body' is created later, then merge x=y at that time) + // could also count how often 'body' was registered and introduce name when more than once. + // maybe better: register x as an existing name for 'body'? question is how to track the dependency on c. + LOG(" skip for now (unnamed body)"); + } else + LOG(" skip for now (disequality with constant)"); + return true; + } + enode* const sy = var2slice(y); + if (!lit.sign()) { + LOG(" merge v" << x << " and v" << y); + return merge(sx, sy, lit); + } + else { + LOG(" store disequality v" << x << " != v" << y); + enode* n = find_or_alloc_disequality(sx, sy, lit); + if (!m_disequality_conflict && is_equal(sx, sy)) { + add_var_congruence_if_needed(x); + add_var_congruence_if_needed(y); + m_disequality_conflict = n; + return false; + } + } + return true; + } + + bool slicing::add_value(pvar v, rational const& value, sat::literal lit) { + enode* const sv = var2slice(v); + if (get_value_node(sv) && get_value(get_value_node(sv)) == value) + return true; + enode* const sval = mk_value_slice(value, width(sv)); + return merge(sv, sval, mk_var_dep(v, sv, lit)); + } + + void slicing::add_value(pvar v, rational const& value) { + LOG("v" << v << " := " << value); + SASSERT(!is_conflict()); + (void)add_value(v, value, sat::null_literal); + } + + void slicing::collect_simple_overlaps(pvar v, pvar_vector& out) { + unsigned const first_out = out.size(); + enode* const sv = var2slice(v); + unsigned const v_width = width(sv); + enode_vector& v_base = m_tmp2; + SASSERT(v_base.empty()); + get_base(var2slice(v), v_base); + + SASSERT(all_of(m_egraph.nodes(), [](enode* n) { return !n->is_marked1(); })); + + // Collect direct sub-slices of v and their equivalences + // (these don't need any extra checks) + for (enode* s = sv; s != nullptr; s = has_sub(s) ? sub_lo(s) : nullptr) { + for (enode* n : euf::enode_class(s)) { + if (!is_proper_slice(n)) + continue; + pvar const w = slice2var(n); + if (w == null_var) + continue; + SASSERT(!n->is_marked1()); + n->mark1(); + out.push_back(w); + } + } + + // lowermost base slice of v + enode* const v_base_lo = v_base.back(); + + svector> candidates; + // Collect all other candidate variables, + // i.e., those who share the lowermost base slice with v. + for (enode* n : euf::enode_class(v_base_lo)) { + if (!is_proper_slice(n)) + continue; + if (n == v_base_lo) + continue; + enode* const n0 = n; + pvar w2 = null_var; // the highest variable we care about from this equivalence class + // examine parents to find variables + SASSERT(!has_sub(n)); + while (true) { + pvar const w = slice2var(n); + if (w != null_var && !n->is_marked1()) + w2 = w; + enode* p = parent(n); + if (!p) + break; + if (sub_lo(p) != n) // we only care about lowermost slices of variables + break; + if (width(p) > v_width) + break; + n = p; + } + if (w2 != null_var) + candidates.push_back({n0, w2}); + } + + // Check candidates + for (auto const& [n0, w2] : candidates) { + // unsigned v_next = v_base.size(); + auto v_it = v_base.rbegin(); + enode* n = n0; + SASSERT_EQ(n->get_root(), (*v_it)->get_root()); + ++v_it; + while (true) { + // here: base of n is equivalent to lower portion of base of v + pvar const w = slice2var(n); + if (w != null_var && !n->is_marked1()) { + n->mark1(); + out.push_back(w); + } + if (w == w2) + break; + // + enode* const p = parent(n); + SASSERT(p); + SASSERT_EQ(sub_lo(p), n); // otherwise not a candidate + // check if base of sub_hi(p) matches the base of v + enode_vector& p_hi_base = m_tmp3; + SASSERT(p_hi_base.empty()); + get_base(sub_hi(p), p_hi_base); + auto p_it = p_hi_base.rbegin(); + bool ok = true; + while (ok && p_it != p_hi_base.rend()) { + if (v_it == v_base.rend()) + ok = false; + else if ((*p_it)->get_root() != (*v_it)->get_root()) + ok = false; + else { + ++p_it; + ++v_it; + } + } + p_hi_base.reset(); + if (!ok) + break; + n = p; + } + } + + v_base.reset(); + for (unsigned i = out.size(); i-- > first_out; ) { + enode* n = var2slice(out[i]); + SASSERT(n->is_marked1()); + n->unmark1(); + } + SASSERT(all_of(m_egraph.nodes(), [](enode* n) { return !n->is_marked1(); })); + } + + void slicing::explain_simple_overlap(pvar v, pvar x, std::function const& on_lit) { + SASSERT(width(var2slice(x)) <= width(var2slice(v))); + SASSERT(m_marked_lits.empty()); + SASSERT(m_tmp_deps.empty()); + + if (v == x) + return; + + enode_vector& v_base = m_tmp4; + SASSERT(v_base.empty()); + get_base(var2slice(v), v_base); + enode_vector& x_base = m_tmp5; + SASSERT(x_base.empty()); + get_base(var2slice(x), x_base); + + auto v_it = v_base.rbegin(); + auto x_it = x_base.rbegin(); + while (x_it != x_base.rend()) { + SASSERT(v_it != v_base.rend()); + enode* nv = *v_it; ++v_it; + enode* nx = *x_it; ++x_it; + SASSERT_EQ(nv->get_root(), nx->get_root()); + explain_equal(nv, nx, m_tmp_deps); + } + + for (void* dp : m_tmp_deps) { + dep_t const d = dep_t::decode(dp); + if (d.is_null()) + continue; + if (d.is_lit()) { + sat::literal lit = d.lit(); + if (m_marked_lits.contains(lit)) + continue; + m_marked_lits.insert(lit); + on_lit(d.lit()); + } + else { + // equivalence between to variables cannot be due to value assignment + UNREACHABLE(); + } + } + m_marked_lits.reset(); + m_tmp_deps.reset(); + } + + void slicing::collect_fixed(pvar v, justified_fixed_bits_vector& out) { + enode_vector& base = m_tmp2; + SASSERT(base.empty()); + get_base(var2slice(v), base); + rational a; + unsigned lo = 0; + for (auto it = base.rbegin(); it != base.rend(); ++it) { + enode* const n = *it; + enode* const nv = get_value_node(n); + unsigned const w = width(n); + unsigned const hi = lo + w - 1; + if (try_get_value(nv, a)) + out.push_back({hi, lo, a, n}); + lo += w; + } + base.reset(); + } + + void slicing::explain_fixed(euf::enode* const n, std::function const& on_lit, std::function const& on_var) { + explain_value(n, on_lit, on_var); + } + + pvar_vector slicing::equivalent_vars(pvar v) const { + pvar_vector xs; + for (enode* n : euf::enode_class(var2slice(v))) { + pvar const x = slice2var(n); + if (x != null_var) + xs.push_back(x); + } + return xs; + } + + std::ostream& slicing::display(std::ostream& out) const { + enode_vector base; + for (pvar v = 0; v < m_var2slice.size(); ++v) { + out << "v" << v << ":"; + base.reset(); + enode* const vs = var2slice(v); + get_base(vs, base); + for (enode* s : base) + display(out << " ", s); + if (enode* vnode = get_value_node(vs)) + out << " [root_value: " << get_value(vnode) << "]"; + out << "\n"; + } + return out; + } + + std::ostream& slicing::display_tree(std::ostream& out) const { + for (pvar v = 0; v < m_var2slice.size(); ++v) { + out << "v" << v << ":\n"; + enode* const s = var2slice(v); + display_tree(out, s, 4, width(s) - 1, 0); + } + out << m_egraph << "\n"; + return out; + } + + std::ostream& slicing::display_tree(std::ostream& out, enode* s, unsigned indent, unsigned hi, unsigned lo) const { + out << std::string(indent, ' ') << "[" << hi << ":" << lo << "]"; + out << " id=" << s->get_id(); + out << " w=" << width(s); + if (slice2var(s) != null_var) + out << " var=v" << slice2var(s); + if (parent(s)) + out << " parent=" << parent(s)->get_id(); + if (!s->is_root()) + out << " root=" << s->get_root_id(); + if (enode* n = get_value_node(s)) + out << " value=" << get_value(n); + // if (is_proper_slice(s)) + // out << " "; + if (is_value(s)) + out << " "; + if (is_concat(s)) + out << " "; + if (is_equality(s)) + out << " "; + out << "\n"; + if (has_sub(s)) { + unsigned cut = info(s).cut; + display_tree(out, sub_hi(s), indent + 4, hi, cut + 1 + lo); + display_tree(out, sub_lo(s), indent + 4, cut + lo, lo); + } + return out; + } + + std::ostream& slicing::display(std::ostream& out, enode* s) const { + out << "{id:" << s->get_id(); + if (is_equality(s)) + return out << ",}"; + out << ",w:" << width(s); + out << ",root:" << s->get_root_id(); + if (slice2var(s) != null_var) + out << ",var:v" << slice2var(s); + if (enode* n = get_value_node(s)) + out << ",value:" << get_value(n); + if (s->interpreted()) + out << ","; + if (is_concat(s)) + out << ","; + out << "}"; + return out; + } + + bool slicing::invariant() const { + VERIFY(m_tmp1.empty()); + VERIFY(m_tmp2.empty()); + VERIFY(m_tmp3.empty()); + if (is_conflict()) // if we break a loop early on conflict, we can't guarantee that all properties are satisfied + return true; + for (enode* s : m_egraph.nodes()) { + // we use equality enodes only to track disequalities + if (s->is_equality()) + continue; + // if the slice is equivalent to a variable, then the variable's slice is in the equivalence class + pvar const v = slice2var(s); + if (v != null_var) { + VERIFY_EQ(var2slice(v)->get_root(), s->get_root()); + } + // if slice has a value, it should be propagated to its sub-slices + if (get_value_node(s) && has_sub(s)) { + VERIFY(get_value_node(sub_hi(s))); + VERIFY(get_value_node(sub_lo(s))); + } + // a base slice is never equivalent to a congruence node + if (is_slice(s) && !has_sub(s)) { + VERIFY(all_of(euf::enode_class(s), [&](enode* n) { return is_slice(n); })); + } + if (is_concat(s)) { + // all concat nodes point to a variable slice + VERIFY(slice2var(s) != null_var); + enode* sv = var2slice(slice2var(s)); // the corresponding variable slice + VERIFY(s != sv); + VERIFY(is_slice(sv)); + VERIFY(s->num_args() >= 2); + } + ///////////////////////////////////////////////////////////////// + // class properties (i.e., skip non-representatives) + if (!s->is_root()) + continue; + bool const sub = has_sub(s); + enode_vector const s_base = get_base(s); + for (enode* n : euf::enode_class(s)) { + // equivalence class only contains slices of equal length + VERIFY_EQ(width(s), width(n)); + // either all nodes in the class have subslices or none do + SASSERT_EQ(sub, has_sub(n)); + // bases of equivalent nodes are equivalent + enode_vector const n_base = get_base(n); + VERIFY_EQ(s_base.size(), n_base.size()); + for (unsigned i = s_base.size(); i-- > 0; ) { + VERIFY_EQ(s_base[i]->get_root(), n_base[i]->get_root()); + } + } + } + return true; + } + +} diff --git a/src/sat/smt/polysat/slicing.h b/src/sat/smt/polysat/slicing.h new file mode 100644 index 000000000..f9f90610b --- /dev/null +++ b/src/sat/smt/polysat/slicing.h @@ -0,0 +1,397 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + polysat slicing (relating variables of different bit-widths by extraction) + +Author: + + Jakob Rath 2023-06-01 + +Notation: + + Let x be a bit-vector of width w. + Let l, h indices such that 0 <= l <= h < w. + Then x[h:l] extracts h - l + 1 bits of x. + Shorthands: + - x[h:] stands for x[h:0], and + - x[:l] stands for x[w-1:l]. + + Example: + 0001[0:] = 1 + 0001[2:0] = 001 + +--*/ +#pragma once +#include "ast/euf/euf_egraph.h" +#include "ast/bv_decl_plugin.h" +#include "math/polysat/types.h" +#include "math/polysat/constraint.h" +#include "math/polysat/fixed_bits.h" +#include + +namespace polysat { + + class solver; + + class slicing final { + + friend class test_slicing; + + public: + using enode = euf::enode; + using enode_vector = euf::enode_vector; + using enode_pair = euf::enode_pair; + using enode_pair_vector = euf::enode_pair_vector; + + private: + class dep_t { + std::variant m_data; + public: + dep_t() { SASSERT(is_null()); } + dep_t(sat::literal l): m_data(l) { SASSERT(l != sat::null_literal); SASSERT_EQ(l, lit()); } + explicit dep_t(unsigned idx): m_data(idx) { SASSERT_EQ(idx, value_idx()); } + bool is_null() const { return std::holds_alternative(m_data); } + bool is_lit() const { return std::holds_alternative(m_data); } + bool is_value() const { return std::holds_alternative(m_data); } + sat::literal lit() const { SASSERT(is_lit()); return *std::get_if(&m_data); } + unsigned value_idx() const { SASSERT(is_value()); return *std::get_if(&m_data); } + bool operator==(dep_t other) const { return m_data == other.m_data; } + bool operator!=(dep_t other) const { return !operator==(other); } + void* encode() const; + static dep_t decode(void* p); + }; + + using dep_vector = svector; + + std::ostream& display(std::ostream& out, dep_t d) const; + + dep_t mk_var_dep(pvar v, enode* s, sat::literal lit); + + pvar_vector m_dep_var; + ptr_vector m_dep_slice; + sat::literal_vector m_dep_lit; // optional, value assignment comes from a literal "x == val" rather than a solver assignment + unsigned_vector m_dep_size_trail; + + pvar get_dep_var(dep_t d) const { return m_dep_var[d.value_idx()]; } + sat::literal get_dep_lit(dep_t d) const { return m_dep_lit[d.value_idx()]; } + enode* get_dep_slice(dep_t d) const { return m_dep_slice[d.value_idx()]; } + + static constexpr unsigned null_cut = std::numeric_limits::max(); + + // We use the following kinds of enodes: + // - proper slices (of variables) + // - value slices + // - interpreted value nodes ... these are short-lived, and only created to immediately trigger a conflict inside the egraph + // - virtual concat(...) expressions + // - equalities between enodes (to track disequalities; currently not represented in slice_info) + struct slice_info { + // Cut point: if not null_cut, the slice s has been subdivided into s[|s|-1:cut+1] and s[cut:0]. + // The cut point is relative to the parent slice (rather than a root variable, which might not be unique) + unsigned cut = null_cut; // cut point, or null_cut if no subslices + pvar var = null_var; // slice is equivalent to this variable, if any (without dependencies) + enode* parent = nullptr; // parent slice, only for proper slices (if not null: s == sub_hi(parent(s)) || s == sub_lo(parent(s))) + enode* slice = nullptr; // if enode corresponds to a concat(...) expression, this field links to the represented slice. + enode* sub_hi = nullptr; // upper subslice s[|s|-1:cut+1] + enode* sub_lo = nullptr; // lower subslice s[cut:0] + enode* value_node = nullptr; // the root of an equivalence class stores the value slice here, if any + + void reset() { *this = slice_info(); } + bool has_sub() const { return !!sub_hi; } + void set_cut(unsigned cut, enode* sub_hi, enode* sub_lo) { this->cut = cut; this->sub_hi = sub_hi; this->sub_lo = sub_lo; } + }; + using slice_info_vector = svector; + + // Return true iff n is either a proper slice or a value slice + bool is_slice(enode* n) const; + + bool is_proper_slice(enode* n) const { return !is_value(n) && is_slice(n); } + bool is_value(enode* n) const; + bool is_concat(enode* n) const; + bool is_equality(enode* n) const { return n->is_equality(); } + + solver& m_solver; + + ast_manager m_ast; + scoped_ptr m_bv; + + euf::egraph m_egraph; + slice_info_vector m_info; // indexed by enode::get_id() + enode_vector m_var2slice; // pvar -> slice + tracked_uint_set m_needs_congruence; // set of pvars that need updated concat(...) expressions + enode* m_disequality_conflict = nullptr; + + // Add an equation v = concat(s1, ..., sn) + // for each variable v with base slices s1, ..., sn + void update_var_congruences(); + void add_var_congruence(pvar v); + void add_var_congruence_if_needed(pvar v); + bool use_var_congruences() const; + + func_decl* mk_concat_decl(ptr_vector const& args); + enode* mk_concat_node(enode_vector const& slices); + enode* mk_concat_node(std::initializer_list slices) { return mk_concat_node(slices.size(), slices.begin()); } + enode* mk_concat_node(unsigned num_slices, enode* const* slices); + // Add s = concat(s1, ..., sn) + void add_concat_node(enode* s, enode* concat); + + slice_info& info(euf::enode* n); + slice_info const& info(euf::enode* n) const; + + enode* alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var); + enode* find_or_alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var); + enode* alloc_slice(unsigned width, pvar var = null_var); + enode* find_or_alloc_disequality(enode* x, enode* y, sat::literal lit); + + // Find hi, lo such that s = a[hi:lo] + bool find_range_in_ancestor(enode* s, enode* a, unsigned& out_hi, unsigned& out_lo); + + enode* var2slice(pvar v) const { return m_var2slice[v]; } + pvar slice2var(enode* s) const { return info(s).var; } + + unsigned width(enode* s) const; + + enode* parent(enode* s) const { return info(s).parent; } + + enode* get_value_node(enode* s) const { return info(s).value_node; } + void set_value_node(enode* s, enode* value_node); + + unsigned get_cut(enode* s) const { return info(s).cut; } + + bool has_sub(enode* s) const { return info(s).has_sub(); } + + /// Upper subslice (direct child, not necessarily the representative) + enode* sub_hi(enode* s) const { return info(s).sub_hi; } + + /// Lower subslice (direct child, not necessarily the representative) + enode* sub_lo(enode* s) const { return info(s).sub_lo; } + + /// sub_lo(parent(s)) or sub_hi(parent(s)), whichever is different from s. + enode* sibling(enode* s) const; + + // Retrieve (or create) a slice representing the given value. + enode* mk_value_slice(rational const& val, unsigned bit_width); + + // Turn value node into unwrapped BV constant node + enode* mk_interpreted_value_node(enode* value_slice); + + rational get_value(enode* s) const; + bool try_get_value(enode* s, rational& val) const; + + /// Split slice s into s[|s|-1:cut+1] and s[cut:0] + void split(enode* s, unsigned cut); + void split_core(enode* s, unsigned cut); + + /// Retrieve base slices s_1,...,s_n such that src == s_1 ++ ... ++ s_n (actual descendant subslices) + void get_base(enode* src, enode_vector& out_base) const; + enode_vector get_base(enode* src) const; + + /// Retrieve (or create) base slices s_1,...,s_n such that src[hi:lo] == s_1 ++ ... ++ s_n. + /// If output_full_src is true, return the new base for src, i.e., src == s_1 ++ ... ++ s_n. + /// If output_base is false, return coarsest intermediate slices instead of only base slices. + void mk_slice(enode* src, unsigned hi, unsigned lo, enode_vector& out, bool output_full_src = false, bool output_base = true); + + // Extract reason why slices x and y are in the same equivalence class + void explain_class(enode* x, enode* y, ptr_vector& out_deps); + + // Extract reason why slices x and y are equal + // (i.e., x and y have the same base, but are not necessarily in the same equivalence class) + void explain_equal(enode* x, enode* y, ptr_vector& out_deps); + + /** Explain why slice is equivalent to a value */ + void explain_value(enode* s, std::function const& on_lit, std::function const& on_var); + + /** Extract reason for conflict */ + void explain(ptr_vector& out_deps); + + /** Extract reason for x == y */ + void explain_equal(pvar x, pvar y, ptr_vector& out_deps); + + void egraph_on_make(enode* n); + void egraph_on_merge(enode* root, enode* other); + void egraph_on_propagate(enode* lit, enode* ante); + + // Merge slices in the e-graph. + bool egraph_merge(enode* s1, enode* s2, dep_t dep); + + // Merge equivalence classes of two base slices. + // Returns true if merge succeeded without conflict. + [[nodiscard]] bool merge_base(enode* s1, enode* s2, dep_t dep); + + // Merge equality x_1 ++ ... ++ x_n == y_1 ++ ... ++ y_k + // + // Precondition: + // - sequence of slices with equal total width + // - ordered from msb to lsb + // + // The argument vectors will be cleared. + // + // Returns true if merge succeeded without conflict. + [[nodiscard]] bool merge(enode_vector& xs, enode_vector& ys, dep_t dep); + [[nodiscard]] bool merge(enode_vector& xs, enode* y, dep_t dep); + [[nodiscard]] bool merge(enode* x, enode* y, dep_t dep); + + // Check whether two slices are known to be equal + bool is_equal(enode* x, enode* y); + + // deduplication of extract terms + struct extract_args { + pvar src = null_var; + unsigned hi = 0; + unsigned lo = 0; + bool operator==(extract_args const& other) const { return src == other.src && hi == other.hi && lo == other.lo; } + unsigned hash() const { return mk_mix(src, hi, lo); } + }; + using extract_args_eq = default_eq; + using extract_args_hash = obj_hash; + using extract_map = map; + extract_map m_extract_dedup; + // svector m_extract_origin; // pvar -> extract_args + // TODO: add 'm_extract_origin' (pvar -> extract_args)? 1. for dependency tracking when sharing subslice trees; 2. for easily checking if a variable is an extraction of another; 3. also makes the replay easier + // bool is_extract(pvar v) const { return m_extract_origin[v].src != null_var; } + + enum class trail_item : std::uint8_t { + add_var, + split_core, + mk_extract, + mk_concat, + set_value_node, + }; + svector m_trail; + enode_vector m_enode_trail; + svector m_extract_trail; + unsigned_vector m_scopes; + + struct concat_info { + pvar v; + unsigned num_args; + unsigned args_idx; + unsigned next_args_idx() const { return args_idx + num_args; } + }; + svector m_concat_trail; + svector m_concat_args; + + void undo_add_var(); + void undo_split_core(); + void undo_mk_extract(); + void undo_set_value_node(); + + mutable enode_vector m_tmp1; + mutable enode_vector m_tmp2; + mutable enode_vector m_tmp3; + mutable enode_vector m_tmp4; + mutable enode_vector m_tmp5; + ptr_vector m_tmp_deps; + sat::literal_set m_marked_lits; + + /** Get variable representing src[hi:lo] */ + pvar mk_extract(enode* src, unsigned hi, unsigned lo, pvar replay_var = null_var); + /** Restore r = src[hi:lo] */ + void replay_extract(extract_args const& args, pvar r); + + pvar mk_concat(unsigned num_args, pvar const* args, pvar replay_var); + void replay_concat(unsigned num_args, pvar const* args, pvar r); + + bool add_constraint_eq(pdd const& p, sat::literal lit); + bool add_equation(pvar x, pdd const& body, sat::literal lit); + bool add_value(pvar v, rational const& value, sat::literal lit); + bool add_fixed_bits(signed_constraint c); + bool add_fixed_bits(pvar x, unsigned hi, unsigned lo, rational const& value, sat::literal lit); + + bool invariant() const; + bool invariant_needs_congruence() const; + + std::ostream& display(std::ostream& out, enode* s) const; + std::ostream& display_tree(std::ostream& out, enode* s, unsigned indent, unsigned hi, unsigned lo) const; + + class slice_pp { + slicing const& s; + enode* n; + public: + slice_pp(slicing const& s, enode* n): s(s), n(n) {} + std::ostream& display(std::ostream& out) const { return s.display(out, n); } + }; + friend std::ostream& operator<<(std::ostream& out, slice_pp const& s) { return s.display(out); } + + class dep_pp { + slicing const& s; + dep_t d; + public: + dep_pp(slicing const& s, dep_t d): s(s), d(d) {} + std::ostream& display(std::ostream& out) const { return s.display(out, d); } + }; + friend std::ostream& operator<<(std::ostream& out, dep_pp const& d) { return d.display(out); } + + euf::egraph::e_pp e_pp(enode* n) const { return euf::egraph::e_pp(m_egraph, n); } + + public: + slicing(solver& s); + + void push_scope(); + void pop_scope(unsigned num_scopes = 1); + + void add_var(unsigned bit_width); + + /** Get or create variable representing x[hi:lo] */ + pvar mk_extract(pvar x, unsigned hi, unsigned lo); + + /** Get or create variable representing x1 ++ x2 ++ ... ++ xn */ + pvar mk_concat(unsigned num_args, pvar const* args) { return mk_concat(num_args, args, null_var); } + pvar mk_concat(std::initializer_list args); + + // Find hi, lo such that x = src[hi:lo]. + bool is_extract(pvar x, pvar src, unsigned& out_hi, unsigned& out_lo); + + // Track value assignments to variables (and propagate to subslices) + void add_value(pvar v, rational const& value); + void add_value(pvar v, unsigned value) { add_value(v, rational(value)); } + void add_value(pvar v, int value) { add_value(v, rational(value)); } + void add_constraint(signed_constraint c); + + bool can_propagate() const; + + // update congruences, egraph + void propagate(); + + bool is_conflict() const { return m_disequality_conflict || m_egraph.inconsistent(); } + + /** Extract conflict clause */ + clause_ref build_conflict_clause(); + + /** Explain why slicing has propagated the value assignment for v */ + void explain_value(pvar v, std::function const& on_lit, std::function const& on_var); + + /** For a given variable v, find the set of variables w such that w = v[|w|:0]. */ + void collect_simple_overlaps(pvar v, pvar_vector& out); + void explain_simple_overlap(pvar v, pvar x, std::function const& on_lit); + + struct justified_fixed_bits : public fixed_bits { + enode* just; + + justified_fixed_bits(unsigned hi, unsigned lo, rational value, enode* just): fixed_bits(hi, lo, value), just(just) {} + }; + + using justified_fixed_bits_vector = vector; + + /** Collect fixed portions of the variable v */ + void collect_fixed(pvar v, justified_fixed_bits_vector& out); + void explain_fixed(enode* just, std::function const& on_lit, std::function const& on_var); + + /** + * Collect variables that are equivalent to v (including v itself) + * + * NOTE: this might miss some variables that are equal due to equivalent base slices. With 'polysat.slicing.congruence=true' and after propagate(), it should return all equal variables. + */ + pvar_vector equivalent_vars(pvar v) const; + + /** Explain why variables x and y are equivalent */ + void explain_equal(pvar x, pvar y, std::function const& on_lit); + + std::ostream& display(std::ostream& out) const; + std::ostream& display_tree(std::ostream& out) const; + }; + + inline std::ostream& operator<<(std::ostream& out, slicing const& s) { return s.display(out); } + +} diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 9e57b8cae..999db42ea 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -254,4 +254,46 @@ namespace polysat { return expr_ref(bv.mk_bv_add(lo, hi), m); } + // walk the egraph starting with pvar for overlaps. + void solver::get_bitvector_prefixes(pvar pv, pvar_vector& out) { + theory_var v = m_pddvar2var[pv]; + euf::enode_vector todo; + uint_set seen; + unsigned lo, hi; + expr* e = nullptr; + todo.push_back(var2enode(v)); + for (unsigned i = 0; i < todo.size(); ++i) { + auto n = todo[i]->get_root(); + if (n->is_marked1()) + continue; + n->mark1(); + for (auto sib : euf::enode_class(n)) { + theory_var w = sib->get_th_var(get_id()); + + // identify prefixes + if (bv.is_concat(sib->get_expr())) + todo.push_back(sib->get_arg(0)); + if (w == euf::null_theory_var) + continue; + if (seen.contains(w)) + continue; + seen.insert(w); + auto const& p = m_var2pdd[w]; + if (p.is_var()) + out.push_back(p.var()); + } + for (auto p : euf::enode_parents(n)) { + if (p->is_marked1()) + continue; + // find prefixes: e[hi:0] a parent of n + if (bv.is_extract(p->get_expr(), lo, hi, e) && lo == 0) { + SASSERT(n == expr2enode(e)->get_root()); + todo.push_back(p); + } + } + } + for (auto n : todo) + n->get_root()->unmark1(); + } + } diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 8489619da..56ab615dc 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -133,6 +133,7 @@ namespace polysat { void propagate(dependency const& d, bool sign, dependency_vector const& deps) override; trail_stack& trail() override; bool inconsistent() const override; + void get_bitvector_prefixes(pvar v, pvar_vector& out) override; void add_lemma(vector const& lemma); From 0c2ecf8b90be15bc2c72dd978b4f978e8e801045 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 9 Dec 2023 13:10:47 -0800 Subject: [PATCH 18/89] working on viable --- src/sat/smt/polysat/CMakeLists.txt | 1 + src/sat/smt/polysat/fixed_bits.cpp | 180 ++++++++++++++++++++++ src/sat/smt/polysat/fixed_bits.h | 31 ++++ src/sat/smt/polysat/polysat_core.cpp | 4 + src/sat/smt/polysat/polysat_core.h | 2 + src/sat/smt/polysat/polysat_types.h | 34 +++++ src/sat/smt/polysat/polysat_viable.cpp | 203 +++++++++++++++++++++++++ src/sat/smt/polysat/polysat_viable.h | 4 +- src/sat/smt/polysat_solver.cpp | 20 +++ src/sat/smt/polysat_solver.h | 1 + src/util/mpq.cpp | 6 + src/util/mpq.h | 8 + src/util/mpz.cpp | 13 ++ src/util/mpz.h | 7 + src/util/rational.h | 19 ++- 15 files changed, 530 insertions(+), 3 deletions(-) create mode 100644 src/sat/smt/polysat/fixed_bits.cpp create mode 100644 src/sat/smt/polysat/fixed_bits.h diff --git a/src/sat/smt/polysat/CMakeLists.txt b/src/sat/smt/polysat/CMakeLists.txt index 0011e0ee5..6c8bed74d 100644 --- a/src/sat/smt/polysat/CMakeLists.txt +++ b/src/sat/smt/polysat/CMakeLists.txt @@ -1,5 +1,6 @@ z3_add_component(polysat SOURCES + fixed_bits.cpp polysat_assignment.cpp polysat_constraints.cpp polysat_core.cpp diff --git a/src/sat/smt/polysat/fixed_bits.cpp b/src/sat/smt/polysat/fixed_bits.cpp new file mode 100644 index 000000000..9b67c883d --- /dev/null +++ b/src/sat/smt/polysat/fixed_bits.cpp @@ -0,0 +1,180 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + Extract fixed bits from constraints + +Author: + + Jakob Rath, Nikolaj Bjorner (nbjorner), Clemens Eisenhofer 2022-08-22 + +--*/ + +#include "sat/smt/polysat/fixed_bits.h" +#include "sat/smt/polysat/polysat_ule.h" + +namespace polysat { + + /** + * 2^k * x = 2^k * b + * ==> x[N-k-1:0] = b[N-k-1:0] + */ + bool get_eq_fixed_lsb(pdd const& p, fixed_bits& out) { + SASSERT(!p.is_val()); + unsigned const N = p.power_of_2(); + // Recognize p = 2^k * a * x - 2^k * b + if (!p.hi().is_val()) + return false; + if (!p.lo().is_val()) + return false; + // p = c * x - d + rational const c = p.hi().val(); + rational const d = (-p.lo()).val(); + SASSERT(!c.is_zero()); +#if 1 + // NOTE: ule_constraint::simplify removes odd factors of the leading term + unsigned k; + VERIFY(c.is_power_of_two(k)); + if (d.parity(N) < k) + return false; + rational const b = machine_div2k(d, k); + out = fixed_bits(N - k - 1, 0, b); + SASSERT_EQ(d, b * rational::power_of_two(k)); + SASSERT_EQ(p, (p.manager().mk_var(p.var()) - out.value) * rational::power_of_two(k)); + return true; +#else + // branch if we want to support non-simplifed constraints (not recommended) + // + // 2^k * a * x = 2^k * b + // ==> x[N-k-1:0] = a^-1 * b[N-k-1:0] + // for odd a + unsigned k = c.parity(N); + if (d.parity(N) < k) + return false; + rational const a = machine_div2k(c, k); + SASSERT(a.is_odd()); + SASSERT(a.is_one()); // TODO: ule-simplify will multiply with a_inv already, so we can drop the check here. + rational a_inv; + VERIFY(a.mult_inverse(N, a_inv)); + rational const b = machine_div2k(d, k); + out.hi = N - k - 1; + out.lo = 0; + out.value = a_inv * b; + SASSERT_EQ(p, (p.manager().mk_var(p.var()) - out.value) * a * rational::power_of_two(k)); + return true; +#endif + } + + bool get_eq_fixed_bits(pdd const& p, fixed_bits& out) { + if (get_eq_fixed_lsb(p, out)) + return true; + return false; + } + + /** + * Constraint lhs <= rhs. + * + * -2^(k - 2) * x > 2^(k - 1) + * <=> 2 + x[1:0] > 2 (mod 4) + * ==> x[1:0] = 1 + * -- TODO: Generalize [the obvious solution does not work] + */ + bool get_ule_fixed_lsb(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out) { + return false; + } + + /** + * Constraint lhs <= rhs. + * + * x <= 2^k - 1 ==> x[N-1:k] = 0 + * x < 2^k ==> x[N-1:k] = 0 + */ + bool get_ule_fixed_msb(pdd const& p, pdd const& q, bool is_positive, fixed_bits& out) { + SASSERT(!q.is_zero()); // equalities are handled elsewhere + unsigned const N = p.power_of_2(); + pdd const& lhs = is_positive ? p : q; + pdd const& rhs = is_positive ? q : p; + bool const is_strict = !is_positive; + if (lhs.is_var() && rhs.is_val()) { + // x <= c + // find smallest k such that c <= 2^k - 1, i.e., c+1 <= 2^k + // ==> x <= 2^k - 1 ==> x[N-1:k] = 0 + // + // x < c + // find smallest k such that c <= 2^k + // ==> x < 2^k ==> x[N-1:k] = 0 + rational const c = is_strict ? rhs.val() : (rhs.val() + 1); + unsigned const k = c.next_power_of_two(); + if (k < N) { + out.hi = N - 1; + out.lo = k; + out.value = 0; + return true; + } + } + return false; + } + + // 2^(N-1) <= 2^(N-1-i) * x + bool get_ule_fixed_bit(pdd const& p, pdd const& q, bool is_positive, fixed_bits& out) { + return false; + } + + bool get_ule_fixed_bits(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out) { + SASSERT(ule_constraint::is_simplified(lhs, rhs)); + if (rhs.is_zero()) + return is_positive ? get_eq_fixed_bits(lhs, out) : false; + if (get_ule_fixed_msb(lhs, rhs, is_positive, out)) + return true; + if (get_ule_fixed_lsb(lhs, rhs, is_positive, out)) + return true; + if (get_ule_fixed_bit(lhs, rhs, is_positive, out)) + return true; + return false; + } + + bool get_fixed_bits(signed_constraint c, fixed_bits& out) { + SASSERT_EQ(c.vars().size(), 1); // this only makes sense for univariate constraints + if (c.is_ule()) + return get_ule_fixed_bits(c.to_ule().lhs(), c.to_ule().rhs(), c.is_positive(), out); + // if (c->is_op()) + // ; // TODO: x & constant = constant ==> bitmask ... but we have trouble recognizing that because we introduce a new variable for '&' before we see the equality. + return false; + } + + + + +/* + // 2^(k - d) * x = m * 2^(k - d) + // Special case [still seems to occur frequently]: -2^(k - 2) * x > 2^(k - 1) - TODO: Generalize [the obvious solution does not work] => lsb(x, 2) = 1 + bool get_lsb(pdd lhs, pdd rhs, pdd& p, trailing_bits& info, bool pos) { + SASSERT(lhs.is_univariate() && lhs.degree() <= 1); + SASSERT(rhs.is_univariate() && rhs.degree() <= 1); + + else { // inequality - check for special case + if (pos || lhs.power_of_2() < 3) + return false; + auto it = lhs.begin(); + if (it == lhs.end()) + return false; + if (it->vars.size() != 1) + return false; + rational coeff = it->coeff; + it++; + if (it != lhs.end()) + return false; + if ((mod2k(-coeff, lhs.power_of_2())) != rational::power_of_two(lhs.power_of_2() - 2)) + return false; + p = lhs.div(coeff); + SASSERT(p.is_var()); + info.bits = 1; + info.length = 2; + info.positive = true; // this is a conjunction + return true; + } + } +*/ + +} // namespace polysat diff --git a/src/sat/smt/polysat/fixed_bits.h b/src/sat/smt/polysat/fixed_bits.h new file mode 100644 index 000000000..78b4a643f --- /dev/null +++ b/src/sat/smt/polysat/fixed_bits.h @@ -0,0 +1,31 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + Extract fixed bits of variables from univariate constraints + +Author: + + Jakob Rath, Nikolaj Bjorner (nbjorner), Clemens Eisenhofer 2022-08-22 + +--*/ +#pragma once +#include "sat/smt/polysat/polysat_types.h" +#include "sat/smt/polysat/polysat_constraints.h" +#include "util/vector.h" + +namespace polysat { + + using fixed_bits_vector = vector; + + bool get_eq_fixed_lsb(pdd const& p, fixed_bits& out); + bool get_eq_fixed_bits(pdd const& p, fixed_bits& out); + + bool get_ule_fixed_lsb(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out); + bool get_ule_fixed_msb(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out); + bool get_ule_fixed_bit(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out); + bool get_ule_fixed_bits(pdd const& lhs, pdd const& rhs, bool is_positive, fixed_bits& out); + bool get_fixed_bits(signed_constraint c, fixed_bits& out); + +} diff --git a/src/sat/smt/polysat/polysat_core.cpp b/src/sat/smt/polysat/polysat_core.cpp index c41938a1a..09bfbd244 100644 --- a/src/sat/smt/polysat/polysat_core.cpp +++ b/src/sat/smt/polysat/polysat_core.cpp @@ -282,6 +282,10 @@ namespace polysat { s.get_bitvector_prefixes(v, out); } + void core::get_fixed_bits(pvar v, svector& fixed_bits) { + s.get_fixed_bits(v, fixed_bits); + } + bool core::inconsistent() const { return s.inconsistent(); } diff --git a/src/sat/smt/polysat/polysat_core.h b/src/sat/smt/polysat/polysat_core.h index 766c3a9bc..b5b7dd380 100644 --- a/src/sat/smt/polysat/polysat_core.h +++ b/src/sat/smt/polysat/polysat_core.h @@ -80,7 +80,9 @@ namespace polysat { void propagate_unsat_core(); void get_bitvector_prefixes(pvar v, pvar_vector& out); + void get_fixed_bits(pvar v, svector& fixed_bits); bool inconsistent() const; + void add_watch(unsigned idx, unsigned var); diff --git a/src/sat/smt/polysat/polysat_types.h b/src/sat/smt/polysat/polysat_types.h index 5b63b5ee5..b9b440481 100644 --- a/src/sat/smt/polysat/polysat_types.h +++ b/src/sat/smt/polysat/polysat_types.h @@ -52,6 +52,39 @@ namespace polysat { return out << "v" << d.eq().first << " == v" << d.eq().second << "@" << d.level(); } + struct trailing_bits { + unsigned length; + rational bits; + bool positive; + unsigned src_idx; + }; + + struct leading_bits { + unsigned length; + bool positive; // either all 0 or all 1 + unsigned src_idx; + }; + + struct single_bit { + bool positive; + unsigned position; + unsigned src_idx; + }; + + struct fixed_bits { + unsigned hi = 0; + unsigned lo = 0; + rational value; + + /// The constraint is equivalent to setting fixed bits on a variable. + // bool is_equivalent; + + fixed_bits() = default; + fixed_bits(unsigned hi, unsigned lo, rational value) : hi(hi), lo(lo), value(value) {} + }; + + struct justified_fixed_bits : public fixed_bits, public dependency {}; + using dependency_vector = vector; class signed_constraint; @@ -66,6 +99,7 @@ namespace polysat { virtual trail_stack& trail() = 0; virtual bool inconsistent() const = 0; virtual void get_bitvector_prefixes(pvar v, pvar_vector& out) = 0; + virtual void get_fixed_bits(pvar v, svector& fixed_bits) = 0; }; } diff --git a/src/sat/smt/polysat/polysat_viable.cpp b/src/sat/smt/polysat/polysat_viable.cpp index d69f40180..a8a9fd6af 100644 --- a/src/sat/smt/polysat/polysat_viable.cpp +++ b/src/sat/smt/polysat/polysat_viable.cpp @@ -20,6 +20,7 @@ Notes: #include "util/log.h" #include "sat/smt/polysat/polysat_viable.h" #include "sat/smt/polysat/polysat_core.h" +#include "sat/smt/polysat/polysat_ule.h" namespace polysat { @@ -420,6 +421,206 @@ namespace polysat { return l_undef; } + // returns true iff no conflict was encountered + bool viable::collect_bit_information(pvar v, bool add_conflict, fixed_bits_info& out_fbi) { + + pdd p = c.var(v); + unsigned const v_sz = c.size(v); + out_fbi.reset(v_sz); + auto& [fixed, just_src, just_side_cond, just_slice] = out_fbi; + + svector fbs; + c.get_fixed_bits(v, fbs); + + for (auto const& fb : fbs) { + LOG("slicing fixed bits: v" << v << "[" << fb.hi << ":" << fb.lo << "] = " << fb.value); + for (unsigned i = fb.lo; i <= fb.hi; ++i) { + SASSERT(out_fbi.just_src[i].empty()); // since we don't get overlapping ranges from collect_fixed. + SASSERT(out_fbi.just_side_cond[i].empty()); + SASSERT(out_fbi.just_slicing[i].empty()); + out_fbi.fixed[i] = to_lbool(fb.value.get_bit(i - fb.lo)); + out_fbi.just_slicing[i].push_back(fb); + } + } + + entry* e1 = m_equal_lin[v]; + entry* e2 = m_units[v].get_entries(c.size(v)); // TODO: take other widths into account (will be done automatically by tracking fixed bits in the slicing egraph) + entry* first = e1; + if (!e1 && !e2) + return true; +#if 0 + + clause_builder builder(s, "bit check"); + sat::literal_set added; + vector> postponed; + + auto add_literal = [&builder, &added](sat::literal lit) { + if (added.contains(lit)) + return; + added.insert(lit); + builder.insert_eval(~lit); + }; + + auto add_literals = [&add_literal](sat::literal_vector const& lits) { + for (sat::literal lit : lits) + add_literal(lit); + }; + + auto add_entry = [&add_literal](entry* e) { + for (const auto& sc : e->side_cond) + add_literal(sc.blit()); + for (const auto& src : e->src) + add_literal(src.blit()); + }; + + auto add_slicing = [this, &add_literal](slicing::enode* n) { + s.m_slicing.explain_fixed(n, [&](sat::literal lit) { + add_literal(lit); + }, [&](pvar v) { + LOG("from slicing: v" << v); + add_literal(s.cs().eq(c.var(v), c.get_value(v)).blit()); + }); + }; + + auto add_bit_justification = [&add_literals, &add_slicing](fixed_bits_info const& fbi, unsigned i) { + add_literals(fbi.just_src[i]); + add_literals(fbi.just_side_cond[i]); + for (slicing::enode* n : fbi.just_slicing[i]) + add_slicing(n); + }; + + if (e1) { + unsigned largest_lsb = 0; + do { + if (e1->src.size() != 1) { + // We just consider the ordinary constraints and not already contracted ones + e1 = e1->next(); + continue; + } + signed_constraint& src = e1->src[0]; + single_bit bit; + trailing_bits lsb; + if (src.is_ule() && + simplify_clause::get_bit(s.subst(src.to_ule().lhs()), s.subst(src.to_ule().rhs()), p, bit, src.is_positive()) && p.is_var()) { + lbool prev = fixed[bit.position]; + fixed[bit.position] = to_lbool(bit.positive); + //verbose_stream() << "Setting bit " << bit.position << " to " << bit.positive << " because of " << e->src << "\n"; + if (prev != l_undef && fixed[bit.position] != prev) { + // LOG("Bit conflicting " << e1->src << " with " << just_src[bit.position][0]); // NOTE: just_src may be empty if the justification is by slicing + if (add_conflict) { + add_bit_justification(out_fbi, bit.position); + add_entry(e1); + s.set_conflict(*builder.build()); + } + return false; + } + // just override; we prefer bit constraints over parity as those are easier for subsumption to remove + // do we just introduce a new justification here that subsumption will remove anyway? + // the only way it will not is if all bits are overwritten like this. + // but in that case we basically replace one parity constraint by multiple bit constraints? + // verbose_stream() << "Adding bit constraint: " << e->src[0] << " (" << bit.position << ")\n"; + if (prev == l_undef) { + out_fbi.set_just(bit.position, e1); + } + } + else if (src.is_eq() && + simplify_clause::get_lsb(s.subst(src.to_ule().lhs()), s.subst(src.to_ule().rhs()), p, lsb, src.is_positive()) && p.is_var()) { + if (src.is_positive()) { + for (unsigned i = 0; i < lsb.length; i++) { + lbool prev = fixed[i]; + fixed[i] = to_lbool(lsb.bits.get_bit(i)); + if (prev == l_undef) { + SASSERT(just_src[i].empty()); + out_fbi.set_just(i, e1); + continue; + } + if (fixed[i] != prev) { + // LOG("Positive parity conflicting " << e1->src << " with " << just_src[i][0]); // NOTE: just_src may be empty if the justification is by slicing + if (add_conflict) { + add_bit_justification(out_fbi, i); + add_entry(e1); + s.set_conflict(*builder.build()); + } + return false; + } + // Prefer justifications from larger masks (less premises) + // TODO: Check that we don't override justifications coming from bit constraints + if (largest_lsb < lsb.length) + out_fbi.set_just(i, e1); + } + largest_lsb = std::max(largest_lsb, lsb.length); + } + else + postponed.push_back({ e1, lsb }); + } + e1 = e1->next(); + } while (e1 != first); + } + + // so far every bit is justified by a single constraint + SASSERT(all_of(just_src, [](auto const& vec) { return vec.size() <= 1; })); + + // TODO: Incomplete - e.g., if we know the trailing bits are not 00 not 10 not 01 and not 11 we could also detect a conflict + // This would require partially clause solving (worth the effort?) + bool_vector removed(postponed.size(), false); + bool changed; + do { // fixed-point required? + changed = false; + for (unsigned j = 0; j < postponed.size(); j++) { + if (removed[j]) + continue; + const auto& neg = postponed[j]; + unsigned indet = 0; + unsigned last_indet = 0; + unsigned i = 0; + for (; i < neg.second.length; i++) { + if (fixed[i] != l_undef) { + if (fixed[i] != to_lbool(neg.second.bits.get_bit(i))) { + removed[j] = true; + break; // this is already satisfied + } + } + else { + indet++; + last_indet = i; + } + } + if (i == neg.second.length) { + if (indet == 0) { + // Already false + LOG("Found conflict with constraint " << neg.first->src); + if (add_conflict) { + for (unsigned k = 0; k < neg.second.length; k++) + add_bit_justification(out_fbi, k); + add_entry(neg.first); + s.set_conflict(*builder.build()); + } + return false; + } + else if (indet == 1) { + // Simple BCP + SASSERT(just_src[last_indet].empty()); + SASSERT(just_side_cond[last_indet].empty()); + for (unsigned k = 0; k < neg.second.length; k++) { + if (k != last_indet) { + SASSERT(fixed[k] != l_undef); + out_fbi.push_from_bit(last_indet, k); + } + } + out_fbi.push_just(last_indet, neg.first); + fixed[last_indet] = neg.second.bits.get_bit(last_indet) ? l_false : l_true; + removed[j] = true; + LOG("Applying fast BCP on bit " << last_indet << " from constraint " << neg.first->src); + changed = true; + } + } + } + } while (changed); +#endif + + return true; + } + /* * Explain why the current variable is not viable or signleton. @@ -436,6 +637,8 @@ namespace polysat { if (c.is_assigned(v)) return; auto [sc, d] = c.m_constraint_trail[idx]; + // fixme: constraint must be assigned a value l_true or l_false at this point. + // adjust sc to the truth value of the constraint when passed to forbidden intervals. entry* ne = alloc_entry(v, idx); if (!m_forbidden_intervals.get_interval(sc, v, *ne)) { diff --git a/src/sat/smt/polysat/polysat_viable.h b/src/sat/smt/polysat/polysat_viable.h index 37b1d7b0c..1812ec40c 100644 --- a/src/sat/smt/polysat/polysat_viable.h +++ b/src/sat/smt/polysat/polysat_viable.h @@ -87,7 +87,7 @@ namespace polysat { svector fixed; vector> just_src; vector> just_side_cond; - vector> just_slicing; + vector> just_slicing; bool is_empty() const { SASSERT_EQ(fixed.empty(), just_src.empty()); @@ -219,6 +219,8 @@ namespace polysat { throw default_exception("nyi"); } + bool collect_bit_information(pvar v, bool add_conflict, fixed_bits_info& out_fbi); + public: viable(core& c); diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 999db42ea..7e867af2f 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -296,4 +296,24 @@ namespace polysat { n->get_root()->unmark1(); } + void solver::get_fixed_bits(pvar pv, svector& fixed_bits) { + theory_var v = m_pddvar2var[pv]; + auto n = var2enode(v); + auto r = n->get_root(); + unsigned lo, hi; + expr* e = nullptr; + for (auto p : euf::enode_parents(r)) { + if (!p->interpreted()) + continue; + for (auto sib : euf::enode_class(p)) { + if (bv.is_extract(sib->get_expr(), lo, hi, e) && r == expr2enode(e)->get_root()) { + throw default_exception("nyi"); + // TODO + // dependency d = dependency(p->get_th_var(get_id()), n->get_th_var(get_id()), s().scope_lvl()); + // fixed_bits.push_back({ hi, lo, rational::zero(), null_dependency()}); + } + } + } + } + } diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 56ab615dc..6c7260f95 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -134,6 +134,7 @@ namespace polysat { trail_stack& trail() override; bool inconsistent() const override; void get_bitvector_prefixes(pvar v, pvar_vector& out) override; + void get_fixed_bits(pvar v, svector& fixed_bits) override; void add_lemma(vector const& lemma); diff --git a/src/util/mpq.cpp b/src/util/mpq.cpp index 324750cfa..f90843e36 100644 --- a/src/util/mpq.cpp +++ b/src/util/mpq.cpp @@ -316,6 +316,12 @@ unsigned mpq_manager::prev_power_of_two(mpq const & a) { return prev_power_of_two(_tmp); } +template +unsigned mpq_manager::next_power_of_two(mpq const & a) { + _scoped_numeral > _tmp(*this); + ceil(a, _tmp); + return next_power_of_two(_tmp); +} template template diff --git a/src/util/mpq.h b/src/util/mpq.h index e254ade69..1bdf8f31b 100644 --- a/src/util/mpq.h +++ b/src/util/mpq.h @@ -848,6 +848,14 @@ public: unsigned prev_power_of_two(mpz const & a) { return mpz_manager::prev_power_of_two(a); } unsigned prev_power_of_two(mpq const & a); + /** + \brief Return the smallest k s.t. a <= 2^k. + + \remark Return 0 if a is not positive. + */ + unsigned next_power_of_two(mpz const & a) { return mpz_manager::next_power_of_two(a); } + unsigned next_power_of_two(mpq const & a); + bool is_int_perfect_square(mpq const & a, mpq & r) { SASSERT(is_int(a)); reset_denominator(r); diff --git a/src/util/mpz.cpp b/src/util/mpz.cpp index c3ba30161..296b4426e 100644 --- a/src/util/mpz.cpp +++ b/src/util/mpz.cpp @@ -2288,6 +2288,19 @@ unsigned mpz_manager::bitsize(mpz const & a) { return mlog2(a) + 1; } +template +unsigned mpz_manager::next_power_of_two(mpz const & a) { + if (is_nonpos(a)) + return 0; + if (is_one(a)) + return 0; + unsigned shift; + if (is_power_of_two(a, shift)) + return shift; + else + return log2(a) + 1; +} + template bool mpz_manager::is_perfect_square(mpz const & a, mpz & root) { if (is_neg(a)) diff --git a/src/util/mpz.h b/src/util/mpz.h index a1bb19395..bb1141ea7 100644 --- a/src/util/mpz.h +++ b/src/util/mpz.h @@ -692,6 +692,13 @@ public: \remark Return 0 if a is not positive. */ unsigned prev_power_of_two(mpz const & a) { return log2(a); } + + /** + \brief Return the smallest k s.t. a <= 2^k. + + \remark Return 0 if a is not positive. + */ + unsigned next_power_of_two(mpz const & a); /** \brief Return true if a^{1/n} is an integer, and store the result in a. diff --git a/src/util/rational.h b/src/util/rational.h index 4253bd4a7..88a0552ba 100644 --- a/src/util/rational.h +++ b/src/util/rational.h @@ -55,7 +55,7 @@ public: explicit rational(double z) { UNREACHABLE(); } explicit rational(char const * v) { m().set(m_val, v); } - + explicit rational(unsigned const * v, unsigned sz) { m().set(m_val, sz, v); } struct i64 {}; @@ -489,6 +489,18 @@ public: return get_num_digits(rational(10)); } + /** + * \brief Return the biggest k s.t. 2^k <= a. + * \remark Return 0 if a is not positive. + */ + unsigned prev_power_of_two() const { return m().prev_power_of_two(m_val); } + + /** + * \brief Return the smallest k s.t. a <= 2^k. + * \remark Return 0 if a is not positive. + */ + unsigned next_power_of_two() const { return m().next_power_of_two(m_val); } + bool get_bit(unsigned index) const { return m().get_bit(m_val, index); } @@ -510,7 +522,6 @@ public: return trailing_zeros(); } - static bool limit_denominator(rational &num, rational const& limit); }; @@ -659,3 +670,7 @@ inline rational gcd(rational const & r1, rational const & r2, rational & a, rati rational::m().gcd(r1.m_val, r2.m_val, a.m_val, b.m_val, result.m_val); return result; } + +inline void swap(rational& r1, rational& r2) { + r1.swap(r2); +} From 09eac8e3715382278b580cbd28d22fc511c79a27 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 9 Dec 2023 09:54:17 -0800 Subject: [PATCH 19/89] allow tracking values of constraints --- src/sat/smt/polysat/polysat_core.cpp | 26 ++++++++++++++++++-------- src/sat/smt/polysat/polysat_core.h | 7 ++++--- src/sat/smt/polysat/polysat_types.h | 4 ++++ src/sat/smt/polysat/polysat_viable.cpp | 2 +- 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/sat/smt/polysat/polysat_core.cpp b/src/sat/smt/polysat/polysat_core.cpp index 09bfbd244..ff35f20ee 100644 --- a/src/sat/smt/polysat/polysat_core.cpp +++ b/src/sat/smt/polysat/polysat_core.cpp @@ -67,13 +67,13 @@ namespace polysat { public: mk_add_watch(core& c) : c(c) {} void undo() override { - auto& [sc, lit] = c.m_constraint_trail.back(); + auto& [sc, lit, val] = c.m_constraint_index.back(); auto& vars = sc.vars(); if (vars.size() > 0) c.m_watch[vars[0]].pop_back(); if (vars.size() > 1) c.m_watch[vars[1]].pop_back(); - c.m_constraint_trail.pop_back(); + c.m_constraint_index.pop_back(); } }; @@ -121,8 +121,8 @@ namespace polysat { } unsigned core::register_constraint(signed_constraint& sc, dependency d) { - unsigned idx = m_constraint_trail.size(); - m_constraint_trail.push_back({ sc, d }); + unsigned idx = m_constraint_index.size(); + m_constraint_index.push_back({ sc, d, l_undef }); auto& vars = sc.vars(); unsigned i = 0, j = 0, sz = vars.size(); for (; i < sz && j < 2; ++i) @@ -177,7 +177,7 @@ namespace polysat { } signed_constraint core::get_constraint(unsigned idx, bool sign) { - auto sc = m_constraint_trail[idx].sc; + auto sc = m_constraint_index[idx].sc; if (sign) sc = ~sc; return sc; @@ -212,7 +212,7 @@ namespace polysat { // for entries where there is only one free variable left add to viable set unsigned j = 0; for (auto idx : m_watch[v]) { - auto [sc, as] = m_constraint_trail[idx]; + auto [sc, as, value] = m_constraint_index[idx]; auto& vars = sc.vars(); if (vars[0] != v) std::swap(vars[0], vars[1]); @@ -263,7 +263,7 @@ namespace polysat { for (auto idx1 : m_watch[m_var]) { if (idx == idx1) continue; - auto [sc, d] = m_constraint_trail[idx1]; + auto [sc, d, value] = m_constraint_index[idx1]; switch (eval(sc)) { case l_false: s.propagate(d, true, explain_eval(sc)); @@ -298,8 +298,18 @@ namespace polysat { } void core::assign_eh(unsigned index, bool sign, dependency const& dep) { + struct unassign : public trail { + core& c; + unsigned m_index; + unassign(core& c, unsigned index): c(c), m_index(index) {} + void undo() override { + c.m_constraint_index[m_index].value = l_undef; + c.m_prop_queue.pop_back(); + } + }; m_prop_queue.push_back({ index, sign, dep }); - s.trail().push(push_back_vector(m_prop_queue)); + m_constraint_index[index].value = to_lbool(!sign); + s.trail().push(unassign(*this, index)); } dependency_vector core::explain_eval(signed_constraint const& sc) { diff --git a/src/sat/smt/polysat/polysat_core.h b/src/sat/smt/polysat/polysat_core.h index b5b7dd380..b4f35e776 100644 --- a/src/sat/smt/polysat/polysat_core.h +++ b/src/sat/smt/polysat/polysat_core.h @@ -42,8 +42,9 @@ namespace polysat { friend class assignment; struct constraint_info { - signed_constraint sc; - dependency d; + signed_constraint sc; // signed constraint representation + dependency d; // justification for constraint + lbool value; // value assigned by solver }; solver_interface& s; viable m_viable; @@ -51,7 +52,7 @@ namespace polysat { assignment m_assignment; unsigned m_qhead = 0, m_vqhead = 0; svector m_prop_queue; - svector m_constraint_trail; // index of constraints + svector m_constraint_index; // index of constraints mutable scoped_ptr_vector m_pdd; dependency_vector m_unsat_core; diff --git a/src/sat/smt/polysat/polysat_types.h b/src/sat/smt/polysat/polysat_types.h index b9b440481..92896a7f7 100644 --- a/src/sat/smt/polysat/polysat_types.h +++ b/src/sat/smt/polysat/polysat_types.h @@ -89,6 +89,10 @@ namespace polysat { class signed_constraint; + // + // The interface that PolySAT uses to the SAT/SMT solver. + // + class solver_interface { public: virtual void add_eq_literal(pvar v, rational const& val) = 0; diff --git a/src/sat/smt/polysat/polysat_viable.cpp b/src/sat/smt/polysat/polysat_viable.cpp index a8a9fd6af..8d9009af0 100644 --- a/src/sat/smt/polysat/polysat_viable.cpp +++ b/src/sat/smt/polysat/polysat_viable.cpp @@ -636,7 +636,7 @@ namespace polysat { if (c.is_assigned(v)) return; - auto [sc, d] = c.m_constraint_trail[idx]; + auto [sc, d, value] = c.m_constraint_index[idx]; // fixme: constraint must be assigned a value l_true or l_false at this point. // adjust sc to the truth value of the constraint when passed to forbidden intervals. From 2b49bd189ab64ebc9c6d78640b670115cae44377 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 9 Dec 2023 16:20:35 -0800 Subject: [PATCH 20/89] fixed fixme Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat/polysat_viable.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sat/smt/polysat/polysat_viable.cpp b/src/sat/smt/polysat/polysat_viable.cpp index 8d9009af0..9af70d716 100644 --- a/src/sat/smt/polysat/polysat_viable.cpp +++ b/src/sat/smt/polysat/polysat_viable.cpp @@ -637,8 +637,9 @@ namespace polysat { if (c.is_assigned(v)) return; auto [sc, d, value] = c.m_constraint_index[idx]; - // fixme: constraint must be assigned a value l_true or l_false at this point. - // adjust sc to the truth value of the constraint when passed to forbidden intervals. + SASSERT(value != l_undef); + if (value == l_false) + sc = ~sc; entry* ne = alloc_entry(v, idx); if (!m_forbidden_intervals.get_interval(sc, v, *ne)) { From 207735d55c07f0b686912fbe0273fbf403106cdf Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 9 Dec 2023 17:14:59 -0800 Subject: [PATCH 21/89] n/a Signed-off-by: Nikolaj Bjorner --- scripts/mk_project.py | 2 +- src/sat/smt/polysat/polysat_core.cpp | 7 +- src/sat/smt/polysat/polysat_core.h | 3 +- src/sat/smt/polysat/polysat_types.h | 8 +- src/sat/smt/polysat/polysat_viable.cpp | 262 +++++++++++++++++++++---- src/sat/smt/polysat/polysat_viable.h | 28 ++- src/sat/smt/polysat_solver.cpp | 19 +- src/sat/smt/polysat_solver.h | 2 +- 8 files changed, 273 insertions(+), 58 deletions(-) diff --git a/scripts/mk_project.py b/scripts/mk_project.py index 0f5dc26ae..99051fe93 100644 --- a/scripts/mk_project.py +++ b/scripts/mk_project.py @@ -58,7 +58,7 @@ def init_project_def(): add_lib('proto_model', ['model', 'rewriter', 'smt_params'], 'smt/proto_model') add_lib('smt', ['bit_blaster', 'macros', 'normal_forms', 'cmd_context', 'proto_model', 'solver_assertions', 'substitution', 'grobner', 'simplex', 'proofs', 'pattern', 'parser_util', 'fpa', 'lp']) - add_lib('polysat', ['util', 'dd'], 'sat/smt/polysat'), + add_lib('polysat', ['util', 'dd', 'sat'], 'sat/smt/polysat'), add_lib('sat_smt', ['sat', 'euf', 'smt', 'tactic', 'solver', 'smt_params', 'bit_blaster', 'fpa', 'mbp', 'polysat', 'normal_forms', 'lp', 'pattern', 'qe_lite'], 'sat/smt') add_lib('sat_tactic', ['tactic', 'sat', 'solver', 'sat_smt'], 'sat/tactic') add_lib('nlsat_tactic', ['nlsat', 'sat_tactic', 'arith_tactics'], 'nlsat/tactic') diff --git a/src/sat/smt/polysat/polysat_core.cpp b/src/sat/smt/polysat/polysat_core.cpp index ff35f20ee..be25c9af8 100644 --- a/src/sat/smt/polysat/polysat_core.cpp +++ b/src/sat/smt/polysat/polysat_core.cpp @@ -107,6 +107,7 @@ namespace polysat { m_justification.push_back(null_dependency); m_watch.push_back({}); m_var_queue.mk_var_eh(v); + m_viable.ensure_var(v); s.trail().push(mk_add_var(*this)); return v; } @@ -147,8 +148,8 @@ namespace polysat { s.trail().push(mk_dqueue_var(m_var, *this)); switch (m_viable.find_viable(m_var, m_value)) { case find_t::empty: - m_unsat_core = m_viable.explain(); - propagate_unsat_core(); + s.set_lemma(m_viable.get_core(), 0, m_viable.explain()); + // propagate_unsat_core(); return sat::check_result::CR_CONTINUE; case find_t::singleton: s.propagate(m_constraints.eq(var2pdd(m_var), m_value), m_viable.explain()); @@ -294,7 +295,7 @@ namespace polysat { // default is to use unsat core: // if core is based on viable, use s.set_lemma(); - s.set_conflict(m_unsat_core); + s.set_conflict(m_unsat_core); } void core::assign_eh(unsigned index, bool sign, dependency const& dep) { diff --git a/src/sat/smt/polysat/polysat_core.h b/src/sat/smt/polysat/polysat_core.h index b4f35e776..144b1256b 100644 --- a/src/sat/smt/polysat/polysat_core.h +++ b/src/sat/smt/polysat/polysat_core.h @@ -17,9 +17,10 @@ Author: --*/ #pragma once +#include "util/var_queue.h" #include "util/dependency.h" #include "math/dd/dd_pdd.h" -#include "sat/smt/sat_th.h" +#include "sat/sat_extension.h" #include "sat/smt/polysat/polysat_types.h" #include "sat/smt/polysat/polysat_constraints.h" #include "sat/smt/polysat/polysat_viable.h" diff --git a/src/sat/smt/polysat/polysat_types.h b/src/sat/smt/polysat/polysat_types.h index 92896a7f7..1df9912bc 100644 --- a/src/sat/smt/polysat/polysat_types.h +++ b/src/sat/smt/polysat/polysat_types.h @@ -26,7 +26,7 @@ namespace polysat { using pvar_vector = unsigned_vector; inline const pvar null_var = UINT_MAX; - + class signed_constraint; class dependency { std::variant> m_data; @@ -87,7 +87,9 @@ namespace polysat { using dependency_vector = vector; - class signed_constraint; + using core_vector = vector>; + + // // The interface that PolySAT uses to the SAT/SMT solver. @@ -97,7 +99,7 @@ namespace polysat { public: virtual void add_eq_literal(pvar v, rational const& val) = 0; virtual void set_conflict(dependency_vector const& core) = 0; - virtual void set_lemma(vector const& lemma, unsigned level, dependency_vector const& core) = 0; + virtual void set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) = 0; virtual dependency propagate(signed_constraint sc, dependency_vector const& deps) = 0; virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0; virtual trail_stack& trail() = 0; diff --git a/src/sat/smt/polysat/polysat_viable.cpp b/src/sat/smt/polysat/polysat_viable.cpp index 9af70d716..4152956de 100644 --- a/src/sat/smt/polysat/polysat_viable.cpp +++ b/src/sat/smt/polysat/polysat_viable.cpp @@ -79,7 +79,6 @@ namespace polysat { find_t viable::find_viable(pvar v, rational& lo) { rational hi; - ensure_var(v); switch (find_viable(v, lo, hi)) { case l_true: return (lo == hi) ? find_t::singleton : find_t::multiple; @@ -106,9 +105,6 @@ namespace polysat { // max size should always be present, regardless of whether we have intervals there (to make sure all fixed bits are considered) widths_set.insert(c.size(v)); - for (pvar v : overlaps) - ensure_var(v); - for (pvar v : overlaps) for (layer const& l : m_units[v].get_layers()) widths_set.insert(l.bit_width); @@ -121,6 +117,7 @@ namespace polysat { rational const& max_value = c.var2pdd(v).max_value(); + m_explain.reset(); lbool result_lo = find_on_layers(v, widths, overlaps, fbi, rational::zero(), max_value, lo); if (result_lo != l_true) return result_lo; @@ -129,18 +126,13 @@ namespace polysat { hi = lo; return l_true; } - + lbool result_hi = find_on_layers(v, widths, overlaps, fbi, lo + 1, max_value, hi); - switch (result_hi) { - case l_false: - hi = lo; - return l_true; - case l_undef: - return l_undef; - default: - return l_true; - } + if (result_hi != l_false) + return result_hi; + hi = lo; + return l_true; } // l_true ... found viable value @@ -153,18 +145,17 @@ namespace polysat { fixed_bits_info const& fbi, rational const& to_cover_lo, rational const& to_cover_hi, - rational& val - ) { - ptr_vector refine_todo; - ptr_vector relevant_entries; + rational& val) { + ptr_vector refine_todo; // max number of interval refinements before falling back to the univariate solver unsigned const refinement_budget = 100; unsigned refinements = refinement_budget; + unsigned explain_size = m_explain.size(); while (refinements--) { - relevant_entries.clear(); - lbool result = find_on_layer(v, widths.size() - 1, widths, overlaps, fbi, to_cover_lo, to_cover_hi, val, refine_todo, relevant_entries); + m_explain.shrink(explain_size); + lbool result = find_on_layer(v, widths.size() - 1, widths, overlaps, fbi, to_cover_lo, to_cover_hi, val, refine_todo); // store bit-intervals we have used for (entry* e : refine_todo) @@ -191,8 +182,6 @@ namespace polysat { if (!refined) return l_true; } - - LOG("Refinement budget exhausted! Fall back to univariate solver."); return l_undef; } @@ -211,12 +200,10 @@ namespace polysat { rational const& to_cover_lo, rational const& to_cover_hi, rational& val, - ptr_vector& refine_todo, - ptr_vector& relevant_entries - ) { + ptr_vector& refine_todo) { unsigned const w = widths[w_idx]; rational const& mod_value = rational::power_of_two(w); - unsigned const first_relevant_for_conflict = relevant_entries.size(); + unsigned const first_relevant_for_conflict = m_explain.size(); LOG("layer " << w << " bits, to_cover: [" << to_cover_lo << "; " << to_cover_hi << "["); SASSERT(0 <= to_cover_lo); @@ -295,12 +282,12 @@ namespace polysat { if (!e) break; - relevant_entries.push_back(e); + m_explain.push_back(e); if (e->interval.is_full()) { - relevant_entries.clear(); - relevant_entries.push_back(e); // full interval e -> all other intervals are subsumed/irrelevant - set_conflict_by_interval(v, w, relevant_entries, 0); + m_explain.reset(); + m_explain.push_back(e); // full interval e -> all other intervals are subsumed/irrelevant + set_conflict_by_interval(v, w, m_explain, 0); return l_false; } @@ -314,7 +301,7 @@ namespace polysat { if (progress >= mod_value) { // covered the whole domain => conflict - set_conflict_by_interval(v, w, relevant_entries, first_relevant_for_conflict); + set_conflict_by_interval(v, w, m_explain, first_relevant_for_conflict); return l_false; } if (progress >= to_cover_len) { @@ -365,7 +352,7 @@ namespace polysat { lower_cover_lo = 0; lower_cover_hi = lower_mod_value; rational a; - lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo, relevant_entries); + lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo); VERIFY(result != l_undef); if (result == l_false) { SASSERT(c.inconsistent()); @@ -387,7 +374,7 @@ namespace polysat { lower_cover_hi = mod(next_val, lower_mod_value); rational a; - lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo, relevant_entries); + lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo); if (result == l_false) { SASSERT(c.inconsistent()); return l_false; // conflict @@ -408,7 +395,7 @@ namespace polysat { if (progress >= mod_value) { // covered the whole domain => conflict - set_conflict_by_interval(v, w, relevant_entries, first_relevant_for_conflict); + set_conflict_by_interval(v, w, m_explain, first_relevant_for_conflict); return l_false; } @@ -421,6 +408,197 @@ namespace polysat { return l_undef; } + void viable::set_conflict_by_interval(pvar v, unsigned w, ptr_vector& intervals, unsigned first_interval) { + SASSERT(!intervals.empty()); + SASSERT(first_interval < intervals.size()); + +#if 0 + bool create_lemma = true; + uint_set vars_to_explain; + char const* lemma_name = nullptr; + + // if there is a full interval, it should be the only one in the given vector + if (intervals.size() == 1 && intervals[0]->interval.is_full()) { + lemma_name = "viable (full interval)"; + entry const* e = intervals[0]; + for (auto sc : e->side_cond) + lemma.insert_eval(~sc); + for (const auto& src : e->src) { + lemma.insert(~src); + core.insert(src); + core.insert_vars(src); + } + if (v != e->var) + vars_to_explain.insert(e->var); + } + else { + SASSERT(all_of(intervals, [](entry* e) { return e->interval.is_proper(); })); + lemma_name = "viable (proper intervals)"; + + // allocate one dummy space in intervals storage to simplify recursive calls + intervals.push_back(nullptr); + entry** intervals_begin = intervals.data() + first_interval; + unsigned num_intervals = intervals.size() - first_interval - 1; + + if (!set_conflict_by_interval_rec(v, w, intervals_begin, num_intervals, core, create_lemma, lemma, vars_to_explain)) + return false; + } + + for (pvar x : vars_to_explain) { + s.m_slicing.explain_simple_overlap(v, x, [this, &core, &lemma](sat::literal l) { + lemma.insert(~l); + core.insert(s.lit2cnstr(l)); + }); + } + + if (create_lemma) + core.add_lemma(lemma_name, lemma.build()); + + //core.logger().log(inf_fi(*this, v)); + core.init_viable_end(v); + return true; +#endif + } + + bool viable::set_conflict_by_interval_rec(pvar v, unsigned w, entry** intervals, unsigned num_intervals, bool& create_lemma, uint_set& vars_to_explain) { + SASSERT(std::all_of(intervals, intervals + num_intervals, [w](entry const* e) { return e->bit_width <= w; })); + // TODO: assert invariants on intervals list + rational const& mod_value = rational::power_of_two(w); + + // heuristic: find longest interval as starting point + unsigned longest_idx = UINT_MAX; + rational longest_len; + for (unsigned i = 0; i < num_intervals; ++i) { + entry* e = intervals[i]; + if (e->bit_width != w) + continue; + rational len = e->interval.current_len(); + if (len > longest_len) { + longest_idx = i; + longest_len = len; + } + } + SASSERT(longest_idx < UINT_MAX); + entry* longest = intervals[longest_idx]; + + if (!longest->valid_for_lemma) + create_lemma = false; + + unsigned i = longest_idx; + entry* e = longest; // e is the current baseline + + entry* tmp = nullptr; + on_scope_exit dont_leak_entry = [this, &tmp]() { + if (tmp) + m_alloc.push_back(tmp); + }; + +#if 0 + while (!longest->interval.currently_contains(e->interval.hi_val())) { + unsigned j = (i + 1) % num_intervals; + entry* n = intervals[j]; + + if (n->bit_width != w) { + // we have a hole starting with 'n', to be filled with intervals at lower bit-width. + // note that the next lower bit-width is not necessarily n->bit_width (e.g., the next layer may start with a hole itself) + unsigned w2 = n->bit_width; + // first, find the next constraint after the hole + unsigned last_idx = j; + unsigned k = j; + while (intervals[k]->bit_width != w) { + if (intervals[k]->bit_width > w2) + w2 = intervals[k]->bit_width; + last_idx = k; + k = (k + 1) % num_intervals; + } + entry* after = intervals[k]; + SASSERT(j < last_idx); // the hole cannot wrap around (but k may be 0) + + rational const& lower_mod_value = rational::power_of_two(w2); + SASSERT(distance(e->interval.hi_val(), after->interval.lo_val(), mod_value) < lower_mod_value); // otherwise we would have started the conflict at w2 already + + // create temporary entry that covers the hole-complement on the lower level + if (!tmp) + tmp = alloc_entry(v); + pdd dummy = s.var2pdd(v).mk_val(123); // we could create extract-terms for boundaries but let's skip that for now and just disable the lemma + create_lemma = false; + tmp->valid_for_lemma = false; + tmp->bit_width = w2; + tmp->interval = eval_interval::proper(dummy, mod(after->interval.lo_val(), lower_mod_value), dummy, mod(e->interval.hi_val(), lower_mod_value)); + + // the index "last_idx+1" is always valid because we allocate an extra dummy space at the end before starting the recursion. + // we only need a single dummy space because the temporary entry is always at bit-width w2. + entry* old = intervals[last_idx + 1]; + intervals[last_idx + 1] = tmp; + + bool result = set_conflict_by_interval_rec(v, w2, &intervals[j], last_idx - j + 2, create_lemma, vars_to_explain); + + intervals[last_idx + 1] = old; + + if (!result) + return false; + + if (create_lemma) { + // hole_length < 2^w2 + signed_constraint c = s.ult(after->interval.lo() - e->interval.hi(), lower_mod_value); + VERIFY(c.is_currently_true(s)); + // this constraint may already exist on the stack with opposite bool value, + // in that case we have a different, earlier conflict + if (c.bvalue(s) == l_false) { + core.reset(); + core.init(~c); + return false; + } + lemma.insert(~c); + } + + tmp->bit_width = w; + tmp->interval = eval_interval::proper(dummy, e->interval.hi_val(), dummy, after->interval.lo_val()); + e = tmp; + j = k; + n = after; + } + + // We may have a bunch of intervals that contain the current value. + // Choose the one making the most progress. + // TODO: it should be the last one, otherwise we wouldn't have added it to relevant_intervals? then we can skip the progress computation. + // (TODO: should note the relevant invariants of intervals list and assert them above.) + SASSERT(n->interval.currently_contains(e->interval.hi_val())); + unsigned best_j = j; + rational best_progress = distance(e->interval.hi_val(), n->interval.hi_val(), mod_value); + while (true) { + unsigned j1 = (j + 1) % num_intervals; + entry* n1 = intervals[j1]; + if (n1->bit_width != w) + break; + if (!n1->interval.currently_contains(e->interval.hi_val())) + break; + j = j1; + n = n1; + SASSERT(n != longest); // because of loop condition on outer while loop + rational progress = distance(e->interval.hi_val(), n->interval.hi_val(), mod_value); + if (progress > best_progress) { + best_j = j; + best_progress = progress; + } + } + j = best_j; + n = intervals[best_j]; + + if (!update_interval_conflict(v, e->interval.hi(), n, core, create_lemma, lemma, vars_to_explain)) + return false; + + i = j; + e = n; + } + + if (!update_interval_conflict(v, e->interval.hi(), longest, core, create_lemma, lemma, vars_to_explain)) + return false; +#endif + + return true; + } + // returns true iff no conflict was encountered bool viable::collect_bit_information(pvar v, bool add_conflict, fixed_bits_info& out_fbi) { @@ -429,7 +607,7 @@ namespace polysat { out_fbi.reset(v_sz); auto& [fixed, just_src, just_side_cond, just_slice] = out_fbi; - svector fbs; + svector fbs; c.get_fixed_bits(v, fbs); for (auto const& fb : fbs) { @@ -625,15 +803,23 @@ namespace polysat { /* * Explain why the current variable is not viable or signleton. */ - dependency_vector viable::explain() { throw default_exception("nyi"); } + dependency_vector viable::explain() { + dependency_vector result; + for (auto e : m_explain) { + auto index = e->constraint_index; + auto const& [sc, d, value] = c.m_constraint_index[index]; + result.push_back(d); + result.append(c.explain_eval(sc)); + } + // TODO: explaination for fixed bits + return result; + } /* * Register constraint at index 'idx' as unitary in v. */ void viable::add_unitary(pvar v, unsigned idx) { - ensure_var(v); - if (c.is_assigned(v)) return; auto [sc, d, value] = c.m_constraint_index[idx]; diff --git a/src/sat/smt/polysat/polysat_viable.h b/src/sat/smt/polysat/polysat_viable.h index 1812ec40c..f426dc326 100644 --- a/src/sat/smt/polysat/polysat_viable.h +++ b/src/sat/smt/polysat/polysat_viable.h @@ -138,6 +138,9 @@ namespace polysat { vector m_units; // set of viable values based on unit multipliers, layered by bit-width in descending order ptr_vector m_equal_lin; // entries that have non-unit multipliers, but are equal ptr_vector m_diseq_lin; // entries that have distinct non-zero multipliers + ptr_vector m_explain; // entries that explain the current propagation or conflict + core_vector m_core; // forbidden interval core + bool m_has_core = false; bool well_formed(entry* e); bool well_formed(layers const& ls); @@ -158,8 +161,6 @@ namespace polysat { bool intersect(pvar v, entry* e); - void ensure_var(pvar v); - lbool find_viable(pvar v, rational& lo, rational& hi); lbool find_on_layers( @@ -180,8 +181,7 @@ namespace polysat { rational const& to_cover_lo, rational const& to_cover_hi, rational& out_val, - ptr_vector& refine_todo, - ptr_vector& relevant_entries); + ptr_vector& refine_todo); template @@ -211,9 +211,8 @@ namespace polysat { throw default_exception("nyi"); } - bool set_conflict_by_interval(pvar v, unsigned w, ptr_vector& intervals, unsigned first_interval) { - throw default_exception("nyi"); - } + void set_conflict_by_interval(pvar v, unsigned w, ptr_vector& intervals, unsigned first_interval); + bool set_conflict_by_interval_rec(pvar v, unsigned w, entry** intervals, unsigned num_intervals, bool& create_lemma, uint_set& vars_to_explain); std::pair find_value(rational const& val, entry* entries) { throw default_exception("nyi"); @@ -236,11 +235,26 @@ namespace polysat { */ dependency_vector explain(); + /* + * flag whether there is a forbidden interval core + */ + bool has_core() const { return m_has_core; } + + /* + * Retrieve lemma corresponding to forbidden interval constraints + */ + core_vector const& get_core() { SASSERT(m_has_core); return m_core; } + /* * Register constraint at index 'idx' as unitary in v. */ void add_unitary(pvar v, unsigned idx); + /* + * Ensure data-structures tracking variable v. + */ + void ensure_var(pvar v); + }; } diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 7e867af2f..460501cd0 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -104,15 +104,26 @@ namespace polysat { return { core, eqs }; } - void solver::set_lemma(vector const& lemma, unsigned level, dependency_vector const& core) { + void solver::set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) { auto [lits, eqs] = explain_deps(core); auto ex = euf::th_explain::conflict(*this, lits, eqs, nullptr); ctx.push(value_trail(m_has_lemma)); m_has_lemma = true; m_lemma_level = level; m_lemma.reset(); - for (auto sc : lemma) - m_lemma.push_back(constraint2expr(sc)); + for (auto sc : aux_core) { + if (std::holds_alternative(sc)) { + auto d = *std::get_if(&sc); + if (d.is_literal()) + m_lemma.push_back(ctx.literal2expr(d.literal())); + else { + auto [v1, v2] = d.eq(); + m_lemma.push_back(ctx.mk_eq(var2enode(v1), var2enode(v2))); + } + } + else if (std::holds_alternative(sc)) + m_lemma.push_back(constraint2expr(*std::get_if(&sc))); + } ctx.set_conflict(ex); } @@ -129,7 +140,7 @@ namespace polysat { sat::literal_vector lits; for (auto* e : m_lemma) - lits.push_back(ctx.mk_literal(e)); + lits.push_back(~ctx.mk_literal(e)); s().add_clause(lits.size(), lits.data(), sat::status::th(true, get_id(), nullptr)); return l_false; } diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 6c7260f95..b5e69c36a 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -128,7 +128,7 @@ namespace polysat { // callbacks from core void add_eq_literal(pvar v, rational const& val) override; void set_conflict(dependency_vector const& core) override; - void set_lemma(vector const& lemma, unsigned level, dependency_vector const& core) override; + void set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) override; dependency propagate(signed_constraint sc, dependency_vector const& deps) override; void propagate(dependency const& d, bool sign, dependency_vector const& deps) override; trail_stack& trail() override; From 21ef68991870b66ee2a14ba1f5b6630a22e5079c Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 9 Dec 2023 17:16:31 -0800 Subject: [PATCH 22/89] n/a Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat/polysat_types.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sat/smt/polysat/polysat_types.h b/src/sat/smt/polysat/polysat_types.h index 1df9912bc..207ea091e 100644 --- a/src/sat/smt/polysat/polysat_types.h +++ b/src/sat/smt/polysat/polysat_types.h @@ -97,6 +97,7 @@ namespace polysat { class solver_interface { public: + virtual ~solver_interface() {} virtual void add_eq_literal(pvar v, rational const& val) = 0; virtual void set_conflict(dependency_vector const& core) = 0; virtual void set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) = 0; From 7ba5d2024dc7cb9c78626b23acb9d85369cee131 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 9 Dec 2023 17:26:32 -0800 Subject: [PATCH 23/89] remove stale file Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat/slicing.cpp | 1727 ------------------------------- src/sat/smt/polysat/slicing.h | 397 ------- 2 files changed, 2124 deletions(-) delete mode 100644 src/sat/smt/polysat/slicing.cpp delete mode 100644 src/sat/smt/polysat/slicing.h diff --git a/src/sat/smt/polysat/slicing.cpp b/src/sat/smt/polysat/slicing.cpp deleted file mode 100644 index 04fe8c4fc..000000000 --- a/src/sat/smt/polysat/slicing.cpp +++ /dev/null @@ -1,1727 +0,0 @@ -/*++ -Copyright (c) 2023 Microsoft Corporation - -Module Name: - - polysat slicing - -Author: - - Jakob Rath 2023-06-01 - ---*/ - - - - -/* - -Example: -(1) x = y -(2) z = y[3:0] -(3) explain(x[3:0] == z)? should be { (1), (2) } - - (1) - x ========================> y - / \ / \ (2) - x[7:4] x[3:0] y[7:4] y[3:0] ===========> z - - -TODO: -- About the sub-slice sharing among equivalent nodes: - - When extracting a variable y := x[h:l], we always need to create a new slice for y. - - Merge slices for x[h:l] with y; store as dependency 'x[h:l]' (rather than 'null_dep' as we do now). - - when merging, we must avoid that the new variable becomes the root of the equivalence class, - because when finding dependencies for 'y := x[h:l]', such extraction-dependencies would be false/unnecessary. - (alternatively, just ignore them. but we never *have* to have them as root, so just don't do it. but add assertions for 1. new var is not root, 2. no extraction-dependency when walking from 'x' to 'x[h:l]'.) - - When encountering this dependency, need to start at slice for 'x' and walk towards 'x[h:l]', - collecting dependencies whenever we move to a representative. -- when solver assigns value of a variable v, add equations with v substituted by its value? - - since we only track equations over variables/names, this might not lead to many further additions - - a question is how to track/handle the dependency on the assignment -- check Algorithm 2 of "Solving Bitvectors with MCSAT"; what is the difference to what we are doing now? -- track equalities such as x = -y ? -- on_merge could propagate values upwards: - if slice has value but parent has no value, then check if sub_other(parent(s)) [sibling(s)?] has a value. - if yes, can propagate value upwards. (add a congruence term to track deps properly?). - we have to check the whole equivalence class, because the parents may be in different classes. - it is enough to propagate values to variables. we could count (in the variable slice) the number of its base slices that are still unassigned. - -*/ - - -#include "ast/reg_decl_plugins.h" -#include "math/polysat/slicing.h" -#include "math/polysat/solver.h" -#include "math/polysat/log.h" -#include "util/tptr.h" - - -namespace { - - template - [[maybe_unused]] - inline constexpr bool always_false_v = false; - -} - -namespace polysat { - - void* slicing::dep_t::encode() const { - void* p = std::visit([](auto arg) -> void* { - using T = std::decay_t; - if constexpr (std::is_same_v) - return nullptr; - else if constexpr (std::is_same_v) - return box(arg.to_uint(), 1); - else if constexpr (std::is_same_v) - return box(arg, 2); - else - static_assert(always_false_v, "non-exhaustive visitor!"); - }, m_data); - SASSERT(*this == decode(p)); - return p; - } - - slicing::dep_t slicing::dep_t::decode(void* p) { - if (!p) - return {}; - unsigned tag = get_tag(p); - SASSERT(tag == 1 || tag == 2); - if (tag == 1) - return dep_t(sat::to_literal(unbox(p))); - else - return dep_t(unbox(p)); - } - - std::ostream& slicing::display(std::ostream& out, dep_t d) const { - if (d.is_null()) - out << "null"; - else if (d.is_value()) { - pvar x = get_dep_var(d); - enode* n = get_dep_slice(d); - sat::literal lit = get_dep_lit(d); - out << "value(v" << x << " on slice "; - if (n) - out << n->get_id(); - else - out << ""; - if (lit != sat::null_literal) - out << " by literal " << lit; - out << ")"; - } - else if (d.is_lit()) - out << "lit(" << d.lit() << ")"; - return out; - } - - slicing::dep_t slicing::mk_var_dep(pvar v, enode* s, sat::literal lit) { - SASSERT_EQ(m_dep_var.size(), m_dep_slice.size()); - SASSERT_EQ(m_dep_var.size(), m_dep_lit.size()); - unsigned const idx = m_dep_var.size(); - m_dep_var.push_back(v); - m_dep_lit.push_back(lit); - m_dep_slice.push_back(s); - return dep_t(idx); - } - - slicing::slicing(solver& s): - m_solver(s), - m_egraph(m_ast) - { - reg_decl_plugins(m_ast); - m_bv = alloc(bv_util, m_ast); - m_egraph.set_display_justification([&](std::ostream& out, void* dp) { display(out, dep_t::decode(dp)); }); - m_egraph.set_on_merge([&](enode* root, enode* other) { egraph_on_merge(root, other); }); - m_egraph.set_on_propagate([&](enode* lit, enode* ante) { egraph_on_propagate(lit, ante); }); - // m_egraph.set_on_make([&](enode* n) { egraph_on_make(n); }); - } - - slicing::slice_info& slicing::info(enode* n) { - return const_cast(std::as_const(*this).info(n)); - } - - slicing::slice_info const& slicing::info(enode* n) const { - SASSERT(n); - SASSERT(!n->is_equality()); - SASSERT(m_bv->is_bv_sort(n->get_sort())); - slice_info const& i = m_info[n->get_id()]; - return i.slice ? info(i.slice) : i; - } - - bool slicing::is_slice(enode* n) const { - if (n->is_equality()) - return false; - SASSERT(m_bv->is_bv_sort(n->get_sort())); - slice_info const& i = m_info[n->get_id()]; - return !i.slice; - } - - bool slicing::is_concat(enode* n) const { - if (n->is_equality()) - return false; - return !is_slice(n); - } - - unsigned slicing::width(enode* s) const { - SASSERT(!s->is_equality()); - return m_bv->get_bv_size(s->get_expr()); - } - - slicing::enode* slicing::sibling(enode* s) const { - enode* p = parent(s); - SASSERT(p); - SASSERT(sub_lo(p) == s || sub_hi(p) == s); - if (s != sub_hi(p)) - return sub_hi(p); - else - return sub_lo(p); - } - - func_decl* slicing::mk_concat_decl(ptr_vector const& args) { - SASSERT(args.size() >= 2); - ptr_vector domain; - unsigned sz = 0; - for (expr* e : args) { - domain.push_back(e->get_sort()); - sz += m_bv->get_bv_size(e); - } - sort* range = m_bv->mk_sort(sz); - return m_ast.mk_func_decl(symbol("slice-concat"), domain.size(), domain.data(), range); - } - - void slicing::push_scope() { - LOG("push_scope"); - if (can_propagate()) - propagate(); - m_scopes.push_back(m_trail.size()); - m_egraph.push(); - m_dep_size_trail.push_back(m_dep_var.size()); - SASSERT(!use_var_congruences() || m_needs_congruence.empty()); - } - - void slicing::pop_scope(unsigned num_scopes) { - LOG("pop_scope(" << num_scopes << ")"); - if (num_scopes == 0) - return; - unsigned const lvl = m_scopes.size(); - SASSERT(num_scopes <= lvl); - unsigned const target_lvl = lvl - num_scopes; - unsigned const target_size = m_scopes[target_lvl]; - m_scopes.shrink(target_lvl); - svector replay_trail; - unsigned_vector replay_add_var_trail; - svector> replay_extract_trail; - svector replay_concat_trail; - unsigned num_replay_concat = 0; - for (unsigned i = m_trail.size(); i-- > target_size; ) { - switch (m_trail[i]) { - case trail_item::add_var: - replay_trail.push_back(trail_item::add_var); - replay_add_var_trail.push_back(width(m_var2slice.back())); - undo_add_var(); - break; - case trail_item::split_core: - undo_split_core(); - break; - case trail_item::mk_extract: { - replay_trail.push_back(trail_item::mk_extract); - extract_args const& args = m_extract_trail.back(); - replay_extract_trail.push_back({args, m_extract_dedup[args]}); - undo_mk_extract(); - break; - } - case trail_item::mk_concat: - replay_trail.push_back(trail_item::mk_concat); - num_replay_concat++; - break; - case trail_item::set_value_node: - undo_set_value_node(); - break; - default: - UNREACHABLE(); - } - } - m_egraph.pop(num_scopes); - m_needs_congruence.reset(); - m_disequality_conflict = nullptr; - m_dep_var.shrink(m_dep_size_trail[target_lvl]); - m_dep_lit.shrink(m_dep_size_trail[target_lvl]); - m_dep_slice.shrink(m_dep_size_trail[target_lvl]); - m_dep_size_trail.shrink(target_lvl); - m_trail.shrink(target_size); - // replay add_var/mk_extract/mk_concat in the same order - // (only until polysat::solver supports proper garbage collection of variables) - unsigned add_var_idx = replay_add_var_trail.size(); - unsigned extract_idx = replay_extract_trail.size(); - unsigned concat_idx = m_concat_trail.size() - num_replay_concat; - for (auto it = replay_trail.rbegin(); it != replay_trail.rend(); ++it) { - switch (*it) { - case trail_item::add_var: { - unsigned const sz = replay_add_var_trail[--add_var_idx]; - add_var(sz); - break; - } - case trail_item::mk_extract: { - auto const [args, v] = replay_extract_trail[--extract_idx]; - replay_extract(args, v); - break; - } - case trail_item::mk_concat: { - NOT_IMPLEMENTED_YET(); - auto const ci = m_concat_trail[concat_idx++]; - num_replay_concat++; - replay_concat(ci.num_args, &m_concat_args[ci.args_idx], ci.v); - break; - } - default: - UNREACHABLE(); - } - } - SASSERT(invariant()); - } - - void slicing::add_var(unsigned bit_width) { - pvar const v = m_var2slice.size(); - enode* s = alloc_slice(bit_width, v); - m_var2slice.push_back(s); - m_trail.push_back(trail_item::add_var); - LOG_V(10, "add_var: v" << v << " -> " << slice_pp(*this, s)); - } - - void slicing::undo_add_var() { - m_var2slice.pop_back(); - } - - slicing::enode* slicing::find_or_alloc_disequality(enode* x, enode* y, sat::literal lit) { - expr_ref eq(m_ast.mk_eq(x->get_expr(), y->get_expr()), m_ast); - enode* eqn = m_egraph.find(eq); - if (eqn) - return eqn; - auto args = {x, y}; - eqn = m_egraph.mk(eq, 0, args.size(), args.begin()); - auto j = euf::justification::external(dep_t(lit).encode()); - m_egraph.set_value(eqn, l_false, j); - SASSERT(eqn->is_equality()); - SASSERT_EQ(eqn->value(), l_false); - return eqn; - } - - void slicing::egraph_on_make(enode* n) { - LOG("on_make: " << e_pp(n)); - } - - slicing::enode* slicing::alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var) { - SASSERT(!m_egraph.find(e)); - // NOTE: sometimes egraph::mk already triggers a merge due to congruence. - // in this case we have to make sure to allocate m_info early enough. - unsigned const id = e->get_id(); - m_info.reserve(id + 1); - slice_info& i = m_info[id]; - i.reset(); - i.var = var; - enode* n = m_egraph.mk(e, 0, num_args, args); // NOTE: the egraph keeps a strong reference to 'e' - LOG_V(10, "alloc_enode: " << slice_pp(*this, n) << " " << e_pp(n)); - return n; - } - - slicing::enode* slicing::find_or_alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var) { - enode* n = m_egraph.find(e); - if (n) { - SASSERT_EQ(info(n).var, var); - return n; - } - return alloc_enode(e, num_args, args, var); - } - - slicing::enode* slicing::alloc_slice(unsigned width, pvar var) { - SASSERT(width > 0); - app_ref a(m_ast.mk_fresh_const("s", m_bv->mk_sort(width), false), m_ast); - return alloc_enode(a, 0, nullptr, var); - } - - slicing::enode* slicing::mk_concat_node(enode_vector const& slices) { - return mk_concat_node(slices.size(), slices.data()); - } - - slicing::enode* slicing::mk_concat_node(unsigned num_slices, enode* const* slices) { - ptr_vector args; - for (unsigned i = 0; i < num_slices; ++i) - args.push_back(slices[i]->get_expr()); - app_ref a(m_ast.mk_app(mk_concat_decl(args), args), m_ast); - return find_or_alloc_enode(a, num_slices, slices, null_var); - } - - void slicing::add_concat_node(enode* s, enode* concat) { - SASSERT(slice2var(s) != null_var); // all concat nodes should point to a variable node - SASSERT(is_app(concat->get_expr())); - slice_info& concat_info = m_info[concat->get_id()]; - if (s->get_root() == concat->get_root()) { - SASSERT_EQ(s, concat_info.slice); - return; - } - SASSERT(!concat_info.slice); // not yet set - concat_info.slice = s; - egraph_merge(s, concat, dep_t()); - } - - void slicing::add_var_congruence(pvar v) { - enode_vector& base = m_tmp2; - SASSERT(base.empty()); - enode* sv = var2slice(v); - get_base(sv, base); - // Add equation v == concat(s1, ..., sn) - add_concat_node(sv, mk_concat_node(base)); - base.clear(); - } - - void slicing::add_var_congruence_if_needed(pvar v) { - if (!m_needs_congruence.contains(v)) - return; - m_needs_congruence.remove(v); - add_var_congruence(v); - } - - void slicing::update_var_congruences() { - if (!use_var_congruences()) - return; - // TODO: this is only needed once per equivalence class - // (mark root of var2slice to detect duplicates?) - for (pvar v : m_needs_congruence) - add_var_congruence(v); - m_needs_congruence.reset(); - } - - bool slicing::use_var_congruences() const { - return m_solver.config().m_slicing_congruence; - } - - // split a single slice without updating any equivalences - void slicing::split_core(enode* s, unsigned cut) { - SASSERT(is_slice(s)); // this action only makes sense for slices - SASSERT(!has_sub(s)); - SASSERT(info(s).sub_hi == nullptr && info(s).sub_lo == nullptr); - SASSERT(width(s) > cut + 1); - unsigned const width_hi = width(s) - cut - 1; - unsigned const width_lo = cut + 1; - enode* sub_hi; - enode* sub_lo; - if (is_value(s)) { - rational const val = get_value(s); - sub_hi = mk_value_slice(machine_div2k(val, width_lo), width_hi); - sub_lo = mk_value_slice(mod2k(val, width_lo), width_lo); - } - else { - sub_hi = alloc_slice(width_hi); - sub_lo = alloc_slice(width_lo); - } - SASSERT(!parent(sub_hi)); - SASSERT(!parent(sub_lo)); - info(sub_hi).parent = s; - info(sub_lo).parent = s; - info(s).set_cut(cut, sub_hi, sub_lo); - m_trail.push_back(trail_item::split_core); - m_enode_trail.push_back(s); - for (enode* n = s; n != nullptr; n = parent(n)) { - pvar const v = slice2var(n); - if (v == null_var) - continue; - if (m_needs_congruence.contains(v)) { - SASSERT(invariant_needs_congruence()); - break; // added parents already previously - } - m_needs_congruence.insert(v); - } - } - - bool slicing::invariant_needs_congruence() const { - for (pvar v : m_needs_congruence) - for (enode* s = var2slice(v); s != nullptr; s = parent(s)) - if (slice2var(s) != null_var) { - VERIFY(m_needs_congruence.contains(slice2var(s))); - } - return true; - } - - void slicing::undo_split_core() { - enode* s = m_enode_trail.back(); - m_enode_trail.pop_back(); - info(s).set_cut(null_cut, nullptr, nullptr); - } - - void slicing::split(enode* s, unsigned cut) { - // this action only makes sense for base slices. - // a base slice is never equivalent to a congruence node. - SASSERT(is_slice(s)); - SASSERT(!has_sub(s)); - SASSERT(cut != null_cut); - // split all slices in the equivalence class - for (enode* n : euf::enode_class(s)) - split_core(n, cut); - // propagate equivalences to subslices - for (enode* n : euf::enode_class(s)) { - enode* target = n->get_target(); - if (!target) - continue; - euf::justification const j = n->get_justification(); - SASSERT(j.is_external()); // cannot be a congruence since the slice wasn't split before. - dep_t const d = dep_t::decode(j.ext()); - egraph_merge(sub_hi(n), sub_hi(target), d); - egraph_merge(sub_lo(n), sub_lo(target), d); - } - } - - void slicing::mk_slice(enode* src, unsigned const hi, unsigned const lo, enode_vector& out, bool output_full_src, bool output_base) { - SASSERT(hi >= lo); - SASSERT(width(src) > hi); // extracted range must be fully contained inside the src slice - auto output_slice = [this, output_base, &out](enode* s) { - if (output_base) - get_base(s, out); - else - out.push_back(s); - }; - if (lo == 0 && width(src) - 1 == hi) { - output_slice(src); - return; - } - if (has_sub(src)) { - // src is split into [src.width-1, cut+1] and [cut, 0] - unsigned const cut = info(src).cut; - if (lo >= cut + 1) { - // target slice falls into upper subslice - mk_slice(sub_hi(src), hi - cut - 1, lo - cut - 1, out, output_full_src, output_base); - if (output_full_src) - output_slice(sub_lo(src)); - return; - } - else if (cut >= hi) { - // target slice falls into lower subslice - if (output_full_src) - output_slice(sub_hi(src)); - mk_slice(sub_lo(src), hi, lo, out, output_full_src, output_base); - return; - } - else { - SASSERT(hi > cut && cut >= lo); - // desired range spans over the cutpoint, so we get multiple slices in the result - mk_slice(sub_hi(src), hi - cut - 1, 0, out, output_full_src, output_base); - mk_slice(sub_lo(src), cut, lo, out, output_full_src, output_base); - return; - } - } - else { - // [src.width-1, 0] has no subdivision yet - if (width(src) - 1 > hi) { - split(src, hi); - SASSERT(!has_sub(sub_hi(src))); - if (output_full_src) - out.push_back(sub_hi(src)); - mk_slice(sub_lo(src), hi, lo, out, output_full_src, output_base); // recursive call to take care of case lo > 0 - return; - } - else { - SASSERT(lo > 0); - split(src, lo - 1); - out.push_back(sub_hi(src)); - SASSERT(!has_sub(sub_lo(src))); - if (output_full_src) - out.push_back(sub_lo(src)); - return; - } - } - UNREACHABLE(); - } - - slicing::enode* slicing::mk_value_slice(rational const& val, unsigned bit_width) { - SASSERT(bit_width > 0); - SASSERT(0 <= val && val < rational::power_of_two(bit_width)); - sort* bv_sort = m_bv->mk_sort(bit_width); - func_decl_ref f(m_ast.mk_fresh_func_decl("val", nullptr, 1, &bv_sort, bv_sort, false), m_ast); - app_ref a(m_ast.mk_app(f, m_bv->mk_numeral(val, bit_width)), m_ast); - enode* s = alloc_enode(a, 0, nullptr, null_var); - set_value_node(s, s); - SASSERT_EQ(get_value(s), val); - return s; - } - - slicing::enode* slicing::mk_interpreted_value_node(enode* s) { - SASSERT(is_value(s)); - // NOTE: how this is used now, the node will not yet be contained in the egraph. - enode* n = alloc_enode(s->get_app()->get_arg(0), 0, nullptr, null_var); - info(n).value_node = s; - n->mark_interpreted(); - SASSERT(n->interpreted()); - SASSERT_EQ(get_value_node(n), s); - return n; - } - - bool slicing::is_value(enode* n) const { - SASSERT(n); - SASSERT(is_app(n->get_expr())); // we only create app nodes at the moment; if this fails just return false here. - app* a = n->get_app(); - return a->get_num_args() == 1 && m_bv->is_numeral(a->get_arg(0)); - } - - rational slicing::get_value(enode* s) const { - SASSERT(is_value(s)); - rational val; - VERIFY(try_get_value(s, val)); - return val; - } - - bool slicing::try_get_value(enode* s, rational& val) const { - if (!s) - return false; - app* a = s->get_app(); - if (a->get_num_args() != 1) - return false; - bool const ok = m_bv->is_numeral(a->get_arg(0), val); - SASSERT_EQ(ok, is_value(s)); - return ok; - } - - void slicing::explain_class(enode* x, enode* y, ptr_vector& out_deps) { - SASSERT_EQ(x->get_root(), y->get_root()); - m_egraph.begin_explain(); - m_egraph.explain_eq(out_deps, nullptr, x, y); - m_egraph.end_explain(); - } - - void slicing::explain_equal(enode* x, enode* y, ptr_vector& out_deps) { - SASSERT(is_equal(x, y)); - SASSERT_EQ(width(x), width(y)); - enode_vector& xs = m_tmp2; - enode_vector& ys = m_tmp3; - SASSERT(xs.empty()); - SASSERT(ys.empty()); - xs.push_back(x); - ys.push_back(y); - while (!xs.empty()) { - SASSERT(!ys.empty()); - enode* const x = xs.back(); xs.pop_back(); - enode* const y = ys.back(); ys.pop_back(); - if (x == y) - continue; - if (width(x) == width(y)) { - enode* const rx = x->get_root(); - enode* const ry = y->get_root(); - if (rx == ry) - explain_class(x, y, out_deps); - else { - xs.push_back(sub_hi(x)); - xs.push_back(sub_lo(x)); - ys.push_back(sub_hi(y)); - ys.push_back(sub_lo(y)); - } - } - else if (width(x) > width(y)) { - xs.push_back(sub_hi(x)); - xs.push_back(sub_lo(x)); - ys.push_back(y); - } - else { - SASSERT(width(x) < width(y)); - xs.push_back(x); - ys.push_back(sub_hi(y)); - ys.push_back(sub_lo(y)); - } - } - SASSERT(ys.empty()); - } - - void slicing::explain_equal(pvar x, pvar y, ptr_vector& out_deps) { - explain_equal(var2slice(x), var2slice(y), out_deps); - } - - void slicing::explain_equal(pvar x, pvar y, std::function const& on_lit) { - SASSERT(m_marked_lits.empty()); - SASSERT(m_tmp_deps.empty()); - explain_equal(x, y, m_tmp_deps); - for (void* dp : m_tmp_deps) { - dep_t const d = dep_t::decode(dp); - if (d.is_null()) - continue; - if (d.is_lit()) { - sat::literal lit = d.lit(); - if (m_marked_lits.contains(lit)) - continue; - m_marked_lits.insert(lit); - on_lit(d.lit()); - } - else { - // equivalence between to variables cannot be due to value assignment - UNREACHABLE(); - } - } - m_marked_lits.reset(); - m_tmp_deps.reset(); - } - - void slicing::explain(ptr_vector& out_deps) { - SASSERT(is_conflict()); - m_egraph.begin_explain(); - if (m_disequality_conflict) { - LOG("Disequality conflict: " << m_disequality_conflict); - enode* eqn = m_disequality_conflict; - SASSERT(eqn->is_equality()); - SASSERT_EQ(eqn->value(), l_false); - SASSERT(eqn->get_lit_justification().is_external()); - SASSERT(m_ast.is_eq(eqn->get_expr())); - SASSERT_EQ(eqn->get_arg(0)->get_root(), eqn->get_arg(1)->get_root()); - m_egraph.explain_eq(out_deps, nullptr, eqn->get_arg(0), eqn->get_arg(1)); - out_deps.push_back(eqn->get_lit_justification().ext()); - } - else { - SASSERT(m_egraph.inconsistent()); - m_egraph.explain(out_deps, nullptr); - } - m_egraph.end_explain(); - } - - clause_ref slicing::build_conflict_clause() { - LOG_H1("slicing: build_conflict_clause"); - // display_tree(verbose_stream()); - - SASSERT(invariant()); - SASSERT(is_conflict()); - SASSERT(m_marked_lits.empty()); - SASSERT(m_tmp_deps.empty()); - explain(m_tmp_deps); - clause_builder cb(m_solver, "slicing"); - - auto add_premise = [this, &cb](sat::literal lit) { - LOG("Premise: " << lit_pp(m_solver, lit)); - cb.insert(~lit); - }; - - auto add_conclusion = [this, &cb](signed_constraint c) { - LOG("Conclusion: " << lit_pp(m_solver, c)); - cb.insert_eval(c); - }; - - pvar x = null_var; enode* sx = nullptr; sat::literal xlit = sat::null_literal; - pvar y = null_var; enode* sy = nullptr; sat::literal ylit = sat::null_literal; - for (void* dp : m_tmp_deps) { - dep_t const d = dep_t::decode(dp); - // LOG("dep: " << dep_pp(*this, d)); - if (d.is_null()) - continue; - if (d.is_lit()) { - sat::literal const lit = d.lit(); - if (m_marked_lits.contains(lit)) - continue; - m_marked_lits.insert(lit); - add_premise(lit); - } - else { - SASSERT(d.is_value()); - pvar const v = get_dep_var(d); - enode* const sv = get_dep_slice(d); - sat::literal const lit = get_dep_lit(d); - if (x == null_var) - x = v, sx = sv, xlit = lit; - else if (y == null_var) - y = v, sy = sv, ylit = lit; - else { - // pvar justifications are only introduced by add_value, i.e., when a variable is assigned in the solver. - // thus there can be at most two pvar justifications in a single conflict. - UNREACHABLE(); - } - } - } - m_marked_lits.reset(); - m_tmp_deps.reset(); - - if (x != null_var && y != null_var && xlit == sat::null_literal && ylit != sat::null_literal) { - using std::swap; - swap(x, y); - swap(sx, sy); - swap(xlit, ylit); - } - - if (x != null_var) { - LOG("Variable v" << x << " with slice " << slice_pp(*this, sx) << " by literal " << lit_pp(m_solver, xlit)); - if (m_solver.is_assigned(x)) - LOG("solver-value " << assignment_pp(m_solver, x, m_solver.get_value(x))); - } - if (y != null_var) { - LOG("Variable v" << y << " with slice " << slice_pp(*this, sy) << " by literal " << lit_pp(m_solver, ylit)); - if (m_solver.is_assigned(y)) - LOG("solver-value " << assignment_pp(m_solver, y, m_solver.get_value(y))); - } - - // conflict has either 0 or 2 vars - VERIFY(x != null_var || y == null_var); - VERIFY(y != null_var || x == null_var); - - if (xlit != sat::null_literal && ylit != sat::null_literal) { - verbose_stream() << "build_conflict_clause (2)" << std::endl; - add_premise(xlit); - add_premise(ylit); - } - else if (xlit != sat::null_literal && ylit == sat::null_literal) { - verbose_stream() << "build_conflict_clause (1)" << std::endl; - add_premise(xlit); - - // rational const x_slice_value = get_value(get_value_node(var2slice(x))); - // LOG("v" << x << " slice_value: " << x_slice_value); - rational const y_slice_value = get_value(get_value_node(var2slice(y))); - LOG("v" << y << " slice_value: " << y_slice_value); - // SASSERT(x_slice_value != y_slice_value); - add_conclusion(~m_solver.eq(m_solver.var(y), y_slice_value)); - -/* - unsigned x_hi, x_lo; - VERIFY(find_range_in_ancestor(sx, var2slice(x), x_hi, x_lo)); - pvar const xx = mk_extract(x, x_hi, x_lo); - LOG("find_range_in_ancestor: v" << x << "[" << x_hi << ":" << x_lo << "] = " << slice_pp(*this, sx) << " --> represented by variable v" << xx); - unsigned y_hi, y_lo; - VERIFY(find_range_in_ancestor(sy, var2slice(y), y_hi, y_lo)); - pvar const yy = mk_extract(y, y_hi, y_lo); - LOG("find_range_in_ancestor: v" << y << "[" << y_hi << ":" << y_lo << "] = " << slice_pp(*this, sy) << " --> represented by variable v" << yy); - // LOG("v" << x << " has solver-value? " << m_solver.is_assigned(x)); - if (m_solver.is_assigned(x)) LOG("v" << x << " has solver-value " << m_solver.get_value(x)); - // LOG("v" << y << " has solver-value? " << m_solver.is_assigned(y)); - if (m_solver.is_assigned(y)) LOG("v" << y << " has solver-value " << m_solver.get_value(y)); - LOG("v" << x << " is slice " << slice_pp(*this, var2slice(x))); - LOG("v" << y << " is slice " << slice_pp(*this, var2slice(y))); - SASSERT_EQ(sy->get_root(), var2slice(yy)->get_root()); - rational const sy_slice_value = get_value(get_value_node(sy)); - // rational const sy_solver_value = mod2k(machine_div2k(m_solver.get_value(y), lo), hi - lo + 1); - // c = m_solver.eq(m_solver.var(yy), sy_slice_value); -*/ - } - else { - verbose_stream() << "build_conflict_clause (0)" << std::endl; - SASSERT(xlit == sat::null_literal); - SASSERT(ylit == sat::null_literal); - - // unsigned x_hi, x_lo, y_hi, y_lo; - // VERIFY(find_range_in_ancestor(sx, var2slice(x), x_hi, x_lo)); - // VERIFY(find_range_in_ancestor(sy, var2slice(y), y_hi, y_lo)); - // pvar const xx = mk_extract(x, x_hi, x_lo); - // pvar const yy = mk_extract(y, y_hi, y_lo); - // SASSERT_EQ(sx->get_root(), var2slice(xx)->get_root()); - // SASSERT_EQ(sy->get_root(), var2slice(yy)->get_root()); - // rational sval = mod2k(machine_div2k(m_solver.get_value(x), x_lo), x_hi - x_lo + 1); - // LOG("find_range_in_ancestor: v" << x << "[" << x_hi << ":" << x_lo << "] = " << slice_pp(*this, sx) << " --> represented by variable v" << xx); - // LOG("find_range_in_ancestor: v" << y << "[" << y_hi << ":" << y_lo << "] = " << slice_pp(*this, sy) << " --> represented by variable v" << yy); - LOG("v" << x << " is slice " << slice_pp(*this, var2slice(x))); - LOG("v" << y << " is slice " << slice_pp(*this, var2slice(y))); - if (m_solver.is_assigned(x)) LOG("v" << x << " has solver-value " << m_solver.get_value(x)); - if (m_solver.is_assigned(y)) LOG("v" << y << " has solver-value " << m_solver.get_value(y)); - // SASSERT(xx != yy); - // c = m_solver.eq(m_solver.var(xx), m_solver.var(yy)); // similar to what Algorithm 1 in BitvectorsMCSAT is doing - // LOG("c: " << lit_pp(m_solver, c)); - - rational const x_slice_value = get_value(get_value_node(var2slice(x))); - LOG("v" << x << " slice-value: " << x_slice_value); - add_conclusion(~m_solver.eq(m_solver.var(x), x_slice_value)); - - rational const y_slice_value = get_value(get_value_node(var2slice(y))); - LOG("v" << y << " slice-value: " << y_slice_value); - add_conclusion(~m_solver.eq(m_solver.var(y), y_slice_value)); - } - - // TODO: we don't need clauses like this ... rather set up the conflict core from it - - return cb.build(); - } - - void slicing::explain_value(enode* s, std::function const& on_lit, std::function const& on_var) { - SASSERT(invariant()); - SASSERT(m_marked_lits.empty()); - - enode* n = get_value_node(s); - SASSERT(is_value(n)); - - SASSERT(m_tmp_deps.empty()); - explain_equal(s, n, m_tmp_deps); - - for (void* dp : m_tmp_deps) { - dep_t const d = dep_t::decode(dp); - if (d.is_null()) - continue; - if (d.is_lit()) { - sat::literal const lit = d.lit(); - if (!m_marked_lits.contains(lit)) { - on_lit(lit); - m_marked_lits.insert(lit); - } - } - else { - SASSERT(d.is_value()); - sat::literal const lit = get_dep_lit(d); - if (lit == sat::null_literal) - on_var(get_dep_var(d)); - else if (!m_marked_lits.contains(lit)) { - on_lit(lit); - m_marked_lits.insert(lit); - } - } - } - m_tmp_deps.reset(); - m_marked_lits.reset(); - } - - void slicing::explain_value(pvar v, std::function const& on_lit, std::function const& on_var) { - explain_value(var2slice(v), on_lit, on_var); - } - - bool slicing::find_range_in_ancestor(enode* s, enode* a, unsigned& out_hi, unsigned& out_lo) { - out_hi = width(s) - 1; - out_lo = 0; - while (true) { - if (s == a) - return true; - enode* p = parent(s); - if (!p) - return false; - if (s == sub_hi(p)) { - unsigned offset = 1 + info(p).cut; - out_hi += offset; - out_lo += offset; - } - else { - SASSERT_EQ(s, sub_lo(p)); - /* range stays unchanged */ - } - s = p; - } - } - - bool slicing::is_extract(pvar x, pvar src, unsigned& out_hi, unsigned& out_lo) { - return find_range_in_ancestor(var2slice(x), var2slice(src), out_hi, out_lo); - } - - void slicing::egraph_on_merge(enode* root, enode* other) { - LOG("on_merge: root " << slice_pp(*this, root) << " other " << slice_pp(*this, other)); - if (root->interpreted()) - return; - if (root->is_equality()) { - SASSERT(other->is_equality()); - return; - } - SASSERT(!other->interpreted()); // by convention, interpreted nodes are always chosen as root - SASSERT(root != other); - SASSERT_EQ(root, root->get_root()); - SASSERT_EQ(root, other->get_root()); - - enode* const v1 = info(root).value_node; // root is the root - enode* const v2 = info(other).value_node; // 'other' was its own root before the merge - if (v1 && v2 && get_value(v1) != get_value(v2)) { - // we have a conflict. add interpreted enodes to make the egraph realize it. - enode* const i1 = mk_interpreted_value_node(v1); - enode* const i2 = mk_interpreted_value_node(v2); - m_egraph.merge(i1, v1, dep_t().encode()); - m_egraph.merge(i2, v2, dep_t().encode()); - SASSERT(is_conflict()); - return; - } - - enode* const v = v1 ? v1 : v2; - if (v && !(v1 && v2)) { - // exactly one node has a value - rational const val = get_value(v); - for (enode* n : euf::enode_class(other)) { - enode* const vn = get_value_node(n); - if (!vn) - set_value_node(n, v); - - pvar const var = slice2var(n); - if (var == null_var) - continue; - if (m_solver.is_assigned(var)) - continue; - LOG("on_merge: v" << var << " := " << val); - m_solver.assign_propagate_by_slicing(var, val); - } - } - } - - void slicing::set_value_node(enode* s, enode* value_node) { - SASSERT(!info(s).value_node); - SASSERT(is_value(value_node)); - info(s).value_node = value_node; - if (s != value_node) { - m_trail.push_back(trail_item::set_value_node); - m_enode_trail.push_back(s); - } - } - - void slicing::undo_set_value_node() { - enode* s = m_enode_trail.back(); - m_enode_trail.pop_back(); - info(s).value_node = nullptr; - } - - void slicing::egraph_on_propagate(enode* lit, enode* ante) { - // ante may be set when symmetric equality is added by congruence - if (ante) - return; - // on_propagate may be called before set_value - if (lit->value() == l_undef) - return; - SASSERT(lit->is_equality()); - SASSERT_EQ(lit->value(), l_false); - SASSERT(lit->get_lit_justification().is_external()); - m_disequality_conflict = lit; - } - - bool slicing::can_propagate() const { - if (use_var_congruences() && !m_needs_congruence.empty()) - return true; - return m_egraph.can_propagate(); - } - - void slicing::propagate() { - // m_egraph.propagate(); - if (is_conflict()) - return; - update_var_congruences(); - m_egraph.propagate(); - } - - bool slicing::egraph_merge(enode* s1, enode* s2, dep_t dep) { - LOG("egraph_merge: " << slice_pp(*this, s1) << " and " << slice_pp(*this, s2) << " by " << dep_pp(*this, dep)); - SASSERT_EQ(width(s1), width(s2)); - if (dep.is_value()) { - if (is_value(s1)) - std::swap(s1, s2); - SASSERT(is_value(s2)); - SASSERT(!is_value(s1)); // we never merge two value slices directly - if (get_dep_slice(dep) != s1) - dep = mk_var_dep(get_dep_var(dep), s1, get_dep_lit(dep)); - } - m_egraph.merge(s1, s2, dep.encode()); - return !is_conflict(); - } - - bool slicing::merge_base(enode* s1, enode* s2, dep_t dep) { - SASSERT(!has_sub(s1)); - SASSERT(!has_sub(s2)); - return egraph_merge(s1, s2, dep); - } - - bool slicing::merge(enode_vector& xs, enode_vector& ys, dep_t dep) { - while (!xs.empty()) { - SASSERT(!ys.empty()); - enode* const x = xs.back(); - enode* const y = ys.back(); - xs.pop_back(); - ys.pop_back(); - if (x == y) - continue; - if (x->get_root() == y->get_root()) { - DEBUG_CODE({ - // invariant: parents merged => base slices merged - enode_vector const x_base = get_base(x); - enode_vector const y_base = get_base(y); - SASSERT_EQ(x_base.size(), y_base.size()); - for (unsigned i = x_base.size(); i-- > 0; ) { - SASSERT_EQ(x_base[i]->get_root(), y_base[i]->get_root()); - } - }); - continue; - } -#if 0 - if (has_sub(x)) { - get_base(x, xs); - x = xs.back(); - xs.pop_back(); - } - if (has_sub(y)) { - get_base(y, ys); - y = ys.back(); - ys.pop_back(); - } - SASSERT(!has_sub(x)); - SASSERT(!has_sub(y)); - if (width(x) == width(y)) { - if (!merge_base(x, y, dep)) { - xs.clear(); - ys.clear(); - return false; - } - } - else if (width(x) > width(y)) { - // need to split x according to y - mk_slice(x, width(y) - 1, 0, xs, true); - ys.push_back(y); - } - else { - SASSERT(width(y) > width(x)); - // need to split y according to x - mk_slice(y, width(x) - 1, 0, ys, true); - xs.push_back(x); - } -#else - if (width(x) == width(y)) { - // We may merge slices if at least one of them doesn't have a subslice yet, - // because in that case all intermediate cut points will be aligned. - // NOTE: it is necessary to merge intermediate slices for value nodes, to ensure downward propagation of assignments. - bool const should_merge = (!has_sub(x) || !has_sub(y)); - // If either slice has a subdivision, we have to cut the other and advance to subslices - if (has_sub(x) || has_sub(y)) { - if (!has_sub(x)) - split(x, get_cut(y)); - if (!has_sub(y)) - split(y, get_cut(x)); - xs.push_back(sub_hi(x)); - xs.push_back(sub_lo(x)); - ys.push_back(sub_hi(y)); - ys.push_back(sub_lo(y)); - } - // We may only merge intermediate nodes after we're done with splitting (since we currently split the whole equivalence class at once) - if (should_merge) { - if (!egraph_merge(x, y, dep)) { - xs.clear(); - ys.clear(); - return false; - } - } - } - else if (width(x) > width(y)) { - if (!has_sub(x)) - split(x, width(y) - 1); - // split(x, has_sub(y) ? get_cut(y) : (width(y) - 1)); - xs.push_back(sub_hi(x)); - xs.push_back(sub_lo(x)); - ys.push_back(y); - } - else { - SASSERT(width(y) > width(x)); - if (!has_sub(y)) - split(y, width(x) - 1); - // split(y, has_sub(x) ? get_cut(x) : (width(x) - 1)); - ys.push_back(sub_hi(y)); - ys.push_back(sub_lo(y)); - xs.push_back(x); - } -#endif - } - SASSERT(ys.empty()); - return true; - } - - bool slicing::merge(enode_vector& xs, enode* y, dep_t dep) { - enode_vector& ys = m_tmp2; - SASSERT(ys.empty()); - ys.push_back(y); - return merge(xs, ys, dep); // will clear xs and ys - } - - bool slicing::merge(enode* x, enode* y, dep_t dep) { - LOG("merge: " << slice_pp(*this, x) << " and " << slice_pp(*this, y)); - SASSERT_EQ(width(x), width(y)); - if (!has_sub(x) && !has_sub(y)) - return merge_base(x, y, dep); - enode_vector& xs = m_tmp2; - enode_vector& ys = m_tmp3; - SASSERT(xs.empty()); - SASSERT(ys.empty()); - xs.push_back(x); - ys.push_back(y); - return merge(xs, ys, dep); // will clear xs and ys - } - - bool slicing::is_equal(enode* x, enode* y) { - SASSERT_EQ(width(x), width(y)); - x = x->get_root(); - y = y->get_root(); - if (x == y) - return true; - enode_vector& xs = m_tmp2; - enode_vector& ys = m_tmp3; - SASSERT(xs.empty()); - SASSERT(ys.empty()); - on_scope_exit clear_vectors([&xs, &ys](){ - xs.clear(); - ys.clear(); - }); - // TODO: we don't always have to collect the full base if intermediate nodes are already equal - get_base(x, xs); - get_base(y, ys); - if (xs.size() != ys.size()) - return false; - for (unsigned i = xs.size(); i-- > 0; ) - if (xs[i]->get_root() != ys[i]->get_root()) - return false; - return true; - } - - void slicing::get_base(enode* src, enode_vector& out_base) const { - enode_vector& todo = m_tmp1; - SASSERT(todo.empty()); - todo.push_back(src); - while (!todo.empty()) { - enode* s = todo.back(); - todo.pop_back(); - if (!has_sub(s)) - out_base.push_back(s); - else { - todo.push_back(sub_lo(s)); - todo.push_back(sub_hi(s)); - } - } - SASSERT(todo.empty()); - } - - slicing::enode_vector slicing::get_base(enode* src) const { - enode_vector out; - get_base(src, out); - return out; - } - - pvar slicing::mk_extract(enode* src, unsigned hi, unsigned lo, pvar replay_var) { - LOG("mk_extract: src=" << slice_pp(*this, src) << " hi=" << hi << " lo=" << lo); - enode_vector& slices = m_tmp3; - SASSERT(slices.empty()); - mk_slice(src, hi, lo, slices, false, false); - pvar v = null_var; - // try to re-use variable of an existing slice - if (slices.size() == 1) - v = slice2var(slices[0]); - if (replay_var != null_var && v != replay_var) { - // replayed variable should be 'fresh', unless it was a re-used variable - enode* s = var2slice(replay_var); - SASSERT(s->is_root()); - SASSERT_EQ(s->class_size(), 1); - SASSERT(!has_sub(s)); - SASSERT_EQ(width(s), hi - lo + 1); - v = replay_var; - } - // allocate new variable if we cannot reuse it - if (v == null_var) { - v = m_solver.add_var(hi - lo + 1, pvar_kind::internal); -#if 1 - // slice didn't have a variable yet; so we can re-use it for the new variable - // (we end up with a "phantom" enode that was first created for the variable) - if (slices.size() == 1) { - enode* s = slices[0]; - LOG("re-using slice " << slice_pp(*this, s) << " for new variable v" << v); - // display_tree(std::cerr, s, 0, hi, lo); - SASSERT_EQ(info(s).var, null_var); - info(m_var2slice[v]).var = null_var; // disconnect the "phantom" enode - info(s).var = v; - m_var2slice[v] = s; - } -#endif - } - // connect new variable - VERIFY(merge(slices, var2slice(v), dep_t())); - slices.reset(); - return v; - } - - void slicing::replay_extract(extract_args const& args, pvar r) { - LOG("replay_extract"); - SASSERT(r != null_var); - SASSERT(!m_extract_dedup.contains(args)); - VERIFY_EQ(mk_extract(var2slice(args.src), args.hi, args.lo, r), r); - m_extract_dedup.insert(args, r); - m_extract_trail.push_back(args); - m_trail.push_back(trail_item::mk_extract); - } - - pvar slicing::mk_extract(pvar src, unsigned hi, unsigned lo) { - LOG_H3("mk_extract: v" << src << "[" << hi << ":" << lo << "] size(v" << src << ") = " << m_solver.size(src)); - if (m_solver.size(src) == hi - lo + 1) - return src; - extract_args args{src, hi, lo}; - auto it = m_extract_dedup.find_iterator(args); - if (it != m_extract_dedup.end()) - return it->m_value; - pvar const v = mk_extract(var2slice(src), hi, lo); - m_extract_dedup.insert(args, v); - m_extract_trail.push_back(args); - m_trail.push_back(trail_item::mk_extract); - LOG("mk_extract: v" << src << "[" << hi << ":" << lo << "] = v" << v); - return v; - } - - void slicing::undo_mk_extract() { - extract_args args = m_extract_trail.back(); - m_extract_trail.pop_back(); - m_extract_dedup.remove(args); - } - - pvar slicing::mk_concat(unsigned num_args, pvar const* args, pvar replay_var) { - enode_vector& slices = m_tmp3; - SASSERT(slices.empty()); - unsigned total_width = 0; - for (unsigned i = 0; i < num_args; ++i) { - enode* s = var2slice(args[i]); - slices.push_back(s); - total_width += width(s); - } - // NOTE: we use concat nodes to deduplicate (syntactically equal) concat expressions. - // we might end up reusing variables that are not introduced by mk_concat (if we enable the variable re-use optimization in mk_extract), - // but because such congruence nodes are only added over direct descendants, we do not get unwanted dependencies from this re-use. - // (but note that the nodes from mk_concat are not only over direct descendants) - enode* concat = mk_concat_node(slices); - pvar v = slice2var(concat); - if (v != null_var) - return v; - if (replay_var != null_var) { - // replayed variable should be 'fresh' - enode* s = var2slice(replay_var); - SASSERT(s->is_root()); - SASSERT_EQ(s->class_size(), 1); - SASSERT(!has_sub(s)); - SASSERT_EQ(width(s), total_width); - v = replay_var; - } - else - v = m_solver.add_var(total_width, pvar_kind::internal); - enode* sv = var2slice(v); - VERIFY(merge(slices, sv, dep_t())); - // NOTE: add_concat_node must be done after merge to preserve the invariant: "a base slice is never equivalent to a congruence node". - add_concat_node(sv, concat); - slices.reset(); - - // don't mess with the concat_trail during replay - if (replay_var == null_var) { - concat_info ci; - ci.v = v; - ci.num_args = num_args; - ci.args_idx = m_concat_args.size(); - m_concat_trail.push_back(ci); - for (unsigned i = 0; i < num_args; ++i) - m_concat_args.push_back(args[i]); - } - m_trail.push_back(trail_item::mk_concat); - - return v; - } - - void slicing::replay_concat(unsigned num_args, pvar const* args, pvar r) { - SASSERT(r != null_var); - VERIFY_EQ(mk_concat(num_args, args, r), r); - } - - pvar slicing::mk_concat(std::initializer_list args) { - return mk_concat(args.size(), args.begin()); - } - - void slicing::add_constraint(signed_constraint c) { - LOG(c); - SASSERT(!is_conflict()); - if (!add_fixed_bits(c)) - return; - if (c->is_eq()) - add_constraint_eq(c->to_eq(), c.blit()); - } - - bool slicing::add_fixed_bits(signed_constraint c) { - // TODO: what is missing here: - // - we don't prioritize constraints that set larger bit ranges - // e.g., c1 sets 3 lower bits, and c2 sets 5 lower bits. - // slicing may have both {c1,c2} in justifications while previously we always prefer c2. - // - instead of prioritizing constraints (which is annoying to do incrementally), let subsumption take care of this issue. - // if constraint C subsumes constraint D, then we might even want to completely deactivate D in the solver? (not easy if D is on higher level than C). - // - (we could wait until propagate() to add fixed bits to the egraph. but that would only work on a single decision level.) - if (c->vars().size() != 1) - return true; - fixed_bits fb; - if (!get_fixed_bits(c, fb)) - return true; - pvar const x = c->vars()[0]; - return add_fixed_bits(x, fb.hi, fb.lo, fb.value, c.blit()); - } - - bool slicing::add_fixed_bits(pvar x, unsigned hi, unsigned lo, rational const& value, sat::literal lit) { - LOG("add_fixed_bits: v" << x << "[" << hi << ":" << lo << "] = " << value << " by " << lit_pp(m_solver, lit)); - enode_vector& xs = m_tmp3; - SASSERT(xs.empty()); - mk_slice(var2slice(x), hi, lo, xs, false, false); - enode* const sval = mk_value_slice(value, hi - lo + 1); - // 'xs' will be cleared by 'merge'. - // NOTE: the 'nullptr' argument will be fixed by 'egraph_merge' - return merge(xs, sval, mk_var_dep(x, nullptr, lit)); - } - - bool slicing::add_constraint_eq(pdd const& p, sat::literal lit) { - auto& m = p.manager(); - for (auto& [a, x] : p.linear_monomials()) { - if (a != 1 && a != m.max_value()) - continue; - pdd const body = a.is_one() ? (m.mk_var(x) - p) : (m.mk_var(x) + p); - // c is either x = body or x != body, depending on polarity - if (!add_equation(x, body, lit)) { - SASSERT(is_conflict()); - return false; - } - // without this check, when p = x - y we would handle both x = y and y = x separately - if (body.is_unary()) - break; - } - return true; - } - - // TODO: handle equations 2^k x = 2^k y? (lower n-k bits of x and y are equal) - bool slicing::add_equation(pvar x, pdd const& body, sat::literal lit) { - LOG("Equation from lit(" << lit << "): v" << x << (lit.sign() ? " != " : " = ") << body); - if (!lit.sign() && body.is_val()) { - LOG(" simple assignment"); - // Simple assignment x = value - return add_value(x, body.val(), lit); - } - enode* const sx = var2slice(x); - pvar const y = m_solver.m_names.get_name(body); - if (y == null_var) { - if (!body.is_val()) { - // TODO: register name trigger (if a name for value 'body' is created later, then merge x=y at that time) - // could also count how often 'body' was registered and introduce name when more than once. - // maybe better: register x as an existing name for 'body'? question is how to track the dependency on c. - LOG(" skip for now (unnamed body)"); - } else - LOG(" skip for now (disequality with constant)"); - return true; - } - enode* const sy = var2slice(y); - if (!lit.sign()) { - LOG(" merge v" << x << " and v" << y); - return merge(sx, sy, lit); - } - else { - LOG(" store disequality v" << x << " != v" << y); - enode* n = find_or_alloc_disequality(sx, sy, lit); - if (!m_disequality_conflict && is_equal(sx, sy)) { - add_var_congruence_if_needed(x); - add_var_congruence_if_needed(y); - m_disequality_conflict = n; - return false; - } - } - return true; - } - - bool slicing::add_value(pvar v, rational const& value, sat::literal lit) { - enode* const sv = var2slice(v); - if (get_value_node(sv) && get_value(get_value_node(sv)) == value) - return true; - enode* const sval = mk_value_slice(value, width(sv)); - return merge(sv, sval, mk_var_dep(v, sv, lit)); - } - - void slicing::add_value(pvar v, rational const& value) { - LOG("v" << v << " := " << value); - SASSERT(!is_conflict()); - (void)add_value(v, value, sat::null_literal); - } - - void slicing::collect_simple_overlaps(pvar v, pvar_vector& out) { - unsigned const first_out = out.size(); - enode* const sv = var2slice(v); - unsigned const v_width = width(sv); - enode_vector& v_base = m_tmp2; - SASSERT(v_base.empty()); - get_base(var2slice(v), v_base); - - SASSERT(all_of(m_egraph.nodes(), [](enode* n) { return !n->is_marked1(); })); - - // Collect direct sub-slices of v and their equivalences - // (these don't need any extra checks) - for (enode* s = sv; s != nullptr; s = has_sub(s) ? sub_lo(s) : nullptr) { - for (enode* n : euf::enode_class(s)) { - if (!is_proper_slice(n)) - continue; - pvar const w = slice2var(n); - if (w == null_var) - continue; - SASSERT(!n->is_marked1()); - n->mark1(); - out.push_back(w); - } - } - - // lowermost base slice of v - enode* const v_base_lo = v_base.back(); - - svector> candidates; - // Collect all other candidate variables, - // i.e., those who share the lowermost base slice with v. - for (enode* n : euf::enode_class(v_base_lo)) { - if (!is_proper_slice(n)) - continue; - if (n == v_base_lo) - continue; - enode* const n0 = n; - pvar w2 = null_var; // the highest variable we care about from this equivalence class - // examine parents to find variables - SASSERT(!has_sub(n)); - while (true) { - pvar const w = slice2var(n); - if (w != null_var && !n->is_marked1()) - w2 = w; - enode* p = parent(n); - if (!p) - break; - if (sub_lo(p) != n) // we only care about lowermost slices of variables - break; - if (width(p) > v_width) - break; - n = p; - } - if (w2 != null_var) - candidates.push_back({n0, w2}); - } - - // Check candidates - for (auto const& [n0, w2] : candidates) { - // unsigned v_next = v_base.size(); - auto v_it = v_base.rbegin(); - enode* n = n0; - SASSERT_EQ(n->get_root(), (*v_it)->get_root()); - ++v_it; - while (true) { - // here: base of n is equivalent to lower portion of base of v - pvar const w = slice2var(n); - if (w != null_var && !n->is_marked1()) { - n->mark1(); - out.push_back(w); - } - if (w == w2) - break; - // - enode* const p = parent(n); - SASSERT(p); - SASSERT_EQ(sub_lo(p), n); // otherwise not a candidate - // check if base of sub_hi(p) matches the base of v - enode_vector& p_hi_base = m_tmp3; - SASSERT(p_hi_base.empty()); - get_base(sub_hi(p), p_hi_base); - auto p_it = p_hi_base.rbegin(); - bool ok = true; - while (ok && p_it != p_hi_base.rend()) { - if (v_it == v_base.rend()) - ok = false; - else if ((*p_it)->get_root() != (*v_it)->get_root()) - ok = false; - else { - ++p_it; - ++v_it; - } - } - p_hi_base.reset(); - if (!ok) - break; - n = p; - } - } - - v_base.reset(); - for (unsigned i = out.size(); i-- > first_out; ) { - enode* n = var2slice(out[i]); - SASSERT(n->is_marked1()); - n->unmark1(); - } - SASSERT(all_of(m_egraph.nodes(), [](enode* n) { return !n->is_marked1(); })); - } - - void slicing::explain_simple_overlap(pvar v, pvar x, std::function const& on_lit) { - SASSERT(width(var2slice(x)) <= width(var2slice(v))); - SASSERT(m_marked_lits.empty()); - SASSERT(m_tmp_deps.empty()); - - if (v == x) - return; - - enode_vector& v_base = m_tmp4; - SASSERT(v_base.empty()); - get_base(var2slice(v), v_base); - enode_vector& x_base = m_tmp5; - SASSERT(x_base.empty()); - get_base(var2slice(x), x_base); - - auto v_it = v_base.rbegin(); - auto x_it = x_base.rbegin(); - while (x_it != x_base.rend()) { - SASSERT(v_it != v_base.rend()); - enode* nv = *v_it; ++v_it; - enode* nx = *x_it; ++x_it; - SASSERT_EQ(nv->get_root(), nx->get_root()); - explain_equal(nv, nx, m_tmp_deps); - } - - for (void* dp : m_tmp_deps) { - dep_t const d = dep_t::decode(dp); - if (d.is_null()) - continue; - if (d.is_lit()) { - sat::literal lit = d.lit(); - if (m_marked_lits.contains(lit)) - continue; - m_marked_lits.insert(lit); - on_lit(d.lit()); - } - else { - // equivalence between to variables cannot be due to value assignment - UNREACHABLE(); - } - } - m_marked_lits.reset(); - m_tmp_deps.reset(); - } - - void slicing::collect_fixed(pvar v, justified_fixed_bits_vector& out) { - enode_vector& base = m_tmp2; - SASSERT(base.empty()); - get_base(var2slice(v), base); - rational a; - unsigned lo = 0; - for (auto it = base.rbegin(); it != base.rend(); ++it) { - enode* const n = *it; - enode* const nv = get_value_node(n); - unsigned const w = width(n); - unsigned const hi = lo + w - 1; - if (try_get_value(nv, a)) - out.push_back({hi, lo, a, n}); - lo += w; - } - base.reset(); - } - - void slicing::explain_fixed(euf::enode* const n, std::function const& on_lit, std::function const& on_var) { - explain_value(n, on_lit, on_var); - } - - pvar_vector slicing::equivalent_vars(pvar v) const { - pvar_vector xs; - for (enode* n : euf::enode_class(var2slice(v))) { - pvar const x = slice2var(n); - if (x != null_var) - xs.push_back(x); - } - return xs; - } - - std::ostream& slicing::display(std::ostream& out) const { - enode_vector base; - for (pvar v = 0; v < m_var2slice.size(); ++v) { - out << "v" << v << ":"; - base.reset(); - enode* const vs = var2slice(v); - get_base(vs, base); - for (enode* s : base) - display(out << " ", s); - if (enode* vnode = get_value_node(vs)) - out << " [root_value: " << get_value(vnode) << "]"; - out << "\n"; - } - return out; - } - - std::ostream& slicing::display_tree(std::ostream& out) const { - for (pvar v = 0; v < m_var2slice.size(); ++v) { - out << "v" << v << ":\n"; - enode* const s = var2slice(v); - display_tree(out, s, 4, width(s) - 1, 0); - } - out << m_egraph << "\n"; - return out; - } - - std::ostream& slicing::display_tree(std::ostream& out, enode* s, unsigned indent, unsigned hi, unsigned lo) const { - out << std::string(indent, ' ') << "[" << hi << ":" << lo << "]"; - out << " id=" << s->get_id(); - out << " w=" << width(s); - if (slice2var(s) != null_var) - out << " var=v" << slice2var(s); - if (parent(s)) - out << " parent=" << parent(s)->get_id(); - if (!s->is_root()) - out << " root=" << s->get_root_id(); - if (enode* n = get_value_node(s)) - out << " value=" << get_value(n); - // if (is_proper_slice(s)) - // out << " "; - if (is_value(s)) - out << " "; - if (is_concat(s)) - out << " "; - if (is_equality(s)) - out << " "; - out << "\n"; - if (has_sub(s)) { - unsigned cut = info(s).cut; - display_tree(out, sub_hi(s), indent + 4, hi, cut + 1 + lo); - display_tree(out, sub_lo(s), indent + 4, cut + lo, lo); - } - return out; - } - - std::ostream& slicing::display(std::ostream& out, enode* s) const { - out << "{id:" << s->get_id(); - if (is_equality(s)) - return out << ",}"; - out << ",w:" << width(s); - out << ",root:" << s->get_root_id(); - if (slice2var(s) != null_var) - out << ",var:v" << slice2var(s); - if (enode* n = get_value_node(s)) - out << ",value:" << get_value(n); - if (s->interpreted()) - out << ","; - if (is_concat(s)) - out << ","; - out << "}"; - return out; - } - - bool slicing::invariant() const { - VERIFY(m_tmp1.empty()); - VERIFY(m_tmp2.empty()); - VERIFY(m_tmp3.empty()); - if (is_conflict()) // if we break a loop early on conflict, we can't guarantee that all properties are satisfied - return true; - for (enode* s : m_egraph.nodes()) { - // we use equality enodes only to track disequalities - if (s->is_equality()) - continue; - // if the slice is equivalent to a variable, then the variable's slice is in the equivalence class - pvar const v = slice2var(s); - if (v != null_var) { - VERIFY_EQ(var2slice(v)->get_root(), s->get_root()); - } - // if slice has a value, it should be propagated to its sub-slices - if (get_value_node(s) && has_sub(s)) { - VERIFY(get_value_node(sub_hi(s))); - VERIFY(get_value_node(sub_lo(s))); - } - // a base slice is never equivalent to a congruence node - if (is_slice(s) && !has_sub(s)) { - VERIFY(all_of(euf::enode_class(s), [&](enode* n) { return is_slice(n); })); - } - if (is_concat(s)) { - // all concat nodes point to a variable slice - VERIFY(slice2var(s) != null_var); - enode* sv = var2slice(slice2var(s)); // the corresponding variable slice - VERIFY(s != sv); - VERIFY(is_slice(sv)); - VERIFY(s->num_args() >= 2); - } - ///////////////////////////////////////////////////////////////// - // class properties (i.e., skip non-representatives) - if (!s->is_root()) - continue; - bool const sub = has_sub(s); - enode_vector const s_base = get_base(s); - for (enode* n : euf::enode_class(s)) { - // equivalence class only contains slices of equal length - VERIFY_EQ(width(s), width(n)); - // either all nodes in the class have subslices or none do - SASSERT_EQ(sub, has_sub(n)); - // bases of equivalent nodes are equivalent - enode_vector const n_base = get_base(n); - VERIFY_EQ(s_base.size(), n_base.size()); - for (unsigned i = s_base.size(); i-- > 0; ) { - VERIFY_EQ(s_base[i]->get_root(), n_base[i]->get_root()); - } - } - } - return true; - } - -} diff --git a/src/sat/smt/polysat/slicing.h b/src/sat/smt/polysat/slicing.h deleted file mode 100644 index f9f90610b..000000000 --- a/src/sat/smt/polysat/slicing.h +++ /dev/null @@ -1,397 +0,0 @@ -/*++ -Copyright (c) 2023 Microsoft Corporation - -Module Name: - - polysat slicing (relating variables of different bit-widths by extraction) - -Author: - - Jakob Rath 2023-06-01 - -Notation: - - Let x be a bit-vector of width w. - Let l, h indices such that 0 <= l <= h < w. - Then x[h:l] extracts h - l + 1 bits of x. - Shorthands: - - x[h:] stands for x[h:0], and - - x[:l] stands for x[w-1:l]. - - Example: - 0001[0:] = 1 - 0001[2:0] = 001 - ---*/ -#pragma once -#include "ast/euf/euf_egraph.h" -#include "ast/bv_decl_plugin.h" -#include "math/polysat/types.h" -#include "math/polysat/constraint.h" -#include "math/polysat/fixed_bits.h" -#include - -namespace polysat { - - class solver; - - class slicing final { - - friend class test_slicing; - - public: - using enode = euf::enode; - using enode_vector = euf::enode_vector; - using enode_pair = euf::enode_pair; - using enode_pair_vector = euf::enode_pair_vector; - - private: - class dep_t { - std::variant m_data; - public: - dep_t() { SASSERT(is_null()); } - dep_t(sat::literal l): m_data(l) { SASSERT(l != sat::null_literal); SASSERT_EQ(l, lit()); } - explicit dep_t(unsigned idx): m_data(idx) { SASSERT_EQ(idx, value_idx()); } - bool is_null() const { return std::holds_alternative(m_data); } - bool is_lit() const { return std::holds_alternative(m_data); } - bool is_value() const { return std::holds_alternative(m_data); } - sat::literal lit() const { SASSERT(is_lit()); return *std::get_if(&m_data); } - unsigned value_idx() const { SASSERT(is_value()); return *std::get_if(&m_data); } - bool operator==(dep_t other) const { return m_data == other.m_data; } - bool operator!=(dep_t other) const { return !operator==(other); } - void* encode() const; - static dep_t decode(void* p); - }; - - using dep_vector = svector; - - std::ostream& display(std::ostream& out, dep_t d) const; - - dep_t mk_var_dep(pvar v, enode* s, sat::literal lit); - - pvar_vector m_dep_var; - ptr_vector m_dep_slice; - sat::literal_vector m_dep_lit; // optional, value assignment comes from a literal "x == val" rather than a solver assignment - unsigned_vector m_dep_size_trail; - - pvar get_dep_var(dep_t d) const { return m_dep_var[d.value_idx()]; } - sat::literal get_dep_lit(dep_t d) const { return m_dep_lit[d.value_idx()]; } - enode* get_dep_slice(dep_t d) const { return m_dep_slice[d.value_idx()]; } - - static constexpr unsigned null_cut = std::numeric_limits::max(); - - // We use the following kinds of enodes: - // - proper slices (of variables) - // - value slices - // - interpreted value nodes ... these are short-lived, and only created to immediately trigger a conflict inside the egraph - // - virtual concat(...) expressions - // - equalities between enodes (to track disequalities; currently not represented in slice_info) - struct slice_info { - // Cut point: if not null_cut, the slice s has been subdivided into s[|s|-1:cut+1] and s[cut:0]. - // The cut point is relative to the parent slice (rather than a root variable, which might not be unique) - unsigned cut = null_cut; // cut point, or null_cut if no subslices - pvar var = null_var; // slice is equivalent to this variable, if any (without dependencies) - enode* parent = nullptr; // parent slice, only for proper slices (if not null: s == sub_hi(parent(s)) || s == sub_lo(parent(s))) - enode* slice = nullptr; // if enode corresponds to a concat(...) expression, this field links to the represented slice. - enode* sub_hi = nullptr; // upper subslice s[|s|-1:cut+1] - enode* sub_lo = nullptr; // lower subslice s[cut:0] - enode* value_node = nullptr; // the root of an equivalence class stores the value slice here, if any - - void reset() { *this = slice_info(); } - bool has_sub() const { return !!sub_hi; } - void set_cut(unsigned cut, enode* sub_hi, enode* sub_lo) { this->cut = cut; this->sub_hi = sub_hi; this->sub_lo = sub_lo; } - }; - using slice_info_vector = svector; - - // Return true iff n is either a proper slice or a value slice - bool is_slice(enode* n) const; - - bool is_proper_slice(enode* n) const { return !is_value(n) && is_slice(n); } - bool is_value(enode* n) const; - bool is_concat(enode* n) const; - bool is_equality(enode* n) const { return n->is_equality(); } - - solver& m_solver; - - ast_manager m_ast; - scoped_ptr m_bv; - - euf::egraph m_egraph; - slice_info_vector m_info; // indexed by enode::get_id() - enode_vector m_var2slice; // pvar -> slice - tracked_uint_set m_needs_congruence; // set of pvars that need updated concat(...) expressions - enode* m_disequality_conflict = nullptr; - - // Add an equation v = concat(s1, ..., sn) - // for each variable v with base slices s1, ..., sn - void update_var_congruences(); - void add_var_congruence(pvar v); - void add_var_congruence_if_needed(pvar v); - bool use_var_congruences() const; - - func_decl* mk_concat_decl(ptr_vector const& args); - enode* mk_concat_node(enode_vector const& slices); - enode* mk_concat_node(std::initializer_list slices) { return mk_concat_node(slices.size(), slices.begin()); } - enode* mk_concat_node(unsigned num_slices, enode* const* slices); - // Add s = concat(s1, ..., sn) - void add_concat_node(enode* s, enode* concat); - - slice_info& info(euf::enode* n); - slice_info const& info(euf::enode* n) const; - - enode* alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var); - enode* find_or_alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var); - enode* alloc_slice(unsigned width, pvar var = null_var); - enode* find_or_alloc_disequality(enode* x, enode* y, sat::literal lit); - - // Find hi, lo such that s = a[hi:lo] - bool find_range_in_ancestor(enode* s, enode* a, unsigned& out_hi, unsigned& out_lo); - - enode* var2slice(pvar v) const { return m_var2slice[v]; } - pvar slice2var(enode* s) const { return info(s).var; } - - unsigned width(enode* s) const; - - enode* parent(enode* s) const { return info(s).parent; } - - enode* get_value_node(enode* s) const { return info(s).value_node; } - void set_value_node(enode* s, enode* value_node); - - unsigned get_cut(enode* s) const { return info(s).cut; } - - bool has_sub(enode* s) const { return info(s).has_sub(); } - - /// Upper subslice (direct child, not necessarily the representative) - enode* sub_hi(enode* s) const { return info(s).sub_hi; } - - /// Lower subslice (direct child, not necessarily the representative) - enode* sub_lo(enode* s) const { return info(s).sub_lo; } - - /// sub_lo(parent(s)) or sub_hi(parent(s)), whichever is different from s. - enode* sibling(enode* s) const; - - // Retrieve (or create) a slice representing the given value. - enode* mk_value_slice(rational const& val, unsigned bit_width); - - // Turn value node into unwrapped BV constant node - enode* mk_interpreted_value_node(enode* value_slice); - - rational get_value(enode* s) const; - bool try_get_value(enode* s, rational& val) const; - - /// Split slice s into s[|s|-1:cut+1] and s[cut:0] - void split(enode* s, unsigned cut); - void split_core(enode* s, unsigned cut); - - /// Retrieve base slices s_1,...,s_n such that src == s_1 ++ ... ++ s_n (actual descendant subslices) - void get_base(enode* src, enode_vector& out_base) const; - enode_vector get_base(enode* src) const; - - /// Retrieve (or create) base slices s_1,...,s_n such that src[hi:lo] == s_1 ++ ... ++ s_n. - /// If output_full_src is true, return the new base for src, i.e., src == s_1 ++ ... ++ s_n. - /// If output_base is false, return coarsest intermediate slices instead of only base slices. - void mk_slice(enode* src, unsigned hi, unsigned lo, enode_vector& out, bool output_full_src = false, bool output_base = true); - - // Extract reason why slices x and y are in the same equivalence class - void explain_class(enode* x, enode* y, ptr_vector& out_deps); - - // Extract reason why slices x and y are equal - // (i.e., x and y have the same base, but are not necessarily in the same equivalence class) - void explain_equal(enode* x, enode* y, ptr_vector& out_deps); - - /** Explain why slice is equivalent to a value */ - void explain_value(enode* s, std::function const& on_lit, std::function const& on_var); - - /** Extract reason for conflict */ - void explain(ptr_vector& out_deps); - - /** Extract reason for x == y */ - void explain_equal(pvar x, pvar y, ptr_vector& out_deps); - - void egraph_on_make(enode* n); - void egraph_on_merge(enode* root, enode* other); - void egraph_on_propagate(enode* lit, enode* ante); - - // Merge slices in the e-graph. - bool egraph_merge(enode* s1, enode* s2, dep_t dep); - - // Merge equivalence classes of two base slices. - // Returns true if merge succeeded without conflict. - [[nodiscard]] bool merge_base(enode* s1, enode* s2, dep_t dep); - - // Merge equality x_1 ++ ... ++ x_n == y_1 ++ ... ++ y_k - // - // Precondition: - // - sequence of slices with equal total width - // - ordered from msb to lsb - // - // The argument vectors will be cleared. - // - // Returns true if merge succeeded without conflict. - [[nodiscard]] bool merge(enode_vector& xs, enode_vector& ys, dep_t dep); - [[nodiscard]] bool merge(enode_vector& xs, enode* y, dep_t dep); - [[nodiscard]] bool merge(enode* x, enode* y, dep_t dep); - - // Check whether two slices are known to be equal - bool is_equal(enode* x, enode* y); - - // deduplication of extract terms - struct extract_args { - pvar src = null_var; - unsigned hi = 0; - unsigned lo = 0; - bool operator==(extract_args const& other) const { return src == other.src && hi == other.hi && lo == other.lo; } - unsigned hash() const { return mk_mix(src, hi, lo); } - }; - using extract_args_eq = default_eq; - using extract_args_hash = obj_hash; - using extract_map = map; - extract_map m_extract_dedup; - // svector m_extract_origin; // pvar -> extract_args - // TODO: add 'm_extract_origin' (pvar -> extract_args)? 1. for dependency tracking when sharing subslice trees; 2. for easily checking if a variable is an extraction of another; 3. also makes the replay easier - // bool is_extract(pvar v) const { return m_extract_origin[v].src != null_var; } - - enum class trail_item : std::uint8_t { - add_var, - split_core, - mk_extract, - mk_concat, - set_value_node, - }; - svector m_trail; - enode_vector m_enode_trail; - svector m_extract_trail; - unsigned_vector m_scopes; - - struct concat_info { - pvar v; - unsigned num_args; - unsigned args_idx; - unsigned next_args_idx() const { return args_idx + num_args; } - }; - svector m_concat_trail; - svector m_concat_args; - - void undo_add_var(); - void undo_split_core(); - void undo_mk_extract(); - void undo_set_value_node(); - - mutable enode_vector m_tmp1; - mutable enode_vector m_tmp2; - mutable enode_vector m_tmp3; - mutable enode_vector m_tmp4; - mutable enode_vector m_tmp5; - ptr_vector m_tmp_deps; - sat::literal_set m_marked_lits; - - /** Get variable representing src[hi:lo] */ - pvar mk_extract(enode* src, unsigned hi, unsigned lo, pvar replay_var = null_var); - /** Restore r = src[hi:lo] */ - void replay_extract(extract_args const& args, pvar r); - - pvar mk_concat(unsigned num_args, pvar const* args, pvar replay_var); - void replay_concat(unsigned num_args, pvar const* args, pvar r); - - bool add_constraint_eq(pdd const& p, sat::literal lit); - bool add_equation(pvar x, pdd const& body, sat::literal lit); - bool add_value(pvar v, rational const& value, sat::literal lit); - bool add_fixed_bits(signed_constraint c); - bool add_fixed_bits(pvar x, unsigned hi, unsigned lo, rational const& value, sat::literal lit); - - bool invariant() const; - bool invariant_needs_congruence() const; - - std::ostream& display(std::ostream& out, enode* s) const; - std::ostream& display_tree(std::ostream& out, enode* s, unsigned indent, unsigned hi, unsigned lo) const; - - class slice_pp { - slicing const& s; - enode* n; - public: - slice_pp(slicing const& s, enode* n): s(s), n(n) {} - std::ostream& display(std::ostream& out) const { return s.display(out, n); } - }; - friend std::ostream& operator<<(std::ostream& out, slice_pp const& s) { return s.display(out); } - - class dep_pp { - slicing const& s; - dep_t d; - public: - dep_pp(slicing const& s, dep_t d): s(s), d(d) {} - std::ostream& display(std::ostream& out) const { return s.display(out, d); } - }; - friend std::ostream& operator<<(std::ostream& out, dep_pp const& d) { return d.display(out); } - - euf::egraph::e_pp e_pp(enode* n) const { return euf::egraph::e_pp(m_egraph, n); } - - public: - slicing(solver& s); - - void push_scope(); - void pop_scope(unsigned num_scopes = 1); - - void add_var(unsigned bit_width); - - /** Get or create variable representing x[hi:lo] */ - pvar mk_extract(pvar x, unsigned hi, unsigned lo); - - /** Get or create variable representing x1 ++ x2 ++ ... ++ xn */ - pvar mk_concat(unsigned num_args, pvar const* args) { return mk_concat(num_args, args, null_var); } - pvar mk_concat(std::initializer_list args); - - // Find hi, lo such that x = src[hi:lo]. - bool is_extract(pvar x, pvar src, unsigned& out_hi, unsigned& out_lo); - - // Track value assignments to variables (and propagate to subslices) - void add_value(pvar v, rational const& value); - void add_value(pvar v, unsigned value) { add_value(v, rational(value)); } - void add_value(pvar v, int value) { add_value(v, rational(value)); } - void add_constraint(signed_constraint c); - - bool can_propagate() const; - - // update congruences, egraph - void propagate(); - - bool is_conflict() const { return m_disequality_conflict || m_egraph.inconsistent(); } - - /** Extract conflict clause */ - clause_ref build_conflict_clause(); - - /** Explain why slicing has propagated the value assignment for v */ - void explain_value(pvar v, std::function const& on_lit, std::function const& on_var); - - /** For a given variable v, find the set of variables w such that w = v[|w|:0]. */ - void collect_simple_overlaps(pvar v, pvar_vector& out); - void explain_simple_overlap(pvar v, pvar x, std::function const& on_lit); - - struct justified_fixed_bits : public fixed_bits { - enode* just; - - justified_fixed_bits(unsigned hi, unsigned lo, rational value, enode* just): fixed_bits(hi, lo, value), just(just) {} - }; - - using justified_fixed_bits_vector = vector; - - /** Collect fixed portions of the variable v */ - void collect_fixed(pvar v, justified_fixed_bits_vector& out); - void explain_fixed(enode* just, std::function const& on_lit, std::function const& on_var); - - /** - * Collect variables that are equivalent to v (including v itself) - * - * NOTE: this might miss some variables that are equal due to equivalent base slices. With 'polysat.slicing.congruence=true' and after propagate(), it should return all equal variables. - */ - pvar_vector equivalent_vars(pvar v) const; - - /** Explain why variables x and y are equivalent */ - void explain_equal(pvar x, pvar y, std::function const& on_lit); - - std::ostream& display(std::ostream& out) const; - std::ostream& display_tree(std::ostream& out) const; - }; - - inline std::ostream& operator<<(std::ostream& out, slicing const& s) { return s.display(out); } - -} From b56a8fa264808cdc179f660baf08d52fd3093081 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 10:03:43 -0800 Subject: [PATCH 24/89] deal with build errors Signed-off-by: Nikolaj Bjorner --- src/sat/smt/euf_solver.cpp | 6 +++++- src/sat/smt/polysat/polysat_assignment.cpp | 13 +------------ src/sat/smt/polysat/polysat_assignment.h | 3 --- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 51c0518e5..2d4b9847e 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -21,6 +21,7 @@ Author: #include "sat/smt/sat_smt.h" #include "sat/smt/pb_solver.h" #include "sat/smt/bv_solver.h" +#include "sat/smt/polysat_solver.h" #include "sat/smt/euf_solver.h" #include "sat/smt/array_solver.h" #include "sat/smt/arith_solver.h" @@ -134,8 +135,11 @@ namespace euf { special_relations_util sp(m); if (pb.get_family_id() == fid) ext = alloc(pb::solver, *this, fid); - else if (bvu.get_family_id() == fid) + else if (bvu.get_family_id() == fid) { ext = alloc(bv::solver, *this, fid); + dealloc(ext); + ext = alloc(polysat::solver, *this, fid); + } else if (au.get_family_id() == fid) ext = alloc(array::solver, *this, fid); else if (fpa.get_family_id() == fid) diff --git a/src/sat/smt/polysat/polysat_assignment.cpp b/src/sat/smt/polysat/polysat_assignment.cpp index aedf6d409..329733d89 100644 --- a/src/sat/smt/polysat/polysat_assignment.cpp +++ b/src/sat/smt/polysat/polysat_assignment.cpp @@ -12,6 +12,7 @@ Author: --*/ +#include #include "sat/smt/polysat/polysat_assignment.h" #include "sat/smt/polysat/polysat_core.h" @@ -43,18 +44,6 @@ namespace polysat { assignment::assignment(core& s) : m_core(s) { } - - assignment assignment::clone() const { - assignment a(s()); - a.m_pairs = m_pairs; - a.m_subst.reserve(m_subst.size()); - for (unsigned i = m_subst.size(); i-- > 0; ) - if (m_subst[i]) - a.m_subst.set(i, alloc(substitution, *m_subst[i])); - a.m_subst_trail = m_subst_trail; - return a; - } - bool assignment::contains(pvar var) const { return subst(s().size(var)).contains(var); } diff --git a/src/sat/smt/polysat/polysat_assignment.h b/src/sat/smt/polysat/polysat_assignment.h index befaad0b7..559f6dab2 100644 --- a/src/sat/smt/polysat/polysat_assignment.h +++ b/src/sat/smt/polysat/polysat_assignment.h @@ -91,9 +91,6 @@ namespace polysat { // prevent implicit copy, use clone() if you do need a copy assignment(assignment const&) = delete; assignment& operator=(assignment const&) = delete; - assignment(assignment&&) = default; - assignment& operator=(assignment&&) = default; - assignment clone() const; void push(pvar var, rational const& value); void pop(); From d64a2bdbed5c0dbafbe8efd6e0c90359a77d594b Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 10:27:22 -0800 Subject: [PATCH 25/89] include dependency in cmakelist Signed-off-by: Nikolaj Bjorner --- src/sat/smt/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 95d0a5324..26de85168 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -55,6 +55,7 @@ z3_add_component(sat_smt ast euf mbp + polysat smt_params ) From 9931c811caddfd22be60cea5da5174331945b63f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 12:31:10 -0800 Subject: [PATCH 26/89] start intblast solver --- src/sat/smt/CMakeLists.txt | 1 + src/sat/smt/intblast_solver.cpp | 316 ++++++++++++++++++++++++++++++++ src/sat/smt/intblast_solver.h | 64 +++++++ 3 files changed, 381 insertions(+) create mode 100644 src/sat/smt/intblast_solver.cpp create mode 100644 src/sat/smt/intblast_solver.h diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 26de85168..2302a6c39 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -27,6 +27,7 @@ z3_add_component(sat_smt euf_proof_checker.cpp euf_relevancy.cpp euf_solver.cpp + intblast_solver.cpp fpa_solver.cpp pb_card.cpp pb_constraint.cpp diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp new file mode 100644 index 000000000..36ebdbacd --- /dev/null +++ b/src/sat/smt/intblast_solver.cpp @@ -0,0 +1,316 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + intblast_solver.cpp + +Author: + + Nikolaj Bjorner (nbjorner) 2023-12-10 + +--*/ + +#include "ast/ast_util.h" +#include "ast/for_each_expr.h" +#include "sat/smt/intblast_solver.h" +#include "sat/smt/euf_solver.h" + + +namespace intblast { + + solver::solver(euf::solver& ctx): + ctx(ctx), + s(ctx.s()), + m(ctx.get_manager()), + bv(m), + a(m), + m_trail(m) + {} + + lbool solver::check() { + sat::literal_vector literals; + uint_set selected; + for (auto const& clause : s.clauses()) { + if (any_of(*clause, [&](auto lit) { return selected.contains(lit.index()); })) + continue; + if (any_of(*clause, [&](auto lit) { return s.value(lit) == l_true && !is_bv(lit); })) + continue; + sat::literal selected_lit = sat::null_literal; + for (auto lit : *clause) { + if (s.value(lit) != l_true) + continue; + SASSERT(is_bv(lit)); + if (selected_lit == sat::null_literal || s.lvl(selected_lit) > s.lvl(lit)) + selected_lit = lit; + } + if (selected_lit == sat::null_literal) { + UNREACHABLE(); + return l_undef; + } + selected.insert(selected_lit.index()); + literals.push_back(selected_lit); + } + unsigned trail_sz = s.init_trail_size(); + for (unsigned i = 0; i < trail_sz; ++i) { + auto lit = s.trail_literal(i); + if (selected.contains(lit.index()) || !is_bv(lit)) + continue; + selected.insert(lit.index()); + literals.push_back(lit); + } + svector> bin; + s.collect_bin_clauses(bin, false, false); + for (auto [a, b] : bin) { + if (selected.contains(a.index())) + continue; + if (selected.contains(b.index())) + continue; + if (s.value(a) == l_true && !is_bv(a)) + continue; + if (s.value(b) == l_true && !is_bv(b)) + continue; + if (s.value(a) == l_false) + std::swap(a, b); + if (s.value(b) == l_true && s.value(a) == l_true && s.lvl(b) < s.lvl(a)) + std::swap(a, b); + selected.insert(a.index()); + literals.push_back(a); + } + + m_solver = mk_smt2_solver(m, s.params(), symbol::null); + + expr_ref_vector es(m); + for (auto lit : literals) + es.push_back(ctx.literal2expr(lit)); + + translate(es); + + for (auto e : es) + m_solver->assert_expr(e); + + + lbool r = m_solver->check_sat(0, nullptr); + + return r; + }; + + bool solver::is_bv(sat::literal lit) { + expr* e = ctx.bool_var2expr(lit.var()); + if (!e) + return false; + if (m.is_and(e) || m.is_or(e) || m.is_not(e) || m.is_implies(e) || m.is_iff(e)) + return false; + if (is_quantifier(e)) + return false; + return any_of(subterms::all(expr_ref(e, m)), [&](auto* p) { return bv.is_bv_sort(p->get_sort()); }); + } + + void solver::sorted_subterms(expr_ref_vector const& es, ptr_vector& sorted) { + expr_fast_mark1 visited; + for (expr* e : es) { + sorted.push_back(e); + visited.mark(e); + } + for (unsigned i = 0; i < sorted.size(); ++i) { + expr* e = sorted[i]; + if (is_app(e)) { + app* a = to_app(e); + for (expr* arg : *a) { + if (!visited.is_marked(arg)) { + visited.mark(arg); + sorted.push_back(arg); + } + } + } + else if (is_quantifier(e)) { + quantifier* q = to_quantifier(e); + expr* b = q->get_expr(); + if (!visited.is_marked(b)) { + visited.mark(b); + sorted.push_back(b); + } + } + } + } + + void solver::translate(expr_ref_vector& es) { + ptr_vector todo; + obj_map translated; + expr_ref_vector args(m); + m_trail.reset(); + m_vars.reset(); + + sorted_subterms(es, todo); + for (unsigned i = todo.size(); i-- > 0; ) { + expr* e = todo[i]; + if (is_quantifier(e)) { + quantifier* q = to_quantifier(e); + expr* b = q->get_expr(); + m_trail.push_back(m.update_quantifier(q, translated[b])); + translated.insert(e, m_trail.back()); + continue; + } + if (is_var(e)) { + if (bv.is_bv_sort(e->get_sort())) { + expr* v = m.mk_var(to_var(e)->get_idx(), a.mk_int()); + m_trail.push_back(v); + translated.insert(e, m_trail.back()); + } + else { + m_trail.push_back(e); + translated.insert(e, m_trail.back()); + } + continue; + } + app* ap = to_app(e); + args.reset(); + for (auto arg : *ap) + args.push_back(translated[arg]); + + auto bv_size = [&]() { return rational::power_of_two(bv.get_bv_size(e->get_sort())); }; + + auto mk_mod = [&](expr* x) { + if (m_vars.contains(x)) + return x; + return to_expr(a.mk_mod(x, a.mk_int(bv_size()))); + }; + + auto mk_smod = [&](expr* x) { + auto shift = bv_size() / 2; + return a.mk_mod(a.mk_add(x, a.mk_int(shift)), a.mk_int(bv_size())); + }; + + if (m.is_eq(e)) { + bool has_bv_arg = any_of(*ap, [&](expr* arg) { return bv.is_bv(arg); }); + if (has_bv_arg) { + m_trail.push_back(m.mk_eq(mk_mod(args.get(0)), mk_mod(args.get(1)))); + translated.insert(e, m_trail.back()); + } + else { + m_trail.push_back(m.mk_eq(args.get(0), args.get(1))); + translated.insert(e, m_trail.back()); + } + continue; + } + + if (ap->get_family_id() != bv.get_family_id()) { + bool has_bv_arg = any_of(*ap, [&](expr* arg) { return bv.is_bv(arg); }); + bool has_bv_sort = bv.is_bv(e); + func_decl* f = ap->get_decl(); + if (has_bv_arg) { + // need to update args with mod where they are bit-vectors. + NOT_IMPLEMENTED_YET(); + } + + if (has_bv_arg || has_bv_sort) { + ptr_vector domain; + for (auto* arg : *ap) { + sort* s = arg->get_sort(); + domain.push_back(bv.is_bv_sort(s) ? a.mk_int() : s); + } + sort* range = bv.is_bv(e) ? a.mk_int() : e->get_sort(); + f = m.mk_fresh_func_decl(ap->get_decl()->get_name(), symbol("bv"), domain.size(), domain.data(), range); + } + + m_trail.push_back(m.mk_app(f, args)); + translated.insert(e, m_trail.back()); + + if (has_bv_sort) + m_vars.insert(e, { m_trail.back(), bv_size() }); + + continue; + } + + switch (ap->get_decl_kind()) { + case OP_BADD: + m_trail.push_back(a.mk_add(args)); + break; + case OP_BSUB: + m_trail.push_back(a.mk_sub(args.size(), args.data())); + break; + case OP_BMUL: + m_trail.push_back(a.mk_mul(args)); + break; + case OP_ULEQ: + m_trail.push_back(a.mk_le(mk_mod(args.get(0)), mk_mod(args.get(1)))); + break; + case OP_UGEQ: + m_trail.push_back(a.mk_ge(mk_mod(args.get(0)), mk_mod(args.get(1)))); + break; + case OP_ULT: + m_trail.push_back(a.mk_lt(mk_mod(args.get(0)), mk_mod(args.get(1)))); + break; + case OP_UGT: + m_trail.push_back(a.mk_gt(mk_mod(args.get(0)), mk_mod(args.get(1)))); + break; + case OP_SLEQ: + m_trail.push_back(a.mk_le(mk_smod(args.get(0)), mk_smod(args.get(1)))); + break; + case OP_SGEQ: + m_trail.push_back(a.mk_ge(mk_smod(args.get(0)), mk_smod(args.get(1)))); + break; + case OP_SLT: + m_trail.push_back(a.mk_lt(mk_smod(args.get(0)), mk_smod(args.get(1)))); + break; + case OP_SGT: + m_trail.push_back(a.mk_gt(mk_smod(args.get(0)), mk_smod(args.get(1)))); + break; + case OP_BNEG: + m_trail.push_back(a.mk_uminus(args.get(0))); + break; + case OP_BNOT: + case OP_BNAND: + case OP_BNOR: + case OP_BXOR: + case OP_BXNOR: + case OP_BCOMP: + case OP_BSHL: + case OP_BLSHR: + case OP_BASHR: + case OP_ROTATE_LEFT: + case OP_ROTATE_RIGHT: + case OP_EXT_ROTATE_LEFT: + case OP_EXT_ROTATE_RIGHT: + case OP_REPEAT: + case OP_ZERO_EXT: + case OP_SIGN_EXT: + case OP_BREDOR: + case OP_BREDAND: + case OP_BUDIV: + case OP_BSDIV: + case OP_BUREM: + case OP_BSREM: + case OP_BSMOD: + case OP_BAND: + NOT_IMPLEMENTED_YET(); + break; + } + translated.insert(e, m_trail.back()); + } + for (unsigned i = 0; i < es.size(); ++i) + es[i] = translated[es.get(i)]; + for (auto const& [src, vi] : m_vars) { + auto const& [v, b] = vi; + es.push_back(a.mk_le(a.mk_int(0), v)); + es.push_back(a.mk_lt(v, a.mk_int(b))); + } + } + + rational solver::get_value(expr* e) const { + SASSERT(bv.is_bv(e)); + model_ref mdl; + m_solver->get_model(mdl); + expr_ref r(m); + var_info vi; + rational val; + if (!m_vars.find(e, vi)) + return rational::zero(); + if (!mdl->eval_expr(vi.dst, r, true)) + return rational::zero(); + if (!a.is_numeral(r, val)) + return rational::zero(); + return val; + } + +} diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h new file mode 100644 index 000000000..f2ec486d5 --- /dev/null +++ b/src/sat/smt/intblast_solver.h @@ -0,0 +1,64 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + intblast_solver.h + +Abstract: + + Int-blast solver. + It assumes a full assignemnt to literals in + irredundant clauses. + It picks a satisfying Boolean assignment and + checks if it is feasible for bit-vectors using + an arithmetic solver. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-12-10 + +--*/ +#pragma once + +#include "ast/arith_decl_plugin.h" +#include "ast/bv_decl_plugin.h" +#include "solver/solver.h" +#include "sat/smt/sat_th.h" + +namespace euf { + class solver; +} + +namespace intblast { + + class solver { + struct var_info { + expr* dst; + rational sz; + }; + + euf::solver& ctx; + sat::solver& s; + ast_manager& m; + bv_util bv; + arith_util a; + scoped_ptr<::solver> m_solver; + obj_map m_vars; + expr_ref_vector m_trail; + + + + bool is_bv(sat::literal lit); + void translate(expr_ref_vector& es); + void sorted_subterms(expr_ref_vector const& es, ptr_vector& sorted); + + public: + solver(euf::solver& ctx); + + lbool check(); + + rational get_value(expr* e) const; + }; + +} From 09c2e0dd6e613704b2e312fc27d7b665d6496153 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 13:00:43 -0800 Subject: [PATCH 27/89] integrate intblast solver --- src/sat/smt/intblast_solver.cpp | 30 +++++++++++++++++++++--------- src/sat/smt/intblast_solver.h | 3 +++ src/sat/smt/polysat_solver.cpp | 30 +++++++++++++++++++++++++++++- src/sat/smt/polysat_solver.h | 4 +++- 4 files changed, 56 insertions(+), 11 deletions(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 36ebdbacd..25ec308d8 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -78,6 +78,7 @@ namespace intblast { literals.push_back(a); } + m_core.reset(); m_solver = mk_smt2_solver(m, s.params(), symbol::null); expr_ref_vector es(m); @@ -86,11 +87,23 @@ namespace intblast { translate(es); - for (auto e : es) - m_solver->assert_expr(e); - + for (auto const& [src, vi] : m_vars) { + auto const& [v, b] = vi; + m_solver->assert_expr(a.mk_le(a.mk_int(0), v)); + m_solver->assert_expr(a.mk_lt(v, a.mk_int(b))); + } - lbool r = m_solver->check_sat(0, nullptr); + lbool r = m_solver->check_sat(es); + + if (r == l_false) { + expr_ref_vector core(m); + m_solver->get_unsat_core(core); + obj_map e2index; + for (unsigned i = 0; i < es.size(); ++i) + e2index.insert(es.get(i), i); + for (auto e : core) + m_core.push_back(literals[e2index[e]]); + } return r; }; @@ -290,11 +303,6 @@ namespace intblast { } for (unsigned i = 0; i < es.size(); ++i) es[i] = translated[es.get(i)]; - for (auto const& [src, vi] : m_vars) { - auto const& [v, b] = vi; - es.push_back(a.mk_le(a.mk_int(0), v)); - es.push_back(a.mk_lt(v, a.mk_int(b))); - } } rational solver::get_value(expr* e) const { @@ -313,4 +321,8 @@ namespace intblast { return val; } + sat::literal_vector const& solver::unsat_core() { + return m_core; + } + } diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index f2ec486d5..1df46c300 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -46,6 +46,7 @@ namespace intblast { scoped_ptr<::solver> m_solver; obj_map m_vars; expr_ref_vector m_trail; + sat::literal_vector m_core; @@ -58,6 +59,8 @@ namespace intblast { lbool check(); + sat::literal_vector const& unsat_core(); + rational get_value(expr* e) const; }; diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 460501cd0..8baf05de5 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -31,6 +31,8 @@ The result of polysat::core::check is one of: #include "sat/smt/polysat/polysat_umul_ovfl.h" + + namespace polysat { solver::solver(euf::solver& ctx, theory_id id): @@ -38,6 +40,7 @@ namespace polysat { bv(ctx.get_manager()), m_autil(ctx.get_manager()), m_core(*this), + m_intblast(ctx), m_lemma(ctx.get_manager()) { ctx.get_egraph().add_plugin(alloc(euf::bv_plugin, ctx.get_egraph())); @@ -56,7 +59,31 @@ namespace polysat { } sat::check_result solver::check() { - return m_core.check(); + switch (m_core.check()) { + case sat::check_result::CR_DONE: + return sat::check_result::CR_DONE; + case sat::check_result::CR_CONTINUE: + return sat::check_result::CR_CONTINUE; + case sat::check_result::CR_GIVEUP: { + if (!m.inc()) + return sat::check_result::CR_GIVEUP; + switch (m_intblast.check()) { + case l_true: + trail().push(value_trail(m_use_intblast_model)); + m_use_intblast_model = true; + return sat::check_result::CR_DONE; + case l_false: { + auto core = m_intblast.unsat_core(); + for (auto& lit : core) + lit.neg(); + s().add_clause(core.size(), core.data(), sat::status::th(true, get_id(), nullptr)); + return sat::check_result::CR_CONTINUE; + } + case l_undef: + return sat::check_result::CR_GIVEUP; + } + } + } } void solver::asserted(literal l) { @@ -136,6 +163,7 @@ namespace polysat { unsigned num_scopes = s().scope_lvl() - m_lemma_level; + NOT_IMPLEMENTED_YET(); // s().pop_reinit(num_scopes); sat::literal_vector lits; diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index b5e69c36a..e1a9221e9 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -19,6 +19,7 @@ Author: #include "sat/smt/sat_th.h" #include "math/dd/dd_pdd.h" #include "sat/smt/polysat/polysat_core.h" +#include "sat/smt/intblast_solver.h" namespace euf { class solver; @@ -57,7 +58,8 @@ namespace polysat { arith_util m_autil; stats m_stats; core m_core; - polysat_proof m_proof; + intblast::solver m_intblast; + bool m_use_intblast_model = false; vector m_var2pdd; // theory_var 2 pdd bool_vector m_var2pdd_valid; // valid flag From 701671466b1eab2ddaa4a5e6a7970c266b34c340 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 19:55:25 -0800 Subject: [PATCH 28/89] integrating int-blaster --- src/sat/smt/euf_solver.cpp | 7 +- src/sat/smt/intblast_solver.cpp | 91 +++++++++++++++++-- src/sat/smt/polysat/polysat_constraints.cpp | 40 +++++++-- src/sat/smt/polysat/polysat_constraints.h | 24 ++--- src/sat/smt/polysat/polysat_core.cpp | 2 +- src/sat/smt/polysat/polysat_core.h | 36 ++++---- src/sat/smt/polysat/polysat_ule.cpp | 18 ++++ src/sat/smt/polysat/polysat_ule.h | 2 + src/sat/smt/polysat/polysat_umul_ovfl.h | 2 + src/sat/smt/polysat/polysat_viable.cpp | 4 +- src/sat/smt/polysat/polysat_viable.h | 14 +-- src/sat/smt/polysat_internalize.cpp | 99 ++++++++++++++++++--- src/sat/smt/polysat_solver.cpp | 14 +-- src/sat/smt/polysat_solver.h | 9 +- 14 files changed, 285 insertions(+), 77 deletions(-) diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 2d4b9847e..b6606d4f6 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -135,11 +135,8 @@ namespace euf { special_relations_util sp(m); if (pb.get_family_id() == fid) ext = alloc(pb::solver, *this, fid); - else if (bvu.get_family_id() == fid) { - ext = alloc(bv::solver, *this, fid); - dealloc(ext); - ext = alloc(polysat::solver, *this, fid); - } + else if (bvu.get_family_id() == fid) + ext = alloc(polysat::solver, *this, fid); else if (au.get_family_id() == fid) ext = alloc(array::solver, *this, fid); else if (fpa.get_family_id() == fid) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 25ec308d8..239bb4682 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -79,6 +79,8 @@ namespace intblast { } m_core.reset(); + m_vars.reset(); + m_trail.reset(); m_solver = mk_smt2_solver(m, s.params(), symbol::null); expr_ref_vector es(m); @@ -89,12 +91,19 @@ namespace intblast { for (auto const& [src, vi] : m_vars) { auto const& [v, b] = vi; + verbose_stream() << "asserting " << mk_pp(v, m) << " < " << b << "\n"; m_solver->assert_expr(a.mk_le(a.mk_int(0), v)); m_solver->assert_expr(a.mk_lt(v, a.mk_int(b))); } + + verbose_stream() << "check\n"; + m_solver->display(verbose_stream()); + verbose_stream() << es << "\n"; lbool r = m_solver->check_sat(es); + verbose_stream() << "result " << r << "\n"; + if (r == l_false) { expr_ref_vector core(m); m_solver->get_unsat_core(core); @@ -114,8 +123,6 @@ namespace intblast { return false; if (m.is_and(e) || m.is_or(e) || m.is_not(e) || m.is_implies(e) || m.is_iff(e)) return false; - if (is_quantifier(e)) - return false; return any_of(subterms::all(expr_ref(e, m)), [&](auto* p) { return bv.is_bv_sort(p->get_sort()); }); } @@ -145,22 +152,34 @@ namespace intblast { } } } + std::stable_sort(sorted.begin(), sorted.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); } void solver::translate(expr_ref_vector& es) { ptr_vector todo; obj_map translated; expr_ref_vector args(m); - m_trail.reset(); - m_vars.reset(); sorted_subterms(es, todo); - for (unsigned i = todo.size(); i-- > 0; ) { - expr* e = todo[i]; + for (expr* e : todo) { if (is_quantifier(e)) { quantifier* q = to_quantifier(e); expr* b = q->get_expr(); - m_trail.push_back(m.update_quantifier(q, translated[b])); + + unsigned nd = q->get_num_decls(); + ptr_vector sorts; + for (unsigned i = 0; i < nd; ++i) { + auto s = q->get_decl_sort(i); + if (bv.is_bv_sort(s)) { + NOT_IMPLEMENTED_YET(); + sorts.push_back(a.mk_int()); + } + else + sorts.push_back(s); + } + b = translated[b]; + // TODO if sorts contain integer, then created bounds variables. + m_trail.push_back(m.update_quantifier(q, b)); translated.insert(e, m_trail.back()); continue; } @@ -177,11 +196,12 @@ namespace intblast { continue; } app* ap = to_app(e); + expr* bv_expr = e; args.reset(); for (auto arg : *ap) args.push_back(translated[arg]); - auto bv_size = [&]() { return rational::power_of_two(bv.get_bv_size(e->get_sort())); }; + auto bv_size = [&]() { return rational::power_of_two(bv.get_bv_size(bv_expr->get_sort())); }; auto mk_mod = [&](expr* x) { if (m_vars.contains(x)) @@ -197,6 +217,7 @@ namespace intblast { if (m.is_eq(e)) { bool has_bv_arg = any_of(*ap, [&](expr* arg) { return bv.is_bv(arg); }); if (has_bv_arg) { + bv_expr = ap->get_arg(0); m_trail.push_back(m.mk_eq(mk_mod(args.get(0)), mk_mod(args.get(1)))); translated.insert(e, m_trail.back()); } @@ -229,6 +250,8 @@ namespace intblast { m_trail.push_back(m.mk_app(f, args)); translated.insert(e, m_trail.back()); + verbose_stream() << "translate " << mk_pp(e, m) << " " << has_bv_sort << "\n"; + if (has_bv_sort) m_vars.insert(e, { m_trail.back(), bv_size() }); @@ -272,6 +295,53 @@ namespace intblast { case OP_BNEG: m_trail.push_back(a.mk_uminus(args.get(0))); break; + case OP_CONCAT: { + expr_ref r(a.mk_int(0), m); + unsigned sz = 0; + for (unsigned i = 0; i < args.size(); ++i) { + expr* old_arg = ap->get_arg(i); + expr* new_arg = args.get(i); + bv_expr = old_arg; + new_arg = mk_mod(new_arg); + if (sz > 0) { + new_arg = a.mk_mul(new_arg, a.mk_int(rational::power_of_two(sz))); + r = a.mk_add(r, new_arg); + } + else + r = new_arg; + sz += bv.get_bv_size(old_arg->get_sort()); + } + m_trail.push_back(r); + break; + } + case OP_EXTRACT: { + unsigned lo, hi; + expr* old_arg; + VERIFY(bv.is_extract(e, lo, hi, old_arg)); + unsigned sz = hi - lo + 1; + expr* new_arg = args.get(0); + if (lo > 0) + new_arg = a.mk_div(new_arg, a.mk_int(rational::power_of_two(lo))); + m_trail.push_back(new_arg); + break; + } + case OP_BV_NUM: { + rational val; + unsigned sz; + VERIFY(bv.is_numeral(e, val, sz)); + m_trail.push_back(a.mk_int(val)); + break; + } + case OP_BUREM_I: { + expr* x = args.get(0), * y = args.get(1); + m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); + break; + } + case OP_BUDIV_I: { + expr* x = args.get(0), * y = args.get(1); + m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); + break; + } case OP_BNOT: case OP_BNAND: case OP_BNOR: @@ -296,9 +366,14 @@ namespace intblast { case OP_BSREM: case OP_BSMOD: case OP_BAND: + verbose_stream() << mk_pp(e, m) << "\n"; NOT_IMPLEMENTED_YET(); break; + default: + verbose_stream() << mk_pp(e, m) << "\n"; + NOT_IMPLEMENTED_YET(); } + verbose_stream() << "insert " << mk_pp(e, m) << " -> " << mk_pp(m_trail.back(), m) << "\n"; translated.insert(e, m_trail.back()); } for (unsigned i = 0; i < es.size(); ++i) diff --git a/src/sat/smt/polysat/polysat_constraints.cpp b/src/sat/smt/polysat/polysat_constraints.cpp index 99da7b0db..faeea62a4 100644 --- a/src/sat/smt/polysat/polysat_constraints.cpp +++ b/src/sat/smt/polysat/polysat_constraints.cpp @@ -1,4 +1,4 @@ -/*++ +/*++ Copyright (c) 2021 Microsoft Corporation Module Name: @@ -12,6 +12,7 @@ Author: --*/ +#include "util/log.h" #include "sat/smt/polysat/polysat_core.h" #include "sat/smt/polysat/polysat_constraints.h" #include "sat/smt/polysat/polysat_ule.h" @@ -23,16 +24,34 @@ namespace polysat { pdd lhs = p, rhs = q; bool is_positive = true; ule_constraint::simplify(is_positive, lhs, rhs); - auto* c = alloc(ule_constraint, p, q); - m_trail.push(new_obj_trail(c)); - auto sc = signed_constraint(ckind_t::ule_t, c); + auto* cnstr = alloc(ule_constraint, p, q); + c.trail().push(new_obj_trail(cnstr)); + auto sc = signed_constraint(ckind_t::ule_t, cnstr); return is_positive ? sc : ~sc; } signed_constraint constraints::umul_ovfl(pdd const& p, pdd const& q) { - auto* c = alloc(umul_ovfl_constraint, p, q); - m_trail.push(new_obj_trail(c)); - return signed_constraint(ckind_t::umul_ovfl_t, c); + auto* cnstr = alloc(umul_ovfl_constraint, p, q); + c.trail().push(new_obj_trail(cnstr)); + return signed_constraint(ckind_t::umul_ovfl_t, cnstr); + } + + bool signed_constraint::is_eq(pvar& v, rational& val) { + if (m_sign) + return false; + if (!is_ule()) + return false; + auto const& ule = to_ule(); + auto const& l = ule.lhs(), &r = ule.rhs(); + if (!r.is_zero()) + return false; + if (!l.is_unilinear()) + return false; + if (!l.hi().is_one()) + return false; + v = l.var(); + val = -l.lo().val(); + return true; } lbool signed_constraint::eval(assignment& a) const { @@ -44,4 +63,11 @@ namespace polysat { if (m_sign) out << "~"; return out << *m_constraint; } + + bool signed_constraint::is_always_true() const { + return m_sign ? m_constraint->is_always_false() : m_constraint->is_always_true(); + } + bool signed_constraint::is_always_false() const { + return m_sign ? m_constraint->is_always_true() : m_constraint->is_always_false(); + } } diff --git a/src/sat/smt/polysat/polysat_constraints.h b/src/sat/smt/polysat/polysat_constraints.h index 687b3d91a..fdef902e2 100644 --- a/src/sat/smt/polysat/polysat_constraints.h +++ b/src/sat/smt/polysat/polysat_constraints.h @@ -41,6 +41,8 @@ namespace polysat { virtual std::ostream& display(std::ostream& out) const = 0; virtual lbool eval() const = 0; virtual lbool eval(assignment const& a) const = 0; + virtual bool is_always_true() const = 0; + virtual bool is_always_false() const = 0; }; inline std::ostream& operator<<(std::ostream& out, constraint const& c) { return c.display(out); } @@ -61,6 +63,8 @@ namespace polysat { unsigned_vector const& vars() const { return m_constraint->vars(); } unsigned var(unsigned idx) const { return m_constraint->var(idx); } bool contains_var(pvar v) const { return m_constraint->contains_var(v); } + bool is_always_true() const; + bool is_always_false() const; lbool eval(assignment& a) const; ckind_t op() const { return m_op; } bool is_ule() const { return m_op == ule_t; } @@ -68,27 +72,27 @@ namespace polysat { bool is_smul_fl() const { return m_op == smul_fl_t; } ule_constraint const& to_ule() const { return *reinterpret_cast(m_constraint); } umul_ovfl_constraint const& to_umul_ovfl() const { return *reinterpret_cast(m_constraint); } - bool is_eq(pvar& v, rational& val) { throw default_exception("nyi"); } + bool is_eq(pvar& v, rational& val); std::ostream& display(std::ostream& out) const; }; inline std::ostream& operator<<(std::ostream& out, signed_constraint const& c) { return c.display(out); } class constraints { - trail_stack& m_trail; + core& c; public: - constraints(trail_stack& c) : m_trail(c) {} + constraints(core& c) : c(c) {} signed_constraint eq(pdd const& p) { return ule(p, p.manager().mk_val(0)); } signed_constraint eq(pdd const& p, rational const& v) { return eq(p - p.manager().mk_val(v)); } signed_constraint ule(pdd const& p, pdd const& q); - signed_constraint sle(pdd const& p, pdd const& q) { throw default_exception("nyi"); } - signed_constraint ult(pdd const& p, pdd const& q) { throw default_exception("nyi"); } - signed_constraint slt(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint sle(pdd const& p, pdd const& q) { auto sh = rational::power_of_two(p.power_of_2() - 1); return ule(p + sh, q + sh); } + signed_constraint ult(pdd const& p, pdd const& q) { return ~ule(q, p); } + signed_constraint slt(pdd const& p, pdd const& q) { return ~sle(q, p); } signed_constraint umul_ovfl(pdd const& p, pdd const& q); - signed_constraint smul_ovfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } - signed_constraint smul_udfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } - signed_constraint bit(pdd const& p, unsigned i) { throw default_exception("nyi"); } + signed_constraint smul_ovfl(pdd const& p, pdd const& q) { throw default_exception("smul ovfl nyi"); } + signed_constraint smul_udfl(pdd const& p, pdd const& q) { throw default_exception("smult-udfl nyi"); } + signed_constraint bit(pdd const& p, unsigned i) { throw default_exception("bit nyi"); } signed_constraint diseq(pdd const& p) { return ~eq(p); } signed_constraint diseq(pdd const& p, pdd const& q) { return diseq(p - q); } @@ -138,4 +142,4 @@ namespace polysat { //signed_constraint even(pdd const& p) { return parity_at_least(p, 1); } //signed_constraint odd(pdd const& p) { return ~even(p); } }; -} \ No newline at end of file +} diff --git a/src/sat/smt/polysat/polysat_core.cpp b/src/sat/smt/polysat/polysat_core.cpp index be25c9af8..f2bb995d7 100644 --- a/src/sat/smt/polysat/polysat_core.cpp +++ b/src/sat/smt/polysat/polysat_core.cpp @@ -80,7 +80,7 @@ namespace polysat { core::core(solver_interface& s) : s(s), m_viable(*this), - m_constraints(s.trail()), + m_constraints(*this), m_assignment(*this), m_var_queue(m_activity) {} diff --git a/src/sat/smt/polysat/polysat_core.h b/src/sat/smt/polysat/polysat_core.h index 144b1256b..5262f2bbe 100644 --- a/src/sat/smt/polysat/polysat_core.h +++ b/src/sat/smt/polysat/polysat_core.h @@ -84,7 +84,8 @@ namespace polysat { void get_bitvector_prefixes(pvar v, pvar_vector& out); void get_fixed_bits(pvar v, svector& fixed_bits); bool inconsistent() const; - + void add_clause(char const* name, std::initializer_list cs, bool is_redundant); + void add_watch(unsigned idx, unsigned var); @@ -114,29 +115,28 @@ namespace polysat { signed_constraint bit(pdd const& p, unsigned i) { return m_constraints.bit(p, i); } - pdd lshr(pdd a, pdd b) { throw default_exception("nyi"); } - pdd ashr(pdd a, pdd b) { throw default_exception("nyi"); } - pdd shl(pdd a, pdd b) { throw default_exception("nyi"); } - pdd band(pdd a, pdd b) { throw default_exception("nyi"); } - pdd bxor(pdd a, pdd b) { throw default_exception("nyi"); } - pdd bor(pdd a, pdd b) { throw default_exception("nyi"); } - pdd bnand(pdd a, pdd b) { throw default_exception("nyi"); } - pdd bxnor(pdd a, pdd b) { throw default_exception("nyi"); } - pdd bnor(pdd a, pdd b) { throw default_exception("nyi"); } - pdd bnot(pdd a) { throw default_exception("nyi"); } - std::pair quot_rem(pdd const& n, pdd const& d) { throw default_exception("nyi"); } - pdd zero_ext(pdd a, unsigned sz) { throw default_exception("nyi"); } - pdd sign_ext(pdd a, unsigned sz) { throw default_exception("nyi"); } - pdd extract(pdd src, unsigned hi, unsigned lo) { throw default_exception("nyi"); } - pdd concat(unsigned n, pdd const* args) { throw default_exception("nyi"); } + pdd lshr(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("lshr nyi"); } + pdd ashr(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("ashr nyi"); } + pdd shl(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("shlh nyi"); } + pdd band(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("band nyi"); } + pdd bxor(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bxor nyi"); } + pdd bor(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bir ==nyi"); } + pdd bnand(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bnand nyi"); } + pdd bxnor(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bxnor nyi"); } + pdd bnor(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bnotr nyi"); } + pdd bnot(pdd a) { NOT_IMPLEMENTED_YET(); throw default_exception("bnot nyi"); } + pdd zero_ext(pdd a, unsigned sz) { NOT_IMPLEMENTED_YET(); throw default_exception("zero ext nyi"); } + pdd sign_ext(pdd a, unsigned sz) { NOT_IMPLEMENTED_YET(); throw default_exception("sign ext nyi"); } + pdd extract(pdd src, unsigned hi, unsigned lo) { NOT_IMPLEMENTED_YET(); throw default_exception("extract nyi"); } + pdd concat(unsigned n, pdd const* args) { NOT_IMPLEMENTED_YET(); throw default_exception("concat nyi"); } pvar add_var(unsigned sz); pdd var(pvar p) { return m_vars[p]; } - unsigned size(pvar v) const { return var2pdd(v).power_of_2(); } + unsigned size(pvar v) const { return m_vars[v].power_of_2(); } constraints& cs() { return m_constraints; } trail_stack& trail(); - std::ostream& display(std::ostream& out) const { throw default_exception("nyi"); } + std::ostream& display(std::ostream& out) const { NOT_IMPLEMENTED_YET(); throw default_exception("nyi"); } }; } diff --git a/src/sat/smt/polysat/polysat_ule.cpp b/src/sat/smt/polysat/polysat_ule.cpp index 0fb01bcae..7482ec36a 100644 --- a/src/sat/smt/polysat/polysat_ule.cpp +++ b/src/sat/smt/polysat/polysat_ule.cpp @@ -343,4 +343,22 @@ namespace polysat { return eval(a.apply_to(lhs()), a.apply_to(rhs())); } + bool ule_constraint::is_always_true() const { + if (lhs().is_zero()) + return true; // 0 <= p + if (rhs().is_max()) + return true; // p <= -1 + if (lhs().is_val() && rhs().is_val()) + return lhs().val() <= rhs().val(); + return false; + } + + bool ule_constraint::is_always_false() const { + if (lhs().is_never_zero() && rhs().is_zero()) + return true; // p > 0, q = 0 + if (lhs().is_val() && rhs().is_val()) + return lhs().val() > rhs().val(); + return false; + } + } diff --git a/src/sat/smt/polysat/polysat_ule.h b/src/sat/smt/polysat/polysat_ule.h index e21ed1029..2944bf614 100644 --- a/src/sat/smt/polysat/polysat_ule.h +++ b/src/sat/smt/polysat/polysat_ule.h @@ -35,6 +35,8 @@ namespace polysat { std::ostream& display(std::ostream& out) const override; lbool eval() const override; lbool eval(assignment const& a) const override; + bool is_always_true() const override; + bool is_always_false() const override; bool is_eq() const { return m_rhs.is_zero(); } unsigned power_of_2() const { return m_lhs.power_of_2(); } diff --git a/src/sat/smt/polysat/polysat_umul_ovfl.h b/src/sat/smt/polysat/polysat_umul_ovfl.h index 41972ef59..65a12c031 100644 --- a/src/sat/smt/polysat/polysat_umul_ovfl.h +++ b/src/sat/smt/polysat/polysat_umul_ovfl.h @@ -34,6 +34,8 @@ namespace polysat { std::ostream& display(std::ostream& out) const override; lbool eval() const override; lbool eval(assignment const& a) const override; + bool is_always_true() const override { return false; } // todo port + bool is_always_false() const override { return false; } }; } diff --git a/src/sat/smt/polysat/polysat_viable.cpp b/src/sat/smt/polysat/polysat_viable.cpp index 4152956de..35158b889 100644 --- a/src/sat/smt/polysat/polysat_viable.cpp +++ b/src/sat/smt/polysat/polysat_viable.cpp @@ -90,6 +90,8 @@ namespace polysat { } lbool viable::find_viable(pvar v, rational& lo, rational& hi) { + return l_undef; + fixed_bits_info fbi; #if 0 @@ -1007,7 +1009,7 @@ namespace polysat { } void viable::log(pvar v) { - throw default_exception("nyi"); + // } diff --git a/src/sat/smt/polysat/polysat_viable.h b/src/sat/smt/polysat/polysat_viable.h index f426dc326..88070a553 100644 --- a/src/sat/smt/polysat/polysat_viable.h +++ b/src/sat/smt/polysat/polysat_viable.h @@ -186,36 +186,36 @@ namespace polysat { template bool refine_viable(pvar v, rational const& val, fixed_bits_info const& fbi) { - throw default_exception("nyi"); + throw default_exception("refine nyi"); } bool refine_viable(pvar v, rational const& val) { - throw default_exception("nyi"); + throw default_exception("refine nyi"); } template bool refine_bits(pvar v, rational const& val, fixed_bits_info const& fbi) { - throw default_exception("nyi"); + throw default_exception("refine nyi"); } template entry* refine_bits(pvar v, rational const& val, unsigned num_bits, fixed_bits_info const& fbi) { - throw default_exception("nyi"); + throw default_exception("refine nyi"); } bool refine_equal_lin(pvar v, rational const& val) { - throw default_exception("nyi"); + throw default_exception("refine nyi"); } bool refine_disequal_lin(pvar v, rational const& val) { - throw default_exception("nyi"); + throw default_exception("refine nyi"); } void set_conflict_by_interval(pvar v, unsigned w, ptr_vector& intervals, unsigned first_interval); bool set_conflict_by_interval_rec(pvar v, unsigned w, entry** intervals, unsigned num_intervals, bool& create_lemma, uint_set& vars_to_explain); std::pair find_value(rational const& val, entry* entries) { - throw default_exception("nyi"); + throw default_exception("fine_value nyi"); } bool collect_bit_information(pvar v, bool add_conflict, fixed_bits_info& out_fbi); diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 95979348d..96f7def0e 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -1,4 +1,4 @@ -/*++ +/*++ Copyright (c) 2022 Microsoft Corporation Module Name: @@ -22,7 +22,9 @@ Author: namespace polysat { euf::theory_var solver::mk_var(euf::enode* n) { - return euf::th_euf_solver::mk_var(n); + theory_var v = euf::th_euf_solver::mk_var(n); + ctx.attach_th_var(n, this, v); + return v; } sat::literal solver::internalize(expr* e, bool sign, bool root) { @@ -115,12 +117,13 @@ namespace polysat { case OP_BSADD_OVFL: case OP_BUSUB_OVFL: case OP_BSSUB_OVFL: + verbose_stream() << mk_pp(a, m) << "\n"; // handled by bv_rewriter for now UNREACHABLE(); break; - case OP_BUDIV_I: internalize_div_rem_i(a, true); break; - case OP_BUREM_I: internalize_div_rem_i(a, false); break; + case OP_BUDIV_I: internalize_udiv_i(a); break; + case OP_BUREM_I: internalize_urem_i(a); break; case OP_BUDIV: internalize_div_rem(a, true); break; case OP_BUREM: internalize_div_rem(a, false); break; @@ -187,17 +190,93 @@ namespace polysat { mk_atom(lit.var(), sc); } - void solver::internalize_div_rem_i(app* e, bool is_div) { - auto p = expr2pdd(e->get_arg(0)); - auto q = expr2pdd(e->get_arg(1)); - auto [quot, rem] = m_core.quot_rem(p, q); - internalize_set(e, is_div ? quot : rem); + void solver::internalize_udiv_i(app* e) { + expr* x, *y; + VERIFY(bv.is_bv_udivi(e, x, y) || bv.is_bv_udiv(e, x, y)); + app_ref rm(bv.mk_bv_urem_i(x, y), m); + internalize(rm); + } + + void solver::internalize_urem_i(app* e) { + expr* x, *y; + if (expr2enode(e)) + return; + VERIFY(bv.is_bv_uremi(e, x, y) || bv.is_bv_urem(e, x, y)); + auto [quot, rem] = quot_rem(x, y); + internalize_set(e, rem); + internalize_set(bv.mk_bv_udiv_i(x, y), quot); + } + + std::pair solver::quot_rem(expr* x, expr* y) { + pdd a = expr2pdd(x); + pdd b = expr2pdd(y); + auto& m = a.manager(); + unsigned sz = m.power_of_2(); + if (b.is_zero()) + // By SMT-LIB specification, b = 0 ==> q = -1, r = a. + return { m.mk_val(m.max_value()), a }; + + if (b.is_one()) + return { a, m.zero() }; + + if (a.is_val() && b.is_val()) { + rational const av = a.val(); + rational const bv = b.val(); + SASSERT(!bv.is_zero()); + rational rv; + rational qv = machine_div_rem(av, bv, rv); + pdd q = m.mk_val(qv); + pdd r = m.mk_val(rv); + SASSERT_EQ(a, b * q + r); + SASSERT(b.val() * q.val() + r.val() <= m.max_value()); + SASSERT(r.val() <= (b * q + r).val()); + SASSERT(r.val() < b.val()); + return { std::move(q), std::move(r) }; + } + + expr* quot = bv.mk_bv_udiv_i(x, y); + expr* rem = bv.mk_bv_urem_i(x, y); + + ctx.internalize(quot); + ctx.internalize(rem); + auto quotv = expr2enode(quot)->get_th_var(get_id()); + auto remv = expr2enode(rem)->get_th_var(get_id()); + + pdd q = var2pdd(quotv); + pdd r = var2pdd(remv); + + // Axioms for quotient/remainder + // + // a = b*q + r + // multiplication does not overflow in b*q + // addition does not overflow in (b*q) + r; for now expressed as: r <= bq+r + // b ≠ 0 ==> r < b + // b = 0 ==> q = -1 + // TODO: when a,b become evaluable, can we actually propagate q,r? doesn't seem like it. + // Maybe we need something like an op_constraint for better propagation. + add_polysat_clause("[axiom] quot_rem 1", { m_core.eq(b * q + r - a) }, false); + add_polysat_clause("[axiom] quot_rem 2", { ~m_core.umul_ovfl(b, q) }, false); + // r <= b*q+r + // { apply equivalence: p <= q <=> q-p <= -p-1 } + // b*q <= -r-1 + add_polysat_clause("[axiom] quot_rem 3", { m_core.ule(b * q, -r - 1) }, false); + + auto c_eq = m_core.eq(b); + if (!c_eq.is_always_true()) + add_polysat_clause("[axiom] quot_rem 4", { c_eq, ~m_core.ule(b, r) }, false); + if (!c_eq.is_always_false()) + add_polysat_clause("[axiom] quot_rem 5", { ~c_eq, m_core.eq(q + 1) }, false); + + return { std::move(q), std::move(r) }; } void solver::internalize_div_rem(app* e, bool is_div) { bv_rewriter_params p(s().params()); if (p.hi_div0()) { - internalize_div_rem_i(e, is_div); + if (bv.is_bv_udivi(e)) + internalize_udiv_i(e); + else + internalize_urem_i(e); return; } expr* arg1 = e->get_arg(0); diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 8baf05de5..e4e022c6e 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -84,6 +84,8 @@ namespace polysat { } } } + UNREACHABLE(); + return sat::check_result::CR_GIVEUP; } void solver::asserted(literal l) { @@ -249,11 +251,11 @@ namespace polysat { return ctx.get_trail_stack(); } - void solver::add_lemma(vector const& lemma) { + void solver::add_polysat_clause(char const* name, std::initializer_list cs, bool is_redundant) { sat::literal_vector lits; - for (auto sc : lemma) + for (auto sc : cs) lits.push_back(ctx.mk_literal(constraint2expr(sc))); - s().add_clause(lits.size(), lits.data(), sat::status::th(true, get_id(), nullptr)); + s().add_clause(lits.size(), lits.data(), sat::status::th(is_redundant, get_id(), nullptr)); } void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) { @@ -271,14 +273,14 @@ namespace polysat { case ckind_t::umul_ovfl_t: { auto l = pdd2expr(sc.to_umul_ovfl().p()); auto r = pdd2expr(sc.to_umul_ovfl().q()); - return expr_ref(bv.mk_bvumul_ovfl(l, r), m); + return expr_ref(m.mk_not(bv.mk_bvumul_no_ovfl(l, r)), m); } case ckind_t::smul_fl_t: case ckind_t::op_t: NOT_IMPLEMENTED_YET(); break; } - throw default_exception("nyi"); + throw default_exception("constraint2expr nyi"); } expr_ref solver::pdd2expr(pdd const& p) { @@ -346,7 +348,7 @@ namespace polysat { continue; for (auto sib : euf::enode_class(p)) { if (bv.is_extract(sib->get_expr(), lo, hi, e) && r == expr2enode(e)->get_root()) { - throw default_exception("nyi"); + throw default_exception("get_fixed nyi"); // TODO // dependency d = dependency(p->get_th_var(get_id()), n->get_th_var(get_id()), s().scope_lvl()); // fixed_bits.push_back({ hi, lo, rational::zero(), null_dependency()}); diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index e1a9221e9..f2c80a6e6 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -113,10 +113,10 @@ namespace polysat { void internalize_extract(app* n); void internalize_repeat(app* n); void internalize_bit2bool(app* n); - void internalize_udiv_i(app* n); template void internalize_le(app* n); - void internalize_div_rem_i(app* e, bool is_div); + void internalize_udiv_i(app* e); + void internalize_urem_i(app* e); void internalize_div_rem(app* e, bool is_div); void internalize_polysat(app* a); void assert_bv2int_axiom(app * n); @@ -126,6 +126,8 @@ namespace polysat { pdd var2pdd(euf::theory_var v); void internalize_set(expr* e, pdd const& p); void internalize_set(euf::theory_var v, pdd const& p); + std::pair quot_rem(expr* x, expr* y); + // callbacks from core void add_eq_literal(pvar v, rational const& val) override; @@ -137,8 +139,7 @@ namespace polysat { bool inconsistent() const override; void get_bitvector_prefixes(pvar v, pvar_vector& out) override; void get_fixed_bits(pvar v, svector& fixed_bits) override; - - void add_lemma(vector const& lemma); + void add_polysat_clause(char const* name, std::initializer_list cs, bool is_redundant); std::pair explain_deps(dependency_vector const& deps); From 21121f14a55111d0b12088a2bd18351e2b139d6b Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 20:48:46 -0800 Subject: [PATCH 29/89] dbg Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat/polysat_core.h | 2 +- src/sat/smt/polysat_solver.h | 1 + src/solver/simplifier_solver.cpp | 18 +++++++++++++----- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/sat/smt/polysat/polysat_core.h b/src/sat/smt/polysat/polysat_core.h index 5262f2bbe..4e9c02118 100644 --- a/src/sat/smt/polysat/polysat_core.h +++ b/src/sat/smt/polysat/polysat_core.h @@ -48,13 +48,13 @@ namespace polysat { lbool value; // value assigned by solver }; solver_interface& s; + mutable scoped_ptr_vector m_pdd; viable m_viable; constraints m_constraints; assignment m_assignment; unsigned m_qhead = 0, m_vqhead = 0; svector m_prop_queue; svector m_constraint_index; // index of constraints - mutable scoped_ptr_vector m_pdd; dependency_vector m_unsat_core; diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index f2c80a6e6..acb8f59a8 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -148,6 +148,7 @@ namespace polysat { public: solver(euf::solver& ctx, theory_id id); + ~solver() override {} void set_lookahead(sat::lookahead* s) override { } void init_search() override {} double get_reward(literal l, sat::ext_constraint_idx idx, sat::literal_occs_fun& occs) const override { return 0; } diff --git a/src/solver/simplifier_solver.cpp b/src/solver/simplifier_solver.cpp index 712c42f45..22c20ae93 100644 --- a/src/solver/simplifier_solver.cpp +++ b/src/solver/simplifier_solver.cpp @@ -289,8 +289,19 @@ public: return m_cached_mc; } - unsigned get_num_assertions() const override { return s->get_num_assertions(); } - expr* get_assertion(unsigned idx) const override { return s->get_assertion(idx); } + unsigned get_num_assertions() const override { + unsigned qhead = m_preprocess_state.qhead(); + unsigned qtail = m_preprocess_state.qtail(); + return s->get_num_assertions() + qtail - qhead; + } + expr* get_assertion(unsigned idx) const override { + unsigned qhead = m_preprocess_state.qhead(); + unsigned qtail = m_preprocess_state.qtail(); + if (idx < qtail - qhead) + return m_fmls.get(idx + qhead).fml(); + idx -= qtail - qhead; + return s->get_assertion(idx); + } std::string reason_unknown() const override { return s->reason_unknown(); } void set_reason_unknown(char const* msg) override { s->set_reason_unknown(msg); } void get_labels(svector& r) override { s->get_labels(r); } @@ -364,9 +375,6 @@ public: expr* congruence_root(expr* e) override { return s->congruence_root(e); } expr* congruence_next(expr* e) override { return s->congruence_next(e); } - std::ostream& display(std::ostream& out, unsigned n, expr* const* assumptions) const override { - return s->display(out, n, assumptions); - } void get_units_core(expr_ref_vector& units) override { s->get_units_core(units); } expr_ref_vector get_trail(unsigned max_level) override { return s->get_trail(max_level); } void get_levels(ptr_vector const& vars, unsigned_vector& depth) override { s->get_levels(vars, depth); } From 83c71b4943b8ed3fe7fd201038e58dec859e01c6 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 22:21:14 -0800 Subject: [PATCH 30/89] fix internalization for quot/rem --- src/ast/euf/euf_bv_plugin.cpp | 4 +- src/sat/smt/intblast_solver.cpp | 14 +++++ src/sat/smt/polysat/polysat_core.cpp | 2 + src/sat/smt/polysat/polysat_viable.cpp | 10 ++-- src/sat/smt/polysat_internalize.cpp | 71 ++++++++++++++++---------- src/sat/smt/polysat_solver.h | 2 +- 6 files changed, 69 insertions(+), 34 deletions(-) diff --git a/src/ast/euf/euf_bv_plugin.cpp b/src/ast/euf/euf_bv_plugin.cpp index 99bf8941b..4766cba12 100644 --- a/src/ast/euf/euf_bv_plugin.cpp +++ b/src/ast/euf/euf_bv_plugin.cpp @@ -104,8 +104,8 @@ namespace euf { } void bv_plugin::merge_eh(enode* x, enode* y) { - SASSERT(x == x->get_root()); - SASSERT(x == y->get_root()); + if (!bv.is_bv(x->get_expr())) + return; TRACE("bv", tout << "merge_eh " << g.bpp(x) << " == " << g.bpp(y) << "\n"); SASSERT(!m_internal); diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 239bb4682..c269fee9c 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -36,6 +36,7 @@ namespace intblast { continue; if (any_of(*clause, [&](auto lit) { return s.value(lit) == l_true && !is_bv(lit); })) continue; + // TBD: if we associate "status" with clauses, we can also remove theory axioms from polysat sat::literal selected_lit = sat::null_literal; for (auto lit : *clause) { if (s.value(lit) != l_true) @@ -269,27 +270,34 @@ namespace intblast { m_trail.push_back(a.mk_mul(args)); break; case OP_ULEQ: + bv_expr = ap->get_arg(0); m_trail.push_back(a.mk_le(mk_mod(args.get(0)), mk_mod(args.get(1)))); break; case OP_UGEQ: + bv_expr = ap->get_arg(0); m_trail.push_back(a.mk_ge(mk_mod(args.get(0)), mk_mod(args.get(1)))); break; case OP_ULT: + bv_expr = ap->get_arg(0); m_trail.push_back(a.mk_lt(mk_mod(args.get(0)), mk_mod(args.get(1)))); break; case OP_UGT: + bv_expr = ap->get_arg(0); m_trail.push_back(a.mk_gt(mk_mod(args.get(0)), mk_mod(args.get(1)))); break; case OP_SLEQ: + bv_expr = ap->get_arg(0); m_trail.push_back(a.mk_le(mk_smod(args.get(0)), mk_smod(args.get(1)))); break; case OP_SGEQ: m_trail.push_back(a.mk_ge(mk_smod(args.get(0)), mk_smod(args.get(1)))); break; case OP_SLT: + bv_expr = ap->get_arg(0); m_trail.push_back(a.mk_lt(mk_smod(args.get(0)), mk_smod(args.get(1)))); break; case OP_SGT: + bv_expr = ap->get_arg(0); m_trail.push_back(a.mk_gt(mk_smod(args.get(0)), mk_smod(args.get(1)))); break; case OP_BNEG: @@ -342,6 +350,12 @@ namespace intblast { m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); break; } + case OP_BUMUL_NO_OVFL: { + expr* x = args.get(0), * y = args.get(1); + bv_expr = ap->get_arg(0); + m_trail.push_back(a.mk_lt(a.mk_mul(mk_mod(x), mk_mod(y)), a.mk_int(bv_size()))); + break; + } case OP_BNOT: case OP_BNAND: case OP_BNOR: diff --git a/src/sat/smt/polysat/polysat_core.cpp b/src/sat/smt/polysat/polysat_core.cpp index f2bb995d7..13964720d 100644 --- a/src/sat/smt/polysat/polysat_core.cpp +++ b/src/sat/smt/polysat/polysat_core.cpp @@ -108,6 +108,7 @@ namespace polysat { m_watch.push_back({}); m_var_queue.mk_var_eh(v); m_viable.ensure_var(v); + m_values.push_back(rational::zero()); s.trail().push(mk_add_var(*this)); return v; } @@ -118,6 +119,7 @@ namespace polysat { m_activity.pop_back(); m_justification.pop_back(); m_watch.pop_back(); + m_values.pop_back(); m_var_queue.del_var_eh(v); } diff --git a/src/sat/smt/polysat/polysat_viable.cpp b/src/sat/smt/polysat/polysat_viable.cpp index 35158b889..a24683ade 100644 --- a/src/sat/smt/polysat/polysat_viable.cpp +++ b/src/sat/smt/polysat/polysat_viable.cpp @@ -67,13 +67,17 @@ namespace polysat { }; viable::entry* viable::alloc_entry(pvar var, unsigned constraint_index) { + entry* e = nullptr; if (m_alloc.empty()) - return alloc(entry); - auto* e = m_alloc.back(); + e = alloc(entry); + else { + e = m_alloc.back(); + m_alloc.pop_back(); + } e->reset(); e->var = var; e->constraint_index = constraint_index; - m_alloc.pop_back(); + return e; } diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 96f7def0e..1b3cb1c04 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -192,32 +192,54 @@ namespace polysat { void solver::internalize_udiv_i(app* e) { expr* x, *y; - VERIFY(bv.is_bv_udivi(e, x, y) || bv.is_bv_udiv(e, x, y)); - app_ref rm(bv.mk_bv_urem_i(x, y), m); + expr_ref rm(m); + if (bv.is_bv_udivi(e, x, y)) + rm = bv.mk_bv_urem_i(x, y); + else if (bv.is_bv_udiv(e, x, y)) + rm = bv.mk_bv_urem(x, y); + else + UNREACHABLE(); internalize(rm); } - void solver::internalize_urem_i(app* e) { + void solver::internalize_urem_i(app* rem) { expr* x, *y; - if (expr2enode(e)) + euf::enode* n = expr2enode(rem); + SASSERT(n && n->is_attached_to(get_id())); + theory_var v = n->get_th_var(get_id()); + if (m_var2pdd_valid.get(v, false)) return; - VERIFY(bv.is_bv_uremi(e, x, y) || bv.is_bv_urem(e, x, y)); - auto [quot, rem] = quot_rem(x, y); - internalize_set(e, rem); - internalize_set(bv.mk_bv_udiv_i(x, y), quot); + expr_ref quot(m); + if (bv.is_bv_uremi(rem, x, y)) + quot = bv.mk_bv_udiv_i(x, y); + else if (bv.is_bv_urem(rem, x, y)) + quot = bv.mk_bv_udiv(x, y); + else + UNREACHABLE(); + m_var2pdd_valid.setx(v, true, false); + ctx.internalize(quot); + m_var2pdd_valid.setx(v, false, false); + quot_rem(quot, rem, x, y); } - - std::pair solver::quot_rem(expr* x, expr* y) { + + void solver::quot_rem(expr* quot, expr* rem, expr* x, expr* y) { pdd a = expr2pdd(x); pdd b = expr2pdd(y); + euf::enode* qn = expr2enode(quot); + euf::enode* rn = expr2enode(rem); auto& m = a.manager(); unsigned sz = m.power_of_2(); - if (b.is_zero()) + if (b.is_zero()) { // By SMT-LIB specification, b = 0 ==> q = -1, r = a. - return { m.mk_val(m.max_value()), a }; - - if (b.is_one()) - return { a, m.zero() }; + internalize_set(quot, m.mk_val(m.max_value())); + internalize_set(rem, a); + return; + } + if (b.is_one()) { + internalize_set(quot, a); + internalize_set(rem, m.zero()); + return; + } if (a.is_val() && b.is_val()) { rational const av = a.val(); @@ -231,19 +253,13 @@ namespace polysat { SASSERT(b.val() * q.val() + r.val() <= m.max_value()); SASSERT(r.val() <= (b * q + r).val()); SASSERT(r.val() < b.val()); - return { std::move(q), std::move(r) }; - } + internalize_set(quot, q); + internalize_set(rem, r); + return; + } - expr* quot = bv.mk_bv_udiv_i(x, y); - expr* rem = bv.mk_bv_urem_i(x, y); - - ctx.internalize(quot); - ctx.internalize(rem); - auto quotv = expr2enode(quot)->get_th_var(get_id()); - auto remv = expr2enode(rem)->get_th_var(get_id()); - - pdd q = var2pdd(quotv); - pdd r = var2pdd(remv); + pdd r = var2pdd(rn->get_th_var(get_id())); + pdd q = var2pdd(qn->get_th_var(get_id())); // Axioms for quotient/remainder // @@ -267,7 +283,6 @@ namespace polysat { if (!c_eq.is_always_false()) add_polysat_clause("[axiom] quot_rem 5", { ~c_eq, m_core.eq(q + 1) }, false); - return { std::move(q), std::move(r) }; } void solver::internalize_div_rem(app* e, bool is_div) { diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index acb8f59a8..1d1356b0e 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -126,7 +126,7 @@ namespace polysat { pdd var2pdd(euf::theory_var v); void internalize_set(expr* e, pdd const& p); void internalize_set(euf::theory_var v, pdd const& p); - std::pair quot_rem(expr* x, expr* y); + void quot_rem(expr* quot, expr* rem, expr* x, expr* y); // callbacks from core From 6518d71c6d18f7d43bd7c5b6177df36935c3086e Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 22:47:27 -0800 Subject: [PATCH 31/89] rename polysat files to exclude namespace Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat/CMakeLists.txt | 14 +++++++------- .../{polysat_assignment.cpp => assignment.cpp} | 4 ++-- .../polysat/{polysat_assignment.h => assignment.h} | 2 +- .../{polysat_constraints.cpp => constraints.cpp} | 8 ++++---- .../{polysat_constraints.h => constraints.h} | 2 +- src/sat/smt/polysat/{polysat_core.cpp => core.cpp} | 13 ++++++++++++- src/sat/smt/polysat/{polysat_core.h => core.h} | 10 +++++----- src/sat/smt/polysat/fixed_bits.cpp | 2 +- src/sat/smt/polysat/fixed_bits.h | 4 ++-- .../{polysat_fi.cpp => forbidden_intervals.cpp} | 10 +++++----- .../{polysat_fi.h => forbidden_intervals.h} | 6 +++--- .../smt/polysat/{polysat_interval.h => interval.h} | 2 +- src/sat/smt/polysat/{polysat_types.h => types.h} | 0 .../{polysat_ule.cpp => ule_constraint.cpp} | 4 ++-- .../polysat/{polysat_ule.h => ule_constraint.h} | 4 ++-- ...ysat_umul_ovfl.cpp => umul_ovfl_constraint.cpp} | 6 +++--- ...{polysat_umul_ovfl.h => umul_ovfl_constraint.h} | 2 +- .../smt/polysat/{polysat_viable.cpp => viable.cpp} | 6 +++--- src/sat/smt/polysat/{polysat_viable.h => viable.h} | 4 ++-- src/sat/smt/polysat_solver.cpp | 4 ++-- src/sat/smt/polysat_solver.h | 2 +- 21 files changed, 60 insertions(+), 49 deletions(-) rename src/sat/smt/polysat/{polysat_assignment.cpp => assignment.cpp} (97%) rename src/sat/smt/polysat/{polysat_assignment.h => assignment.h} (98%) rename src/sat/smt/polysat/{polysat_constraints.cpp => constraints.cpp} (91%) rename src/sat/smt/polysat/{polysat_constraints.h => constraints.h} (99%) rename src/sat/smt/polysat/{polysat_core.cpp => core.cpp} (95%) rename src/sat/smt/polysat/{polysat_core.h => core.h} (95%) rename src/sat/smt/polysat/{polysat_fi.cpp => forbidden_intervals.cpp} (98%) rename src/sat/smt/polysat/{polysat_fi.h => forbidden_intervals.h} (96%) rename src/sat/smt/polysat/{polysat_interval.h => interval.h} (99%) rename src/sat/smt/polysat/{polysat_types.h => types.h} (100%) rename src/sat/smt/polysat/{polysat_ule.cpp => ule_constraint.cpp} (99%) rename src/sat/smt/polysat/{polysat_ule.h => ule_constraint.h} (95%) rename src/sat/smt/polysat/{polysat_umul_ovfl.cpp => umul_ovfl_constraint.cpp} (92%) rename src/sat/smt/polysat/{polysat_umul_ovfl.h => umul_ovfl_constraint.h} (96%) rename src/sat/smt/polysat/{polysat_viable.cpp => viable.cpp} (99%) rename src/sat/smt/polysat/{polysat_viable.h => viable.h} (99%) diff --git a/src/sat/smt/polysat/CMakeLists.txt b/src/sat/smt/polysat/CMakeLists.txt index 6c8bed74d..1f943f48d 100644 --- a/src/sat/smt/polysat/CMakeLists.txt +++ b/src/sat/smt/polysat/CMakeLists.txt @@ -1,13 +1,13 @@ z3_add_component(polysat SOURCES + assignment.cpp + constraints.cpp + core.cpp fixed_bits.cpp - polysat_assignment.cpp - polysat_constraints.cpp - polysat_core.cpp - polysat_fi.cpp - polysat_ule.cpp - polysat_umul_ovfl.cpp - polysat_viable.cpp + forbidden_intervals.cpp + ule_constraint.cpp + umul_ovfl_constraint.cpp + viable.cpp COMPONENT_DEPENDENCIES util dd diff --git a/src/sat/smt/polysat/polysat_assignment.cpp b/src/sat/smt/polysat/assignment.cpp similarity index 97% rename from src/sat/smt/polysat/polysat_assignment.cpp rename to src/sat/smt/polysat/assignment.cpp index 329733d89..c1620891a 100644 --- a/src/sat/smt/polysat/polysat_assignment.cpp +++ b/src/sat/smt/polysat/assignment.cpp @@ -13,8 +13,8 @@ Author: --*/ #include -#include "sat/smt/polysat/polysat_assignment.h" -#include "sat/smt/polysat/polysat_core.h" +#include "sat/smt/polysat/assignment.h" +#include "sat/smt/polysat/core.h" namespace polysat { diff --git a/src/sat/smt/polysat/polysat_assignment.h b/src/sat/smt/polysat/assignment.h similarity index 98% rename from src/sat/smt/polysat/polysat_assignment.h rename to src/sat/smt/polysat/assignment.h index 559f6dab2..823d991b9 100644 --- a/src/sat/smt/polysat/polysat_assignment.h +++ b/src/sat/smt/polysat/assignment.h @@ -13,7 +13,7 @@ Author: --*/ #pragma once #include "util/scoped_ptr_vector.h" -#include "sat/smt/polysat/polysat_types.h" +#include "sat/smt/polysat/types.h" namespace polysat { diff --git a/src/sat/smt/polysat/polysat_constraints.cpp b/src/sat/smt/polysat/constraints.cpp similarity index 91% rename from src/sat/smt/polysat/polysat_constraints.cpp rename to src/sat/smt/polysat/constraints.cpp index faeea62a4..0de987693 100644 --- a/src/sat/smt/polysat/polysat_constraints.cpp +++ b/src/sat/smt/polysat/constraints.cpp @@ -13,10 +13,10 @@ Author: --*/ #include "util/log.h" -#include "sat/smt/polysat/polysat_core.h" -#include "sat/smt/polysat/polysat_constraints.h" -#include "sat/smt/polysat/polysat_ule.h" -#include "sat/smt/polysat/polysat_umul_ovfl.h" +#include "sat/smt/polysat/core.h" +#include "sat/smt/polysat/constraints.h" +#include "sat/smt/polysat/ule_constraint.h" +#include "sat/smt/polysat/umul_ovfl_constraint.h" namespace polysat { diff --git a/src/sat/smt/polysat/polysat_constraints.h b/src/sat/smt/polysat/constraints.h similarity index 99% rename from src/sat/smt/polysat/polysat_constraints.h rename to src/sat/smt/polysat/constraints.h index fdef902e2..81ba6f6a0 100644 --- a/src/sat/smt/polysat/polysat_constraints.h +++ b/src/sat/smt/polysat/constraints.h @@ -15,7 +15,7 @@ Author: #pragma once #include "util/trail.h" -#include "sat/smt/polysat/polysat_types.h" +#include "sat/smt/polysat/types.h" namespace polysat { diff --git a/src/sat/smt/polysat/polysat_core.cpp b/src/sat/smt/polysat/core.cpp similarity index 95% rename from src/sat/smt/polysat/polysat_core.cpp rename to src/sat/smt/polysat/core.cpp index 13964720d..8e779923d 100644 --- a/src/sat/smt/polysat/polysat_core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -28,7 +28,7 @@ polysat::core --*/ -#include "sat/smt/polysat/polysat_core.h" +#include "sat/smt/polysat/core.h" namespace polysat { @@ -335,4 +335,15 @@ namespace polysat { return s.trail(); } + std::ostream& core::display(std::ostream& out) const { + if (m_constraint_index.empty() && m_vars.empty()) + return out; + out << "polysat:\n"; + for (auto const& [sc, d, value] : m_constraint_index) + out << sc << " " << d << " := " << value << "\n"; + for (unsigned i = 0; i < m_vars.size(); ++i) + out << "p" << m_vars[i] << " := " << m_values[i] << " " << m_justification[i] << "\n"; + return out; + } + } diff --git a/src/sat/smt/polysat/polysat_core.h b/src/sat/smt/polysat/core.h similarity index 95% rename from src/sat/smt/polysat/polysat_core.h rename to src/sat/smt/polysat/core.h index 4e9c02118..32ad84fa3 100644 --- a/src/sat/smt/polysat/polysat_core.h +++ b/src/sat/smt/polysat/core.h @@ -21,10 +21,10 @@ Author: #include "util/dependency.h" #include "math/dd/dd_pdd.h" #include "sat/sat_extension.h" -#include "sat/smt/polysat/polysat_types.h" -#include "sat/smt/polysat/polysat_constraints.h" -#include "sat/smt/polysat/polysat_viable.h" -#include "sat/smt/polysat/polysat_assignment.h" +#include "sat/smt/polysat/types.h" +#include "sat/smt/polysat/constraints.h" +#include "sat/smt/polysat/viable.h" +#include "sat/smt/polysat/assignment.h" namespace polysat { @@ -136,7 +136,7 @@ namespace polysat { constraints& cs() { return m_constraints; } trail_stack& trail(); - std::ostream& display(std::ostream& out) const { NOT_IMPLEMENTED_YET(); throw default_exception("nyi"); } + std::ostream& display(std::ostream& out) const; }; } diff --git a/src/sat/smt/polysat/fixed_bits.cpp b/src/sat/smt/polysat/fixed_bits.cpp index 9b67c883d..85b35de66 100644 --- a/src/sat/smt/polysat/fixed_bits.cpp +++ b/src/sat/smt/polysat/fixed_bits.cpp @@ -12,7 +12,7 @@ Author: --*/ #include "sat/smt/polysat/fixed_bits.h" -#include "sat/smt/polysat/polysat_ule.h" +#include "sat/smt/polysat/ule_constraint.h" namespace polysat { diff --git a/src/sat/smt/polysat/fixed_bits.h b/src/sat/smt/polysat/fixed_bits.h index 78b4a643f..f07b2c7d2 100644 --- a/src/sat/smt/polysat/fixed_bits.h +++ b/src/sat/smt/polysat/fixed_bits.h @@ -11,8 +11,8 @@ Author: --*/ #pragma once -#include "sat/smt/polysat/polysat_types.h" -#include "sat/smt/polysat/polysat_constraints.h" +#include "sat/smt/polysat/types.h" +#include "sat/smt/polysat/constraints.h" #include "util/vector.h" namespace polysat { diff --git a/src/sat/smt/polysat/polysat_fi.cpp b/src/sat/smt/polysat/forbidden_intervals.cpp similarity index 98% rename from src/sat/smt/polysat/polysat_fi.cpp rename to src/sat/smt/polysat/forbidden_intervals.cpp index e54fb5cea..1c0494429 100644 --- a/src/sat/smt/polysat/polysat_fi.cpp +++ b/src/sat/smt/polysat/forbidden_intervals.cpp @@ -13,11 +13,11 @@ Author: Nikolaj Bjorner (nbjorner) 2021-03-19 --*/ -#include "sat/smt/polysat/polysat_fi.h" -#include "sat/smt/polysat/polysat_interval.h" -#include "sat/smt/polysat/polysat_umul_ovfl.h" -#include "sat/smt/polysat/polysat_ule.h" -#include "sat/smt/polysat/polysat_core.h" +#include "sat/smt/polysat/forbidden_intervals.h" +#include "sat/smt/polysat/interval.h" +#include "sat/smt/polysat/umul_ovfl_constraint.h" +#include "sat/smt/polysat/ule_constraint.h" +#include "sat/smt/polysat/core.h" namespace polysat { diff --git a/src/sat/smt/polysat/polysat_fi.h b/src/sat/smt/polysat/forbidden_intervals.h similarity index 96% rename from src/sat/smt/polysat/polysat_fi.h rename to src/sat/smt/polysat/forbidden_intervals.h index e1f876c3c..b790da1c8 100644 --- a/src/sat/smt/polysat/polysat_fi.h +++ b/src/sat/smt/polysat/forbidden_intervals.h @@ -14,9 +14,9 @@ Author: --*/ #pragma once -#include "sat/smt/polysat/polysat_types.h" -#include "sat/smt/polysat/polysat_interval.h" -#include "sat/smt/polysat/polysat_constraints.h" +#include "sat/smt/polysat/types.h" +#include "sat/smt/polysat/interval.h" +#include "sat/smt/polysat/constraints.h" namespace polysat { diff --git a/src/sat/smt/polysat/polysat_interval.h b/src/sat/smt/polysat/interval.h similarity index 99% rename from src/sat/smt/polysat/polysat_interval.h rename to src/sat/smt/polysat/interval.h index 0299f83b3..27e62eef9 100644 --- a/src/sat/smt/polysat/polysat_interval.h +++ b/src/sat/smt/polysat/interval.h @@ -12,7 +12,7 @@ Author: --*/ #pragma once -#include "sat/smt/polysat/polysat_types.h" +#include "sat/smt/polysat/types.h" #include namespace polysat { diff --git a/src/sat/smt/polysat/polysat_types.h b/src/sat/smt/polysat/types.h similarity index 100% rename from src/sat/smt/polysat/polysat_types.h rename to src/sat/smt/polysat/types.h diff --git a/src/sat/smt/polysat/polysat_ule.cpp b/src/sat/smt/polysat/ule_constraint.cpp similarity index 99% rename from src/sat/smt/polysat/polysat_ule.cpp rename to src/sat/smt/polysat/ule_constraint.cpp index 7482ec36a..3d6240bad 100644 --- a/src/sat/smt/polysat/polysat_ule.cpp +++ b/src/sat/smt/polysat/ule_constraint.cpp @@ -70,8 +70,8 @@ Useful lemmas: --*/ -#include "sat/smt/polysat/polysat_constraints.h" -#include "sat/smt/polysat/polysat_ule.h" +#include "sat/smt/polysat/constraints.h" +#include "sat/smt/polysat/ule_constraint.h" #define LOG(_msg_) verbose_stream() << _msg_ << "\n" diff --git a/src/sat/smt/polysat/polysat_ule.h b/src/sat/smt/polysat/ule_constraint.h similarity index 95% rename from src/sat/smt/polysat/polysat_ule.h rename to src/sat/smt/polysat/ule_constraint.h index 2944bf614..0d481c5ea 100644 --- a/src/sat/smt/polysat/polysat_ule.h +++ b/src/sat/smt/polysat/ule_constraint.h @@ -12,8 +12,8 @@ Author: --*/ #pragma once -#include "sat/smt/polysat/polysat_assignment.h" -#include "sat/smt/polysat/polysat_constraints.h" +#include "sat/smt/polysat/assignment.h" +#include "sat/smt/polysat/constraints.h" namespace polysat { diff --git a/src/sat/smt/polysat/polysat_umul_ovfl.cpp b/src/sat/smt/polysat/umul_ovfl_constraint.cpp similarity index 92% rename from src/sat/smt/polysat/polysat_umul_ovfl.cpp rename to src/sat/smt/polysat/umul_ovfl_constraint.cpp index dfe400603..e7dc5801c 100644 --- a/src/sat/smt/polysat/polysat_umul_ovfl.cpp +++ b/src/sat/smt/polysat/umul_ovfl_constraint.cpp @@ -10,9 +10,9 @@ Author: Jakob Rath, Nikolaj Bjorner (nbjorner) 2021-12-09 --*/ -#include "sat/smt/polysat/polysat_constraints.h" -#include "sat/smt/polysat/polysat_assignment.h" -#include "sat/smt/polysat/polysat_umul_ovfl.h" +#include "sat/smt/polysat/constraints.h" +#include "sat/smt/polysat/assignment.h" +#include "sat/smt/polysat/umul_ovfl_constraint.h" namespace polysat { diff --git a/src/sat/smt/polysat/polysat_umul_ovfl.h b/src/sat/smt/polysat/umul_ovfl_constraint.h similarity index 96% rename from src/sat/smt/polysat/polysat_umul_ovfl.h rename to src/sat/smt/polysat/umul_ovfl_constraint.h index 65a12c031..4ac03dfb3 100644 --- a/src/sat/smt/polysat/polysat_umul_ovfl.h +++ b/src/sat/smt/polysat/umul_ovfl_constraint.h @@ -11,7 +11,7 @@ Author: --*/ #pragma once -#include "sat/smt/polysat/polysat_constraints.h" +#include "sat/smt/polysat/constraints.h" namespace polysat { diff --git a/src/sat/smt/polysat/polysat_viable.cpp b/src/sat/smt/polysat/viable.cpp similarity index 99% rename from src/sat/smt/polysat/polysat_viable.cpp rename to src/sat/smt/polysat/viable.cpp index a24683ade..20a31b730 100644 --- a/src/sat/smt/polysat/polysat_viable.cpp +++ b/src/sat/smt/polysat/viable.cpp @@ -18,9 +18,9 @@ Notes: #include "util/debug.h" #include "util/log.h" -#include "sat/smt/polysat/polysat_viable.h" -#include "sat/smt/polysat/polysat_core.h" -#include "sat/smt/polysat/polysat_ule.h" +#include "sat/smt/polysat/viable.h" +#include "sat/smt/polysat/core.h" +#include "sat/smt/polysat/ule_constraint.h" namespace polysat { diff --git a/src/sat/smt/polysat/polysat_viable.h b/src/sat/smt/polysat/viable.h similarity index 99% rename from src/sat/smt/polysat/polysat_viable.h rename to src/sat/smt/polysat/viable.h index 88070a553..5f5af7616 100644 --- a/src/sat/smt/polysat/polysat_viable.h +++ b/src/sat/smt/polysat/viable.h @@ -21,8 +21,8 @@ Author: #include "util/map.h" #include "util/small_object_allocator.h" -#include "sat/smt/polysat/polysat_types.h" -#include "sat/smt/polysat/polysat_fi.h" +#include "sat/smt/polysat/types.h" +#include "sat/smt/polysat/forbidden_intervals.h" namespace polysat { diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index e4e022c6e..82a61486a 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -27,8 +27,8 @@ The result of polysat::core::check is one of: #include "ast/euf/euf_bv_plugin.h" #include "sat/smt/polysat_solver.h" #include "sat/smt/euf_solver.h" -#include "sat/smt/polysat/polysat_ule.h" -#include "sat/smt/polysat/polysat_umul_ovfl.h" +#include "sat/smt/polysat/ule_constraint.h" +#include "sat/smt/polysat/umul_ovfl_constraint.h" diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 1d1356b0e..5a9c26cb8 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -18,7 +18,7 @@ Author: #include "sat/smt/sat_th.h" #include "math/dd/dd_pdd.h" -#include "sat/smt/polysat/polysat_core.h" +#include "sat/smt/polysat/core.h" #include "sat/smt/intblast_solver.h" namespace euf { From dc690307ff97e015d3af048d5dccb20ff22c17a8 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 23:14:24 -0800 Subject: [PATCH 32/89] sign and zero extend --- src/sat/smt/polysat/core.h | 4 --- src/sat/smt/polysat_internalize.cpp | 46 ++++++++++++++++++++++------- src/sat/smt/polysat_solver.cpp | 13 ++++---- src/sat/smt/polysat_solver.h | 2 ++ 4 files changed, 43 insertions(+), 22 deletions(-) diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index 32ad84fa3..b303ff8b8 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -125,10 +125,6 @@ namespace polysat { pdd bxnor(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bxnor nyi"); } pdd bnor(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bnotr nyi"); } pdd bnot(pdd a) { NOT_IMPLEMENTED_YET(); throw default_exception("bnot nyi"); } - pdd zero_ext(pdd a, unsigned sz) { NOT_IMPLEMENTED_YET(); throw default_exception("zero ext nyi"); } - pdd sign_ext(pdd a, unsigned sz) { NOT_IMPLEMENTED_YET(); throw default_exception("sign ext nyi"); } - pdd extract(pdd src, unsigned hi, unsigned lo) { NOT_IMPLEMENTED_YET(); throw default_exception("extract nyi"); } - pdd concat(unsigned n, pdd const* args) { NOT_IMPLEMENTED_YET(); throw default_exception("concat nyi"); } pvar add_var(unsigned sz); pdd var(pvar p) { return m_vars[p]; } unsigned size(pvar v) const { return m_vars[v].power_of_2(); } diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 1b3cb1c04..387b2e5a7 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -135,8 +135,8 @@ namespace polysat { case OP_EXTRACT: internalize_extract(a); break; case OP_CONCAT: internalize_concat(a); break; - case OP_ZERO_EXT: internalize_par_unary(a, [&](pdd const& p, unsigned sz) { return m_core.zero_ext(p, sz); }); break; - case OP_SIGN_EXT: internalize_par_unary(a, [&](pdd const& p, unsigned sz) { return m_core.sign_ext(p, sz); }); break; + case OP_ZERO_EXT: internalize_zero_extend(a); break; + case OP_SIGN_EXT: internalize_sign_extend(a); break; // polysat::solver should also support at least: case OP_BREDAND: // x == 2^K - 1 unary, return single bit, 1 if all input bits are set. @@ -282,7 +282,38 @@ namespace polysat { add_polysat_clause("[axiom] quot_rem 4", { c_eq, ~m_core.ule(b, r) }, false); if (!c_eq.is_always_false()) add_polysat_clause("[axiom] quot_rem 5", { ~c_eq, m_core.eq(q + 1) }, false); + } + void solver::internalize_sign_extend(app* e) { + expr* arg = e->get_arg(0); + unsigned sz = bv.get_bv_size(e); + unsigned arg_sz = bv.get_bv_size(arg); + unsigned sz2 = sz - arg_sz; + + var2pdd(expr2enode(e)->get_th_var(get_id())); + + if (arg_sz == sz) + add_clause(eq_internalize(e, arg), false); + else { + sat::literal lt0 = ctx.mk_literal(bv.mk_slt(arg, bv.mk_numeral(0, arg_sz))); + // arg < 0 ==> e = concat(arg, 1...1) + // arg >= 0 ==> e = concat(arg, 0...0) + add_clause(lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(rational::power_of_two(sz2) - 1, sz2))), false); + add_clause(~lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz - arg_sz))), false); + } + } + + void solver::internalize_zero_extend(app* e) { + expr* arg = e->get_arg(0); + unsigned sz = bv.get_bv_size(e); + unsigned arg_sz = bv.get_bv_size(arg); + unsigned sz2 = sz - arg_sz; + var2pdd(expr2enode(e)->get_th_var(get_id())); + if (arg_sz == sz) + add_clause(eq_internalize(e, arg), false); + else + // e = concat(arg, 0...0) + add_clause(eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz - arg_sz))), false); } void solver::internalize_div_rem(app* e, bool is_div) { @@ -332,20 +363,13 @@ namespace polysat { } void solver::internalize_extract(app* e) { - unsigned const hi = bv.get_extract_high(e); - unsigned const lo = bv.get_extract_low(e); - auto const src = expr2pdd(e->get_arg(0)); - auto const p = m_core.extract(src, hi, lo); - SASSERT_EQ(p.power_of_2(), hi - lo + 1); + auto p = var2pdd(expr2enode(e)->get_th_var(get_id())); internalize_set(e, p); } void solver::internalize_concat(app* e) { SASSERT(bv.is_concat(e)); - vector args; - for (expr* arg : *e) - args.push_back(expr2pdd(arg)); - auto const p = m_core.concat(args.size(), args.data()); + auto p = var2pdd(expr2enode(e)->get_th_var(get_id())); internalize_set(e, p); } diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 82a61486a..ad4beb561 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -30,9 +30,6 @@ The result of polysat::core::check is one of: #include "sat/smt/polysat/ule_constraint.h" #include "sat/smt/polysat/umul_ovfl_constraint.h" - - - namespace polysat { solver::solver(euf::solver& ctx, theory_id id): @@ -288,11 +285,13 @@ namespace polysat { expr* n = bv.mk_numeral(p.val(), p.power_of_2()); return expr_ref(n, m); } - auto lo = pdd2expr(p.lo()); - auto hi = pdd2expr(p.hi()); auto v = var2enode(m_pddvar2var[p.var()]); - hi = bv.mk_bv_mul(v->get_expr(), hi); - return expr_ref(bv.mk_bv_add(lo, hi), m); + expr* r = v->get_expr(); + if (!p.hi().is_one()) + r = bv.mk_bv_mul(r, pdd2expr(p.hi())); + if (!p.lo().is_zero()) + r = bv.mk_bv_add(r, pdd2expr(p.lo())); + return expr_ref(r, m); } // walk the egraph starting with pvar for overlaps. diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 5a9c26cb8..3b5eb27ba 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -115,6 +115,8 @@ namespace polysat { void internalize_bit2bool(app* n); template void internalize_le(app* n); + void internalize_zero_extend(app* n); + void internalize_sign_extend(app* n); void internalize_udiv_i(app* e); void internalize_urem_i(app* e); void internalize_div_rem(app* e, bool is_div); From 286932684a4bdea218f402fdbd425fdddf8fe0b9 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 23:15:03 -0800 Subject: [PATCH 33/89] sign and zero extend Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat_internalize.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 387b2e5a7..709ae5da3 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -289,9 +289,7 @@ namespace polysat { unsigned sz = bv.get_bv_size(e); unsigned arg_sz = bv.get_bv_size(arg); unsigned sz2 = sz - arg_sz; - var2pdd(expr2enode(e)->get_th_var(get_id())); - if (arg_sz == sz) add_clause(eq_internalize(e, arg), false); else { @@ -299,7 +297,7 @@ namespace polysat { // arg < 0 ==> e = concat(arg, 1...1) // arg >= 0 ==> e = concat(arg, 0...0) add_clause(lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(rational::power_of_two(sz2) - 1, sz2))), false); - add_clause(~lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz - arg_sz))), false); + add_clause(~lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz2))), false); } } @@ -313,7 +311,7 @@ namespace polysat { add_clause(eq_internalize(e, arg), false); else // e = concat(arg, 0...0) - add_clause(eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz - arg_sz))), false); + add_clause(eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz2))), false); } void solver::internalize_div_rem(app* e, bool is_div) { From f89de2b45534857e57716ac95cf60e4af501ee88 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 23:49:23 -0800 Subject: [PATCH 34/89] more internalize cases --- src/ast/bv_decl_plugin.h | 6 ++++ src/sat/smt/polysat/core.h | 9 ++---- src/sat/smt/polysat_internalize.cpp | 49 +++++++++++++++++++++++++---- src/sat/smt/polysat_solver.h | 6 ++++ 4 files changed, 58 insertions(+), 12 deletions(-) diff --git a/src/ast/bv_decl_plugin.h b/src/ast/bv_decl_plugin.h index 9126b97b7..4eeac49ee 100644 --- a/src/ast/bv_decl_plugin.h +++ b/src/ast/bv_decl_plugin.h @@ -411,6 +411,12 @@ public: MATCH_BINARY(is_bv_sdiv); MATCH_BINARY(is_bv_udiv); MATCH_BINARY(is_bv_smod); + MATCH_BINARY(is_bv_and); + MATCH_BINARY(is_bv_or); + MATCH_BINARY(is_bv_xor); + MATCH_BINARY(is_bv_nand); + MATCH_BINARY(is_bv_nor); + MATCH_BINARY(is_bv_uremi); MATCH_BINARY(is_bv_sremi); diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index b303ff8b8..b70e70d9b 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -119,12 +119,9 @@ namespace polysat { pdd ashr(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("ashr nyi"); } pdd shl(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("shlh nyi"); } pdd band(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("band nyi"); } - pdd bxor(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bxor nyi"); } - pdd bor(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bir ==nyi"); } - pdd bnand(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bnand nyi"); } - pdd bxnor(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bxnor nyi"); } - pdd bnor(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("bnotr nyi"); } - pdd bnot(pdd a) { NOT_IMPLEMENTED_YET(); throw default_exception("bnot nyi"); } + pdd bnot(pdd a) { return -a - 1; } + + pvar add_var(unsigned sz); pdd var(pvar p) { return m_vars[p]; } unsigned size(pvar v) const { return m_vars[v].power_of_2(); } diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 709ae5da3..0afcf06d4 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -87,11 +87,11 @@ namespace polysat { case OP_BLSHR: internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.lshr(p, q); }); break; case OP_BSHL: internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.shl(p, q); }); break; case OP_BAND: internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.band(p, q); }); break; - case OP_BOR: internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.bor(p, q); }); break; - case OP_BXOR: internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.bxor(p, q); }); break; - case OP_BNAND: if_unary(m_core.bnot); internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.bnand(p, q); }); break; - case OP_BNOR: if_unary(m_core.bnot); internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.bnor(p, q); }); break; - case OP_BXNOR: if_unary(m_core.bnot); internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.bxnor(p, q); }); break; + case OP_BOR: internalize_bor(a); break; + case OP_BXOR: internalize_bxor(a); break; + case OP_BNAND: if_unary(m_core.bnot); internalize_bnand(a); break; + case OP_BNOR: if_unary(m_core.bnot); internalize_bnor(a); break; + case OP_BXNOR: if_unary(m_core.bnot); internalize_bxnor(a); break; case OP_BNOT: internalize_unary(a, [&](pdd const& p) { return m_core.bnot(p); }); break; case OP_BNEG: internalize_unary(a, [&](pdd const& p) { return -p; }); break; case OP_MKBV: internalize_mkbv(a); break; @@ -202,6 +202,34 @@ namespace polysat { internalize(rm); } + // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; + // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 + // (p + q) - band(p, q); + void solver::internalize_bor(app* n) { + internalize_binary(n, [&](expr* const& x, expr* const& y) { return bv.mk_bv_sub(bv.mk_bv_add(x, y), bv.mk_bv_and(x, y)); }); + } + + // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; + // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 + // (p + q) - 2*band(p, q); + void solver::internalize_bxor(app* n) { + internalize_binary(n, [&](expr* const& x, expr* const& y) { + return bv.mk_bv_sub(bv.mk_bv_add(x, y), bv.mk_bv_add(bv.mk_bv_and(x, y), bv.mk_bv_and(x, y))); + }); + } + + void solver::internalize_bnor(app* n) { + internalize_binary(n, [&](expr* const& x, expr* const& y) { return bv.mk_bv_not(bv.mk_bv_or(x, y)); }); + } + + void solver::internalize_bnand(app* n) { + internalize_binary(n, [&](expr* const& x, expr* const& y) { return bv.mk_bv_not(bv.mk_bv_and(x, y)); }); + } + + void solver::internalize_bxnor(app* n) { + internalize_binary(n, [&](expr* const& x, expr* const& y) { return bv.mk_bv_not(bv.mk_bv_xor(x, y)); }); + } + void solver::internalize_urem_i(app* rem) { expr* x, *y; euf::enode* n = expr2enode(rem); @@ -317,7 +345,7 @@ namespace polysat { void solver::internalize_div_rem(app* e, bool is_div) { bv_rewriter_params p(s().params()); if (p.hi_div0()) { - if (bv.is_bv_udivi(e)) + if (is_div) internalize_udiv_i(e); else internalize_urem_i(e); @@ -385,6 +413,15 @@ namespace polysat { internalize_set(e, p); } + void solver::internalize_binary(app* e, std::function const& fn) { + SASSERT(e->get_num_args() >= 1); + expr* r = e->get_arg(0); + for (unsigned i = 1; i < e->get_num_args(); ++i) + r = fn(r, e->get_arg(i)); + ctx.internalize(r); + internalize_set(e, var2pdd(expr2enode(r)->get_th_var(get_id()))); + } + void solver::internalize_unary(app* e, std::function const& fn) { SASSERT(e->get_num_args() == 1); auto p = expr2pdd(e->get_arg(0)); diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 3b5eb27ba..f3435f364 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -98,6 +98,7 @@ namespace polysat { void add_def(sat::literal def, sat::literal l); void internalize_unary(app* e, std::function const& fn); void internalize_binary(app* e, std::function const& fn); + void internalize_binary(app* e, std::function const& fn); void internalize_binaryc(app* e, std::function const& fn); void internalize_par_unary(app* e, std::function const& fn); void internalize_novfl(app* n, std::function& fn); @@ -113,6 +114,11 @@ namespace polysat { void internalize_extract(app* n); void internalize_repeat(app* n); void internalize_bit2bool(app* n); + void internalize_bor(app* n); + void internalize_bxor(app* n); + void internalize_bnor(app* n); + void internalize_bnand(app* n); + void internalize_bxnor(app* n); template void internalize_le(app* n); void internalize_zero_extend(app* n); From 9373e1b7f50016c5d71a009260a17b9defc4e5ca Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 11 Dec 2023 10:00:11 -0800 Subject: [PATCH 35/89] intblast debugging --- src/ast/euf/euf_bv_plugin.cpp | 3 ++- src/sat/smt/intblast_solver.cpp | 36 ++++++++++++++++++++----- src/sat/smt/intblast_solver.h | 3 +++ src/sat/smt/polysat/core.h | 11 ++++---- src/sat/smt/polysat_internalize.cpp | 40 ++++++++++++++++++++++----- src/sat/smt/polysat_model.cpp | 42 ++++++++++++++++++++++++++++- src/sat/smt/polysat_solver.cpp | 2 +- src/sat/smt/polysat_solver.h | 6 ++++- 8 files changed, 121 insertions(+), 22 deletions(-) diff --git a/src/ast/euf/euf_bv_plugin.cpp b/src/ast/euf/euf_bv_plugin.cpp index 4766cba12..f640b3ed0 100644 --- a/src/ast/euf/euf_bv_plugin.cpp +++ b/src/ast/euf/euf_bv_plugin.cpp @@ -343,7 +343,8 @@ namespace euf { enode* hi = mk_extract(n, cut, w - 1); enode* lo = mk_extract(n, 0, cut - 1); auto& i = info(n); - SASSERT(i.value); + if (!i.value) + i.value = n; i.hi = hi; i.lo = lo; i.cut = cut; diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index c269fee9c..0a38597e2 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -82,6 +82,7 @@ namespace intblast { m_core.reset(); m_vars.reset(); m_trail.reset(); + m_new_funs.reset(); m_solver = mk_smt2_solver(m, s.params(), symbol::null); expr_ref_vector es(m); @@ -228,12 +229,19 @@ namespace intblast { } continue; } + + if (m.is_ite(e)) { + m_trail.push_back(m.mk_ite(args.get(0), args.get(1), args.get(2))); + translated.insert(e, m_trail.back()); + continue; + } if (ap->get_family_id() != bv.get_family_id()) { bool has_bv_arg = any_of(*ap, [&](expr* arg) { return bv.is_bv(arg); }); bool has_bv_sort = bv.is_bv(e); func_decl* f = ap->get_decl(); if (has_bv_arg) { + verbose_stream() << mk_pp(ap, m) << "\n"; // need to update args with mod where they are bit-vectors. NOT_IMPLEMENTED_YET(); } @@ -245,14 +253,17 @@ namespace intblast { domain.push_back(bv.is_bv_sort(s) ? a.mk_int() : s); } sort* range = bv.is_bv(e) ? a.mk_int() : e->get_sort(); - f = m.mk_fresh_func_decl(ap->get_decl()->get_name(), symbol("bv"), domain.size(), domain.data(), range); + func_decl* g = nullptr; + if (!m_new_funs.find(f, g)) { + g = m.mk_fresh_func_decl(ap->get_decl()->get_name(), symbol("bv"), domain.size(), domain.data(), range); + m_new_funs.insert(f, g); + } + f = g; } m_trail.push_back(m.mk_app(f, args)); translated.insert(e, m_trail.back()); - verbose_stream() << "translate " << mk_pp(e, m) << " " << has_bv_sort << "\n"; - if (has_bv_sort) m_vars.insert(e, { m_trail.back(), bv_size() }); @@ -329,7 +340,7 @@ namespace intblast { unsigned sz = hi - lo + 1; expr* new_arg = args.get(0); if (lo > 0) - new_arg = a.mk_div(new_arg, a.mk_int(rational::power_of_two(lo))); + new_arg = a.mk_idiv(new_arg, a.mk_int(rational::power_of_two(lo))); m_trail.push_back(new_arg); break; } @@ -386,12 +397,19 @@ namespace intblast { default: verbose_stream() << mk_pp(e, m) << "\n"; NOT_IMPLEMENTED_YET(); - } - verbose_stream() << "insert " << mk_pp(e, m) << " -> " << mk_pp(m_trail.back(), m) << "\n"; + } translated.insert(e, m_trail.back()); } + + TRACE("bv", + for (unsigned i = 0; i < es.size(); ++i) + tout << mk_pp(es.get(i), m) << " -> " << mk_pp(translated[es.get(i)], m) << "\n"; + ); + for (unsigned i = 0; i < es.size(); ++i) es[i] = translated[es.get(i)]; + + } rational solver::get_value(expr* e) const { @@ -414,4 +432,10 @@ namespace intblast { return m_core; } + std::ostream& solver::display(std::ostream& out) const { + if (m_solver) + m_solver->display(out); + return out; + } + } diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index 1df46c300..33d024be5 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -45,6 +45,7 @@ namespace intblast { arith_util a; scoped_ptr<::solver> m_solver; obj_map m_vars; + obj_map m_new_funs; expr_ref_vector m_trail; sat::literal_vector m_core; @@ -62,6 +63,8 @@ namespace intblast { sat::literal_vector const& unsat_core(); rational get_value(expr* e) const; + + std::ostream& display(std::ostream& out) const; }; } diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index b70e70d9b..e43633620 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -115,11 +115,12 @@ namespace polysat { signed_constraint bit(pdd const& p, unsigned i) { return m_constraints.bit(p, i); } - pdd lshr(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("lshr nyi"); } - pdd ashr(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("ashr nyi"); } - pdd shl(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("shlh nyi"); } - pdd band(pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("band nyi"); } - pdd bnot(pdd a) { return -a - 1; } + void lshr(pdd r, pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("lshr nyi"); } + void ashr(pdd r, pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("ashr nyi"); } + void shl(pdd r, pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("shlh nyi"); } + void band(pdd r, pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("band nyi"); } + + pdd bnot(pdd p) { return -p - 1; } pvar add_var(unsigned sz); diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 0afcf06d4..4eeec3da4 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -84,9 +84,9 @@ namespace polysat { case OP_BMUL: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p * q; }); break; case OP_BADD: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p + q; }); break; case OP_BSUB: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p - q; }); break; - case OP_BLSHR: internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.lshr(p, q); }); break; - case OP_BSHL: internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.shl(p, q); }); break; - case OP_BAND: internalize_binary(a, [&](pdd const& p, pdd const& q) { return m_core.band(p, q); }); break; + case OP_BLSHR: internalize_lshr(a); break; + case OP_BSHL: internalize_shl(a); break; + case OP_BAND: internalize_band(a); break; case OP_BOR: internalize_bor(a); break; case OP_BXOR: internalize_bxor(a); break; case OP_BNAND: if_unary(m_core.bnot); internalize_bnand(a); break; @@ -230,6 +230,34 @@ namespace polysat { internalize_binary(n, [&](expr* const& x, expr* const& y) { return bv.mk_bv_not(bv.mk_bv_xor(x, y)); }); } + void solver::internalize_band(app* n) { + if (n->get_num_args() == 2) { + expr* x, * y; + VERIFY(bv.is_bv_and(n, x, y)); + m_core.band(expr2pdd(n), expr2pdd(x), expr2pdd(y)); + } + else { + expr_ref z(n->get_arg(0), m); + for (unsigned i = 1; i < n->get_num_args(); ++i) { + z = bv.mk_bv_and(z, n->get_arg(i)); + ctx.internalize(z); + } + internalize_set(n, expr2pdd(z)); + } + } + + void solver::internalize_lshr(app* n) { + expr* x, * y; + VERIFY(bv.is_bv_lshr(n, x, y)); + m_core.lshr(expr2pdd(n), expr2pdd(x), expr2pdd(y)); + } + + void solver::internalize_shl(app* n) { + expr* x, * y; + VERIFY(bv.is_bv_shl(n, x, y)); + m_core.shl(expr2pdd(n), expr2pdd(x), expr2pdd(y)); + } + void solver::internalize_urem_i(app* rem) { expr* x, *y; euf::enode* n = expr2enode(rem); @@ -389,14 +417,12 @@ namespace polysat { } void solver::internalize_extract(app* e) { - auto p = var2pdd(expr2enode(e)->get_th_var(get_id())); - internalize_set(e, p); + var2pdd(expr2enode(e)->get_th_var(get_id())); } void solver::internalize_concat(app* e) { SASSERT(bv.is_concat(e)); - auto p = var2pdd(expr2enode(e)->get_th_var(get_id())); - internalize_set(e, p); + var2pdd(expr2enode(e)->get_th_var(get_id())); } void solver::internalize_par_unary(app* e, std::function const& fn) { diff --git a/src/sat/smt/polysat_model.cpp b/src/sat/smt/polysat_model.cpp index 383a3f692..9a44e0abf 100644 --- a/src/sat/smt/polysat_model.cpp +++ b/src/sat/smt/polysat_model.cpp @@ -18,11 +18,34 @@ Author: #include "params/bv_rewriter_params.hpp" #include "sat/smt/polysat_solver.h" #include "sat/smt/euf_solver.h" +#include "ast/rewriter/bv_rewriter.h" namespace polysat { void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + + if (m_use_intblast_model) { + expr_ref value(m); + if (n->interpreted()) + value = n->get_expr(); + else if (to_app(n->get_expr())->get_family_id() == bv.get_family_id()) { + bv_rewriter rw(m); + expr_ref_vector args(m); + for (auto arg : euf::enode_args(n)) + args.push_back(values.get(arg->get_root_id())); + rw.mk_app(n->get_decl(), args.size(), args.data(), value); + VERIFY(value); + } + else { + rational r = m_intblast.get_value(n->get_expr()); + verbose_stream() << ctx.bpp(n) << " := " << r << "\n"; + value = bv.mk_numeral(r, get_bv_size(n)); + } + values.set(n->get_root_id(), value); + TRACE("model", tout << "add_value " << ctx.bpp(n) << " := " << value << "\n"); + return; + } #if 0 auto p = expr2pdd(n->get_expr()); rational val; @@ -31,9 +54,24 @@ namespace polysat { #endif } + bool solver::add_dep(euf::enode* n, top_sort& dep) { + if (!is_app(n->get_expr())) + return false; + app* e = to_app(n->get_expr()); + if (n->num_args() == 0) { + dep.insert(n, nullptr); + return true; + } + if (e->get_family_id() != bv.get_family_id()) + return false; + for (euf::enode* arg : euf::enode_args(n)) + dep.add(n, arg->get_root()); + return true; + } + bool solver::check_model(sat::model const& m) const { - return false; + return true; } void solver::finalize_model(model& mdl) { @@ -53,6 +91,8 @@ namespace polysat { for (unsigned v = 0; v < get_num_vars(); ++v) if (m_var2pdd_valid.get(v, false)) out << ctx.bpp(var2enode(v)) << " := " << m_var2pdd[v] << "\n"; + if (m_use_intblast_model) + m_intblast.display(out); return out; } } diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index ad4beb561..43f156c7d 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -40,7 +40,7 @@ namespace polysat { m_intblast(ctx), m_lemma(ctx.get_manager()) { - ctx.get_egraph().add_plugin(alloc(euf::bv_plugin, ctx.get_egraph())); + // ctx.get_egraph().add_plugin(alloc(euf::bv_plugin, ctx.get_egraph())); } unsigned solver::get_bv_size(euf::enode* n) { diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index f3435f364..7cf176b0c 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -119,6 +119,9 @@ namespace polysat { void internalize_bnor(app* n); void internalize_bnand(app* n); void internalize_bxnor(app* n); + void internalize_band(app* n); + void internalize_lshr(app* n); + void internalize_shl(app* n); template void internalize_le(app* n); void internalize_zero_extend(app* n); @@ -172,7 +175,7 @@ namespace polysat { std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override; std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override; void collect_statistics(statistics& st) const override {} - euf::th_solver* clone(euf::solver& ctx) override { throw default_exception("nyi"); } + euf::th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx, get_id()); } extension* copy(sat::solver* s) override { throw default_exception("nyi"); } void find_mutexes(literal_vector& lits, vector & mutexes) override {} void gc() override {} @@ -190,6 +193,7 @@ namespace polysat { bool unit_propagate() override; void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; + bool add_dep(euf::enode* n, top_sort& dep) override; bool extract_pb(std::function& card, std::function& pb) override { return false; } From 5622b13ed345c04c1c8bf8e635344cc056531792 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 11 Dec 2023 12:53:51 -0800 Subject: [PATCH 36/89] working on model extraction Signed-off-by: Nikolaj Bjorner --- src/sat/smt/euf_model.cpp | 5 +-- src/sat/smt/intblast_solver.cpp | 47 +++++++++++++++++++++++------ src/sat/smt/intblast_solver.h | 3 +- src/sat/smt/polysat_internalize.cpp | 10 +++--- 4 files changed, 47 insertions(+), 18 deletions(-) diff --git a/src/sat/smt/euf_model.cpp b/src/sat/smt/euf_model.cpp index b117ac1e3..073d164be 100644 --- a/src/sat/smt/euf_model.cpp +++ b/src/sat/smt/euf_model.cpp @@ -302,7 +302,7 @@ namespace euf { if (mval != sval) { if (r->bool_var() != sat::null_bool_var) out << "b" << r->bool_var() << " "; - out << bpp(r) << " :=\neval: " << sval << "\nmval: " << mval << "\n"; + out << bpp(r) << " :=\nvalue obtained from model: " << sval << "\nvalue of the root expression: " << mval << "\n"; continue; } if (!m.is_bool(val)) @@ -310,7 +310,7 @@ namespace euf { auto bval = s().value(r->bool_var()); bool tt = l_true == bval; if (tt != m.is_true(sval)) - out << bpp(r) << " :=\neval: " << sval << "\nmval: " << bval << "\n"; + out << bpp(r) << " :=\nvalue according to model: " << sval << "\nvalue of Boolean literal: " << bval << "\n"; } for (euf::enode* r : nodes) if (r) @@ -357,6 +357,7 @@ namespace euf { if (!tt && !mdl.is_true(e)) continue; CTRACE("euf", first, display_validation_failure(tout, mdl, n);); + CTRACE("euf", first, display(tout)); IF_VERBOSE(0, display_validation_failure(verbose_stream(), mdl, n);); (void)first; first = false; diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 0a38597e2..a6fd38213 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -93,18 +93,17 @@ namespace intblast { for (auto const& [src, vi] : m_vars) { auto const& [v, b] = vi; - verbose_stream() << "asserting " << mk_pp(v, m) << " < " << b << "\n"; m_solver->assert_expr(a.mk_le(a.mk_int(0), v)); m_solver->assert_expr(a.mk_lt(v, a.mk_int(b))); } - verbose_stream() << "check\n"; - m_solver->display(verbose_stream()); - verbose_stream() << es << "\n"; + IF_VERBOSE(10, verbose_stream() << "check\n"; + m_solver->display(verbose_stream()); + verbose_stream() << es << "\n"); lbool r = m_solver->check_sat(es); - verbose_stream() << "result " << r << "\n"; + IF_VERBOSE(2, verbose_stream() << "(sat.intblast :result " << r << ")\n"); if (r == l_false) { expr_ref_vector core(m); @@ -112,8 +111,13 @@ namespace intblast { obj_map e2index; for (unsigned i = 0; i < es.size(); ++i) e2index.insert(es.get(i), i); - for (auto e : core) - m_core.push_back(literals[e2index[e]]); + for (auto e : core) { + unsigned idx = e2index[e]; + if (idx < literals.size()) + m_core.push_back(literals[idx]); + else + m_core.push_back(ctx.mk_literal(e)); + } } return r; @@ -128,7 +132,7 @@ namespace intblast { return any_of(subterms::all(expr_ref(e, m)), [&](auto* p) { return bv.is_bv_sort(p->get_sort()); }); } - void solver::sorted_subterms(expr_ref_vector const& es, ptr_vector& sorted) { + void solver::sorted_subterms(expr_ref_vector& es, ptr_vector& sorted) { expr_fast_mark1 visited; for (expr* e : es) { sorted.push_back(e); @@ -144,6 +148,28 @@ namespace intblast { sorted.push_back(arg); } } + + // + // Add ground equalities to ensure the model is valid with respect to the current case splits. + // This may cause more conflicts than necessary. Instead could use intblast on the base level, but using literal + // assignment from complete level. + // E.g., force the solver to completely backtrack, check satisfiability using the assignment obtained under a complete assignment. + // If intblast is SAT, then force the model and literal assignment on the rest of the literals. + // + if (!is_ground(e)) + continue; + euf::enode* n = ctx.get_enode(e); + if (!n) + continue; + if (n == n->get_root()) + continue; + expr* r = n->get_root()->get_expr(); + es.push_back(m.mk_eq(e, r)); + r = es.back(); + if (!visited.is_marked(r)) { + visited.mark(r); + sorted.push_back(r); + } } else if (is_quantifier(e)) { quantifier* q = to_quantifier(e); @@ -163,6 +189,7 @@ namespace intblast { expr_ref_vector args(m); sorted_subterms(es, todo); + for (expr* e : todo) { if (is_quantifier(e)) { quantifier* q = to_quantifier(e); @@ -402,8 +429,8 @@ namespace intblast { } TRACE("bv", - for (unsigned i = 0; i < es.size(); ++i) - tout << mk_pp(es.get(i), m) << " -> " << mk_pp(translated[es.get(i)], m) << "\n"; + for (expr* e : es) + tout << mk_pp(e, m) << "\n->\n" << mk_pp(translated[e], m) << "\n"; ); for (unsigned i = 0; i < es.size(); ++i) diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index 33d024be5..c165e1562 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -53,7 +53,8 @@ namespace intblast { bool is_bv(sat::literal lit); void translate(expr_ref_vector& es); - void sorted_subterms(expr_ref_vector const& es, ptr_vector& sorted); + void add_root_equations(expr_ref_vector& es, ptr_vector& sorted); + void sorted_subterms(expr_ref_vector& es, ptr_vector& sorted); public: solver(euf::solver& ctx); diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 4eeec3da4..46c1e293f 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -347,13 +347,13 @@ namespace polysat { unsigned sz2 = sz - arg_sz; var2pdd(expr2enode(e)->get_th_var(get_id())); if (arg_sz == sz) - add_clause(eq_internalize(e, arg), false); + add_clause(eq_internalize(e, arg), nullptr); else { sat::literal lt0 = ctx.mk_literal(bv.mk_slt(arg, bv.mk_numeral(0, arg_sz))); // arg < 0 ==> e = concat(arg, 1...1) // arg >= 0 ==> e = concat(arg, 0...0) - add_clause(lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(rational::power_of_two(sz2) - 1, sz2))), false); - add_clause(~lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz2))), false); + add_clause(lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(rational::power_of_two(sz2) - 1, sz2))), nullptr); + add_clause(~lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz2))), nullptr); } } @@ -364,10 +364,10 @@ namespace polysat { unsigned sz2 = sz - arg_sz; var2pdd(expr2enode(e)->get_th_var(get_id())); if (arg_sz == sz) - add_clause(eq_internalize(e, arg), false); + add_clause(eq_internalize(e, arg), nullptr); else // e = concat(arg, 0...0) - add_clause(eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz2))), false); + add_clause(eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz2))), nullptr); } void solver::internalize_div_rem(app* e, bool is_div) { From 17c480f8370bf91a0ae5315a210ae6c4c30c9238 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 11 Dec 2023 14:51:21 -0800 Subject: [PATCH 37/89] adding band Signed-off-by: Nikolaj Bjorner --- src/ast/arith_decl_plugin.cpp | 14 + src/ast/arith_decl_plugin.h | 6 + src/sat/smt/arith_internalize.cpp | 5 + src/sat/smt/intblast_solver.cpp | 323 ++++++++++++--------- src/sat/smt/polysat/CMakeLists.txt | 1 + src/sat/smt/polysat/constraints.cpp | 31 +- src/sat/smt/polysat/constraints.h | 11 +- src/sat/smt/polysat/core.h | 8 +- src/sat/smt/polysat/ule_constraint.cpp | 18 -- src/sat/smt/polysat/ule_constraint.h | 2 - src/sat/smt/polysat/umul_ovfl_constraint.h | 2 - src/sat/smt/polysat_internalize.cpp | 8 +- 12 files changed, 246 insertions(+), 183 deletions(-) diff --git a/src/ast/arith_decl_plugin.cpp b/src/ast/arith_decl_plugin.cpp index ba2b4b4a7..2d830d510 100644 --- a/src/ast/arith_decl_plugin.cpp +++ b/src/ast/arith_decl_plugin.cpp @@ -523,6 +523,12 @@ func_decl * arith_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters return m_manager->mk_func_decl(symbol("divisible"), 1, &m_int_decl, m_manager->mk_bool_sort(), func_decl_info(m_family_id, k, num_parameters, parameters)); } + if (k == OP_ARITH_BAND) { + if (arity != 2 || domain[0] != m_int_decl || domain[1] != m_int_decl || num_parameters != 1 || !parameters[0].is_int()) + m_manager->raise_exception("invalid bitwise and application. Expects integer parameter and two arguments of sort integer"); + return m_manager->mk_func_decl(symbol("band"), 2, domain, m_int_decl, + func_decl_info(m_family_id, k, num_parameters, parameters)); + } if (m_manager->int_real_coercions() && use_coercion(k)) { return mk_func_decl(fix_kind(k, arity), has_real_arg(arity, domain, m_real_decl)); @@ -548,6 +554,14 @@ func_decl * arith_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters return m_manager->mk_func_decl(symbol("divisible"), 1, &m_int_decl, m_manager->mk_bool_sort(), func_decl_info(m_family_id, k, num_parameters, parameters)); } + if (k == OP_ARITH_BAND) { + if (num_args != 2 || args[0]->get_sort() != m_int_decl || args[1]->get_sort() != m_int_decl || num_parameters != 1 || !parameters[0].is_int()) + m_manager->raise_exception("invalid bitwise and application. Expects integer parameter and two arguments of sort integer"); + sort* domain[2] = { m_int_decl, m_int_decl }; + return m_manager->mk_func_decl(symbol("band"), 2, domain, m_int_decl, + func_decl_info(m_family_id, k, num_parameters, parameters)); + } + if (m_manager->int_real_coercions() && use_coercion(k)) { return mk_func_decl(fix_kind(k, num_args), has_real_arg(m_manager, num_args, args, m_real_decl)); } diff --git a/src/ast/arith_decl_plugin.h b/src/ast/arith_decl_plugin.h index fa359a9a7..a5ab60731 100644 --- a/src/ast/arith_decl_plugin.h +++ b/src/ast/arith_decl_plugin.h @@ -70,6 +70,8 @@ enum arith_op_kind { OP_ASINH, OP_ACOSH, OP_ATANH, + // Bit-vector functions + OP_ARITH_BAND, // constants OP_PI, OP_E, @@ -309,6 +311,8 @@ public: bool is_int_real(sort const * s) const { return s->get_family_id() == arith_family_id; } bool is_int_real(expr const * n) const { return is_int_real(n->get_sort()); } + bool is_band(expr const* n) const { return is_app_of(n, arith_family_id, OP_ARITH_BAND); } + bool is_sin(expr const* n) const { return is_app_of(n, arith_family_id, OP_SIN); } bool is_cos(expr const* n) const { return is_app_of(n, arith_family_id, OP_COS); } bool is_tan(expr const* n) const { return is_app_of(n, arith_family_id, OP_TAN); } @@ -471,6 +475,8 @@ public: app * mk_power(expr* arg1, expr* arg2) { return m_manager.mk_app(arith_family_id, OP_POWER, arg1, arg2); } app * mk_power0(expr* arg1, expr* arg2) { return m_manager.mk_app(arith_family_id, OP_POWER0, arg1, arg2); } + app* mk_band(unsigned n, expr* arg1, expr* arg2) { parameter p(n); expr* args[2] = { arg1, arg2 }; return m_manager.mk_app(arith_family_id, OP_ARITH_BAND, 1, &p, 2, args); } + app * mk_sin(expr * arg) { return m_manager.mk_app(arith_family_id, OP_SIN, arg); } app * mk_cos(expr * arg) { return m_manager.mk_app(arith_family_id, OP_COS, arg); } app * mk_tan(expr * arg) { return m_manager.mk_app(arith_family_id, OP_TAN, arg); } diff --git a/src/sat/smt/arith_internalize.cpp b/src/sat/smt/arith_internalize.cpp index 3c6286317..66b25cb34 100644 --- a/src/sat/smt/arith_internalize.cpp +++ b/src/sat/smt/arith_internalize.cpp @@ -252,6 +252,11 @@ namespace arith { st.to_ensure_var().push_back(n1); st.to_ensure_var().push_back(n2); } + else if (a.is_band(n)) { + // unsupported for now. + found_unsupported(n); + ensure_arg_vars(to_app(n)); + } else if (!a.is_div0(n) && !a.is_mod0(n) && !a.is_idiv0(n) && !a.is_rem0(n) && !a.is_power0(n)) { found_unsupported(n); ensure_arg_vars(to_app(n)); diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index a6fd38213..8bb3faa49 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -19,15 +19,15 @@ Author: namespace intblast { - solver::solver(euf::solver& ctx): - ctx(ctx), + solver::solver(euf::solver& ctx) : + ctx(ctx), s(ctx.s()), m(ctx.get_manager()), bv(m), a(m), m_trail(m) {} - + lbool solver::check() { sat::literal_vector literals; uint_set selected; @@ -76,7 +76,7 @@ namespace intblast { if (s.value(b) == l_true && s.value(a) == l_true && s.lvl(b) < s.lvl(a)) std::swap(a, b); selected.insert(a.index()); - literals.push_back(a); + literals.push_back(a); } m_core.reset(); @@ -98,9 +98,9 @@ namespace intblast { } IF_VERBOSE(10, verbose_stream() << "check\n"; - m_solver->display(verbose_stream()); - verbose_stream() << es << "\n"); - + m_solver->display(verbose_stream()); + verbose_stream() << es << "\n"); + lbool r = m_solver->check_sat(es); IF_VERBOSE(2, verbose_stream() << "(sat.intblast :result " << r << ")\n"); @@ -116,14 +116,14 @@ namespace intblast { if (idx < literals.size()) m_core.push_back(literals[idx]); else - m_core.push_back(ctx.mk_literal(e)); + m_core.push_back(ctx.mk_literal(e)); } } return r; }; - bool solver::is_bv(sat::literal lit) { + bool solver::is_bv(sat::literal lit) { expr* e = ctx.bool_var2expr(lit.var()); if (!e) return false; @@ -185,9 +185,9 @@ namespace intblast { void solver::translate(expr_ref_vector& es) { ptr_vector todo; - obj_map translated; + obj_map translated; expr_ref_vector args(m); - + sorted_subterms(es, todo); for (expr* e : todo) { @@ -236,12 +236,12 @@ namespace intblast { if (m_vars.contains(x)) return x; return to_expr(a.mk_mod(x, a.mk_int(bv_size()))); - }; + }; auto mk_smod = [&](expr* x) { auto shift = bv_size() / 2; return a.mk_mod(a.mk_add(x, a.mk_int(shift)), a.mk_int(bv_size())); - }; + }; if (m.is_eq(e)) { bool has_bv_arg = any_of(*ap, [&](expr* arg) { return bv.is_bv(arg); }); @@ -256,7 +256,7 @@ namespace intblast { } continue; } - + if (m.is_ite(e)) { m_trail.push_back(m.mk_ite(args.get(0), args.get(1), args.get(2))); translated.insert(e, m_trail.back()); @@ -287,144 +287,179 @@ namespace intblast { } f = g; } - + m_trail.push_back(m.mk_app(f, args)); translated.insert(e, m_trail.back()); - if (has_bv_sort) + if (has_bv_sort) m_vars.insert(e, { m_trail.back(), bv_size() }); - + continue; } + auto bnot = [&](expr* e) { + return a.mk_sub(a.mk_int(-1), e); + }; + + auto band = [&](expr_ref_vector const& args) { + expr * r = args.get(0); + for (unsigned i = 1; i < args.size(); ++i) + r = a.mk_band(bv.get_bv_size(e), r, args.get(i)); + return r; + }; + switch (ap->get_decl_kind()) { - case OP_BADD: - m_trail.push_back(a.mk_add(args)); - break; - case OP_BSUB: - m_trail.push_back(a.mk_sub(args.size(), args.data())); - break; - case OP_BMUL: - m_trail.push_back(a.mk_mul(args)); - break; - case OP_ULEQ: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_le(mk_mod(args.get(0)), mk_mod(args.get(1)))); - break; - case OP_UGEQ: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_ge(mk_mod(args.get(0)), mk_mod(args.get(1)))); - break; - case OP_ULT: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_lt(mk_mod(args.get(0)), mk_mod(args.get(1)))); - break; - case OP_UGT: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_gt(mk_mod(args.get(0)), mk_mod(args.get(1)))); - break; - case OP_SLEQ: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_le(mk_smod(args.get(0)), mk_smod(args.get(1)))); - break; - case OP_SGEQ: - m_trail.push_back(a.mk_ge(mk_smod(args.get(0)), mk_smod(args.get(1)))); - break; - case OP_SLT: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_lt(mk_smod(args.get(0)), mk_smod(args.get(1)))); - break; - case OP_SGT: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_gt(mk_smod(args.get(0)), mk_smod(args.get(1)))); - break; - case OP_BNEG: - m_trail.push_back(a.mk_uminus(args.get(0))); - break; - case OP_CONCAT: { - expr_ref r(a.mk_int(0), m); - unsigned sz = 0; - for (unsigned i = 0; i < args.size(); ++i) { - expr* old_arg = ap->get_arg(i); - expr* new_arg = args.get(i); - bv_expr = old_arg; - new_arg = mk_mod(new_arg); - if (sz > 0) { - new_arg = a.mk_mul(new_arg, a.mk_int(rational::power_of_two(sz))); - r = a.mk_add(r, new_arg); - } - else - r = new_arg; - sz += bv.get_bv_size(old_arg->get_sort()); + case OP_BADD: + m_trail.push_back(a.mk_add(args)); + break; + case OP_BSUB: + m_trail.push_back(a.mk_sub(args.size(), args.data())); + break; + case OP_BMUL: + m_trail.push_back(a.mk_mul(args)); + break; + case OP_ULEQ: + bv_expr = ap->get_arg(0); + m_trail.push_back(a.mk_le(mk_mod(args.get(0)), mk_mod(args.get(1)))); + break; + case OP_UGEQ: + bv_expr = ap->get_arg(0); + m_trail.push_back(a.mk_ge(mk_mod(args.get(0)), mk_mod(args.get(1)))); + break; + case OP_ULT: + bv_expr = ap->get_arg(0); + m_trail.push_back(a.mk_lt(mk_mod(args.get(0)), mk_mod(args.get(1)))); + break; + case OP_UGT: + bv_expr = ap->get_arg(0); + m_trail.push_back(a.mk_gt(mk_mod(args.get(0)), mk_mod(args.get(1)))); + break; + case OP_SLEQ: + bv_expr = ap->get_arg(0); + m_trail.push_back(a.mk_le(mk_smod(args.get(0)), mk_smod(args.get(1)))); + break; + case OP_SGEQ: + m_trail.push_back(a.mk_ge(mk_smod(args.get(0)), mk_smod(args.get(1)))); + break; + case OP_SLT: + bv_expr = ap->get_arg(0); + m_trail.push_back(a.mk_lt(mk_smod(args.get(0)), mk_smod(args.get(1)))); + break; + case OP_SGT: + bv_expr = ap->get_arg(0); + m_trail.push_back(a.mk_gt(mk_smod(args.get(0)), mk_smod(args.get(1)))); + break; + case OP_BNEG: + m_trail.push_back(a.mk_uminus(args.get(0))); + break; + case OP_CONCAT: { + expr_ref r(a.mk_int(0), m); + unsigned sz = 0; + for (unsigned i = 0; i < args.size(); ++i) { + expr* old_arg = ap->get_arg(i); + expr* new_arg = args.get(i); + bv_expr = old_arg; + new_arg = mk_mod(new_arg); + if (sz > 0) { + new_arg = a.mk_mul(new_arg, a.mk_int(rational::power_of_two(sz))); + r = a.mk_add(r, new_arg); } - m_trail.push_back(r); - break; + else + r = new_arg; + sz += bv.get_bv_size(old_arg->get_sort()); } - case OP_EXTRACT: { - unsigned lo, hi; - expr* old_arg; - VERIFY(bv.is_extract(e, lo, hi, old_arg)); - unsigned sz = hi - lo + 1; - expr* new_arg = args.get(0); - if (lo > 0) - new_arg = a.mk_idiv(new_arg, a.mk_int(rational::power_of_two(lo))); - m_trail.push_back(new_arg); - break; - } - case OP_BV_NUM: { - rational val; - unsigned sz; - VERIFY(bv.is_numeral(e, val, sz)); - m_trail.push_back(a.mk_int(val)); - break; - } - case OP_BUREM_I: { - expr* x = args.get(0), * y = args.get(1); - m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); - break; - } - case OP_BUDIV_I: { - expr* x = args.get(0), * y = args.get(1); - m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); - break; - } - case OP_BUMUL_NO_OVFL: { - expr* x = args.get(0), * y = args.get(1); - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_lt(a.mk_mul(mk_mod(x), mk_mod(y)), a.mk_int(bv_size()))); - break; - } - case OP_BNOT: - case OP_BNAND: - case OP_BNOR: - case OP_BXOR: - case OP_BXNOR: - case OP_BCOMP: - case OP_BSHL: - case OP_BLSHR: - case OP_BASHR: - case OP_ROTATE_LEFT: - case OP_ROTATE_RIGHT: - case OP_EXT_ROTATE_LEFT: - case OP_EXT_ROTATE_RIGHT: - case OP_REPEAT: - case OP_ZERO_EXT: - case OP_SIGN_EXT: - case OP_BREDOR: - case OP_BREDAND: - case OP_BUDIV: - case OP_BSDIV: - case OP_BUREM: - case OP_BSREM: - case OP_BSMOD: - case OP_BAND: - verbose_stream() << mk_pp(e, m) << "\n"; - NOT_IMPLEMENTED_YET(); - break; - default: - verbose_stream() << mk_pp(e, m) << "\n"; - NOT_IMPLEMENTED_YET(); - } + m_trail.push_back(r); + break; + } + case OP_EXTRACT: { + unsigned lo, hi; + expr* old_arg; + VERIFY(bv.is_extract(e, lo, hi, old_arg)); + unsigned sz = hi - lo + 1; + expr* new_arg = args.get(0); + if (lo > 0) + new_arg = a.mk_idiv(new_arg, a.mk_int(rational::power_of_two(lo))); + m_trail.push_back(new_arg); + break; + } + case OP_BV_NUM: { + rational val; + unsigned sz; + VERIFY(bv.is_numeral(e, val, sz)); + m_trail.push_back(a.mk_int(val)); + break; + } + case OP_BUREM_I: { + expr* x = args.get(0), * y = args.get(1); + m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); + break; + } + case OP_BUDIV_I: { + expr* x = args.get(0), * y = args.get(1); + m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); + break; + } + case OP_BUMUL_NO_OVFL: { + expr* x = args.get(0), * y = args.get(1); + bv_expr = ap->get_arg(0); + m_trail.push_back(a.mk_lt(a.mk_mul(mk_mod(x), mk_mod(y)), a.mk_int(bv_size()))); + break; + } + case OP_BSHL: { + expr* x = args.get(0), * y = args.get(1); + expr* r = a.mk_int(0); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) + r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); + m_trail.push_back(r); + break; + } + case OP_BNOT: + m_trail.push_back(bnot(args.get(0))); + break; + case OP_BLSHR: { + expr* x = args.get(0), * y = args.get(1); + expr* r = a.mk_int(0); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) + r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); + m_trail.push_back(r); + break; + } + case OP_BOR: + for (unsigned i = 0; i < args.size(); ++i) + args[i] = bnot(args.get(i)); + m_trail.push_back(bnot(band(args))); + break; + case OP_BNAND: + m_trail.push_back(bnot(band(args))); + break; + case OP_BAND: + m_trail.push_back(band(args)); + break; + case OP_BXOR: + case OP_BXNOR: + case OP_BCOMP: + case OP_BASHR: + case OP_ROTATE_LEFT: + case OP_ROTATE_RIGHT: + case OP_EXT_ROTATE_LEFT: + case OP_EXT_ROTATE_RIGHT: + case OP_REPEAT: + case OP_ZERO_EXT: + case OP_SIGN_EXT: + case OP_BREDOR: + case OP_BREDAND: + case OP_BUDIV: + case OP_BSDIV: + case OP_BUREM: + case OP_BSREM: + case OP_BSMOD: + verbose_stream() << mk_pp(e, m) << "\n"; + NOT_IMPLEMENTED_YET(); + break; + default: + verbose_stream() << mk_pp(e, m) << "\n"; + NOT_IMPLEMENTED_YET(); + } translated.insert(e, m_trail.back()); } @@ -433,7 +468,7 @@ namespace intblast { tout << mk_pp(e, m) << "\n->\n" << mk_pp(translated[e], m) << "\n"; ); - for (unsigned i = 0; i < es.size(); ++i) + for (unsigned i = 0; i < es.size(); ++i) es[i] = translated[es.get(i)]; @@ -445,7 +480,7 @@ namespace intblast { m_solver->get_model(mdl); expr_ref r(m); var_info vi; - rational val; + rational val; if (!m_vars.find(e, vi)) return rational::zero(); if (!mdl->eval_expr(vi.dst, r, true)) diff --git a/src/sat/smt/polysat/CMakeLists.txt b/src/sat/smt/polysat/CMakeLists.txt index 1f943f48d..c7f5e49d5 100644 --- a/src/sat/smt/polysat/CMakeLists.txt +++ b/src/sat/smt/polysat/CMakeLists.txt @@ -5,6 +5,7 @@ z3_add_component(polysat core.cpp fixed_bits.cpp forbidden_intervals.cpp + op_constraint.cpp ule_constraint.cpp umul_ovfl_constraint.cpp viable.cpp diff --git a/src/sat/smt/polysat/constraints.cpp b/src/sat/smt/polysat/constraints.cpp index 0de987693..83476160a 100644 --- a/src/sat/smt/polysat/constraints.cpp +++ b/src/sat/smt/polysat/constraints.cpp @@ -17,6 +17,7 @@ Author: #include "sat/smt/polysat/constraints.h" #include "sat/smt/polysat/ule_constraint.h" #include "sat/smt/polysat/umul_ovfl_constraint.h" +#include "sat/smt/polysat/op_constraint.h" namespace polysat { @@ -36,6 +37,30 @@ namespace polysat { return signed_constraint(ckind_t::umul_ovfl_t, cnstr); } + signed_constraint constraints::lshr(pdd const& a, pdd const& b, pdd const& r) { + auto* cnstr = alloc(op_constraint, op_constraint::code::lshr_op, a, b, r); + c.trail().push(new_obj_trail(cnstr)); + return signed_constraint(ckind_t::op_t, cnstr); + } + + signed_constraint constraints::ashr(pdd const& a, pdd const& b, pdd const& r) { + auto* cnstr = alloc(op_constraint, op_constraint::code::ashr_op, a, b, r); + c.trail().push(new_obj_trail(cnstr)); + return signed_constraint(ckind_t::op_t, cnstr); + } + + signed_constraint constraints::shl(pdd const& a, pdd const& b, pdd const& r) { + auto* cnstr = alloc(op_constraint, op_constraint::code::shl_op, a, b, r); + c.trail().push(new_obj_trail(cnstr)); + return signed_constraint(ckind_t::op_t, cnstr); + } + + signed_constraint constraints::band(pdd const& a, pdd const& b, pdd const& r) { + auto* cnstr = alloc(op_constraint, op_constraint::code::and_op, a, b, r); + c.trail().push(new_obj_trail(cnstr)); + return signed_constraint(ckind_t::op_t, cnstr); + } + bool signed_constraint::is_eq(pvar& v, rational& val) { if (m_sign) return false; @@ -64,10 +89,4 @@ namespace polysat { return out << *m_constraint; } - bool signed_constraint::is_always_true() const { - return m_sign ? m_constraint->is_always_false() : m_constraint->is_always_true(); - } - bool signed_constraint::is_always_false() const { - return m_sign ? m_constraint->is_always_true() : m_constraint->is_always_false(); - } } diff --git a/src/sat/smt/polysat/constraints.h b/src/sat/smt/polysat/constraints.h index 81ba6f6a0..15d8dfa09 100644 --- a/src/sat/smt/polysat/constraints.h +++ b/src/sat/smt/polysat/constraints.h @@ -41,8 +41,6 @@ namespace polysat { virtual std::ostream& display(std::ostream& out) const = 0; virtual lbool eval() const = 0; virtual lbool eval(assignment const& a) const = 0; - virtual bool is_always_true() const = 0; - virtual bool is_always_false() const = 0; }; inline std::ostream& operator<<(std::ostream& out, constraint const& c) { return c.display(out); } @@ -63,9 +61,10 @@ namespace polysat { unsigned_vector const& vars() const { return m_constraint->vars(); } unsigned var(unsigned idx) const { return m_constraint->var(idx); } bool contains_var(pvar v) const { return m_constraint->contains_var(v); } - bool is_always_true() const; - bool is_always_false() const; + bool is_always_true() const { return eval() == l_true; } + bool is_always_false() const { return eval() == l_false; } lbool eval(assignment& a) const; + lbool eval() const { return m_sign ? ~m_constraint->eval() : m_constraint->eval();} ckind_t op() const { return m_op; } bool is_ule() const { return m_op == ule_t; } bool is_umul_ovfl() const { return m_op == umul_ovfl_t; } @@ -138,6 +137,10 @@ namespace polysat { signed_constraint umul_ovfl(int p, pdd const& q) { return umul_ovfl(rational(p), q); } signed_constraint umul_ovfl(unsigned p, pdd const& q) { return umul_ovfl(rational(p), q); } + signed_constraint lshr(pdd const& a, pdd const& b, pdd const& r); + signed_constraint ashr(pdd const& a, pdd const& b, pdd const& r); + signed_constraint shl(pdd const& a, pdd const& b, pdd const& r); + signed_constraint band(pdd const& a, pdd const& b, pdd const& r); //signed_constraint even(pdd const& p) { return parity_at_least(p, 1); } //signed_constraint odd(pdd const& p) { return ~even(p); } diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index e43633620..c6442a290 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -115,10 +115,10 @@ namespace polysat { signed_constraint bit(pdd const& p, unsigned i) { return m_constraints.bit(p, i); } - void lshr(pdd r, pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("lshr nyi"); } - void ashr(pdd r, pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("ashr nyi"); } - void shl(pdd r, pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("shlh nyi"); } - void band(pdd r, pdd a, pdd b) { NOT_IMPLEMENTED_YET(); throw default_exception("band nyi"); } + signed_constraint lshr(pdd const& a, pdd const& b, pdd const& r) { return m_constraints.lshr(a, b, r); } + signed_constraint ashr(pdd const& a, pdd const& b, pdd const& r) { return m_constraints.ashr(a, b, r); } + signed_constraint shl(pdd const& a, pdd const& b, pdd const& r) { return m_constraints.shl(a, b, r); } + signed_constraint band(pdd const& a, pdd const& b, pdd const& r) { return m_constraints.band(a, b, r); } pdd bnot(pdd p) { return -p - 1; } diff --git a/src/sat/smt/polysat/ule_constraint.cpp b/src/sat/smt/polysat/ule_constraint.cpp index 3d6240bad..185dad0ee 100644 --- a/src/sat/smt/polysat/ule_constraint.cpp +++ b/src/sat/smt/polysat/ule_constraint.cpp @@ -343,22 +343,4 @@ namespace polysat { return eval(a.apply_to(lhs()), a.apply_to(rhs())); } - bool ule_constraint::is_always_true() const { - if (lhs().is_zero()) - return true; // 0 <= p - if (rhs().is_max()) - return true; // p <= -1 - if (lhs().is_val() && rhs().is_val()) - return lhs().val() <= rhs().val(); - return false; - } - - bool ule_constraint::is_always_false() const { - if (lhs().is_never_zero() && rhs().is_zero()) - return true; // p > 0, q = 0 - if (lhs().is_val() && rhs().is_val()) - return lhs().val() > rhs().val(); - return false; - } - } diff --git a/src/sat/smt/polysat/ule_constraint.h b/src/sat/smt/polysat/ule_constraint.h index 0d481c5ea..aa53e6a4f 100644 --- a/src/sat/smt/polysat/ule_constraint.h +++ b/src/sat/smt/polysat/ule_constraint.h @@ -35,8 +35,6 @@ namespace polysat { std::ostream& display(std::ostream& out) const override; lbool eval() const override; lbool eval(assignment const& a) const override; - bool is_always_true() const override; - bool is_always_false() const override; bool is_eq() const { return m_rhs.is_zero(); } unsigned power_of_2() const { return m_lhs.power_of_2(); } diff --git a/src/sat/smt/polysat/umul_ovfl_constraint.h b/src/sat/smt/polysat/umul_ovfl_constraint.h index 4ac03dfb3..c9d03fb01 100644 --- a/src/sat/smt/polysat/umul_ovfl_constraint.h +++ b/src/sat/smt/polysat/umul_ovfl_constraint.h @@ -34,8 +34,6 @@ namespace polysat { std::ostream& display(std::ostream& out) const override; lbool eval() const override; lbool eval(assignment const& a) const override; - bool is_always_true() const override { return false; } // todo port - bool is_always_false() const override { return false; } }; } diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 46c1e293f..4496dc759 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -234,7 +234,9 @@ namespace polysat { if (n->get_num_args() == 2) { expr* x, * y; VERIFY(bv.is_bv_and(n, x, y)); - m_core.band(expr2pdd(n), expr2pdd(x), expr2pdd(y)); + auto sc = m_core.band(expr2pdd(x), expr2pdd(y), expr2pdd(n)); + // auto index = m_core.register_constraint(sc, dependency::axiom()); + // } else { expr_ref z(n->get_arg(0), m); @@ -249,13 +251,13 @@ namespace polysat { void solver::internalize_lshr(app* n) { expr* x, * y; VERIFY(bv.is_bv_lshr(n, x, y)); - m_core.lshr(expr2pdd(n), expr2pdd(x), expr2pdd(y)); + auto sc = m_core.lshr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); } void solver::internalize_shl(app* n) { expr* x, * y; VERIFY(bv.is_bv_shl(n, x, y)); - m_core.shl(expr2pdd(n), expr2pdd(x), expr2pdd(y)); + auto sc = m_core.shl(expr2pdd(x), expr2pdd(y), expr2pdd(n)); } void solver::internalize_urem_i(app* rem) { From 727a738958e28ab13ffefc5ca92ae0a5728e8420 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 11 Dec 2023 14:51:31 -0800 Subject: [PATCH 38/89] new files --- src/sat/smt/polysat/op_constraint.cpp | 663 ++++++++++++++++++++++++++ src/sat/smt/polysat/op_constraint.h | 94 ++++ 2 files changed, 757 insertions(+) create mode 100644 src/sat/smt/polysat/op_constraint.cpp create mode 100644 src/sat/smt/polysat/op_constraint.h diff --git a/src/sat/smt/polysat/op_constraint.cpp b/src/sat/smt/polysat/op_constraint.cpp new file mode 100644 index 000000000..401b1ca52 --- /dev/null +++ b/src/sat/smt/polysat/op_constraint.cpp @@ -0,0 +1,663 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat constraints for bit operations. + +Author: + + Jakob Rath, Nikolaj Bjorner (nbjorner) 2021-12-09 + +Notes: + +Additional possible functionality on constraints: + +- activate - when operation is first activated. It may be created and only activated later. +- bit-wise assignments - narrow based on bit assignment, not entire word assignment. +- integration with congruence tables +- integration with conflict resolution + +--*/ + +#include "sat/smt/polysat/op_constraint.h" +#include "sat/smt/polysat/core.h" + +namespace polysat { + + op_constraint::op_constraint(code c, pdd const& p, pdd const& q, pdd const& r) : + m_op(c), m_p(p), m_q(q), m_r(r) { + vars().append(p.free_vars()); + for (auto v : q.free_vars()) + if (!vars().contains(v)) + vars().push_back(v); + for (auto v : r.free_vars()) + if (!vars().contains(v)) + vars().push_back(v); + + switch (c) { + case code::and_op: + if (p.index() > q.index()) + std::swap(m_p, m_q); + break; + case code::inv_op: + SASSERT(q.is_zero()); + default: + break; + } + VERIFY(r.is_var()); + } + + lbool op_constraint::eval() const { + return eval(p(), q(), r()); + } + + lbool op_constraint::eval(assignment const& a) const { + return eval(a.apply_to(p()), a.apply_to(q()), a.apply_to(r())); + } + + lbool op_constraint::eval(pdd const& p, pdd const& q, pdd const& r) const { + switch (m_op) { + case code::lshr_op: + return eval_lshr(p, q, r); + case code::shl_op: + return eval_shl(p, q, r); + case code::and_op: + return eval_and(p, q, r); + case code::inv_op: + return eval_inv(p, r); + default: + return l_undef; + } + } + + + /** Evaluate constraint: r == p >> q */ + lbool op_constraint::eval_lshr(pdd const& p, pdd const& q, pdd const& r) { + auto& m = p.manager(); + + if (q.is_zero() && p == r) + return l_true; + + if (q.is_val() && q.val() >= m.power_of_2() && r.is_val()) + return to_lbool(r.is_zero()); + + if (p.is_val() && q.is_val() && r.is_val()) { + SASSERT(q.val().is_unsigned()); // otherwise, previous condition should have been triggered + return to_lbool(r.val() == machine_div2k(p.val(), q.val().get_unsigned())); + } + + // TODO: other cases when we know lower bound of q, + // e.g, q = 2^k*q1 + q2, where q2 is a constant. + return l_undef; + } + + /** Evaluate constraint: r == p << q */ + lbool op_constraint::eval_shl(pdd const& p, pdd const& q, pdd const& r) { + auto& m = p.manager(); + + if (q.is_zero() && p == r) + return l_true; + + if (q.is_val() && q.val() >= m.power_of_2() && r.is_val()) + return to_lbool(r.is_zero()); + + if (p.is_val() && q.is_val() && r.is_val()) { + SASSERT(q.val().is_unsigned()); // otherwise, previous condition should have been triggered + // TODO: use left-shift operation instead of multiplication? + auto factor = rational::power_of_two(q.val().get_unsigned()); + return to_lbool(r == p * m.mk_val(factor)); + } + + // TODO: other cases when we know lower bound of q, + // e.g, q = 2^k*q1 + q2, where q2 is a constant. + // (bounds should be tracked by viable, then just use min_viable here) + return l_undef; + } + + /** Evaluate constraint: r == p & q */ + lbool op_constraint::eval_and(pdd const& p, pdd const& q, pdd const& r) { + if ((p.is_zero() || q.is_zero()) && r.is_zero()) + return l_true; + + if (p.is_val() && q.is_val() && r.is_val()) + return r.val() == bitwise_and(p.val(), q.val()) ? l_true : l_false; + + return l_undef; + } + + /** Evaluate constraint: r == inv p */ + lbool op_constraint::eval_inv(pdd const& p, pdd const& r) { + if (!p.is_val() || !r.is_val()) + return l_undef; + + if (p.is_zero() || r.is_zero()) // the inverse of 0 is 0 (by arbitrary definition). Just to have some unique value + return to_lbool(p.is_zero() && r.is_zero()); + + return to_lbool(p.val().pseudo_inverse(p.power_of_2()) == r.val()); + } + + std::ostream& op_constraint::display(std::ostream& out, lbool status) const { + switch (status) { + case l_true: return display(out, "=="); + case l_false: return display(out, "!="); + default: return display(out, "?="); + } + } + + std::ostream& operator<<(std::ostream& out, op_constraint::code c) { + switch (c) { + case op_constraint::code::ashr_op: + return out << ">>a"; + case op_constraint::code::lshr_op: + return out << ">>"; + case op_constraint::code::shl_op: + return out << "<<"; + case op_constraint::code::and_op: + return out << "&"; + case op_constraint::code::inv_op: + return out << "inv"; + + default: + UNREACHABLE(); + return out; + } + return out; + } + + std::ostream& op_constraint::display(std::ostream& out) const { + return display(out, l_true); + } + + std::ostream& op_constraint::display(std::ostream& out, char const* eq) const { + if (m_op == code::inv_op) + return out << r() << " " << eq << " " << m_op << " " << p(); + + return out << r() << " " << eq << " " << p() << " " << m_op << " " << q(); + } + +#if 0 + /** + * Produce lemmas that contradict the given assignment. + * + * We can assume that op_constraint is only asserted positive. + */ + clause_ref op_constraint::produce_lemma(solver& s, assignment const& a, bool is_positive) { + SASSERT(is_positive); + + if (is_currently_true(a, is_positive)) + return {}; + + return produce_lemma(s, a); + } + + clause_ref op_constraint::produce_lemma(solver& s, assignment const& a) { + switch (m_op) { + case code::lshr_op: + return lemma_lshr(s, a); + case code::shl_op: + return lemma_shl(s, a); + case code::and_op: + return lemma_and(s, a); + case code::inv_op: + return lemma_inv(s, a); + default: + NOT_IMPLEMENTED_YET(); + return {}; + } + } + + /** + * Enforce basic axioms for r == p >> q: + * + * q >= N -> r = 0 + * q >= k -> r[i] = 0 for N - k <= i < N (bit indices range from 0 to N-1, inclusive) + * q >= k -> r <= 2^{N-k} - 1 + * q = k -> r[i] = p[i+k] for 0 <= i < N - k + * r <= p + * q != 0 -> r <= p (subsumed by previous axiom) + * q != 0 /\ p > 0 -> r < p + * q = 0 -> r = p + * p = q -> r = 0 + * + * when q is a constant, several axioms can be enforced at activation time. + * + * Enforce also inferences and bounds + * + * TODO: use also + * s.m_viable.min_viable(); + * s.m_viable.max_viable() + * when r, q are variables. + */ + clause_ref op_constraint::lemma_lshr(solver& s, assignment const& a) { + auto& m = p().manager(); + auto const pv = a.apply_to(p()); + auto const qv = a.apply_to(q()); + auto const rv = a.apply_to(r()); + unsigned const N = m.power_of_2(); + + signed_constraint const lshr(this, true); + + if (pv.is_val() && rv.is_val() && rv.val() > pv.val()) + // r <= p + return s.mk_clause(~lshr, s.ule(r(), p()), true); + else if (qv.is_val() && qv.val() >= N && rv.is_val() && !rv.is_zero()) + // TODO: instead of rv.is_val() && !rv.is_zero(), we should use !is_forced_zero(r) which checks whether eval(r) = 0 or bvalue(r=0) = true; see saturation.cpp + // q >= N -> r = 0 + return s.mk_clause(~lshr, ~s.ule(N, q()), s.eq(r()), true); + else if (qv.is_zero() && pv.is_val() && rv.is_val() && pv != rv) + // q = 0 -> p = r + return s.mk_clause(~lshr, ~s.eq(q()), s.eq(p(), r()), true); + else if (qv.is_val() && !qv.is_zero() && pv.is_val() && rv.is_val() && !pv.is_zero() && rv.val() >= pv.val()) + // q != 0 & p > 0 -> r < p + return s.mk_clause(~lshr, s.eq(q()), s.ule(p(), 0), s.ult(r(), p()), true); + else if (qv.is_val() && !qv.is_zero() && qv.val() < N && rv.is_val() && rv.val() > rational::power_of_two(N - qv.val().get_unsigned()) - 1) + // q >= k -> r <= 2^{N-k} - 1 + return s.mk_clause(~lshr, ~s.ule(qv.val(), q()), s.ule(r(), rational::power_of_two(N - qv.val().get_unsigned()) - 1), true); + // else if (pv == qv && !rv.is_zero()) + // return s.mk_clause(~lshr, ~s.eq(p(), q()), s.eq(r()), true); + else if (pv.is_val() && rv.is_val() && qv.is_val() && !qv.is_zero()) { + unsigned k = qv.val().get_unsigned(); + // q = k -> r[i] = p[i+k] for 0 <= i < N - k + for (unsigned i = 0; i < N - k; ++i) { + if (rv.val().get_bit(i) && !pv.val().get_bit(i + k)) { + return s.mk_clause(~lshr, ~s.eq(q(), k), ~s.bit(r(), i), s.bit(p(), i + k), true); + } + if (!rv.val().get_bit(i) && pv.val().get_bit(i + k)) { + return s.mk_clause(~lshr, ~s.eq(q(), k), s.bit(r(), i), ~s.bit(p(), i + k), true); + } + } + } + else { + // forward propagation + SASSERT(!(pv.is_val() && qv.is_val() && rv.is_val())); + // LOG(p() << " = " << pv << " and " << q() << " = " << qv << " yields [>>] " << r() << " = " << (qv.val().is_unsigned() ? machine_div2k(pv.val(), qv.val().get_unsigned()) : rational::zero())); + if (qv.is_val() && !rv.is_val()) { + rational const& q_val = qv.val(); + if (q_val >= N) + // q >= N ==> r = 0 + return s.mk_clause(~lshr, ~s.ule(N, q()), s.eq(r()), true); + if (pv.is_val()) { + SASSERT(q_val.is_unsigned()); + // p = p_val & q = q_val ==> r = p_val / 2^q_val + rational const r_val = machine_div2k(pv.val(), q_val.get_unsigned()); + return s.mk_clause(~lshr, ~s.eq(p(), pv), ~s.eq(q(), qv), s.eq(r(), r_val), true); + } + } + } + return {}; + } + + + /** + * Enforce axioms for constraint: r == p << q + * + * q >= N -> r = 0 + * q >= k -> r = 0 \/ r >= 2^k + * q >= k -> r[i] = 0 for i < k + * q = k -> r[i+k] = p[i] for 0 <= i < N - k + * q = 0 -> r = p + */ + clause_ref op_constraint::lemma_shl(solver& s, assignment const& a) { + auto& m = p().manager(); + auto const pv = a.apply_to(p()); + auto const qv = a.apply_to(q()); + auto const rv = a.apply_to(r()); + unsigned const N = m.power_of_2(); + + signed_constraint const shl(this, true); + + if (qv.is_val() && qv.val() >= N && rv.is_val() && !rv.is_zero()) + // q >= N -> r = 0 + return s.mk_clause(~shl, ~s.ule(N, q()), s.eq(r()), true); + else if (qv.is_zero() && pv.is_val() && rv.is_val() && rv != pv) + // q = 0 -> r = p + return s.mk_clause(~shl, ~s.eq(q()), s.eq(r(), p()), true); + else if (qv.is_val() && !qv.is_zero() && qv.val() < N && rv.is_val() && + !rv.is_zero() && rv.val() < rational::power_of_two(qv.val().get_unsigned())) + // q >= k -> r = 0 \/ r >= 2^k (intuitive version) + // q >= k -> r - 1 >= 2^k - 1 (equivalent unit constraint to better support narrowing) + return s.mk_clause(~shl, ~s.ule(qv.val(), q()), s.ule(rational::power_of_two(qv.val().get_unsigned()) - 1, r() - 1), true); + else if (pv.is_val() && rv.is_val() && qv.is_val() && !qv.is_zero()) { + unsigned k = qv.val().get_unsigned(); + // q = k -> r[i+k] = p[i] for 0 <= i < N - k + for (unsigned i = 0; i < N - k; ++i) { + if (rv.val().get_bit(i + k) && !pv.val().get_bit(i)) { + return s.mk_clause(~shl, ~s.eq(q(), k), ~s.bit(r(), i + k), s.bit(p(), i), true); + } + if (!rv.val().get_bit(i + k) && pv.val().get_bit(i)) { + return s.mk_clause(~shl, ~s.eq(q(), k), s.bit(r(), i + k), ~s.bit(p(), i), true); + } + } + } + else { + // forward propagation + SASSERT(!(pv.is_val() && qv.is_val() && rv.is_val())); + // LOG(p() << " = " << pv << " and " << q() << " = " << qv << " yields [<<] " << r() << " = " << (qv.val().is_unsigned() ? rational::power_of_two(qv.val().get_unsigned()) * pv.val() : rational::zero())); + if (qv.is_val() && !rv.is_val()) { + rational const& q_val = qv.val(); + if (q_val >= N) + // q >= N ==> r = 0 + return s.mk_clause("shl forward 1", {~shl, ~s.ule(N, q()), s.eq(r())}, true); + if (pv.is_val()) { + SASSERT(q_val.is_unsigned()); + // p = p_val & q = q_val ==> r = p_val * 2^q_val + rational const r_val = pv.val() * rational::power_of_two(q_val.get_unsigned()); + return s.mk_clause("shl forward 2", {~shl, ~s.eq(p(), pv), ~s.eq(q(), qv), s.eq(r(), r_val)}, true); + } + } + } + return {}; + } + + + + void op_constraint::activate_and(solver& s) { + auto x = p(), y = q(); + if (x.is_val()) + std::swap(x, y); + if (!y.is_val()) + return; + auto& m = x.manager(); + auto yv = y.val(); + if (!(yv + 1).is_power_of_two()) + return; + signed_constraint const andc(this, true); + if (yv == m.max_value()) + s.add_clause(~andc, s.eq(x, r()), false); + else if (yv == 0) + s.add_clause(~andc, s.eq(r()), false); + else { + unsigned N = m.power_of_2(); + unsigned k = yv.get_num_bits(); + SASSERT(k < N); + rational exp = rational::power_of_two(N - k); + s.add_clause(~andc, s.eq(x * exp, r() * exp), false); + s.add_clause(~andc, s.ule(r(), y), false); // maybe always activate these constraints regardless? + } + } + + /** + * Produce lemmas for constraint: r == p & q + * r <= p + * r <= q + * p = q => r = p + * p[i] && q[i] = r[i] + * p = 2^N - 1 => q = r + * q = 2^N - 1 => p = r + * p = 2^k - 1 => r*2^{N - k} = q*2^{N - k} + * q = 2^k - 1 => r*2^{N - k} = p*2^{N - k} + * p = 2^k - 1 && r = 0 && q != 0 => q >= 2^k + * q = 2^k - 1 && r = 0 && p != 0 => p >= 2^k + */ + clause_ref op_constraint::lemma_and(solver& s, assignment const& a) { + auto& m = p().manager(); + auto pv = a.apply_to(p()); + auto qv = a.apply_to(q()); + auto rv = a.apply_to(r()); + + signed_constraint const andc(this, true); // op_constraints are always true + + // r <= p + if (pv.is_val() && rv.is_val() && rv.val() > pv.val()) + return s.mk_clause(~andc, s.ule(r(), p()), true); + // r <= q + if (qv.is_val() && rv.is_val() && rv.val() > qv.val()) + return s.mk_clause(~andc, s.ule(r(), q()), true); + // p = q => r = p + if (pv.is_val() && qv.is_val() && rv.is_val() && pv == qv && rv != pv) + return s.mk_clause(~andc, ~s.eq(p(), q()), s.eq(r(), p()), true); + if (pv.is_val() && qv.is_val() && rv.is_val()) { + // p = -1 => r = q + if (pv.is_max() && qv != rv) + return s.mk_clause(~andc, ~s.eq(p(), m.max_value()), s.eq(q(), r()), true); + // q = -1 => r = p + if (qv.is_max() && pv != rv) + return s.mk_clause(~andc, ~s.eq(q(), m.max_value()), s.eq(p(), r()), true); + + unsigned const N = m.power_of_2(); + unsigned pow; + if ((pv.val() + 1).is_power_of_two(pow)) { + // p = 2^k - 1 && r = 0 && q != 0 => q >= 2^k + if (rv.is_zero() && !qv.is_zero() && qv.val() <= pv.val()) + return s.mk_clause(~andc, ~s.eq(p(), pv), ~s.eq(r()), s.eq(q()), s.ule(pv + 1, q()), true); + // p = 2^k - 1 ==> r*2^{N - k} = q*2^{N - k} + if (rv != qv) + return s.mk_clause(~andc, ~s.eq(p(), pv), s.eq(r() * rational::power_of_two(N - pow), q() * rational::power_of_two(N - pow)), true); + } + if ((qv.val() + 1).is_power_of_two(pow)) { + // q = 2^k - 1 && r = 0 && p != 0 ==> p >= 2^k + if (rv.is_zero() && !pv.is_zero() && pv.val() <= qv.val()) + return s.mk_clause(~andc, ~s.eq(q(), qv), ~s.eq(r()), s.eq(p()), s.ule(qv + 1, p()), true); + // q = 2^k - 1 ==> r*2^{N - k} = p*2^{N - k} + if (rv != pv) + return s.mk_clause(~andc, ~s.eq(q(), qv), s.eq(r() * rational::power_of_two(N - pow), p() * rational::power_of_two(N - pow)), true); + } + + for (unsigned i = 0; i < N; ++i) { + bool pb = pv.val().get_bit(i); + bool qb = qv.val().get_bit(i); + bool rb = rv.val().get_bit(i); + if (rb == (pb && qb)) + continue; + if (pb && qb && !rb) + return s.mk_clause(~andc, ~s.bit(p(), i), ~s.bit(q(), i), s.bit(r(), i), true); + else if (!pb && rb) + return s.mk_clause(~andc, s.bit(p(), i), ~s.bit(r(), i), true); + else if (!qb && rb) + return s.mk_clause(~andc, s.bit(q(), i), ~s.bit(r(), i), true); + else + UNREACHABLE(); + } + return {}; + } + + // Propagate r if p or q are 0 + if (pv.is_zero() && !rv.is_zero()) // rv not necessarily fully evaluated + return s.mk_clause(~andc, s.ule(r(), p()), true); + if (qv.is_zero() && !rv.is_zero()) // rv not necessarily fully evaluated + return s.mk_clause(~andc, s.ule(r(), q()), true); + // p = a && q = b ==> r = a & b + if (pv.is_val() && qv.is_val() && !rv.is_val()) { + // Just assign by this very weak justification. It will be strengthened in saturation in case of a conflict + LOG(p() << " = " << pv << " and " << q() << " = " << qv << " yields [band] " << r() << " = " << bitwise_and(pv.val(), qv.val())); + return s.mk_clause(~andc, ~s.eq(p(), pv), ~s.eq(q(), qv), s.eq(r(), bitwise_and(pv.val(), qv.val())), true); + } + + return {}; + } + + + + /** + * Produce lemmas for constraint: r == inv p + * p = 0 ==> r = 0 + * r = 0 ==> p = 0 + * p != 0 ==> odd(r) + * parity(p) >= k ==> p * r >= 2^k + * parity(p) < k ==> p * r <= 2^k - 1 + * parity(p) < k ==> r <= 2^(N - k) - 1 (because r is the smallest pseudo-inverse) + */ + clause_ref op_constraint::lemma_inv(solver& s, assignment const& a) { + auto& m = p().manager(); + auto pv = a.apply_to(p()); + auto rv = a.apply_to(r()); + + if (eval_inv(pv, rv) == l_true) + return {}; + + signed_constraint const invc(this, true); + + // p = 0 ==> r = 0 + if (pv.is_zero()) + return s.mk_clause(~invc, ~s.eq(p()), s.eq(r()), true); + // r = 0 ==> p = 0 + if (rv.is_zero()) + return s.mk_clause(~invc, ~s.eq(r()), s.eq(p()), true); + + // forward propagation: p assigned ==> r = pseudo_inverse(eval(p)) + // TODO: (later) this should be propagated instead of adding a clause + /*if (pv.is_val() && !rv.is_val()) + return s.mk_clause(~invc, ~s.eq(p(), pv), s.eq(r(), pv.val().pseudo_inverse(m.power_of_2())), true);*/ + + if (!pv.is_val() || !rv.is_val()) + return {}; + + unsigned parity_pv = pv.val().trailing_zeros(); + unsigned parity_rv = rv.val().trailing_zeros(); + + LOG("p: " << p() << " := " << pv << " parity " << parity_pv); + LOG("r: " << r() << " := " << rv << " parity " << parity_rv); + + // p != 0 ==> odd(r) + if (parity_rv != 0) + return s.mk_clause("r = inv p & p != 0 ==> odd(r)", {~invc, s.eq(p()), s.odd(r())}, true); + + pdd prod = p() * r(); + rational prodv = (pv * rv).val(); +// if (prodv != rational::power_of_two(parity_pv)) +// verbose_stream() << prodv << " " << rational::power_of_two(parity_pv) << " " << parity_pv << " " << pv << " " << rv << "\n"; + unsigned lower = 0, upper = m.power_of_2(); + // binary search for the parity (otw. we would have justifications like "parity_at_most(k) && parity_at_least(k)" for at most "k" widths + while (lower + 1 < upper) { + unsigned middle = (upper + lower) / 2; + LOG("Splitting on " << middle); + if (parity_pv >= middle) { // parity at least middle + lower = middle; + LOG("Its in [" << lower << "; " << upper << ")"); + // parity(p) >= k ==> p * r >= 2^k + if (prodv < rational::power_of_two(middle)) + return s.mk_clause("r = inv p & parity(p) >= k ==> p*r >= 2^k", + {~invc, ~s.parity_at_least(p(), middle), s.uge(prod, rational::power_of_two(middle))}, false); + // parity(p) >= k ==> r <= 2^(N - k) - 1 (because r is the smallest pseudo-inverse) + rational const max_rv = rational::power_of_two(m.power_of_2() - middle) - 1; + if (rv.val() > max_rv) + return s.mk_clause("r = inv p & parity(p) >= k ==> r <= 2^(N - k) - 1", + {~invc, ~s.parity_at_least(p(), middle), s.ule(r(), max_rv)}, false); + } + else { // parity less than middle + SASSERT(parity_pv < middle); + upper = middle; + LOG("Its in [" << lower << "; " << upper << ")"); + // parity(p) < k ==> p * r <= 2^k - 1 + if (prodv > rational::power_of_two(middle)) + return s.mk_clause("r = inv p & parity(p) < k ==> p*r <= 2^k - 1", + {~invc, s.parity_at_least(p(), middle), s.ule(prod, rational::power_of_two(middle) - 1)}, false); + } + } + // Why did it evaluate to false in this case? + UNREACHABLE(); + return {}; + } + + + + void op_constraint::activate_udiv(solver& s) { + // signed_constraint const udivc(this, true); Do we really need this premiss? We anyway assert these constraints as unit clauses + + pdd const& quot = r(); + pdd const& rem = m_linked->r(); + + // Axioms for quotient/remainder: + // a = b*q + r + // multiplication does not overflow in b*q + // addition does not overflow in (b*q) + r; for now expressed as: r <= bq+r + // b ≠ 0 ==> r < b + // b = 0 ==> q = -1 + // TODO: when a,b become evaluable, can we actually propagate q,r? doesn't seem like it. + // Maybe we need something like an op_constraint for better propagation. + s.add_clause(s.eq(q() * quot + rem - p()), false); + s.add_clause(~s.umul_ovfl(q(), quot), false); + // r <= b*q+r + // { apply equivalence: p <= q <=> q-p <= -p-1 } + // b*q <= -r-1 + s.add_clause(s.ule(q() * quot, -rem - 1), false); + + auto c_eq = s.eq(q()); + s.add_clause(c_eq, s.ult(rem, q()), false); + s.add_clause(~c_eq, s.eq(quot + 1), false); + } + + /** + * Produce lemmas for constraint: r == p / q + * q = 0 ==> r = max_value + * p = 0 ==> r = 0 || r = max_value + * q = 1 ==> r = p + */ + clause_ref op_constraint::lemma_udiv(solver& s, assignment const& a) { + auto pv = a.apply_to(p()); + auto qv = a.apply_to(q()); + auto rv = a.apply_to(r()); + + if (eval_udiv(pv, qv, rv) == l_true) + return {}; + + signed_constraint const udivc(this, true); + + if (qv.is_zero() && !rv.is_val()) + return s.mk_clause(~udivc, ~s.eq(q()), s.eq(r(), r().manager().max_value()), true); + if (pv.is_zero() && !rv.is_val()) + return s.mk_clause(~udivc, ~s.eq(p()), s.eq(r()), s.eq(r(), r().manager().max_value()), true); + if (qv.is_one()) + return s.mk_clause(~udivc, ~s.eq(q(), 1), s.eq(r(), p()), true); + + if (pv.is_val() && qv.is_val() && !rv.is_val()) { + SASSERT(!qv.is_zero()); + // TODO: We could actually propagate an interval. Instead of p = 9 & q = 4 => r = 2 we could do p >= 8 && p < 12 && q = 4 => r = 2 + return s.mk_clause(~udivc, ~s.eq(p(), pv.val()), ~s.eq(q(), qv.val()), s.eq(r(), div(pv.val(), qv.val())), true); + } + + return {}; + } + + + /** + * Produce lemmas for constraint: r == p % q + * p = 0 ==> r = 0 + * q = 1 ==> r = 0 + * q = 0 ==> r = p + */ + clause_ref op_constraint::lemma_urem(solver& s, assignment const& a) { + auto pv = a.apply_to(p()); + auto qv = a.apply_to(q()); + auto rv = a.apply_to(r()); + + if (eval_urem(pv, qv, rv) == l_true) + return {}; + + signed_constraint const urem(this, true); + + if (pv.is_zero() && !rv.is_val()) + return s.mk_clause(~urem, ~s.eq(p()), s.eq(r()), true); + if (qv.is_one() && !rv.is_val()) + return s.mk_clause(~urem, ~s.eq(q(), 1), s.eq(r()), true); + if (qv.is_zero()) + return s.mk_clause(~urem, ~s.eq(q()), s.eq(r(), p()), true); + + if (pv.is_val() && qv.is_val() && !rv.is_val()) { + SASSERT(!qv.is_zero()); + return s.mk_clause(~urem, ~s.eq(p(), pv.val()), ~s.eq(q(), qv.val()), s.eq(r(), mod(pv.val(), qv.val())), true); + } + + return {}; + } + + /** Evaluate constraint: r == p % q */ + lbool op_constraint::eval_urem(pdd const& p, pdd const& q, pdd const& r) { + + if (q.is_one() && r.is_val()) { + return r.val().is_zero() ? l_true : l_false; + } + if (q.is_zero()) { + if (r == p) + return l_true; + } + + if (!p.is_val() || !q.is_val() || !r.is_val()) + return l_undef; + + return r.val() == mod(p.val(), q.val()) ? l_true : l_false; // mod == rem as we know hat q > 0 + } + +#endif +} diff --git a/src/sat/smt/polysat/op_constraint.h b/src/sat/smt/polysat/op_constraint.h new file mode 100644 index 000000000..a33f1b705 --- /dev/null +++ b/src/sat/smt/polysat/op_constraint.h @@ -0,0 +1,94 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + Op constraint. + + lshr: r == p >> q + ashr: r == p >>a q + lshl: r == p << q + and: r == p & q + not: r == ~p + +Author: + + Jakob Rath, Nikolaj Bjorner (nbjorner) 2021-12-09 + +--*/ +#pragma once +#include "sat/smt/polysat/constraints.h" +#include + +namespace polysat { + + class core; + + class op_constraint final : public constraint { + public: + enum class code { + /// r is the logical right shift of p by q. + lshr_op, + /// r is the arithmetic right shift of p by q. + ashr_op, + /// r is the left shift of p by q. + shl_op, + /// r is the bit-wise 'and' of p and q. + and_op, + /// r is the smallest multiplicative pseudo-inverse of p; + /// by definition we set r == 0 when p == 0. + /// Note that in general, there are 2^parity(p) many pseudo-inverses of p. + inv_op, + }; + protected: + friend class constraints; + + code m_op; + pdd m_p; // operand1 + pdd m_q; // operand2 + pdd m_r; // result + + op_constraint(code c, pdd const& r, pdd const& p, pdd const& q); + lbool eval(pdd const& r, pdd const& p, pdd const& q) const; +// clause_ref produce_lemma(core& s, assignment const& a); + + // clause_ref lemma_lshr(core& s, assignment const& a); + static lbool eval_lshr(pdd const& p, pdd const& q, pdd const& r); + + // clause_ref lemma_shl(core& s, assignment const& a); + static lbool eval_shl(pdd const& p, pdd const& q, pdd const& r); + + // clause_ref lemma_and(core& s, assignment const& a); + static lbool eval_and(pdd const& p, pdd const& q, pdd const& r); + + // clause_ref lemma_inv(core& s, assignment const& a); + static lbool eval_inv(pdd const& p, pdd const& r); + + // clause_ref lemma_udiv(core& s, assignment const& a); + static lbool eval_udiv(pdd const& p, pdd const& q, pdd const& r); + + // clause_ref lemma_urem(core& s, assignment const& a); + static lbool eval_urem(pdd const& p, pdd const& q, pdd const& r); + + std::ostream& display(std::ostream& out, char const* eq) const; + + void activate(core& s); + + void activate_and(core& s); + void activate_udiv(core& s); + + public: + ~op_constraint() override {} + pdd const& p() const { return m_p; } + pdd const& q() const { return m_q; } + pdd const& r() const { return m_r; } + code get_op() const { return m_op; } + std::ostream& display(std::ostream& out, lbool status) const override; + std::ostream& display(std::ostream& out) const override; + lbool eval() const override; + lbool eval(assignment const& a) const override; + bool is_always_true() const { return false; } + bool is_always_false() const { return false; } + }; + +} From 15bae80ceafd880b56963a6d53e1bb191548c35f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 11 Dec 2023 15:00:06 -0800 Subject: [PATCH 39/89] handle more intblast cases Signed-off-by: Nikolaj Bjorner --- src/sat/smt/intblast_solver.cpp | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 8bb3faa49..8a3738cec 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -423,7 +423,8 @@ namespace intblast { r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); m_trail.push_back(r); break; - } + } + // Or use (p + q) - band(p, q)? case OP_BOR: for (unsigned i = 0; i < args.size(); ++i) args[i] = bnot(args.get(i)); @@ -435,8 +436,22 @@ namespace intblast { case OP_BAND: m_trail.push_back(band(args)); break; - case OP_BXOR: + // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; + // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 + // (p + q) - 2*band(p, q); case OP_BXNOR: + case OP_BXOR: { + unsigned sz = bv.get_bv_size(e); + expr* p = args.get(0); + for (unsigned i = 1; i < args.size(); ++i) { + expr* q = args.get(i); + p = a.mk_sub(a.mk_add(p, q), a.mk_mul(a.mk_int(2), a.mk_band(sz, p, q))); + } + if (ap->get_decl_kind() == OP_BXNOR) + p = bnot(p); + m_trail.push_back(p); + break; + } case OP_BCOMP: case OP_BASHR: case OP_ROTATE_LEFT: From b72575148ff59c10d9a584b7558f84754bdf61fe Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 11 Dec 2023 15:45:54 -0800 Subject: [PATCH 40/89] axioms for b-and Signed-off-by: Nikolaj Bjorner --- src/ast/arith_decl_plugin.h | 8 ++++ src/sat/smt/arith_axioms.cpp | 65 +++++++++++++++++++++++++++++++ src/sat/smt/arith_internalize.cpp | 4 +- src/sat/smt/arith_solver.cpp | 2 + src/sat/smt/arith_solver.h | 4 ++ src/sat/smt/intblast_solver.cpp | 27 +++++++++++-- src/sat/smt/intblast_solver.h | 4 +- 7 files changed, 106 insertions(+), 8 deletions(-) diff --git a/src/ast/arith_decl_plugin.h b/src/ast/arith_decl_plugin.h index a5ab60731..b073e205e 100644 --- a/src/ast/arith_decl_plugin.h +++ b/src/ast/arith_decl_plugin.h @@ -312,6 +312,14 @@ public: bool is_int_real(expr const * n) const { return is_int_real(n->get_sort()); } bool is_band(expr const* n) const { return is_app_of(n, arith_family_id, OP_ARITH_BAND); } + bool is_band(expr const* n, unsigned& sz, expr*& x, expr*& y) { + if (!is_band(n)) + return false; + x = to_app(n)->get_arg(0); + y = to_app(n)->get_arg(1); + sz = to_app(n)->get_parameter(0).get_int(); + return true; + } bool is_sin(expr const* n) const { return is_app_of(n, arith_family_id, OP_SIN); } bool is_cos(expr const* n) const { return is_app_of(n, arith_family_id, OP_COS); } diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 173ae28c8..b8bffa5f2 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -205,6 +205,71 @@ namespace arith { add_clause(dgez, neg); } + bool solver::check_band_term(app* n) { + unsigned sz; + expr* x, * y; + VERIFY(a.is_band(n, sz, x, y)); + if (use_nra_model()) { + found_unsupported(n); + return true; + } + theory_var vx = expr2enode(x)->get_th_var(get_id()); + theory_var vy = expr2enode(y)->get_th_var(get_id()); + theory_var xn = expr2enode(n)->get_th_var(get_id()); + rational valx = get_value(vx); + rational valy = get_value(vy); + rational valn = get_value(xn); + + // x mod 2^{i + 1} >= 2^i means the i'th bit is 1. + auto bitof = [&](expr* x, unsigned i) { + expr_ref r(m); + r = a.mk_ge(a.mk_mod(x, a.mk_int(rational::power_of_two(i+1))), a.mk_int(rational::power_of_two(i))); + return mk_literal(r); + }; + for (unsigned i = 0; i < sz; ++i) { + bool xb = valx.get_bit(i); + bool yb = valy.get_bit(i); + bool nb = valn.get_bit(i); + if (xb && yb && !nb) { + add_clause(~bitof(x, i), ~bitof(y, i), bitof(n, i)); + return false; + } + if (nb && !xb) { + add_clause(~bitof(n, i), bitof(x, i)); + return false; + } + if (nb && !yb) { + add_clause(~bitof(n, i), bitof(y, i)); + return false; + } + } + return true; + } + + bool solver::check_band_terms() { + for (app* n : m_band_terms) { + if (!check_band_term(n)) + return false; + } + return true; + } + + /* + * 0 <= x&y < 2^sz + * x&y <= x + * x&y <= y + */ + void solver::mk_band_axiom(app* n) { + unsigned sz; + expr* x, * y; + VERIFY(a.is_band(n, sz, x, y)); + rational N = rational::power_of_two(sz); + add_clause(mk_literal(a.mk_ge(n, a.mk_int(0)))); + add_clause(mk_literal(a.mk_le(n, a.mk_int(N - 1)))); + add_clause(mk_literal(a.mk_le(n, a.mk_mod(x, a.mk_int(N))))); + add_clause(mk_literal(a.mk_le(n, a.mk_mod(y, a.mk_int(N))))); + } + void solver::mk_bound_axioms(api_bound& b) { theory_var v = b.get_var(); lp_api::bound_kind kind1 = b.get_bound_kind(); diff --git a/src/sat/smt/arith_internalize.cpp b/src/sat/smt/arith_internalize.cpp index 66b25cb34..4d0943d65 100644 --- a/src/sat/smt/arith_internalize.cpp +++ b/src/sat/smt/arith_internalize.cpp @@ -254,7 +254,9 @@ namespace arith { } else if (a.is_band(n)) { // unsupported for now. - found_unsupported(n); + m_band_terms.push_back(to_app(n)); + mk_band_axiom(to_app(n)); + ctx.push(push_back_vector(m_band_terms)); ensure_arg_vars(to_app(n)); } else if (!a.is_div0(n) && !a.is_mod0(n) && !a.is_idiv0(n) && !a.is_rem0(n) && !a.is_power0(n)) { diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 2be9b6b60..9e03bbee4 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -1197,6 +1197,8 @@ namespace arith { default: UNREACHABLE(); } + if (lia_check == l_true && !check_band_terms()) + lia_check = l_false; return lia_check; } diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 20ae599c2..50cdc63ef 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -214,6 +214,7 @@ namespace arith { expr* m_not_handled = nullptr; ptr_vector m_underspecified; ptr_vector m_idiv_terms; + ptr_vector m_band_terms; vector > m_use_list; // bounds where variables are used. // attributes for incremental version: @@ -317,6 +318,7 @@ namespace arith { void mk_bound_axioms(api_bound& b); void mk_bound_axiom(api_bound& b1, api_bound& b2); void mk_power0_axioms(app* t, app* n); + void mk_band_axiom(app* n); void flush_bound_axioms(); void add_farkas_clause(sat::literal l1, sat::literal l2); @@ -408,6 +410,8 @@ namespace arith { bool check_delayed_eqs(); lbool check_lia(); lbool check_nla(); + bool check_band_terms(); + bool check_band_term(app* n); void add_lemmas(); void propagate_nla(); void add_equality(lpvar v, rational const& k, lp::explanation const& exp); diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 8a3738cec..65dc56e00 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -13,6 +13,7 @@ Author: #include "ast/ast_util.h" #include "ast/for_each_expr.h" +#include "params/bv_rewriter_params.hpp" #include "sat/smt/intblast_solver.h" #include "sat/smt/euf_solver.h" @@ -25,7 +26,8 @@ namespace intblast { m(ctx.get_manager()), bv(m), a(m), - m_trail(m) + m_trail(m), + m_pinned(m) {} lbool solver::check() { @@ -82,7 +84,6 @@ namespace intblast { m_core.reset(); m_vars.reset(); m_trail.reset(); - m_new_funs.reset(); m_solver = mk_smt2_solver(m, s.params(), symbol::null); expr_ref_vector es(m); @@ -284,6 +285,8 @@ namespace intblast { if (!m_new_funs.find(f, g)) { g = m.mk_fresh_func_decl(ap->get_decl()->get_name(), symbol("bv"), domain.size(), domain.data(), range); m_new_funs.insert(f, g); + m_pinned.push_back(f); + m_pinned.push_back(g); } f = g; } @@ -452,6 +455,24 @@ namespace intblast { m_trail.push_back(p); break; } + case OP_BUDIV: { + bv_rewriter_params p(ctx.s().params()); + expr* x = args.get(0), * y = args.get(1); + if (p.hi_div0()) + m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); + else + m_trail.push_back(a.mk_idiv(x, y)); + break; + } + case OP_BUREM: { + bv_rewriter_params p(ctx.s().params()); + expr* x = args.get(0), * y = args.get(1); + if (p.hi_div0()) + m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); + else + m_trail.push_back(a.mk_mod(x, y)); + break; + } case OP_BCOMP: case OP_BASHR: case OP_ROTATE_LEFT: @@ -463,9 +484,7 @@ namespace intblast { case OP_SIGN_EXT: case OP_BREDOR: case OP_BREDAND: - case OP_BUDIV: case OP_BSDIV: - case OP_BUREM: case OP_BSREM: case OP_BSMOD: verbose_stream() << mk_pp(e, m) << "\n"; diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index c165e1562..a093713bb 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -47,13 +47,11 @@ namespace intblast { obj_map m_vars; obj_map m_new_funs; expr_ref_vector m_trail; + ast_ref_vector m_pinned; sat::literal_vector m_core; - - bool is_bv(sat::literal lit); void translate(expr_ref_vector& es); - void add_root_equations(expr_ref_vector& es, ptr_vector& sorted); void sorted_subterms(expr_ref_vector& es, ptr_vector& sorted); public: From c72780d9b92bea096b98c83a6abc031305637665 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 11 Dec 2023 20:22:23 -0800 Subject: [PATCH 41/89] b-and, stats, reinsert variable to heap, debugging --- src/math/lp/lp_api.h | 2 ++ src/sat/smt/arith_axioms.cpp | 34 ++++++++++-------- src/sat/smt/arith_internalize.cpp | 1 - src/sat/smt/arith_solver.cpp | 5 +-- src/sat/smt/intblast_solver.cpp | 55 +++++++++++++++++++++++++++-- src/sat/smt/intblast_solver.h | 8 ++++- src/sat/smt/polysat/core.cpp | 18 ++++++++-- src/sat/smt/polysat/core.h | 3 ++ src/sat/smt/polysat_internalize.cpp | 9 ++++- src/sat/smt/polysat_model.cpp | 35 +++++++----------- src/sat/smt/polysat_solver.cpp | 6 ++++ src/sat/smt/polysat_solver.h | 3 +- 12 files changed, 132 insertions(+), 47 deletions(-) diff --git a/src/math/lp/lp_api.h b/src/math/lp/lp_api.h index 2a4e5058d..0eb8b6b37 100644 --- a/src/math/lp/lp_api.h +++ b/src/math/lp/lp_api.h @@ -108,6 +108,7 @@ namespace lp_api { unsigned m_gomory_cuts; unsigned m_assume_eqs; unsigned m_branch; + unsigned m_band_axioms; stats() { reset(); } void reset() { memset(this, 0, sizeof(*this)); @@ -128,6 +129,7 @@ namespace lp_api { st.update("arith-gomory-cuts", m_gomory_cuts); st.update("arith-assume-eqs", m_assume_eqs); st.update("arith-branch", m_branch); + st.update("arith-band-axioms", m_band_axioms); } }; diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index b8bffa5f2..046470000 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -215,10 +215,15 @@ namespace arith { } theory_var vx = expr2enode(x)->get_th_var(get_id()); theory_var vy = expr2enode(y)->get_th_var(get_id()); - theory_var xn = expr2enode(n)->get_th_var(get_id()); - rational valx = get_value(vx); - rational valy = get_value(vy); - rational valn = get_value(xn); + theory_var vn = expr2enode(n)->get_th_var(get_id()); + rational N = rational::power_of_two(sz); + SASSERT(get_value(vx).is_int()); + SASSERT(get_value(vy).is_int()); + SASSERT(get_value(vn).is_int()); + rational valx = mod(get_value(vx), N); + rational valy = mod(get_value(vy), N); + rational valn = get_value(vn); + SASSERT(0 <= valn && valn < N); // x mod 2^{i + 1} >= 2^i means the i'th bit is 1. auto bitof = [&](expr* x, unsigned i) { @@ -230,26 +235,25 @@ namespace arith { bool xb = valx.get_bit(i); bool yb = valy.get_bit(i); bool nb = valn.get_bit(i); - if (xb && yb && !nb) { + if (xb && yb && !nb) add_clause(~bitof(x, i), ~bitof(y, i), bitof(n, i)); - return false; - } - if (nb && !xb) { + else if (nb && !xb) add_clause(~bitof(n, i), bitof(x, i)); - return false; - } - if (nb && !yb) { + else if (nb && !yb) add_clause(~bitof(n, i), bitof(y, i)); - return false; - } + else + continue; + return false; } return true; } bool solver::check_band_terms() { for (app* n : m_band_terms) { - if (!check_band_term(n)) - return false; + if (!check_band_term(n)) { + ++m_stats.m_band_axioms; + return false; + } } return true; } diff --git a/src/sat/smt/arith_internalize.cpp b/src/sat/smt/arith_internalize.cpp index 4d0943d65..decd49019 100644 --- a/src/sat/smt/arith_internalize.cpp +++ b/src/sat/smt/arith_internalize.cpp @@ -253,7 +253,6 @@ namespace arith { st.to_ensure_var().push_back(n2); } else if (a.is_band(n)) { - // unsupported for now. m_band_terms.push_back(to_app(n)); mk_band_axiom(to_app(n)); ctx.push(push_back_vector(m_band_terms)); diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 9e03bbee4..306a6cce0 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -1042,6 +1042,9 @@ namespace arith { if (!check_delayed_eqs()) return sat::check_result::CR_CONTINUE; + if (!check_band_terms()) + return sat::check_result::CR_CONTINUE; + if (ctx.get_config().m_arith_ignore_int && int_undef) return sat::check_result::CR_GIVEUP; if (m_not_handled != nullptr) { @@ -1197,8 +1200,6 @@ namespace arith { default: UNREACHABLE(); } - if (lia_check == l_true && !check_band_terms()) - lia_check = l_false; return lia_check; } diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 65dc56e00..db5798236 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -13,6 +13,7 @@ Author: #include "ast/ast_util.h" #include "ast/for_each_expr.h" +#include "ast/rewriter/bv_rewriter.h" #include "params/bv_rewriter_params.hpp" #include "sat/smt/intblast_solver.h" #include "sat/smt/euf_solver.h" @@ -104,6 +105,8 @@ namespace intblast { lbool r = m_solver->check_sat(es); + m_solver->collect_statistics(m_stats); + IF_VERBOSE(2, verbose_stream() << "(sat.intblast :result " << r << ")\n"); if (r == l_false) { @@ -472,9 +475,32 @@ namespace intblast { else m_trail.push_back(a.mk_mod(x, y)); break; - } + } + // + // ashr(x, y) + // if y = k & x >= 0 -> x / 2^k + // if y = k & x < 0 -> - (x / 2^k) + // + + case OP_BASHR: { + expr* x = args.get(0), * y = args.get(1); + rational N = rational::power_of_two(bv.get_bv_size(e)); + bv_expr = ap; + x = mk_mod(x); + y = mk_mod(y); + expr* signbit = a.mk_ge(x, a.mk_int(N/2)); + expr* r = m.mk_ite(signbit, a.mk_int(N - 1), a.mk_int(0)); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) { + expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); + r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), + m.mk_ite(signbit, a.mk_uminus(d), d), + r); + } + m_trail.push_back(r); + break; + } case OP_BCOMP: - case OP_BASHR: + case OP_ROTATE_LEFT: case OP_ROTATE_RIGHT: case OP_EXT_ROTATE_LEFT: @@ -524,6 +550,27 @@ namespace intblast { return val; } + void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + expr_ref value(m); + if (n->interpreted()) + value = n->get_expr(); + else if (to_app(n->get_expr())->get_family_id() == bv.get_family_id()) { + bv_rewriter rw(m); + expr_ref_vector args(m); + for (auto arg : euf::enode_args(n)) + args.push_back(values.get(arg->get_root_id())); + rw.mk_app(n->get_decl(), args.size(), args.data(), value); + VERIFY(value); + } + else { + rational r = get_value(n->get_expr()); + verbose_stream() << ctx.bpp(n) << " := " << r << "\n"; + value = bv.mk_numeral(r, bv.get_bv_size(n->get_expr())); + } + values.set(n->get_root_id(), value); + TRACE("model", tout << "add_value " << ctx.bpp(n) << " := " << value << "\n"); + } + sat::literal_vector const& solver::unsat_core() { return m_core; } @@ -534,4 +581,8 @@ namespace intblast { return out; } + void solver::collect_statistics(statistics& st) const { + st.copy(m_stats); + } + } diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index a093713bb..b87724cc8 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -25,6 +25,7 @@ Author: #include "ast/bv_decl_plugin.h" #include "solver/solver.h" #include "sat/smt/sat_th.h" +#include "util/statistics.h" namespace euf { class solver; @@ -49,11 +50,14 @@ namespace intblast { expr_ref_vector m_trail; ast_ref_vector m_pinned; sat::literal_vector m_core; + statistics m_stats; bool is_bv(sat::literal lit); void translate(expr_ref_vector& es); void sorted_subterms(expr_ref_vector& es, ptr_vector& sorted); + rational get_value(expr* e) const; + public: solver(euf::solver& ctx); @@ -61,9 +65,11 @@ namespace intblast { sat::literal_vector const& unsat_core(); - rational get_value(expr* e) const; + void add_value(euf::enode* n, model& mdl, expr_ref_vector& values); std::ostream& display(std::ostream& out) const; + + void collect_statistics(statistics& st) const; }; } diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index 8e779923d..a552bb9ab 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -119,7 +119,7 @@ namespace polysat { m_activity.pop_back(); m_justification.pop_back(); m_watch.pop_back(); - m_values.pop_back(); + m_values.pop_back(); m_var_queue.del_var_eh(v); } @@ -160,6 +160,7 @@ namespace polysat { s.add_eq_literal(m_var, m_value); return sat::check_result::CR_CONTINUE; case find_t::resource_out: + m_var_queue.unassign_var_eh(m_var); return sat::check_result::CR_GIVEUP; } UNREACHABLE(); @@ -342,8 +343,21 @@ namespace polysat { for (auto const& [sc, d, value] : m_constraint_index) out << sc << " " << d << " := " << value << "\n"; for (unsigned i = 0; i < m_vars.size(); ++i) - out << "p" << m_vars[i] << " := " << m_values[i] << " " << m_justification[i] << "\n"; + out << m_vars[i] << " := " << m_values[i] << " " << m_justification[i] << "\n"; + m_var_queue.display(out << "vars ") << "\n"; return out; } + bool core::try_eval(pdd const& p, rational& r) { + auto q = subst(p); + if (!q.is_val()) + return false; + r = q.val(); + return true; + } + + void core::collect_statistics(statistics& st) const { + + } + } diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index c6442a290..c3dddfece 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -104,6 +104,9 @@ namespace polysat { pdd value(rational const& v, unsigned sz); pdd subst(pdd const&); + bool try_eval(pdd const& p, rational& r); + + void collect_statistics(statistics& st) const; signed_constraint eq(pdd const& p) { return m_constraints.eq(p); } signed_constraint eq(pdd const& p, pdd const& q) { return m_constraints.eq(p - q); } diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 4496dc759..5e5647bd3 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -86,6 +86,7 @@ namespace polysat { case OP_BSUB: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p - q; }); break; case OP_BLSHR: internalize_lshr(a); break; case OP_BSHL: internalize_shl(a); break; + case OP_BASHR: internalize_ashr(a); break; case OP_BAND: internalize_band(a); break; case OP_BOR: internalize_bor(a); break; case OP_BXOR: internalize_bxor(a); break; @@ -148,7 +149,7 @@ namespace polysat { case OP_BSDIV_I: case OP_BSREM_I: case OP_BSMOD_I: - case OP_BASHR: + IF_VERBOSE(0, verbose_stream() << mk_pp(a, m) << "\n"); NOT_IMPLEMENTED_YET(); return; @@ -254,6 +255,12 @@ namespace polysat { auto sc = m_core.lshr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); } + void solver::internalize_ashr(app* n) { + expr* x, * y; + VERIFY(bv.is_bv_ashr(n, x, y)); + auto sc = m_core.ashr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); + } + void solver::internalize_shl(app* n) { expr* x, * y; VERIFY(bv.is_bv_shl(n, x, y)); diff --git a/src/sat/smt/polysat_model.cpp b/src/sat/smt/polysat_model.cpp index 9a44e0abf..5bd8d4dc9 100644 --- a/src/sat/smt/polysat_model.cpp +++ b/src/sat/smt/polysat_model.cpp @@ -26,32 +26,18 @@ namespace polysat { void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { if (m_use_intblast_model) { - expr_ref value(m); - if (n->interpreted()) - value = n->get_expr(); - else if (to_app(n->get_expr())->get_family_id() == bv.get_family_id()) { - bv_rewriter rw(m); - expr_ref_vector args(m); - for (auto arg : euf::enode_args(n)) - args.push_back(values.get(arg->get_root_id())); - rw.mk_app(n->get_decl(), args.size(), args.data(), value); - VERIFY(value); - } - else { - rational r = m_intblast.get_value(n->get_expr()); - verbose_stream() << ctx.bpp(n) << " := " << r << "\n"; - value = bv.mk_numeral(r, get_bv_size(n)); - } - values.set(n->get_root_id(), value); - TRACE("model", tout << "add_value " << ctx.bpp(n) << " := " << value << "\n"); + m_intblast.add_value(n, mdl, values); return; } -#if 0 auto p = expr2pdd(n->get_expr()); rational val; - VERIFY(m_polysat.try_eval(p, val)); - values[n->get_root_id()] = bv.mk_numeral(val, get_bv_size(n)); -#endif + if (!m_core.try_eval(p, val)) { + ctx.s().display(verbose_stream()); + verbose_stream() << ctx.bpp(n) << " := " << p << "\n"; + UNREACHABLE(); + } + VERIFY(m_core.try_eval(p, val)); + values.set(n->get_root_id(), bv.mk_numeral(val, get_bv_size(n))); } bool solver::add_dep(euf::enode* n, top_sort& dep) { @@ -78,6 +64,11 @@ namespace polysat { } + void solver::collect_statistics(statistics& st) const { + m_intblast.collect_statistics(st); + m_core.collect_statistics(st); + } + std::ostream& solver::display_justification(std::ostream& out, sat::ext_justification_idx idx) const { return out; } diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 43f156c7d..9f185b22d 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -184,6 +184,9 @@ namespace polysat { void solver::new_eq_eh(euf::th_eq const& eq) { auto v1 = eq.v1(), v2 = eq.v2(); + euf::enode* n = var2enode(v1); + if (!bv.is_bv(n->get_expr())) + return; pdd p = var2pdd(v1); pdd q = var2pdd(v2); auto sc = m_core.eq(p, q); @@ -197,6 +200,9 @@ namespace polysat { void solver::new_diseq_eh(euf::th_eq const& ne) { euf::theory_var v1 = ne.v1(), v2 = ne.v2(); + euf::enode* n = var2enode(v1); + if (!bv.is_bv(n->get_expr())) + return; pdd p = var2pdd(v1); pdd q = var2pdd(v2); auto sc = ~m_core.eq(p, q); diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 7cf176b0c..f54bafb1c 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -121,6 +121,7 @@ namespace polysat { void internalize_bxnor(app* n); void internalize_band(app* n); void internalize_lshr(app* n); + void internalize_ashr(app* n); void internalize_shl(app* n); template void internalize_le(app* n); @@ -174,7 +175,7 @@ namespace polysat { std::ostream& display(std::ostream& out) const override; std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override; std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override; - void collect_statistics(statistics& st) const override {} + void collect_statistics(statistics& st) const override; euf::th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx, get_id()); } extension* copy(sat::solver* s) override { throw default_exception("nyi"); } void find_mutexes(literal_vector& lits, vector & mutexes) override {} From 4cadf6d9f2a55679446c3bf887a25668cbe46799 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 12 Dec 2023 11:11:37 -0800 Subject: [PATCH 42/89] preparing intblaster as self-contained solver. add activate and propagate to constraints support axiomatized operators band, lsh, rshl, rsha --- src/sat/smt/intblast_solver.cpp | 763 +++++++++++-------- src/sat/smt/intblast_solver.h | 76 +- src/sat/smt/polysat/constraints.h | 6 + src/sat/smt/polysat/core.cpp | 21 +- src/sat/smt/polysat/core.h | 19 +- src/sat/smt/polysat/op_constraint.cpp | 431 ++++------- src/sat/smt/polysat/op_constraint.h | 26 +- src/sat/smt/polysat/types.h | 12 +- src/sat/smt/polysat/ule_constraint.cpp | 15 +- src/sat/smt/polysat/ule_constraint.h | 2 + src/sat/smt/polysat/umul_ovfl_constraint.cpp | 82 ++ src/sat/smt/polysat/umul_ovfl_constraint.h | 4 + src/sat/smt/polysat_internalize.cpp | 10 +- src/sat/smt/polysat_solver.cpp | 23 +- src/sat/smt/polysat_solver.h | 2 +- 15 files changed, 831 insertions(+), 661 deletions(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index db5798236..0505eaa92 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -22,16 +22,109 @@ Author: namespace intblast { solver::solver(euf::solver& ctx) : + th_euf_solver(ctx, symbol("intblast"), ctx.get_manager().get_family_id("bv")), ctx(ctx), s(ctx.s()), m(ctx.get_manager()), bv(m), a(m), - m_trail(m), + m_args(m), + m_translate(m), m_pinned(m) {} - lbool solver::check() { + euf::theory_var solver::mk_var(euf::enode* n) { + auto r = euf::th_euf_solver::mk_var(n); + ctx.attach_th_var(n, this, r); + TRACE("bv", tout << "mk-var: v" << r << " " << ctx.bpp(n) << "\n";); + return r; + } + + sat::literal solver::internalize(expr* e, bool sign, bool root) { + force_push(); + SASSERT(m.is_bool(e)); + if (!visit_rec(m, e, sign, root)) + return sat::null_literal; + sat::literal lit = expr2literal(e); + if (sign) + lit.neg(); + return lit; + } + + void solver::internalize(expr* e) { + force_push(); + visit_rec(m, e, false, false); + } + + bool solver::visit(expr* e) { + if (!is_app(e) || to_app(e)->get_family_id() != get_id()) { + ctx.internalize(e); + return true; + } + m_stack.push_back(sat::eframe(e)); + return false; + } + + bool solver::visited(expr* e) { + euf::enode* n = expr2enode(e); + return n && n->is_attached_to(get_id()); + } + + bool solver::post_visit(expr* e, bool sign, bool root) { + euf::enode* n = expr2enode(e); + app* a = to_app(e); + if (visited(e)) + return true; + SASSERT(!n || !n->is_attached_to(get_id())); + if (!n) + n = mk_enode(e, false); + SASSERT(!n->is_attached_to(get_id())); + mk_var(n); + SASSERT(n->is_attached_to(get_id())); + internalize_bv(a); + return true; + } + + void solver::internalize_bv(app* e) { + ensure_args(e); + m_args.reset(); + for (auto arg : *e) + m_args.push_back(translated(arg)); + translate_bv(e); + if (m.is_bool(e)) + add_equiv(expr2literal(e), mk_literal(translated(e))); + } + + void solver::ensure_args(app* e) { + ptr_vector todo; + ast_fast_mark1 visited; + for (auto arg : *e) { + if (!m_translate.get(arg->get_id(), nullptr)) + todo.push_back(arg); + } + if (todo.empty()) + return; + for (unsigned i = 0; i < todo.size(); ++i) { + expr* e = todo[i]; + if (is_app(e)) { + for (auto arg : *to_app(e)) + if (!visited.is_marked(arg)) { + visited.mark(arg); + todo.push_back(arg); + } + } + else if (is_quantifier(e) && !visited.is_marked(to_quantifier(e)->get_expr())) { + visited.mark(to_quantifier(e)->get_expr()); + todo.push_back(to_quantifier(e)->get_expr()); + } + } + + std::stable_sort(todo.begin(), todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); + for (expr* e : todo) + translate_expr(e); + } + + lbool solver::check_solver_state() { sat::literal_vector literals; uint_set selected; for (auto const& clause : s.clauses()) { @@ -84,7 +177,7 @@ namespace intblast { m_core.reset(); m_vars.reset(); - m_trail.reset(); + m_translate.reset(); m_solver = mk_smt2_solver(m, s.params(), symbol::null); expr_ref_vector es(m); @@ -123,7 +216,6 @@ namespace intblast { m_core.push_back(ctx.mk_literal(e)); } } - return r; }; @@ -189,349 +281,348 @@ namespace intblast { void solver::translate(expr_ref_vector& es) { ptr_vector todo; - obj_map translated; - expr_ref_vector args(m); sorted_subterms(es, todo); - for (expr* e : todo) { - if (is_quantifier(e)) { - quantifier* q = to_quantifier(e); - expr* b = q->get_expr(); - - unsigned nd = q->get_num_decls(); - ptr_vector sorts; - for (unsigned i = 0; i < nd; ++i) { - auto s = q->get_decl_sort(i); - if (bv.is_bv_sort(s)) { - NOT_IMPLEMENTED_YET(); - sorts.push_back(a.mk_int()); - } - else - sorts.push_back(s); - } - b = translated[b]; - // TODO if sorts contain integer, then created bounds variables. - m_trail.push_back(m.update_quantifier(q, b)); - translated.insert(e, m_trail.back()); - continue; - } - if (is_var(e)) { - if (bv.is_bv_sort(e->get_sort())) { - expr* v = m.mk_var(to_var(e)->get_idx(), a.mk_int()); - m_trail.push_back(v); - translated.insert(e, m_trail.back()); - } - else { - m_trail.push_back(e); - translated.insert(e, m_trail.back()); - } - continue; - } - app* ap = to_app(e); - expr* bv_expr = e; - args.reset(); - for (auto arg : *ap) - args.push_back(translated[arg]); - - auto bv_size = [&]() { return rational::power_of_two(bv.get_bv_size(bv_expr->get_sort())); }; - - auto mk_mod = [&](expr* x) { - if (m_vars.contains(x)) - return x; - return to_expr(a.mk_mod(x, a.mk_int(bv_size()))); - }; - - auto mk_smod = [&](expr* x) { - auto shift = bv_size() / 2; - return a.mk_mod(a.mk_add(x, a.mk_int(shift)), a.mk_int(bv_size())); - }; - - if (m.is_eq(e)) { - bool has_bv_arg = any_of(*ap, [&](expr* arg) { return bv.is_bv(arg); }); - if (has_bv_arg) { - bv_expr = ap->get_arg(0); - m_trail.push_back(m.mk_eq(mk_mod(args.get(0)), mk_mod(args.get(1)))); - translated.insert(e, m_trail.back()); - } - else { - m_trail.push_back(m.mk_eq(args.get(0), args.get(1))); - translated.insert(e, m_trail.back()); - } - continue; - } - - if (m.is_ite(e)) { - m_trail.push_back(m.mk_ite(args.get(0), args.get(1), args.get(2))); - translated.insert(e, m_trail.back()); - continue; - } - - if (ap->get_family_id() != bv.get_family_id()) { - bool has_bv_arg = any_of(*ap, [&](expr* arg) { return bv.is_bv(arg); }); - bool has_bv_sort = bv.is_bv(e); - func_decl* f = ap->get_decl(); - if (has_bv_arg) { - verbose_stream() << mk_pp(ap, m) << "\n"; - // need to update args with mod where they are bit-vectors. - NOT_IMPLEMENTED_YET(); - } - - if (has_bv_arg || has_bv_sort) { - ptr_vector domain; - for (auto* arg : *ap) { - sort* s = arg->get_sort(); - domain.push_back(bv.is_bv_sort(s) ? a.mk_int() : s); - } - sort* range = bv.is_bv(e) ? a.mk_int() : e->get_sort(); - func_decl* g = nullptr; - if (!m_new_funs.find(f, g)) { - g = m.mk_fresh_func_decl(ap->get_decl()->get_name(), symbol("bv"), domain.size(), domain.data(), range); - m_new_funs.insert(f, g); - m_pinned.push_back(f); - m_pinned.push_back(g); - } - f = g; - } - - m_trail.push_back(m.mk_app(f, args)); - translated.insert(e, m_trail.back()); - - if (has_bv_sort) - m_vars.insert(e, { m_trail.back(), bv_size() }); - - continue; - } - - auto bnot = [&](expr* e) { - return a.mk_sub(a.mk_int(-1), e); - }; - - auto band = [&](expr_ref_vector const& args) { - expr * r = args.get(0); - for (unsigned i = 1; i < args.size(); ++i) - r = a.mk_band(bv.get_bv_size(e), r, args.get(i)); - return r; - }; - - switch (ap->get_decl_kind()) { - case OP_BADD: - m_trail.push_back(a.mk_add(args)); - break; - case OP_BSUB: - m_trail.push_back(a.mk_sub(args.size(), args.data())); - break; - case OP_BMUL: - m_trail.push_back(a.mk_mul(args)); - break; - case OP_ULEQ: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_le(mk_mod(args.get(0)), mk_mod(args.get(1)))); - break; - case OP_UGEQ: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_ge(mk_mod(args.get(0)), mk_mod(args.get(1)))); - break; - case OP_ULT: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_lt(mk_mod(args.get(0)), mk_mod(args.get(1)))); - break; - case OP_UGT: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_gt(mk_mod(args.get(0)), mk_mod(args.get(1)))); - break; - case OP_SLEQ: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_le(mk_smod(args.get(0)), mk_smod(args.get(1)))); - break; - case OP_SGEQ: - m_trail.push_back(a.mk_ge(mk_smod(args.get(0)), mk_smod(args.get(1)))); - break; - case OP_SLT: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_lt(mk_smod(args.get(0)), mk_smod(args.get(1)))); - break; - case OP_SGT: - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_gt(mk_smod(args.get(0)), mk_smod(args.get(1)))); - break; - case OP_BNEG: - m_trail.push_back(a.mk_uminus(args.get(0))); - break; - case OP_CONCAT: { - expr_ref r(a.mk_int(0), m); - unsigned sz = 0; - for (unsigned i = 0; i < args.size(); ++i) { - expr* old_arg = ap->get_arg(i); - expr* new_arg = args.get(i); - bv_expr = old_arg; - new_arg = mk_mod(new_arg); - if (sz > 0) { - new_arg = a.mk_mul(new_arg, a.mk_int(rational::power_of_two(sz))); - r = a.mk_add(r, new_arg); - } - else - r = new_arg; - sz += bv.get_bv_size(old_arg->get_sort()); - } - m_trail.push_back(r); - break; - } - case OP_EXTRACT: { - unsigned lo, hi; - expr* old_arg; - VERIFY(bv.is_extract(e, lo, hi, old_arg)); - unsigned sz = hi - lo + 1; - expr* new_arg = args.get(0); - if (lo > 0) - new_arg = a.mk_idiv(new_arg, a.mk_int(rational::power_of_two(lo))); - m_trail.push_back(new_arg); - break; - } - case OP_BV_NUM: { - rational val; - unsigned sz; - VERIFY(bv.is_numeral(e, val, sz)); - m_trail.push_back(a.mk_int(val)); - break; - } - case OP_BUREM_I: { - expr* x = args.get(0), * y = args.get(1); - m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); - break; - } - case OP_BUDIV_I: { - expr* x = args.get(0), * y = args.get(1); - m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); - break; - } - case OP_BUMUL_NO_OVFL: { - expr* x = args.get(0), * y = args.get(1); - bv_expr = ap->get_arg(0); - m_trail.push_back(a.mk_lt(a.mk_mul(mk_mod(x), mk_mod(y)), a.mk_int(bv_size()))); - break; - } - case OP_BSHL: { - expr* x = args.get(0), * y = args.get(1); - expr* r = a.mk_int(0); - for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); - m_trail.push_back(r); - break; - } - case OP_BNOT: - m_trail.push_back(bnot(args.get(0))); - break; - case OP_BLSHR: { - expr* x = args.get(0), * y = args.get(1); - expr* r = a.mk_int(0); - for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); - m_trail.push_back(r); - break; - } - // Or use (p + q) - band(p, q)? - case OP_BOR: - for (unsigned i = 0; i < args.size(); ++i) - args[i] = bnot(args.get(i)); - m_trail.push_back(bnot(band(args))); - break; - case OP_BNAND: - m_trail.push_back(bnot(band(args))); - break; - case OP_BAND: - m_trail.push_back(band(args)); - break; - // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; - // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 - // (p + q) - 2*band(p, q); - case OP_BXNOR: - case OP_BXOR: { - unsigned sz = bv.get_bv_size(e); - expr* p = args.get(0); - for (unsigned i = 1; i < args.size(); ++i) { - expr* q = args.get(i); - p = a.mk_sub(a.mk_add(p, q), a.mk_mul(a.mk_int(2), a.mk_band(sz, p, q))); - } - if (ap->get_decl_kind() == OP_BXNOR) - p = bnot(p); - m_trail.push_back(p); - break; - } - case OP_BUDIV: { - bv_rewriter_params p(ctx.s().params()); - expr* x = args.get(0), * y = args.get(1); - if (p.hi_div0()) - m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); - else - m_trail.push_back(a.mk_idiv(x, y)); - break; - } - case OP_BUREM: { - bv_rewriter_params p(ctx.s().params()); - expr* x = args.get(0), * y = args.get(1); - if (p.hi_div0()) - m_trail.push_back(m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); - else - m_trail.push_back(a.mk_mod(x, y)); - break; - } - // - // ashr(x, y) - // if y = k & x >= 0 -> x / 2^k - // if y = k & x < 0 -> - (x / 2^k) - // - - case OP_BASHR: { - expr* x = args.get(0), * y = args.get(1); - rational N = rational::power_of_two(bv.get_bv_size(e)); - bv_expr = ap; - x = mk_mod(x); - y = mk_mod(y); - expr* signbit = a.mk_ge(x, a.mk_int(N/2)); - expr* r = m.mk_ite(signbit, a.mk_int(N - 1), a.mk_int(0)); - for (unsigned i = 0; i < bv.get_bv_size(e); ++i) { - expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); - r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), - m.mk_ite(signbit, a.mk_uminus(d), d), - r); - } - m_trail.push_back(r); - break; - } - case OP_BCOMP: - - case OP_ROTATE_LEFT: - case OP_ROTATE_RIGHT: - case OP_EXT_ROTATE_LEFT: - case OP_EXT_ROTATE_RIGHT: - case OP_REPEAT: - case OP_ZERO_EXT: - case OP_SIGN_EXT: - case OP_BREDOR: - case OP_BREDAND: - case OP_BSDIV: - case OP_BSREM: - case OP_BSMOD: - verbose_stream() << mk_pp(e, m) << "\n"; - NOT_IMPLEMENTED_YET(); - break; - default: - verbose_stream() << mk_pp(e, m) << "\n"; - NOT_IMPLEMENTED_YET(); - } - translated.insert(e, m_trail.back()); - } + for (expr* e : todo) + translate_expr(e); TRACE("bv", for (expr* e : es) - tout << mk_pp(e, m) << "\n->\n" << mk_pp(translated[e], m) << "\n"; + tout << mk_pp(e, m) << "\n->\n" << mk_pp(translated(e), m) << "\n"; ); for (unsigned i = 0; i < es.size(); ++i) - es[i] = translated[es.get(i)]; + es[i] = translated(es.get(i)); + } + expr* solver::mk_mod(expr* x) { + if (m_vars.contains(x)) + return x; + return to_expr(a.mk_mod(x, a.mk_int(bv_size()))); + } + expr* solver::mk_smod(expr* x) { + auto shift = bv_size() / 2; + return a.mk_mod(a.mk_add(x, a.mk_int(shift)), a.mk_int(bv_size())); + } + + rational solver::bv_size() { + return rational::power_of_two(bv.get_bv_size(bv_expr->get_sort())); + } + + void solver::translate_expr(expr* e) { + if (is_quantifier(e)) + translate_quantifier(to_quantifier(e)); + else if (is_var(e)) + translate_var(to_var(e)); + else { + app* ap = to_app(e); + bv_expr = e; + m_args.reset(); + for (auto arg : *ap) + m_args.push_back(translated(arg)); + + if (ap->get_family_id() == basic_family_id) + translate_basic(ap); + else if (ap->get_family_id() == bv.get_family_id()) + translate_bv(ap); + else + translate_app(ap); + } + } + + void solver::translate_quantifier(quantifier* q) { + expr* b = q->get_expr(); + unsigned nd = q->get_num_decls(); + ptr_vector sorts; + for (unsigned i = 0; i < nd; ++i) { + auto s = q->get_decl_sort(i); + if (bv.is_bv_sort(s)) { + NOT_IMPLEMENTED_YET(); + sorts.push_back(a.mk_int()); + } + else + sorts.push_back(s); + } + b = translated(b); + // TODO if sorts contain integer, then created bounds variables. + set_translated(q, m.update_quantifier(q, b)); + } + + void solver::translate_var(var* v) { + if (bv.is_bv_sort(v->get_sort())) + set_translated(v, m.mk_var(v->get_idx(), a.mk_int())); + else + set_translated(v, v); + } + + void solver::translate_app(app* e) { + bool has_bv_arg = any_of(*e, [&](expr* arg) { return bv.is_bv(arg); }); + bool has_bv_sort = bv.is_bv(e); + func_decl* f = e->get_decl(); + if (has_bv_arg) { + verbose_stream() << mk_pp(e, m) << "\n"; + // need to update args with mod where they are bit-vectors. + NOT_IMPLEMENTED_YET(); + } + + if (has_bv_arg || has_bv_sort) { + ptr_vector domain; + for (auto* arg : *e) { + sort* s = arg->get_sort(); + domain.push_back(bv.is_bv_sort(s) ? a.mk_int() : s); + } + sort* range = bv.is_bv(e) ? a.mk_int() : e->get_sort(); + func_decl* g = nullptr; + if (!m_new_funs.find(f, g)) { + g = m.mk_fresh_func_decl(e->get_decl()->get_name(), symbol("bv"), domain.size(), domain.data(), range); + m_new_funs.insert(f, g); + m_pinned.push_back(f); + m_pinned.push_back(g); + } + f = g; + } + + set_translated(e, m.mk_app(f, m_args)); + + if (has_bv_sort) + m_vars.insert(e, { translated(e), bv_size()}); + } + + void solver::translate_bv(app* e) { + + auto bnot = [&](expr* e) { + return a.mk_sub(a.mk_int(-1), e); + }; + + auto band = [&](expr_ref_vector const& args) { + expr* r = arg(0); + for (unsigned i = 1; i < args.size(); ++i) + r = a.mk_band(bv.get_bv_size(e), r, arg(i)); + return r; + }; + + bv_expr = e; + expr* r = nullptr; + auto const& args = m_args; + switch (e->get_decl_kind()) { + case OP_BADD: + r = (a.mk_add(args)); + break; + case OP_BSUB: + r = (a.mk_sub(args.size(), args.data())); + break; + case OP_BMUL: + r = (a.mk_mul(args)); + break; + case OP_ULEQ: + bv_expr = e->get_arg(0); + r = (a.mk_le(mk_mod(arg(0)), mk_mod(arg(1)))); + break; + case OP_UGEQ: + bv_expr = e->get_arg(0); + r = (a.mk_ge(mk_mod(arg(0)), mk_mod(arg(1)))); + break; + case OP_ULT: + bv_expr = e->get_arg(0); + r = (a.mk_lt(mk_mod(arg(0)), mk_mod(arg(1)))); + break; + case OP_UGT: + bv_expr = e->get_arg(0); + r = (a.mk_gt(mk_mod(arg(0)), mk_mod(arg(1)))); + break; + case OP_SLEQ: + bv_expr = e->get_arg(0); + r = (a.mk_le(mk_smod(arg(0)), mk_smod(arg(1)))); + break; + case OP_SGEQ: + r = (a.mk_ge(mk_smod(arg(0)), mk_smod(arg(1)))); + break; + case OP_SLT: + bv_expr = e->get_arg(0); + r = (a.mk_lt(mk_smod(arg(0)), mk_smod(arg(1)))); + break; + case OP_SGT: + bv_expr = e->get_arg(0); + r = (a.mk_gt(mk_smod(arg(0)), mk_smod(arg(1)))); + break; + case OP_BNEG: + r = (a.mk_uminus(arg(0))); + break; + case OP_CONCAT: { + r = a.mk_int(0); + unsigned sz = 0; + for (unsigned i = 0; i < args.size(); ++i) { + expr* old_arg = e->get_arg(i); + expr* new_arg = arg(i); + bv_expr = old_arg; + new_arg = mk_mod(new_arg); + if (sz > 0) { + new_arg = a.mk_mul(new_arg, a.mk_int(rational::power_of_two(sz))); + r = a.mk_add(r, new_arg); + } + else + r = new_arg; + sz += bv.get_bv_size(old_arg->get_sort()); + } + break; + } + case OP_EXTRACT: { + unsigned lo, hi; + expr* old_arg; + VERIFY(bv.is_extract(e, lo, hi, old_arg)); + unsigned sz = hi - lo + 1; + expr* r = arg(0); + if (lo > 0) + r = a.mk_idiv(r, a.mk_int(rational::power_of_two(lo))); + break; + } + case OP_BV_NUM: { + rational val; + unsigned sz; + VERIFY(bv.is_numeral(e, val, sz)); + r = (a.mk_int(val)); + break; + } + case OP_BUREM_I: { + expr* x = arg(0), * y = arg(1); + r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); + break; + } + case OP_BUDIV_I: { + expr* x = arg(0), * y = arg(1); + r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); + break; + } + case OP_BUMUL_NO_OVFL: { + expr* x = arg(0), * y = arg(1); + bv_expr = e->get_arg(0); + r = (a.mk_lt(a.mk_mul(mk_mod(x), mk_mod(y)), a.mk_int(bv_size()))); + break; + } + case OP_BSHL: { + expr* x = arg(0), * y = arg(1); + r = a.mk_int(0); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) + r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); + break; + } + case OP_BNOT: + r = (bnot(arg(0))); + break; + case OP_BLSHR: { + expr* x = arg(0), * y = arg(1); + r = a.mk_int(0); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) + r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); + break; + } + // Or use (p + q) - band(p, q)? + case OP_BOR: { + r = arg(0); + for (unsigned i = 1; i < args.size(); ++i) + r = a.mk_sub(a.mk_add(r, arg(i)), a.mk_band(bv.get_bv_size(e), r, arg(i))); + break; + } + case OP_BNAND: + r = (bnot(band(args))); + break; + case OP_BAND: + r = (band(args)); + break; + // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; + // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 + // (p + q) - 2*band(p, q); + case OP_BXNOR: + case OP_BXOR: { + unsigned sz = bv.get_bv_size(e); + expr* p = arg(0); + for (unsigned i = 1; i < args.size(); ++i) { + expr* q = arg(i); + p = a.mk_sub(a.mk_add(p, q), a.mk_mul(a.mk_int(2), a.mk_band(sz, p, q))); + } + if (e->get_decl_kind() == OP_BXNOR) + p = bnot(p); + r = (p); + break; + } + case OP_BUDIV: { + bv_rewriter_params p(ctx.s().params()); + expr* x = arg(0), * y = arg(1); + if (p.hi_div0()) + r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); + else + r = (a.mk_idiv(x, y)); + break; + } + case OP_BUREM: { + bv_rewriter_params p(ctx.s().params()); + expr* x = arg(0), * y = arg(1); + if (p.hi_div0()) + r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); + else + r = (a.mk_mod(x, y)); + break; + } + // + // ashr(x, y) + // if y = k & x >= 0 -> x / 2^k + // if y = k & x < 0 -> - (x / 2^k) + // + + case OP_BASHR: { + expr* x = arg(0), * y = arg(1); + rational N = rational::power_of_two(bv.get_bv_size(e)); + bv_expr = e; + x = mk_mod(x); + y = mk_mod(y); + expr* signbit = a.mk_ge(x, a.mk_int(N / 2)); + r = m.mk_ite(signbit, a.mk_int(N - 1), a.mk_int(0)); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) { + expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); + r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), + m.mk_ite(signbit, a.mk_uminus(d), d), + r); + } + break; + } + case OP_BCOMP: + + case OP_ROTATE_LEFT: + case OP_ROTATE_RIGHT: + case OP_EXT_ROTATE_LEFT: + case OP_EXT_ROTATE_RIGHT: + case OP_REPEAT: + case OP_ZERO_EXT: + case OP_SIGN_EXT: + case OP_BREDOR: + case OP_BREDAND: + case OP_BSDIV: + case OP_BSREM: + case OP_BSMOD: + verbose_stream() << mk_pp(e, m) << "\n"; + NOT_IMPLEMENTED_YET(); + break; + default: + verbose_stream() << mk_pp(e, m) << "\n"; + NOT_IMPLEMENTED_YET(); + } + set_translated(e, r); + } + + void solver::translate_basic(app* e) { + if (m.is_eq(e)) { + bool has_bv_arg = any_of(*e, [&](expr* arg) { return bv.is_bv(arg); }); + if (has_bv_arg) { + bv_expr = e->get_arg(0); + set_translated(e, m.mk_eq(mk_mod(arg(0)), mk_mod(arg(1)))); + } + else + set_translated(e, m.mk_eq(arg(0), arg(1))); + } + else + set_translated(e, m.mk_app(e->get_decl(), m_args)); } rational solver::get_value(expr* e) const { diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index b87724cc8..037b009a3 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -8,12 +8,24 @@ Module Name: Abstract: Int-blast solver. - It assumes a full assignemnt to literals in + + check_solver_state assumes a full assignment to literals in irredundant clauses. It picks a satisfying Boolean assignment and checks if it is feasible for bit-vectors using an arithmetic solver. + The solver plugin is self-contained. + + Internalize: + - internalize bit-vector terms bottom-up by updating m_translate. + - add axioms of the form: + - ule(b,a) <=> translate(ule(b, a)) + - let arithmetic solver handle bit-vector constraints. + - For shared b + - Ensure: int2bv(translate(b)) = b + - but avoid bit-blasting by ensuring int2bv is injective (mod N) during final check + Author: Nikolaj Bjorner (nbjorner) 2023-12-10 @@ -33,7 +45,7 @@ namespace euf { namespace intblast { - class solver { + class solver : public euf::th_euf_solver { struct var_info { expr* dst; rational sz; @@ -47,7 +59,7 @@ namespace intblast { scoped_ptr<::solver> m_solver; obj_map m_vars; obj_map m_new_funs; - expr_ref_vector m_trail; + expr_ref_vector m_translate, m_args; ast_ref_vector m_pinned; sat::literal_vector m_core; statistics m_stats; @@ -58,18 +70,68 @@ namespace intblast { rational get_value(expr* e) const; + expr* translated(expr* e) { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } + void set_translated(expr* e, expr* r) { m_translate.setx(e->get_id(), r); } + expr* arg(unsigned i) { return m_args.get(i); } + + expr* mk_mod(expr* x); + expr* mk_smod(expr* x); + expr* bv_expr = nullptr; + rational bv_size(); + + void translate_expr(expr* e); + void translate_bv(app* e); + void translate_basic(app* e); + void translate_app(app* e); + void translate_quantifier(quantifier* q); + void translate_var(var* v); + + void ensure_args(app* e); + void internalize_bv(app* e); + + euf::theory_var mk_var(euf::enode* n) override; + public: solver(euf::solver& ctx); - lbool check(); + ~solver() override {} + + lbool check_solver_state(); sat::literal_vector const& unsat_core(); - void add_value(euf::enode* n, model& mdl, expr_ref_vector& values); + void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; - std::ostream& display(std::ostream& out) const; + std::ostream& display(std::ostream& out) const override; + + void collect_statistics(statistics& st) const override; + + + + bool unit_propagate() override { return false; } + + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override {} + + sat::check_result check() override { return sat::check_result::CR_DONE; } + + std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { return out; } + + std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { return out; } + + euf::th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); } + + void internalize(expr* e) override; + + bool visited(expr* e) override; + + bool post_visit(expr* e, bool sign, bool root) override; + + bool visit(expr* e) override; + + sat::literal internalize(expr* e, bool, bool) override; + + void eq_internalized(euf::enode* n) override {} - void collect_statistics(statistics& st) const; }; } diff --git a/src/sat/smt/polysat/constraints.h b/src/sat/smt/polysat/constraints.h index 15d8dfa09..a9ec63165 100644 --- a/src/sat/smt/polysat/constraints.h +++ b/src/sat/smt/polysat/constraints.h @@ -41,6 +41,8 @@ namespace polysat { virtual std::ostream& display(std::ostream& out) const = 0; virtual lbool eval() const = 0; virtual lbool eval(assignment const& a) const = 0; + virtual void activate(core& c, bool sign, dependency const& d) = 0; + virtual void propagate(core& c, lbool value, dependency const& d) = 0; }; inline std::ostream& operator<<(std::ostream& out, constraint const& c) { return c.display(out); } @@ -61,6 +63,8 @@ namespace polysat { unsigned_vector const& vars() const { return m_constraint->vars(); } unsigned var(unsigned idx) const { return m_constraint->var(idx); } bool contains_var(pvar v) const { return m_constraint->contains_var(v); } + void activate(core& c, dependency const& d) { m_constraint->activate(c, m_sign, d); } + void propagate(core& c, lbool value, dependency const& d) { m_constraint->propagate(c, value, d); } bool is_always_true() const { return eval() == l_true; } bool is_always_false() const { return eval() == l_false; } lbool eval(assignment& a) const; @@ -84,6 +88,8 @@ namespace polysat { signed_constraint eq(pdd const& p) { return ule(p, p.manager().mk_val(0)); } signed_constraint eq(pdd const& p, rational const& v) { return eq(p - p.manager().mk_val(v)); } + signed_constraint eq(pdd const& p, unsigned v) { return eq(p - p.manager().mk_val(v)); } + signed_constraint eq(pdd const& p, pdd const& q) { return eq(p - q); } signed_constraint ule(pdd const& p, pdd const& q); signed_constraint sle(pdd const& p, pdd const& q) { auto sh = rational::power_of_two(p.power_of_2() - 1); return ule(p + sh, q + sh); } signed_constraint ult(pdd const& p, pdd const& q) { return ~ule(q, p); } diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index a552bb9ab..dd30e8227 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -187,12 +187,13 @@ namespace polysat { return sc; } - void core::propagate_assignment(prop_item& dc) { auto [idx, sign, dep] = dc; auto sc = get_constraint(idx, sign); if (sc.is_eq(m_var, m_value)) propagate_assignment(m_var, m_value, dep); + else + sc.activate(*this, dep); } void core::add_watch(unsigned idx, unsigned var) { @@ -216,7 +217,7 @@ namespace polysat { // for entries where there is only one free variable left add to viable set unsigned j = 0; for (auto idx : m_watch[v]) { - auto [sc, as, value] = m_constraint_index[idx]; + auto [sc, dep, value] = m_constraint_index[idx]; auto& vars = sc.vars(); if (vars[0] != v) std::swap(vars[0], vars[1]); @@ -231,6 +232,8 @@ namespace polysat { } } + sc.propagate(*this, value, dep); + SASSERT(!swapped || vars.size() <= 1 || (!is_assigned(vars[0]) && !is_assigned(vars[1]))); if (swapped) continue; @@ -262,6 +265,11 @@ namespace polysat { default: break; } + // propagate current assignment for sc + sc.propagate(*this, to_lbool(!sign), dep); + if (s.inconsistent()) + return; + // if sc is v == value, then check the watch list for v to propagate truth assignments if (sc.is_eq(m_var, m_value)) { for (auto idx1 : m_watch[m_var]) { @@ -360,4 +368,13 @@ namespace polysat { } + void core::add_axiom(signed_constraint sc) { + auto idx = register_constraint(sc, dependency::axiom()); + assign_eh(idx, false, dependency::axiom()); + } + + void core::add_clause(char const* name, core_vector const& cs, bool is_redundant) { + s.add_polysat_clause(name, cs, is_redundant); + } + } diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index c3dddfece..6297e567e 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -84,7 +84,7 @@ namespace polysat { void get_bitvector_prefixes(pvar v, pvar_vector& out); void get_fixed_bits(pvar v, svector& fixed_bits); bool inconsistent() const; - void add_clause(char const* name, std::initializer_list cs, bool is_redundant); + void add_watch(unsigned idx, unsigned var); @@ -94,6 +94,8 @@ namespace polysat { lbool eval(signed_constraint const& sc); dependency_vector explain_eval(signed_constraint const& sc); + void add_axiom(signed_constraint sc); + public: core(solver_interface& s); @@ -118,13 +120,20 @@ namespace polysat { signed_constraint bit(pdd const& p, unsigned i) { return m_constraints.bit(p, i); } - signed_constraint lshr(pdd const& a, pdd const& b, pdd const& r) { return m_constraints.lshr(a, b, r); } - signed_constraint ashr(pdd const& a, pdd const& b, pdd const& r) { return m_constraints.ashr(a, b, r); } - signed_constraint shl(pdd const& a, pdd const& b, pdd const& r) { return m_constraints.shl(a, b, r); } - signed_constraint band(pdd const& a, pdd const& b, pdd const& r) { return m_constraints.band(a, b, r); } + void lshr(pdd const& a, pdd const& b, pdd const& r) { add_axiom(m_constraints.lshr(a, b, r)); } + void ashr(pdd const& a, pdd const& b, pdd const& r) { add_axiom(m_constraints.ashr(a, b, r)); } + void shl(pdd const& a, pdd const& b, pdd const& r) { add_axiom(m_constraints.shl(a, b, r)); } + void band(pdd const& a, pdd const& b, pdd const& r) { add_axiom(m_constraints.band(a, b, r)); } pdd bnot(pdd p) { return -p - 1; } + + /* + * Add a named clause. Dependencies are assumed, signed constraints are guaranteeed. + * In other words, the clause represents the formula /\ d_i -> \/ sc_j + * Where d_i are logical interpretations of dependencies and sc_j are signed constraints. + */ + void add_clause(char const* name, core_vector const& cs, bool is_redundant); pvar add_var(unsigned sz); pdd var(pvar p) { return m_vars[p]; } diff --git a/src/sat/smt/polysat/op_constraint.cpp b/src/sat/smt/polysat/op_constraint.cpp index 401b1ca52..234ddea04 100644 --- a/src/sat/smt/polysat/op_constraint.cpp +++ b/src/sat/smt/polysat/op_constraint.cpp @@ -13,13 +13,13 @@ Notes: Additional possible functionality on constraints: -- activate - when operation is first activated. It may be created and only activated later. - bit-wise assignments - narrow based on bit assignment, not entire word assignment. - integration with congruence tables - integration with conflict resolution --*/ +#include "util/log.h" #include "sat/smt/polysat/op_constraint.h" #include "sat/smt/polysat/core.h" @@ -157,7 +157,6 @@ namespace polysat { return out << "&"; case op_constraint::code::inv_op: return out << "inv"; - default: UNREACHABLE(); return out; @@ -176,96 +175,95 @@ namespace polysat { return out << r() << " " << eq << " " << p() << " " << m_op << " " << q(); } -#if 0 - /** - * Produce lemmas that contradict the given assignment. - * - * We can assume that op_constraint is only asserted positive. - */ - clause_ref op_constraint::produce_lemma(solver& s, assignment const& a, bool is_positive) { - SASSERT(is_positive); - - if (is_currently_true(a, is_positive)) - return {}; - - return produce_lemma(s, a); - } - - clause_ref op_constraint::produce_lemma(solver& s, assignment const& a) { + void op_constraint::activate(core& c, bool sign, dependency const& dep) { + SASSERT(!sign); switch (m_op) { - case code::lshr_op: - return lemma_lshr(s, a); - case code::shl_op: - return lemma_shl(s, a); case code::and_op: - return lemma_and(s, a); - case code::inv_op: - return lemma_inv(s, a); + activate_and(c, dep); + break; default: - NOT_IMPLEMENTED_YET(); - return {}; + break; } } + void op_constraint::propagate(core& c, lbool value, dependency const& dep) { + SASSERT(value == l_true); + switch (m_op) { + case code::lshr_op: + propagate_lshr(c, dep); + break; + case code::shl_op: + propagate_shl(c, dep); + break; + case code::and_op: + propagate_and(c, dep); + break; + case code::inv_op: + propagate_inv(c, dep); + break; + default: + NOT_IMPLEMENTED_YET(); + break; + } + } + + void op_constraint::propagate_inv(core& s, dependency const& dep) { + + } + /** - * Enforce basic axioms for r == p >> q: - * - * q >= N -> r = 0 - * q >= k -> r[i] = 0 for N - k <= i < N (bit indices range from 0 to N-1, inclusive) - * q >= k -> r <= 2^{N-k} - 1 - * q = k -> r[i] = p[i+k] for 0 <= i < N - k - * r <= p - * q != 0 -> r <= p (subsumed by previous axiom) - * q != 0 /\ p > 0 -> r < p - * q = 0 -> r = p - * p = q -> r = 0 - * - * when q is a constant, several axioms can be enforced at activation time. - * - * Enforce also inferences and bounds - * - * TODO: use also - * s.m_viable.min_viable(); - * s.m_viable.max_viable() - * when r, q are variables. - */ - clause_ref op_constraint::lemma_lshr(solver& s, assignment const& a) { + * Enforce basic axioms for r == p >> q: + * + * q >= N -> r = 0 + * q >= k -> r[i] = 0 for N - k <= i < N (bit indices range from 0 to N-1, inclusive) + * q >= k -> r <= 2^{N-k} - 1 + * q = k -> r[i] = p[i+k] for 0 <= i < N - k + * r <= p + * q != 0 -> r <= p (subsumed by previous axiom) + * q != 0 /\ p > 0 -> r < p + * q = 0 -> r = p + * p = q -> r = 0 + * + * when q is a constant, several axioms can be enforced at activation time. + * + * Enforce also inferences and bounds + * + * TODO: use also + * s.m_viable.min_viable(); + * s.m_viable.max_viable() + * when r, q are variables. + */ + void op_constraint::propagate_lshr(core& c, dependency const& d) { auto& m = p().manager(); - auto const pv = a.apply_to(p()); - auto const qv = a.apply_to(q()); - auto const rv = a.apply_to(r()); + auto const pv = c.subst(p()); + auto const qv = c.subst(q()); + auto const rv = c.subst(r()); unsigned const N = m.power_of_2(); - signed_constraint const lshr(this, true); + + signed_constraint const lshr(polysat::ckind_t::op_t, this); + auto& C = c.cs(); if (pv.is_val() && rv.is_val() && rv.val() > pv.val()) - // r <= p - return s.mk_clause(~lshr, s.ule(r(), p()), true); + c.add_clause("lshr 1", { d, C.ule(r(), p()) }, false); + else if (qv.is_val() && qv.val() >= N && rv.is_val() && !rv.is_zero()) // TODO: instead of rv.is_val() && !rv.is_zero(), we should use !is_forced_zero(r) which checks whether eval(r) = 0 or bvalue(r=0) = true; see saturation.cpp - // q >= N -> r = 0 - return s.mk_clause(~lshr, ~s.ule(N, q()), s.eq(r()), true); + c.add_clause("q >= N -> r = 0", { d, ~C.ule(N, q()), C.eq(r()) }, true); else if (qv.is_zero() && pv.is_val() && rv.is_val() && pv != rv) - // q = 0 -> p = r - return s.mk_clause(~lshr, ~s.eq(q()), s.eq(p(), r()), true); + c.add_clause("q = 0 -> p = r", { d, ~C.eq(q()), C.eq(p(), r()) } , true); else if (qv.is_val() && !qv.is_zero() && pv.is_val() && rv.is_val() && !pv.is_zero() && rv.val() >= pv.val()) - // q != 0 & p > 0 -> r < p - return s.mk_clause(~lshr, s.eq(q()), s.ule(p(), 0), s.ult(r(), p()), true); + c.add_clause("q != 0 & p > 0 -> r < p", { d, C.eq(q()), C.ule(p(), 0), C.ult(r(), p()) }, true); else if (qv.is_val() && !qv.is_zero() && qv.val() < N && rv.is_val() && rv.val() > rational::power_of_two(N - qv.val().get_unsigned()) - 1) - // q >= k -> r <= 2^{N-k} - 1 - return s.mk_clause(~lshr, ~s.ule(qv.val(), q()), s.ule(r(), rational::power_of_two(N - qv.val().get_unsigned()) - 1), true); - // else if (pv == qv && !rv.is_zero()) - // return s.mk_clause(~lshr, ~s.eq(p(), q()), s.eq(r()), true); + c.add_clause("q >= k -> r <= 2^{N-k} - 1", { d, ~C.ule(qv.val(), q()), C.ule(r(), rational::power_of_two(N - qv.val().get_unsigned()) - 1)}, true); else if (pv.is_val() && rv.is_val() && qv.is_val() && !qv.is_zero()) { unsigned k = qv.val().get_unsigned(); - // q = k -> r[i] = p[i+k] for 0 <= i < N - k for (unsigned i = 0; i < N - k; ++i) { - if (rv.val().get_bit(i) && !pv.val().get_bit(i + k)) { - return s.mk_clause(~lshr, ~s.eq(q(), k), ~s.bit(r(), i), s.bit(p(), i + k), true); - } - if (!rv.val().get_bit(i) && pv.val().get_bit(i + k)) { - return s.mk_clause(~lshr, ~s.eq(q(), k), s.bit(r(), i), ~s.bit(p(), i + k), true); - } + if (rv.val().get_bit(i) && !pv.val().get_bit(i + k)) + c.add_clause("q = k -> r[i] = p[i+k] for 0 <= i < N - k", { d, ~C.eq(q(), k), ~C.bit(r(), i), C.bit(p(), i + k) }, true); + + if (!rv.val().get_bit(i) && pv.val().get_bit(i + k)) + c.add_clause("q = k -> r[i] = p[i+k] for 0 <= i < N - k", { d, ~C.eq(q(), k), C.bit(r(), i), ~C.bit(p(), i + k) }, true); } } else { @@ -276,19 +274,44 @@ namespace polysat { rational const& q_val = qv.val(); if (q_val >= N) // q >= N ==> r = 0 - return s.mk_clause(~lshr, ~s.ule(N, q()), s.eq(r()), true); - if (pv.is_val()) { + c.add_clause("q >= N ==> r = 0", { d, ~C.ule(N, q()), C.eq(r()) }, true); + else if (pv.is_val()) { SASSERT(q_val.is_unsigned()); - // p = p_val & q = q_val ==> r = p_val / 2^q_val + // rational const r_val = machine_div2k(pv.val(), q_val.get_unsigned()); - return s.mk_clause(~lshr, ~s.eq(p(), pv), ~s.eq(q(), qv), s.eq(r(), r_val), true); + c.add_clause("p = p_val & q = q_val ==> r = p_val / 2^q_val", { d, ~C.eq(p(), pv), ~C.eq(q(), qv), C.eq(r(), r_val) }, true); } } } - return {}; } + void op_constraint::activate_and(core& c, dependency const& d) { + auto x = p(), y = q(); + auto& C = c.cs(); + if (x.is_val()) + std::swap(x, y); + if (!y.is_val()) + return; + auto& m = x.manager(); + auto yv = y.val(); + if (!(yv + 1).is_power_of_two()) + return; + if (yv == m.max_value()) + c.add_clause("band-mask-true", { d, C.eq(x, r()) }, false); + else if (yv == 0) + c.add_clause("band-mask-false", { d, C.eq(r()) }, false); + else { + unsigned N = m.power_of_2(); + unsigned k = yv.get_num_bits(); + SASSERT(k < N); + rational exp = rational::power_of_two(N - k); + c.add_clause("band-mask 1", { d, C.eq(x * exp, r() * exp) }, false); + c.add_clause("band-mask 2", { d, C.ule(r(), y) }, false); // maybe always activate these constraints regardless? + } + } + + /** * Enforce axioms for constraint: r == p << q * @@ -298,35 +321,33 @@ namespace polysat { * q = k -> r[i+k] = p[i] for 0 <= i < N - k * q = 0 -> r = p */ - clause_ref op_constraint::lemma_shl(solver& s, assignment const& a) { + void op_constraint::propagate_shl(core& c, dependency const& d) { auto& m = p().manager(); - auto const pv = a.apply_to(p()); - auto const qv = a.apply_to(q()); - auto const rv = a.apply_to(r()); + auto const pv = c.subst(p()); + auto const qv = c.subst(q()); + auto const rv = c.subst(r()); unsigned const N = m.power_of_2(); + auto& C = c.cs(); - signed_constraint const shl(this, true); - - if (qv.is_val() && qv.val() >= N && rv.is_val() && !rv.is_zero()) - // q >= N -> r = 0 - return s.mk_clause(~shl, ~s.ule(N, q()), s.eq(r()), true); + if (qv.is_val() && qv.val() >= N && rv.is_val() && !rv.is_zero()) + c.add_clause("q >= N -> r = 0", { d, ~C.ule(N, q()), C.eq(r()) }, true); else if (qv.is_zero() && pv.is_val() && rv.is_val() && rv != pv) - // q = 0 -> r = p - return s.mk_clause(~shl, ~s.eq(q()), s.eq(r(), p()), true); + // + c.add_clause("q = 0 -> r = p", { d, ~C.eq(q()), C.eq(r(), p()) }, true); else if (qv.is_val() && !qv.is_zero() && qv.val() < N && rv.is_val() && !rv.is_zero() && rv.val() < rational::power_of_two(qv.val().get_unsigned())) // q >= k -> r = 0 \/ r >= 2^k (intuitive version) // q >= k -> r - 1 >= 2^k - 1 (equivalent unit constraint to better support narrowing) - return s.mk_clause(~shl, ~s.ule(qv.val(), q()), s.ule(rational::power_of_two(qv.val().get_unsigned()) - 1, r() - 1), true); + c.add_clause("q >= k -> r - 1 >= 2^k - 1", { d, ~C.ule(qv.val(), q()), C.ule(rational::power_of_two(qv.val().get_unsigned()) - 1, r() - 1) }, true); else if (pv.is_val() && rv.is_val() && qv.is_val() && !qv.is_zero()) { unsigned k = qv.val().get_unsigned(); // q = k -> r[i+k] = p[i] for 0 <= i < N - k for (unsigned i = 0; i < N - k; ++i) { if (rv.val().get_bit(i + k) && !pv.val().get_bit(i)) { - return s.mk_clause(~shl, ~s.eq(q(), k), ~s.bit(r(), i + k), s.bit(p(), i), true); + c.add_clause("q = k -> r[i+k] = p[i] for 0 <= i < N - k", { d, ~C.eq(q(), k), ~C.bit(r(), i + k), C.bit(p(), i) }, true); } if (!rv.val().get_bit(i + k) && pv.val().get_bit(i)) { - return s.mk_clause(~shl, ~s.eq(q(), k), s.bit(r(), i + k), ~s.bit(p(), i), true); + c.add_clause("q = k -> r[i+k] = p[i] for 0 <= i < N - k", { d, ~C.eq(q(), k), C.bit(r(), i + k), ~C.bit(p(), i) }, true); } } } @@ -338,43 +359,15 @@ namespace polysat { rational const& q_val = qv.val(); if (q_val >= N) // q >= N ==> r = 0 - return s.mk_clause("shl forward 1", {~shl, ~s.ule(N, q()), s.eq(r())}, true); + c.add_clause("shl forward 1", {d, ~C.ule(N, q()), C.eq(r())}, true); if (pv.is_val()) { SASSERT(q_val.is_unsigned()); // p = p_val & q = q_val ==> r = p_val * 2^q_val rational const r_val = pv.val() * rational::power_of_two(q_val.get_unsigned()); - return s.mk_clause("shl forward 2", {~shl, ~s.eq(p(), pv), ~s.eq(q(), qv), s.eq(r(), r_val)}, true); + c.add_clause("shl forward 2", {d, ~C.eq(p(), pv), ~C.eq(q(), qv), C.eq(r(), r_val)}, true); } } } - return {}; - } - - - - void op_constraint::activate_and(solver& s) { - auto x = p(), y = q(); - if (x.is_val()) - std::swap(x, y); - if (!y.is_val()) - return; - auto& m = x.manager(); - auto yv = y.val(); - if (!(yv + 1).is_power_of_two()) - return; - signed_constraint const andc(this, true); - if (yv == m.max_value()) - s.add_clause(~andc, s.eq(x, r()), false); - else if (yv == 0) - s.add_clause(~andc, s.eq(r()), false); - else { - unsigned N = m.power_of_2(); - unsigned k = yv.get_num_bits(); - SASSERT(k < N); - rational exp = rational::power_of_two(N - k); - s.add_clause(~andc, s.eq(x * exp, r() * exp), false); - s.add_clause(~andc, s.ule(r(), y), false); // maybe always activate these constraints regardless? - } } /** @@ -390,48 +383,39 @@ namespace polysat { * p = 2^k - 1 && r = 0 && q != 0 => q >= 2^k * q = 2^k - 1 && r = 0 && p != 0 => p >= 2^k */ - clause_ref op_constraint::lemma_and(solver& s, assignment const& a) { + void op_constraint::propagate_and(core& c, dependency const& d) { auto& m = p().manager(); - auto pv = a.apply_to(p()); - auto qv = a.apply_to(q()); - auto rv = a.apply_to(r()); + auto pv = c.subst(p()); + auto qv = c.subst(q()); + auto rv = c.subst(r()); + auto& C = c.cs(); - signed_constraint const andc(this, true); // op_constraints are always true - - // r <= p if (pv.is_val() && rv.is_val() && rv.val() > pv.val()) - return s.mk_clause(~andc, s.ule(r(), p()), true); - // r <= q - if (qv.is_val() && rv.is_val() && rv.val() > qv.val()) - return s.mk_clause(~andc, s.ule(r(), q()), true); - // p = q => r = p - if (pv.is_val() && qv.is_val() && rv.is_val() && pv == qv && rv != pv) - return s.mk_clause(~andc, ~s.eq(p(), q()), s.eq(r(), p()), true); - if (pv.is_val() && qv.is_val() && rv.is_val()) { - // p = -1 => r = q + c.add_clause("p&q <= p", { d, C.ule(r(), p()) }, true); + else if (qv.is_val() && rv.is_val() && rv.val() > qv.val()) + c.add_clause("p&q <= q", { d, C.ule(r(), q()) }, true); + else if (pv.is_val() && qv.is_val() && rv.is_val() && pv == qv && rv != pv) + c.add_clause("p = q => r = p", { d, ~C.eq(p(), q()), C.eq(r(), p()) }, true); + else if (pv.is_val() && qv.is_val() && rv.is_val()) { if (pv.is_max() && qv != rv) - return s.mk_clause(~andc, ~s.eq(p(), m.max_value()), s.eq(q(), r()), true); - // q = -1 => r = p + c.add_clause("p = -1 => r = q", { d, ~C.eq(p(), m.max_value()), C.eq(q(), r()) }, true); if (qv.is_max() && pv != rv) - return s.mk_clause(~andc, ~s.eq(q(), m.max_value()), s.eq(p(), r()), true); + c.add_clause("q = -1 => r = p", { d, ~C.eq(q(), m.max_value()), C.eq(p(), r()) }, true); unsigned const N = m.power_of_2(); unsigned pow; if ((pv.val() + 1).is_power_of_two(pow)) { - // p = 2^k - 1 && r = 0 && q != 0 => q >= 2^k if (rv.is_zero() && !qv.is_zero() && qv.val() <= pv.val()) - return s.mk_clause(~andc, ~s.eq(p(), pv), ~s.eq(r()), s.eq(q()), s.ule(pv + 1, q()), true); - // p = 2^k - 1 ==> r*2^{N - k} = q*2^{N - k} + c.add_clause("p = 2^k - 1 && r = 0 && q != 0 => q >= 2^k", { d, ~C.eq(p(), pv), ~C.eq(r()), C.eq(q()), C.ule(pv + 1, q()) }, true); if (rv != qv) - return s.mk_clause(~andc, ~s.eq(p(), pv), s.eq(r() * rational::power_of_two(N - pow), q() * rational::power_of_two(N - pow)), true); + c.add_clause("p = 2^k - 1 ==> r*2^{N - k} = q*2^{N - k}", { d, ~C.eq(p(), pv), C.eq(r() * rational::power_of_two(N - pow), q() * rational::power_of_two(N - pow)) }, true); } if ((qv.val() + 1).is_power_of_two(pow)) { - // q = 2^k - 1 && r = 0 && p != 0 ==> p >= 2^k if (rv.is_zero() && !pv.is_zero() && pv.val() <= qv.val()) - return s.mk_clause(~andc, ~s.eq(q(), qv), ~s.eq(r()), s.eq(p()), s.ule(qv + 1, p()), true); - // q = 2^k - 1 ==> r*2^{N - k} = p*2^{N - k} + c.add_clause("q = 2^k - 1 && r = 0 && p != 0 ==> p >= 2^k", { d, ~C.eq(q(), qv), ~C.eq(r()), C.eq(p()), C.ule(qv + 1, p()) }, true); + // if (rv != pv) - return s.mk_clause(~andc, ~s.eq(q(), qv), s.eq(r() * rational::power_of_two(N - pow), p() * rational::power_of_two(N - pow)), true); + c.add_clause("q = 2^k - 1 ==> r*2^{N - k} = p*2^{N - k}", { d, ~C.eq(q(), qv), C.eq(r() * rational::power_of_two(N - pow), p() * rational::power_of_two(N - pow)) }, true); } for (unsigned i = 0; i < N; ++i) { @@ -441,33 +425,31 @@ namespace polysat { if (rb == (pb && qb)) continue; if (pb && qb && !rb) - return s.mk_clause(~andc, ~s.bit(p(), i), ~s.bit(q(), i), s.bit(r(), i), true); + c.add_clause("p&q[i] = p[i]&q[i]", { d, ~C.bit(p(), i), ~C.bit(q(), i), C.bit(r(), i) }, true); else if (!pb && rb) - return s.mk_clause(~andc, s.bit(p(), i), ~s.bit(r(), i), true); + c.add_clause("p&q[i] = p[i]&q[i]", { d, C.bit(p(), i), ~C.bit(r(), i) }, true); else if (!qb && rb) - return s.mk_clause(~andc, s.bit(q(), i), ~s.bit(r(), i), true); + c.add_clause("p&q[i] = p[i]&q[i]", { d, C.bit(q(), i), ~C.bit(r(), i) }, true); else UNREACHABLE(); } - return {}; + return; } // Propagate r if p or q are 0 - if (pv.is_zero() && !rv.is_zero()) // rv not necessarily fully evaluated - return s.mk_clause(~andc, s.ule(r(), p()), true); - if (qv.is_zero() && !rv.is_zero()) // rv not necessarily fully evaluated - return s.mk_clause(~andc, s.ule(r(), q()), true); + else if (pv.is_zero() && !rv.is_zero()) // rv not necessarily fully evaluated + c.add_clause("p = 0 -> p&q = 0", { d, C.ule(r(), p()) }, true); + else if (qv.is_zero() && !rv.is_zero()) // rv not necessarily fully evaluated + c.add_clause("q = 0 -> p&q = 0", { d, C.ule(r(), q()) }, true); // p = a && q = b ==> r = a & b - if (pv.is_val() && qv.is_val() && !rv.is_val()) { + else if (pv.is_val() && qv.is_val() && !rv.is_val()) { // Just assign by this very weak justification. It will be strengthened in saturation in case of a conflict LOG(p() << " = " << pv << " and " << q() << " = " << qv << " yields [band] " << r() << " = " << bitwise_and(pv.val(), qv.val())); - return s.mk_clause(~andc, ~s.eq(p(), pv), ~s.eq(q(), qv), s.eq(r(), bitwise_and(pv.val(), qv.val())), true); + c.add_clause("p = a & q = b => r = a&b", { d, ~C.eq(p(), pv), ~C.eq(q(), qv), C.eq(r(), bitwise_and(pv.val(), qv.val())) }, true); } - - return {}; } - +#if 0 /** * Produce lemmas for constraint: r == inv p @@ -490,15 +472,15 @@ namespace polysat { // p = 0 ==> r = 0 if (pv.is_zero()) - return s.mk_clause(~invc, ~s.eq(p()), s.eq(r()), true); + c.add_clause(~invc, ~C.eq(p()), C.eq(r()), true); // r = 0 ==> p = 0 if (rv.is_zero()) - return s.mk_clause(~invc, ~s.eq(r()), s.eq(p()), true); + c.add_clause(~invc, ~C.eq(r()), C.eq(p()), true); // forward propagation: p assigned ==> r = pseudo_inverse(eval(p)) // TODO: (later) this should be propagated instead of adding a clause /*if (pv.is_val() && !rv.is_val()) - return s.mk_clause(~invc, ~s.eq(p(), pv), s.eq(r(), pv.val().pseudo_inverse(m.power_of_2())), true);*/ + c.add_clause(~invc, ~C.eq(p(), pv), C.eq(r(), pv.val().pseudo_inverse(m.power_of_2())), true);*/ if (!pv.is_val() || !rv.is_val()) return {}; @@ -511,7 +493,7 @@ namespace polysat { // p != 0 ==> odd(r) if (parity_rv != 0) - return s.mk_clause("r = inv p & p != 0 ==> odd(r)", {~invc, s.eq(p()), s.odd(r())}, true); + c.add_clause("r = inv p & p != 0 ==> odd(r)", {~invc, C.eq(p()), s.odd(r())}, true); pdd prod = p() * r(); rational prodv = (pv * rv).val(); @@ -527,13 +509,13 @@ namespace polysat { LOG("Its in [" << lower << "; " << upper << ")"); // parity(p) >= k ==> p * r >= 2^k if (prodv < rational::power_of_two(middle)) - return s.mk_clause("r = inv p & parity(p) >= k ==> p*r >= 2^k", + c.add_clause("r = inv p & parity(p) >= k ==> p*r >= 2^k", {~invc, ~s.parity_at_least(p(), middle), s.uge(prod, rational::power_of_two(middle))}, false); // parity(p) >= k ==> r <= 2^(N - k) - 1 (because r is the smallest pseudo-inverse) rational const max_rv = rational::power_of_two(m.power_of_2() - middle) - 1; if (rv.val() > max_rv) - return s.mk_clause("r = inv p & parity(p) >= k ==> r <= 2^(N - k) - 1", - {~invc, ~s.parity_at_least(p(), middle), s.ule(r(), max_rv)}, false); + c.add_clause("r = inv p & parity(p) >= k ==> r <= 2^(N - k) - 1", + {~invc, ~s.parity_at_least(p(), middle), C.ule(r(), max_rv)}, false); } else { // parity less than middle SASSERT(parity_pv < middle); @@ -541,8 +523,8 @@ namespace polysat { LOG("Its in [" << lower << "; " << upper << ")"); // parity(p) < k ==> p * r <= 2^k - 1 if (prodv > rational::power_of_two(middle)) - return s.mk_clause("r = inv p & parity(p) < k ==> p*r <= 2^k - 1", - {~invc, s.parity_at_least(p(), middle), s.ule(prod, rational::power_of_two(middle) - 1)}, false); + c.add_clause("r = inv p & parity(p) < k ==> p*r <= 2^k - 1", + {~invc, s.parity_at_least(p(), middle), C.ule(prod, rational::power_of_two(middle) - 1)}, false); } } // Why did it evaluate to false in this case? @@ -550,114 +532,5 @@ namespace polysat { return {}; } - - - void op_constraint::activate_udiv(solver& s) { - // signed_constraint const udivc(this, true); Do we really need this premiss? We anyway assert these constraints as unit clauses - - pdd const& quot = r(); - pdd const& rem = m_linked->r(); - - // Axioms for quotient/remainder: - // a = b*q + r - // multiplication does not overflow in b*q - // addition does not overflow in (b*q) + r; for now expressed as: r <= bq+r - // b ≠ 0 ==> r < b - // b = 0 ==> q = -1 - // TODO: when a,b become evaluable, can we actually propagate q,r? doesn't seem like it. - // Maybe we need something like an op_constraint for better propagation. - s.add_clause(s.eq(q() * quot + rem - p()), false); - s.add_clause(~s.umul_ovfl(q(), quot), false); - // r <= b*q+r - // { apply equivalence: p <= q <=> q-p <= -p-1 } - // b*q <= -r-1 - s.add_clause(s.ule(q() * quot, -rem - 1), false); - - auto c_eq = s.eq(q()); - s.add_clause(c_eq, s.ult(rem, q()), false); - s.add_clause(~c_eq, s.eq(quot + 1), false); - } - - /** - * Produce lemmas for constraint: r == p / q - * q = 0 ==> r = max_value - * p = 0 ==> r = 0 || r = max_value - * q = 1 ==> r = p - */ - clause_ref op_constraint::lemma_udiv(solver& s, assignment const& a) { - auto pv = a.apply_to(p()); - auto qv = a.apply_to(q()); - auto rv = a.apply_to(r()); - - if (eval_udiv(pv, qv, rv) == l_true) - return {}; - - signed_constraint const udivc(this, true); - - if (qv.is_zero() && !rv.is_val()) - return s.mk_clause(~udivc, ~s.eq(q()), s.eq(r(), r().manager().max_value()), true); - if (pv.is_zero() && !rv.is_val()) - return s.mk_clause(~udivc, ~s.eq(p()), s.eq(r()), s.eq(r(), r().manager().max_value()), true); - if (qv.is_one()) - return s.mk_clause(~udivc, ~s.eq(q(), 1), s.eq(r(), p()), true); - - if (pv.is_val() && qv.is_val() && !rv.is_val()) { - SASSERT(!qv.is_zero()); - // TODO: We could actually propagate an interval. Instead of p = 9 & q = 4 => r = 2 we could do p >= 8 && p < 12 && q = 4 => r = 2 - return s.mk_clause(~udivc, ~s.eq(p(), pv.val()), ~s.eq(q(), qv.val()), s.eq(r(), div(pv.val(), qv.val())), true); - } - - return {}; - } - - - /** - * Produce lemmas for constraint: r == p % q - * p = 0 ==> r = 0 - * q = 1 ==> r = 0 - * q = 0 ==> r = p - */ - clause_ref op_constraint::lemma_urem(solver& s, assignment const& a) { - auto pv = a.apply_to(p()); - auto qv = a.apply_to(q()); - auto rv = a.apply_to(r()); - - if (eval_urem(pv, qv, rv) == l_true) - return {}; - - signed_constraint const urem(this, true); - - if (pv.is_zero() && !rv.is_val()) - return s.mk_clause(~urem, ~s.eq(p()), s.eq(r()), true); - if (qv.is_one() && !rv.is_val()) - return s.mk_clause(~urem, ~s.eq(q(), 1), s.eq(r()), true); - if (qv.is_zero()) - return s.mk_clause(~urem, ~s.eq(q()), s.eq(r(), p()), true); - - if (pv.is_val() && qv.is_val() && !rv.is_val()) { - SASSERT(!qv.is_zero()); - return s.mk_clause(~urem, ~s.eq(p(), pv.val()), ~s.eq(q(), qv.val()), s.eq(r(), mod(pv.val(), qv.val())), true); - } - - return {}; - } - - /** Evaluate constraint: r == p % q */ - lbool op_constraint::eval_urem(pdd const& p, pdd const& q, pdd const& r) { - - if (q.is_one() && r.is_val()) { - return r.val().is_zero() ? l_true : l_false; - } - if (q.is_zero()) { - if (r == p) - return l_true; - } - - if (!p.is_val() || !q.is_val() || !r.is_val()) - return l_undef; - - return r.val() == mod(p.val(), q.val()) ? l_true : l_false; // mod == rem as we know hat q > 0 - } - #endif } diff --git a/src/sat/smt/polysat/op_constraint.h b/src/sat/smt/polysat/op_constraint.h index a33f1b705..d7b6be392 100644 --- a/src/sat/smt/polysat/op_constraint.h +++ b/src/sat/smt/polysat/op_constraint.h @@ -50,32 +50,22 @@ namespace polysat { op_constraint(code c, pdd const& r, pdd const& p, pdd const& q); lbool eval(pdd const& r, pdd const& p, pdd const& q) const; -// clause_ref produce_lemma(core& s, assignment const& a); - // clause_ref lemma_lshr(core& s, assignment const& a); static lbool eval_lshr(pdd const& p, pdd const& q, pdd const& r); - - // clause_ref lemma_shl(core& s, assignment const& a); static lbool eval_shl(pdd const& p, pdd const& q, pdd const& r); - - // clause_ref lemma_and(core& s, assignment const& a); static lbool eval_and(pdd const& p, pdd const& q, pdd const& r); - - // clause_ref lemma_inv(core& s, assignment const& a); static lbool eval_inv(pdd const& p, pdd const& r); + + void propagate_lshr(core& s, dependency const& dep); + void propagate_shl(core& s, dependency const& dep); + void propagate_and(core& s, dependency const& dep); + void propagate_inv(core& s, dependency const& dep); + - // clause_ref lemma_udiv(core& s, assignment const& a); - static lbool eval_udiv(pdd const& p, pdd const& q, pdd const& r); - - // clause_ref lemma_urem(core& s, assignment const& a); - static lbool eval_urem(pdd const& p, pdd const& q, pdd const& r); std::ostream& display(std::ostream& out, char const* eq) const; - void activate(core& s); - - void activate_and(core& s); - void activate_udiv(core& s); + void activate_and(core& s, dependency const& d); public: ~op_constraint() override {} @@ -89,6 +79,8 @@ namespace polysat { lbool eval(assignment const& a) const override; bool is_always_true() const { return false; } bool is_always_false() const { return false; } + void activate(core& c, bool sign, dependency const& dep) override; + void propagate(core& c, lbool value, dependency const& dep) override; }; } diff --git a/src/sat/smt/polysat/types.h b/src/sat/smt/polysat/types.h index 207ea091e..6f855f98b 100644 --- a/src/sat/smt/polysat/types.h +++ b/src/sat/smt/polysat/types.h @@ -29,12 +29,17 @@ namespace polysat { class signed_constraint; class dependency { - std::variant> m_data; + struct axiom_t {}; + std::variant> m_data; unsigned m_level; + dependency(): m_data(axiom_t()), m_level(0) {} public: dependency(sat::literal lit, unsigned level) : m_data(lit), m_level(level) {} dependency(theory_var v1, theory_var v2, unsigned level) : m_data(std::make_pair(v1, v2)), m_level(level) {} + static dependency axiom() { return dependency(); } bool is_null() const { return is_literal() && *std::get_if(&m_data) == sat::null_literal; } + bool is_axiom() const { return std::holds_alternative(m_data); } + bool is_eq() const { return std::holds_alternative>(m_data); } bool is_literal() const { return std::holds_alternative(m_data); } sat::literal literal() const { SASSERT(is_literal()); return *std::get_if(&m_data); } std::pair eq() const { SASSERT(!is_literal()); return *std::get_if>(&m_data); } @@ -46,6 +51,8 @@ namespace polysat { inline std::ostream& operator<<(std::ostream& out, dependency d) { if (d.is_null()) return out << "null"; + else if (d.is_axiom()) + return out << "axiom@" << d.level(); else if (d.is_literal()) return out << d.literal() << "@" << d.level(); else @@ -87,7 +94,7 @@ namespace polysat { using dependency_vector = vector; - using core_vector = vector>; + using core_vector = std::initializer_list>; @@ -101,6 +108,7 @@ namespace polysat { virtual void add_eq_literal(pvar v, rational const& val) = 0; virtual void set_conflict(dependency_vector const& core) = 0; virtual void set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) = 0; + virtual void add_polysat_clause(char const* name, core_vector cs, bool redundant) = 0; virtual dependency propagate(signed_constraint sc, dependency_vector const& deps) = 0; virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0; virtual trail_stack& trail() = 0; diff --git a/src/sat/smt/polysat/ule_constraint.cpp b/src/sat/smt/polysat/ule_constraint.cpp index 185dad0ee..bdfcb7c5f 100644 --- a/src/sat/smt/polysat/ule_constraint.cpp +++ b/src/sat/smt/polysat/ule_constraint.cpp @@ -70,6 +70,8 @@ Useful lemmas: --*/ +#include "util/log.h" +#include "sat/smt/polysat/core.h" #include "sat/smt/polysat/constraints.h" #include "sat/smt/polysat/ule_constraint.h" @@ -314,8 +316,6 @@ namespace polysat { return display(out, l_true, m_lhs, m_rhs); } - - // Evaluate lhs <= rhs lbool ule_constraint::eval(pdd const& lhs, pdd const& rhs) { // NOTE: don't assume simplifications here because we also call this on partially substituted constraints @@ -343,4 +343,15 @@ namespace polysat { return eval(a.apply_to(lhs()), a.apply_to(rhs())); } + void ule_constraint::activate(core& c, bool sign, dependency const& d) { + auto p = c.subst(lhs()); + auto q = c.subst(rhs()); + auto& C = c.cs(); + if (sign && !lhs().is_val() && !rhs().is_val()) { + c.add_clause("lhs > rhs ==> -1 > rhs", { d, C.ult(rhs(), -1) }, false); + c.add_clause("lhs > rhs ==> lhs > 0", { d, C.ult(0, lhs()) }, false); + } + } + + } diff --git a/src/sat/smt/polysat/ule_constraint.h b/src/sat/smt/polysat/ule_constraint.h index aa53e6a4f..81a0b64c5 100644 --- a/src/sat/smt/polysat/ule_constraint.h +++ b/src/sat/smt/polysat/ule_constraint.h @@ -35,6 +35,8 @@ namespace polysat { std::ostream& display(std::ostream& out) const override; lbool eval() const override; lbool eval(assignment const& a) const override; + void activate(core& c, bool sign, dependency const& dep); + void propagate(core& c, lbool value, dependency const& dep) {} bool is_eq() const { return m_rhs.is_zero(); } unsigned power_of_2() const { return m_lhs.power_of_2(); } diff --git a/src/sat/smt/polysat/umul_ovfl_constraint.cpp b/src/sat/smt/polysat/umul_ovfl_constraint.cpp index e7dc5801c..5d185e7ee 100644 --- a/src/sat/smt/polysat/umul_ovfl_constraint.cpp +++ b/src/sat/smt/polysat/umul_ovfl_constraint.cpp @@ -10,6 +10,8 @@ Author: Jakob Rath, Nikolaj Bjorner (nbjorner) 2021-12-09 --*/ +#include "util/log.h" +#include "sat/smt/polysat/core.h" #include "sat/smt/polysat/constraints.h" #include "sat/smt/polysat/assignment.h" #include "sat/smt/polysat/umul_ovfl_constraint.h" @@ -70,4 +72,84 @@ namespace polysat { return eval(a.apply_to(p()), a.apply_to(q())); } + void umul_ovfl_constraint::activate(core& c, bool sign, dependency const& dep) { + + } + + void umul_ovfl_constraint::propagate(core& c, lbool value, dependency const& dep) { + auto& C = c.cs(); + auto p1 = c.subst(p()); + auto q1 = c.subst(q()); + if (narrow_bound(c, value == l_true, p(), q(), p1, q1, dep)) + return; + if (narrow_bound(c, value == l_true, q(), p(), q1, p1, dep)) + return; + } + + /** + * if p constant, q, propagate inequality + */ + bool umul_ovfl_constraint::narrow_bound(core& c, bool is_positive, pdd const& p0, pdd const& q0, pdd const& p, pdd const& q, dependency const& d) { + LOG("p: " << p0 << " := " << p); + LOG("q: " << q0 << " := " << q); + + if (!p.is_val()) + return false; + VERIFY(!p.is_zero() && !p.is_one()); // evaluation should catch this case + + rational const& M = p.manager().two_to_N(); + auto& C = c.cs(); + + // q_bound + // = min q . Ovfl(p_val, q) + // = min q . p_val * q >= M + // = min q . q >= M / p_val + // = ceil(M / p_val) + rational const q_bound = ceil(M / p.val()); + SASSERT(2 <= q_bound && q_bound <= M / 2); + SASSERT(p.val() * q_bound >= M); + SASSERT(p.val() * (q_bound - 1) < M); + // LOG("q_bound: " << q.manager().mk_val(q_bound)); + + // We need the following properties for the bounds: + // + // p_bound * (q_bound - 1) < M + // p_bound * q_bound >= M + // + // With these properties we get: + // + // p <= p_bound & q < q_bound ==> ~Ovfl(p, q) + // p >= p_bound & q >= q_bound ==> Ovfl(p, q) + // + // Written as lemmas: + // + // Ovfl(p, q) & p <= p_bound ==> q >= q_bound + // ~Ovfl(p, q) & p >= p_bound ==> q < q_bound + // + if (is_positive) { + // Find largest bound for p such that q_bound is still correct. + // p_bound = max p . (q_bound - 1)*p < M + // = max p . p < M / (q_bound - 1) + // = ceil(M / (q_bound - 1)) - 1 + rational const p_bound = ceil(M / (q_bound - 1)) - 1; + SASSERT(p.val() <= p_bound); + SASSERT(p_bound * q_bound >= M); + SASSERT(p_bound * (q_bound - 1) < M); + // LOG("p_bound: " << p.manager().mk_val(p_bound)); + c.add_clause("~Ovfl(p, q) & p <= p_bound ==> q < q_bound", { d, ~C.ule(p0, p_bound), C.ule(q_bound, q0) }, false); + } + else { + // Find lowest bound for p such that q_bound is still correct. + // p_bound = min p . Ovfl(p, q_bound) = ceil(M / q_bound) + rational const p_bound = ceil(M / q_bound); + SASSERT(p_bound <= p.val()); + SASSERT(p_bound * q_bound >= M); + SASSERT(p_bound * (q_bound - 1) < M); + // LOG("p_bound: " << p.manager().mk_val(p_bound)); + c.add_clause("~Ovfl(p, q) & p >= p_bound ==> q < q_bound", { d, ~C.ule(p_bound, p0), C.ult(q0, q_bound) }, false); + } + return true; + } + + } diff --git a/src/sat/smt/polysat/umul_ovfl_constraint.h b/src/sat/smt/polysat/umul_ovfl_constraint.h index c9d03fb01..374d346f7 100644 --- a/src/sat/smt/polysat/umul_ovfl_constraint.h +++ b/src/sat/smt/polysat/umul_ovfl_constraint.h @@ -25,6 +25,8 @@ namespace polysat { static bool is_always_false(bool is_positive, pdd const& p, pdd const& q) { return is_always_true(!is_positive, p, q); } static lbool eval(pdd const& p, pdd const& q); + bool narrow_bound(core& c, bool is_positive, pdd const& p0, pdd const& q0, pdd const& p, pdd const& q, dependency const& d); + public: umul_ovfl_constraint(pdd const& p, pdd const& q); ~umul_ovfl_constraint() override {} @@ -34,6 +36,8 @@ namespace polysat { std::ostream& display(std::ostream& out) const override; lbool eval() const override; lbool eval(assignment const& a) const override; + void activate(core& c, bool sign, dependency const& dep) override; + void propagate(core& c, lbool value, dependency const& dep) override; }; } diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 5e5647bd3..ef469fe6f 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -235,9 +235,7 @@ namespace polysat { if (n->get_num_args() == 2) { expr* x, * y; VERIFY(bv.is_bv_and(n, x, y)); - auto sc = m_core.band(expr2pdd(x), expr2pdd(y), expr2pdd(n)); - // auto index = m_core.register_constraint(sc, dependency::axiom()); - // + m_core.band(expr2pdd(x), expr2pdd(y), expr2pdd(n)); } else { expr_ref z(n->get_arg(0), m); @@ -252,19 +250,19 @@ namespace polysat { void solver::internalize_lshr(app* n) { expr* x, * y; VERIFY(bv.is_bv_lshr(n, x, y)); - auto sc = m_core.lshr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); + m_core.lshr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); } void solver::internalize_ashr(app* n) { expr* x, * y; VERIFY(bv.is_bv_ashr(n, x, y)); - auto sc = m_core.ashr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); + m_core.ashr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); } void solver::internalize_shl(app* n) { expr* x, * y; VERIFY(bv.is_bv_shl(n, x, y)); - auto sc = m_core.shl(expr2pdd(x), expr2pdd(y), expr2pdd(n)); + m_core.shl(expr2pdd(x), expr2pdd(y), expr2pdd(n)); } void solver::internalize_urem_i(app* rem) { diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 9f185b22d..690548aaa 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -64,7 +64,7 @@ namespace polysat { case sat::check_result::CR_GIVEUP: { if (!m.inc()) return sat::check_result::CR_GIVEUP; - switch (m_intblast.check()) { + switch (m_intblast.check_solver_state()) { case l_true: trail().push(value_trail(m_use_intblast_model)); m_use_intblast_model = true; @@ -254,10 +254,25 @@ namespace polysat { return ctx.get_trail_stack(); } - void solver::add_polysat_clause(char const* name, std::initializer_list cs, bool is_redundant) { + void solver::add_polysat_clause(char const* name, core_vector cs, bool is_redundant) { sat::literal_vector lits; - for (auto sc : cs) - lits.push_back(ctx.mk_literal(constraint2expr(sc))); + for (auto e : cs) { + if (std::holds_alternative(e)) { + auto d = *std::get_if(&e); + SASSERT(!d.is_null()); + if (d.is_literal()) + lits.push_back(~d.literal()); + else if (d.is_eq()) { + auto [v1, v2] = d.eq(); + lits.push_back(~eq_internalize(var2enode(v1), var2enode(v2))); + } + else { + SASSERT(d.is_axiom()); + } + } + else if (std::holds_alternative(e)) + lits.push_back(ctx.mk_literal(constraint2expr(*std::get_if(&e)))); + } s().add_clause(lits.size(), lits.data(), sat::status::th(is_redundant, get_id(), nullptr)); } diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index f54bafb1c..a04c76618 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -151,7 +151,7 @@ namespace polysat { bool inconsistent() const override; void get_bitvector_prefixes(pvar v, pvar_vector& out) override; void get_fixed_bits(pvar v, svector& fixed_bits) override; - void add_polysat_clause(char const* name, std::initializer_list cs, bool is_redundant); + void add_polysat_clause(char const* name, core_vector cs, bool redundant) override; std::pair explain_deps(dependency_vector const& deps); From 06ebf9a02af5720ba6d55cd65f9c6462dce5522c Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 12 Dec 2023 14:41:31 -0800 Subject: [PATCH 43/89] n/a --- src/ast/bv_decl_plugin.cpp | 5 + src/ast/bv_decl_plugin.h | 1 + src/sat/smt/intblast_solver.cpp | 314 +++++++++++++------ src/sat/smt/intblast_solver.h | 31 +- src/sat/smt/polysat/umul_ovfl_constraint.cpp | 2 + src/sat/smt/polysat_solver.cpp | 26 ++ 6 files changed, 272 insertions(+), 107 deletions(-) diff --git a/src/ast/bv_decl_plugin.cpp b/src/ast/bv_decl_plugin.cpp index f725fefc5..30cfe4cdb 100644 --- a/src/ast/bv_decl_plugin.cpp +++ b/src/ast/bv_decl_plugin.cpp @@ -942,3 +942,8 @@ app * bv_util::mk_bv2int(expr* e) { parameter p(s); return m_manager.mk_app(get_fid(), OP_BV2INT, 1, &p, 1, &e); } + +app* bv_util::mk_int2bv(unsigned sz, expr* e) { + parameter p(sz); + return m_manager.mk_app(get_fid(), OP_INT2BV, 1, &p, 1, &e); +} diff --git a/src/ast/bv_decl_plugin.h b/src/ast/bv_decl_plugin.h index 4eeac49ee..cb1f63881 100644 --- a/src/ast/bv_decl_plugin.h +++ b/src/ast/bv_decl_plugin.h @@ -522,6 +522,7 @@ public: app * mk_bv_lshr(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_BLSHR, arg1, arg2); } app * mk_bv2int(expr* e); + app * mk_int2bv(unsigned sz, expr* e); // TODO: all these binary ops commute (right?) but it'd be more logical to swap `n` & `m` in the `return` app * mk_bvsmul_no_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_NO_OVFL, n, m); } diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 0505eaa92..250e279cd 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -29,8 +29,7 @@ namespace intblast { bv(m), a(m), m_args(m), - m_translate(m), - m_pinned(m) + m_translate(m) {} euf::theory_var solver::mk_var(euf::enode* n) { @@ -85,14 +84,38 @@ namespace intblast { return true; } + void solver::eq_internalized(euf::enode* n) { + expr* e = n->get_expr(); + expr* x, * y; + VERIFY(m.is_eq(n->get_expr(), x, y)); + m_args.reset(); + m_args.push_back(translated(x)); + m_args.push_back(translated(y)); + add_equiv(expr2literal(e), eq_internalize(umod(x, 0), umod(x, 1))); + } + void solver::internalize_bv(app* e) { ensure_args(e); m_args.reset(); for (auto arg : *e) m_args.push_back(translated(arg)); translate_bv(e); - if (m.is_bool(e)) - add_equiv(expr2literal(e), mk_literal(translated(e))); + if (m.is_bool(e)) + add_equiv(expr2literal(e), mk_literal(translated(e))); + add_bound_axioms(); + } + + void solver::add_bound_axioms() { + if (m_vars_qhead == m_vars.size()) + return; + ctx.push(value_trail(m_vars_qhead)); + for (; m_vars_qhead < m_vars.size(); ++m_vars_qhead) { + auto v = m_vars[m_vars_qhead]; + auto w = translated(v); + auto sz = rational::power_of_two(bv.get_bv_size(v->get_sort())); + add_unit(ctx.mk_literal(a.mk_ge(w, a.mk_int(0)))); + add_unit(ctx.mk_literal(a.mk_le(w, a.mk_int(sz - 1)))); + } } void solver::ensure_args(app* e) { @@ -106,17 +129,17 @@ namespace intblast { return; for (unsigned i = 0; i < todo.size(); ++i) { expr* e = todo[i]; - if (is_app(e)) { + if (m.is_bool(e)) + continue; + else if (is_app(e)) { for (auto arg : *to_app(e)) if (!visited.is_marked(arg)) { visited.mark(arg); todo.push_back(arg); } } - else if (is_quantifier(e) && !visited.is_marked(to_quantifier(e)->get_expr())) { - visited.mark(to_quantifier(e)->get_expr()); - todo.push_back(to_quantifier(e)->get_expr()); - } + else if (is_lambda(e)) + throw default_exception("lambdas are not supported in intblaster"); } std::stable_sort(todo.begin(), todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); @@ -176,8 +199,8 @@ namespace intblast { } m_core.reset(); - m_vars.reset(); m_translate.reset(); + m_is_plugin = false; m_solver = mk_smt2_solver(m, s.params(), symbol::null); expr_ref_vector es(m); @@ -186,8 +209,9 @@ namespace intblast { translate(es); - for (auto const& [src, vi] : m_vars) { - auto const& [v, b] = vi; + for (auto e : m_vars) { + auto v = translated(e); + auto b = rational::power_of_two(bv.get_bv_size(e)); m_solver->assert_expr(a.mk_le(a.mk_int(0), v)); m_solver->assert_expr(a.mk_lt(v, a.mk_int(b))); } @@ -296,19 +320,68 @@ namespace intblast { es[i] = translated(es.get(i)); } - expr* solver::mk_mod(expr* x) { - if (m_vars.contains(x)) + sat::check_result solver::check() { + // ensure that bv2int is injective + for (auto e : m_bv2int) { + euf::enode* n = expr2enode(e); + euf::enode* r1 = n->get_arg(0)->get_root(); + for (auto sib : euf::enode_class(n)) { + if (sib == n) + continue; + if (!bv.is_bv2int(sib->get_expr())) + continue; + if (sib->get_arg(0)->get_root() == r1) + continue; + add_clause(~eq_internalize(n, sib), eq_internalize(sib->get_arg(0), n->get_arg(0)), nullptr); + return sat::check_result::CR_CONTINUE; + } + } + // ensure that int2bv respects values + // bv2int(int2bv(x)) = x mod N + for (auto e : m_int2bv) { + auto n = expr2enode(e); + auto x = n->get_arg(0)->get_expr(); + auto bv2int = bv.mk_bv2int(e); + ctx.internalize(bv2int); + auto N = rational::power_of_two(bv.get_bv_size(e)); + auto xModN = a.mk_mod(x, a.mk_int(N)); + ctx.internalize(xModN); + auto nBv2int = ctx.get_enode(bv2int); + auto nxModN = ctx.get_enode(xModN); + if (nBv2int->get_root() != nxModN->get_root()) { + add_unit(eq_internalize(nBv2int, nxModN)); + return sat::check_result::CR_CONTINUE; + } + } + return sat::check_result::CR_DONE; + } + + expr* solver::umod(expr* bv_expr, unsigned i) { + expr* x = arg(i); + rational r; + rational N = bv_size(bv_expr); + if (a.is_numeral(x, r)) { + if (0 <= r && r < N) + return x; + return a.mk_int(mod(r, N)); + } + if (any_of(m_vars, [&](expr* v) { return translated(v) == x; })) return x; - return to_expr(a.mk_mod(x, a.mk_int(bv_size()))); + return to_expr(a.mk_mod(x, a.mk_int(N))); } - expr* solver::mk_smod(expr* x) { - auto shift = bv_size() / 2; - return a.mk_mod(a.mk_add(x, a.mk_int(shift)), a.mk_int(bv_size())); + expr* solver::smod(expr* bv_expr, unsigned i) { + expr* x = arg(i); + auto N = bv_size(bv_expr); + auto shift = N / 2; + rational r; + if (a.is_numeral(x, r)) + return a.mk_int(mod(r + shift, N)); + return a.mk_mod(a.mk_add(x, a.mk_int(shift)), a.mk_int(N)); } - rational solver::bv_size() { - return rational::power_of_two(bv.get_bv_size(bv_expr->get_sort())); + rational solver::bv_size(expr* bv_expr) { + return rational::power_of_two(bv.get_bv_size(bv_expr->get_sort())); } void solver::translate_expr(expr* e) { @@ -318,7 +391,6 @@ namespace intblast { translate_var(to_var(e)); else { app* ap = to_app(e); - bv_expr = e; m_args.reset(); for (auto arg : *ap) m_args.push_back(translated(arg)); @@ -333,6 +405,12 @@ namespace intblast { } void solver::translate_quantifier(quantifier* q) { + if (is_lambda(q)) + throw default_exception("lambdas are not supported in intblaster"); + if (m_is_plugin) { + set_translated(q, q); + return; + } expr* b = q->get_expr(); unsigned nd = q->get_num_decls(); ptr_vector sorts; @@ -357,37 +435,47 @@ namespace intblast { set_translated(v, v); } + // Translate functions that are not built-in or bit-vectors. + // Base method uses fresh functions. + // Other method could use bv2int, int2bv axioms and coercions. + // f(args) = bv2int(f(int2bv(args')) + // + void solver::translate_app(app* e) { - bool has_bv_arg = any_of(*e, [&](expr* arg) { return bv.is_bv(arg); }); - bool has_bv_sort = bv.is_bv(e); - func_decl* f = e->get_decl(); - if (has_bv_arg) { - verbose_stream() << mk_pp(e, m) << "\n"; - // need to update args with mod where they are bit-vectors. - NOT_IMPLEMENTED_YET(); + + if (m_is_plugin && m.is_bool(e)) { + set_translated(e, e); + return; } - if (has_bv_arg || has_bv_sort) { - ptr_vector domain; - for (auto* arg : *e) { - sort* s = arg->get_sort(); - domain.push_back(bv.is_bv_sort(s) ? a.mk_int() : s); + bool has_bv_sort = bv.is_bv(e); + func_decl* f = e->get_decl(); + + for (unsigned i = 0; i < m_args.size(); ++i) + if (bv.is_bv(e->get_arg(i))) + m_args[i] = bv.mk_int2bv(bv.get_bv_size(e->get_arg(i)), m_args.get(i)); + + if (has_bv_sort) + m_vars.push_back(e); + + if (m_is_plugin) { + expr* r = m.mk_app(f, m_args); + if (has_bv_sort) { + ctx.push(push_back_vector(m_vars)); + r = bv.mk_bv2int(r); } - sort* range = bv.is_bv(e) ? a.mk_int() : e->get_sort(); + set_translated(e, r); + return; + } + else if (has_bv_sort) { func_decl* g = nullptr; if (!m_new_funs.find(f, g)) { - g = m.mk_fresh_func_decl(e->get_decl()->get_name(), symbol("bv"), domain.size(), domain.data(), range); + g = m.mk_fresh_func_decl(e->get_decl()->get_name(), symbol("bv"), f->get_arity(), f->get_domain(), a.mk_int()); m_new_funs.insert(f, g); - m_pinned.push_back(f); - m_pinned.push_back(g); } f = g; } - - set_translated(e, m.mk_app(f, m_args)); - - if (has_bv_sort) - m_vars.insert(e, { translated(e), bv_size()}); + set_translated(e, m.mk_app(f, m_args)); } void solver::translate_bv(app* e) { @@ -403,61 +491,59 @@ namespace intblast { return r; }; - bv_expr = e; + expr* bv_expr = e; expr* r = nullptr; auto const& args = m_args; switch (e->get_decl_kind()) { case OP_BADD: - r = (a.mk_add(args)); + r = a.mk_add(args); break; case OP_BSUB: - r = (a.mk_sub(args.size(), args.data())); + r = a.mk_sub(args.size(), args.data()); break; case OP_BMUL: - r = (a.mk_mul(args)); + r = a.mk_mul(args); break; case OP_ULEQ: bv_expr = e->get_arg(0); - r = (a.mk_le(mk_mod(arg(0)), mk_mod(arg(1)))); + r = a.mk_le(umod(bv_expr, 0), umod(bv_expr, 1)); break; case OP_UGEQ: bv_expr = e->get_arg(0); - r = (a.mk_ge(mk_mod(arg(0)), mk_mod(arg(1)))); + r = a.mk_ge(umod(bv_expr, 0), umod(bv_expr, 1)); break; case OP_ULT: bv_expr = e->get_arg(0); - r = (a.mk_lt(mk_mod(arg(0)), mk_mod(arg(1)))); + r = a.mk_lt(umod(bv_expr, 0), umod(bv_expr, 1)); break; case OP_UGT: bv_expr = e->get_arg(0); - r = (a.mk_gt(mk_mod(arg(0)), mk_mod(arg(1)))); + r = a.mk_gt(umod(bv_expr, 0), umod(bv_expr, 1)); break; case OP_SLEQ: bv_expr = e->get_arg(0); - r = (a.mk_le(mk_smod(arg(0)), mk_smod(arg(1)))); + r = a.mk_le(smod(bv_expr, 0), smod(bv_expr, 1)); break; case OP_SGEQ: - r = (a.mk_ge(mk_smod(arg(0)), mk_smod(arg(1)))); + r = a.mk_ge(smod(bv_expr, 0), smod(bv_expr, 1)); break; case OP_SLT: bv_expr = e->get_arg(0); - r = (a.mk_lt(mk_smod(arg(0)), mk_smod(arg(1)))); + r = a.mk_lt(smod(bv_expr, 0), smod(bv_expr, 1)); break; case OP_SGT: bv_expr = e->get_arg(0); - r = (a.mk_gt(mk_smod(arg(0)), mk_smod(arg(1)))); + r = a.mk_gt(smod(bv_expr, 0), smod(bv_expr, 1)); break; case OP_BNEG: - r = (a.mk_uminus(arg(0))); + r = a.mk_uminus(arg(0)); break; case OP_CONCAT: { r = a.mk_int(0); unsigned sz = 0; for (unsigned i = 0; i < args.size(); ++i) { expr* old_arg = e->get_arg(i); - expr* new_arg = arg(i); - bv_expr = old_arg; - new_arg = mk_mod(new_arg); + expr* new_arg = umod(old_arg, i); if (sz > 0) { new_arg = a.mk_mul(new_arg, a.mk_int(rational::power_of_two(sz))); r = a.mk_add(r, new_arg); @@ -482,23 +568,22 @@ namespace intblast { rational val; unsigned sz; VERIFY(bv.is_numeral(e, val, sz)); - r = (a.mk_int(val)); + r = a.mk_int(val); break; } case OP_BUREM_I: { expr* x = arg(0), * y = arg(1); - r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y)); break; } case OP_BUDIV_I: { expr* x = arg(0), * y = arg(1); - r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, umod(bv_expr, 1))); break; } case OP_BUMUL_NO_OVFL: { - expr* x = arg(0), * y = arg(1); bv_expr = e->get_arg(0); - r = (a.mk_lt(a.mk_mul(mk_mod(x), mk_mod(y)), a.mk_int(bv_size()))); + r = a.mk_lt(a.mk_mul(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(bv_size(bv_expr))); break; } case OP_BSHL: { @@ -509,7 +594,7 @@ namespace intblast { break; } case OP_BNOT: - r = (bnot(arg(0))); + r = bnot(arg(0)); break; case OP_BLSHR: { expr* x = arg(0), * y = arg(1); @@ -518,7 +603,7 @@ namespace intblast { r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); break; } - // Or use (p + q) - band(p, q)? + // Or use (p + q) - band(p, q)? case OP_BOR: { r = arg(0); for (unsigned i = 1; i < args.size(); ++i) @@ -537,46 +622,43 @@ namespace intblast { case OP_BXNOR: case OP_BXOR: { unsigned sz = bv.get_bv_size(e); - expr* p = arg(0); + r = arg(0); for (unsigned i = 1; i < args.size(); ++i) { expr* q = arg(i); - p = a.mk_sub(a.mk_add(p, q), a.mk_mul(a.mk_int(2), a.mk_band(sz, p, q))); + r = a.mk_sub(a.mk_add(r, q), a.mk_mul(a.mk_int(2), a.mk_band(sz, r, q))); } if (e->get_decl_kind() == OP_BXNOR) - p = bnot(p); - r = (p); + r = bnot(r); break; } case OP_BUDIV: { bv_rewriter_params p(ctx.s().params()); expr* x = arg(0), * y = arg(1); if (p.hi_div0()) - r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y))); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y)); else - r = (a.mk_idiv(x, y)); + r = a.mk_idiv(x, y); break; } case OP_BUREM: { bv_rewriter_params p(ctx.s().params()); expr* x = arg(0), * y = arg(1); if (p.hi_div0()) - r = (m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y))); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y)); else - r = (a.mk_mod(x, y)); + r = a.mk_mod(x, y); break; } + // // ashr(x, y) // if y = k & x >= 0 -> x / 2^k // if y = k & x < 0 -> - (x / 2^k) // - case OP_BASHR: { - expr* x = arg(0), * y = arg(1); rational N = rational::power_of_two(bv.get_bv_size(e)); - bv_expr = e; - x = mk_mod(x); - y = mk_mod(y); + expr* x = umod(e, 0); + expr* y = umod(e, 1); expr* signbit = a.mk_ge(x, a.mk_int(N / 2)); r = m.mk_ite(signbit, a.mk_int(N - 1), a.mk_int(0)); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) { @@ -587,15 +669,39 @@ namespace intblast { } break; } + case OP_ZERO_EXT: + bv_expr = e->get_arg(0); + r = umod(bv_expr, 0); + SASSERT(bv.get_bv_size(e) >= bv.get_bv_size(bv_expr)); + break; + case OP_SIGN_EXT: { + bv_expr = e->get_arg(0); + r = umod(bv_expr, 0); + SASSERT(bv.get_bv_size(e) >= bv.get_bv_size(bv_expr)); + unsigned arg_sz = bv.get_bv_size(bv_expr); + unsigned sz = bv.get_bv_size(e); + rational N = rational::power_of_two(sz); + rational M = rational::power_of_two(arg_sz); + expr* signbit = a.mk_ge(r, a.mk_int(M / 2)); + r = m.mk_ite(signbit, a.mk_uminus(r), r); + break; + } + case OP_INT2BV: + m_int2bv.push_back(e); + ctx.push(push_back_vector(m_int2bv)); + r = arg(0); + break; + case OP_BV2INT: + m_bv2int.push_back(e); + ctx.push(push_back_vector(m_bv2int)); + r = arg(0); + break; case OP_BCOMP: - case OP_ROTATE_LEFT: case OP_ROTATE_RIGHT: case OP_EXT_ROTATE_LEFT: case OP_EXT_ROTATE_RIGHT: case OP_REPEAT: - case OP_ZERO_EXT: - case OP_SIGN_EXT: case OP_BREDOR: case OP_BREDAND: case OP_BSDIV: @@ -610,19 +716,19 @@ namespace intblast { } set_translated(e, r); } - + void solver::translate_basic(app* e) { if (m.is_eq(e)) { bool has_bv_arg = any_of(*e, [&](expr* arg) { return bv.is_bv(arg); }); if (has_bv_arg) { - bv_expr = e->get_arg(0); - set_translated(e, m.mk_eq(mk_mod(arg(0)), mk_mod(arg(1)))); + expr* bv_expr = e->get_arg(0); + set_translated(e, m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1))); } - else - set_translated(e, m.mk_eq(arg(0), arg(1))); + else + set_translated(e, m.mk_eq(arg(0), arg(1))); } - else - set_translated(e, m.mk_app(e->get_decl(), m_args)); + else + set_translated(e, m.mk_app(e->get_decl(), m_args)); } rational solver::get_value(expr* e) const { @@ -630,11 +736,9 @@ namespace intblast { model_ref mdl; m_solver->get_model(mdl); expr_ref r(m); - var_info vi; + r = translated(e); rational val; - if (!m_vars.find(e, vi)) - return rational::zero(); - if (!mdl->eval_expr(vi.dst, r, true)) + if (!mdl->eval_expr(r, r, true)) return rational::zero(); if (!a.is_numeral(r, val)) return rational::zero(); @@ -642,6 +746,32 @@ namespace intblast { } void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + if (m_is_plugin) + add_value_plugin(n, mdl, values); + else + add_value_solver(n, mdl, values); + } + + bool solver::add_dep(euf::enode* n, top_sort& dep) { + // bv2int + auto e = ctx.get_enode(translated(n->get_expr())); + if (!e) + return false; + dep.add(n, e); + } + + // TODO: handle dependencies properly by using arithmetical model to retrieve values of translated + // bit-vectors directly. + void solver::add_value_plugin(euf::enode* n, model& mdl, expr_ref_vector& values) { + SASSERT(bv.is_bv(n->get_expr())); + rational N = rational::power_of_two(bv.get_bv_size(n->get_expr())); + auto e = ctx.get_enode(translated(n->get_expr())); + expr_ref value(m); + value = values.get(e->get_root_id()); + values.setx(n->get_root_id(), value); + } + + void solver::add_value_solver(euf::enode* n, model& mdl, expr_ref_vector& values) { expr_ref value(m); if (n->interpreted()) value = n->get_expr(); diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index 037b009a3..707f53832 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -46,23 +46,18 @@ namespace euf { namespace intblast { class solver : public euf::th_euf_solver { - struct var_info { - expr* dst; - rational sz; - }; - euf::solver& ctx; sat::solver& s; ast_manager& m; bv_util bv; arith_util a; scoped_ptr<::solver> m_solver; - obj_map m_vars; obj_map m_new_funs; expr_ref_vector m_translate, m_args; - ast_ref_vector m_pinned; sat::literal_vector m_core; + ptr_vector m_bv2int, m_int2bv; statistics m_stats; + bool m_is_plugin = true; // when the solver is used as a plugin, then do not translate below quantifiers. bool is_bv(sat::literal lit); void translate(expr_ref_vector& es); @@ -70,14 +65,13 @@ namespace intblast { rational get_value(expr* e) const; - expr* translated(expr* e) { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } + expr* translated(expr* e) const { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } void set_translated(expr* e, expr* r) { m_translate.setx(e->get_id(), r); } expr* arg(unsigned i) { return m_args.get(i); } - expr* mk_mod(expr* x); - expr* mk_smod(expr* x); - expr* bv_expr = nullptr; - rational bv_size(); + expr* umod(expr* bv_expr, unsigned i); + expr* smod(expr* bv_expr, unsigned i); + rational bv_size(expr* bv_expr); void translate_expr(expr* e); void translate_bv(app* e); @@ -89,8 +83,15 @@ namespace intblast { void ensure_args(app* e); void internalize_bv(app* e); + unsigned m_vars_qhead = 0; + ptr_vector m_vars; + void add_bound_axioms(); + euf::theory_var mk_var(euf::enode* n) override; + void add_value_plugin(euf::enode* n, model& mdl, expr_ref_vector& values); + void add_value_solver(euf::enode* n, model& mdl, expr_ref_vector& values); + public: solver(euf::solver& ctx); @@ -102,12 +103,12 @@ namespace intblast { void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; + bool add_dep(euf::enode* n, top_sort& dep) override; + std::ostream& display(std::ostream& out) const override; void collect_statistics(statistics& st) const override; - - bool unit_propagate() override { return false; } void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override {} @@ -130,7 +131,7 @@ namespace intblast { sat::literal internalize(expr* e, bool, bool) override; - void eq_internalized(euf::enode* n) override {} + void eq_internalized(euf::enode* n) override; }; diff --git a/src/sat/smt/polysat/umul_ovfl_constraint.cpp b/src/sat/smt/polysat/umul_ovfl_constraint.cpp index 5d185e7ee..445169c2f 100644 --- a/src/sat/smt/polysat/umul_ovfl_constraint.cpp +++ b/src/sat/smt/polysat/umul_ovfl_constraint.cpp @@ -77,6 +77,8 @@ namespace polysat { } void umul_ovfl_constraint::propagate(core& c, lbool value, dependency const& dep) { + if (value == l_undef) + return; auto& C = c.cs(); auto p1 = c.subst(p()); auto q1 = c.subst(q()); diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 690548aaa..c14ca1d14 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -256,6 +256,32 @@ namespace polysat { void solver::add_polysat_clause(char const* name, core_vector cs, bool is_redundant) { sat::literal_vector lits; + signed_constraint sc; + unsigned constraint_count = 0; + for (auto e : cs) { + if (std::holds_alternative(e)) { + sc = *std::get_if(&e); + constraint_count++; + } + } + if (constraint_count == 1) { + auto lit = ctx.mk_literal(constraint2expr(sc)); + svector eqs; + for (auto e : cs) { + if (std::holds_alternative(e)) { + auto d = *std::get_if(&e); + if (d.is_literal()) + lits.push_back(d.literal()); + else if (d.is_eq()) { + auto [v1, v2] = d.eq(); + eqs.push_back({ var2enode(v1), var2enode(v2) }); + } + } + } + ctx.propagate(lit, euf::th_explain::propagate(*this, lits, eqs, lit, nullptr)); + return; + } + for (auto e : cs) { if (std::holds_alternative(e)) { auto d = *std::get_if(&e); From 7247bbb78f3d5629ef65404b9f0d04643dfbea88 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 12 Dec 2023 14:42:34 -0800 Subject: [PATCH 44/89] na/ --- src/sat/smt/intblast_solver.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index 707f53832..2e0d2017f 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -113,7 +113,7 @@ namespace intblast { void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override {} - sat::check_result check() override { return sat::check_result::CR_DONE; } + sat::check_result check() override; std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { return out; } From 9b435eda904c48d5db1b3ad2fcfc1c3f85d0de29 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 12 Dec 2023 14:53:10 -0800 Subject: [PATCH 45/89] fixes --- src/sat/smt/intblast_solver.cpp | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 250e279cd..4008ee602 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -28,8 +28,8 @@ namespace intblast { m(ctx.get_manager()), bv(m), a(m), - m_args(m), - m_translate(m) + m_translate(m), + m_args(m) {} euf::theory_var solver::mk_var(euf::enode* n) { @@ -100,6 +100,8 @@ namespace intblast { for (auto arg : *e) m_args.push_back(translated(arg)); translate_bv(e); + + // possibly wait until propagation? if (m.is_bool(e)) add_equiv(expr2literal(e), mk_literal(translated(e))); add_bound_axioms(); @@ -468,6 +470,8 @@ namespace intblast { return; } else if (has_bv_sort) { + if (f->get_family_id() != null_family_id) + throw default_exception("conversion for interpreted functions is not supported by intblast solver"); func_decl* g = nullptr; if (!m_new_funs.find(f, g)) { g = m.mk_fresh_func_decl(e->get_decl()->get_name(), symbol("bv"), f->get_arity(), f->get_domain(), a.mk_int()); @@ -558,7 +562,6 @@ namespace intblast { unsigned lo, hi; expr* old_arg; VERIFY(bv.is_extract(e, lo, hi, old_arg)); - unsigned sz = hi - lo + 1; expr* r = arg(0); if (lo > 0) r = a.mk_idiv(r, a.mk_int(rational::power_of_two(lo))); @@ -611,10 +614,10 @@ namespace intblast { break; } case OP_BNAND: - r = (bnot(band(args))); + r = bnot(band(args)); break; case OP_BAND: - r = (band(args)); + r = band(args); break; // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 @@ -635,18 +638,18 @@ namespace intblast { bv_rewriter_params p(ctx.s().params()); expr* x = arg(0), * y = arg(1); if (p.hi_div0()) - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y)); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, umod(e, 1))); else - r = a.mk_idiv(x, y); + r = a.mk_idiv(x, umod(e, 1)); break; } case OP_BUREM: { bv_rewriter_params p(ctx.s().params()); expr* x = arg(0), * y = arg(1); if (p.hi_div0()) - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y)); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, umod(e, 1))); else - r = a.mk_mod(x, y); + r = a.mk_mod(x, umod(e, 1)); break; } @@ -694,7 +697,7 @@ namespace intblast { case OP_BV2INT: m_bv2int.push_back(e); ctx.push(push_back_vector(m_bv2int)); - r = arg(0); + r = umod(e->get_arg(0), 0); break; case OP_BCOMP: case OP_ROTATE_LEFT: @@ -758,6 +761,7 @@ namespace intblast { if (!e) return false; dep.add(n, e); + return true; } // TODO: handle dependencies properly by using arithmetical model to retrieve values of translated From 40e93d747835d9cd378cbe57cfc78a2765a8c870 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 12 Dec 2023 15:09:08 -0800 Subject: [PATCH 46/89] n/a Signed-off-by: Nikolaj Bjorner --- src/sat/smt/intblast_solver.cpp | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 4008ee602..46a42b0d9 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -636,29 +636,28 @@ namespace intblast { } case OP_BUDIV: { bv_rewriter_params p(ctx.s().params()); - expr* x = arg(0), * y = arg(1); + expr* x = arg(0), * y = umod(e, 1); if (p.hi_div0()) - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, umod(e, 1))); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y)); else - r = a.mk_idiv(x, umod(e, 1)); + r = a.mk_idiv(x, y); break; } case OP_BUREM: { bv_rewriter_params p(ctx.s().params()); - expr* x = arg(0), * y = arg(1); + expr* x = arg(0), * y = umod(e, 1); if (p.hi_div0()) - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, umod(e, 1))); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), y)); else - r = a.mk_mod(x, umod(e, 1)); + r = a.mk_mod(x, y); break; } - - // - // ashr(x, y) - // if y = k & x >= 0 -> x / 2^k - // if y = k & x < 0 -> - (x / 2^k) - // case OP_BASHR: { + // + // ashr(x, y) + // if y = k & x >= 0 -> x / 2^k + // if y = k & x < 0 -> - (x / 2^k) + // rational N = rational::power_of_two(bv.get_bv_size(e)); expr* x = umod(e, 0); expr* y = umod(e, 1); @@ -700,6 +699,9 @@ namespace intblast { r = umod(e->get_arg(0), 0); break; case OP_BCOMP: + bv_expr = e->get_arg(0); + r = m.mk_ite(m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(1), a.mk_int(0)); + break; case OP_ROTATE_LEFT: case OP_ROTATE_RIGHT: case OP_EXT_ROTATE_LEFT: From 35eb95b447e26474df55ce5eeb121d4277658f9c Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 12 Dec 2023 15:42:39 -0800 Subject: [PATCH 47/89] na Signed-off-by: Nikolaj Bjorner --- src/sat/smt/euf_solver.cpp | 20 ++++++++++++++++++++ src/sat/smt/euf_solver.h | 1 + src/sat/smt/intblast_solver.cpp | 2 +- src/sat/smt/polysat/core.cpp | 9 +++++---- src/sat/smt/polysat/types.h | 2 +- src/sat/smt/polysat_solver.cpp | 7 ++++++- src/sat/smt/polysat_solver.h | 2 +- 7 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index b6606d4f6..b95e44d74 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -291,6 +291,26 @@ namespace euf { } } + void solver::get_eq_antecedents(enode* a, enode* b, literal_vector& r) { + m_egraph.begin_explain(); + m_explain.reset(); + m_egraph.explain_eq(m_explain, nullptr, a, b); + for (unsigned qhead = 0; qhead < m_explain.size(); ++qhead) { + size_t* e = m_explain[qhead]; + if (is_literal(e)) + r.push_back(get_literal(e)); + else { + size_t idx = get_justification(e); + auto* ext = sat::constraint_base::to_extension(idx); + SASSERT(ext != this); + sat::literal lit = sat::null_literal; + ext->get_antecedents(lit, idx, r, true); + } + } + m_egraph.end_explain(); + } + + void solver::get_th_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing) { for (auto lit : euf::th_explain::lits(jst)) r.push_back(lit); diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 9cac6e02a..7d2d01473 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -369,6 +369,7 @@ namespace euf { void flush_roots() override; void get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) override; + void get_eq_antecedents(enode* a, enode* b, literal_vector& r); void get_th_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing); void add_eq_antecedent(bool probing, enode* a, enode* b); void explain_diseq(ptr_vector& ex, cc_justification* cc, enode* a, enode* b); diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 46a42b0d9..0b9f08ec8 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -647,7 +647,7 @@ namespace intblast { bv_rewriter_params p(ctx.s().params()); expr* x = arg(0), * y = umod(e, 1); if (p.hi_div0()) - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), y)); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), y); else r = a.mk_mod(x, y); break; diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index dd30e8227..793db7392 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -149,13 +149,14 @@ namespace polysat { m_var = m_var_queue.next_var(); s.trail().push(mk_dqueue_var(m_var, *this)); switch (m_viable.find_viable(m_var, m_value)) { - case find_t::empty: - s.set_lemma(m_viable.get_core(), 0, m_viable.explain()); - // propagate_unsat_core(); + case find_t::empty: + s.set_lemma(m_viable.get_core(), m_viable.explain()); + // propagate_unsat_core(); return sat::check_result::CR_CONTINUE; - case find_t::singleton: + case find_t::singleton: { s.propagate(m_constraints.eq(var2pdd(m_var), m_value), m_viable.explain()); return sat::check_result::CR_CONTINUE; + } case find_t::multiple: s.add_eq_literal(m_var, m_value); return sat::check_result::CR_CONTINUE; diff --git a/src/sat/smt/polysat/types.h b/src/sat/smt/polysat/types.h index 6f855f98b..e7beb3eb1 100644 --- a/src/sat/smt/polysat/types.h +++ b/src/sat/smt/polysat/types.h @@ -107,7 +107,7 @@ namespace polysat { virtual ~solver_interface() {} virtual void add_eq_literal(pvar v, rational const& val) = 0; virtual void set_conflict(dependency_vector const& core) = 0; - virtual void set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) = 0; + virtual void set_lemma(core_vector const& aux_core, dependency_vector const& core) = 0; virtual void add_polysat_clause(char const* name, core_vector cs, bool redundant) = 0; virtual dependency propagate(signed_constraint sc, dependency_vector const& deps) = 0; virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0; diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index c14ca1d14..a713cc4a4 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -130,8 +130,13 @@ namespace polysat { return { core, eqs }; } - void solver::set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) { + void solver::set_lemma(core_vector const& aux_core, dependency_vector const& core) { auto [lits, eqs] = explain_deps(core); + unsigned level = 0; + for (auto const& [n1, n2] : eqs) + ctx.get_eq_antecedents(n1, n2, lits); + for (auto lit : lits) + level = std::max(level, s().lvl(lit)); auto ex = euf::th_explain::conflict(*this, lits, eqs, nullptr); ctx.push(value_trail(m_has_lemma)); m_has_lemma = true; diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index a04c76618..e88eafdd2 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -144,7 +144,7 @@ namespace polysat { // callbacks from core void add_eq_literal(pvar v, rational const& val) override; void set_conflict(dependency_vector const& core) override; - void set_lemma(core_vector const& aux_core, unsigned level, dependency_vector const& core) override; + void set_lemma(core_vector const& aux_core, dependency_vector const& core) override; dependency propagate(signed_constraint sc, dependency_vector const& deps) override; void propagate(dependency const& d, bool sign, dependency_vector const& deps) override; trail_stack& trail() override; From 34229eaa8e63fcfe9d66249adb67621ff1c97d36 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 12 Dec 2023 16:37:39 -0800 Subject: [PATCH 48/89] na Signed-off-by: Nikolaj Bjorner --- src/sat/smt/intblast_solver.cpp | 2 +- src/sat/smt/intblast_solver.h | 2 +- src/sat/smt/polysat/core.cpp | 27 ++++++++++++++++++++++---- src/sat/smt/polysat/ule_constraint.cpp | 2 -- src/sat/smt/polysat_solver.cpp | 7 ++++++- 5 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 0b9f08ec8..21207ff83 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -562,7 +562,7 @@ namespace intblast { unsigned lo, hi; expr* old_arg; VERIFY(bv.is_extract(e, lo, hi, old_arg)); - expr* r = arg(0); + r = arg(0); if (lo > 0) r = a.mk_idiv(r, a.mk_int(rational::power_of_two(lo))); break; diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index 2e0d2017f..ee85ed6e2 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -66,7 +66,7 @@ namespace intblast { rational get_value(expr* e) const; expr* translated(expr* e) const { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } - void set_translated(expr* e, expr* r) { m_translate.setx(e->get_id(), r); } + void set_translated(expr* e, expr* r) { SASSERT(r); m_translate.setx(e->get_id(), r); } expr* arg(unsigned i) { return m_args.get(i); } expr* umod(expr* bv_expr, unsigned i); diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index 793db7392..ac3870635 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -66,9 +66,15 @@ namespace polysat { core& c; public: mk_add_watch(core& c) : c(c) {} - void undo() override { + void undo() override { auto& [sc, lit, val] = c.m_constraint_index.back(); auto& vars = sc.vars(); + verbose_stream() << "undo add watch " << sc << " "; + if (vars.size() > 0) verbose_stream() << "(" << c.m_constraint_index.size() -1 << ": " << c.m_watch[vars[0]] << ") "; + if (vars.size() > 1) verbose_stream() << "(" << c.m_constraint_index.size() -1 << ": " << c.m_watch[vars[1]] << ") "; + verbose_stream() << "\n"; + SASSERT(vars.size() <= 0 || c.m_watch[vars[0]].back() == c.m_constraint_index.size() - 1); + SASSERT(vars.size() <= 1 || c.m_watch[vars[1]].back() == c.m_constraint_index.size() - 1); if (vars.size() > 0) c.m_watch[vars[0]].pop_back(); if (vars.size() > 1) @@ -135,6 +141,10 @@ namespace polysat { add_watch(idx, vars[0]); if (vars.size() > 1) add_watch(idx, vars[1]); + verbose_stream() << "add watch " << sc << " " << vars << " "; + if (vars.size() > 0) verbose_stream() << "( " << idx << " : " << m_watch[vars[0]] << ") "; + if (vars.size() > 1) verbose_stream() << "( " << idx << " : " << m_watch[vars[1]] << ") "; + verbose_stream() << "\n"; s.trail().push(mk_add_watch(*this)); return idx; } @@ -213,6 +223,15 @@ namespace polysat { m_assignment.push(v , value); s.trail().push(mk_assign_var(v, *this)); + return; + // to debug: + unsigned sz = m_watch[v].size(); + for (unsigned i = 0; i < sz; ++i) { + auto idx = m_watch[v][i]; + auto [sc, dep, value] = m_constraint_index[idx]; + sc.propagate(*this, value, dep); + } + // update the watch lists for pvars // remove constraints from m_watch[v] that have more than 2 free variables. // for entries where there is only one free variable left add to viable set @@ -226,15 +245,14 @@ namespace polysat { bool swapped = false; for (unsigned i = vars.size(); i-- > 2; ) { if (!is_assigned(vars[i])) { + verbose_stream() << "watch instead " << idx << " " << vars[i] << "instead of " << vars[0] << "\n"; add_watch(idx, vars[i]); std::swap(vars[i], vars[0]); swapped = true; break; } } - - sc.propagate(*this, value, dep); - + SASSERT(!swapped || vars.size() <= 1 || (!is_assigned(vars[0]) && !is_assigned(vars[1]))); if (swapped) continue; @@ -249,6 +267,7 @@ namespace polysat { m_viable.add_unitary(v1, idx); } m_watch[v].shrink(j); + verbose_stream() << "new watch " << v << ": " << m_watch[v] << "\n"; } void core::propagate_value(prop_item const& dc) { diff --git a/src/sat/smt/polysat/ule_constraint.cpp b/src/sat/smt/polysat/ule_constraint.cpp index bdfcb7c5f..bc363084b 100644 --- a/src/sat/smt/polysat/ule_constraint.cpp +++ b/src/sat/smt/polysat/ule_constraint.cpp @@ -75,8 +75,6 @@ Useful lemmas: #include "sat/smt/polysat/constraints.h" #include "sat/smt/polysat/ule_constraint.h" -#define LOG(_msg_) verbose_stream() << _msg_ << "\n" - namespace polysat { // Simplify lhs <= rhs. diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index a713cc4a4..271b6986e 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -138,6 +138,10 @@ namespace polysat { for (auto lit : lits) level = std::max(level, s().lvl(lit)); auto ex = euf::th_explain::conflict(*this, lits, eqs, nullptr); + if (level == 0) { + ctx.set_conflict(ex); + return; + } ctx.push(value_trail(m_has_lemma)); m_has_lemma = true; m_lemma_level = level; @@ -165,7 +169,8 @@ namespace polysat { if (!m_has_lemma) return l_undef; - unsigned num_scopes = s().scope_lvl() - m_lemma_level; + SASSERT(m_lemma_level > 0); + unsigned num_scopes = s().scope_lvl() - m_lemma_level - 1; NOT_IMPLEMENTED_YET(); // s().pop_reinit(num_scopes); From c6a8ae1e8c9644560d4a45d3976b29708515f771 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 12 Dec 2023 18:00:43 -0800 Subject: [PATCH 49/89] include nyis --- src/sat/smt/intblast_solver.cpp | 56 ++++++++++++++++----------- src/sat/smt/polysat/core.cpp | 11 +++--- src/sat/smt/polysat/op_constraint.cpp | 9 +++++ src/sat/smt/polysat/op_constraint.h | 1 + 4 files changed, 50 insertions(+), 27 deletions(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 21207ff83..e87944a91 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -574,14 +574,16 @@ namespace intblast { r = a.mk_int(val); break; } + case OP_BUREM: case OP_BUREM_I: { expr* x = arg(0), * y = arg(1); - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_mod(x, y)); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, a.mk_mod(x, y)); break; } + case OP_BUDIV: case OP_BUDIV_I: { expr* x = arg(0), * y = arg(1); - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, umod(bv_expr, 1))); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(1), a.mk_idiv(x, umod(bv_expr, 1))); break; } case OP_BUMUL_NO_OVFL: { @@ -634,24 +636,6 @@ namespace intblast { r = bnot(r); break; } - case OP_BUDIV: { - bv_rewriter_params p(ctx.s().params()); - expr* x = arg(0), * y = umod(e, 1); - if (p.hi_div0()) - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), a.mk_idiv(x, y)); - else - r = a.mk_idiv(x, y); - break; - } - case OP_BUREM: { - bv_rewriter_params p(ctx.s().params()); - expr* x = arg(0), * y = umod(e, 1); - if (p.hi_div0()) - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(0), y); - else - r = a.mk_mod(x, y); - break; - } case OP_BASHR: { // // ashr(x, y) @@ -702,6 +686,36 @@ namespace intblast { bv_expr = e->get_arg(0); r = m.mk_ite(m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(1), a.mk_int(0)); break; + case OP_BSMOD: { + bv_expr = e; + expr* x = umod(bv_expr, 0), *y = umod(bv_expr, 0); + rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); + expr* signx = a.mk_ge(x, a.mk_int(N/2)); + expr* signy = a.mk_ge(y, a.mk_int(N/2)); + expr* u = a.mk_mod(x, y); + // x < 0, y < 0 -> r = -u + // x < 0, y >= 0 -> r = y - u + // x >= 0, y < 0 -> r = y + u + // x >= 0, y >= 0 -> r = u + // u = 0 -> r = 0 + // y = 0 -> r = x + r = a.mk_uminus(u); + r = m.mk_ite(m.mk_and(m.mk_not(signx), signy), a.mk_add(u, y), r); + r = m.mk_ite(m.mk_and(signx, m.mk_not(signy)), a.mk_sub(y, u), r); + r = m.mk_ite(m.mk_and(m.mk_not(signx), m.mk_not(signy)), u, r); + r = m.mk_ite(m.mk_eq(u, a.mk_int(0)), a.mk_int(0), r); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, r); + break; + } + case OP_BSDIV: { + // y = 0, x > 0 -> 1 + // y = 0, x <= 0 -> -1 + // y != 0 -> machine_div(x, y) +#if 0 + +#endif + } + case OP_ROTATE_LEFT: case OP_ROTATE_RIGHT: case OP_EXT_ROTATE_LEFT: @@ -709,9 +723,7 @@ namespace intblast { case OP_REPEAT: case OP_BREDOR: case OP_BREDAND: - case OP_BSDIV: case OP_BSREM: - case OP_BSMOD: verbose_stream() << mk_pp(e, m) << "\n"; NOT_IMPLEMENTED_YET(); break; diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index ac3870635..de2fedb5a 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -69,10 +69,11 @@ namespace polysat { void undo() override { auto& [sc, lit, val] = c.m_constraint_index.back(); auto& vars = sc.vars(); + IF_VERBOSE(10, verbose_stream() << "undo add watch " << sc << " "; if (vars.size() > 0) verbose_stream() << "(" << c.m_constraint_index.size() -1 << ": " << c.m_watch[vars[0]] << ") "; if (vars.size() > 1) verbose_stream() << "(" << c.m_constraint_index.size() -1 << ": " << c.m_watch[vars[1]] << ") "; - verbose_stream() << "\n"; + verbose_stream() << "\n"); SASSERT(vars.size() <= 0 || c.m_watch[vars[0]].back() == c.m_constraint_index.size() - 1); SASSERT(vars.size() <= 1 || c.m_watch[vars[1]].back() == c.m_constraint_index.size() - 1); if (vars.size() > 0) @@ -141,10 +142,10 @@ namespace polysat { add_watch(idx, vars[0]); if (vars.size() > 1) add_watch(idx, vars[1]); - verbose_stream() << "add watch " << sc << " " << vars << " "; - if (vars.size() > 0) verbose_stream() << "( " << idx << " : " << m_watch[vars[0]] << ") "; - if (vars.size() > 1) verbose_stream() << "( " << idx << " : " << m_watch[vars[1]] << ") "; - verbose_stream() << "\n"; + IF_VERBOSE(10, verbose_stream() << "add watch " << sc << " " << vars << " "; + if (vars.size() > 0) verbose_stream() << "( " << idx << " : " << m_watch[vars[0]] << ") "; + if (vars.size() > 1) verbose_stream() << "( " << idx << " : " << m_watch[vars[1]] << ") "; + verbose_stream() << "\n"); s.trail().push(mk_add_watch(*this)); return idx; } diff --git a/src/sat/smt/polysat/op_constraint.cpp b/src/sat/smt/polysat/op_constraint.cpp index 234ddea04..c971fe1cd 100644 --- a/src/sat/smt/polysat/op_constraint.cpp +++ b/src/sat/smt/polysat/op_constraint.cpp @@ -42,6 +42,7 @@ namespace polysat { break; case code::inv_op: SASSERT(q.is_zero()); + break; default: break; } @@ -192,6 +193,9 @@ namespace polysat { case code::lshr_op: propagate_lshr(c, dep); break; + case code::ashr_op: + propagate_ashr(c, dep); + break; case code::shl_op: propagate_shl(c, dep); break; @@ -202,6 +206,7 @@ namespace polysat { propagate_inv(c, dep); break; default: + verbose_stream() << "not implemented yet: " << *this << "\n"; NOT_IMPLEMENTED_YET(); break; } @@ -311,6 +316,10 @@ namespace polysat { } } + void op_constraint::propagate_ashr(core& s, dependency const& dep) { + + } + /** * Enforce axioms for constraint: r == p << q diff --git a/src/sat/smt/polysat/op_constraint.h b/src/sat/smt/polysat/op_constraint.h index d7b6be392..1aec1c486 100644 --- a/src/sat/smt/polysat/op_constraint.h +++ b/src/sat/smt/polysat/op_constraint.h @@ -57,6 +57,7 @@ namespace polysat { static lbool eval_inv(pdd const& p, pdd const& r); void propagate_lshr(core& s, dependency const& dep); + void propagate_ashr(core& s, dependency const& dep); void propagate_shl(core& s, dependency const& dep); void propagate_and(core& s, dependency const& dep); void propagate_inv(core& s, dependency const& dep); From 5fdfd4f3f4e533ed4787dbb60f0f9260f389ce33 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 13 Dec 2023 08:17:21 -0800 Subject: [PATCH 50/89] n/a Signed-off-by: Nikolaj Bjorner --- src/sat/smt/arith_axioms.cpp | 7 ++ src/sat/smt/euf_solver.cpp | 13 ++- src/sat/smt/intblast_solver.cpp | 128 +++++++++++++++++---------- src/sat/smt/intblast_solver.h | 4 +- src/sat/smt/polysat/constraints.h | 5 ++ src/sat/smt/polysat/core.cpp | 34 +++---- src/smt/params/smt_params_helper.pyg | 1 + src/smt/params/theory_bv_params.cpp | 2 + src/smt/params/theory_bv_params.h | 1 + 9 files changed, 128 insertions(+), 67 deletions(-) diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 046470000..09db74f75 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -208,6 +208,8 @@ namespace arith { bool solver::check_band_term(app* n) { unsigned sz; expr* x, * y; + if (!ctx.is_relevant(expr2enode(n))) + return true; VERIFY(a.is_band(n, sz, x, y)); if (use_nra_model()) { found_unsupported(n); @@ -217,6 +219,11 @@ namespace arith { theory_var vy = expr2enode(y)->get_th_var(get_id()); theory_var vn = expr2enode(n)->get_th_var(get_id()); rational N = rational::power_of_two(sz); + if (!get_value(vx).is_int() || !get_value(vy).is_int()) { + + s().display(verbose_stream()); + verbose_stream() << vx << " " << vy << " " << mk_pp(n, m) << "\n"; + } SASSERT(get_value(vx).is_int()); SASSERT(get_value(vy).is_int()); SASSERT(get_value(vn).is_int()); diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index b95e44d74..f1ccb6879 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -22,6 +22,7 @@ Author: #include "sat/smt/pb_solver.h" #include "sat/smt/bv_solver.h" #include "sat/smt/polysat_solver.h" +#include "sat/smt/intblast_solver.h" #include "sat/smt/euf_solver.h" #include "sat/smt/array_solver.h" #include "sat/smt/arith_solver.h" @@ -135,8 +136,16 @@ namespace euf { special_relations_util sp(m); if (pb.get_family_id() == fid) ext = alloc(pb::solver, *this, fid); - else if (bvu.get_family_id() == fid) - ext = alloc(polysat::solver, *this, fid); + else if (bvu.get_family_id() == fid) { + if (get_config().m_bv_solver == 0) + ext = alloc(bv::solver, *this, fid); + else if (get_config().m_bv_solver == 1) + ext = alloc(polysat::solver, *this, fid); + else if (get_config().m_bv_solver == 2) + ext = alloc(intblast::solver, *this); + else + throw default_exception("unknown bit-vector solver. Accepted values 0 (bit blast), 1 (polysat), 2 (int blast)"); + } else if (au.get_family_id() == fid) ext = alloc(array::solver, *this, fid); else if (fpa.get_family_id() == fid) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index e87944a91..32bf52f79 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -88,22 +88,25 @@ namespace intblast { expr* e = n->get_expr(); expr* x, * y; VERIFY(m.is_eq(n->get_expr(), x, y)); + SASSERT(bv.is_bv(x)); + ensure_translated(x); + ensure_translated(y); m_args.reset(); - m_args.push_back(translated(x)); - m_args.push_back(translated(y)); - add_equiv(expr2literal(e), eq_internalize(umod(x, 0), umod(x, 1))); + m_args.push_back(a.mk_sub(translated(x), translated(y))); + expr_ref lhs(umod(x, 0), m); + ctx.get_rewriter()(lhs); + add_equiv(expr2literal(e), eq_internalize(lhs, a.mk_int(0))); } void solver::internalize_bv(app* e) { - ensure_args(e); - m_args.reset(); - for (auto arg : *e) - m_args.push_back(translated(arg)); - translate_bv(e); + ensure_translated(e); // possibly wait until propagation? - if (m.is_bool(e)) - add_equiv(expr2literal(e), mk_literal(translated(e))); + if (m.is_bool(e)) { + expr_ref r(translated(e), m); + ctx.get_rewriter()(r); + add_equiv(expr2literal(e), mk_literal(r)); + } add_bound_axioms(); } @@ -120,32 +123,28 @@ namespace intblast { } } - void solver::ensure_args(app* e) { + void solver::ensure_translated(expr* e) { + if (m_translate.get(e->get_id(), nullptr)) + return; ptr_vector todo; ast_fast_mark1 visited; - for (auto arg : *e) { - if (!m_translate.get(arg->get_id(), nullptr)) - todo.push_back(arg); - } - if (todo.empty()) - return; + todo.push_back(e); + visited.mark(e); for (unsigned i = 0; i < todo.size(); ++i) { expr* e = todo[i]; - if (m.is_bool(e)) + if (!is_app(e)) continue; - else if (is_app(e)) { - for (auto arg : *to_app(e)) - if (!visited.is_marked(arg)) { - visited.mark(arg); - todo.push_back(arg); - } - } - else if (is_lambda(e)) - throw default_exception("lambdas are not supported in intblaster"); + app* a = to_app(e); + if (m.is_bool(e) && a->get_family_id() != bv.get_family_id()) + continue; + for (auto arg : *a) + if (!visited.is_marked(arg) && !m_translate.get(arg->get_id(), nullptr)) { + visited.mark(arg); + todo.push_back(arg); + } } - std::stable_sort(todo.begin(), todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); - for (expr* e : todo) + for (expr* e : todo) translate_expr(e); } @@ -369,7 +368,7 @@ namespace intblast { } if (any_of(m_vars, [&](expr* v) { return translated(v) == x; })) return x; - return to_expr(a.mk_mod(x, a.mk_int(N))); + return a.mk_mod(x, a.mk_int(N)); } expr* solver::smod(expr* bv_expr, unsigned i) { @@ -393,6 +392,10 @@ namespace intblast { translate_var(to_var(e)); else { app* ap = to_app(e); + if (m_is_plugin && ap->get_family_id() == basic_family_id && m.is_bool(ap)) { + set_translated(e, e); + return; + } m_args.reset(); for (auto arg : *ap) m_args.push_back(translated(arg)); @@ -543,7 +546,6 @@ namespace intblast { r = a.mk_uminus(arg(0)); break; case OP_CONCAT: { - r = a.mk_int(0); unsigned sz = 0; for (unsigned i = 0; i < args.size(); ++i) { expr* old_arg = e->get_arg(i); @@ -595,7 +597,7 @@ namespace intblast { expr* x = arg(0), * y = arg(1); r = a.mk_int(0); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); break; } case OP_BNOT: @@ -605,7 +607,7 @@ namespace intblast { expr* x = arg(0), * y = arg(1); r = a.mk_int(0); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); break; } // Or use (p + q) - band(p, q)? @@ -649,7 +651,7 @@ namespace intblast { r = m.mk_ite(signbit, a.mk_int(N - 1), a.mk_int(0)); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) { expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); - r = m.mk_ite(a.mk_eq(y, a.mk_int(i)), + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), m.mk_ite(signbit, a.mk_uminus(d), d), r); } @@ -686,6 +688,7 @@ namespace intblast { bv_expr = e->get_arg(0); r = m.mk_ite(m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(1), a.mk_int(0)); break; + case OP_BSMOD_I: case OP_BSMOD: { bv_expr = e; expr* x = umod(bv_expr, 0), *y = umod(bv_expr, 0); @@ -693,12 +696,12 @@ namespace intblast { expr* signx = a.mk_ge(x, a.mk_int(N/2)); expr* signy = a.mk_ge(y, a.mk_int(N/2)); expr* u = a.mk_mod(x, y); - // x < 0, y < 0 -> r = -u - // x < 0, y >= 0 -> r = y - u - // x >= 0, y < 0 -> r = y + u - // x >= 0, y >= 0 -> r = u - // u = 0 -> r = 0 - // y = 0 -> r = x + // u = 0 -> 0 + // y = 0 -> x + // x < 0, y < 0 -> -u + // x < 0, y >= 0 -> y - u + // x >= 0, y < 0 -> y + u + // x >= 0, y >= 0 -> u r = a.mk_uminus(u); r = m.mk_ite(m.mk_and(m.mk_not(signx), signy), a.mk_add(u, y), r); r = m.mk_ite(m.mk_and(signx, m.mk_not(signy)), a.mk_sub(y, u), r); @@ -707,15 +710,41 @@ namespace intblast { r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, r); break; } + case OP_BSDIV_I: case OP_BSDIV: { + // d = udiv(x mod N, y mod N) // y = 0, x > 0 -> 1 // y = 0, x <= 0 -> -1 - // y != 0 -> machine_div(x, y) -#if 0 - -#endif + // x = 0, y != 0 -> 0 + // x < 0, y < 0 -> -d + // x < 0, y > 0 -> -d + // x > 0, y > 0 -> d + // x < 0, y < 0 -> d + bv_expr = e; + expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 0); + rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + expr* signy = a.mk_ge(y, a.mk_int(N / 2)); + expr* d = a.mk_idiv(x, y); + r = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), m.mk_ite(signx, a.mk_int(-1), a.mk_int(1)), r); + break; + } + case OP_BSREM_I: + case OP_BSREM: { + // y = 0 -> x + // else x - sdiv(x, y) * y + bv_expr = e; + expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 0); + rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + expr* signy = a.mk_ge(y, a.mk_int(N / 2)); + expr* d = a.mk_idiv(x, y); + d = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); + r = a.mk_sub(x, a.mk_mul(d, y)); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, r); + break; } - case OP_ROTATE_LEFT: case OP_ROTATE_RIGHT: case OP_EXT_ROTATE_LEFT: @@ -723,7 +752,7 @@ namespace intblast { case OP_REPEAT: case OP_BREDOR: case OP_BREDAND: - case OP_BSREM: + verbose_stream() << mk_pp(e, m) << "\n"; NOT_IMPLEMENTED_YET(); break; @@ -739,11 +768,16 @@ namespace intblast { bool has_bv_arg = any_of(*e, [&](expr* arg) { return bv.is_bv(arg); }); if (has_bv_arg) { expr* bv_expr = e->get_arg(0); - set_translated(e, m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1))); + m_args[0] = a.mk_sub(arg(0), arg(1)); + set_translated(e, m.mk_eq(umod(bv_expr, 0), a.mk_int(0))); } else set_translated(e, m.mk_eq(arg(0), arg(1))); } + else if (m.is_ite(e)) + set_translated(e, m.mk_ite(arg(0), arg(1), arg(2))); + else if (m_is_plugin) + set_translated(e, e); else set_translated(e, m.mk_app(e->get_decl(), m_args)); } diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index ee85ed6e2..7dd37d5a7 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -57,7 +57,7 @@ namespace intblast { sat::literal_vector m_core; ptr_vector m_bv2int, m_int2bv; statistics m_stats; - bool m_is_plugin = true; // when the solver is used as a plugin, then do not translate below quantifiers. + bool m_is_plugin = true; // when the solver is used as a plugin, then do not translate below quantifiers. bool is_bv(sat::literal lit); void translate(expr_ref_vector& es); @@ -80,7 +80,7 @@ namespace intblast { void translate_quantifier(quantifier* q); void translate_var(var* v); - void ensure_args(app* e); + void ensure_translated(expr* e); void internalize_bv(app* e); unsigned m_vars_qhead = 0; diff --git a/src/sat/smt/polysat/constraints.h b/src/sat/smt/polysat/constraints.h index a9ec63165..47c9beb49 100644 --- a/src/sat/smt/polysat/constraints.h +++ b/src/sat/smt/polysat/constraints.h @@ -31,12 +31,15 @@ namespace polysat { class constraint { unsigned_vector m_vars; + unsigned m_num_watch = 0; public: virtual ~constraint() {} unsigned_vector& vars() { return m_vars; } unsigned_vector const& vars() const { return m_vars; } unsigned var(unsigned idx) const { return m_vars[idx]; } bool contains_var(pvar v) const { return m_vars.contains(v); } + unsigned num_watch() const { return m_num_watch; } + void set_num_watch(unsigned n) { SASSERT(n <= 2); m_num_watch = n; } virtual std::ostream& display(std::ostream& out, lbool status) const = 0; virtual std::ostream& display(std::ostream& out) const = 0; virtual lbool eval() const = 0; @@ -63,6 +66,8 @@ namespace polysat { unsigned_vector const& vars() const { return m_constraint->vars(); } unsigned var(unsigned idx) const { return m_constraint->var(idx); } bool contains_var(pvar v) const { return m_constraint->contains_var(v); } + unsigned num_watch() const { return m_constraint->num_watch(); } + void set_num_watch(unsigned n) { m_constraint->set_num_watch(n); } void activate(core& c, dependency const& d) { m_constraint->activate(c, m_sign, d); } void propagate(core& c, lbool value, dependency const& d) { m_constraint->propagate(c, value, d); } bool is_always_true() const { return eval() == l_true; } diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index de2fedb5a..c0b56a3d8 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -74,11 +74,13 @@ namespace polysat { if (vars.size() > 0) verbose_stream() << "(" << c.m_constraint_index.size() -1 << ": " << c.m_watch[vars[0]] << ") "; if (vars.size() > 1) verbose_stream() << "(" << c.m_constraint_index.size() -1 << ": " << c.m_watch[vars[1]] << ") "; verbose_stream() << "\n"); - SASSERT(vars.size() <= 0 || c.m_watch[vars[0]].back() == c.m_constraint_index.size() - 1); - SASSERT(vars.size() <= 1 || c.m_watch[vars[1]].back() == c.m_constraint_index.size() - 1); - if (vars.size() > 0) + unsigned n = sc.num_watch(); + SASSERT(n <= vars.size()); + SASSERT(n <= 0 || c.m_watch[vars[0]].back() == c.m_constraint_index.size() - 1); + SASSERT(n <= 1 || c.m_watch[vars[1]].back() == c.m_constraint_index.size() - 1); + if (n > 0) c.m_watch[vars[0]].pop_back(); - if (vars.size() > 1) + if (n > 1) c.m_watch[vars[1]].pop_back(); c.m_constraint_index.pop_back(); } @@ -138,9 +140,10 @@ namespace polysat { for (; i < sz && j < 2; ++i) if (!is_assigned(vars[i])) std::swap(vars[i], vars[j++]); - if (vars.size() > 0) + sc.set_num_watch(i); + if (i > 0) add_watch(idx, vars[0]); - if (vars.size() > 1) + if (i > 1) add_watch(idx, vars[1]); IF_VERBOSE(10, verbose_stream() << "add watch " << sc << " " << vars << " "; if (vars.size() > 0) verbose_stream() << "( " << idx << " : " << m_watch[vars[0]] << ") "; @@ -225,19 +228,12 @@ namespace polysat { s.trail().push(mk_assign_var(v, *this)); return; - // to debug: - unsigned sz = m_watch[v].size(); - for (unsigned i = 0; i < sz; ++i) { - auto idx = m_watch[v][i]; - auto [sc, dep, value] = m_constraint_index[idx]; - sc.propagate(*this, value, dep); - } - // update the watch lists for pvars // remove constraints from m_watch[v] that have more than 2 free variables. // for entries where there is only one free variable left add to viable set - unsigned j = 0; - for (auto idx : m_watch[v]) { + unsigned j = 0, sz = m_watch[v].size(); + for (unsigned k = 0; k < sz; ++k) { + auto idx = m_watch[v][k]; auto [sc, dep, value] = m_constraint_index[idx]; auto& vars = sc.vars(); if (vars[0] != v) @@ -253,6 +249,11 @@ namespace polysat { break; } } + + // this can create fresh literals and update m_watch, but + // will not update m_watch[v] (other than copy constructor for m_watch) + // because v has been assigned a value. + sc.propagate(*this, value, dep); SASSERT(!swapped || vars.size() <= 1 || (!is_assigned(vars[0]) && !is_assigned(vars[1]))); if (swapped) @@ -267,6 +268,7 @@ namespace polysat { // detect unitary, add to viable, detect conflict? m_viable.add_unitary(v1, idx); } + SASSERT(m_watch[v].size() == sz && "size of watch list was not changed"); m_watch[v].shrink(j); verbose_stream() << "new watch " << v << ": " << m_watch[v] << "\n"; } diff --git a/src/smt/params/smt_params_helper.pyg b/src/smt/params/smt_params_helper.pyg index 300bef1fb..b882c1abf 100644 --- a/src/smt/params/smt_params_helper.pyg +++ b/src/smt/params/smt_params_helper.pyg @@ -54,6 +54,7 @@ def_module_params(module_name='smt', ('bv.watch_diseq', BOOL, False, 'use watch lists instead of eager axioms for bit-vectors'), ('bv.delay', BOOL, False, 'delay internalize expensive bit-vector operations'), ('bv.size_reduce', BOOL, False, 'pre-processing; turn assertions that set the upper bits of a bit-vector to constants into a substitution that replaces the bit-vector with constant bits. Useful for minimizing circuits as many input bits to circuits are constant'), + ('bv.solver', UINT, 1, 'bit-vector solver engine: 0 - bit-blasting, 1 - polysat, 2 - intblast, requires sat.smt=true'), ('arith.random_initial_value', BOOL, False, 'use random initial values in the simplex-based procedure for linear arithmetic'), ('arith.solver', UINT, 6, 'arithmetic solver: 0 - no solver, 1 - bellman-ford based solver (diff. logic only), 2 - simplex based solver, 3 - floyd-warshall based solver (diff. logic only) and no theory combination 4 - utvpi, 5 - infinitary lra, 6 - lra solver'), ('arith.nl', BOOL, True, '(incomplete) nonlinear arithmetic support based on Groebner basis and interval propagation, relevant only if smt.arith.solver=2'), diff --git a/src/smt/params/theory_bv_params.cpp b/src/smt/params/theory_bv_params.cpp index 734a983fb..8a3ddcf37 100644 --- a/src/smt/params/theory_bv_params.cpp +++ b/src/smt/params/theory_bv_params.cpp @@ -28,6 +28,7 @@ void theory_bv_params::updt_params(params_ref const & _p) { m_bv_enable_int2bv2int = p.bv_enable_int2bv(); m_bv_delay = p.bv_delay(); m_bv_size_reduce = p.bv_size_reduce(); + m_bv_solver = p.bv_solver(); } #define DISPLAY_PARAM(X) out << #X"=" << X << '\n'; @@ -42,4 +43,5 @@ void theory_bv_params::display(std::ostream & out) const { DISPLAY_PARAM(m_bv_enable_int2bv2int); DISPLAY_PARAM(m_bv_delay); DISPLAY_PARAM(m_bv_size_reduce); + DISPLAY_PARAM(m_bv_solver); } diff --git a/src/smt/params/theory_bv_params.h b/src/smt/params/theory_bv_params.h index 523459f09..97428c8ba 100644 --- a/src/smt/params/theory_bv_params.h +++ b/src/smt/params/theory_bv_params.h @@ -36,6 +36,7 @@ struct theory_bv_params { bool m_bv_watch_diseq = false; bool m_bv_delay = true; bool m_bv_size_reduce = false; + unsigned m_bv_solver = 0; theory_bv_params(params_ref const & p = params_ref()) { updt_params(p); } From 5dfe86fc2d73aaf0fd48048114cbb097972962df Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 13 Dec 2023 14:13:16 -0800 Subject: [PATCH 51/89] bugfixes in intblast solver Signed-off-by: Nikolaj Bjorner --- src/math/lp/int_solver.cpp | 6 +- src/sat/smt/CMakeLists.txt | 1 + src/sat/smt/arith_axioms.cpp | 1 + src/sat/smt/arith_solver.cpp | 35 +++++-- src/sat/smt/arith_solver.h | 2 + src/sat/smt/dt_solver.cpp | 2 +- src/sat/smt/intblast_solver.cpp | 171 ++++++++++++++++++++++---------- src/sat/smt/intblast_solver.h | 13 ++- src/smt/theory_datatype.cpp | 2 +- src/util/trail.h | 6 +- 10 files changed, 163 insertions(+), 76 deletions(-) diff --git a/src/math/lp/int_solver.cpp b/src/math/lp/int_solver.cpp index c324af5b6..9cbc765d4 100644 --- a/src/math/lp/int_solver.cpp +++ b/src/math/lp/int_solver.cpp @@ -207,8 +207,10 @@ namespace lp { #endif m_cut_vars.reset(); - if (r == lia_move::undef) r = int_branch(*this)(); - if (settings().get_cancel_flag()) r = lia_move::undef; + if (settings().get_cancel_flag()) + return lia_move::undef; + if (r == lia_move::undef) + r = int_branch(*this)(); return r; } diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 2302a6c39..1ed9c05ca 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -5,6 +5,7 @@ z3_add_component(sat_smt arith_internalize.cpp arith_sls.cpp arith_solver.cpp + arith_value.cpp array_axioms.cpp array_diagnostics.cpp array_internalize.cpp diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 09db74f75..0150824b2 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -250,6 +250,7 @@ namespace arith { add_clause(~bitof(n, i), bitof(y, i)); else continue; + verbose_stream() << "added b-and clause\n"; return false; } return true; diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 306a6cce0..37aef2bf8 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -619,17 +619,20 @@ namespace arith { } } - void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + bool solver::get_value(euf::enode* n, expr_ref& value) { theory_var v = n->get_th_var(get_id()); expr* o = n->get_expr(); - expr_ref value(m); + if (m.is_value(n->get_root()->get_expr())) { value = n->get_root()->get_expr(); } else if (use_nra_model() && lp().external_to_local(v) != lp::null_lpvar) { anum const& an = nl_value(v, m_nla->tmp1()); + + + if (a.is_int(o) && !m_nla->am().is_int(an)) - value = a.mk_numeral(rational::zero(), a.is_int(o)); + value = a.mk_numeral(rational::zero(), a.is_int(o)); else value = a.mk_numeral(m_nla->am(), nl_value(v, m_nla->tmp1()), a.is_int(o)); } @@ -637,24 +640,35 @@ namespace arith { rational r = get_value(v); TRACE("arith", tout << mk_pp(o, m) << " v" << v << " := " << r << "\n";); SASSERT("integer variables should have integer values: " && (ctx.get_config().m_arith_ignore_int || !a.is_int(o) || r.is_int() || m_not_handled != nullptr || m.limit().is_canceled())); - if (a.is_int(o) && !r.is_int()) + if (a.is_int(o) && !r.is_int()) r = floor(r); value = a.mk_numeral(r, o->get_sort()); } + else + return false; + + return true; + } + + + void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { + expr_ref value(m); + expr* o = n->get_expr(); + if (get_value(n, value)) + ; else if (a.is_arith_expr(o) && reflect(o)) { expr_ref_vector args(m); for (auto* arg : *to_app(o)) { if (m.is_value(arg)) args.push_back(arg); - else + else args.push_back(values.get(ctx.get_enode(arg)->get_root_id())); } value = m.mk_app(to_app(o)->get_decl(), args.size(), args.data()); ctx.get_rewriter()(value); } - else { - value = mdl.get_fresh_value(o->get_sort()); - } + else + value = mdl.get_fresh_value(n->get_sort()); mdl.register_value(value); values.set(n->get_root_id(), value); } @@ -1042,7 +1056,7 @@ namespace arith { if (!check_delayed_eqs()) return sat::check_result::CR_CONTINUE; - if (!check_band_terms()) + if (!int_undef && !check_band_terms()) return sat::check_result::CR_CONTINUE; if (ctx.get_config().m_arith_ignore_int && int_undef) @@ -1195,7 +1209,8 @@ namespace arith { lia_check = l_undef; break; case lp::lia_move::continue_with_check: - lia_check = l_undef; + TRACE("arith", tout << "continue-with-check\n"); + lia_check = l_false; break; default: UNREACHABLE(); diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 50cdc63ef..022dbeaea 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -526,6 +526,8 @@ namespace arith { bool add_eq(lpvar u, lpvar v, lp::explanation const& e, bool is_fixed); void consume(rational const& v, lp::constraint_index j); bool bound_is_interesting(unsigned vi, lp::lconstraint_kind kind, const rational& bval) const; + + bool get_value(euf::enode* n, expr_ref& val); }; diff --git a/src/sat/smt/dt_solver.cpp b/src/sat/smt/dt_solver.cpp index daecb7325..52c4ed953 100644 --- a/src/sat/smt/dt_solver.cpp +++ b/src/sat/smt/dt_solver.cpp @@ -400,7 +400,7 @@ namespace dt { return; } SASSERT(val == l_undef || (val == l_false && !d->m_constructor)); - ctx.push(set_vector_idx_trail(d->m_recognizers, c_idx)); + ctx.push(set_vector_idx_trail(d->m_recognizers, c_idx)); d->m_recognizers[c_idx] = recognizer; if (val == l_false) propagate_recognizer(v, recognizer); diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 32bf52f79..9960197fb 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -17,6 +17,7 @@ Author: #include "params/bv_rewriter_params.hpp" #include "sat/smt/intblast_solver.h" #include "sat/smt/euf_solver.h" +#include "sat/smt/arith_value.h" namespace intblast { @@ -29,7 +30,8 @@ namespace intblast { bv(m), a(m), m_translate(m), - m_args(m) + m_args(m), + m_pinned(m) {} euf::theory_var solver::mk_var(euf::enode* n) { @@ -89,40 +91,70 @@ namespace intblast { expr* x, * y; VERIFY(m.is_eq(n->get_expr(), x, y)); SASSERT(bv.is_bv(x)); - ensure_translated(x); - ensure_translated(y); - m_args.reset(); - m_args.push_back(a.mk_sub(translated(x), translated(y))); - expr_ref lhs(umod(x, 0), m); - ctx.get_rewriter()(lhs); - add_equiv(expr2literal(e), eq_internalize(lhs, a.mk_int(0))); + if (!is_translated(e)) { + ensure_translated(x); + ensure_translated(y); + m_args.reset(); + m_args.push_back(a.mk_sub(translated(x), translated(y))); + set_translated(e, m.mk_eq(umod(x, 0), a.mk_int(0))); + } + m_preds.push_back(e); + ctx.push(push_back_vector(m_preds)); + } + + void solver::set_translated(expr* e, expr* r) { + SASSERT(r); + SASSERT(!is_translated(e)); + m_translate.setx(e->get_id(), r); + ctx.push(set_vector_idx_trail(m_translate, e->get_id())); } void solver::internalize_bv(app* e) { ensure_translated(e); - - // possibly wait until propagation? if (m.is_bool(e)) { - expr_ref r(translated(e), m); - ctx.get_rewriter()(r); - add_equiv(expr2literal(e), mk_literal(r)); + m_preds.push_back(e); + ctx.push(push_back_vector(m_preds)); } - add_bound_axioms(); } - void solver::add_bound_axioms() { + bool solver::add_bound_axioms() { if (m_vars_qhead == m_vars.size()) - return; + return false; ctx.push(value_trail(m_vars_qhead)); for (; m_vars_qhead < m_vars.size(); ++m_vars_qhead) { auto v = m_vars[m_vars_qhead]; auto w = translated(v); auto sz = rational::power_of_two(bv.get_bv_size(v->get_sort())); - add_unit(ctx.mk_literal(a.mk_ge(w, a.mk_int(0)))); - add_unit(ctx.mk_literal(a.mk_le(w, a.mk_int(sz - 1)))); + auto lo = ctx.mk_literal(a.mk_ge(w, a.mk_int(0))); + auto hi = ctx.mk_literal(a.mk_le(w, a.mk_int(sz - 1))); + ctx.mark_relevant(lo); + ctx.mark_relevant(hi); + add_unit(lo); + add_unit(hi); } + return true; } + bool solver::add_predicate_axioms() { + if (m_preds_qhead == m_preds.size()) + return false; + ctx.push(value_trail(m_preds_qhead)); + for (; m_preds_qhead < m_preds.size(); ++m_preds_qhead) { + expr* e = m_preds[m_preds_qhead]; + expr_ref r(translated(e), m); + ctx.get_rewriter()(r); + auto a = expr2literal(e); + auto b = mk_literal(r); + ctx.mark_relevant(b); + add_equiv(a, b); + } + return true; + } + + bool solver::unit_propagate() { + return add_bound_axioms() || add_predicate_axioms(); + } + void solver::ensure_translated(expr* e) { if (m_translate.get(e->get_id(), nullptr)) return; @@ -200,7 +232,6 @@ namespace intblast { } m_core.reset(); - m_translate.reset(); m_is_plugin = false; m_solver = mk_smt2_solver(m, s.params(), symbol::null); @@ -256,6 +287,8 @@ namespace intblast { void solver::sorted_subterms(expr_ref_vector& es, ptr_vector& sorted) { expr_fast_mark1 visited; for (expr* e : es) { + if (is_translated(e)) + continue; sorted.push_back(e); visited.mark(e); } @@ -264,7 +297,7 @@ namespace intblast { if (is_app(e)) { app* a = to_app(e); for (expr* arg : *a) { - if (!visited.is_marked(arg)) { + if (!visited.is_marked(arg) && !is_translated(arg)) { visited.mark(arg); sorted.push_back(arg); } @@ -287,7 +320,7 @@ namespace intblast { expr* r = n->get_root()->get_expr(); es.push_back(m.mk_eq(e, r)); r = es.back(); - if (!visited.is_marked(r)) { + if (!visited.is_marked(r) && !is_translated(r)) { visited.mark(r); sorted.push_back(r); } @@ -295,7 +328,7 @@ namespace intblast { else if (is_quantifier(e)) { quantifier* q = to_quantifier(e); expr* b = q->get_expr(); - if (!visited.is_marked(b)) { + if (!visited.is_marked(b) && !is_translated(b)) { visited.mark(b); sorted.push_back(b); } @@ -333,7 +366,11 @@ namespace intblast { continue; if (sib->get_arg(0)->get_root() == r1) continue; - add_clause(~eq_internalize(n, sib), eq_internalize(sib->get_arg(0), n->get_arg(0)), nullptr); + auto a = eq_internalize(n, sib); + auto b = eq_internalize(sib->get_arg(0), n->get_arg(0)); + ctx.mark_relevant(a); + ctx.mark_relevant(b); + add_clause(~a, b, nullptr); return sat::check_result::CR_CONTINUE; } } @@ -350,7 +387,9 @@ namespace intblast { auto nBv2int = ctx.get_enode(bv2int); auto nxModN = ctx.get_enode(xModN); if (nBv2int->get_root() != nxModN->get_root()) { - add_unit(eq_internalize(nBv2int, nxModN)); + auto a = eq_internalize(nBv2int, nxModN); + ctx.mark_relevant(a); + add_unit(a); return sat::check_result::CR_CONTINUE; } } @@ -366,7 +405,7 @@ namespace intblast { return x; return a.mk_int(mod(r, N)); } - if (any_of(m_vars, [&](expr* v) { return translated(v) == x; })) + if (any_of(m_vars, [&](expr* v) { return translated(v) == x && bv.get_bv_size(v) == bv.get_bv_size(bv_expr); })) return x; return a.mk_mod(x, a.mk_int(N)); } @@ -481,6 +520,7 @@ namespace intblast { m_new_funs.insert(f, g); } f = g; + m_pinned.push_back(f); } set_translated(e, m.mk_app(f, m_args)); } @@ -578,14 +618,14 @@ namespace intblast { } case OP_BUREM: case OP_BUREM_I: { - expr* x = arg(0), * y = arg(1); + expr* x = arg(0), * y = umod(e, 1); r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, a.mk_mod(x, y)); break; } case OP_BUDIV: case OP_BUDIV_I: { - expr* x = arg(0), * y = arg(1); - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(1), a.mk_idiv(x, umod(bv_expr, 1))); + expr* x = arg(0), * y = umod(e, 1); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(1), a.mk_idiv(x, y)); break; } case OP_BUMUL_NO_OVFL: { @@ -594,24 +634,24 @@ namespace intblast { break; } case OP_BSHL: { - expr* x = arg(0), * y = arg(1); + expr* x = arg(0), * y = umod(e, 1); r = a.mk_int(0); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); break; } case OP_BNOT: r = bnot(arg(0)); break; case OP_BLSHR: { - expr* x = arg(0), * y = arg(1); + expr* x = arg(0), * y = umod(e, 1); r = a.mk_int(0); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); break; - } - // Or use (p + q) - band(p, q)? + } case OP_BOR: { + // p | q := (p + q) - band(p, q) r = arg(0); for (unsigned i = 1; i < args.size(); ++i) r = a.mk_sub(a.mk_add(r, arg(i)), a.mk_band(bv.get_bv_size(e), r, arg(i))); @@ -623,11 +663,9 @@ namespace intblast { case OP_BAND: r = band(args); break; - // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; - // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 - // (p + q) - 2*band(p, q); case OP_BXNOR: case OP_BXOR: { + // p ^ q := (p + q) - 2*band(p, q); unsigned sz = bv.get_bv_size(e); r = arg(0); for (unsigned i = 1; i < args.size(); ++i) { @@ -691,7 +729,7 @@ namespace intblast { case OP_BSMOD_I: case OP_BSMOD: { bv_expr = e; - expr* x = umod(bv_expr, 0), *y = umod(bv_expr, 0); + expr* x = umod(bv_expr, 0), *y = umod(bv_expr, 1); rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); expr* signx = a.mk_ge(x, a.mk_int(N/2)); expr* signy = a.mk_ge(y, a.mk_int(N/2)); @@ -721,7 +759,7 @@ namespace intblast { // x > 0, y > 0 -> d // x < 0, y < 0 -> d bv_expr = e; - expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 0); + expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 1); rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); expr* signx = a.mk_ge(x, a.mk_int(N / 2)); expr* signy = a.mk_ge(y, a.mk_int(N / 2)); @@ -735,7 +773,7 @@ namespace intblast { // y = 0 -> x // else x - sdiv(x, y) * y bv_expr = e; - expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 0); + expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 1); rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); expr* signx = a.mk_ge(x, a.mk_int(N / 2)); expr* signy = a.mk_ge(y, a.mk_int(N / 2)); @@ -751,8 +789,7 @@ namespace intblast { case OP_EXT_ROTATE_RIGHT: case OP_REPEAT: case OP_BREDOR: - case OP_BREDAND: - + case OP_BREDAND: verbose_stream() << mk_pp(e, m) << "\n"; NOT_IMPLEMENTED_YET(); break; @@ -804,26 +841,46 @@ namespace intblast { } bool solver::add_dep(euf::enode* n, top_sort& dep) { - // bv2int - auto e = ctx.get_enode(translated(n->get_expr())); - if (!e) + if (!is_app(n->get_expr())) return false; - dep.add(n, e); + app* e = to_app(n->get_expr()); + if (n->num_args() == 0) { + dep.insert(n, nullptr); + return true; + } + if (e->get_family_id() != bv.get_family_id()) + return false; + for (euf::enode* arg : euf::enode_args(n)) + dep.add(n, arg->get_root()); return true; } // TODO: handle dependencies properly by using arithmetical model to retrieve values of translated // bit-vectors directly. - void solver::add_value_plugin(euf::enode* n, model& mdl, expr_ref_vector& values) { - SASSERT(bv.is_bv(n->get_expr())); - rational N = rational::power_of_two(bv.get_bv_size(n->get_expr())); - auto e = ctx.get_enode(translated(n->get_expr())); + void solver::add_value_solver(euf::enode* n, model& mdl, expr_ref_vector& values) { + expr* e = n->get_expr(); + SASSERT(bv.is_bv(e)); + + if (bv.is_numeral(e)) { + values.setx(n->get_root_id(), e); + return; + } + + rational r, N = rational::power_of_two(bv.get_bv_size(e)); + expr* te = translated(e); + model_ref mdlr; + m_solver->get_model(mdlr); expr_ref value(m); - value = values.get(e->get_root_id()); - values.setx(n->get_root_id(), value); + if (mdlr->eval_expr(te, value, true) && a.is_numeral(value, r)) { + values.setx(n->get_root_id(), bv.mk_numeral(mod(r, N), bv.get_bv_size(e))); + return; + } + ctx.s().display(verbose_stream()); + verbose_stream() << "failed to evaluate " << mk_pp(te, m) << " " << value << "\n"; + UNREACHABLE(); } - void solver::add_value_solver(euf::enode* n, model& mdl, expr_ref_vector& values) { + void solver::add_value_plugin(euf::enode* n, model& mdl, expr_ref_vector& values) { expr_ref value(m); if (n->interpreted()) value = n->get_expr(); @@ -833,10 +890,16 @@ namespace intblast { for (auto arg : euf::enode_args(n)) args.push_back(values.get(arg->get_root_id())); rw.mk_app(n->get_decl(), args.size(), args.data(), value); - VERIFY(value); } else { - rational r = get_value(n->get_expr()); + expr_ref bv2int(bv.mk_bv2int(n->get_expr()), m); + euf::enode* b2i = ctx.get_enode(bv2int); + if (!b2i) verbose_stream() << bv2int << "\n"; + SASSERT(b2i); + VERIFY(b2i); + arith::arith_value av(ctx); + rational r; + VERIFY(av.get_value(b2i->get_expr(), r)); verbose_stream() << ctx.bpp(n) << " := " << r << "\n"; value = bv.mk_numeral(r, bv.get_bv_size(n->get_expr())); } diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index 7dd37d5a7..493b1f3c5 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -54,6 +54,7 @@ namespace intblast { scoped_ptr<::solver> m_solver; obj_map m_new_funs; expr_ref_vector m_translate, m_args; + ast_ref_vector m_pinned; sat::literal_vector m_core; ptr_vector m_bv2int, m_int2bv; statistics m_stats; @@ -65,8 +66,9 @@ namespace intblast { rational get_value(expr* e) const; + bool is_translated(expr* e) const { return !!m_translate.get(e->get_id(), nullptr); } expr* translated(expr* e) const { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } - void set_translated(expr* e, expr* r) { SASSERT(r); m_translate.setx(e->get_id(), r); } + void set_translated(expr* e, expr* r); expr* arg(unsigned i) { return m_args.get(i); } expr* umod(expr* bv_expr, unsigned i); @@ -83,9 +85,10 @@ namespace intblast { void ensure_translated(expr* e); void internalize_bv(app* e); - unsigned m_vars_qhead = 0; - ptr_vector m_vars; - void add_bound_axioms(); + unsigned m_vars_qhead = 0, m_preds_qhead = 0; + ptr_vector m_vars, m_preds; + bool add_bound_axioms(); + bool add_predicate_axioms(); euf::theory_var mk_var(euf::enode* n) override; @@ -109,7 +112,7 @@ namespace intblast { void collect_statistics(statistics& st) const override; - bool unit_propagate() override { return false; } + bool unit_propagate() override; void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override {} diff --git a/src/smt/theory_datatype.cpp b/src/smt/theory_datatype.cpp index cfc1f06f2..b794a44b5 100644 --- a/src/smt/theory_datatype.cpp +++ b/src/smt/theory_datatype.cpp @@ -915,7 +915,7 @@ namespace smt { } SASSERT(val == l_undef || (val == l_false && d->m_constructor == nullptr)); d->m_recognizers[c_idx] = recognizer; - m_trail_stack.push(set_vector_idx_trail(d->m_recognizers, c_idx)); + m_trail_stack.push(set_vector_idx_trail(d->m_recognizers, c_idx)); if (val == l_false) { propagate_recognizer(v, recognizer); } diff --git a/src/util/trail.h b/src/util/trail.h index 1aa7e4441..43e698234 100644 --- a/src/util/trail.h +++ b/src/util/trail.h @@ -219,12 +219,12 @@ public: } }; -template +template class set_vector_idx_trail : public trail { - ptr_vector & m_vector; + V & m_vector; unsigned m_idx; public: - set_vector_idx_trail(ptr_vector & v, unsigned idx): + set_vector_idx_trail(V & v, unsigned idx): m_vector(v), m_idx(idx) { } From 03730b2aad9f0bbf976a2b4d5321a5c61ba3bfd8 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 13 Dec 2023 14:16:35 -0800 Subject: [PATCH 52/89] new files --- src/sat/smt/arith_value.cpp | 145 ++ src/sat/smt/arith_value.h | 52 + src/sat/smt/polysat/saturation.cpp | 2190 ++++++++++++++++++++++++++++ src/sat/smt/polysat/saturation.h | 241 +++ 4 files changed, 2628 insertions(+) create mode 100644 src/sat/smt/arith_value.cpp create mode 100644 src/sat/smt/arith_value.h create mode 100644 src/sat/smt/polysat/saturation.cpp create mode 100644 src/sat/smt/polysat/saturation.h diff --git a/src/sat/smt/arith_value.cpp b/src/sat/smt/arith_value.cpp new file mode 100644 index 000000000..bb301808e --- /dev/null +++ b/src/sat/smt/arith_value.cpp @@ -0,0 +1,145 @@ +/*++ +Copyright (c) 2018 Microsoft Corporation + +Module Name: + + smt_arith_value.cpp + +Abstract: + + Utility to extract arithmetic values from context. + +Author: + + Nikolaj Bjorner (nbjorner) 2018-12-08. + +Revision History: + +--*/ + +#include "ast/ast_pp.h" +#include "sat/smt/arith_value.h" +#include "sat/smt/euf_solver.h" +#include "sat/smt/arith_solver.h" + +namespace arith { + + arith_value::arith_value(euf::solver& s) : + s(s), m(s.get_manager()), a(m) {} + + void arith_value::init() { + if (!as) + as = dynamic_cast(s.fid2solver(a.get_family_id())); + } + + bool arith_value::get_value(expr* e, rational& val) { + auto n = s.get_enode(e); + expr_ref _val(m); + init(); + return n && as->get_value(n, _val) && a.is_numeral(_val, val); + } + +#if 0 + bool arith_value::get_lo_equiv(expr* e, rational& lo, bool& is_strict) { + if (!s.get_enode(e)) + return false; + init(); + is_strict = false; + bool found = false; + bool is_strict1; + rational lo1; + for (auto sib : euf::enode_class(s.get_enode(e))) { + if (!as->get_lower(sib, lo1, is_strict1)) + continue; + if (!found || lo1 > lo || lo == lo1 && is_strict1) + lo = lo1, is_strict = is_strict1; + found = true; + } + CTRACE("arith_value", !found, tout << "value not found for " << mk_pp(e, m) << "\n";); + return found; + } + + bool arith_value::get_up_equiv(expr* e, rational& hi, bool& is_strict) { + if (!s.get_enode(e)) + return false; + init(); + is_strict = false; + bool found = false; + bool is_strict1; + rational hi1; + for (auto sib : euf::enode_class(s.get_enode(e))) { + if (!as->get_upper(sib, hi1, is_strict1)) + continue; + if (!found || hi1 < hi || hi == hi1 && is_strict1) + hi = hi1, is_strict = is_strict1; + found = true; + } + CTRACE("arith_value", !found, tout << "value not found for " << mk_pp(e, m) << "\n";); + return found; + } + + bool arith_value::get_up(expr* e, rational& up, bool& is_strict) const { + init(); + enode* n = s.get_enode(e); + is_strict = false; + return n && as->get_upper(n, up, is_strict); + } + + bool arith_value::get_lo(expr* e, rational& lo, bool& is_strict) const { + init(); + enode* n = s.get_enode(e); + is_strict = false; + return n && as->get_lower(n, lo, is_strict); + } + +#endif + + +#if 0 + + + bool arith_value::get_value_equiv(expr* e, rational& val) const { + if (!m_ctx->e_internalized(e)) return false; + expr_ref _val(m); + enode* next = m_ctx->get_enode(e), * n = next; + do { + e = next->get_expr(); + if (m_tha && m_tha->get_value(next, _val) && a.is_numeral(_val, val)) return true; + if (m_thi && m_thi->get_value(next, _val) && a.is_numeral(_val, val)) return true; + if (m_thr && m_thr->get_value(next, val)) return true; + next = next->get_next(); + } while (next != n); + TRACE("arith_value", tout << "value not found for " << mk_pp(e, m_ctx->get_manager()) << "\n";); + return false; + } + + expr_ref arith_value::get_lo(expr* e) const { + rational lo; + bool s = false; + if ((a.is_int_real(e) || b.is_bv(e)) && get_lo(e, lo, s) && !s) { + return expr_ref(a.mk_numeral(lo, e->get_sort()), m); + } + return expr_ref(e, m); + } + + expr_ref arith_value::get_up(expr* e) const { + rational up; + bool s = false; + if ((a.is_int_real(e) || b.is_bv(e)) && get_up(e, up, s) && !s) { + return expr_ref(a.mk_numeral(up, e->get_sort()), m); + } + return expr_ref(e, m); + } + + expr_ref arith_value::get_fixed(expr* e) const { + rational lo, up; + bool s = false; + if (a.is_int_real(e) && get_lo(e, lo, s) && !s && get_up(e, up, s) && !s && lo == up) { + return expr_ref(a.mk_numeral(lo, e->get_sort()), m); + } + return expr_ref(e, m); + } + +#endif + +}; diff --git a/src/sat/smt/arith_value.h b/src/sat/smt/arith_value.h new file mode 100644 index 000000000..b858ff896 --- /dev/null +++ b/src/sat/smt/arith_value.h @@ -0,0 +1,52 @@ + +/*++ +Copyright (c) 2018 Microsoft Corporation + +Module Name: + + arith_value.h + +Abstract: + + Utility to extract arithmetic values from context. + +Author: + + Nikolaj Bjorner (nbjorner) 2018-12-08. + +Revision History: + +--*/ +#pragma once + +#include "ast/arith_decl_plugin.h" + +namespace euf { + class solver; +} +namespace arith { + + class solver; + + class arith_value { + euf::solver& s; + ast_manager& m; + arith_util a; + solver* as = nullptr; + void init(); + public: + arith_value(euf::solver& s); + bool get_value(expr* e, rational& value); + +#if 0 + bool get_lo_equiv(expr* e, rational& lo, bool& strict); + bool get_up_equiv(expr* e, rational& up, bool& strict); + bool get_lo(expr* e, rational& lo, bool& strict); + bool get_up(expr* e, rational& up, bool& strict); + bool get_value_equiv(expr* e, rational& value); + expr_ref get_lo(expr* e); + expr_ref get_up(expr* e); + expr_ref get_fixed(expr* e); +#endif + }; +}; diff --git a/src/sat/smt/polysat/saturation.cpp b/src/sat/smt/polysat/saturation.cpp new file mode 100644 index 000000000..81fd6f221 --- /dev/null +++ b/src/sat/smt/polysat/saturation.cpp @@ -0,0 +1,2190 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + Polysat core saturation + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-6 + + +TODO: preserve falsification +- each rule selects a certain premises that are problematic. + If the problematic premise is false under the current assignment, the newly inferred + literal should also be false in the assignment in order to preserve conflicts. + + +TODO: when we check that 'x' is "unary": +- in principle, 'x' could be any polynomial. However, we need to divide the lhs by x, and we don't have general polynomial division yet. + so for now we just allow the form 'value*variable'. + (extension to arbitrary monomials for 'x' should be fairly easy too) + +--*/ +#include "sat/smt/polysat/core.h" +#include "sat/smt/polysat/saturation.h" +#include "sat/smt/polysat/umul_ovfl_constraint.h" +#include "sat/smt/polysat/ule_constraint.h" +#include "util/log.h" + + +namespace polysat { + + saturation::saturation(core& c) : c(c), C(c.cs()) {} + + void saturation::perform(pvar v) { + for (signed_constraint c : core) + perform(v, sc, core); + } + + bool saturation::perform(pvar v, signed_constraint sc) { + if (sc.is_currently_true(c)) + return false; + + if (sc.is_ule()) { + auto i = inequality::from_ule(sc); + return try_inequality(v, i); + } + if (sc.is_umul_ovfl()) + return try_umul_ovfl(v, sc); + + if (sc.is_op()) + return try_op(v, sc); + + return false; + } + + bool saturation::try_inequality(pvar v, inequality const& i, conflict& core) { + bool prop = false; + if (s.size(v) != i.lhs().power_of_2()) + return false; + if (try_nonzero_upper_extract(v, core, i)) + prop = true; + if (try_congruence(v, core, i)) + prop = true; + if (try_mul_bounds(v, core, i)) + prop = true; + if (try_parity(v, core, i)) + prop = true; + if (try_parity_diseq(v, core, i)) + prop = true; + if (try_transitivity(v, core, i)) + prop = true; + if (try_factor_equality(v, core, i)) + prop = true; + if (try_infer_equality(v, core, i)) + prop = true; + if (try_add_overflow_bound(v, core, i)) + prop = true; + if (try_add_mul_bound(v, core, i)) + prop = true; + if (try_infer_parity_equality(v, core, i)) + prop = true; + if (try_mul_eq_bound(v, core, i)) + prop = true; + if (try_ugt_x(v, core, i)) + prop = true; + if (try_ugt_y(v, core, i)) + prop = true; + if (try_ugt_z(v, core, i)) + prop = true; + if (try_y_l_ax_and_x_l_z(v, core, i)) + prop = true; + if (false && try_tangent(v, core, i)) + prop = true; + return prop; + } + + bool saturation::try_congruence(pvar x, conflict& core, inequality const& i) { + set_rule("egraph(x == y) & C(x,y) ==> C(y,y)"); + // TODO: generalize to other constraint types? + // if (!i.is_strict()) + // return false; + // if (!core.vars().contains(x)) + // return false; + if (!i.as_signed_constraint().contains_var(x)) + return false; + for (pvar y : s.m_slicing.equivalent_vars(x)) { + if (x == y) + continue; + if (!s.is_assigned(y)) + continue; + if (!core.vars().contains(y)) + continue; + if (!i.as_signed_constraint().contains_var(y)) + continue; + SASSERT(s.m_search.get_pvar_index(y) < s.m_search.get_pvar_index(x)); // y was the earlier one since we are currently resolving x + pdd const lhs = i.lhs().subst_pdd(x, s.var(y)); + pdd const rhs = i.rhs().subst_pdd(x, s.var(y)); + signed_constraint c = ineq(true, lhs, rhs); + m_lemma.reset(); + s.m_slicing.explain_equal(x, y, [&](sat::literal lit) { m_lemma.insert(~lit); }); + if (propagate(x, core, i, c)) + return true; + } + return false; + } + + bool saturation::try_nonzero_upper_extract(pvar y, conflict& core, inequality const& i) { + set_rule("y = x[h:l] & y != 0 ==> x >= 2^l"); + if (!s.m_justification[y].is_propagation_by_slicing()) + return false; + if (!s.get_value(y).is_zero()) + return false; + if (!is_nonzero_by(y, i)) + return false; + for (pvar x : core.vars()) { + if (!s.get_value(x).is_zero()) + continue; + unsigned hi, lo; + if (!s.m_slicing.is_extract(y, x, hi, lo)) // TODO: generalize; use is_equal to check this and if yes, extract the explanation. otherwise it will only work in very limited cases. + continue; + if (propagate(y, core, i, s.ule(rational::power_of_two(lo), s.var(x)))) + return true; + } + return false; + } + + // TODO: can be generalized + bool saturation::is_nonzero_by(pvar x, inequality const& i) { + if (i.is_strict() && i.lhs().is_zero()) { + // 0 < p + pdd const& p = i.rhs(); + if (p.is_val()) + return false; + if (!p.lo().is_zero()) + return false; + if (!p.hi().is_val()) + return false; + SASSERT(!p.hi().is_zero()); + // 0 < a*x for a != 0 + return true; + } + return false; + } + + bool saturation::try_umul_ovfl(pvar v, signed_constraint c, conflict& core) { + SASSERT(c->is_umul_ovfl()); + bool prop = false; + if (try_umul_ovfl_noovfl(v, c, core)) + prop = true; +#if 0 + if (c.is_positive()) { + prop = try_umul_ovfl_bounds(v, c, core); + } + else { + prop = try_umul_noovfl_bounds(v, c, core); + if (false && try_umul_noovfl_lo(v, c, core)) + prop = true; + } +#endif + return prop; + } + + // Ovfl(x, y) & ~Ovfl(y, z) ==> x > z + // TODO: Ovfl(x, y1) & ~Ovfl(y2, z) & y1 <= y2 ==> x > z + bool saturation::try_umul_ovfl_noovfl(pvar v, signed_constraint c1, conflict& core) { + set_rule("[y] ovfl(x, y) & ~ovfl(y, z) ==> x > z"); + SASSERT(c1->is_umul_ovfl()); + if (!c1.is_positive()) // since we search for both premises in the core, break symmetry by requiring c1 to be positive + return false; + pdd p = c1->to_umul_ovfl().p(); + pdd q = c1->to_umul_ovfl().q(); + for (auto c2 : core) { + if (!c2.is_negative()) + continue; + if (!c2->is_umul_ovfl()) + continue; + pdd r = c2->to_umul_ovfl().p(); + pdd u = c2->to_umul_ovfl().q(); + if (p == u || q == u) + swap(r, u); + if (q == r || q == u) + swap(p, q); + if (p != r) + continue; + m_lemma.reset(); + m_lemma.insert(~c1); + m_lemma.insert(~c2); + if (propagate(v, core, s.ult(u, q))) + return true; + } + return false; + } + + bool saturation::try_umul_noovfl_lo(pvar v, signed_constraint c, conflict& core) { + set_rule("[x] ~ovfl(x, y) => y = 0 or x <= x * y"); + SASSERT(c->is_umul_ovfl()); + if (!c.is_negative()) + return false; + auto const& ovfl = c->to_umul_ovfl(); + auto V = s.var(v); + auto p = ovfl.p(), q = ovfl.q(); + // TODO could relax condition to be that V occurs in p + if (q == V) + std::swap(p, q); + signed_constraint q_eq_0; + if (p == V && is_forced_diseq(q, 0, q_eq_0)) { + // ~ovfl(V,q) => q = 0 or V <= V*q + m_lemma.reset(); + m_lemma.insert_eval(q_eq_0); + if (propagate(v, core, c, s.ule(p, p * q))) + return true; + } + return false; + } + + /** + * ~ovfl(p, q) & p >= k => q < 2^N/k + * TODO: subsumed by narrowing inferences? + */ + bool saturation::try_umul_noovfl_bounds(pvar x, signed_constraint c, conflict& core) { + set_rule("[x] ~ovfl(x, q) & x >= k => q <= (2^N-1)/k"); + SASSERT(c->is_umul_ovfl()); + SASSERT(c.is_negative()); + auto const& ovfl = c->to_umul_ovfl(); + auto p = ovfl.p(), q = ovfl.q(); + auto X = s.var(x); + auto& m = p.manager(); + rational p_val, q_val; + if (q == X) + std::swap(p, q); + if (p == X) { + vector x_ge_bound; + if (!s.try_eval(q, q_val)) + return false; + if (!has_lower_bound(x, core, p_val, x_ge_bound)) + return false; + if (p_val * q_val <= m.max_value()) + return false; + m_lemma.reset(); + m_lemma.insert_eval(~s.uge(X, p_val)); + signed_constraint conseq = s.ule(q, floor(m.max_value()/p_val)); + return propagate(x, core, c, conseq); + } + if (!s.try_eval(p, p_val) || !s.try_eval(q, q_val)) + return false; + if (p_val * q_val <= m.max_value()) + return false; + m_lemma.reset(); + m_lemma.insert_eval(~s.uge(q, q_val)); + signed_constraint conseq = s.ule(p, floor(m.max_value()/q_val)); + return propagate(x, core, c, conseq); + } + + /** + * ovfl(p, q) & p <= k => q > 2^N/k + */ + bool saturation::try_umul_ovfl_bounds(pvar v, signed_constraint c, conflict& core) { + SASSERT(c->is_umul_ovfl()); + SASSERT(c.is_positive()); + auto const& ovfl = c->to_umul_ovfl(); + auto p = ovfl.p(), q = ovfl.q(); + rational p_val, q_val; + return false; + } + + bool saturation::try_op(pvar v, signed_constraint c, conflict& core) { + set_rule("try_op"); + SASSERT(c->is_op()); + SASSERT(c.is_positive()); + clause_ref correction = c.produce_lemma(s, s.get_assignment()); + if (correction) { + IF_LOGGING( + LOG("correcting op_constraint: " << *correction); + for (sat::literal lit : *correction) { + LOG("\t" << lit_pp(s, lit)); + } + ); + + for (sat::literal lit : *correction) + if (!s.m_bvars.is_assigned(lit) && s.lit2cnstr(lit).is_currently_false(s)) + s.assign_eval(~lit); + core.add_lemma(correction); + log_lemma(v, core); + return true; + } + return false; + } + + signed_constraint saturation::ineq(bool is_strict, pdd const& lhs, pdd const& rhs) { + if (is_strict) + return s.ult(lhs, rhs); + else + return s.ule(lhs, rhs); + } + + bool saturation::propagate(pvar v, conflict& core, inequality const& crit, signed_constraint c) { + return propagate(v, core, crit.as_signed_constraint(), c); + } + + bool saturation::propagate(pvar v, conflict& core, signed_constraint crit, signed_constraint c) { + m_lemma.insert(~crit); + return propagate(v, core, c); + } + + bool saturation::propagate(pvar v, conflict& core, signed_constraint c) { + if (is_forced_true(c)) + return false; + + SASSERT(all_of(m_lemma, [this](sat::literal lit) { return is_forced_false(s.lit2cnstr(lit)); })); + + m_lemma.insert(c); + m_lemma.set_name(m_rule); + core.add_lemma(m_lemma.build()); + log_lemma(v, core); + return true; + } + + bool saturation::add_conflict(pvar v, conflict& core, inequality const& crit1, signed_constraint c) { + return add_conflict(v, core, crit1, crit1, c); + } + + bool saturation::add_conflict(pvar v, conflict& core, inequality const& _crit1, inequality const& _crit2, signed_constraint const c) { + auto crit1 = _crit1.as_signed_constraint(); + auto crit2 = _crit2.as_signed_constraint(); + m_lemma.insert(~crit1); + if (crit1 != crit2) + m_lemma.insert(~crit2); + + LOG("critical " << m_rule << " " << crit1); + LOG("consequent " << c << " value: " << c.bvalue(s) << " is-false: " << c.is_currently_false(s)); + + SASSERT(all_of(m_lemma, [this](sat::literal lit) { return s.m_bvars.value(lit) == l_false; })); + + // Ensure lemma is a conflict lemma + if (!is_forced_false(c)) + return false; + + // Constraint c is already on the search stack, so the lemma will not derive anything new. + if (c.bvalue(s) == l_true) + return false; + + m_lemma.insert_eval(c); + m_lemma.set_name(m_rule); + core.add_lemma(m_lemma.build()); + log_lemma(v, core); + return true; + } + + bool saturation::is_non_overflow(pdd const& x, pdd const& y, signed_constraint& c) { + + if (is_non_overflow(x, y)) { + c = ~s.umul_ovfl(x, y); + return true; + } + + // TODO: do we really search the stack or can we just create the literal s.umul_ovfl(x, y) + // and check if it is assigned, or not even create the literal but look up whether it is assigned? + // constraint_manager uses m_dedup, alloc + // but to probe whether a literal occurs these are not needed. + // m_dedup.constraints.contains(&c); + + for (auto si : s.m_search) { + if (!si.is_boolean()) + continue; + if (si.is_resolved()) + continue; + auto d = s.lit2cnstr(si.lit()); + if (!d->is_umul_ovfl() || !d.is_negative()) + continue; + auto const& ovfl = d->to_umul_ovfl(); + if (x != ovfl.p() && x != ovfl.q()) + continue; + if (y != ovfl.p() && y != ovfl.q()) + continue; + c = d; + return true; + } + return false; + } + + /* + * Match [v] .. <= v + */ + bool saturation::is_l_v(pvar v, inequality const& i) { + return i.rhs() == s.var(v); + } + + /* + * Match [v] v <= ... + */ + bool saturation::is_g_v(pvar v, inequality const& i) { + return i.lhs() == s.var(v); + } + + /* + * Match [x] x <= y + */ + bool saturation::is_x_l_Y(pvar x, inequality const& i, pdd& y) { + y.reset(i.rhs().manager()); + y = i.rhs(); + return is_g_v(x, i); + } + + /* + * Match [x] y <= x + */ + bool saturation::is_Y_l_x(pvar x, inequality const& i, pdd& y) { + y.reset(i.lhs().manager()); + y = i.lhs(); + return is_l_v(x, i); + } + + /* + * Match [x] y <= a*x + */ + bool saturation::is_Y_l_Ax(pvar x, inequality const& i, pdd& a, pdd& y) { + y.reset(i.lhs().manager()); + y = i.lhs(); + return is_xY(x, i.rhs(), a); + } + + bool saturation::verify_Y_l_Ax(pvar x, inequality const& i, pdd const& a, pdd const& y) { + return i.lhs() == y && i.rhs() == a * s.var(x); + } + + /** + * Match [x] a*x <= y + */ + bool saturation::is_Ax_l_Y(pvar x, inequality const& i, pdd& a, pdd& y) { + y.reset(i.rhs().manager()); + y = i.rhs(); + return is_xY(x, i.lhs(), a); + } + + bool saturation::verify_Ax_l_Y(pvar x, inequality const& i, pdd const& a, pdd const& y) { + return i.rhs() == y && i.lhs() == a * s.var(x); + } + + /** + * Match [x] a*x + b <= y + */ + bool saturation::is_AxB_l_Y(pvar x, inequality const& i, pdd& a, pdd& b, pdd& y) { + y.reset(i.rhs().manager()); + y = i.rhs(); + return i.lhs().degree(x) == 1 && (i.lhs().factor(x, 1, a, b), true); + } + + bool saturation::verify_AxB_l_Y(pvar x, inequality const& i, pdd const& a, pdd const& b, pdd const& y) { + return i.rhs() == y && i.lhs() == a * s.var(x) + b; + } + + + bool saturation::is_Y_l_AxB(pvar x, inequality const& i, pdd& y, pdd& a, pdd& b) { + y.reset(i.lhs().manager()); + y = i.lhs(); + return i.rhs().degree(x) == 1 && (i.rhs().factor(x, 1, a, b), true); + } + + bool saturation::verify_Y_l_AxB(pvar x, inequality const& i, pdd const& y, pdd const& a, pdd& b) { + return i.lhs() == y && i.rhs() == a * s.var(x) + b; + } + + + /** + * Match [x] a*x + b <= y, val(y) = 0 + */ + bool saturation::is_AxB_eq_0(pvar x, inequality const& i, pdd& a, pdd& b, pdd& y) { + y.reset(i.rhs().manager()); + y = i.rhs(); + rational y_val; + if (!s.try_eval(y, y_val) || y_val != 0) + return false; + return i.lhs().degree(x) == 1 && (i.lhs().factor(x, 1, a, b), true); + } + + bool saturation::verify_AxB_eq_0(pvar x, inequality const& i, pdd const& a, pdd const& b, pdd const& y) { + return y.is_val() && y.val() == 0 && i.rhs() == y && i.lhs() == a * s.var(x) + b; + } + + bool saturation::is_AxB_diseq_0(pvar x, inequality const& i, pdd& a, pdd& b, pdd& y) { + if (!i.is_strict()) + return false; + y.reset(i.lhs().manager()); + y = i.lhs(); + if (i.rhs().is_val() && i.rhs().val() == 1) + return false; + rational y_val; + if (!s.try_eval(y, y_val) || y_val != 0) + return false; + a.reset(i.rhs().manager()); + b.reset(i.rhs().manager()); + return i.rhs().degree(x) == 1 && (i.rhs().factor(x, 1, a, b), true); + } + + /** + * Match [coeff*x] coeff*x*Y where x is a variable + */ + bool saturation::is_coeffxY(pdd const& x, pdd const& p, pdd& y) { + pdd xy = x.manager().zero(); + return x.is_unary() && p.try_div(x.hi().val(), xy) && xy.factor(x.var(), 1, y); + } + + /** + * Determine whether values of x * y is non-overflowing. + */ + bool saturation::is_non_overflow(pdd const& x, pdd const& y) { + rational x_val, y_val; + rational bound = x.manager().two_to_N(); + return s.try_eval(x, x_val) && s.try_eval(y, y_val) && x_val * y_val < bound; + } + + /** + * Match [v] v*x <= z*x with x a variable + */ + bool saturation::is_Xy_l_XZ(pvar v, inequality const& i, pdd& x, pdd& z) { + return is_xY(v, i.lhs(), x) && is_coeffxY(x, i.rhs(), z); + } + + bool saturation::verify_Xy_l_XZ(pvar v, inequality const& i, pdd const& x, pdd const& z) { + return i.lhs() == s.var(v) * x && i.rhs() == z * x; + } + + /** + * Match [z] yx <= zx with x a variable + */ + bool saturation::is_YX_l_zX(pvar z, inequality const& c, pdd& x, pdd& y) { + return is_xY(z, c.rhs(), x) && is_coeffxY(x, c.lhs(), y); + } + + bool saturation::verify_YX_l_zX(pvar z, inequality const& c, pdd const& x, pdd const& y) { + return c.lhs() == y * x && c.rhs() == s.var(z) * x; + } + + /** + * Match [x] xY <= xZ + */ + bool saturation::is_xY_l_xZ(pvar x, inequality const& c, pdd& y, pdd& z) { + return is_xY(x, c.lhs(), y) && is_xY(x, c.rhs(), z); + } + + /** + * Match xy = x * Y + */ + bool saturation::is_xY(pvar x, pdd const& xy, pdd& y) { + return xy.degree(x) == 1 && xy.factor(x, 1, y); + } + + // + // overall comment: we use value propagation to check if p is val + // but we could also use literal propagation and establish there is a literal p = 0 that is true. + // in this way the value of p doesn't have to be fixed. + // + // is_forced_diseq already creates a literal. + // is_non_overflow also creates a literal + // + // The condition that p = val may be indirect. + // it could be a literal + // it could be by propagation of literals + // Example: + // -35: v90 + v89*v43 + -1*v87 != 0 [ l_false bprop@0 pwatched ] + // 36: ovfl*(v43, v89) [ l_false bprop@0 pwatched ] + // -218: v90 + -1*v87 + -1 != 0 [ l_false eval@6 pwatched ] + // + // what should we "pay" to establish this condition? + // or do we just afford us to add this lemma? + // + + bool saturation::is_forced_eq(pdd const& p, rational const& val) { + rational pv; + if (s.try_eval(p, pv) && pv == val) + return true; + return false; + } + + bool saturation::is_forced_diseq(pdd const& p, rational const& val, signed_constraint& c) { + c = s.eq(p, val); + return is_forced_false(c); + } + + bool saturation::is_forced_odd(pdd const& p, signed_constraint& c) { + c = s.odd(p); + return is_forced_true(c); + } + + bool saturation::is_forced_false(signed_constraint const& c) { + return c.bvalue(s) == l_false || c.is_currently_false(s); + } + + bool saturation::is_forced_true(signed_constraint const& c) { + return c.bvalue(s) == l_true || c.is_currently_true(s); + } + + /** + * Implement the inferences + * [x] yx < zx ==> Ω*(x,y) \/ y < z + * [x] yx <= zx ==> Ω*(x,y) \/ y <= z \/ x = 0 + */ + bool saturation::try_ugt_x(pvar v, conflict& core, inequality const& xy_l_xz) { + set_rule("[x] yx <= zx"); + pdd x = s.var(v); + pdd y = x; + pdd z = x; + signed_constraint non_ovfl; + + if (!is_xY_l_xZ(v, xy_l_xz, y, z)) + return false; + if (!xy_l_xz.is_strict() && s.is_assigned(v) && s.get_value(v).is_zero()) + return false; + if (!is_non_overflow(x, y, non_ovfl)) + return false; + m_lemma.reset(); + m_lemma.insert_eval(~non_ovfl); + if (!xy_l_xz.is_strict()) + m_lemma.insert_eval(s.eq(x)); + return add_conflict(v, core, xy_l_xz, ineq(xy_l_xz.is_strict(), y, z)); + } + + /** + * [y] z' <= y /\ yx <= zx ==> Ω*(x,y) \/ z'x <= zx + * [y] z' <= y /\ yx < zx ==> Ω*(x,y) \/ z'x < zx + * [y] z' < y /\ yx <= zx ==> Ω*(x,y) \/ z'x <= zx + * [y] z' < y /\ yx < zx ==> Ω*(x,y) \/ z'x < zx + * [y] z' < y /\ yx < zx ==> Ω*(x,y) \/ z'x + 1 < zx (TODO?) + * [y] z' < y /\ yx < zx ==> Ω*(x,y) \/ (z' + 1)x < zx (TODO?) + */ + bool saturation::try_ugt_y(pvar v, conflict& core, inequality const& yx_l_zx) { + set_rule("[y] z' <= y & yx <= zx"); + auto& m = s.var2pdd(v); + pdd x = m.zero(); + pdd z = m.zero(); + if (!is_Xy_l_XZ(v, yx_l_zx, x, z)) + return false; + for (auto si : s.m_search) { + if (!si.is_boolean()) + continue; + if (si.is_resolved()) + continue; + auto d = s.lit2cnstr(si.lit()); + if (!d->is_ule()) + continue; + auto l_y = inequality::from_ule(d); + if (is_l_v(v, l_y) && try_ugt_y(v, core, l_y, yx_l_zx, x, z)) + return true; + } + return false; + } + + bool saturation::try_ugt_y(pvar v, conflict& core, inequality const& l_y, inequality const& yx_l_zx, pdd const& x, pdd const& z) { + SASSERT(is_l_v(v, l_y)); + SASSERT(verify_Xy_l_XZ(v, yx_l_zx, x, z)); + pdd const y = s.var(v); + signed_constraint non_ovfl; + if (!is_non_overflow(x, y, non_ovfl)) + return false; + pdd const& z_prime = l_y.lhs(); + m_lemma.reset(); + m_lemma.insert_eval(~non_ovfl); + return add_conflict(v, core, l_y, yx_l_zx, ineq(yx_l_zx.is_strict(), z_prime * x, z * x)); + } + + /** + * [z] z <= y' /\ yx <= zx ==> Ω*(x,y') \/ yx <= y'x + * [z] z <= y' /\ yx < zx ==> Ω*(x,y') \/ yx < y'x + * [z] z < y' /\ yx <= zx ==> Ω*(x,y') \/ yx <= y'x + * [z] z < y' /\ yx < zx ==> Ω*(x,y') \/ yx < y'x + * [z] z < y' /\ yx < zx ==> Ω*(x,y') \/ yx+1 < y'x (TODO?) + * [z] z < y' /\ yx < zx ==> Ω*(x,y') \/ (y+1)x < y'x (TODO?) + */ + bool saturation::try_ugt_z(pvar z, conflict& core, inequality const& yx_l_zx) { + set_rule("[z] z <= y' && yx <= zx"); + auto& m = s.var2pdd(z); + pdd y = m.zero(); + pdd x = m.zero(); + if (!is_YX_l_zX(z, yx_l_zx, x, y)) + return false; + for (auto si : s.m_search) { + if (!si.is_boolean()) + continue; + if (si.is_resolved()) + continue; + auto d = s.lit2cnstr(si.lit()); + if (!d->is_ule()) + continue; + auto z_l_y = inequality::from_ule(d); + if (is_g_v(z, z_l_y) && try_ugt_z(z, core, z_l_y, yx_l_zx, x, y)) + return true; + } + return false; + } + + bool saturation::try_ugt_z(pvar z, conflict& core, inequality const& z_l_y, inequality const& yx_l_zx, pdd const& x, pdd const& y) { + SASSERT(is_g_v(z, z_l_y)); + SASSERT(verify_YX_l_zX(z, yx_l_zx, x, y)); + pdd const& y_prime = z_l_y.rhs(); + signed_constraint non_ovfl; + if (!is_non_overflow(x, y_prime, non_ovfl)) + return false; + m_lemma.reset(); + m_lemma.insert_eval(~non_ovfl); + return add_conflict(z, core, yx_l_zx, z_l_y, ineq(yx_l_zx.is_strict(), y * x, y_prime * x)); + } + + /** + * [x] y <= ax /\ x <= z (non-overflow case) + * ==> Ω*(a, z) \/ y <= az + * ... (other combinations of <, <=) + */ + bool saturation::try_y_l_ax_and_x_l_z(pvar x, conflict& core, inequality const& y_l_ax) { + set_rule("[x] y <= ax & x <= z"); + auto& m = s.var2pdd(x); + pdd y = m.zero(); + pdd a = m.zero(); + if (!is_Y_l_Ax(x, y_l_ax, a, y)) + return false; + if (a.is_one()) + return false; + for (auto si : s.m_search) { + if (!si.is_boolean()) + continue; + if (si.is_resolved()) + continue; + auto d = s.lit2cnstr(si.lit()); + if (!d->is_ule()) + continue; + auto x_l_z = inequality::from_ule(d); + if (is_g_v(x, x_l_z) && try_y_l_ax_and_x_l_z(x, core, y_l_ax, x_l_z, a, y)) + return true; + } + return false; + } + + bool saturation::try_y_l_ax_and_x_l_z(pvar x, conflict& core, inequality const& y_l_ax, inequality const& x_l_z, pdd const& a, pdd const& y) { + SASSERT(is_g_v(x, x_l_z)); + SASSERT(verify_Y_l_Ax(x, y_l_ax, a, y)); + pdd const& z = x_l_z.rhs(); + signed_constraint non_ovfl; + if (!is_non_overflow(a, z, non_ovfl)) + return false; + m_lemma.reset(); + m_lemma.insert_eval(~non_ovfl); + return add_conflict(x, core, y_l_ax, x_l_z, ineq(x_l_z.is_strict() || y_l_ax.is_strict(), y, a * z)); + } + + /** + * [x] a <= k & a*x + b = 0 & b = 0 => a = 0 or x = 0 or x >= 2^K/k + * [x] x <= k & a*x + b = 0 & b = 0 => x = 0 or a = 0 or a >= 2^K/k + * Better? + * [x] a*x + b = 0 & b = 0 => a = 0 or x = 0 or Ω*(a, x) + * We need up to four versions of this for all sign combinations of a, x + */ + bool saturation::try_mul_bounds(pvar x, conflict& core, inequality const& axb_l_y) { + set_rule("[x] a*x + b = 0 & b = 0 => a = 0 or x = 0 or ovfl(a, x)"); + auto& m = s.var2pdd(x); + pdd y = m.zero(); + pdd a = m.zero(); + pdd b = m.zero(); + pdd k = m.zero(); + pdd X = s.var(x); + rational k_val; + if (!is_AxB_eq_0(x, axb_l_y, a, b, y)) + return false; + if (a.is_val()) + return false; + if (!is_forced_eq(b, 0)) + return false; + + signed_constraint x_eq_0, a_eq_0; + if (!is_forced_diseq(X, 0, x_eq_0)) + return false; + if (!is_forced_diseq(a, 0, a_eq_0)) + return false; + + auto prop1 = [&](signed_constraint c) { + m_lemma.reset(); + m_lemma.insert_eval(~s.eq(b)); + m_lemma.insert_eval(~s.eq(y)); + m_lemma.insert_eval(x_eq_0); + m_lemma.insert_eval(a_eq_0); + return propagate(x, core, axb_l_y, c); + }; + + auto prop2 = [&](signed_constraint ante, signed_constraint c) { + m_lemma.reset(); + m_lemma.insert_eval(~s.eq(b)); + m_lemma.insert_eval(~s.eq(y)); + m_lemma.insert_eval(x_eq_0); + m_lemma.insert_eval(a_eq_0); + m_lemma.insert_eval(~ante); + return propagate(x, core, axb_l_y, c); + }; + + pdd minus_a = -a; + pdd minus_X = -X; + pdd Y = X; + for (auto si : s.m_search) { + if (!si.is_boolean()) + continue; + if (si.is_resolved()) + continue; + auto d = s.lit2cnstr(si.lit()); + if (!d->is_ule()) + continue; + auto u_l_k = inequality::from_ule(d); + if (u_l_k.rhs().power_of_2() != m.power_of_2()) + continue; + // a <= k or x <= k + k = u_l_k.rhs(); + if (!k.is_val()) + continue; + k_val = k.val(); + if (u_l_k.is_strict()) + k_val -= 1; + if (k_val <= 1) + continue; + if (u_l_k.lhs() == a || u_l_k.lhs() == minus_a) + Y = X; + else if (u_l_k.lhs() == X || u_l_k.lhs() == minus_X) + Y = a; + else + continue; + // + // NSB review: should we handle cases where k_val >= 2^{K-1}, but exploit that x*y = 0 iff -x*y = 0? + // + // IF_VERBOSE(0, verbose_stream() << "mult-bounds2 " << Y << " " << axb_l_y << " " << u_l_k<< " \n"); + rational bound = ceil(rational::power_of_two(m.power_of_2()) / k_val); + if (prop2(d, s.uge(Y, bound))) + return true; + if (prop2(d, s.uge(-Y, bound))) + return true; + } + + // IF_VERBOSE(0, verbose_stream() << "mult-bounds1 " << a << " " << axb_l_y << " \n"); + // IF_VERBOSE(0, verbose_stream() << core << "\n"); + if (prop1(s.umul_ovfl(a, X))) + return true; + if (prop1(s.umul_ovfl(a, -X))) + return true; + if (prop1(s.umul_ovfl(-a, X))) + return true; + if (prop1(s.umul_ovfl(-a, -X))) + return true; + + return false; + } + + + // bench 5 + // fairly ad-hoc rule derived from bench 5. + // The clause could also be added whenever narrowing the literal 2^k*x = 2^k*y + // It can be expected to be relatively common because these equalities come from + // bit-masking. + // + // A bigger hammer for detecting such propagations may be through LIA or a variant + // + // a*x - a*y + b*z = 0 0 <= x < b/a, 0 <= y < b/a => z = 0 + // and then => x = y + // + // a general lemma is that the linear term a*p = 0 is such that a*p does not overflow + // and therefore p = 0 + // + // the rule would also be subsumed by equality rewriting modulo parity + // + // TBD: encode the general lemma instead of this special case. + // + bool saturation::try_mul_eq_bound(pvar x, conflict& core, inequality const& axb_l_y) { + set_rule("[x] 2^k*x = 2^k*y & x < 2^N-k => y = x or y >= 2^{N-k}"); + auto& m = s.var2pdd(x); + unsigned const N = m.power_of_2(); + pdd y = m.zero(); + pdd a = y, b = y, a2 = y; + pdd const X = s.var(x); + // Match ax + b <= y with eval(y) = 0 + if (!is_AxB_eq_0(x, axb_l_y, a, b, y)) + return false; + if (!a.is_val()) + return false; + if (!a.val().is_power_of_two()) + return false; + unsigned const k = a.val().trailing_zeros(); + if (k == 0) + return false; + // x*2^k = b, x <= a2 < 2^{N-k} + rational const bound = rational::power_of_two(N - k); + b = -b; + if (b.leading_coefficient() != a.val()) + return false; + pdd Y = m.zero(); + if (!b.try_div(b.leading_coefficient(), Y)) + return false; + rational Y_val; + if (!s.try_eval(Y, Y_val) || Y_val >= bound) + return false; + for (auto c : core) { + if (!c->is_ule()) + continue; + auto i = inequality::from_ule(c); + if (!is_x_l_Y(x, i, a2)) + continue; + if (!a2.is_val()) + continue; + if (i.is_strict() && a2.val() >= bound) + continue; + if (!i.is_strict() && a2.val() > bound) + continue; + signed_constraint le = s.ule(Y, bound - 1); + m_lemma.reset(); + m_lemma.insert_eval(~le); + m_lemma.insert_eval(~s.eq(y)); + m_lemma.insert(~c); + if (propagate(x, core, axb_l_y, s.eq(X, Y))) + return true; + } + return false; + } + + /* + * x*y = 1 & ~ovfl(x,y) => x = 1 + * x*y = -1 & ~ovfl(-x,y) => -x = 1 + */ + bool saturation::try_mul_eq_1(pvar x, conflict& core, inequality const& axb_l_y) { + set_rule("[x] ax + b <= y & y = 0 & b = -1 & ~ovfl(a,x) => x = 1"); + auto& m = s.var2pdd(x); + pdd y = m.zero(); + pdd a = m.zero(); + pdd b = m.zero(); + pdd X = s.var(x); + signed_constraint non_ovfl; + if (!is_AxB_eq_0(x, axb_l_y, a, b, y)) + return false; + if (!is_forced_eq(b, -1)) + return false; + if (!is_non_overflow(a, X, non_ovfl)) + return false; + m_lemma.reset(); + m_lemma.insert_eval(~s.eq(b, rational(-1))); + m_lemma.insert_eval(~s.eq(y)); + m_lemma.insert_eval(~non_ovfl); + if (propagate(x, core, axb_l_y, s.eq(X, 1))) + return true; + if (propagate(x, core, axb_l_y, s.eq(a, 1))) + return true; + return false; + } + + /** + * odd(x*y) => odd(x) + * even(x) => even(x*y) + * + * parity(x) <= parity(x*y) + * parity(x) = k & parity(x*y) = k + j => parity(y) = j + * parity(x) = k & parity(y) = j => parity(x*y) = k + j + * + * odd(x) & even(y) => x + y != 0 + * + * Special case rule: a*x + y = 0 => (odd(y) <=> odd(a) & odd(x)) + * + * General rule: + * + * a*x + y = 0 => min(K, parity(a) + parity(x)) = parity(y) + * + * using inequalities: + * + * parity(x) <= i, parity(a) <= j => parity(y) <= i + j + * parity(x) >= i, parity(a) >= j => parity(y) >= i + j + * parity(x) <= i, parity(y) >= j => parity(a) >= j - i + * parity(x) >= i, parity(y) <= j => parity(a) <= j - i + * symmetric rules for swapping x, a + * + * min_parity(x) = N if x = 0 + * min_parity(x) = number of trailing bits of x if x is a non-zero value + * min_parity(x) = k if 2^{N-k}*x == 0 is forced for max k + * min_parity(x1*x2) = min_parity(x1) + min_parity(x2) + * min_parity(x) = 0, otherwise + * + * max_parity(x) = N if x = 0 + * max_parity(x) = number of trailing bits of x if x is a non-zero value + * max_parity(x) = k if 2^{N-k-1}*x != 0 for min k + * max_parity(x1*x2) = max_parity(x1) + max_parity(x2) + * max_parity(x) = N, otherwise + * + */ + unsigned saturation::min_parity(pdd const& p, vector& explain) { + auto& m = p.manager(); + unsigned const N = m.power_of_2(); + if (p.is_val()) + return p.val().parity(N); + + rational val; + if (s.try_eval(p, val)) { + unsigned k = val.parity(N); + if (k > 0) + explain.push_back(s.parity_at_least(p, k)); + return k; + } + + unsigned min = 0; + unsigned const sz = explain.size(); + if (!p.is_var()) { + // parity of a product => sum of parities + // parity of sum => minimum of monomial's minimal parities + min = N; + for (auto const& monomial : p) { + SASSERT(!monomial.coeff.is_zero()); + unsigned parity_sum = monomial.coeff.trailing_zeros(); + for (pvar v : monomial.vars) + parity_sum += min_parity(m.mk_var(v), explain); + min = std::min(min, parity_sum); + } + } + SASSERT(min <= N); + + for (unsigned j = N; j > min; --j) + if (is_forced_true(s.parity_at_least(p, j))) { + explain.shrink(sz); + explain.push_back(s.parity_at_least(p, j)); + return j; + } + return min; + } + + unsigned saturation::max_parity(pdd const& p, vector& explain) { + auto& m = p.manager(); + unsigned N = m.power_of_2(); + rational val; + if (p.is_val()) + return p.val().parity(N); + + if (s.try_eval(p, val)) { + unsigned k = val.parity(N); + if (k != N) + explain.push_back(s.parity_at_most(p, k)); + return k; + } + + unsigned max = N; + unsigned sz = explain.size(); + if (!p.is_var() && p.is_monomial()) { + // it's just a product => sum them up + // the case of a sum is harder as the lower bound (because of carry bits) + // ==> restricted for now to monomials + dd::pdd_monomial monomial = *p.begin(); + max = monomial.coeff.trailing_zeros(); + for (pvar c : monomial.vars) + max += max_parity(m.mk_var(c), explain); + } + + for (unsigned j = 0; j < max; ++j) + if (is_forced_true(s.parity_at_most(p, j))) { + explain.shrink(sz); + explain.push_back(s.parity_at_most(p, j)); + return j; + } + return max; + } + + bool saturation::try_parity(pvar x, conflict& core, inequality const& axb_l_y) { + auto& m = s.var2pdd(x); + unsigned N = m.power_of_2(); + pdd y = m.zero(); + pdd a = y, b = y; + pdd X = s.var(x); + if (!is_AxB_eq_0(x, axb_l_y, a, b, y)) + return false; + if (a.is_max() && b.is_var()) // x == y, we propagate values in each direction and don't need a lemma + return false; + if (a.is_one() && (-b).is_var()) // y == x + return false; + if (a.is_one()) // TODO: Sure this is correct? + return false; + if (a.is_val() && b.is_zero()) + return false; + + auto propagate1 = [&](vector const& premise, signed_constraint conseq) { + IF_VERBOSE(1, verbose_stream() << "propagate " << axb_l_y << " " << premise << " => " << conseq << "\n"); + m_lemma.reset(); + m_lemma.insert_eval(~s.eq(y)); + for (auto const& c : premise) { + if (is_forced_false(c)) + return false; + m_lemma.insert_eval(~c); + } + return propagate(x, core, axb_l_y, conseq); + }; + + auto propagate2 = [&](vector const& premise1, vector const& premise2, signed_constraint conseq) { + IF_VERBOSE(1, verbose_stream() << "propagate " << axb_l_y << " " << premise1 << " " << premise2 << " => " << conseq << "\n"); + m_lemma.reset(); + m_lemma.insert_eval(~s.eq(y)); + for (auto const& c : premise1) { + if (is_forced_false(c)) + return false; + m_lemma.insert_eval(~c); + } + for (auto const& c : premise2) { + if (is_forced_false(c)) + return false; + m_lemma.insert_eval(~c); + } + return propagate(x, core, axb_l_y, conseq); + }; + + auto correct_parity = [&](vector const& at_least, vector const& at_most) { + m_lemma.reset(); + m_lemma.insert_eval(~s.eq(y)); + for (auto const& c : at_least) { + if (is_forced_false(c)) + return false; + m_lemma.insert_eval(~c); + } + for (auto const& c : at_most) { + if (is_forced_false(c)) + return false; + m_lemma.insert_eval(~c); + } + return propagate(x, core, axb_l_y, s.f()); // TODO: Conflict overload + }; + + vector at_least_x, at_most_x, at_least_b, at_most_b, at_least_a, at_most_a; + + set_rule("[x] min_parity(t[x], j1) > max_parity(t[x], j2) => (!j1 || !j2)"); + + bool failed = false; + unsigned min_x = min_parity(X, at_least_x), max_x = max_parity(X, at_most_x); + unsigned min_b = min_parity(b, at_least_b), max_b = max_parity(b, at_most_b); + unsigned min_a = min_parity(a, at_least_a), max_a = max_parity(a, at_most_a); + + // correct min_parity(x) > max_parity(x) + if (min_x > max_x) { + failed = true; + correct_parity(at_least_x, at_most_x); + } + if (min_b > max_b) { + failed = true; + correct_parity(at_least_b, at_most_b); + } + if (min_a > max_a) { + failed = true; + correct_parity(at_least_a, at_most_a); + } + + if (failed) + // we propagated at least one parity correction lemma but there is no reason to proceed + return true; + + SASSERT(max_x <= N); + SASSERT(max_b <= N); + SASSERT(max_a <= N); + + IF_VERBOSE(2, + verbose_stream() << "try parity v" << x << " " << axb_l_y << "\n"; + verbose_stream() << "x " << X << " " << min_x << " " << max_x << "\n"; + verbose_stream() << "a " << a << " " << min_a << " " << max_a << "\n"; + verbose_stream() << "b " << b << " " << min_b << " " << max_b << "\n"); + + if (min_x >= N || min_a >= N) + return false; + + auto at_most = [&](pdd const& p, unsigned k) { + VERIFY(k < N); + return s.parity_at_most(p, k); + }; + + auto at_least = [&](pdd const& p, unsigned k) { + VERIFY(k != 0); + VERIFY(k <= N); + return s.parity_at_least(p, k); + }; + + set_rule("[x] a*x + b = 0 => (odd(a) & odd(x) <=> odd(b)) 1"); + if (!b.is_val() && max_b > max_a + max_x && propagate2(at_most_a, at_most_x, at_most(b, max_x + max_a))) + return true; + set_rule("[x] a*x + b = 0 => (odd(a) & odd(x) <=> odd(b)) 2"); + if (!b.is_val() && min_x > min_b && propagate1(at_least_x, at_least(b, min_x))) + return true; + set_rule("[x] a*x + b = 0 => (odd(a) & odd(x) <=> odd(b)) 3"); + if (!b.is_val() && min_a > min_b && propagate1(at_least_a, at_least(b, min_a))) + return true; + set_rule("[x] a*x + b = 0 => parity(b) >= parity(a) + parity(x)"); + if (!b.is_val() && min_x > 0 && min_a > 0 && min_x + min_a > min_b && N > min_b && propagate2(at_least_a, at_least_x, at_least(b, std::min(N, min_a + min_x)))) + return true; + set_rule("[x] a*x + b = 0 => (odd(a) & odd(x) <=> odd(b)) 5"); + if (!a.is_val() && max_x <= min_b && min_a < min_b - max_x && propagate2(at_most_x, at_least_b, at_least(a, min_b - max_x))) + return true; + set_rule("[x] a*x + b = 0 => (odd(a) & odd(x) <=> odd(b)) 6"); + if (max_a <= min_b && min_x < min_b - max_a && propagate2(at_most_a, at_least_b, at_least(X, min_b - max_a))) + return true; + set_rule("[x] a*x + b = 0 => (odd(a) & odd(x) <=> odd(b)) 7"); + if (max_b < N && !a.is_val() && min_x > 0 && min_x <= max_b && max_a > max_b - min_x && propagate2(at_least_x, at_most_b, at_most(a, max_b - min_x))) + return true; + set_rule("[x] a*x + b = 0 => (odd(a) & odd(x) <=> odd(b)) 8"); + if (max_b < N && min_a > 0 && min_a <= max_b && max_x > max_b - min_a && propagate2(at_least_a, at_most_b, at_most(X, max_b - min_a))) + return true; + + return false; + } + + /** + * 2^{N-1}*x*y != 0 => odd(x) & odd(y) + * 2^k*x != 0 => parity(x) < N - k + * 2^k*x*y != 0 => parity(x) + parity(y) < N - k + * + * 2^k*x + b != 0 & parity(x) >= N - k => b != 0 & 2^k*x = 0 (rewriting constraints modulo parity is more powerful and subsumes this) + */ + bool saturation::try_parity_diseq(pvar x, conflict& core, inequality const& axb_l_y) { + set_rule("[x] p(x,y) != 0 => constraints on parity(x), parity(y)"); + auto& m = s.var2pdd(x); + unsigned N = m.power_of_2(); + pdd y = m.zero(); + pdd a = y, b = y; + pdd X = s.var(x); + if (!is_AxB_diseq_0(x, axb_l_y, a, b, y)) + return false; + if (is_forced_eq(b, 0)) { + auto coeff = a.leading_coefficient(); + if (coeff.is_odd()) + return false; + SASSERT(coeff != 0); + unsigned k = coeff.trailing_zeros(); + m_lemma.reset(); + m_lemma.insert_eval(~s.eq(y)); + m_lemma.insert_eval(~s.eq(b)); + if (propagate(x, core, axb_l_y, ~s.parity_at_least(X, N - k))) + return true; + // TODO parity on a (without leading coefficient?) + } + if (a.is_val()) { + auto coeff = a.val(); + unsigned k = coeff.trailing_zeros(); + vector at_least_x; + unsigned p_x = min_parity(X, at_least_x); + if (k + p_x >= N) { + // ax + b != 0 + m_lemma.reset(); + m_lemma.insert_eval(~s.eq(y)); + for (auto c : at_least_x) + m_lemma.insert_eval(~c); + if (propagate(x, core, axb_l_y, ~s.eq(b))) + return true; + } + } + + return false; + } + + /** + * a*x = 0 => a = 0 or even(x) + * a*x = 0 => a = 0 or x = 0 or even(a) + */ + bool saturation::try_mul_odd(pvar x, conflict& core, inequality const& axb_l_y) { + set_rule("[x] ax = 0 => a = 0 or even(x)"); + auto& m = s.var2pdd(x); + pdd y = m.zero(); + pdd a = m.zero(); + pdd b = m.zero(); + pdd X = s.var(x); + signed_constraint a_eq_0, x_eq_0; + if (!is_AxB_eq_0(x, axb_l_y, a, b, y)) + return false; + if (!is_forced_eq(b, 0)) + return false; + if (!is_forced_diseq(a, 0, a_eq_0)) + return false; + m_lemma.reset(); + m_lemma.insert_eval(s.eq(y)); + m_lemma.insert_eval(~s.eq(b)); + m_lemma.insert_eval(a_eq_0); + if (propagate(x, core, axb_l_y, s.even(X))) + return true; + if (!is_forced_diseq(X, 0, x_eq_0)) + return false; + m_lemma.insert_eval(x_eq_0); + if (propagate(x, core, axb_l_y, s.even(a))) + return true; + return false; + } + + /** + * TODO If both inequalities are strict, then the implied inequality has a gap of 2 + * a < b, b < c => a + 1 < c & a + 1 != 0 + */ + bool saturation::try_transitivity(pvar x, conflict& core, inequality const& a_l_b) { + set_rule("[x] q < x & x <= p => q < p"); + auto& m = s.var2pdd(x); + pdd p = m.zero(); + pdd a = p, b = p, q = p; + // x <= p + if (!is_Ax_l_Y(x, a_l_b, a, p)) + return false; + if (!is_forced_eq(a, 1)) + return false; + for (auto c : core) { + if (!c->is_ule()) + continue; + if (c->to_ule().power_of_2() != m.power_of_2()) + continue; + auto i = inequality::from_ule(c); + if (c == a_l_b.as_signed_constraint()) + continue; + if (!is_Y_l_Ax(x, i, b, q)) + continue; + if (!is_forced_eq(b, 1)) + continue; + m_lemma.reset(); + m_lemma.insert_eval(~s.eq(a, 1)); + m_lemma.insert_eval(~s.eq(b, 1)); + m_lemma.insert(~c); + auto ineq = i.is_strict() || a_l_b.is_strict() ? (p.is_val() ? s.ule(q, p - 1) : s.ult(q, p)) : s.ule(q, p); + if (propagate(x, core, a_l_b, ineq)) + return true; + } + + return false; + } + + /** + * p <= q, q <= p => p - q = 0 + */ + bool saturation::try_infer_equality(pvar x, conflict& core, inequality const& a_l_b) { + set_rule("[x] p <= q, q <= p => p - q = 0"); + if (a_l_b.is_strict()) + return false; + if (a_l_b.lhs().degree(x) == 0 && a_l_b.rhs().degree(x) == 0) + return false; + for (auto c : core) { + if (!c->is_ule()) + continue; + auto i = inequality::from_ule(c); + if (i.lhs() == a_l_b.rhs() && i.rhs() == a_l_b.lhs() && !i.is_strict()) { + m_lemma.reset(); + m_lemma.insert(~c); + if (propagate(x, core, a_l_b, s.eq(i.lhs() - i.rhs()))) { + IF_VERBOSE(1, verbose_stream() << "infer equality " << s.eq(i.lhs() - i.rhs()) << "\n"); + return true; + } + } + } + return false; + } + + lbool saturation::get_multiple(const pdd& p1, const pdd& p2, pdd& out) { + LOG("Check if " << p2 << " can be multiplied with something to get " << p1); + if (p1.is_zero()) { // TODO: use the evaluated parity (max_parity) instead? + out = p1.manager().zero(); + return l_true; + } + if (p2.is_one()) { + out = p1; + return l_true; + } + if (!p1.is_monomial() || !p2.is_monomial()) + // TODO: Actually, this could work as well. (4a*d + 6b*c*d) is a multiple of (2a + 3b*c) although none of them is a monomial + return l_undef; + + vector maxp1, minp2; + unsigned max_parity_p1 = max_parity(p1, maxp1); + unsigned min_parity_p2 = min_parity(p2, minp2); + + if (min_parity_p2 > max_parity_p1) + return l_false; + + dd::pdd_monomial p1m = *p1.begin(); + dd::pdd_monomial p2m = *p2.begin(); + + m_occ_cnt.reserve(s.m_vars.size(), (unsigned)0); // TODO: Are there duplicates in the list (e.g., v1 * v1)?) + + for (const auto& v1 : p1m.vars) { + if (m_occ_cnt[v1] == 0) + m_occ.push_back(v1); + m_occ_cnt[v1]++; + } + for (const auto& v2 : p2m.vars) { + if (m_occ_cnt[v2] == 0) { + for (const auto& occ : m_occ) + m_occ_cnt[occ] = 0; + m_occ.clear(); + return l_undef; // p2 contains more v2 than p1; we need more information (assignments) + } + m_occ_cnt[v2]--; + } + + unsigned tz1 = p1m.coeff.trailing_zeros(); + unsigned tz2 = p2m.coeff.trailing_zeros(); + if (tz2 > tz1) + return l_undef; + + rational odd = div(p2m.coeff, rational::power_of_two(tz2)); + rational inv; + VERIFY(odd.mult_inverse(p1.power_of_2() - tz2, inv)); // we divided by the even part, so it has to be odd/invertible now + inv *= div(p1m.coeff, rational::power_of_two(tz2)); + + out = p1.manager().mk_val(inv); + for (const auto& occ : m_occ) { + for (unsigned i = 0; i < m_occ_cnt[occ]; i++) + out *= s.var(occ); + m_occ_cnt[occ] = 0; + } + m_occ.clear(); + LOG("Found multiple: " << out); + return l_true; + } + + bool saturation::try_factor_equality(pvar x, conflict& core, inequality const& a_l_b) { + set_rule("[x] ax + b = 0 & C[x] => C[-inv(a)*b]"); + auto& m = s.var2pdd(x); + pdd y = m.zero(); + pdd a = y, b = y, a1 = y, b1 = y, mul_fac = y; + if (!is_AxB_eq_0(x, a_l_b, a, b, y)) // TODO: Is the restriction to linear "x" too restrictive? + return false; + + bool prop = false; + + for (auto c : core) { + if (c == a_l_b) + continue; + LOG("Trying to eliminate v" << x << " in " << c << " by using equation " << a_l_b.as_signed_constraint()); + if (c->is_ule()) { + set_rule("[x] ax + b = 0 & C[x] => C[-inv(a)*b] ule"); + // If both are equalities this boils down to polynomial superposition => Might generate the same lemma twice + auto const& ule = c->to_ule(); + m_lemma.reset(); + auto [lhs_new, changed_lhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ule.lhs(), m_lemma); + auto [rhs_new, changed_rhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ule.rhs(), m_lemma); + if (!changed_lhs && !changed_rhs) + continue; // nothing changed - no reason for propagating lemmas + m_lemma.insert(~c); + m_lemma.insert_eval(~s.eq(y)); + + if (propagate(x, core, a_l_b, c.is_positive() ? s.ule(lhs_new, rhs_new) : ~s.ule(lhs_new, rhs_new))) + prop = true; + } + else if (c->is_umul_ovfl()) { + set_rule("[x] ax + b = 0 & C[x] => C[-inv(a)*b] umul_ovfl"); + auto const& ovf = c->to_umul_ovfl(); + m_lemma.reset(); + auto [lhs_new, changed_lhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ovf.p(), m_lemma); + auto [rhs_new, changed_rhs] = m_parity_tracker.eliminate_variable(*this, x, a, b, ovf.q(), m_lemma); + if (!changed_lhs && !changed_rhs) + continue; + m_lemma.insert(~c); + m_lemma.insert_eval(~s.eq(y)); + + if (propagate(x, core, a_l_b, c.is_positive() ? s.umul_ovfl(lhs_new, rhs_new) : ~s.umul_ovfl(lhs_new, rhs_new))) + prop = true; + } + } + return prop; + } + + + /** + * x >= x + y & x <= n ==> y >= M - n or y = 0 + * x > x + y & x <= n ==> y >= M - n + * -x <= -x - y & x <= n ==> y >= M - n or y = 0 or x = 0 + * -x < -x - y & x <= n ==> y >= M - n or x = 0 + * + * NOTE: x + y <= x <=> -y <= x <=> -x-1 <= y-1 + * x <= x + y <=> x <= -y-1 <=> y <= -x-1 + * (see notes on equivalent forms in ule_constraint.cpp) + * + * p <= q ==> p = 0 or -q <= -p + */ + bool saturation::try_add_overflow_bound(pvar x, conflict& core, inequality const& i) { + set_rule("[x] x >= x + y & x <= n => y = 0 or y >= 2^N - n"); + signed_constraint y_eq_0, x_eq_0; + vector x_le_bound; + auto& m = s.var2pdd(x); + pdd y = m.zero(); + bool is_minus; + if (!is_add_overflow(x, i, y, is_minus)) + return false; + if (!i.is_strict() && !is_forced_diseq(y, 0, y_eq_0)) + return false; + if (is_minus && !is_forced_diseq(s.var(x), 0, x_eq_0)) + return false; + rational bound; + if (!has_upper_bound(x, core, bound, x_le_bound)) + return false; + SASSERT(bound != 0); + m_lemma.reset(); + if (!i.is_strict()) + m_lemma.insert_eval(y_eq_0); + if (is_minus) + m_lemma.insert_eval(x_eq_0); + for (auto c : x_le_bound) + m_lemma.insert_eval(~c); + return propagate(x, core, i, s.uge(y, m.two_to_N() - bound)); + } + + /** + * Match one of the patterns: + * x >= x + y + * x > x + y + * -x <= -x - y + * -x < -x - y + */ + bool saturation::is_add_overflow(pvar x, inequality const& i, pdd& y, bool& is_minus) { + pdd const X = s.var(x); + pdd a = X; + if (i.lhs().degree(x) != 1 || i.rhs().degree(x) != 1) + return false; + if (i.rhs() == X) { + i.lhs().factor(x, 1, a, y); + if (a.is_one()) { + is_minus = false; + return true; + } + } + if (i.lhs() == -X) { + i.rhs().factor(x, 1, a, y); + if ((-a).is_one()) { + is_minus = true; + y = -y; + return true; + } + } + return false; + } + + bool saturation::has_upper_bound(pvar x, conflict& core, rational& bound, vector& x_le_bound) { + return s.m_viable.has_upper_bound(x, bound, x_le_bound); + } + + bool saturation::has_lower_bound(pvar x, conflict& core, rational& bound, vector& x_ge_bound) { + return s.m_viable.has_lower_bound(x, bound, x_ge_bound); + } + + rational saturation::round(rational const& M, rational const& x) { + SASSERT(0 <= x && x < M); + if (x + M/2 > M) + return x - M; + else + return x; + } + + bool saturation::eval_round(rational const& M, pdd const& p, rational& r) { + if (!s.try_eval(p, r)) + return false; + r = round(M, r); + return true; + } + + /** + * Write as q := a*y + b + * + * If y == null_var, chooses some variable y from q (if one exists). + */ + bool saturation::extract_linear_form(pdd const& q, pvar& y, rational& a, rational& b) { + auto& m = q.manager(); + rational const& M = m.two_to_N(); + + if (q.is_val()) { + a = 0; + b = round(M, q.val()); + return true; + } + if (y == null_var) { + // choose the top variable + y = q.var(); + if (!q.hi().is_val() && q.hi().var() == y) + return false; + if (!eval_round(M, q.hi(), a)) + return false; + if (!eval_round(M, q.lo(), b)) + return false; + return true; + } + else { + // factor according to given variable + SASSERT(y != null_var); + switch (q.degree(y)) { + case 0: + if (!eval_round(M, q, b)) + return false; + a = 0; + return true; + case 1: { + pdd a1(m), b1(m); + q.factor(y, 1, a1, b1); + if (!eval_round(M, a1, a)) + return false; + if (!eval_round(M, b1, b)) + return false; + return true; + } + default: + return false; + } + } + } + + /** + * Write as p := a*x*y + b*x + c*y + d + * + * If y == null_var, chooses some variable y != x from p (if one exists). + */ + bool saturation::extract_bilinear_form(pvar x, pdd const& p, pvar& y, bilinear& b) { + auto& m = s.var2pdd(x); + rational const& M = m.two_to_N(); + switch (p.degree(x)) { + case 0: + if (!s.try_eval(p, b.d)) + return false; + b.a = b.b = b.c = 0; + return true; + case 1: { + pdd q = p, r = p, u = p, v = p; + p.factor(x, 1, q, r); + if (!extract_linear_form(q, y, b.a, b.b)) + return false; + if (b.a == 0) { + b.c = 0; + return eval_round(M, r, b.d); + } + SASSERT(y != null_var); + switch (r.degree(y)) { + case 0: + if (!eval_round(M, r, b.d)) + return false; + b.c = 0; + return true; + case 1: + r.factor(y, 1, u, v); + if (!eval_round(M, u, b.c)) + return false; + if (!eval_round(M, v, b.d)) + return false; + return true; + default: + return false; + } + return false; + } + default: + return false; + } + } + + /** + * Update d such that -M < a*x*y0 + b*x + c*y0 + d < M for every value x_min <= x <= x_max, return x_split such that [x_min,x_split[ and [x_split,x_max] can fit into [0,M[ + * return false if there is no such d. + */ + bool saturation::adjust_bound(rational const& x_min, rational const& x_max, rational const& y0, + rational const& M, bilinear& b, rational& x) { + SASSERT(x_min <= x_max); + rational A = b.a*y0 + b.b; + rational B = b.c*y0 + b.d; + rational max = A >= 0 ? x_max * A + B : x_min * A + B; + rational min = A >= 0 ? x_min * A + B : x_max * A + B; + VERIFY(min <= max); + if (max - min >= M) { + IF_VERBOSE(10, verbose_stream() << "adjust_bound: abort because max - min >= M\n"); + return false; + } + + // k0 = min k. val + kM >= 0 + // = min k. k >= -val/M + // = ceil(-val/M) = -floor(val/M) + rational offset = rational::zero(); + if (max < 0 || max >= M) + offset = -M * floor(max / M); + b.d += offset; + + // If min + offset < 0, then [min,max] contains a multiple of M. + if (min + offset < 0) { + // A*x_split + B + offset = 0 + // x_split = -(B+offset)/A + if (A >= 0) { + x = ceil((-offset - B) / A); + // [x_min; x_split-1] maps to interval < 0 + // [x_split; x_max] maps to interval >= 0 + VERIFY(b.eval(x, y0) >= 0); + VERIFY(b.eval(x-1, y0) < 0); + VERIFY(x_min <= x && x <= x_max); + } + else { + x = floor((-offset - B) / A) + 1; + // [x_min; x_split-1] maps to interval >= 0 + // [x_split; x_max] maps to interval < 0 + VERIFY(b.eval(x, y0) < 0); + VERIFY(b.eval(x-1, y0) >= 0); + VERIFY(x_min <= x && x <= x_max); + } + } + + VERIFY(-M < b.eval(x_min, y0)); + VERIFY(b.eval(x_min, y0) < M); + VERIFY(-M < b.eval(x_max, y0)); + VERIFY(b.eval(x_max, y0) < M); + return true; + } + + /** + * Based on a*x*y + b*x + c*y + d >= 0 + * update lower bound for y + */ + bool saturation::update_min(rational& y_min, rational const& x_min, rational const& x_max, + bilinear const& b) { + if (b.a == 0 && b.c == 0) + return true; + + rational x_bound; + if (b.a >= 0 && b.b >= 0) + x_bound = x_min; + else if (b.a <= 0 && b.b <= 0) + x_bound = x_max; + else + return false; + + // a*x_bound*y + b*x_bound + c*y + d >= 0 + // (a*x_bound + c)*y >= -d - b*x_bound + // if a*x_bound + c > 0 + rational A = b.a*x_bound + b.c; + if (A <= 0) + return true; + rational y1 = ceil((- b.d - b.b*x_bound)/A); + if (y1 > y_min) + y_min = y1; + return true; + } + + bool saturation::update_max(rational& y_max, rational const& x_min, rational const& x_max, + bilinear const& b) { + if (b.a == 0 && b.c == 0) + return true; + + rational x_bound; + if (b.a >= 0 && b.b >= 0) + x_bound = x_min; + else if (b.a <= 0 && b.b <= 0) + x_bound = x_max; + else + return false; + + // a*x_bound*y + b*x_bound + c*y + d >= 0 + // (a*x_bound + c)*y >= -d - b*x_bound + // if a*x_bound + c < 0 + rational A = b.a*x_bound + b.c; + if (A >= 0) + return true; + rational y1 = floor((- b.d - b.b*x_bound)/A); + if (y1 < y_max) + y_max = y1; + return true; + } + + void saturation::fix_values(pvar y, pdd const& p) { + if (p.degree(y) == 0) { + rational p_val; + VERIFY(s.try_eval(p, p_val)); + m_lemma.insert_eval(~s.eq(p, p_val)); + } + else { + pdd q = p, r = p; + p.factor(y, 1, q, r); + fix_values(y, q); + fix_values(y, r); + } + } + + void saturation::fix_values(pvar x, pvar y, pdd const& p) { + if (p.degree(x) == 0) + fix_values(y, p); + else { + pdd q = p, r = p; + p.factor(x, 1, q, r); + fix_values(x, y, q); + fix_values(x, y, r); + } + } + + bool saturation::update_bounds_for_xs(rational const& x_min, rational const& x_max, rational& y_min, rational& y_max, rational const& y0, bilinear b1, bilinear b2, rational const& M, inequality const& a_l_b) { + + VERIFY(x_min <= x_max); + + if (b1.eval(x_min, y0) < 0) + b1.d += M; + if (b2.eval(x_min, y0) < 0) + b2.d += M; + + IF_VERBOSE(2, + verbose_stream() << "Adjusted for x in [" << x_min << "; " << x_max << "]\n"; + verbose_stream() << "p ... " << b1 << " " << b1.eval(x_min, y0) << "\n"; + verbose_stream() << "q ... " << b2 << " " << b2.eval(x_min, y0) << "\n"; + ); + + // Precondition: forall x . x_min <= x <= x_max ==> p(x,y0) > q(x,y0) + // check the endpoints + // + // the pre-condition could be false if the interval x_min..x_max + // is not defined by a_l_b, but different constraints. + // + if (b1.eval(x_min, y0) < b2.eval(x_min, y0) + (a_l_b.is_strict() ? 0 : 1)) + return false; + if (b1.eval(x_max, y0) < b2.eval(x_max, y0) + (a_l_b.is_strict() ? 0 : 1)) + return false; + + if (!update_min(y_min, x_min, x_max, b1)) + return false; + if (!update_min(y_min, x_min, x_max, b2)) + return false; + //verbose_stream() << "min-max: x := v" << x << " [" << x_min << "," << x_max << "] y := v" << y << " [" << y_min << ", " << y_max << "] y0 " << y0 << "\n"; + VERIFY(y_min <= y0 && y0 <= y_max); + if (!update_max(y_max, x_min, x_max, b1)) + return false; + if (!update_max(y_max, x_min, x_max, b2)) + return false; + //verbose_stream() << "min-max: x := v" << x << " [" << x_min << "," << x_max << "] y := v" << y << " [" << y_min << ", " << y_max << "] y0 " << y0 << "\n"; + VERIFY(y_min <= y0 && y0 <= y_max); + // p < M iff -p > -M iff -p + M - 1 >= 0 + if (!update_min(y_min, x_min, x_max, -b1 + (M - 1))) + return false; + if (!update_min(y_min, x_min, x_max, -b2 + (M - 1))) + return false; + if (!update_max(y_max, x_min, x_max, -b1 + (M - 1))) + return false; + if (!update_max(y_max, x_min, x_max, -b2 + (M - 1))) + return false; + VERIFY(y_min <= y0 && y0 <= y_max); + // p <= q or p < q is false + // so p > q or p >= q + // p - q - 1 >= 0 or p - q >= 0 + // min-max for p - q - 1 or p - q are non-negative + if (!update_min(y_min, x_min, x_max, b1 - b2 - (a_l_b.is_strict() ? 0 : 1))) + return false; + if (!update_max(y_max, x_min, x_max, b1 - b2 - (a_l_b.is_strict() ? 0 : 1))) + return false; + return true; + } + + // wip - outline of what should be a more general approach + bool saturation::try_add_mul_bound(pvar x, conflict& core, inequality const& a_l_b) { + set_rule("[x] mul-bound2 ax + b <= y, ... => a >= u_a"); + + // comment out for dev + return false; + + auto& m = s.var2pdd(x); + pdd p = a_l_b.lhs(), q = a_l_b.rhs(); + // add this filter to remove useless bounds + if (q.is_zero()) + return false; + if (p.degree(x) > 1 || q.degree(x) > 1) + return false; + if (p.degree(x) == 0 && q.degree(x) == 0) + return false; + + pvar y = null_var; + bilinear b1, b2; + if (!extract_bilinear_form(x, p, y, b1)) + return false; + if (!extract_bilinear_form(x, q, y, b2)) + return false; + if (y == null_var) + return false; + if (!s.is_assigned(y)) + return false; + rational y0 = s.get_value(y); + + vector bounds; + rational x_min, x_max; + if (!s.m_viable.has_max_forbidden(x, a_l_b, x_max, x_min, bounds)) + return false; + + // retrieved maximal forbidden interval [x_max, x_min[. + // [x_min, x_max[ is the allowed interval. + // compute [x_min, x_max - 1] + VERIFY(x_min != x_max); + SASSERT(0 <= x_min && x_min <= m.max_value()); + SASSERT(0 <= x_max && x_max <= m.max_value()); + rational const& M = m.two_to_N(); + x_max = x_max == 0 ? m.max_value() : x_max - 1; + if (x_min == x_max) + return false; + if (x_min > x_max) + x_min -= M; + // else if (m.max_value() - x_max < x_min) { + // TODO: deal with large x values like this? + // x_min -= M; + // x_max -= M; + // } + SASSERT(x_min <= x_max); + + IF_VERBOSE(2, + verbose_stream() << "\n---\n\n"; + verbose_stream() << "constraint " << lit_pp(s, a_l_b) << "\n"; + verbose_stream() << "x = v" << x << "\n"; + verbose_stream() << "y = v" << y << "\n"; + s.m_viable.display(verbose_stream() << "\nx-intervals:\n", x, "\n") << "\n"; + verbose_stream() << "\n"; + verbose_stream() << "x_min " << x_min << " x_max " << x_max << "\n"; + verbose_stream() << "v" << y << " " << y0 << "\n"; + verbose_stream() << p << " ... " << b1 << "\n"; + verbose_stream() << q << " ... " << b2 << "\n"); + + rational x_sp1 = x_min; + rational x_sp2 = x_min; + + if (!adjust_bound(x_min, x_max, y0, M, b1, x_sp1)) + return false; + if (!adjust_bound(x_min, x_max, y0, M, b2, x_sp2)) + return false; + + if (x_sp1 > x_sp2) + std::swap(x_sp1, x_sp2); + SASSERT(x_min <= x_sp1 && x_sp1 <= x_sp2 && x_sp2 <= x_max); + + IF_VERBOSE(2, + verbose_stream() << "Adjusted\n"; + verbose_stream() << p << " ... " << b1 << "\n"; + verbose_stream() << q << " ... " << b2 << "\n"; + // verbose_stream() << "p(x_min,y0) = " << b1.eval(x_min, y0) << "\n"; + // verbose_stream() << "q(x_min,y0) = " << b2.eval(x_min, y0) << "\n"; + // verbose_stream() << "p(x_max,y0) = " << b1.eval(x_max, y0) << "\n"; + // verbose_stream() << "q(x_max,y0) = " << b2.eval(x_max, y0) << "\n"; + ); + + rational y_min(0), y_max(M-1); + if (x_min != x_sp1 && !update_bounds_for_xs(x_min, x_sp1-1, y_min, y_max, y0, b1, b2, M, a_l_b)) + return false; + IF_VERBOSE(11, verbose_stream() << "min-max: x := v" << x << " [" << x_min << "," << x_max << "] y := v" << y << " [" << y_min << ", " << y_max << "] y0 " << y0 << "\n"); + if (x_sp1 != x_sp2 && !update_bounds_for_xs(x_sp1, x_sp2-1, y_min, y_max, y0, b1, b2, M, a_l_b)) + return false; + IF_VERBOSE(11, verbose_stream() << "min-max: x := v" << x << " [" << x_min << "," << x_max << "] y := v" << y << " [" << y_min << ", " << y_max << "] y0 " << y0 << "\n"); + if (!update_bounds_for_xs(x_sp2, x_max, y_min, y_max, y0, b1, b2, M, a_l_b)) + return false; + IF_VERBOSE(11, verbose_stream() << "min-max: x := v" << x << " [" << x_min << "," << x_max << "] y := v" << y << " [" << y_min << ", " << y_max << "] y0 " << y0 << "\n"); + + SASSERT(y_min <= y0 && y0 <= y_max); + VERIFY(y_min <= y0 && y0 <= y_max); + if (y_min == y_max) + return false; + + m_lemma.reset(); + for (auto const& c : bounds) + m_lemma.insert_eval(~c); + fix_values(x, y, p); + fix_values(x, y, q); + if (y_max != M - 1) { + if (y_min != 0) + m_lemma.insert_eval(s.ult(s.var(y), y_min)); + return propagate(x, core, a_l_b, s.ult(y_max, s.var(y))); + } + if (y_min != 0) + return propagate(x, core, a_l_b, s.ult(s.var(y), y_min)); + else + return false; + } + + /** + * p >= q & q*2^k = 0 & p < 2^{K-k} => q = 0 + * More generally + * p >= q + r & q*2^k = 0 & p < 2^{K-k} & r < 2^{K-k} => q = 0 & p >= r + * + * The parity constraint on q entails that the low K-k bits of q must be 0 + * and therefore q is either 0 or at or above 2^{K-k}. + * Since p is blow 2^{K-k} the only intersection between the viable + * intervals imposed by p and possible for q is 0. + * + */ + bool saturation::try_infer_parity_equality(pvar x, conflict& core, inequality const& a_l_b) { + return false; + set_rule("[x] p > q & 2^k*q = 0 & p < 2^{K-k} => q = 0"); + auto& m = s.var2pdd(x); + auto p = a_l_b.rhs(), q = a_l_b.lhs(); + if (q.is_val()) + return false; + if (p.is_val() && p.val() == 0) + return false; + rational p_val; + if (!s.try_eval(p, p_val)) + return false; + vector at_least_k; + unsigned k = min_parity(q, at_least_k); + unsigned N = m.power_of_2(); + if (k == N) + return false; + if (rational::power_of_two(k) > p_val) { + // verbose_stream() << k << " " << p_val << " " << a_l_b << "\n"; + m_lemma.reset(); + for (auto const& c : at_least_k) + m_lemma.insert_eval(~c); + m_lemma.insert_eval(~s.ult(p, rational::power_of_two(k))); + return propagate(x, core, a_l_b, s.eq(q)); + } + return false; + } + + + /** + * let q1 = x1 / y1, q2 = x2 / y2 + * x1 <= x2 & y1 >= y2 => q1 <= q2 + * y1 <= y2 & q1 < q2 => (x2 - x1) >= (q2 - q1 - 1) * y1 + * + * Limitation/assumption: + * Values of x1, y1, q1 have to be available for the rule to apply. + * If not all values are present, the rule isn't going to be used. + * The arithmetic solver uses complete assignments because it + * builds on top of an integer feasible state (or feasible over rationals) + * Lemmas are false under that assignment. They don't necessarily propagate, though. + * PolySAT isn't (yet) set up to work with complete assignments and thereforce misses such lemmas. + * - should we force complete assignments by computing first a model that is feasible modulo linear constraints + * (ignore non-linear constraints in linear mode)? + * - should we detect forcing relations x1 <= x2, y2 <= y1 based on the constraints (not on assignments)? + * other saturation rules already do this, but it is highly syntactic whether they apply. + * + * + * Other rules: + * x < y div z => x * z < y + * + * Or just: + * (y div z) * z <= y, + * ~overfl((y div z) * z) + * + * ~overfl(x * y), z <= y => x * z <= x * y + * + */ + bool saturation::try_div_monotonicity(conflict& core) { + bool propagated = false; + + auto log = [&](auto& x1, auto& x1_val, auto& y1, auto& y1_val, auto& q1, auto& q1_val, + auto& x2, auto& x2_val, auto& y2, auto& y2_val, auto& q2, auto& q2_val) { + IF_VERBOSE(1, verbose_stream() << "Division monotonicity: [" << x1 << "] (" << x1_val << ") / [" << y1 << "] (" << y1_val << ") = " + << s.var(q1) << "\n"); + }; + +#if 0 + // monotonicity0 lemma should be asserted eagerly. + auto monotonicity0 = [&](auto& x1, auto& x1_val, auto& y1, auto& y1_val, auto& q1, auto& q1_val) { + if (q1_val * y1_val <= x1_val) + return; + // q1*y1 + r1 = x1, q1*y1 <= -r1 - 1, q1*y1 <= x1 + propagated = true; + set_rule("[x1, y1] (x1 / y1) * y1 <= x1"); + m_lemma.reset(); + propagate(q1, core, s.ule(s.var(q1) * y1, x1)); + }; +#endif + + auto monotonicity1 = [&](auto& x1, auto& x1_val, auto& y1, auto& y1_val, auto& q1, auto& q1_val, + auto& x2, auto& x2_val, auto& y2, auto& y2_val, auto& q2, auto& q2_val) { + if (!(x1_val <= x2_val && y1_val >= y2_val && q1_val > q2_val)) + return; + propagated = true; + set_rule("[x1, y1, x2, y2] x1 <= x2 & y2 <= y1 => x1 / y1 <= x2 / y2"); + log(x1, x1_val, y1, y1_val, q1, q1_val, x2, x2_val, y2, y2_val, q2, q2_val); + m_lemma.reset(); + m_lemma.insert_eval(~s.ule(x1, x2)); + m_lemma.insert_eval(~s.ule(y2, y1)); + propagate(q1, core, s.ule(s.var(q1), s.var(q2))); + }; + + auto monotonicity2 = [&](auto& x1, auto& x1_val, auto& y1, auto& y1_val, auto& q1, auto& q1_val, + auto& x2, auto& x2_val, auto& y2, auto& y2_val, auto& q2, auto& q2_val) { + if (!(y1_val <= y2_val && q1_val < q2_val && (x2_val - x1_val < (q2_val - q1_val - 1) * y1_val))) + return; + propagated = true; + set_rule("[x1, y1, x2, y2] y2 >= y1 & q2 > q1 => x2 - x1 >= ((x2 / y2) - (x1 / y1) - 1) * y1"); + log(x1, x1_val, y1, y1_val, q1, q1_val, x2, x2_val, y2, y2_val, q2, q2_val); + m_lemma.reset(); + m_lemma.insert_eval(~s.uge(y2, y1)); + m_lemma.insert_eval(~s.ult(s.var(q1), s.var(q2))); + propagate(q1, core, s.uge(x2 - x1, (s.var(q2) - s.var(q1) - 1) * y1)); + }; + + + for (auto const& [x1, y1, q1, r1] : s.m_constraints.div_constraints()) { + rational x1_val, y1_val, q1_val; + if (!s.try_eval(x1, x1_val) || !s.try_eval(y1, y1_val) || !s.is_assigned(q1)) + continue; + q1_val = s.get_value(q1); + rational expected1 = y1_val.is_zero() ? y1.manager().max_value() : div(x1_val, y1_val); + + if (q1_val == expected1) + continue; + + // force that q1 * y1 <= x1 if it isn't the case. + // monotonicity0(x1, x1_val, y1, y1_val, q1, q1_val); + + for (auto const& [x2, y2, q2, r2] : s.m_constraints.div_constraints()) { + if (x1 == x2 && y1 == y2) + continue; + if (x1.power_of_2() != x2.power_of_2()) + continue; + rational x2_val, y2_val, q2_val; + if (!s.try_eval(x2, x2_val) || !s.try_eval(y2, y2_val) || !s.is_assigned(q2)) + continue; + q2_val = s.get_value(q2); + monotonicity1(x1, x1_val, y1, y1_val, q1, q1_val, x2, x2_val, y2, y2_val, q2, q2_val); + monotonicity1(x2, x2_val, y2, y2_val, q2, q2_val, x1, x1_val, y1, y1_val, q1, q1_val); + monotonicity2(x1, x1_val, y1, y1_val, q1, q1_val, x2, x2_val, y2, y2_val, q2, q2_val); + monotonicity2(x2, x2_val, y2, y2_val, q2, q2_val, x1, x1_val, y1, y1_val, q1, q1_val); + } + } + return propagated; + } + + /* + * TODO + * + * Maybe also + * x*y = k => \/_{j is such that there is j', j*j' = k} x = j + * x*y = k & ~ovfl(x,y) & x = j => y = k/j where j is a divisor of k + */ + + + /** + * [x] p(x) <= q(x) where value(p) > value(q) + * ==> q <= value(q) => p <= value(q) + * + * for strict? + * p(x) < q(x) where value(p) >= value(q) + * ==> value(p) <= p => value(p) < q + */ + bool saturation::try_tangent(pvar v, conflict& core, inequality const& c) { + set_rule("[x] p(x) <= q(x) where value(p) > value(q)"); + // if (c.is_strict()) + // return false; + if (!c.as_signed_constraint()->contains_var(v)) + return false; + if (c.lhs().is_val() || c.rhs().is_val()) + return false; + + auto& m = s.var2pdd(v); + pdd q_l(m), e_l(m), q_r(m), e_r(m); + bool is_linear = true; + is_linear &= c.lhs().degree(v) <= 1; + is_linear &= c.rhs().degree(v) <= 1; + if (c.lhs().degree(v) == 1) { + c.lhs().factor(v, 1, q_l, e_l); + is_linear &= q_l.is_val(); + } + if (c.rhs().degree(v) == 1) { + c.rhs().factor(v, 1, q_r, e_r); + is_linear &= q_r.is_val(); + } + if (is_linear) + return false; + + if (!c.as_signed_constraint().is_currently_false(s)) + return false; + rational l_val, r_val; + if (!s.try_eval(c.lhs(), l_val)) + return false; + if (!s.try_eval(c.rhs(), r_val)) + return false; + SASSERT(c.is_strict() || l_val > r_val); + SASSERT(!c.is_strict() || l_val >= r_val); + m_lemma.reset(); + if (c.is_strict()) { + auto d = s.ule(l_val, c.lhs()); + if (d.bvalue(s) == l_false) // it is a different value conflict that contains v + return false; + m_lemma.insert_eval(~d); + auto conseq = s.ult(r_val, c.rhs()); + return add_conflict(v, core, c, conseq); + } + else { + auto d = s.ule(c.rhs(), r_val); + if (d.bvalue(s) == l_false) // it is a different value conflict that contains v + return false; + m_lemma.insert_eval(~d); + auto conseq = s.ule(c.lhs(), r_val); + return add_conflict(v, core, c, conseq); + } + } + +} diff --git a/src/sat/smt/polysat/saturation.h b/src/sat/smt/polysat/saturation.h new file mode 100644 index 000000000..f0dcc56ce --- /dev/null +++ b/src/sat/smt/polysat/saturation.h @@ -0,0 +1,241 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + Polysat core saturation + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-6 + +--*/ +#pragma once + +#include "math/polysat/constraints.h" + +namespace polysat { + + struct bilinear { + rational a, b, c, d; + + rational eval(rational const& x, rational const& y) const { + return a*x*y + b*x + c*y + d; + } + + bilinear operator-() const { + bilinear r(*this); + r.a = -r.a; + r.b = -r.b; + r.c = -r.c; + r.d = -r.d; + return r; + } + + bilinear operator-(bilinear const& other) const { + bilinear r(*this); + r.a -= other.a; + r.b -= other.b; + r.c -= other.c; + r.d -= other.d; + return r; + } + + bilinear operator+(rational const& d) const { + bilinear r(*this); + r.d += d; + return r; + } + + bilinear operator-(rational const& d) const { + bilinear r(*this); + r.d -= d; + return r; + } + + bilinear operator-(int d) const { + bilinear r(*this); + r.d -= d; + return r; + } +}; + + inline std::ostream& operator<<(std::ostream& out, bilinear const& b) { + return out << b.a << "*x*y + " << b.b << "*x + " << b.c << "*y + " << b.d; + } + + /** + * Introduce lemmas that derive new (simpler) constraints from the current conflict and partial model. + */ + class saturation { + + core& c; + constraints& C; + char const* m_rule = nullptr; + +#if 0 + parity_tracker m_parity_tracker; + unsigned_vector m_occ; + unsigned_vector m_occ_cnt; + + void set_rule(char const* r) { m_rule = r; } + + bool is_non_overflow(pdd const& x, pdd const& y, signed_constraint& c); + signed_constraint ineq(bool strict, pdd const& lhs, pdd const& rhs); + + void log_lemma(pvar v, conflict& core); + bool propagate(pvar v, conflict& core, signed_constraint crit1, signed_constraint c); + bool propagate(pvar v, conflict& core, inequality const& crit1, signed_constraint c); + bool propagate(pvar v, conflict& core, signed_constraint c); + bool add_conflict(pvar v, conflict& core, inequality const& crit1, signed_constraint c); + bool add_conflict(pvar v, conflict& core, inequality const& crit1, inequality const& crit2, signed_constraint c); + + bool try_ugt_x(pvar v, conflict& core, inequality const& c); + + bool try_ugt_y(pvar v, conflict& core, inequality const& c); + bool try_ugt_y(pvar v, conflict& core, inequality const& l_y, inequality const& yx_l_zx, pdd const& x, pdd const& z); + + bool try_y_l_ax_and_x_l_z(pvar x, conflict& core, inequality const& c); + bool try_y_l_ax_and_x_l_z(pvar x, conflict& core, inequality const& x_l_z, inequality const& y_l_ax, pdd const& a, pdd const& y); + + bool try_ugt_z(pvar z, conflict& core, inequality const& c); + bool try_ugt_z(pvar z, conflict& core, inequality const& x_l_z0, inequality const& yz_l_xz, pdd const& y, pdd const& x); + + bool try_parity(pvar x, conflict& core, inequality const& axb_l_y); + bool try_parity_diseq(pvar x, conflict& core, inequality const& axb_l_y); + bool try_mul_bounds(pvar x, conflict& core, inequality const& axb_l_y); + bool try_factor_equality(pvar x, conflict& core, inequality const& a_l_b); + bool try_infer_equality(pvar x, conflict& core, inequality const& a_l_b); + bool try_mul_eq_1(pvar x, conflict& core, inequality const& axb_l_y); + bool try_mul_odd(pvar x, conflict& core, inequality const& axb_l_y); + bool try_mul_eq_bound(pvar x, conflict& core, inequality const& axb_l_y); + bool try_transitivity(pvar x, conflict& core, inequality const& axb_l_y); + bool try_tangent(pvar v, conflict& core, inequality const& c); + bool try_add_overflow_bound(pvar x, conflict& core, inequality const& axb_l_y); + bool try_add_mul_bound(pvar x, conflict& core, inequality const& axb_l_y); + bool try_infer_parity_equality(pvar x, conflict& core, inequality const& a_l_b); + bool try_div_monotonicity(conflict& core); + + bool try_nonzero_upper_extract(pvar v, conflict& core, inequality const& i); + bool try_congruence(pvar v, conflict& core, inequality const& i); + + + rational round(rational const& M, rational const& x); + bool eval_round(rational const& M, pdd const& p, rational& r); + bool extract_linear_form(pdd const& q, pvar& y, rational& a, rational& b); + bool extract_bilinear_form(pvar x, pdd const& p, pvar& y, bilinear& b); + bool adjust_bound(rational const& x_min, rational const& x_max, rational const& y0, rational const& M, + bilinear& b, rational& x_split); + bool update_min(rational& y_min, rational const& x_min, rational const& x_max, + bilinear const& b); + bool update_max(rational& y_max, rational const& x_min, rational const& x_max, + bilinear const& b); + bool update_bounds_for_xs(rational const& x_min, rational const& x_max, rational& y_min, rational& y_max, + rational const& y0, bilinear b1, bilinear b2, + rational const& M, inequality const& a_l_b); + void fix_values(pvar x, pvar y, pdd const& p); + void fix_values(pvar y, pdd const& p); + + // c := lhs ~ v + // where ~ is < or <= + bool is_l_v(pvar v, inequality const& c); + + // c := v ~ rhs + bool is_g_v(pvar v, inequality const& c); + + // c := x ~ Y + bool is_x_l_Y(pvar x, inequality const& i, pdd& y); + + // c := Y ~ x + bool is_Y_l_x(pvar x, inequality const& i, pdd& y); + + // c := X*y ~ X*Z + bool is_Xy_l_XZ(pvar y, inequality const& c, pdd& x, pdd& z); + bool verify_Xy_l_XZ(pvar y, inequality const& c, pdd const& x, pdd const& z); + + // c := Y ~ Ax + bool is_Y_l_Ax(pvar x, inequality const& c, pdd& a, pdd& y); + bool verify_Y_l_Ax(pvar x, inequality const& c, pdd const& a, pdd const& y); + + // c := Ax ~ Y + bool is_Ax_l_Y(pvar x, inequality const& c, pdd& a, pdd& y); + bool verify_Ax_l_Y(pvar x, inequality const& c, pdd const& a, pdd const& y); + + // c := Ax + B ~ Y + bool is_AxB_l_Y(pvar x, inequality const& c, pdd& a, pdd& b, pdd& y); + bool verify_AxB_l_Y(pvar x, inequality const& c, pdd const& a, pdd const& b, pdd const& y); + + // c := Y ~ Ax + B + bool is_Y_l_AxB(pvar x, inequality const& c, pdd& y, pdd& a, pdd& b); + bool verify_Y_l_AxB(pvar x, inequality const& c, pdd const& y, pdd const& a, pdd& b); + + // c := Ax + B ~ Y, val(Y) = 0 + bool is_AxB_eq_0(pvar x, inequality const& c, pdd& a, pdd& b, pdd& y); + bool verify_AxB_eq_0(pvar x, inequality const& c, pdd const& a, pdd const& b, pdd const& y); + + // c := Ax + B != Y, val(Y) = 0 + bool is_AxB_diseq_0(pvar x, inequality const& c, pdd& a, pdd& b, pdd& y); + + // c := Y*X ~ z*X + bool is_YX_l_zX(pvar z, inequality const& c, pdd& x, pdd& y); + bool verify_YX_l_zX(pvar z, inequality const& c, pdd const& x, pdd const& y); + + // c := xY <= xZ + bool is_xY_l_xZ(pvar x, inequality const& c, pdd& y, pdd& z); + + // xy := x * Y + bool is_xY(pvar x, pdd const& xy, pdd& y); + + // a * b does not overflow + bool is_non_overflow(pdd const& a, pdd const& b); + + // p := coeff*x*y where coeff_x = coeff*x, x a variable + bool is_coeffxY(pdd const& coeff_x, pdd const& p, pdd& y); + + bool is_add_overflow(pvar x, inequality const& i, pdd& y, bool& is_minus); + + bool has_upper_bound(pvar x, conflict& core, rational& bound, vector& x_ge_bound); + + bool has_lower_bound(pvar x, conflict& core, rational& bound, vector& x_le_bound); + + // inequality i implies x != 0 + bool is_nonzero_by(pvar x, inequality const& i); + + // determine min/max parity of polynomial + unsigned min_parity(pdd const& p, vector& explain); + unsigned max_parity(pdd const& p, vector& explain); + unsigned min_parity(pdd const& p) { vector ex; return min_parity(p, ex); } + unsigned max_parity(pdd const& p) { vector ex; return max_parity(p, ex); } + + lbool get_multiple(const pdd& p1, const pdd& p2, pdd& out); + + bool is_forced_eq(pdd const& p, rational const& val); + bool is_forced_eq(pdd const& p, int i) { return is_forced_eq(p, rational(i)); } + + bool is_forced_diseq(pdd const& p, rational const& val, signed_constraint& c); + bool is_forced_diseq(pdd const& p, int i, signed_constraint& c) { return is_forced_diseq(p, rational(i), c); } + + bool is_forced_odd(pdd const& p, signed_constraint& c); + + bool is_forced_false(signed_constraint const& sc); + + bool is_forced_true(signed_constraint const& sc); + + bool try_inequality(pvar v, inequality const& i); + + bool try_umul_ovfl(pvar v, signed_constraint c); + bool try_umul_ovfl_noovfl(pvar v, signed_constraint c); + bool try_umul_noovfl_lo(pvar v, signed_constraint c); + bool try_umul_noovfl_bounds(pvar v, signed_constraint c); + bool try_umul_ovfl_bounds(pvar v, signed_constraint c); + + bool try_op(pvar v, signed_constraint c); +#endif + + public: + saturation(core& c); + void perform(pvar v); + bool perform(pvar v, signed_constraint sc); + }; +} From 236ec01b78fb37fd20c71bcadd1809b18025317b Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 13 Dec 2023 14:17:19 -0800 Subject: [PATCH 53/89] disable from python build Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat/{saturation.cpp => saturation.cpp.disabled} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/sat/smt/polysat/{saturation.cpp => saturation.cpp.disabled} (100%) diff --git a/src/sat/smt/polysat/saturation.cpp b/src/sat/smt/polysat/saturation.cpp.disabled similarity index 100% rename from src/sat/smt/polysat/saturation.cpp rename to src/sat/smt/polysat/saturation.cpp.disabled From e5375c4071b2791f9a87ea30b14aff4c26517a54 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 13 Dec 2023 17:19:16 -0800 Subject: [PATCH 54/89] fuzz fixes to semantics Signed-off-by: Nikolaj Bjorner --- src/sat/smt/intblast_solver.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 9960197fb..869e388ff 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -146,6 +146,7 @@ namespace intblast { auto a = expr2literal(e); auto b = mk_literal(r); ctx.mark_relevant(b); +// verbose_stream() << "add-predicate-axiom: " << mk_pp(e, m) << " == " << r << "\n"; add_equiv(a, b); } return true; @@ -625,7 +626,7 @@ namespace intblast { case OP_BUDIV: case OP_BUDIV_I: { expr* x = arg(0), * y = umod(e, 1); - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(1), a.mk_idiv(x, y)); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), a.mk_int(-1), a.mk_idiv(x, y)); break; } case OP_BUMUL_NO_OVFL: { @@ -727,10 +728,9 @@ namespace intblast { r = m.mk_ite(m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(1), a.mk_int(0)); break; case OP_BSMOD_I: - case OP_BSMOD: { - bv_expr = e; - expr* x = umod(bv_expr, 0), *y = umod(bv_expr, 1); - rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); + case OP_BSMOD: { + expr* x = umod(e, 0), *y = umod(e, 1); + rational N = bv_size(e); expr* signx = a.mk_ge(x, a.mk_int(N/2)); expr* signy = a.mk_ge(y, a.mk_int(N/2)); expr* u = a.mk_mod(x, y); @@ -750,22 +750,23 @@ namespace intblast { } case OP_BSDIV_I: case OP_BSDIV: { - // d = udiv(x mod N, y mod N) + // d = udiv(abs(x), abs(y)) // y = 0, x > 0 -> 1 // y = 0, x <= 0 -> -1 // x = 0, y != 0 -> 0 - // x < 0, y < 0 -> -d + // x > 0, y < 0 -> -d // x < 0, y > 0 -> -d // x > 0, y > 0 -> d // x < 0, y < 0 -> d - bv_expr = e; - expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 1); - rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); + expr* x = umod(e, 0), * y = umod(e, 1); + rational N = bv_size(e); expr* signx = a.mk_ge(x, a.mk_int(N / 2)); expr* signy = a.mk_ge(y, a.mk_int(N / 2)); + x = m.mk_ite(signx, a.mk_sub(a.mk_int(N), x), x); + y = m.mk_ite(signy, a.mk_sub(a.mk_int(N), y), y); expr* d = a.mk_idiv(x, y); r = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); - r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), m.mk_ite(signx, a.mk_int(-1), a.mk_int(1)), r); + r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), m.mk_ite(signx, a.mk_int(1), a.mk_int(-1)), r); break; } case OP_BSREM_I: From f91655ce153777c57878f664c8004928961ac549 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 7 Dec 2023 13:34:35 -0800 Subject: [PATCH 55/89] fix divergence reported by Guido Martinez --- src/math/lp/emonics.cpp | 15 +++++++++++++++ src/math/lp/emonics.h | 1 + src/math/lp/monic.h | 3 +++ src/math/lp/nla_core.cpp | 3 +++ 4 files changed, 22 insertions(+) diff --git a/src/math/lp/emonics.cpp b/src/math/lp/emonics.cpp index 9a9e4566b..5d0e664f2 100644 --- a/src/math/lp/emonics.cpp +++ b/src/math/lp/emonics.cpp @@ -611,4 +611,19 @@ void emonics::set_propagated(monic const& m) { m_u_f_stack.push(set_unpropagated(*this, m.var())); } +void emonics::set_bound_propagated(monic const& m) { + struct set_bound_unpropagated : public trail { + emonics& em; + unsigned var; + public: + set_bound_unpropagated(emonics& em, unsigned var): em(em), var(var) {} + void undo() override { + em[var].set_bound_propagated(false); + } + }; + SASSERT(!m.is_bound_propagated()); + (*this)[m.var()].set_bound_propagated(true); + m_u_f_stack.push(set_bound_unpropagated(*this, m.var())); +} + } diff --git a/src/math/lp/emonics.h b/src/math/lp/emonics.h index fe0b19117..55086515d 100644 --- a/src/math/lp/emonics.h +++ b/src/math/lp/emonics.h @@ -143,6 +143,7 @@ public: void after_merge_eh(unsigned r2, unsigned r1, unsigned v2, unsigned v1) {} void set_propagated(monic const& m); + void set_bound_propagated(monic const& m); // this method is required by union_find trail_stack & get_trail_stack() { return m_u_f_stack; } diff --git a/src/math/lp/monic.h b/src/math/lp/monic.h index d981b2042..19137cd31 100644 --- a/src/math/lp/monic.h +++ b/src/math/lp/monic.h @@ -59,6 +59,7 @@ class monic: public mon_eq { bool m_rsign; mutable unsigned m_visited; bool m_propagated = false; + bool m_bound_propagated = false; public: // constructors monic(lpvar v, unsigned sz, lpvar const* vs, unsigned idx): @@ -77,6 +78,8 @@ public: void sort_rvars() { std::sort(m_rvars.begin(), m_rvars.end()); } void set_propagated(bool p) { m_propagated = p; } bool is_propagated() const { return m_propagated; } + void set_bound_propagated(bool p) { m_bound_propagated = p; } + bool is_bound_propagated() const { return m_bound_propagated; } svector::const_iterator begin() const { return vars().begin(); } svector::const_iterator end() const { return vars().end(); } diff --git a/src/math/lp/nla_core.cpp b/src/math/lp/nla_core.cpp index 96f1b4a30..f36fab52e 100644 --- a/src/math/lp/nla_core.cpp +++ b/src/math/lp/nla_core.cpp @@ -1517,6 +1517,9 @@ void core::add_bounds() { for (lpvar j : m.vars()) { if (!var_is_free(j)) continue; + if (m.is_bound_propagated()) + continue; + m_emons.set_bound_propagated(m); // split the free variable (j <= 0, or j > 0), and return m_literals.push_back(ineq(j, lp::lconstraint_kind::EQ, rational::zero())); TRACE("nla_solver", print_ineq(m_literals.back(), tout) << "\n"); From 96f84c6b4478f83900edd728e39640f0255c28a2 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 11:39:55 -0800 Subject: [PATCH 56/89] kludge to fixup osver in python for Mac Signed-off-by: Nikolaj Bjorner --- src/api/python/setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/api/python/setup.py b/src/api/python/setup.py index a20ff53a3..54992156f 100644 --- a/src/api/python/setup.py +++ b/src/api/python/setup.py @@ -313,6 +313,8 @@ if 'bdist_wheel' in sys.argv and '--plat-name' not in sys.argv: osver = RELEASE_METADATA[3] if osver.count('.') > 1: osver = '.'.join(osver.split('.')[:2]) + if osver.startswith("11"): + osver = "11_0" if arch == 'x64': plat_name ='macosx_%s_x86_64' % osver.replace('.', '_') elif arch == 'arm64': From dc83c5b28d6925ae0f4e0c26e48ec4acff8213b3 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 13:05:21 -0800 Subject: [PATCH 57/89] fix #7049 Signed-off-by: Nikolaj Bjorner --- src/ast/polymorphism_inst.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ast/polymorphism_inst.cpp b/src/ast/polymorphism_inst.cpp index 4da83ee10..aa9b1e5fe 100644 --- a/src/ast/polymorphism_inst.cpp +++ b/src/ast/polymorphism_inst.cpp @@ -94,7 +94,10 @@ namespace polymorphism { t.push(value_trail(m_decl_qhead)); for (; m_decl_qhead < num_decls; ++m_decl_qhead) { func_decl* p = m_decl_queue.get(m_decl_qhead); - for (expr* e : m_occurs[m.poly_root(p)]) + func_decl* r = m.poly_root(p); + if (!m_occurs.contains(r)) + continue; + for (expr* e : m_occurs[r]) instantiate(p, e, instances); } } From 2323a5f9d2009fd1ea0bc8b6d18137022452bb32 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 13:12:05 -0800 Subject: [PATCH 58/89] try fix suggested in #7041 Signed-off-by: Nikolaj Bjorner --- src/api/python/pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/api/python/pyproject.toml b/src/api/python/pyproject.toml index a9f2676a7..aa4c50adf 100644 --- a/src/api/python/pyproject.toml +++ b/src/api/python/pyproject.toml @@ -1,3 +1,6 @@ [build-system] requires = ["setuptools>=46.4.0", "wheel", "cmake"] build-backend = "setuptools.build_meta" + +[project] +dependencies = ["importlib-resources" ] \ No newline at end of file From 7b145f36bd8fdfb0324e61c3381d57d443f53c1a Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 14:57:45 -0800 Subject: [PATCH 59/89] try add name to project Signed-off-by: Nikolaj Bjorner --- src/api/python/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/api/python/pyproject.toml b/src/api/python/pyproject.toml index aa4c50adf..4c199fb26 100644 --- a/src/api/python/pyproject.toml +++ b/src/api/python/pyproject.toml @@ -3,4 +3,5 @@ requires = ["setuptools>=46.4.0", "wheel", "cmake"] build-backend = "setuptools.build_meta" [project] -dependencies = ["importlib-resources" ] \ No newline at end of file +name = "z3-solver" +dependencies = ["importlib-resources", ] \ No newline at end of file From dd271563d361df1376c9d436e487c73cae982485 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 15:50:09 -0800 Subject: [PATCH 60/89] add version Signed-off-by: Nikolaj Bjorner --- src/api/python/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/src/api/python/pyproject.toml b/src/api/python/pyproject.toml index 4c199fb26..2f441d000 100644 --- a/src/api/python/pyproject.toml +++ b/src/api/python/pyproject.toml @@ -4,4 +4,5 @@ build-backend = "setuptools.build_meta" [project] name = "z3-solver" +version = "4" dependencies = ["importlib-resources", ] \ No newline at end of file From 165d81cac4ae590a6d3df0eb996dd789701b97ee Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 18:38:29 -0800 Subject: [PATCH 61/89] follow error message to put dependencies in setup args Signed-off-by: Nikolaj Bjorner --- src/api/python/pyproject.toml | 5 ----- src/api/python/setup.py | 1 + 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/api/python/pyproject.toml b/src/api/python/pyproject.toml index 2f441d000..a9f2676a7 100644 --- a/src/api/python/pyproject.toml +++ b/src/api/python/pyproject.toml @@ -1,8 +1,3 @@ [build-system] requires = ["setuptools>=46.4.0", "wheel", "cmake"] build-backend = "setuptools.build_meta" - -[project] -name = "z3-solver" -version = "4" -dependencies = ["importlib-resources", ] \ No newline at end of file diff --git a/src/api/python/setup.py b/src/api/python/setup.py index 54992156f..325fb4230 100644 --- a/src/api/python/setup.py +++ b/src/api/python/setup.py @@ -341,6 +341,7 @@ setup( license='MIT License', keywords=['z3', 'smt', 'sat', 'prover', 'theorem'], packages=['z3'], + install_requires = ['importlib-resources'], include_package_data=True, package_data={ 'z3': [os.path.join('lib', '*'), os.path.join('include', '*.h'), os.path.join('include', 'c++', '*.h')] From c33859d7290ba375145f0b2f93c68cd6ba087511 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 10:22:35 -0800 Subject: [PATCH 62/89] try adding readme again Signed-off-by: Nikolaj Bjorner --- scripts/mk_nuget_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/mk_nuget_task.py b/scripts/mk_nuget_task.py index ef41051d8..8e67bd65e 100644 --- a/scripts/mk_nuget_task.py +++ b/scripts/mk_nuget_task.py @@ -149,7 +149,7 @@ class Env: unpack(self.packages, self.symbols, self.arch) mk_targets(self.source_root) mk_icon(self.source_root) -# mk_readme(self.source_root) + mk_readme(self.source_root) create_nuget_spec(self.version, self.repo, self.branch, self.commit, self.symbols, self.arch) def main(): From 14935529b80c7a1ed1ee8e6ede2b17f8600e1392 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 11:11:37 -0800 Subject: [PATCH 63/89] add readme under content Signed-off-by: Nikolaj Bjorner --- scripts/mk_nuget_task.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/mk_nuget_task.py b/scripts/mk_nuget_task.py index 8e67bd65e..fe06bdba2 100644 --- a/scripts/mk_nuget_task.py +++ b/scripts/mk_nuget_task.py @@ -91,7 +91,7 @@ def mk_icon(source_root): def mk_readme(source_root): mk_dir("out/content") - shutil.copy(f"{source_root}/src/api/dotnet/README.md", "out/README.md") + shutil.copy(f"{source_root}/src/api/dotnet/README.md", "out/content/README.md") @@ -112,6 +112,7 @@ Linux Dependencies: © Microsoft Corporation. All rights reserved. smt constraint solver theorem prover content/icon.jpg + content/README>md https://github.com/Z3Prover/z3 MIT @@ -121,6 +122,9 @@ Linux Dependencies: + + + """.format(version, repo, branch, commit, arch) print(contents) sym = "sym." if symbols else "" From 46baa449b3746c226cf11124e41f8db62b1a8733 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 12:33:29 -0800 Subject: [PATCH 64/89] nuget spec: does this work? Signed-off-by: Nikolaj Bjorner --- scripts/mk_nuget_task.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/scripts/mk_nuget_task.py b/scripts/mk_nuget_task.py index fe06bdba2..5e6d3e039 100644 --- a/scripts/mk_nuget_task.py +++ b/scripts/mk_nuget_task.py @@ -88,9 +88,6 @@ def mk_targets(source_root): def mk_icon(source_root): mk_dir("out/content") shutil.copy(f"{source_root}/resources/icon.jpg", "out/content/icon.jpg") - -def mk_readme(source_root): - mk_dir("out/content") shutil.copy(f"{source_root}/src/api/dotnet/README.md", "out/content/README.md") @@ -124,6 +121,7 @@ Linux Dependencies: + """.format(version, repo, branch, commit, arch) print(contents) @@ -153,7 +151,6 @@ class Env: unpack(self.packages, self.symbols, self.arch) mk_targets(self.source_root) mk_icon(self.source_root) - mk_readme(self.source_root) create_nuget_spec(self.version, self.repo, self.branch, self.commit, self.symbols, self.arch) def main(): From b76eabb5871827418595f0dcbcd268170ba7b572 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 10 Dec 2023 20:50:32 -0800 Subject: [PATCH 65/89] fix character Signed-off-by: Nikolaj Bjorner --- scripts/mk_nuget_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/mk_nuget_task.py b/scripts/mk_nuget_task.py index 5e6d3e039..d3f402773 100644 --- a/scripts/mk_nuget_task.py +++ b/scripts/mk_nuget_task.py @@ -109,7 +109,7 @@ Linux Dependencies: © Microsoft Corporation. All rights reserved. smt constraint solver theorem prover content/icon.jpg - content/README>md + content/README.md https://github.com/Z3Prover/z3 MIT From a614ac7d95e21b432f54a44b411cda03260946cc Mon Sep 17 00:00:00 2001 From: Bruce Mitchener Date: Thu, 14 Dec 2023 00:36:41 +0700 Subject: [PATCH 66/89] tptr.h: Include `` once rather than twice. (#7051) --- src/util/tptr.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/util/tptr.h b/src/util/tptr.h index 50e9417fe..37b6f64fe 100644 --- a/src/util/tptr.h +++ b/src/util/tptr.h @@ -21,7 +21,6 @@ Revision History: #include #include "util/machine.h" -#include #define TAG_SHIFT PTR_ALIGNMENT #define ALIGNMENT_VALUE (1 << PTR_ALIGNMENT) From 1a39def7a1002d3bdd6656973dd5d6722e6de17f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 7 Dec 2023 15:53:07 -0800 Subject: [PATCH 67/89] v2 of polysat --- src/sat/smt/polysat_assignment.cpp | 119 +++++++ src/sat/smt/polysat_assignment.h | 120 +++++++ src/sat/smt/polysat_constraints.cpp | 25 ++ src/sat/smt/polysat_constraints.h | 128 +++++++ src/sat/smt/polysat_core.cpp | 276 +++++++++++++++ src/sat/smt/polysat_core.h | 128 +++++++ src/sat/smt/polysat_internalize.cpp | 526 ---------------------------- src/sat/smt/polysat_substitution.h | 212 +++++++++++ src/sat/smt/polysat_types.h | 45 +++ src/sat/smt/polysat_viable.h | 55 +++ 10 files changed, 1108 insertions(+), 526 deletions(-) create mode 100644 src/sat/smt/polysat_assignment.cpp create mode 100644 src/sat/smt/polysat_assignment.h create mode 100644 src/sat/smt/polysat_constraints.cpp create mode 100644 src/sat/smt/polysat_constraints.h create mode 100644 src/sat/smt/polysat_core.cpp create mode 100644 src/sat/smt/polysat_core.h delete mode 100644 src/sat/smt/polysat_internalize.cpp create mode 100644 src/sat/smt/polysat_substitution.h create mode 100644 src/sat/smt/polysat_types.h create mode 100644 src/sat/smt/polysat_viable.h diff --git a/src/sat/smt/polysat_assignment.cpp b/src/sat/smt/polysat_assignment.cpp new file mode 100644 index 000000000..a985188fa --- /dev/null +++ b/src/sat/smt/polysat_assignment.cpp @@ -0,0 +1,119 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat substitution and assignment + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ + +#include "sat/smt/polysat_assignment.h" +#include "sat/smt/polysat_core.h" + +namespace polysat { + + substitution::substitution(pdd p) + : m_subst(std::move(p)) { } + + substitution::substitution(dd::pdd_manager& m) + : m_subst(m.one()) { } + + substitution substitution::add(pvar var, rational const& value) const { + return {m_subst.subst_add(var, value)}; + } + + pdd substitution::apply_to(pdd const& p) const { + return p.subst_val(m_subst); + } + + bool substitution::contains(pvar var) const { + rational out_value; + return value(var, out_value); + } + + bool substitution::value(pvar var, rational& out_value) const { + return m_subst.subst_get(var, out_value); + } + + assignment::assignment(core& s) + : m_core(s) { } + + + assignment assignment::clone() const { + assignment a(s()); + a.m_pairs = m_pairs; + a.m_subst.reserve(m_subst.size()); + for (unsigned i = m_subst.size(); i-- > 0; ) + if (m_subst[i]) + a.m_subst.set(i, alloc(substitution, *m_subst[i])); + a.m_subst_trail = m_subst_trail; + return a; + } + + bool assignment::contains(pvar var) const { + return subst(s().size(var)).contains(var); + } + + bool assignment::value(pvar var, rational& out_value) const { + return subst(s().size(var)).value(var, out_value); + } + + substitution& assignment::subst(unsigned sz) { + return const_cast(std::as_const(*this).subst(sz)); + } + + substitution const& assignment::subst(unsigned sz) const { + m_subst.reserve(sz + 1); + if (!m_subst[sz]) + m_subst.set(sz, alloc(substitution, s().sz2pdd(sz))); + return *m_subst[sz]; + } + + void assignment::push(pvar var, rational const& value) { + SASSERT(all_of(m_pairs, [var](assignment_item_t const& item) { return item.first != var; })); + m_pairs.push_back({var, value}); + unsigned const sz = s().size(var); + substitution& sub = subst(sz); + m_subst_trail.push_back(sub); + sub = sub.add(var, value); + SASSERT_EQ(sub, *m_subst[sz]); + } + + void assignment::pop() { + substitution& sub = m_subst_trail.back(); + unsigned sz = sub.bit_width(); + SASSERT_EQ(sz, s().size(m_pairs.back().first)); + *m_subst[sz] = sub; + m_subst_trail.pop_back(); + m_pairs.pop_back(); + } + + pdd assignment::apply_to(pdd const& p) const { + unsigned const sz = p.power_of_2(); + return subst(sz).apply_to(p); + } + + std::ostream& substitution::display(std::ostream& out) const { + char const* delim = ""; + pdd p = m_subst; + while (!p.is_val()) { + SASSERT(p.lo().is_val()); + out << delim << "v" << p.var() << " := " << p.lo(); + delim = " "; + p = p.hi(); + } + return out; + } + + std::ostream& assignment::display(std::ostream& out) const { + char const* delim = ""; + for (auto const& [var, value] : m_pairs) + out << delim << var << " == " << value, delim = " "; + return out; + } +} diff --git a/src/sat/smt/polysat_assignment.h b/src/sat/smt/polysat_assignment.h new file mode 100644 index 000000000..daff03dd5 --- /dev/null +++ b/src/sat/smt/polysat_assignment.h @@ -0,0 +1,120 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat substitution and assignment + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once +#include "util/scoped_ptr_vector.h" +#include "sat/smt/polysat_types.h" + +namespace polysat { + + class core; + + using assignment_item_t = std::pair; + + class substitution_iterator { + pdd m_current; + substitution_iterator(pdd current) : m_current(std::move(current)) {} + friend class substitution; + + public: + using value_type = assignment_item_t; + using difference_type = std::ptrdiff_t; + using pointer = value_type const*; + using reference = value_type const&; + using iterator_category = std::input_iterator_tag; + + substitution_iterator& operator++() { + SASSERT(!m_current.is_val()); + m_current = m_current.hi(); + return *this; + } + + value_type operator*() const { + SASSERT(!m_current.is_val()); + return { m_current.var(), m_current.lo().val() }; + } + + bool operator==(substitution_iterator const& other) const { return m_current == other.m_current; } + bool operator!=(substitution_iterator const& other) const { return !operator==(other); } + }; + + /** Substitution for a single bit width. */ + class substitution { + pdd m_subst; + + substitution(pdd p); + + public: + substitution(dd::pdd_manager& m); + [[nodiscard]] substitution add(pvar var, rational const& value) const; + [[nodiscard]] pdd apply_to(pdd const& p) const; + + [[nodiscard]] bool contains(pvar var) const; + [[nodiscard]] bool value(pvar var, rational& out_value) const; + + [[nodiscard]] bool empty() const { return m_subst.is_one(); } + + pdd const& to_pdd() const { return m_subst; } + unsigned bit_width() const { return to_pdd().power_of_2(); } + + bool operator==(substitution const& other) const { return &m_subst.manager() == &other.m_subst.manager() && m_subst == other.m_subst; } + bool operator!=(substitution const& other) const { return !operator==(other); } + + std::ostream& display(std::ostream& out) const; + + using const_iterator = substitution_iterator; + const_iterator begin() const { return {m_subst}; } + const_iterator end() const { return {m_subst.manager().one()}; } + }; + + /** Full variable assignment, may include variables of varying bit widths. */ + class assignment { + core& m_core; + vector m_pairs; + mutable scoped_ptr_vector m_subst; + vector m_subst_trail; + + substitution& subst(unsigned sz); + core& s() const { return m_core; } + public: + assignment(core& s); + // prevent implicit copy, use clone() if you do need a copy + assignment(assignment const&) = delete; + assignment& operator=(assignment const&) = delete; + assignment(assignment&&) = default; + assignment& operator=(assignment&&) = default; + assignment clone() const; + + void push(pvar var, rational const& value); + void pop(); + + pdd apply_to(pdd const& p) const; + + bool contains(pvar var) const; + bool value(pvar var, rational& out_value) const; + rational value(pvar var) const { rational val; VERIFY(value(var, val)); return val; } + bool empty() const { return pairs().empty(); } + substitution const& subst(unsigned sz) const; + vector const& pairs() const { return m_pairs; } + using const_iterator = decltype(m_pairs)::const_iterator; + const_iterator begin() const { return pairs().begin(); } + const_iterator end() const { return pairs().end(); } + + std::ostream& display(std::ostream& out) const; + }; + + inline std::ostream& operator<<(std::ostream& out, substitution const& sub) { return sub.display(out); } + + inline std::ostream& operator<<(std::ostream& out, assignment const& a) { return a.display(out); } +} + diff --git a/src/sat/smt/polysat_constraints.cpp b/src/sat/smt/polysat_constraints.cpp new file mode 100644 index 000000000..1c9de327c --- /dev/null +++ b/src/sat/smt/polysat_constraints.cpp @@ -0,0 +1,25 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat constraints + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#include "sat/smt/polysat_core.h" +#include "sat/smt/polysat_solver.h" +#include "sat/smt/polysat_constraints.h" + +namespace polysat { + + signed_constraint constraints::ule(pdd const& p, pdd const& q) { + auto* c = alloc(ule_constraint, p, q); + m_trail.push(new_obj_trail(c)); + return signed_constraint(ckind_t::ule_t, c); + } +} diff --git a/src/sat/smt/polysat_constraints.h b/src/sat/smt/polysat_constraints.h new file mode 100644 index 000000000..24c7f9a11 --- /dev/null +++ b/src/sat/smt/polysat_constraints.h @@ -0,0 +1,128 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat constraints + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ + + +#pragma once +#include "sat/smt/polysat_types.h" + +namespace polysat { + + class core; + + using pdd = dd::pdd; + using pvar = unsigned; + + enum ckind_t { ule_t, umul_ovfl_t, smul_fl_t, op_t }; + + class constraint { + unsigned_vector m_vars; + public: + virtual ~constraint() {} + unsigned_vector& vars() { return m_vars; } + unsigned_vector const& vars() const { return m_vars; } + unsigned var(unsigned idx) const { return m_vars[idx]; } + bool contains_var(pvar v) const { return m_vars.contains(v); } + }; + + class ule_constraint : public constraint { + pdd m_lhs, m_rhs; + public: + ule_constraint(pdd const& lhs, pdd const& rhs) : m_lhs(lhs), m_rhs(rhs) {} + }; + + class signed_constraint { + bool m_sign = false; + ckind_t m_op = ule_t; + constraint* m_constraint = nullptr; + public: + signed_constraint() {} + signed_constraint(ckind_t c, constraint* p) : m_op(c), m_constraint(p) {} + signed_constraint operator~() const { signed_constraint r(*this); r.m_sign = !r.m_sign; return r; } + bool sign() const { return m_sign; } + unsigned_vector& vars() { return m_constraint->vars(); } + unsigned_vector const& vars() const { return m_constraint->vars(); } + unsigned var(unsigned idx) const { return m_constraint->var(idx); } + bool contains_var(pvar v) const { return m_constraint->contains_var(v); } + bool is_ule() const { return m_op == ule_t; } + ule_constraint& to_ule() { return *reinterpret_cast(m_constraint); } + bool is_eq(pvar& v, rational& val) { throw default_exception("nyi"); } + }; + + using dependent_constraint = std::pair; + + class constraints { + trail_stack& m_trail; + public: + constraints(trail_stack& c) : m_trail(c) {} + + signed_constraint eq(pdd const& p) { return ule(p, p.manager().mk_val(0)); } + signed_constraint eq(pdd const& p, rational const& v) { return eq(p - p.manager().mk_val(v)); } + signed_constraint ule(pdd const& p, pdd const& q); + signed_constraint sle(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint ult(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint slt(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint umul_ovfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint smul_ovfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint smul_udfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } + signed_constraint bit(pdd const& p, unsigned i) { throw default_exception("nyi"); } + + signed_constraint diseq(pdd const& p) { return ~eq(p); } + signed_constraint diseq(pdd const& p, pdd const& q) { return diseq(p - q); } + signed_constraint diseq(pdd const& p, rational const& q) { return diseq(p - q); } + signed_constraint diseq(pdd const& p, int q) { return diseq(p, rational(q)); } + signed_constraint diseq(pdd const& p, unsigned q) { return diseq(p, rational(q)); } + + signed_constraint ule(pdd const& p, rational const& q) { return ule(p, p.manager().mk_val(q)); } + signed_constraint ule(rational const& p, pdd const& q) { return ule(q.manager().mk_val(p), q); } + signed_constraint ule(pdd const& p, int q) { return ule(p, rational(q)); } + signed_constraint ule(pdd const& p, unsigned q) { return ule(p, rational(q)); } + signed_constraint ule(int p, pdd const& q) { return ule(rational(p), q); } + signed_constraint ule(unsigned p, pdd const& q) { return ule(rational(p), q); } + + signed_constraint uge(pdd const& p, pdd const& q) { return ule(q, p); } + signed_constraint uge(pdd const& p, rational const& q) { return ule(q, p); } + + signed_constraint ult(pdd const& p, rational const& q) { return ult(p, p.manager().mk_val(q)); } + signed_constraint ult(rational const& p, pdd const& q) { return ult(q.manager().mk_val(p), q); } + signed_constraint ult(int p, pdd const& q) { return ult(rational(p), q); } + signed_constraint ult(unsigned p, pdd const& q) { return ult(rational(p), q); } + signed_constraint ult(pdd const& p, int q) { return ult(p, rational(q)); } + signed_constraint ult(pdd const& p, unsigned q) { return ult(p, rational(q)); } + + signed_constraint slt(pdd const& p, rational const& q) { return slt(p, p.manager().mk_val(q)); } + signed_constraint slt(rational const& p, pdd const& q) { return slt(q.manager().mk_val(p), q); } + signed_constraint slt(pdd const& p, int q) { return slt(p, rational(q)); } + signed_constraint slt(pdd const& p, unsigned q) { return slt(p, rational(q)); } + signed_constraint slt(int p, pdd const& q) { return slt(rational(p), q); } + signed_constraint slt(unsigned p, pdd const& q) { return slt(rational(p), q); } + + + signed_constraint sgt(pdd const& p, pdd const& q) { return slt(q, p); } + signed_constraint sgt(pdd const& p, int q) { return slt(q, p); } + signed_constraint sgt(pdd const& p, unsigned q) { return slt(q, p); } + signed_constraint sgt(int p, pdd const& q) { return slt(q, p); } + signed_constraint sgt(unsigned p, pdd const& q) { return slt(q, p); } + + signed_constraint umul_ovfl(pdd const& p, rational const& q) { return umul_ovfl(p, p.manager().mk_val(q)); } + signed_constraint umul_ovfl(rational const& p, pdd const& q) { return umul_ovfl(q.manager().mk_val(p), q); } + signed_constraint umul_ovfl(pdd const& p, int q) { return umul_ovfl(p, rational(q)); } + signed_constraint umul_ovfl(pdd const& p, unsigned q) { return umul_ovfl(p, rational(q)); } + signed_constraint umul_ovfl(int p, pdd const& q) { return umul_ovfl(rational(p), q); } + signed_constraint umul_ovfl(unsigned p, pdd const& q) { return umul_ovfl(rational(p), q); } + + + //signed_constraint even(pdd const& p) { return parity_at_least(p, 1); } + //signed_constraint odd(pdd const& p) { return ~even(p); } + }; +} \ No newline at end of file diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp new file mode 100644 index 000000000..27d6ee731 --- /dev/null +++ b/src/sat/smt/polysat_core.cpp @@ -0,0 +1,276 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + polysat_core.cpp + +Abstract: + + PolySAT core functionality + +Author: + + Nikolaj Bjorner (nbjorner) 2022-01-26 + Jakob Rath 2021-04-06 + +Notes: + +polysat::solver +- adds assignments +- calls propagation and check + +polysat::core +- propagates literals +- crates case splits by value assignment (equalities) +- detects conflicts based on Literal assignmets +- adds lemmas based on projections + +--*/ + +#include "params/bv_rewriter_params.hpp" +#include "sat/smt/polysat_solver.h" +#include "sat/smt/euf_solver.h" + +namespace polysat { + + class core::mk_assign_var : public trail { + pvar m_var; + core& c; + public: + mk_assign_var(pvar v, core& c) : m_var(v), c(c) {} + void undo() { + c.m_justification[m_var] = nullptr; + c.m_assignment.pop(); + } + }; + + class core::mk_dqueue_var : public trail { + pvar m_var; + core& c; + public: + mk_dqueue_var(pvar v, core& c) : m_var(v), c(c) {} + void undo() { + c.m_var_queue.unassign_var_eh(m_var); + } + }; + + class core::mk_add_var : public trail { + core& c; + public: + mk_add_var(core& c) : c(c) {} + void undo() override { + c.del_var(); + } + }; + + class core::mk_add_watch : public trail { + core& c; + unsigned m_idx; + public: + mk_add_watch(core& c, unsigned idx) : c(c), m_idx(idx) {} + void undo() override { + auto& sc = c.m_prop_queue[m_idx].first; + auto& vars = sc.vars(); + if (vars.size() > 0) + c.m_watch[vars[0]].pop_back(); + if (vars.size() > 1) + c.m_watch[vars[1]].pop_back(); + } + }; + + core::core(solver& s) : + s(s), + m_viable(*this), + m_constraints(s.get_trail_stack()), + m_assignment(*this), + m_dep(s.get_region()), + m_var_queue(m_activity) + {} + + pdd core::value(rational const& v, unsigned sz) { + return sz2pdd(sz).mk_val(v); + } + + dd::pdd_manager& core::sz2pdd(unsigned sz) const { + m_pdd.reserve(sz + 1); + if (!m_pdd[sz]) + m_pdd.set(sz, alloc(dd::pdd_manager, 1000, dd::pdd_manager::semantics::mod2N_e, sz)); + return *m_pdd[sz]; + } + + dd::pdd_manager& core::var2pdd(pvar v) const { + return sz2pdd(size(v)); + } + + pvar core::add_var(unsigned sz) { + unsigned v = m_vars.size(); + m_vars.push_back(sz2pdd(sz).mk_var(v)); + m_activity.push_back({ sz, 0 }); + m_justification.push_back(nullptr); + m_watch.push_back({}); + m_var_queue.mk_var_eh(v); + s.ctx.push(mk_add_var(*this)); + return v; + } + + void core::del_var() { + unsigned v = m_vars.size() - 1; + m_vars.pop_back(); + m_activity.pop_back(); + m_justification.pop_back(); + m_watch.pop_back(); + m_var_queue.del_var_eh(v); + } + + // case split on unassigned variables until all are assigned values. + // create equality literal for unassigned variable. + // return new_eq if there is a new literal. + + sat::check_result core::check() { + if (m_var_queue.empty()) + return sat::check_result::CR_DONE; + m_var = m_var_queue.next_var(); + s.ctx.push(mk_dqueue_var(m_var, *this)); + switch (m_viable.find_viable(m_var, m_value)) { + case find_t::empty: + m_unsat_core = m_viable.explain(); + propagate_unsat_core(); + return sat::check_result::CR_CONTINUE; + case find_t::singleton: + s.propagate(m_constraints.eq(var2pdd(m_var), m_value), m_viable.explain()); + return sat::check_result::CR_CONTINUE; + case find_t::multiple: + return sat::check_result::CR_CONTINUE; + case find_t::resource_out: + return sat::check_result::CR_GIVEUP; + } + UNREACHABLE(); + return sat::check_result::CR_GIVEUP; + } + + // First propagate Boolean assignment, then propagate value assignment + bool core::propagate() { + if (m_qhead == m_prop_queue.size() && m_vqhead == m_prop_queue.size()) + return false; + s.ctx.push(value_trail(m_qhead)); + for (; m_qhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_qhead) + propagate_constraint(m_qhead, m_prop_queue[m_qhead]); + s.ctx.push(value_trail(m_vqhead)); + for (; m_vqhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_vqhead) + propagate_value(m_vqhead, m_prop_queue[m_vqhead]); + return true; + } + + void core::propagate_constraint(unsigned idx, dependent_constraint& dc) { + auto [sc, dep] = dc; + if (sc.is_eq(m_var, m_value)) { + propagate_assignment(m_var, m_value, dep); + return; + } + add_watch(idx, sc); + } + + void core::add_watch(unsigned idx, signed_constraint& sc) { + auto& vars = sc.vars(); + unsigned i = 0, j = 0, sz = vars.size(); + for (; i < sz && j < 2; ++i) + if (!is_assigned(vars[i])) + std::swap(vars[i], vars[j++]); + if (vars.size() > 0) + add_watch(idx, vars[0]); + if (vars.size() > 1) + add_watch(idx, vars[1]); + s.ctx.push(mk_add_watch(*this, idx)); + } + + void core::add_watch(unsigned idx, unsigned var) { + m_watch[var].push_back(idx); + } + + void core::propagate_assignment(pvar v, rational const& value, stacked_dependency* dep) { + if (is_assigned(v)) + return; + if (m_var_queue.contains(v)) { + m_var_queue.del_var_eh(v); + s.ctx.push(mk_dqueue_var(v, *this)); + } + m_values[v] = value; + m_justification[v] = dep; + m_assignment.push(v , value); + s.ctx.push(mk_assign_var(v, *this)); + + // update the watch lists for pvars + // remove constraints from m_watch[v] that have more than 2 free variables. + // for entries where there is only one free variable left add to viable set + unsigned j = 0; + for (auto idx : m_watch[v]) { + auto [sc, dep] = m_prop_queue[idx]; + auto& vars = sc.vars(); + if (vars[0] != v) + std::swap(vars[0], vars[1]); + SASSERT(vars[0] == v); + bool swapped = false; + for (unsigned i = vars.size(); i-- > 2; ) { + if (!is_assigned(vars[i])) { + add_watch(idx, vars[i]); + std::swap(vars[i], vars[0]); + swapped = true; + break; + } + } + if (!swapped) { + m_watch[v][j++] = idx; + } + + // constraint is unitary, add to viable set + if (vars.size() >= 2 && is_assigned(vars[0]) && !is_assigned(vars[1])) { + // m_viable.add_unitary(vars[1], idx); + } + } + m_watch[v].shrink(j); + } + + void core::propagate_value(unsigned idx, dependent_constraint const& dc) { + auto [sc, dep] = dc; + // check if sc evaluates to false + switch (eval(sc)) { + case l_true: + return; + case l_false: + m_unsat_core = explain_eval(dc); + propagate_unsat_core(); + return; + default: + break; + } + // if sc is v == value, then check the watch list for v if they evaluate to false. + if (sc.is_eq(m_var, m_value)) { + for (auto idx : m_watch[m_var]) { + auto [sc, dep] = m_prop_queue[idx]; + if (eval(sc) == l_false) { + m_unsat_core = explain_eval({ sc, dep }); + propagate_unsat_core(); + return; + } + } + } + + throw default_exception("nyi"); + } + + bool core::propagate_unsat_core() { + // default is to use unsat core: + s.set_conflict(m_unsat_core); + // if core is based on viable, use s.set_lemma(); + throw default_exception("nyi"); + } + + void core::assign_eh(signed_constraint const& sc, dependency const& dep) { + m_prop_queue.push_back({ sc, m_dep.mk_leaf(dep) }); + s.ctx.push(push_back_vector(m_prop_queue)); + } + + + +} diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat_core.h new file mode 100644 index 000000000..7fdf8c88c --- /dev/null +++ b/src/sat/smt/polysat_core.h @@ -0,0 +1,128 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + polysat_core.h + +Abstract: + + Core solver for polysat + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-30 + Jakob Rath 2021-04-06 + +--*/ +#pragma once + +#include "util/dependency.h" +#include "math/dd/dd_pdd.h" +#include "sat/smt/sat_th.h" +#include "sat/smt/polysat_types.h" +#include "sat/smt/polysat_constraints.h" +#include "sat/smt/polysat_viable.h" +#include "sat/smt/polysat_assignment.h" + +namespace polysat { + + class core; + class solver; + + class core { + class mk_add_var; + class mk_dqueue_var; + class mk_assign_var; + class mk_add_watch; + typedef svector> activity; + friend class viable; + friend class constraints; + friend class assignment; + + solver& s; + viable m_viable; + constraints m_constraints; + assignment m_assignment; + unsigned m_qhead = 0, m_vqhead = 0; + svector m_prop_queue; + stacked_dependency_manager m_dep; + mutable scoped_ptr_vector m_pdd; + dependency_vector m_unsat_core; + + + // attributes associated with variables + vector m_vars; // for each variable a pdd + vector m_values; // current value of assigned variable + ptr_vector m_justification; // justification for assignment + activity m_activity; // activity of variables + var_queue m_var_queue; // priority queue of variables to assign + vector m_watch; // watch lists for variables for constraints on m_prop_queue where they occur + + vector m_subst; // substitution, one for each size. + + // values to split on + rational m_value; + pvar m_var = 0; + + dd::pdd_manager& sz2pdd(unsigned sz) const; + dd::pdd_manager& var2pdd(pvar v) const; + unsigned size(pvar v) const { return var2pdd(v).power_of_2(); } + void del_var(); + + bool is_assigned(pvar v) { return nullptr != m_justification[v]; } + void propagate_constraint(unsigned idx, dependent_constraint& dc); + void propagate_value(unsigned idx, dependent_constraint const& dc); + void propagate_assignment(pvar v, rational const& value, stacked_dependency* dep); + bool propagate_unsat_core(); + + void add_watch(unsigned idx, signed_constraint& sc); + void add_watch(unsigned idx, unsigned var); + + lbool eval(signed_constraint sc) { throw default_exception("nyi"); } + dependency_vector explain_eval(dependent_constraint const& dc) { throw default_exception("nyi"); } + + public: + core(solver& s); + + sat::check_result check(); + + bool propagate(); + void assign_eh(signed_constraint const& sc, dependency const& dep); + + expr_ref constraint2expr(signed_constraint const& sc) const { throw default_exception("nyi"); } + + pdd value(rational const& v, unsigned sz); + + signed_constraint eq(pdd const& p) { return m_constraints.eq(p); } + signed_constraint eq(pdd const& p, pdd const& q) { return m_constraints.eq(p - q); } + signed_constraint ule(pdd const& p, pdd const& q) { return m_constraints.ule(p, q); } + signed_constraint sle(pdd const& p, pdd const& q) { return m_constraints.sle(p, q); } + signed_constraint umul_ovfl(pdd const& p, pdd const& q) { return m_constraints.umul_ovfl(p, q); } + signed_constraint smul_ovfl(pdd const& p, pdd const& q) { return m_constraints.smul_ovfl(p, q); } + signed_constraint smul_udfl(pdd const& p, pdd const& q) { return m_constraints.smul_udfl(p, q); } + signed_constraint bit(pdd const& p, unsigned i) { return m_constraints.bit(p, i); } + + + pdd lshr(pdd a, pdd b) { throw default_exception("nyi"); } + pdd ashr(pdd a, pdd b) { throw default_exception("nyi"); } + pdd shl(pdd a, pdd b) { throw default_exception("nyi"); } + pdd band(pdd a, pdd b) { throw default_exception("nyi"); } + pdd bxor(pdd a, pdd b) { throw default_exception("nyi"); } + pdd bor(pdd a, pdd b) { throw default_exception("nyi"); } + pdd bnand(pdd a, pdd b) { throw default_exception("nyi"); } + pdd bxnor(pdd a, pdd b) { throw default_exception("nyi"); } + pdd bnor(pdd a, pdd b) { throw default_exception("nyi"); } + pdd bnot(pdd a) { throw default_exception("nyi"); } + std::pair quot_rem(pdd const& n, pdd const& d) { throw default_exception("nyi"); } + pdd zero_ext(pdd a, unsigned sz) { throw default_exception("nyi"); } + pdd sign_ext(pdd a, unsigned sz) { throw default_exception("nyi"); } + pdd extract(pdd src, unsigned hi, unsigned lo) { throw default_exception("nyi"); } + pdd concat(unsigned n, pdd const* args) { throw default_exception("nyi"); } + pvar add_var(unsigned sz); + pdd var(pvar p) { return m_vars[p]; } + + std::ostream& display(std::ostream& out) const { throw default_exception("nyi"); } + }; + +} diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp deleted file mode 100644 index ef469fe6f..000000000 --- a/src/sat/smt/polysat_internalize.cpp +++ /dev/null @@ -1,526 +0,0 @@ -/*++ -Copyright (c) 2022 Microsoft Corporation - -Module Name: - - polysat_internalize.cpp - -Abstract: - - PolySAT internalize - -Author: - - Nikolaj Bjorner (nbjorner) 2022-01-26 - ---*/ - -#include "params/bv_rewriter_params.hpp" -#include "sat/smt/polysat_solver.h" -#include "sat/smt/euf_solver.h" - -namespace polysat { - - euf::theory_var solver::mk_var(euf::enode* n) { - theory_var v = euf::th_euf_solver::mk_var(n); - ctx.attach_th_var(n, this, v); - return v; - } - - sat::literal solver::internalize(expr* e, bool sign, bool root) { - force_push(); - SASSERT(m.is_bool(e)); - if (!visit_rec(m, e, sign, root)) - return sat::null_literal; - sat::literal lit = expr2literal(e); - if (sign) - lit.neg(); - return lit; - } - - void solver::internalize(expr* e) { - force_push(); - visit_rec(m, e, false, false); - } - - bool solver::visit(expr* e) { - force_push(); - if (!is_app(e) || to_app(e)->get_family_id() != get_id()) { - ctx.internalize(e); - return true; - } - m_stack.push_back(sat::eframe(e)); - return false; - } - - bool solver::visited(expr* e) { - euf::enode* n = expr2enode(e); - return n && n->is_attached_to(get_id()); - } - - bool solver::post_visit(expr* e, bool sign, bool root) { - euf::enode* n = expr2enode(e); - app* a = to_app(e); - - if (visited(e)) - return true; - - SASSERT(!n || !n->is_attached_to(get_id())); - if (!n) - n = mk_enode(e, false); - - SASSERT(!n->is_attached_to(get_id())); - mk_var(n); - SASSERT(n->is_attached_to(get_id())); - internalize_polysat(a); - return true; - } - - void solver::internalize_polysat(app* a) { - -#define if_unary(F) if (a->get_num_args() == 1) { internalize_unary(a, [&](pdd const& p) { return F(p); }); break; } - - switch (a->get_decl_kind()) { - case OP_BMUL: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p * q; }); break; - case OP_BADD: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p + q; }); break; - case OP_BSUB: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p - q; }); break; - case OP_BLSHR: internalize_lshr(a); break; - case OP_BSHL: internalize_shl(a); break; - case OP_BASHR: internalize_ashr(a); break; - case OP_BAND: internalize_band(a); break; - case OP_BOR: internalize_bor(a); break; - case OP_BXOR: internalize_bxor(a); break; - case OP_BNAND: if_unary(m_core.bnot); internalize_bnand(a); break; - case OP_BNOR: if_unary(m_core.bnot); internalize_bnor(a); break; - case OP_BXNOR: if_unary(m_core.bnot); internalize_bxnor(a); break; - case OP_BNOT: internalize_unary(a, [&](pdd const& p) { return m_core.bnot(p); }); break; - case OP_BNEG: internalize_unary(a, [&](pdd const& p) { return -p; }); break; - case OP_MKBV: internalize_mkbv(a); break; - case OP_BV_NUM: internalize_num(a); break; - case OP_ULEQ: internalize_le(a); break; - case OP_SLEQ: internalize_le(a); break; - case OP_UGEQ: internalize_le(a); break; - case OP_SGEQ: internalize_le(a); break; - case OP_ULT: internalize_le(a); break; - case OP_SLT: internalize_le(a); break; - case OP_UGT: internalize_le(a); break; - case OP_SGT: internalize_le(a); break; - - case OP_BUMUL_NO_OVFL: internalize_binaryc(a, [&](pdd const& p, pdd const& q) { return m_core.umul_ovfl(p, q); }); break; - case OP_BSMUL_NO_OVFL: internalize_binaryc(a, [&](pdd const& p, pdd const& q) { return m_core.smul_ovfl(p, q); }); break; - case OP_BSMUL_NO_UDFL: internalize_binaryc(a, [&](pdd const& p, pdd const& q) { return m_core.smul_udfl(p, q); }); break; - - case OP_BUMUL_OVFL: - case OP_BSMUL_OVFL: - case OP_BSDIV_OVFL: - case OP_BNEG_OVFL: - case OP_BUADD_OVFL: - case OP_BSADD_OVFL: - case OP_BUSUB_OVFL: - case OP_BSSUB_OVFL: - verbose_stream() << mk_pp(a, m) << "\n"; - // handled by bv_rewriter for now - UNREACHABLE(); - break; - - case OP_BUDIV_I: internalize_udiv_i(a); break; - case OP_BUREM_I: internalize_urem_i(a); break; - - case OP_BUDIV: internalize_div_rem(a, true); break; - case OP_BUREM: internalize_div_rem(a, false); break; - case OP_BSDIV0: UNREACHABLE(); break; - case OP_BUDIV0: UNREACHABLE(); break; - case OP_BSREM0: UNREACHABLE(); break; - case OP_BUREM0: UNREACHABLE(); break; - case OP_BSMOD0: UNREACHABLE(); break; - - case OP_EXTRACT: internalize_extract(a); break; - case OP_CONCAT: internalize_concat(a); break; - case OP_ZERO_EXT: internalize_zero_extend(a); break; - case OP_SIGN_EXT: internalize_sign_extend(a); break; - - // polysat::solver should also support at least: - case OP_BREDAND: // x == 2^K - 1 unary, return single bit, 1 if all input bits are set. - case OP_BREDOR: // x > 0 unary, return single bit, 1 if at least one input bit is set. - case OP_BCOMP: // x == y binary, return single bit, 1 if the arguments are equal. - case OP_BSDIV: - case OP_BSREM: - case OP_BSMOD: - case OP_BSDIV_I: - case OP_BSREM_I: - case OP_BSMOD_I: - - IF_VERBOSE(0, verbose_stream() << mk_pp(a, m) << "\n"); - NOT_IMPLEMENTED_YET(); - return; - default: - IF_VERBOSE(0, verbose_stream() << mk_pp(a, m) << "\n"); - NOT_IMPLEMENTED_YET(); - return; - } -#undef if_unary - } - - class solver::mk_atom_trail : public trail { - solver& th; - sat::bool_var m_var; - public: - mk_atom_trail(sat::bool_var v, solver& th) : th(th), m_var(v) {} - void undo() override { - th.erase_bv2a(m_var); - } - }; - - void solver::mk_atom(sat::bool_var bv, signed_constraint& sc) { - if (get_bv2a(bv)) - return; - sat::literal lit(bv, false); - auto index = m_core.register_constraint(sc, dependency(lit, 0)); - auto a = new (get_region()) atom(bv, index); - insert_bv2a(bv, a); - ctx.push(mk_atom_trail(bv, *this)); - } - - void solver::internalize_binaryc(app* e, std::function const& fn) { - auto p = expr2pdd(e->get_arg(0)); - auto q = expr2pdd(e->get_arg(1)); - auto sc = ~fn(p, q); - sat::literal lit = expr2literal(e); - if (lit.sign()) - sc = ~sc; - mk_atom(lit.var(), sc); - } - - void solver::internalize_udiv_i(app* e) { - expr* x, *y; - expr_ref rm(m); - if (bv.is_bv_udivi(e, x, y)) - rm = bv.mk_bv_urem_i(x, y); - else if (bv.is_bv_udiv(e, x, y)) - rm = bv.mk_bv_urem(x, y); - else - UNREACHABLE(); - internalize(rm); - } - - // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; - // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 - // (p + q) - band(p, q); - void solver::internalize_bor(app* n) { - internalize_binary(n, [&](expr* const& x, expr* const& y) { return bv.mk_bv_sub(bv.mk_bv_add(x, y), bv.mk_bv_and(x, y)); }); - } - - // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; - // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 - // (p + q) - 2*band(p, q); - void solver::internalize_bxor(app* n) { - internalize_binary(n, [&](expr* const& x, expr* const& y) { - return bv.mk_bv_sub(bv.mk_bv_add(x, y), bv.mk_bv_add(bv.mk_bv_and(x, y), bv.mk_bv_and(x, y))); - }); - } - - void solver::internalize_bnor(app* n) { - internalize_binary(n, [&](expr* const& x, expr* const& y) { return bv.mk_bv_not(bv.mk_bv_or(x, y)); }); - } - - void solver::internalize_bnand(app* n) { - internalize_binary(n, [&](expr* const& x, expr* const& y) { return bv.mk_bv_not(bv.mk_bv_and(x, y)); }); - } - - void solver::internalize_bxnor(app* n) { - internalize_binary(n, [&](expr* const& x, expr* const& y) { return bv.mk_bv_not(bv.mk_bv_xor(x, y)); }); - } - - void solver::internalize_band(app* n) { - if (n->get_num_args() == 2) { - expr* x, * y; - VERIFY(bv.is_bv_and(n, x, y)); - m_core.band(expr2pdd(x), expr2pdd(y), expr2pdd(n)); - } - else { - expr_ref z(n->get_arg(0), m); - for (unsigned i = 1; i < n->get_num_args(); ++i) { - z = bv.mk_bv_and(z, n->get_arg(i)); - ctx.internalize(z); - } - internalize_set(n, expr2pdd(z)); - } - } - - void solver::internalize_lshr(app* n) { - expr* x, * y; - VERIFY(bv.is_bv_lshr(n, x, y)); - m_core.lshr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); - } - - void solver::internalize_ashr(app* n) { - expr* x, * y; - VERIFY(bv.is_bv_ashr(n, x, y)); - m_core.ashr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); - } - - void solver::internalize_shl(app* n) { - expr* x, * y; - VERIFY(bv.is_bv_shl(n, x, y)); - m_core.shl(expr2pdd(x), expr2pdd(y), expr2pdd(n)); - } - - void solver::internalize_urem_i(app* rem) { - expr* x, *y; - euf::enode* n = expr2enode(rem); - SASSERT(n && n->is_attached_to(get_id())); - theory_var v = n->get_th_var(get_id()); - if (m_var2pdd_valid.get(v, false)) - return; - expr_ref quot(m); - if (bv.is_bv_uremi(rem, x, y)) - quot = bv.mk_bv_udiv_i(x, y); - else if (bv.is_bv_urem(rem, x, y)) - quot = bv.mk_bv_udiv(x, y); - else - UNREACHABLE(); - m_var2pdd_valid.setx(v, true, false); - ctx.internalize(quot); - m_var2pdd_valid.setx(v, false, false); - quot_rem(quot, rem, x, y); - } - - void solver::quot_rem(expr* quot, expr* rem, expr* x, expr* y) { - pdd a = expr2pdd(x); - pdd b = expr2pdd(y); - euf::enode* qn = expr2enode(quot); - euf::enode* rn = expr2enode(rem); - auto& m = a.manager(); - unsigned sz = m.power_of_2(); - if (b.is_zero()) { - // By SMT-LIB specification, b = 0 ==> q = -1, r = a. - internalize_set(quot, m.mk_val(m.max_value())); - internalize_set(rem, a); - return; - } - if (b.is_one()) { - internalize_set(quot, a); - internalize_set(rem, m.zero()); - return; - } - - if (a.is_val() && b.is_val()) { - rational const av = a.val(); - rational const bv = b.val(); - SASSERT(!bv.is_zero()); - rational rv; - rational qv = machine_div_rem(av, bv, rv); - pdd q = m.mk_val(qv); - pdd r = m.mk_val(rv); - SASSERT_EQ(a, b * q + r); - SASSERT(b.val() * q.val() + r.val() <= m.max_value()); - SASSERT(r.val() <= (b * q + r).val()); - SASSERT(r.val() < b.val()); - internalize_set(quot, q); - internalize_set(rem, r); - return; - } - - pdd r = var2pdd(rn->get_th_var(get_id())); - pdd q = var2pdd(qn->get_th_var(get_id())); - - // Axioms for quotient/remainder - // - // a = b*q + r - // multiplication does not overflow in b*q - // addition does not overflow in (b*q) + r; for now expressed as: r <= bq+r - // b ≠ 0 ==> r < b - // b = 0 ==> q = -1 - // TODO: when a,b become evaluable, can we actually propagate q,r? doesn't seem like it. - // Maybe we need something like an op_constraint for better propagation. - add_polysat_clause("[axiom] quot_rem 1", { m_core.eq(b * q + r - a) }, false); - add_polysat_clause("[axiom] quot_rem 2", { ~m_core.umul_ovfl(b, q) }, false); - // r <= b*q+r - // { apply equivalence: p <= q <=> q-p <= -p-1 } - // b*q <= -r-1 - add_polysat_clause("[axiom] quot_rem 3", { m_core.ule(b * q, -r - 1) }, false); - - auto c_eq = m_core.eq(b); - if (!c_eq.is_always_true()) - add_polysat_clause("[axiom] quot_rem 4", { c_eq, ~m_core.ule(b, r) }, false); - if (!c_eq.is_always_false()) - add_polysat_clause("[axiom] quot_rem 5", { ~c_eq, m_core.eq(q + 1) }, false); - } - - void solver::internalize_sign_extend(app* e) { - expr* arg = e->get_arg(0); - unsigned sz = bv.get_bv_size(e); - unsigned arg_sz = bv.get_bv_size(arg); - unsigned sz2 = sz - arg_sz; - var2pdd(expr2enode(e)->get_th_var(get_id())); - if (arg_sz == sz) - add_clause(eq_internalize(e, arg), nullptr); - else { - sat::literal lt0 = ctx.mk_literal(bv.mk_slt(arg, bv.mk_numeral(0, arg_sz))); - // arg < 0 ==> e = concat(arg, 1...1) - // arg >= 0 ==> e = concat(arg, 0...0) - add_clause(lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(rational::power_of_two(sz2) - 1, sz2))), nullptr); - add_clause(~lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz2))), nullptr); - } - } - - void solver::internalize_zero_extend(app* e) { - expr* arg = e->get_arg(0); - unsigned sz = bv.get_bv_size(e); - unsigned arg_sz = bv.get_bv_size(arg); - unsigned sz2 = sz - arg_sz; - var2pdd(expr2enode(e)->get_th_var(get_id())); - if (arg_sz == sz) - add_clause(eq_internalize(e, arg), nullptr); - else - // e = concat(arg, 0...0) - add_clause(eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz2))), nullptr); - } - - void solver::internalize_div_rem(app* e, bool is_div) { - bv_rewriter_params p(s().params()); - if (p.hi_div0()) { - if (is_div) - internalize_udiv_i(e); - else - internalize_urem_i(e); - return; - } - expr* arg1 = e->get_arg(0); - expr* arg2 = e->get_arg(1); - unsigned sz = bv.get_bv_size(e); - expr_ref zero(bv.mk_numeral(0, sz), m); - sat::literal eqZ = eq_internalize(arg2, zero); - sat::literal eqU = eq_internalize(e, is_div ? bv.mk_bv_udiv0(arg1) : bv.mk_bv_urem0(arg1)); - sat::literal eqI = eq_internalize(e, is_div ? bv.mk_bv_udiv_i(arg1, arg2) : bv.mk_bv_urem_i(arg1, arg2)); - add_clause(~eqZ, eqU); - add_clause(eqZ, eqI); - ctx.add_aux(~eqZ, eqU); - ctx.add_aux(eqZ, eqI); - } - - void solver::internalize_num(app* a) { - rational val; - unsigned sz = 0; - VERIFY(bv.is_numeral(a, val, sz)); - auto p = m_core.value(val, sz); - internalize_set(a, p); - } - - // TODO - test that internalize works with recursive call on bit2bool - void solver::internalize_mkbv(app* a) { - unsigned i = 0; - for (expr* arg : *a) { - expr_ref b2b(m); - b2b = bv.mk_bit2bool(a, i); - sat::literal bit_i = ctx.internalize(b2b, false, false); - sat::literal lit = expr2literal(arg); - add_equiv(lit, bit_i); -#if 0 - ctx.add_aux_equiv(lit, bit_i); -#endif - ++i; - } - } - - void solver::internalize_extract(app* e) { - var2pdd(expr2enode(e)->get_th_var(get_id())); - } - - void solver::internalize_concat(app* e) { - SASSERT(bv.is_concat(e)); - var2pdd(expr2enode(e)->get_th_var(get_id())); - } - - void solver::internalize_par_unary(app* e, std::function const& fn) { - pdd const p = expr2pdd(e->get_arg(0)); - unsigned const par = e->get_parameter(0).get_int(); - internalize_set(e, fn(p, par)); - } - - void solver::internalize_binary(app* e, std::function const& fn) { - SASSERT(e->get_num_args() >= 1); - auto p = expr2pdd(e->get_arg(0)); - for (unsigned i = 1; i < e->get_num_args(); ++i) - p = fn(p, expr2pdd(e->get_arg(i))); - internalize_set(e, p); - } - - void solver::internalize_binary(app* e, std::function const& fn) { - SASSERT(e->get_num_args() >= 1); - expr* r = e->get_arg(0); - for (unsigned i = 1; i < e->get_num_args(); ++i) - r = fn(r, e->get_arg(i)); - ctx.internalize(r); - internalize_set(e, var2pdd(expr2enode(r)->get_th_var(get_id()))); - } - - void solver::internalize_unary(app* e, std::function const& fn) { - SASSERT(e->get_num_args() == 1); - auto p = expr2pdd(e->get_arg(0)); - internalize_set(e, fn(p)); - } - - template - void solver::internalize_le(app* e) { - SASSERT(e->get_num_args() == 2); - auto p = expr2pdd(e->get_arg(0)); - auto q = expr2pdd(e->get_arg(1)); - if (Rev) - std::swap(p, q); - auto sc = Signed ? m_core.sle(p, q) : m_core.ule(p, q); - if (Negated) - sc = ~sc; - - sat::literal lit = expr2literal(e); - if (lit.sign()) - sc = ~sc; - mk_atom(lit.var(), sc); - } - - dd::pdd solver::expr2pdd(expr* e) { - return var2pdd(get_th_var(e)); - } - - dd::pdd solver::var2pdd(euf::theory_var v) { - if (!m_var2pdd_valid.get(v, false)) { - unsigned bv_size = get_bv_size(v); - pvar pv = m_core.add_var(bv_size); - m_pddvar2var.setx(pv, v, UINT_MAX); - pdd p = m_core.var(pv); - internalize_set(v, p); - return p; - } - return m_var2pdd[v]; - } - - void solver::apply_sort_cnstr(euf::enode* n, sort* s) { - if (!bv.is_bv(n->get_expr())) - return; - theory_var v = n->get_th_var(get_id()); - if (v == euf::null_theory_var) - v = mk_var(n); - var2pdd(v); - } - - void solver::internalize_set(expr* e, pdd const& p) { - internalize_set(get_th_var(e), p); - } - - void solver::internalize_set(euf::theory_var v, pdd const& p) { - SASSERT_EQ(get_bv_size(v), p.power_of_2()); - m_var2pdd.reserve(get_num_vars(), p); - m_var2pdd_valid.reserve(get_num_vars(), false); - ctx.push(set_bitvector_trail(m_var2pdd_valid, v)); -#if 0 - m_var2pdd[v].reset(p.manager()); -#endif - m_var2pdd[v] = p; - } - - void solver::eq_internalized(euf::enode* n) { - SASSERT(m.is_eq(n->get_expr())); - } - - -} diff --git a/src/sat/smt/polysat_substitution.h b/src/sat/smt/polysat_substitution.h new file mode 100644 index 000000000..a30c6b710 --- /dev/null +++ b/src/sat/smt/polysat_substitution.h @@ -0,0 +1,212 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + polysat substitution + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once +#include "sat/smt/polysat_types.h" + +namespace polysat { + + using assignment_item_t = std::pair; + + class substitution_iterator { + pdd m_current; + substitution_iterator(pdd current) : m_current(std::move(current)) {} + friend class substitution; + + public: + using value_type = assignment_item_t; + using difference_type = std::ptrdiff_t; + using pointer = value_type const*; + using reference = value_type const&; + using iterator_category = std::input_iterator_tag; + + substitution_iterator& operator++() { + SASSERT(!m_current.is_val()); + m_current = m_current.hi(); + return *this; + } + + value_type operator*() const { + SASSERT(!m_current.is_val()); + return { m_current.var(), m_current.lo().val() }; + } + + bool operator==(substitution_iterator const& other) const { return m_current == other.m_current; } + bool operator!=(substitution_iterator const& other) const { return !operator==(other); } + }; + + /** Substitution for a single bit width. */ + class substitution { + pdd m_subst; + + substitution(pdd p); + + public: + substitution(dd::pdd_manager& m); + [[nodiscard]] substitution add(pvar var, rational const& value) const; + [[nodiscard]] pdd apply_to(pdd const& p) const; + + [[nodiscard]] bool contains(pvar var) const; + [[nodiscard]] bool value(pvar var, rational& out_value) const; + + [[nodiscard]] bool empty() const { return m_subst.is_one(); } + + pdd const& to_pdd() const { return m_subst; } + unsigned bit_width() const { return to_pdd().power_of_2(); } + + bool operator==(substitution const& other) const { return &m_subst.manager() == &other.m_subst.manager() && m_subst == other.m_subst; } + bool operator!=(substitution const& other) const { return !operator==(other); } + + std::ostream& display(std::ostream& out) const; + + using const_iterator = substitution_iterator; + const_iterator begin() const { return {m_subst}; } + const_iterator end() const { return {m_subst.manager().one()}; } + }; + + /** Full variable assignment, may include variables of varying bit widths. */ + class assignment { + vector m_pairs; + mutable scoped_ptr_vector m_subst; + vector m_subst_trail; + + substitution& subst(unsigned sz); + solver& s() const { return *m_solver; } + public: + assignment(solver& s); + // prevent implicit copy, use clone() if you do need a copy + assignment(assignment const&) = delete; + assignment& operator=(assignment const&) = delete; + assignment(assignment&&) = default; + assignment& operator=(assignment&&) = default; + assignment clone() const; + + void push(pvar var, rational const& value); + void pop(); + + pdd apply_to(pdd const& p) const; + + bool contains(pvar var) const; + bool value(pvar var, rational& out_value) const; + rational value(pvar var) const { rational val; VERIFY(value(var, val)); return val; } + bool empty() const { return pairs().empty(); } + substitution const& subst(unsigned sz) const; + vector const& pairs() const { return m_pairs; } + using const_iterator = decltype(m_pairs)::const_iterator; + const_iterator begin() const { return pairs().begin(); } + const_iterator end() const { return pairs().end(); } + + std::ostream& display(std::ostream& out) const; + }; + + inline std::ostream& operator<<(std::ostream& out, substitution const& sub) { return sub.display(out); } + + inline std::ostream& operator<<(std::ostream& out, assignment const& a) { return a.display(out); } +} + +namespace polysat { + + enum class search_item_k + { + assignment, + boolean, + }; + + class search_item { + search_item_k m_kind; + union { + pvar m_var; + sat::literal m_lit; + }; + bool m_resolved = false; // when marked as resolved it is no longer valid to reduce the conflict state + + search_item(pvar var): m_kind(search_item_k::assignment), m_var(var) {} + search_item(sat::literal lit): m_kind(search_item_k::boolean), m_lit(lit) {} + public: + static search_item assignment(pvar var) { return search_item(var); } + static search_item boolean(sat::literal lit) { return search_item(lit); } + bool is_assignment() const { return m_kind == search_item_k::assignment; } + bool is_boolean() const { return m_kind == search_item_k::boolean; } + bool is_resolved() const { return m_resolved; } + search_item_k kind() const { return m_kind; } + pvar var() const { SASSERT(is_assignment()); return m_var; } + sat::literal lit() const { SASSERT(is_boolean()); return m_lit; } + void set_resolved() { m_resolved = true; } + }; + + class search_state { + solver& s; + + vector m_items; + assignment m_assignment; + + // store index into m_items + unsigned_vector m_pvar_to_idx; + unsigned_vector m_bool_to_idx; + + bool value(pvar v, rational& r) const; + + public: + search_state(solver& s): s(s), m_assignment(s) {} + unsigned size() const { return m_items.size(); } + search_item const& back() const { return m_items.back(); } + search_item const& operator[](unsigned i) const { return m_items[i]; } + + assignment const& get_assignment() const { return m_assignment; } + substitution const& subst(unsigned sz) const { return m_assignment.subst(sz); } + + // TODO: implement the following method if we actually need the assignments without resolved items already during conflict resolution + // (no separate trail needed, just a second m_subst and an index into the trail, I think) + // (update on set_resolved? might be one iteration too early, looking at the old solver::resolve_conflict loop) + substitution const& unresolved_assignment(unsigned sz) const; + + void push_assignment(pvar v, rational const& r); + void push_boolean(sat::literal lit); + void pop(); + + unsigned get_pvar_index(pvar v) const; + unsigned get_bool_index(sat::bool_var var) const; + unsigned get_bool_index(sat::literal lit) const { return get_bool_index(lit.var()); } + + void set_resolved(unsigned i) { m_items[i].set_resolved(); } + + using const_iterator = decltype(m_items)::const_iterator; + const_iterator begin() const { return m_items.begin(); } + const_iterator end() const { return m_items.end(); } + + std::ostream& display(std::ostream& out) const; + std::ostream& display(search_item const& item, std::ostream& out) const; + std::ostream& display_verbose(std::ostream& out) const; + std::ostream& display_verbose(search_item const& item, std::ostream& out) const; + }; + + struct search_state_pp { + search_state const& s; + bool verbose; + search_state_pp(search_state const& s, bool verbose = false) : s(s), verbose(verbose) {} + }; + + struct search_item_pp { + search_state const& s; + search_item const& i; + bool verbose; + search_item_pp(search_state const& s, search_item const& i, bool verbose = false) : s(s), i(i), verbose(verbose) {} + }; + + inline std::ostream& operator<<(std::ostream& out, search_state const& s) { return s.display(out); } + + inline std::ostream& operator<<(std::ostream& out, search_state_pp const& p) { return p.verbose ? p.s.display_verbose(out) : p.s.display(out); } + + inline std::ostream& operator<<(std::ostream& out, search_item_pp const& p) { return p.verbose ? p.s.display_verbose(p.i, out) : p.s.display(p.i, out); } + +} diff --git a/src/sat/smt/polysat_types.h b/src/sat/smt/polysat_types.h new file mode 100644 index 000000000..4296a8247 --- /dev/null +++ b/src/sat/smt/polysat_types.h @@ -0,0 +1,45 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once + +#include "math/dd/dd_pdd.h" +#include "util/sat_literal.h" +#include "util/dependency.h" + +namespace polysat { + + using pdd = dd::pdd; + using pvar = unsigned; + + + class dependency { + unsigned m_index; + unsigned m_level; + public: + dependency(sat::literal lit, unsigned level) : m_index(2 * lit.index()), m_level(level) {} + dependency(unsigned var_idx, unsigned level) : m_index(1 + 2 * var_idx), m_level(level) {} + bool is_literal() const { return m_index % 2 == 0; } + sat::literal literal() const { SASSERT(is_literal()); return sat::to_literal(m_index / 2); } + unsigned index() const { SASSERT(!is_literal()); return (m_index - 1) / 2; } + unsigned level() const { return m_level; } + }; + + using stacked_dependency = stacked_dependency_manager::dependency; + + inline std::ostream& operator<<(std::ostream& out, dependency d) { + if (d.is_literal()) + return out << d.literal() << "@" << d.level(); + else + return out << "v" << d.index() << "@" << d.level(); + } + + using dependency_vector = vector; + +} diff --git a/src/sat/smt/polysat_viable.h b/src/sat/smt/polysat_viable.h new file mode 100644 index 000000000..def069652 --- /dev/null +++ b/src/sat/smt/polysat_viable.h @@ -0,0 +1,55 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + maintain viable domains + It uses the interval extraction functions from forbidden intervals. + An empty viable set corresponds directly to a conflict that does not rely on + the non-viable variable. + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +--*/ +#pragma once + +#include "util/rational.h" +#include "sat/smt/polysat_types.h" + +namespace polysat { + + enum class find_t { + empty, + singleton, + multiple, + resource_out, + }; + + class core; + + class viable { + core& c; + public: + viable(core& c) : c(c) {} + + /** + * Find a next viable value for variable. + */ + find_t find_viable(pvar v, rational& out_val) { throw default_exception("nyi"); } + + /* + * Explain why the current variable is not viable or signleton. + */ + dependency_vector explain() { throw default_exception("nyi"); } + + /* + * Register constraint at index 'idx' as unitary in v. + */ + void add_unitary(pvar v, unsigned idx) { throw default_exception("nyi"); } + + }; + +} From edfa18f8cc513dea6d343efda77cd5f27bf2948f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 14:50:33 -0800 Subject: [PATCH 68/89] porting viable --- src/sat/smt/polysat_core.cpp | 276 ----------------------------------- src/sat/smt/polysat_types.h | 45 ------ src/sat/smt/polysat_viable.h | 77 +++++++++- 3 files changed, 71 insertions(+), 327 deletions(-) delete mode 100644 src/sat/smt/polysat_core.cpp delete mode 100644 src/sat/smt/polysat_types.h diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp deleted file mode 100644 index 27d6ee731..000000000 --- a/src/sat/smt/polysat_core.cpp +++ /dev/null @@ -1,276 +0,0 @@ -/*++ -Copyright (c) 2022 Microsoft Corporation - -Module Name: - - polysat_core.cpp - -Abstract: - - PolySAT core functionality - -Author: - - Nikolaj Bjorner (nbjorner) 2022-01-26 - Jakob Rath 2021-04-06 - -Notes: - -polysat::solver -- adds assignments -- calls propagation and check - -polysat::core -- propagates literals -- crates case splits by value assignment (equalities) -- detects conflicts based on Literal assignmets -- adds lemmas based on projections - ---*/ - -#include "params/bv_rewriter_params.hpp" -#include "sat/smt/polysat_solver.h" -#include "sat/smt/euf_solver.h" - -namespace polysat { - - class core::mk_assign_var : public trail { - pvar m_var; - core& c; - public: - mk_assign_var(pvar v, core& c) : m_var(v), c(c) {} - void undo() { - c.m_justification[m_var] = nullptr; - c.m_assignment.pop(); - } - }; - - class core::mk_dqueue_var : public trail { - pvar m_var; - core& c; - public: - mk_dqueue_var(pvar v, core& c) : m_var(v), c(c) {} - void undo() { - c.m_var_queue.unassign_var_eh(m_var); - } - }; - - class core::mk_add_var : public trail { - core& c; - public: - mk_add_var(core& c) : c(c) {} - void undo() override { - c.del_var(); - } - }; - - class core::mk_add_watch : public trail { - core& c; - unsigned m_idx; - public: - mk_add_watch(core& c, unsigned idx) : c(c), m_idx(idx) {} - void undo() override { - auto& sc = c.m_prop_queue[m_idx].first; - auto& vars = sc.vars(); - if (vars.size() > 0) - c.m_watch[vars[0]].pop_back(); - if (vars.size() > 1) - c.m_watch[vars[1]].pop_back(); - } - }; - - core::core(solver& s) : - s(s), - m_viable(*this), - m_constraints(s.get_trail_stack()), - m_assignment(*this), - m_dep(s.get_region()), - m_var_queue(m_activity) - {} - - pdd core::value(rational const& v, unsigned sz) { - return sz2pdd(sz).mk_val(v); - } - - dd::pdd_manager& core::sz2pdd(unsigned sz) const { - m_pdd.reserve(sz + 1); - if (!m_pdd[sz]) - m_pdd.set(sz, alloc(dd::pdd_manager, 1000, dd::pdd_manager::semantics::mod2N_e, sz)); - return *m_pdd[sz]; - } - - dd::pdd_manager& core::var2pdd(pvar v) const { - return sz2pdd(size(v)); - } - - pvar core::add_var(unsigned sz) { - unsigned v = m_vars.size(); - m_vars.push_back(sz2pdd(sz).mk_var(v)); - m_activity.push_back({ sz, 0 }); - m_justification.push_back(nullptr); - m_watch.push_back({}); - m_var_queue.mk_var_eh(v); - s.ctx.push(mk_add_var(*this)); - return v; - } - - void core::del_var() { - unsigned v = m_vars.size() - 1; - m_vars.pop_back(); - m_activity.pop_back(); - m_justification.pop_back(); - m_watch.pop_back(); - m_var_queue.del_var_eh(v); - } - - // case split on unassigned variables until all are assigned values. - // create equality literal for unassigned variable. - // return new_eq if there is a new literal. - - sat::check_result core::check() { - if (m_var_queue.empty()) - return sat::check_result::CR_DONE; - m_var = m_var_queue.next_var(); - s.ctx.push(mk_dqueue_var(m_var, *this)); - switch (m_viable.find_viable(m_var, m_value)) { - case find_t::empty: - m_unsat_core = m_viable.explain(); - propagate_unsat_core(); - return sat::check_result::CR_CONTINUE; - case find_t::singleton: - s.propagate(m_constraints.eq(var2pdd(m_var), m_value), m_viable.explain()); - return sat::check_result::CR_CONTINUE; - case find_t::multiple: - return sat::check_result::CR_CONTINUE; - case find_t::resource_out: - return sat::check_result::CR_GIVEUP; - } - UNREACHABLE(); - return sat::check_result::CR_GIVEUP; - } - - // First propagate Boolean assignment, then propagate value assignment - bool core::propagate() { - if (m_qhead == m_prop_queue.size() && m_vqhead == m_prop_queue.size()) - return false; - s.ctx.push(value_trail(m_qhead)); - for (; m_qhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_qhead) - propagate_constraint(m_qhead, m_prop_queue[m_qhead]); - s.ctx.push(value_trail(m_vqhead)); - for (; m_vqhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_vqhead) - propagate_value(m_vqhead, m_prop_queue[m_vqhead]); - return true; - } - - void core::propagate_constraint(unsigned idx, dependent_constraint& dc) { - auto [sc, dep] = dc; - if (sc.is_eq(m_var, m_value)) { - propagate_assignment(m_var, m_value, dep); - return; - } - add_watch(idx, sc); - } - - void core::add_watch(unsigned idx, signed_constraint& sc) { - auto& vars = sc.vars(); - unsigned i = 0, j = 0, sz = vars.size(); - for (; i < sz && j < 2; ++i) - if (!is_assigned(vars[i])) - std::swap(vars[i], vars[j++]); - if (vars.size() > 0) - add_watch(idx, vars[0]); - if (vars.size() > 1) - add_watch(idx, vars[1]); - s.ctx.push(mk_add_watch(*this, idx)); - } - - void core::add_watch(unsigned idx, unsigned var) { - m_watch[var].push_back(idx); - } - - void core::propagate_assignment(pvar v, rational const& value, stacked_dependency* dep) { - if (is_assigned(v)) - return; - if (m_var_queue.contains(v)) { - m_var_queue.del_var_eh(v); - s.ctx.push(mk_dqueue_var(v, *this)); - } - m_values[v] = value; - m_justification[v] = dep; - m_assignment.push(v , value); - s.ctx.push(mk_assign_var(v, *this)); - - // update the watch lists for pvars - // remove constraints from m_watch[v] that have more than 2 free variables. - // for entries where there is only one free variable left add to viable set - unsigned j = 0; - for (auto idx : m_watch[v]) { - auto [sc, dep] = m_prop_queue[idx]; - auto& vars = sc.vars(); - if (vars[0] != v) - std::swap(vars[0], vars[1]); - SASSERT(vars[0] == v); - bool swapped = false; - for (unsigned i = vars.size(); i-- > 2; ) { - if (!is_assigned(vars[i])) { - add_watch(idx, vars[i]); - std::swap(vars[i], vars[0]); - swapped = true; - break; - } - } - if (!swapped) { - m_watch[v][j++] = idx; - } - - // constraint is unitary, add to viable set - if (vars.size() >= 2 && is_assigned(vars[0]) && !is_assigned(vars[1])) { - // m_viable.add_unitary(vars[1], idx); - } - } - m_watch[v].shrink(j); - } - - void core::propagate_value(unsigned idx, dependent_constraint const& dc) { - auto [sc, dep] = dc; - // check if sc evaluates to false - switch (eval(sc)) { - case l_true: - return; - case l_false: - m_unsat_core = explain_eval(dc); - propagate_unsat_core(); - return; - default: - break; - } - // if sc is v == value, then check the watch list for v if they evaluate to false. - if (sc.is_eq(m_var, m_value)) { - for (auto idx : m_watch[m_var]) { - auto [sc, dep] = m_prop_queue[idx]; - if (eval(sc) == l_false) { - m_unsat_core = explain_eval({ sc, dep }); - propagate_unsat_core(); - return; - } - } - } - - throw default_exception("nyi"); - } - - bool core::propagate_unsat_core() { - // default is to use unsat core: - s.set_conflict(m_unsat_core); - // if core is based on viable, use s.set_lemma(); - throw default_exception("nyi"); - } - - void core::assign_eh(signed_constraint const& sc, dependency const& dep) { - m_prop_queue.push_back({ sc, m_dep.mk_leaf(dep) }); - s.ctx.push(push_back_vector(m_prop_queue)); - } - - - -} diff --git a/src/sat/smt/polysat_types.h b/src/sat/smt/polysat_types.h deleted file mode 100644 index 4296a8247..000000000 --- a/src/sat/smt/polysat_types.h +++ /dev/null @@ -1,45 +0,0 @@ -/*++ -Copyright (c) 2021 Microsoft Corporation - -Author: - - Nikolaj Bjorner (nbjorner) 2021-03-19 - Jakob Rath 2021-04-06 - ---*/ -#pragma once - -#include "math/dd/dd_pdd.h" -#include "util/sat_literal.h" -#include "util/dependency.h" - -namespace polysat { - - using pdd = dd::pdd; - using pvar = unsigned; - - - class dependency { - unsigned m_index; - unsigned m_level; - public: - dependency(sat::literal lit, unsigned level) : m_index(2 * lit.index()), m_level(level) {} - dependency(unsigned var_idx, unsigned level) : m_index(1 + 2 * var_idx), m_level(level) {} - bool is_literal() const { return m_index % 2 == 0; } - sat::literal literal() const { SASSERT(is_literal()); return sat::to_literal(m_index / 2); } - unsigned index() const { SASSERT(!is_literal()); return (m_index - 1) / 2; } - unsigned level() const { return m_level; } - }; - - using stacked_dependency = stacked_dependency_manager::dependency; - - inline std::ostream& operator<<(std::ostream& out, dependency d) { - if (d.is_literal()) - return out << d.literal() << "@" << d.level(); - else - return out << "v" << d.index() << "@" << d.level(); - } - - using dependency_vector = vector; - -} diff --git a/src/sat/smt/polysat_viable.h b/src/sat/smt/polysat_viable.h index def069652..31c88c62f 100644 --- a/src/sat/smt/polysat_viable.h +++ b/src/sat/smt/polysat_viable.h @@ -17,7 +17,12 @@ Author: #pragma once #include "util/rational.h" +#include "util/dlist.h" +#include "util/map.h" +#include "util/small_object_allocator.h" + #include "sat/smt/polysat_types.h" +#include "sat/smt/polysat_fi.h" namespace polysat { @@ -29,26 +34,86 @@ namespace polysat { }; class core; + class constraints; class viable { core& c; - public: - viable(core& c) : c(c) {} + constraints& cs; + forbidden_intervals m_forbidden_intervals; - /** + struct entry final : public dll_base, public fi_record { + /// whether the entry has been created by refinement (from constraints in 'fi_record::src') + bool refined = false; + /// whether the entry is part of the current set of intervals, or stashed away for backtracking + bool active = true; + bool valid_for_lemma = true; + pvar var = null_var; + + void reset() { + // dll_base::init(this); // we never did this in alloc_entry either + fi_record::reset(); + refined = false; + active = true; + valid_for_lemma = true; + var = null_var; + } + }; + + enum class entry_kind { unit_e, equal_e, diseq_e }; + + struct layer final { + entry* entries = nullptr; + unsigned bit_width = 0; + layer(unsigned bw) : bit_width(bw) {} + }; + + class layers final { + svector m_layers; + public: + svector const& get_layers() const { return m_layers; } + layer& ensure_layer(unsigned bit_width); + layer* get_layer(unsigned bit_width); + layer* get_layer(entry* e) { return get_layer(e->bit_width); } + layer const* get_layer(unsigned bit_width) const; + layer const* get_layer(entry* e) const { return get_layer(e->bit_width); } + entry* get_entries(unsigned bit_width) const { layer const* l = get_layer(bit_width); return l ? l->entries : nullptr; } + }; + + ptr_vector m_alloc; + vector m_units; // set of viable values based on unit multipliers, layered by bit-width in descending order + ptr_vector m_equal_lin; // entries that have non-unit multipliers, but are equal + ptr_vector m_diseq_lin; // entries that have distinct non-zero multipliers + + entry* alloc_entry(pvar v); + + std::ostream& display_one(std::ostream& out, pvar v, entry const* e) const; + std::ostream& display_all(std::ostream& out, pvar v, entry const* e, char const* delimiter = "") const; + void log(); + void log(pvar v); + + void insert(entry* e, pvar v, ptr_vector& entries, entry_kind k); + + void intersect(pvar v, entry* e); + + + public: + viable(core& c); + ~viable(); + + /** * Find a next viable value for variable. */ - find_t find_viable(pvar v, rational& out_val) { throw default_exception("nyi"); } + find_t find_viable(pvar v, rational& out_val); /* * Explain why the current variable is not viable or signleton. */ - dependency_vector explain() { throw default_exception("nyi"); } + dependency_vector explain(); /* * Register constraint at index 'idx' as unitary in v. */ - void add_unitary(pvar v, unsigned idx) { throw default_exception("nyi"); } + void add_unitary(pvar v, unsigned idx); }; From 660ce31538c7cc236e7f5fd2b839d06eacc49ef6 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 8 Dec 2023 15:28:05 -0800 Subject: [PATCH 69/89] porting viable --- src/sat/smt/polysat_core.h | 7 + src/sat/smt/polysat_viable.cpp | 475 +++++++++++++++++++++++++++++++++ src/sat/smt/polysat_viable.h | 12 +- 3 files changed, 493 insertions(+), 1 deletion(-) create mode 100644 src/sat/smt/polysat_viable.cpp diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat_core.h index 7fdf8c88c..92d2090ee 100644 --- a/src/sat/smt/polysat_core.h +++ b/src/sat/smt/polysat_core.h @@ -121,6 +121,13 @@ namespace polysat { pdd concat(unsigned n, pdd const* args) { throw default_exception("nyi"); } pvar add_var(unsigned sz); pdd var(pvar p) { return m_vars[p]; } +<<<<<<< HEAD +======= + unsigned size(pvar v) const { return var2pdd(v).power_of_2(); } + + constraints& cs() { return m_constraints; } + trail_stack& trail(); +>>>>>>> c7945af45 (porting viable) std::ostream& display(std::ostream& out) const { throw default_exception("nyi"); } }; diff --git a/src/sat/smt/polysat_viable.cpp b/src/sat/smt/polysat_viable.cpp new file mode 100644 index 000000000..d68822563 --- /dev/null +++ b/src/sat/smt/polysat_viable.cpp @@ -0,0 +1,475 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + maintain viable domains + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-06 + +Notes: + + +--*/ + + +#include "util/debug.h" +#include "util/log.h" +#include "sat/smt/polysat_viable.h" +#include "sat/smt/polysat_core.h" + +namespace polysat { + + using dd::val_pp; + + viable::viable(core& c) : c(c), cs(c.cs()), m_forbidden_intervals(c) {} + + viable::~viable() { + for (auto* e : m_alloc) + dealloc(e); + } + + std::ostream& operator<<(std::ostream& out, find_t f) { + switch (f) { + case find_t::empty: return out << "empty"; + case find_t::singleton: return out << "singleton"; + case find_t::multiple: return out << "multiple"; + case find_t::resource_out: return out << "resource-out"; + default: return out << ""; + } + } + + struct viable::pop_viable_trail : public trail { + viable& m_s; + entry* e; + pvar v; + entry_kind k; + public: + pop_viable_trail(viable& s, entry* e, pvar v, entry_kind k) + : m_s(s), e(e), v(v), k(k) {} + void undo() override { + m_s.pop_viable(e, v, k); + } + }; + + struct viable::push_viable_trail : public trail { + viable& m_s; + entry* e; + pvar v; + entry_kind k; + public: + push_viable_trail(viable& s, entry* e, pvar v, entry_kind k) + : m_s(s), e(e), v(v), k(k) {} + void undo() override { + m_s.push_viable(e, v, k); + } + }; + + viable::entry* viable::alloc_entry(pvar var) { + if (m_alloc.empty()) + return alloc(entry); + auto* e = m_alloc.back(); + e->reset(); + e->var = var; + m_alloc.pop_back(); + return e; + } + + find_t viable::find_viable(pvar v, rational& out_val) { + ensure_var(v); + throw default_exception("nyi"); + } + + /* + * Explain why the current variable is not viable or signleton. + */ + dependency_vector viable::explain() { throw default_exception("nyi"); } + + /* + * Register constraint at index 'idx' as unitary in v. + */ + void viable::add_unitary(pvar v, unsigned idx) { + + ensure_var(v); + + if (c.is_assigned(v)) + return; + auto [sc, d] = c.m_constraint_trail[idx]; + + entry* ne = alloc_entry(v); + if (!m_forbidden_intervals.get_interval(sc, v, *ne)) { + m_alloc.push_back(ne); + return; + } + + if (ne->interval.is_currently_empty()) { + m_alloc.push_back(ne); + return; + } + + if (ne->coeff == 1) { + intersect(v, ne); + return; + } + else if (ne->coeff == -1) { + insert(ne, v, m_diseq_lin, entry_kind::diseq_e); + return; + } + else { + unsigned const w = c.size(v); + unsigned const k = ne->coeff.parity(w); + // unsigned const lo_parity = ne->interval.lo_val().parity(w); + // unsigned const hi_parity = ne->interval.hi_val().parity(w); + + display_one(std::cerr << "try to reduce entry: ", v, ne) << "\n"; + + if (k > 0 && ne->coeff.is_power_of_two()) { + // reduction of coeff gives us a unit entry + // + // 2^k a x \not\in [ lo ; hi [ + // + // new_lo = lo[w-1:k] if lo[k-1:0] = 0 + // lo[w-1:k] + 1 otherwise + // + // new_hi = hi[w-1:k] if hi[k-1:0] = 0 + // hi[w-1:k] + 1 otherwise + // + // Reference: Fig. 1 (dtrim) in BitvectorsMCSAT + // + pdd const& pdd_lo = ne->interval.lo(); + pdd const& pdd_hi = ne->interval.hi(); + rational const& lo = ne->interval.lo_val(); + rational const& hi = ne->interval.hi_val(); + + rational new_lo = machine_div2k(lo, k); + if (mod2k(lo, k).is_zero()) + ne->side_cond.push_back(cs.eq(pdd_lo * rational::power_of_two(w - k))); + else { + new_lo += 1; + ne->side_cond.push_back(~cs.eq(pdd_lo * rational::power_of_two(w - k))); + } + + rational new_hi = machine_div2k(hi, k); + if (mod2k(hi, k).is_zero()) + ne->side_cond.push_back(cs.eq(pdd_hi * rational::power_of_two(w - k))); + else { + new_hi += 1; + ne->side_cond.push_back(~cs.eq(pdd_hi * rational::power_of_two(w - k))); + } + + // we have to update also the pdd bounds accordingly, but it seems not worth introducing new variables for this eagerly + // new_lo = lo[:k] etc. + // TODO: for now just disable the FI-lemma if this case occurs + ne->valid_for_lemma = false; + + if (new_lo == new_hi) { + // empty or full + // if (ne->interval.currently_contains(rational::zero())) + NOT_IMPLEMENTED_YET(); + } + + ne->coeff = machine_div2k(ne->coeff, k); + ne->interval = eval_interval::proper(pdd_lo, new_lo, pdd_hi, new_hi); + ne->bit_width -= k; + display_one(std::cerr << "reduced entry: ", v, ne) << "\n"; + LOG("reduced entry to unit in bitwidth " << ne->bit_width); + intersect(v, ne); + } + + // TODO: later, can reduce according to shared_parity + // unsigned const shared_parity = std::min(coeff_parity, std::min(lo_parity, hi_parity)); + + insert(ne, v, m_equal_lin, entry_kind::equal_e); + return; + } + } + + void viable::ensure_var(pvar v) { + while (v >= m_units.size()) { + m_units.push_back(layers()); + m_equal_lin.push_back(nullptr); + m_diseq_lin.push_back(nullptr); + } + } + + bool viable::intersect(pvar v, entry* ne) { + SASSERT(!c.is_assigned(v)); + SASSERT(!ne->src.empty()); + entry*& entries = m_units[v].ensure_layer(ne->bit_width).entries; + entry* e = entries; + if (e && e->interval.is_full()) { + m_alloc.push_back(ne); + return false; + } + + if (ne->interval.is_currently_empty()) { + m_alloc.push_back(ne); + return false; + } + + auto create_entry = [&]() { + c.trail().push(pop_viable_trail(*this, ne, v, entry_kind::unit_e)); + ne->init(ne); + return ne; + }; + + auto remove_entry = [&](entry* e) { + c.trail().push(push_viable_trail(*this, e, v, entry_kind::unit_e)); + e->remove_from(entries, e); + e->active = false; + }; + + if (ne->interval.is_full()) { + // for (auto const& l : m_units[v].get_layers()) + // while (l.entries) + // remove_entry(l.entries); + while (entries) + remove_entry(entries); + entries = create_entry(); + return true; + } + + if (!e) + entries = create_entry(); + else { + entry* first = e; + do { + if (e->interval.currently_contains(ne->interval)) { + m_alloc.push_back(ne); + return false; + } + while (ne->interval.currently_contains(e->interval)) { + entry* n = e->next(); + remove_entry(e); + if (!entries) { + entries = create_entry(); + return true; + } + if (e == first) + first = n; + e = n; + } + SASSERT(e->interval.lo_val() != ne->interval.lo_val()); + if (e->interval.lo_val() > ne->interval.lo_val()) { + if (first->prev()->interval.currently_contains(ne->interval)) { + m_alloc.push_back(ne); + return false; + } + e->insert_before(create_entry()); + if (e == first) + entries = e->prev(); + SASSERT(well_formed(m_units[v])); + return true; + } + e = e->next(); + } while (e != first); + // otherwise, append to end of list + first->insert_before(create_entry()); + } + SASSERT(well_formed(m_units[v])); + return true; + } + + void viable::log() { + for (pvar v = 0; v < m_units.size(); ++v) + log(v); + } + + void viable::log(pvar v) { + throw default_exception("nyi"); + } + + + viable::layer& viable::layers::ensure_layer(unsigned bit_width) { + for (unsigned i = 0; i < m_layers.size(); ++i) { + layer& l = m_layers[i]; + if (l.bit_width == bit_width) + return l; + else if (l.bit_width < bit_width) { + m_layers.push_back(layer(0)); + for (unsigned j = m_layers.size(); --j > i; ) + m_layers[j] = m_layers[j - 1]; + m_layers[i] = layer(bit_width); + return m_layers[i]; + } + } + m_layers.push_back(layer(bit_width)); + return m_layers.back(); + } + + viable::layer* viable::layers::get_layer(unsigned bit_width) { + return const_cast(std::as_const(*this).get_layer(bit_width)); + } + + viable::layer const* viable::layers::get_layer(unsigned bit_width) const { + for (layer const& l : m_layers) + if (l.bit_width == bit_width) + return &l; + return nullptr; + } + + void viable::pop_viable(entry* e, pvar v, entry_kind k) { + SASSERT(well_formed(m_units[v])); + SASSERT(e->active); + e->active = false; + switch (k) { + case entry_kind::unit_e: + entry::remove_from(m_units[v].get_layer(e)->entries, e); + SASSERT(well_formed(m_units[v])); + break; + case entry_kind::equal_e: + entry::remove_from(m_equal_lin[v], e); + break; + case entry_kind::diseq_e: + entry::remove_from(m_diseq_lin[v], e); + break; + default: + UNREACHABLE(); + break; + } + m_alloc.push_back(e); + } + + void viable::push_viable(entry* e, pvar v, entry_kind k) { + // display_one(verbose_stream() << "Push entry: ", v, e) << "\n"; + entry*& entries = m_units[v].get_layer(e)->entries; + SASSERT(e->prev() != e || !entries); + SASSERT(e->prev() != e || e->next() == e); + SASSERT(k == entry_kind::unit_e); + SASSERT(!e->active); + e->active = true; + (void)k; + SASSERT(well_formed(m_units[v])); + if (e->prev() != e) { + entry* pos = e->prev(); + e->init(e); + pos->insert_after(e); + if (e->interval.lo_val() < entries->interval.lo_val()) + entries = e; + } + else + entries = e; + SASSERT(well_formed(m_units[v])); + } + + void viable::insert(entry* e, pvar v, ptr_vector& entries, entry_kind k) { + SASSERT(well_formed(m_units[v])); + + c.trail().push(pop_viable_trail(*this, e, v, k)); + + e->init(e); + if (!entries[v]) + entries[v] = e; + else + e->insert_after(entries[v]); + SASSERT(entries[v]->invariant()); + SASSERT(well_formed(m_units[v])); + } + + + std::ostream& viable::display_one(std::ostream& out, pvar v, entry const* e) const { + auto& m = c.var2pdd(v); + if (e->coeff == -1) { + // p*val + q > r*val + s if e->src.is_positive() + // p*val + q >= r*val + s if e->src.is_negative() + // Note that e->interval is meaningless in this case, + // we just use it to transport the values p,q,r,s + rational const& p = e->interval.lo_val(); + rational const& q_ = e->interval.lo().val(); + rational const& r = e->interval.hi_val(); + rational const& s_ = e->interval.hi().val(); + out << "[ "; + out << val_pp(m, p, true) << "*v" << v << " + " << val_pp(m, q_); + out << (e->src[0].is_positive() ? " > " : " >= "); + out << val_pp(m, r, true) << "*v" << v << " + " << val_pp(m, s_); + out << " ] "; + } + else if (e->coeff != 1) + out << e->coeff << " * v" << v << " " << e->interval << " "; + else + out << e->interval << " "; + if (e->side_cond.size() <= 5) + out << e->side_cond << " "; + else + out << e->side_cond.size() << " side-conditions "; + unsigned count = 0; + for (const auto& src : e->src) { + ++count; + out << src << "; "; + if (count > 10) { + out << " ..."; + break; + } + } + return out; + } + + std::ostream& viable::display_all(std::ostream& out, pvar v, entry const* e, char const* delimiter) const { + if (!e) + return out; + entry const* first = e; + unsigned count = 0; + do { + display_one(out, v, e) << delimiter; + e = e->next(); + ++count; + if (count > 10) { + out << " ..."; + break; + } + } + while (e != first); + return out; + } + + /* + * Lower bounds are strictly ascending. + * Intervals don't contain each-other (since lower bounds are ascending, it suffices to check containment in one direction). + */ + bool viable::well_formed(entry* e) { + if (!e) + return true; + entry* first = e; + while (true) { + if (!e->active) + return false; + + if (e->interval.is_full()) + return e->next() == e; + if (e->interval.is_currently_empty()) + return false; + + auto* n = e->next(); + if (n != e && e->interval.currently_contains(n->interval)) + return false; + + if (n == first) + break; + if (e->interval.lo_val() >= n->interval.lo_val()) + return false; + e = n; + } + return true; + } + + /* + * Layers are ordered in strictly descending bit-width. + * Entries in each layer are well-formed. + */ + bool viable::well_formed(layers const& ls) { + unsigned prev_width = std::numeric_limits::max(); + for (layer const& l : ls.get_layers()) { + if (!well_formed(l.entries)) + return false; + if (!all_of(dll_elements(l.entries), [&l](entry const& e) { return e.bit_width == l.bit_width; })) + return false; + if (prev_width <= l.bit_width) + return false; + prev_width = l.bit_width; + } + return true; + } +} diff --git a/src/sat/smt/polysat_viable.h b/src/sat/smt/polysat_viable.h index 31c88c62f..e268d67a3 100644 --- a/src/sat/smt/polysat_viable.h +++ b/src/sat/smt/polysat_viable.h @@ -84,6 +84,9 @@ namespace polysat { ptr_vector m_equal_lin; // entries that have non-unit multipliers, but are equal ptr_vector m_diseq_lin; // entries that have distinct non-zero multipliers + bool well_formed(entry* e); + bool well_formed(layers const& ls); + entry* alloc_entry(pvar v); std::ostream& display_one(std::ostream& out, pvar v, entry const* e) const; @@ -91,13 +94,20 @@ namespace polysat { void log(); void log(pvar v); + struct pop_viable_trail; + void pop_viable(entry* e, pvar v, entry_kind k); + struct push_viable_trail; + void push_viable(entry* e, pvar v, entry_kind k); + void insert(entry* e, pvar v, ptr_vector& entries, entry_kind k); - void intersect(pvar v, entry* e); + bool intersect(pvar v, entry* e); + void ensure_var(pvar v); public: viable(core& c); + ~viable(); /** From 179d89295879478ace43da1a9b0a8057c6c53565 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 9 Dec 2023 13:10:47 -0800 Subject: [PATCH 70/89] working on viable --- src/sat/smt/polysat_core.h | 135 ---------- src/sat/smt/polysat_viable.cpp | 475 --------------------------------- 2 files changed, 610 deletions(-) delete mode 100644 src/sat/smt/polysat_core.h delete mode 100644 src/sat/smt/polysat_viable.cpp diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat_core.h deleted file mode 100644 index 92d2090ee..000000000 --- a/src/sat/smt/polysat_core.h +++ /dev/null @@ -1,135 +0,0 @@ -/*++ -Copyright (c) 2020 Microsoft Corporation - -Module Name: - - polysat_core.h - -Abstract: - - Core solver for polysat - -Author: - - Nikolaj Bjorner (nbjorner) 2020-08-30 - Jakob Rath 2021-04-06 - ---*/ -#pragma once - -#include "util/dependency.h" -#include "math/dd/dd_pdd.h" -#include "sat/smt/sat_th.h" -#include "sat/smt/polysat_types.h" -#include "sat/smt/polysat_constraints.h" -#include "sat/smt/polysat_viable.h" -#include "sat/smt/polysat_assignment.h" - -namespace polysat { - - class core; - class solver; - - class core { - class mk_add_var; - class mk_dqueue_var; - class mk_assign_var; - class mk_add_watch; - typedef svector> activity; - friend class viable; - friend class constraints; - friend class assignment; - - solver& s; - viable m_viable; - constraints m_constraints; - assignment m_assignment; - unsigned m_qhead = 0, m_vqhead = 0; - svector m_prop_queue; - stacked_dependency_manager m_dep; - mutable scoped_ptr_vector m_pdd; - dependency_vector m_unsat_core; - - - // attributes associated with variables - vector m_vars; // for each variable a pdd - vector m_values; // current value of assigned variable - ptr_vector m_justification; // justification for assignment - activity m_activity; // activity of variables - var_queue m_var_queue; // priority queue of variables to assign - vector m_watch; // watch lists for variables for constraints on m_prop_queue where they occur - - vector m_subst; // substitution, one for each size. - - // values to split on - rational m_value; - pvar m_var = 0; - - dd::pdd_manager& sz2pdd(unsigned sz) const; - dd::pdd_manager& var2pdd(pvar v) const; - unsigned size(pvar v) const { return var2pdd(v).power_of_2(); } - void del_var(); - - bool is_assigned(pvar v) { return nullptr != m_justification[v]; } - void propagate_constraint(unsigned idx, dependent_constraint& dc); - void propagate_value(unsigned idx, dependent_constraint const& dc); - void propagate_assignment(pvar v, rational const& value, stacked_dependency* dep); - bool propagate_unsat_core(); - - void add_watch(unsigned idx, signed_constraint& sc); - void add_watch(unsigned idx, unsigned var); - - lbool eval(signed_constraint sc) { throw default_exception("nyi"); } - dependency_vector explain_eval(dependent_constraint const& dc) { throw default_exception("nyi"); } - - public: - core(solver& s); - - sat::check_result check(); - - bool propagate(); - void assign_eh(signed_constraint const& sc, dependency const& dep); - - expr_ref constraint2expr(signed_constraint const& sc) const { throw default_exception("nyi"); } - - pdd value(rational const& v, unsigned sz); - - signed_constraint eq(pdd const& p) { return m_constraints.eq(p); } - signed_constraint eq(pdd const& p, pdd const& q) { return m_constraints.eq(p - q); } - signed_constraint ule(pdd const& p, pdd const& q) { return m_constraints.ule(p, q); } - signed_constraint sle(pdd const& p, pdd const& q) { return m_constraints.sle(p, q); } - signed_constraint umul_ovfl(pdd const& p, pdd const& q) { return m_constraints.umul_ovfl(p, q); } - signed_constraint smul_ovfl(pdd const& p, pdd const& q) { return m_constraints.smul_ovfl(p, q); } - signed_constraint smul_udfl(pdd const& p, pdd const& q) { return m_constraints.smul_udfl(p, q); } - signed_constraint bit(pdd const& p, unsigned i) { return m_constraints.bit(p, i); } - - - pdd lshr(pdd a, pdd b) { throw default_exception("nyi"); } - pdd ashr(pdd a, pdd b) { throw default_exception("nyi"); } - pdd shl(pdd a, pdd b) { throw default_exception("nyi"); } - pdd band(pdd a, pdd b) { throw default_exception("nyi"); } - pdd bxor(pdd a, pdd b) { throw default_exception("nyi"); } - pdd bor(pdd a, pdd b) { throw default_exception("nyi"); } - pdd bnand(pdd a, pdd b) { throw default_exception("nyi"); } - pdd bxnor(pdd a, pdd b) { throw default_exception("nyi"); } - pdd bnor(pdd a, pdd b) { throw default_exception("nyi"); } - pdd bnot(pdd a) { throw default_exception("nyi"); } - std::pair quot_rem(pdd const& n, pdd const& d) { throw default_exception("nyi"); } - pdd zero_ext(pdd a, unsigned sz) { throw default_exception("nyi"); } - pdd sign_ext(pdd a, unsigned sz) { throw default_exception("nyi"); } - pdd extract(pdd src, unsigned hi, unsigned lo) { throw default_exception("nyi"); } - pdd concat(unsigned n, pdd const* args) { throw default_exception("nyi"); } - pvar add_var(unsigned sz); - pdd var(pvar p) { return m_vars[p]; } -<<<<<<< HEAD -======= - unsigned size(pvar v) const { return var2pdd(v).power_of_2(); } - - constraints& cs() { return m_constraints; } - trail_stack& trail(); ->>>>>>> c7945af45 (porting viable) - - std::ostream& display(std::ostream& out) const { throw default_exception("nyi"); } - }; - -} diff --git a/src/sat/smt/polysat_viable.cpp b/src/sat/smt/polysat_viable.cpp deleted file mode 100644 index d68822563..000000000 --- a/src/sat/smt/polysat_viable.cpp +++ /dev/null @@ -1,475 +0,0 @@ -/*++ -Copyright (c) 2021 Microsoft Corporation - -Module Name: - - maintain viable domains - -Author: - - Nikolaj Bjorner (nbjorner) 2021-03-19 - Jakob Rath 2021-04-06 - -Notes: - - ---*/ - - -#include "util/debug.h" -#include "util/log.h" -#include "sat/smt/polysat_viable.h" -#include "sat/smt/polysat_core.h" - -namespace polysat { - - using dd::val_pp; - - viable::viable(core& c) : c(c), cs(c.cs()), m_forbidden_intervals(c) {} - - viable::~viable() { - for (auto* e : m_alloc) - dealloc(e); - } - - std::ostream& operator<<(std::ostream& out, find_t f) { - switch (f) { - case find_t::empty: return out << "empty"; - case find_t::singleton: return out << "singleton"; - case find_t::multiple: return out << "multiple"; - case find_t::resource_out: return out << "resource-out"; - default: return out << ""; - } - } - - struct viable::pop_viable_trail : public trail { - viable& m_s; - entry* e; - pvar v; - entry_kind k; - public: - pop_viable_trail(viable& s, entry* e, pvar v, entry_kind k) - : m_s(s), e(e), v(v), k(k) {} - void undo() override { - m_s.pop_viable(e, v, k); - } - }; - - struct viable::push_viable_trail : public trail { - viable& m_s; - entry* e; - pvar v; - entry_kind k; - public: - push_viable_trail(viable& s, entry* e, pvar v, entry_kind k) - : m_s(s), e(e), v(v), k(k) {} - void undo() override { - m_s.push_viable(e, v, k); - } - }; - - viable::entry* viable::alloc_entry(pvar var) { - if (m_alloc.empty()) - return alloc(entry); - auto* e = m_alloc.back(); - e->reset(); - e->var = var; - m_alloc.pop_back(); - return e; - } - - find_t viable::find_viable(pvar v, rational& out_val) { - ensure_var(v); - throw default_exception("nyi"); - } - - /* - * Explain why the current variable is not viable or signleton. - */ - dependency_vector viable::explain() { throw default_exception("nyi"); } - - /* - * Register constraint at index 'idx' as unitary in v. - */ - void viable::add_unitary(pvar v, unsigned idx) { - - ensure_var(v); - - if (c.is_assigned(v)) - return; - auto [sc, d] = c.m_constraint_trail[idx]; - - entry* ne = alloc_entry(v); - if (!m_forbidden_intervals.get_interval(sc, v, *ne)) { - m_alloc.push_back(ne); - return; - } - - if (ne->interval.is_currently_empty()) { - m_alloc.push_back(ne); - return; - } - - if (ne->coeff == 1) { - intersect(v, ne); - return; - } - else if (ne->coeff == -1) { - insert(ne, v, m_diseq_lin, entry_kind::diseq_e); - return; - } - else { - unsigned const w = c.size(v); - unsigned const k = ne->coeff.parity(w); - // unsigned const lo_parity = ne->interval.lo_val().parity(w); - // unsigned const hi_parity = ne->interval.hi_val().parity(w); - - display_one(std::cerr << "try to reduce entry: ", v, ne) << "\n"; - - if (k > 0 && ne->coeff.is_power_of_two()) { - // reduction of coeff gives us a unit entry - // - // 2^k a x \not\in [ lo ; hi [ - // - // new_lo = lo[w-1:k] if lo[k-1:0] = 0 - // lo[w-1:k] + 1 otherwise - // - // new_hi = hi[w-1:k] if hi[k-1:0] = 0 - // hi[w-1:k] + 1 otherwise - // - // Reference: Fig. 1 (dtrim) in BitvectorsMCSAT - // - pdd const& pdd_lo = ne->interval.lo(); - pdd const& pdd_hi = ne->interval.hi(); - rational const& lo = ne->interval.lo_val(); - rational const& hi = ne->interval.hi_val(); - - rational new_lo = machine_div2k(lo, k); - if (mod2k(lo, k).is_zero()) - ne->side_cond.push_back(cs.eq(pdd_lo * rational::power_of_two(w - k))); - else { - new_lo += 1; - ne->side_cond.push_back(~cs.eq(pdd_lo * rational::power_of_two(w - k))); - } - - rational new_hi = machine_div2k(hi, k); - if (mod2k(hi, k).is_zero()) - ne->side_cond.push_back(cs.eq(pdd_hi * rational::power_of_two(w - k))); - else { - new_hi += 1; - ne->side_cond.push_back(~cs.eq(pdd_hi * rational::power_of_two(w - k))); - } - - // we have to update also the pdd bounds accordingly, but it seems not worth introducing new variables for this eagerly - // new_lo = lo[:k] etc. - // TODO: for now just disable the FI-lemma if this case occurs - ne->valid_for_lemma = false; - - if (new_lo == new_hi) { - // empty or full - // if (ne->interval.currently_contains(rational::zero())) - NOT_IMPLEMENTED_YET(); - } - - ne->coeff = machine_div2k(ne->coeff, k); - ne->interval = eval_interval::proper(pdd_lo, new_lo, pdd_hi, new_hi); - ne->bit_width -= k; - display_one(std::cerr << "reduced entry: ", v, ne) << "\n"; - LOG("reduced entry to unit in bitwidth " << ne->bit_width); - intersect(v, ne); - } - - // TODO: later, can reduce according to shared_parity - // unsigned const shared_parity = std::min(coeff_parity, std::min(lo_parity, hi_parity)); - - insert(ne, v, m_equal_lin, entry_kind::equal_e); - return; - } - } - - void viable::ensure_var(pvar v) { - while (v >= m_units.size()) { - m_units.push_back(layers()); - m_equal_lin.push_back(nullptr); - m_diseq_lin.push_back(nullptr); - } - } - - bool viable::intersect(pvar v, entry* ne) { - SASSERT(!c.is_assigned(v)); - SASSERT(!ne->src.empty()); - entry*& entries = m_units[v].ensure_layer(ne->bit_width).entries; - entry* e = entries; - if (e && e->interval.is_full()) { - m_alloc.push_back(ne); - return false; - } - - if (ne->interval.is_currently_empty()) { - m_alloc.push_back(ne); - return false; - } - - auto create_entry = [&]() { - c.trail().push(pop_viable_trail(*this, ne, v, entry_kind::unit_e)); - ne->init(ne); - return ne; - }; - - auto remove_entry = [&](entry* e) { - c.trail().push(push_viable_trail(*this, e, v, entry_kind::unit_e)); - e->remove_from(entries, e); - e->active = false; - }; - - if (ne->interval.is_full()) { - // for (auto const& l : m_units[v].get_layers()) - // while (l.entries) - // remove_entry(l.entries); - while (entries) - remove_entry(entries); - entries = create_entry(); - return true; - } - - if (!e) - entries = create_entry(); - else { - entry* first = e; - do { - if (e->interval.currently_contains(ne->interval)) { - m_alloc.push_back(ne); - return false; - } - while (ne->interval.currently_contains(e->interval)) { - entry* n = e->next(); - remove_entry(e); - if (!entries) { - entries = create_entry(); - return true; - } - if (e == first) - first = n; - e = n; - } - SASSERT(e->interval.lo_val() != ne->interval.lo_val()); - if (e->interval.lo_val() > ne->interval.lo_val()) { - if (first->prev()->interval.currently_contains(ne->interval)) { - m_alloc.push_back(ne); - return false; - } - e->insert_before(create_entry()); - if (e == first) - entries = e->prev(); - SASSERT(well_formed(m_units[v])); - return true; - } - e = e->next(); - } while (e != first); - // otherwise, append to end of list - first->insert_before(create_entry()); - } - SASSERT(well_formed(m_units[v])); - return true; - } - - void viable::log() { - for (pvar v = 0; v < m_units.size(); ++v) - log(v); - } - - void viable::log(pvar v) { - throw default_exception("nyi"); - } - - - viable::layer& viable::layers::ensure_layer(unsigned bit_width) { - for (unsigned i = 0; i < m_layers.size(); ++i) { - layer& l = m_layers[i]; - if (l.bit_width == bit_width) - return l; - else if (l.bit_width < bit_width) { - m_layers.push_back(layer(0)); - for (unsigned j = m_layers.size(); --j > i; ) - m_layers[j] = m_layers[j - 1]; - m_layers[i] = layer(bit_width); - return m_layers[i]; - } - } - m_layers.push_back(layer(bit_width)); - return m_layers.back(); - } - - viable::layer* viable::layers::get_layer(unsigned bit_width) { - return const_cast(std::as_const(*this).get_layer(bit_width)); - } - - viable::layer const* viable::layers::get_layer(unsigned bit_width) const { - for (layer const& l : m_layers) - if (l.bit_width == bit_width) - return &l; - return nullptr; - } - - void viable::pop_viable(entry* e, pvar v, entry_kind k) { - SASSERT(well_formed(m_units[v])); - SASSERT(e->active); - e->active = false; - switch (k) { - case entry_kind::unit_e: - entry::remove_from(m_units[v].get_layer(e)->entries, e); - SASSERT(well_formed(m_units[v])); - break; - case entry_kind::equal_e: - entry::remove_from(m_equal_lin[v], e); - break; - case entry_kind::diseq_e: - entry::remove_from(m_diseq_lin[v], e); - break; - default: - UNREACHABLE(); - break; - } - m_alloc.push_back(e); - } - - void viable::push_viable(entry* e, pvar v, entry_kind k) { - // display_one(verbose_stream() << "Push entry: ", v, e) << "\n"; - entry*& entries = m_units[v].get_layer(e)->entries; - SASSERT(e->prev() != e || !entries); - SASSERT(e->prev() != e || e->next() == e); - SASSERT(k == entry_kind::unit_e); - SASSERT(!e->active); - e->active = true; - (void)k; - SASSERT(well_formed(m_units[v])); - if (e->prev() != e) { - entry* pos = e->prev(); - e->init(e); - pos->insert_after(e); - if (e->interval.lo_val() < entries->interval.lo_val()) - entries = e; - } - else - entries = e; - SASSERT(well_formed(m_units[v])); - } - - void viable::insert(entry* e, pvar v, ptr_vector& entries, entry_kind k) { - SASSERT(well_formed(m_units[v])); - - c.trail().push(pop_viable_trail(*this, e, v, k)); - - e->init(e); - if (!entries[v]) - entries[v] = e; - else - e->insert_after(entries[v]); - SASSERT(entries[v]->invariant()); - SASSERT(well_formed(m_units[v])); - } - - - std::ostream& viable::display_one(std::ostream& out, pvar v, entry const* e) const { - auto& m = c.var2pdd(v); - if (e->coeff == -1) { - // p*val + q > r*val + s if e->src.is_positive() - // p*val + q >= r*val + s if e->src.is_negative() - // Note that e->interval is meaningless in this case, - // we just use it to transport the values p,q,r,s - rational const& p = e->interval.lo_val(); - rational const& q_ = e->interval.lo().val(); - rational const& r = e->interval.hi_val(); - rational const& s_ = e->interval.hi().val(); - out << "[ "; - out << val_pp(m, p, true) << "*v" << v << " + " << val_pp(m, q_); - out << (e->src[0].is_positive() ? " > " : " >= "); - out << val_pp(m, r, true) << "*v" << v << " + " << val_pp(m, s_); - out << " ] "; - } - else if (e->coeff != 1) - out << e->coeff << " * v" << v << " " << e->interval << " "; - else - out << e->interval << " "; - if (e->side_cond.size() <= 5) - out << e->side_cond << " "; - else - out << e->side_cond.size() << " side-conditions "; - unsigned count = 0; - for (const auto& src : e->src) { - ++count; - out << src << "; "; - if (count > 10) { - out << " ..."; - break; - } - } - return out; - } - - std::ostream& viable::display_all(std::ostream& out, pvar v, entry const* e, char const* delimiter) const { - if (!e) - return out; - entry const* first = e; - unsigned count = 0; - do { - display_one(out, v, e) << delimiter; - e = e->next(); - ++count; - if (count > 10) { - out << " ..."; - break; - } - } - while (e != first); - return out; - } - - /* - * Lower bounds are strictly ascending. - * Intervals don't contain each-other (since lower bounds are ascending, it suffices to check containment in one direction). - */ - bool viable::well_formed(entry* e) { - if (!e) - return true; - entry* first = e; - while (true) { - if (!e->active) - return false; - - if (e->interval.is_full()) - return e->next() == e; - if (e->interval.is_currently_empty()) - return false; - - auto* n = e->next(); - if (n != e && e->interval.currently_contains(n->interval)) - return false; - - if (n == first) - break; - if (e->interval.lo_val() >= n->interval.lo_val()) - return false; - e = n; - } - return true; - } - - /* - * Layers are ordered in strictly descending bit-width. - * Entries in each layer are well-formed. - */ - bool viable::well_formed(layers const& ls) { - unsigned prev_width = std::numeric_limits::max(); - for (layer const& l : ls.get_layers()) { - if (!well_formed(l.entries)) - return false; - if (!all_of(dll_elements(l.entries), [&l](entry const& e) { return e.bit_width == l.bit_width; })) - return false; - if (prev_width <= l.bit_width) - return false; - prev_width = l.bit_width; - } - return true; - } -} From f69c75af592bd325185031f0a2e6374826c67dad Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 13 Dec 2023 17:32:00 -0800 Subject: [PATCH 71/89] na Signed-off-by: Nikolaj Bjorner --- src/sat/smt/intblast_solver.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 869e388ff..fcb4699f9 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -619,7 +619,7 @@ namespace intblast { } case OP_BUREM: case OP_BUREM_I: { - expr* x = arg(0), * y = umod(e, 1); + expr* x = umod(e, 0), * y = umod(e, 1); r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, a.mk_mod(x, y)); break; } From 6c3890eee3ab56790b3359f094ee288b777e27a2 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 13 Dec 2023 19:55:16 -0800 Subject: [PATCH 72/89] merge Signed-off-by: Nikolaj Bjorner --- src/sat/smt/intblast_solver.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index fcb4699f9..ca8389557 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -71,6 +71,8 @@ namespace intblast { return n && n->is_attached_to(get_id()); } + + bool solver::post_visit(expr* e, bool sign, bool root) { euf::enode* n = expr2enode(e); app* a = to_app(e); From 54160d2efe64a6045d7ee9f4733c6def79f90413 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 13 Dec 2023 20:28:03 -0800 Subject: [PATCH 73/89] merge Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat_internalize.cpp | 526 ++++++++++++++++++++++++++++ 1 file changed, 526 insertions(+) create mode 100644 src/sat/smt/polysat_internalize.cpp diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp new file mode 100644 index 000000000..ef469fe6f --- /dev/null +++ b/src/sat/smt/polysat_internalize.cpp @@ -0,0 +1,526 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + polysat_internalize.cpp + +Abstract: + + PolySAT internalize + +Author: + + Nikolaj Bjorner (nbjorner) 2022-01-26 + +--*/ + +#include "params/bv_rewriter_params.hpp" +#include "sat/smt/polysat_solver.h" +#include "sat/smt/euf_solver.h" + +namespace polysat { + + euf::theory_var solver::mk_var(euf::enode* n) { + theory_var v = euf::th_euf_solver::mk_var(n); + ctx.attach_th_var(n, this, v); + return v; + } + + sat::literal solver::internalize(expr* e, bool sign, bool root) { + force_push(); + SASSERT(m.is_bool(e)); + if (!visit_rec(m, e, sign, root)) + return sat::null_literal; + sat::literal lit = expr2literal(e); + if (sign) + lit.neg(); + return lit; + } + + void solver::internalize(expr* e) { + force_push(); + visit_rec(m, e, false, false); + } + + bool solver::visit(expr* e) { + force_push(); + if (!is_app(e) || to_app(e)->get_family_id() != get_id()) { + ctx.internalize(e); + return true; + } + m_stack.push_back(sat::eframe(e)); + return false; + } + + bool solver::visited(expr* e) { + euf::enode* n = expr2enode(e); + return n && n->is_attached_to(get_id()); + } + + bool solver::post_visit(expr* e, bool sign, bool root) { + euf::enode* n = expr2enode(e); + app* a = to_app(e); + + if (visited(e)) + return true; + + SASSERT(!n || !n->is_attached_to(get_id())); + if (!n) + n = mk_enode(e, false); + + SASSERT(!n->is_attached_to(get_id())); + mk_var(n); + SASSERT(n->is_attached_to(get_id())); + internalize_polysat(a); + return true; + } + + void solver::internalize_polysat(app* a) { + +#define if_unary(F) if (a->get_num_args() == 1) { internalize_unary(a, [&](pdd const& p) { return F(p); }); break; } + + switch (a->get_decl_kind()) { + case OP_BMUL: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p * q; }); break; + case OP_BADD: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p + q; }); break; + case OP_BSUB: internalize_binary(a, [&](pdd const& p, pdd const& q) { return p - q; }); break; + case OP_BLSHR: internalize_lshr(a); break; + case OP_BSHL: internalize_shl(a); break; + case OP_BASHR: internalize_ashr(a); break; + case OP_BAND: internalize_band(a); break; + case OP_BOR: internalize_bor(a); break; + case OP_BXOR: internalize_bxor(a); break; + case OP_BNAND: if_unary(m_core.bnot); internalize_bnand(a); break; + case OP_BNOR: if_unary(m_core.bnot); internalize_bnor(a); break; + case OP_BXNOR: if_unary(m_core.bnot); internalize_bxnor(a); break; + case OP_BNOT: internalize_unary(a, [&](pdd const& p) { return m_core.bnot(p); }); break; + case OP_BNEG: internalize_unary(a, [&](pdd const& p) { return -p; }); break; + case OP_MKBV: internalize_mkbv(a); break; + case OP_BV_NUM: internalize_num(a); break; + case OP_ULEQ: internalize_le(a); break; + case OP_SLEQ: internalize_le(a); break; + case OP_UGEQ: internalize_le(a); break; + case OP_SGEQ: internalize_le(a); break; + case OP_ULT: internalize_le(a); break; + case OP_SLT: internalize_le(a); break; + case OP_UGT: internalize_le(a); break; + case OP_SGT: internalize_le(a); break; + + case OP_BUMUL_NO_OVFL: internalize_binaryc(a, [&](pdd const& p, pdd const& q) { return m_core.umul_ovfl(p, q); }); break; + case OP_BSMUL_NO_OVFL: internalize_binaryc(a, [&](pdd const& p, pdd const& q) { return m_core.smul_ovfl(p, q); }); break; + case OP_BSMUL_NO_UDFL: internalize_binaryc(a, [&](pdd const& p, pdd const& q) { return m_core.smul_udfl(p, q); }); break; + + case OP_BUMUL_OVFL: + case OP_BSMUL_OVFL: + case OP_BSDIV_OVFL: + case OP_BNEG_OVFL: + case OP_BUADD_OVFL: + case OP_BSADD_OVFL: + case OP_BUSUB_OVFL: + case OP_BSSUB_OVFL: + verbose_stream() << mk_pp(a, m) << "\n"; + // handled by bv_rewriter for now + UNREACHABLE(); + break; + + case OP_BUDIV_I: internalize_udiv_i(a); break; + case OP_BUREM_I: internalize_urem_i(a); break; + + case OP_BUDIV: internalize_div_rem(a, true); break; + case OP_BUREM: internalize_div_rem(a, false); break; + case OP_BSDIV0: UNREACHABLE(); break; + case OP_BUDIV0: UNREACHABLE(); break; + case OP_BSREM0: UNREACHABLE(); break; + case OP_BUREM0: UNREACHABLE(); break; + case OP_BSMOD0: UNREACHABLE(); break; + + case OP_EXTRACT: internalize_extract(a); break; + case OP_CONCAT: internalize_concat(a); break; + case OP_ZERO_EXT: internalize_zero_extend(a); break; + case OP_SIGN_EXT: internalize_sign_extend(a); break; + + // polysat::solver should also support at least: + case OP_BREDAND: // x == 2^K - 1 unary, return single bit, 1 if all input bits are set. + case OP_BREDOR: // x > 0 unary, return single bit, 1 if at least one input bit is set. + case OP_BCOMP: // x == y binary, return single bit, 1 if the arguments are equal. + case OP_BSDIV: + case OP_BSREM: + case OP_BSMOD: + case OP_BSDIV_I: + case OP_BSREM_I: + case OP_BSMOD_I: + + IF_VERBOSE(0, verbose_stream() << mk_pp(a, m) << "\n"); + NOT_IMPLEMENTED_YET(); + return; + default: + IF_VERBOSE(0, verbose_stream() << mk_pp(a, m) << "\n"); + NOT_IMPLEMENTED_YET(); + return; + } +#undef if_unary + } + + class solver::mk_atom_trail : public trail { + solver& th; + sat::bool_var m_var; + public: + mk_atom_trail(sat::bool_var v, solver& th) : th(th), m_var(v) {} + void undo() override { + th.erase_bv2a(m_var); + } + }; + + void solver::mk_atom(sat::bool_var bv, signed_constraint& sc) { + if (get_bv2a(bv)) + return; + sat::literal lit(bv, false); + auto index = m_core.register_constraint(sc, dependency(lit, 0)); + auto a = new (get_region()) atom(bv, index); + insert_bv2a(bv, a); + ctx.push(mk_atom_trail(bv, *this)); + } + + void solver::internalize_binaryc(app* e, std::function const& fn) { + auto p = expr2pdd(e->get_arg(0)); + auto q = expr2pdd(e->get_arg(1)); + auto sc = ~fn(p, q); + sat::literal lit = expr2literal(e); + if (lit.sign()) + sc = ~sc; + mk_atom(lit.var(), sc); + } + + void solver::internalize_udiv_i(app* e) { + expr* x, *y; + expr_ref rm(m); + if (bv.is_bv_udivi(e, x, y)) + rm = bv.mk_bv_urem_i(x, y); + else if (bv.is_bv_udiv(e, x, y)) + rm = bv.mk_bv_urem(x, y); + else + UNREACHABLE(); + internalize(rm); + } + + // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; + // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 + // (p + q) - band(p, q); + void solver::internalize_bor(app* n) { + internalize_binary(n, [&](expr* const& x, expr* const& y) { return bv.mk_bv_sub(bv.mk_bv_add(x, y), bv.mk_bv_and(x, y)); }); + } + + // From "Hacker's Delight", section 2-2. Addition Combined with Logical Operations; + // found via Int-Blasting paper; see https://doi.org/10.1007/978-3-030-94583-1_24 + // (p + q) - 2*band(p, q); + void solver::internalize_bxor(app* n) { + internalize_binary(n, [&](expr* const& x, expr* const& y) { + return bv.mk_bv_sub(bv.mk_bv_add(x, y), bv.mk_bv_add(bv.mk_bv_and(x, y), bv.mk_bv_and(x, y))); + }); + } + + void solver::internalize_bnor(app* n) { + internalize_binary(n, [&](expr* const& x, expr* const& y) { return bv.mk_bv_not(bv.mk_bv_or(x, y)); }); + } + + void solver::internalize_bnand(app* n) { + internalize_binary(n, [&](expr* const& x, expr* const& y) { return bv.mk_bv_not(bv.mk_bv_and(x, y)); }); + } + + void solver::internalize_bxnor(app* n) { + internalize_binary(n, [&](expr* const& x, expr* const& y) { return bv.mk_bv_not(bv.mk_bv_xor(x, y)); }); + } + + void solver::internalize_band(app* n) { + if (n->get_num_args() == 2) { + expr* x, * y; + VERIFY(bv.is_bv_and(n, x, y)); + m_core.band(expr2pdd(x), expr2pdd(y), expr2pdd(n)); + } + else { + expr_ref z(n->get_arg(0), m); + for (unsigned i = 1; i < n->get_num_args(); ++i) { + z = bv.mk_bv_and(z, n->get_arg(i)); + ctx.internalize(z); + } + internalize_set(n, expr2pdd(z)); + } + } + + void solver::internalize_lshr(app* n) { + expr* x, * y; + VERIFY(bv.is_bv_lshr(n, x, y)); + m_core.lshr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); + } + + void solver::internalize_ashr(app* n) { + expr* x, * y; + VERIFY(bv.is_bv_ashr(n, x, y)); + m_core.ashr(expr2pdd(x), expr2pdd(y), expr2pdd(n)); + } + + void solver::internalize_shl(app* n) { + expr* x, * y; + VERIFY(bv.is_bv_shl(n, x, y)); + m_core.shl(expr2pdd(x), expr2pdd(y), expr2pdd(n)); + } + + void solver::internalize_urem_i(app* rem) { + expr* x, *y; + euf::enode* n = expr2enode(rem); + SASSERT(n && n->is_attached_to(get_id())); + theory_var v = n->get_th_var(get_id()); + if (m_var2pdd_valid.get(v, false)) + return; + expr_ref quot(m); + if (bv.is_bv_uremi(rem, x, y)) + quot = bv.mk_bv_udiv_i(x, y); + else if (bv.is_bv_urem(rem, x, y)) + quot = bv.mk_bv_udiv(x, y); + else + UNREACHABLE(); + m_var2pdd_valid.setx(v, true, false); + ctx.internalize(quot); + m_var2pdd_valid.setx(v, false, false); + quot_rem(quot, rem, x, y); + } + + void solver::quot_rem(expr* quot, expr* rem, expr* x, expr* y) { + pdd a = expr2pdd(x); + pdd b = expr2pdd(y); + euf::enode* qn = expr2enode(quot); + euf::enode* rn = expr2enode(rem); + auto& m = a.manager(); + unsigned sz = m.power_of_2(); + if (b.is_zero()) { + // By SMT-LIB specification, b = 0 ==> q = -1, r = a. + internalize_set(quot, m.mk_val(m.max_value())); + internalize_set(rem, a); + return; + } + if (b.is_one()) { + internalize_set(quot, a); + internalize_set(rem, m.zero()); + return; + } + + if (a.is_val() && b.is_val()) { + rational const av = a.val(); + rational const bv = b.val(); + SASSERT(!bv.is_zero()); + rational rv; + rational qv = machine_div_rem(av, bv, rv); + pdd q = m.mk_val(qv); + pdd r = m.mk_val(rv); + SASSERT_EQ(a, b * q + r); + SASSERT(b.val() * q.val() + r.val() <= m.max_value()); + SASSERT(r.val() <= (b * q + r).val()); + SASSERT(r.val() < b.val()); + internalize_set(quot, q); + internalize_set(rem, r); + return; + } + + pdd r = var2pdd(rn->get_th_var(get_id())); + pdd q = var2pdd(qn->get_th_var(get_id())); + + // Axioms for quotient/remainder + // + // a = b*q + r + // multiplication does not overflow in b*q + // addition does not overflow in (b*q) + r; for now expressed as: r <= bq+r + // b ≠ 0 ==> r < b + // b = 0 ==> q = -1 + // TODO: when a,b become evaluable, can we actually propagate q,r? doesn't seem like it. + // Maybe we need something like an op_constraint for better propagation. + add_polysat_clause("[axiom] quot_rem 1", { m_core.eq(b * q + r - a) }, false); + add_polysat_clause("[axiom] quot_rem 2", { ~m_core.umul_ovfl(b, q) }, false); + // r <= b*q+r + // { apply equivalence: p <= q <=> q-p <= -p-1 } + // b*q <= -r-1 + add_polysat_clause("[axiom] quot_rem 3", { m_core.ule(b * q, -r - 1) }, false); + + auto c_eq = m_core.eq(b); + if (!c_eq.is_always_true()) + add_polysat_clause("[axiom] quot_rem 4", { c_eq, ~m_core.ule(b, r) }, false); + if (!c_eq.is_always_false()) + add_polysat_clause("[axiom] quot_rem 5", { ~c_eq, m_core.eq(q + 1) }, false); + } + + void solver::internalize_sign_extend(app* e) { + expr* arg = e->get_arg(0); + unsigned sz = bv.get_bv_size(e); + unsigned arg_sz = bv.get_bv_size(arg); + unsigned sz2 = sz - arg_sz; + var2pdd(expr2enode(e)->get_th_var(get_id())); + if (arg_sz == sz) + add_clause(eq_internalize(e, arg), nullptr); + else { + sat::literal lt0 = ctx.mk_literal(bv.mk_slt(arg, bv.mk_numeral(0, arg_sz))); + // arg < 0 ==> e = concat(arg, 1...1) + // arg >= 0 ==> e = concat(arg, 0...0) + add_clause(lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(rational::power_of_two(sz2) - 1, sz2))), nullptr); + add_clause(~lt0, eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz2))), nullptr); + } + } + + void solver::internalize_zero_extend(app* e) { + expr* arg = e->get_arg(0); + unsigned sz = bv.get_bv_size(e); + unsigned arg_sz = bv.get_bv_size(arg); + unsigned sz2 = sz - arg_sz; + var2pdd(expr2enode(e)->get_th_var(get_id())); + if (arg_sz == sz) + add_clause(eq_internalize(e, arg), nullptr); + else + // e = concat(arg, 0...0) + add_clause(eq_internalize(e, bv.mk_concat(arg, bv.mk_numeral(0, sz2))), nullptr); + } + + void solver::internalize_div_rem(app* e, bool is_div) { + bv_rewriter_params p(s().params()); + if (p.hi_div0()) { + if (is_div) + internalize_udiv_i(e); + else + internalize_urem_i(e); + return; + } + expr* arg1 = e->get_arg(0); + expr* arg2 = e->get_arg(1); + unsigned sz = bv.get_bv_size(e); + expr_ref zero(bv.mk_numeral(0, sz), m); + sat::literal eqZ = eq_internalize(arg2, zero); + sat::literal eqU = eq_internalize(e, is_div ? bv.mk_bv_udiv0(arg1) : bv.mk_bv_urem0(arg1)); + sat::literal eqI = eq_internalize(e, is_div ? bv.mk_bv_udiv_i(arg1, arg2) : bv.mk_bv_urem_i(arg1, arg2)); + add_clause(~eqZ, eqU); + add_clause(eqZ, eqI); + ctx.add_aux(~eqZ, eqU); + ctx.add_aux(eqZ, eqI); + } + + void solver::internalize_num(app* a) { + rational val; + unsigned sz = 0; + VERIFY(bv.is_numeral(a, val, sz)); + auto p = m_core.value(val, sz); + internalize_set(a, p); + } + + // TODO - test that internalize works with recursive call on bit2bool + void solver::internalize_mkbv(app* a) { + unsigned i = 0; + for (expr* arg : *a) { + expr_ref b2b(m); + b2b = bv.mk_bit2bool(a, i); + sat::literal bit_i = ctx.internalize(b2b, false, false); + sat::literal lit = expr2literal(arg); + add_equiv(lit, bit_i); +#if 0 + ctx.add_aux_equiv(lit, bit_i); +#endif + ++i; + } + } + + void solver::internalize_extract(app* e) { + var2pdd(expr2enode(e)->get_th_var(get_id())); + } + + void solver::internalize_concat(app* e) { + SASSERT(bv.is_concat(e)); + var2pdd(expr2enode(e)->get_th_var(get_id())); + } + + void solver::internalize_par_unary(app* e, std::function const& fn) { + pdd const p = expr2pdd(e->get_arg(0)); + unsigned const par = e->get_parameter(0).get_int(); + internalize_set(e, fn(p, par)); + } + + void solver::internalize_binary(app* e, std::function const& fn) { + SASSERT(e->get_num_args() >= 1); + auto p = expr2pdd(e->get_arg(0)); + for (unsigned i = 1; i < e->get_num_args(); ++i) + p = fn(p, expr2pdd(e->get_arg(i))); + internalize_set(e, p); + } + + void solver::internalize_binary(app* e, std::function const& fn) { + SASSERT(e->get_num_args() >= 1); + expr* r = e->get_arg(0); + for (unsigned i = 1; i < e->get_num_args(); ++i) + r = fn(r, e->get_arg(i)); + ctx.internalize(r); + internalize_set(e, var2pdd(expr2enode(r)->get_th_var(get_id()))); + } + + void solver::internalize_unary(app* e, std::function const& fn) { + SASSERT(e->get_num_args() == 1); + auto p = expr2pdd(e->get_arg(0)); + internalize_set(e, fn(p)); + } + + template + void solver::internalize_le(app* e) { + SASSERT(e->get_num_args() == 2); + auto p = expr2pdd(e->get_arg(0)); + auto q = expr2pdd(e->get_arg(1)); + if (Rev) + std::swap(p, q); + auto sc = Signed ? m_core.sle(p, q) : m_core.ule(p, q); + if (Negated) + sc = ~sc; + + sat::literal lit = expr2literal(e); + if (lit.sign()) + sc = ~sc; + mk_atom(lit.var(), sc); + } + + dd::pdd solver::expr2pdd(expr* e) { + return var2pdd(get_th_var(e)); + } + + dd::pdd solver::var2pdd(euf::theory_var v) { + if (!m_var2pdd_valid.get(v, false)) { + unsigned bv_size = get_bv_size(v); + pvar pv = m_core.add_var(bv_size); + m_pddvar2var.setx(pv, v, UINT_MAX); + pdd p = m_core.var(pv); + internalize_set(v, p); + return p; + } + return m_var2pdd[v]; + } + + void solver::apply_sort_cnstr(euf::enode* n, sort* s) { + if (!bv.is_bv(n->get_expr())) + return; + theory_var v = n->get_th_var(get_id()); + if (v == euf::null_theory_var) + v = mk_var(n); + var2pdd(v); + } + + void solver::internalize_set(expr* e, pdd const& p) { + internalize_set(get_th_var(e), p); + } + + void solver::internalize_set(euf::theory_var v, pdd const& p) { + SASSERT_EQ(get_bv_size(v), p.power_of_2()); + m_var2pdd.reserve(get_num_vars(), p); + m_var2pdd_valid.reserve(get_num_vars(), false); + ctx.push(set_bitvector_trail(m_var2pdd_valid, v)); +#if 0 + m_var2pdd[v].reset(p.manager()); +#endif + m_var2pdd[v] = p; + } + + void solver::eq_internalized(euf::enode* n) { + SASSERT(m.is_eq(n->get_expr())); + } + + +} From 7bcb4936c7ecb176d08db4ed2b22d898dca2d3a0 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 13 Dec 2023 20:45:33 -0800 Subject: [PATCH 74/89] remove stale files Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat_assignment.cpp | 119 ---------------- src/sat/smt/polysat_assignment.h | 120 ---------------- src/sat/smt/polysat_constraints.cpp | 25 ---- src/sat/smt/polysat_constraints.h | 128 ----------------- src/sat/smt/polysat_substitution.h | 212 ---------------------------- src/sat/smt/polysat_viable.h | 130 ----------------- 6 files changed, 734 deletions(-) delete mode 100644 src/sat/smt/polysat_assignment.cpp delete mode 100644 src/sat/smt/polysat_assignment.h delete mode 100644 src/sat/smt/polysat_constraints.cpp delete mode 100644 src/sat/smt/polysat_constraints.h delete mode 100644 src/sat/smt/polysat_substitution.h delete mode 100644 src/sat/smt/polysat_viable.h diff --git a/src/sat/smt/polysat_assignment.cpp b/src/sat/smt/polysat_assignment.cpp deleted file mode 100644 index a985188fa..000000000 --- a/src/sat/smt/polysat_assignment.cpp +++ /dev/null @@ -1,119 +0,0 @@ -/*++ -Copyright (c) 2021 Microsoft Corporation - -Module Name: - - polysat substitution and assignment - -Author: - - Nikolaj Bjorner (nbjorner) 2021-03-19 - Jakob Rath 2021-04-06 - ---*/ - -#include "sat/smt/polysat_assignment.h" -#include "sat/smt/polysat_core.h" - -namespace polysat { - - substitution::substitution(pdd p) - : m_subst(std::move(p)) { } - - substitution::substitution(dd::pdd_manager& m) - : m_subst(m.one()) { } - - substitution substitution::add(pvar var, rational const& value) const { - return {m_subst.subst_add(var, value)}; - } - - pdd substitution::apply_to(pdd const& p) const { - return p.subst_val(m_subst); - } - - bool substitution::contains(pvar var) const { - rational out_value; - return value(var, out_value); - } - - bool substitution::value(pvar var, rational& out_value) const { - return m_subst.subst_get(var, out_value); - } - - assignment::assignment(core& s) - : m_core(s) { } - - - assignment assignment::clone() const { - assignment a(s()); - a.m_pairs = m_pairs; - a.m_subst.reserve(m_subst.size()); - for (unsigned i = m_subst.size(); i-- > 0; ) - if (m_subst[i]) - a.m_subst.set(i, alloc(substitution, *m_subst[i])); - a.m_subst_trail = m_subst_trail; - return a; - } - - bool assignment::contains(pvar var) const { - return subst(s().size(var)).contains(var); - } - - bool assignment::value(pvar var, rational& out_value) const { - return subst(s().size(var)).value(var, out_value); - } - - substitution& assignment::subst(unsigned sz) { - return const_cast(std::as_const(*this).subst(sz)); - } - - substitution const& assignment::subst(unsigned sz) const { - m_subst.reserve(sz + 1); - if (!m_subst[sz]) - m_subst.set(sz, alloc(substitution, s().sz2pdd(sz))); - return *m_subst[sz]; - } - - void assignment::push(pvar var, rational const& value) { - SASSERT(all_of(m_pairs, [var](assignment_item_t const& item) { return item.first != var; })); - m_pairs.push_back({var, value}); - unsigned const sz = s().size(var); - substitution& sub = subst(sz); - m_subst_trail.push_back(sub); - sub = sub.add(var, value); - SASSERT_EQ(sub, *m_subst[sz]); - } - - void assignment::pop() { - substitution& sub = m_subst_trail.back(); - unsigned sz = sub.bit_width(); - SASSERT_EQ(sz, s().size(m_pairs.back().first)); - *m_subst[sz] = sub; - m_subst_trail.pop_back(); - m_pairs.pop_back(); - } - - pdd assignment::apply_to(pdd const& p) const { - unsigned const sz = p.power_of_2(); - return subst(sz).apply_to(p); - } - - std::ostream& substitution::display(std::ostream& out) const { - char const* delim = ""; - pdd p = m_subst; - while (!p.is_val()) { - SASSERT(p.lo().is_val()); - out << delim << "v" << p.var() << " := " << p.lo(); - delim = " "; - p = p.hi(); - } - return out; - } - - std::ostream& assignment::display(std::ostream& out) const { - char const* delim = ""; - for (auto const& [var, value] : m_pairs) - out << delim << var << " == " << value, delim = " "; - return out; - } -} diff --git a/src/sat/smt/polysat_assignment.h b/src/sat/smt/polysat_assignment.h deleted file mode 100644 index daff03dd5..000000000 --- a/src/sat/smt/polysat_assignment.h +++ /dev/null @@ -1,120 +0,0 @@ -/*++ -Copyright (c) 2021 Microsoft Corporation - -Module Name: - - polysat substitution and assignment - -Author: - - Nikolaj Bjorner (nbjorner) 2021-03-19 - Jakob Rath 2021-04-06 - ---*/ -#pragma once -#include "util/scoped_ptr_vector.h" -#include "sat/smt/polysat_types.h" - -namespace polysat { - - class core; - - using assignment_item_t = std::pair; - - class substitution_iterator { - pdd m_current; - substitution_iterator(pdd current) : m_current(std::move(current)) {} - friend class substitution; - - public: - using value_type = assignment_item_t; - using difference_type = std::ptrdiff_t; - using pointer = value_type const*; - using reference = value_type const&; - using iterator_category = std::input_iterator_tag; - - substitution_iterator& operator++() { - SASSERT(!m_current.is_val()); - m_current = m_current.hi(); - return *this; - } - - value_type operator*() const { - SASSERT(!m_current.is_val()); - return { m_current.var(), m_current.lo().val() }; - } - - bool operator==(substitution_iterator const& other) const { return m_current == other.m_current; } - bool operator!=(substitution_iterator const& other) const { return !operator==(other); } - }; - - /** Substitution for a single bit width. */ - class substitution { - pdd m_subst; - - substitution(pdd p); - - public: - substitution(dd::pdd_manager& m); - [[nodiscard]] substitution add(pvar var, rational const& value) const; - [[nodiscard]] pdd apply_to(pdd const& p) const; - - [[nodiscard]] bool contains(pvar var) const; - [[nodiscard]] bool value(pvar var, rational& out_value) const; - - [[nodiscard]] bool empty() const { return m_subst.is_one(); } - - pdd const& to_pdd() const { return m_subst; } - unsigned bit_width() const { return to_pdd().power_of_2(); } - - bool operator==(substitution const& other) const { return &m_subst.manager() == &other.m_subst.manager() && m_subst == other.m_subst; } - bool operator!=(substitution const& other) const { return !operator==(other); } - - std::ostream& display(std::ostream& out) const; - - using const_iterator = substitution_iterator; - const_iterator begin() const { return {m_subst}; } - const_iterator end() const { return {m_subst.manager().one()}; } - }; - - /** Full variable assignment, may include variables of varying bit widths. */ - class assignment { - core& m_core; - vector m_pairs; - mutable scoped_ptr_vector m_subst; - vector m_subst_trail; - - substitution& subst(unsigned sz); - core& s() const { return m_core; } - public: - assignment(core& s); - // prevent implicit copy, use clone() if you do need a copy - assignment(assignment const&) = delete; - assignment& operator=(assignment const&) = delete; - assignment(assignment&&) = default; - assignment& operator=(assignment&&) = default; - assignment clone() const; - - void push(pvar var, rational const& value); - void pop(); - - pdd apply_to(pdd const& p) const; - - bool contains(pvar var) const; - bool value(pvar var, rational& out_value) const; - rational value(pvar var) const { rational val; VERIFY(value(var, val)); return val; } - bool empty() const { return pairs().empty(); } - substitution const& subst(unsigned sz) const; - vector const& pairs() const { return m_pairs; } - using const_iterator = decltype(m_pairs)::const_iterator; - const_iterator begin() const { return pairs().begin(); } - const_iterator end() const { return pairs().end(); } - - std::ostream& display(std::ostream& out) const; - }; - - inline std::ostream& operator<<(std::ostream& out, substitution const& sub) { return sub.display(out); } - - inline std::ostream& operator<<(std::ostream& out, assignment const& a) { return a.display(out); } -} - diff --git a/src/sat/smt/polysat_constraints.cpp b/src/sat/smt/polysat_constraints.cpp deleted file mode 100644 index 1c9de327c..000000000 --- a/src/sat/smt/polysat_constraints.cpp +++ /dev/null @@ -1,25 +0,0 @@ -/*++ -Copyright (c) 2021 Microsoft Corporation - -Module Name: - - polysat constraints - -Author: - - Nikolaj Bjorner (nbjorner) 2021-03-19 - Jakob Rath 2021-04-06 - ---*/ -#include "sat/smt/polysat_core.h" -#include "sat/smt/polysat_solver.h" -#include "sat/smt/polysat_constraints.h" - -namespace polysat { - - signed_constraint constraints::ule(pdd const& p, pdd const& q) { - auto* c = alloc(ule_constraint, p, q); - m_trail.push(new_obj_trail(c)); - return signed_constraint(ckind_t::ule_t, c); - } -} diff --git a/src/sat/smt/polysat_constraints.h b/src/sat/smt/polysat_constraints.h deleted file mode 100644 index 24c7f9a11..000000000 --- a/src/sat/smt/polysat_constraints.h +++ /dev/null @@ -1,128 +0,0 @@ -/*++ -Copyright (c) 2021 Microsoft Corporation - -Module Name: - - polysat constraints - -Author: - - Nikolaj Bjorner (nbjorner) 2021-03-19 - Jakob Rath 2021-04-06 - ---*/ - - -#pragma once -#include "sat/smt/polysat_types.h" - -namespace polysat { - - class core; - - using pdd = dd::pdd; - using pvar = unsigned; - - enum ckind_t { ule_t, umul_ovfl_t, smul_fl_t, op_t }; - - class constraint { - unsigned_vector m_vars; - public: - virtual ~constraint() {} - unsigned_vector& vars() { return m_vars; } - unsigned_vector const& vars() const { return m_vars; } - unsigned var(unsigned idx) const { return m_vars[idx]; } - bool contains_var(pvar v) const { return m_vars.contains(v); } - }; - - class ule_constraint : public constraint { - pdd m_lhs, m_rhs; - public: - ule_constraint(pdd const& lhs, pdd const& rhs) : m_lhs(lhs), m_rhs(rhs) {} - }; - - class signed_constraint { - bool m_sign = false; - ckind_t m_op = ule_t; - constraint* m_constraint = nullptr; - public: - signed_constraint() {} - signed_constraint(ckind_t c, constraint* p) : m_op(c), m_constraint(p) {} - signed_constraint operator~() const { signed_constraint r(*this); r.m_sign = !r.m_sign; return r; } - bool sign() const { return m_sign; } - unsigned_vector& vars() { return m_constraint->vars(); } - unsigned_vector const& vars() const { return m_constraint->vars(); } - unsigned var(unsigned idx) const { return m_constraint->var(idx); } - bool contains_var(pvar v) const { return m_constraint->contains_var(v); } - bool is_ule() const { return m_op == ule_t; } - ule_constraint& to_ule() { return *reinterpret_cast(m_constraint); } - bool is_eq(pvar& v, rational& val) { throw default_exception("nyi"); } - }; - - using dependent_constraint = std::pair; - - class constraints { - trail_stack& m_trail; - public: - constraints(trail_stack& c) : m_trail(c) {} - - signed_constraint eq(pdd const& p) { return ule(p, p.manager().mk_val(0)); } - signed_constraint eq(pdd const& p, rational const& v) { return eq(p - p.manager().mk_val(v)); } - signed_constraint ule(pdd const& p, pdd const& q); - signed_constraint sle(pdd const& p, pdd const& q) { throw default_exception("nyi"); } - signed_constraint ult(pdd const& p, pdd const& q) { throw default_exception("nyi"); } - signed_constraint slt(pdd const& p, pdd const& q) { throw default_exception("nyi"); } - signed_constraint umul_ovfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } - signed_constraint smul_ovfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } - signed_constraint smul_udfl(pdd const& p, pdd const& q) { throw default_exception("nyi"); } - signed_constraint bit(pdd const& p, unsigned i) { throw default_exception("nyi"); } - - signed_constraint diseq(pdd const& p) { return ~eq(p); } - signed_constraint diseq(pdd const& p, pdd const& q) { return diseq(p - q); } - signed_constraint diseq(pdd const& p, rational const& q) { return diseq(p - q); } - signed_constraint diseq(pdd const& p, int q) { return diseq(p, rational(q)); } - signed_constraint diseq(pdd const& p, unsigned q) { return diseq(p, rational(q)); } - - signed_constraint ule(pdd const& p, rational const& q) { return ule(p, p.manager().mk_val(q)); } - signed_constraint ule(rational const& p, pdd const& q) { return ule(q.manager().mk_val(p), q); } - signed_constraint ule(pdd const& p, int q) { return ule(p, rational(q)); } - signed_constraint ule(pdd const& p, unsigned q) { return ule(p, rational(q)); } - signed_constraint ule(int p, pdd const& q) { return ule(rational(p), q); } - signed_constraint ule(unsigned p, pdd const& q) { return ule(rational(p), q); } - - signed_constraint uge(pdd const& p, pdd const& q) { return ule(q, p); } - signed_constraint uge(pdd const& p, rational const& q) { return ule(q, p); } - - signed_constraint ult(pdd const& p, rational const& q) { return ult(p, p.manager().mk_val(q)); } - signed_constraint ult(rational const& p, pdd const& q) { return ult(q.manager().mk_val(p), q); } - signed_constraint ult(int p, pdd const& q) { return ult(rational(p), q); } - signed_constraint ult(unsigned p, pdd const& q) { return ult(rational(p), q); } - signed_constraint ult(pdd const& p, int q) { return ult(p, rational(q)); } - signed_constraint ult(pdd const& p, unsigned q) { return ult(p, rational(q)); } - - signed_constraint slt(pdd const& p, rational const& q) { return slt(p, p.manager().mk_val(q)); } - signed_constraint slt(rational const& p, pdd const& q) { return slt(q.manager().mk_val(p), q); } - signed_constraint slt(pdd const& p, int q) { return slt(p, rational(q)); } - signed_constraint slt(pdd const& p, unsigned q) { return slt(p, rational(q)); } - signed_constraint slt(int p, pdd const& q) { return slt(rational(p), q); } - signed_constraint slt(unsigned p, pdd const& q) { return slt(rational(p), q); } - - - signed_constraint sgt(pdd const& p, pdd const& q) { return slt(q, p); } - signed_constraint sgt(pdd const& p, int q) { return slt(q, p); } - signed_constraint sgt(pdd const& p, unsigned q) { return slt(q, p); } - signed_constraint sgt(int p, pdd const& q) { return slt(q, p); } - signed_constraint sgt(unsigned p, pdd const& q) { return slt(q, p); } - - signed_constraint umul_ovfl(pdd const& p, rational const& q) { return umul_ovfl(p, p.manager().mk_val(q)); } - signed_constraint umul_ovfl(rational const& p, pdd const& q) { return umul_ovfl(q.manager().mk_val(p), q); } - signed_constraint umul_ovfl(pdd const& p, int q) { return umul_ovfl(p, rational(q)); } - signed_constraint umul_ovfl(pdd const& p, unsigned q) { return umul_ovfl(p, rational(q)); } - signed_constraint umul_ovfl(int p, pdd const& q) { return umul_ovfl(rational(p), q); } - signed_constraint umul_ovfl(unsigned p, pdd const& q) { return umul_ovfl(rational(p), q); } - - - //signed_constraint even(pdd const& p) { return parity_at_least(p, 1); } - //signed_constraint odd(pdd const& p) { return ~even(p); } - }; -} \ No newline at end of file diff --git a/src/sat/smt/polysat_substitution.h b/src/sat/smt/polysat_substitution.h deleted file mode 100644 index a30c6b710..000000000 --- a/src/sat/smt/polysat_substitution.h +++ /dev/null @@ -1,212 +0,0 @@ -/*++ -Copyright (c) 2021 Microsoft Corporation - -Module Name: - - polysat substitution - -Author: - - Nikolaj Bjorner (nbjorner) 2021-03-19 - Jakob Rath 2021-04-06 - ---*/ -#pragma once -#include "sat/smt/polysat_types.h" - -namespace polysat { - - using assignment_item_t = std::pair; - - class substitution_iterator { - pdd m_current; - substitution_iterator(pdd current) : m_current(std::move(current)) {} - friend class substitution; - - public: - using value_type = assignment_item_t; - using difference_type = std::ptrdiff_t; - using pointer = value_type const*; - using reference = value_type const&; - using iterator_category = std::input_iterator_tag; - - substitution_iterator& operator++() { - SASSERT(!m_current.is_val()); - m_current = m_current.hi(); - return *this; - } - - value_type operator*() const { - SASSERT(!m_current.is_val()); - return { m_current.var(), m_current.lo().val() }; - } - - bool operator==(substitution_iterator const& other) const { return m_current == other.m_current; } - bool operator!=(substitution_iterator const& other) const { return !operator==(other); } - }; - - /** Substitution for a single bit width. */ - class substitution { - pdd m_subst; - - substitution(pdd p); - - public: - substitution(dd::pdd_manager& m); - [[nodiscard]] substitution add(pvar var, rational const& value) const; - [[nodiscard]] pdd apply_to(pdd const& p) const; - - [[nodiscard]] bool contains(pvar var) const; - [[nodiscard]] bool value(pvar var, rational& out_value) const; - - [[nodiscard]] bool empty() const { return m_subst.is_one(); } - - pdd const& to_pdd() const { return m_subst; } - unsigned bit_width() const { return to_pdd().power_of_2(); } - - bool operator==(substitution const& other) const { return &m_subst.manager() == &other.m_subst.manager() && m_subst == other.m_subst; } - bool operator!=(substitution const& other) const { return !operator==(other); } - - std::ostream& display(std::ostream& out) const; - - using const_iterator = substitution_iterator; - const_iterator begin() const { return {m_subst}; } - const_iterator end() const { return {m_subst.manager().one()}; } - }; - - /** Full variable assignment, may include variables of varying bit widths. */ - class assignment { - vector m_pairs; - mutable scoped_ptr_vector m_subst; - vector m_subst_trail; - - substitution& subst(unsigned sz); - solver& s() const { return *m_solver; } - public: - assignment(solver& s); - // prevent implicit copy, use clone() if you do need a copy - assignment(assignment const&) = delete; - assignment& operator=(assignment const&) = delete; - assignment(assignment&&) = default; - assignment& operator=(assignment&&) = default; - assignment clone() const; - - void push(pvar var, rational const& value); - void pop(); - - pdd apply_to(pdd const& p) const; - - bool contains(pvar var) const; - bool value(pvar var, rational& out_value) const; - rational value(pvar var) const { rational val; VERIFY(value(var, val)); return val; } - bool empty() const { return pairs().empty(); } - substitution const& subst(unsigned sz) const; - vector const& pairs() const { return m_pairs; } - using const_iterator = decltype(m_pairs)::const_iterator; - const_iterator begin() const { return pairs().begin(); } - const_iterator end() const { return pairs().end(); } - - std::ostream& display(std::ostream& out) const; - }; - - inline std::ostream& operator<<(std::ostream& out, substitution const& sub) { return sub.display(out); } - - inline std::ostream& operator<<(std::ostream& out, assignment const& a) { return a.display(out); } -} - -namespace polysat { - - enum class search_item_k - { - assignment, - boolean, - }; - - class search_item { - search_item_k m_kind; - union { - pvar m_var; - sat::literal m_lit; - }; - bool m_resolved = false; // when marked as resolved it is no longer valid to reduce the conflict state - - search_item(pvar var): m_kind(search_item_k::assignment), m_var(var) {} - search_item(sat::literal lit): m_kind(search_item_k::boolean), m_lit(lit) {} - public: - static search_item assignment(pvar var) { return search_item(var); } - static search_item boolean(sat::literal lit) { return search_item(lit); } - bool is_assignment() const { return m_kind == search_item_k::assignment; } - bool is_boolean() const { return m_kind == search_item_k::boolean; } - bool is_resolved() const { return m_resolved; } - search_item_k kind() const { return m_kind; } - pvar var() const { SASSERT(is_assignment()); return m_var; } - sat::literal lit() const { SASSERT(is_boolean()); return m_lit; } - void set_resolved() { m_resolved = true; } - }; - - class search_state { - solver& s; - - vector m_items; - assignment m_assignment; - - // store index into m_items - unsigned_vector m_pvar_to_idx; - unsigned_vector m_bool_to_idx; - - bool value(pvar v, rational& r) const; - - public: - search_state(solver& s): s(s), m_assignment(s) {} - unsigned size() const { return m_items.size(); } - search_item const& back() const { return m_items.back(); } - search_item const& operator[](unsigned i) const { return m_items[i]; } - - assignment const& get_assignment() const { return m_assignment; } - substitution const& subst(unsigned sz) const { return m_assignment.subst(sz); } - - // TODO: implement the following method if we actually need the assignments without resolved items already during conflict resolution - // (no separate trail needed, just a second m_subst and an index into the trail, I think) - // (update on set_resolved? might be one iteration too early, looking at the old solver::resolve_conflict loop) - substitution const& unresolved_assignment(unsigned sz) const; - - void push_assignment(pvar v, rational const& r); - void push_boolean(sat::literal lit); - void pop(); - - unsigned get_pvar_index(pvar v) const; - unsigned get_bool_index(sat::bool_var var) const; - unsigned get_bool_index(sat::literal lit) const { return get_bool_index(lit.var()); } - - void set_resolved(unsigned i) { m_items[i].set_resolved(); } - - using const_iterator = decltype(m_items)::const_iterator; - const_iterator begin() const { return m_items.begin(); } - const_iterator end() const { return m_items.end(); } - - std::ostream& display(std::ostream& out) const; - std::ostream& display(search_item const& item, std::ostream& out) const; - std::ostream& display_verbose(std::ostream& out) const; - std::ostream& display_verbose(search_item const& item, std::ostream& out) const; - }; - - struct search_state_pp { - search_state const& s; - bool verbose; - search_state_pp(search_state const& s, bool verbose = false) : s(s), verbose(verbose) {} - }; - - struct search_item_pp { - search_state const& s; - search_item const& i; - bool verbose; - search_item_pp(search_state const& s, search_item const& i, bool verbose = false) : s(s), i(i), verbose(verbose) {} - }; - - inline std::ostream& operator<<(std::ostream& out, search_state const& s) { return s.display(out); } - - inline std::ostream& operator<<(std::ostream& out, search_state_pp const& p) { return p.verbose ? p.s.display_verbose(out) : p.s.display(out); } - - inline std::ostream& operator<<(std::ostream& out, search_item_pp const& p) { return p.verbose ? p.s.display_verbose(p.i, out) : p.s.display(p.i, out); } - -} diff --git a/src/sat/smt/polysat_viable.h b/src/sat/smt/polysat_viable.h deleted file mode 100644 index e268d67a3..000000000 --- a/src/sat/smt/polysat_viable.h +++ /dev/null @@ -1,130 +0,0 @@ -/*++ -Copyright (c) 2021 Microsoft Corporation - -Module Name: - - maintain viable domains - It uses the interval extraction functions from forbidden intervals. - An empty viable set corresponds directly to a conflict that does not rely on - the non-viable variable. - -Author: - - Nikolaj Bjorner (nbjorner) 2021-03-19 - Jakob Rath 2021-04-06 - ---*/ -#pragma once - -#include "util/rational.h" -#include "util/dlist.h" -#include "util/map.h" -#include "util/small_object_allocator.h" - -#include "sat/smt/polysat_types.h" -#include "sat/smt/polysat_fi.h" - -namespace polysat { - - enum class find_t { - empty, - singleton, - multiple, - resource_out, - }; - - class core; - class constraints; - - class viable { - core& c; - constraints& cs; - forbidden_intervals m_forbidden_intervals; - - struct entry final : public dll_base, public fi_record { - /// whether the entry has been created by refinement (from constraints in 'fi_record::src') - bool refined = false; - /// whether the entry is part of the current set of intervals, or stashed away for backtracking - bool active = true; - bool valid_for_lemma = true; - pvar var = null_var; - - void reset() { - // dll_base::init(this); // we never did this in alloc_entry either - fi_record::reset(); - refined = false; - active = true; - valid_for_lemma = true; - var = null_var; - } - }; - - enum class entry_kind { unit_e, equal_e, diseq_e }; - - struct layer final { - entry* entries = nullptr; - unsigned bit_width = 0; - layer(unsigned bw) : bit_width(bw) {} - }; - - class layers final { - svector m_layers; - public: - svector const& get_layers() const { return m_layers; } - layer& ensure_layer(unsigned bit_width); - layer* get_layer(unsigned bit_width); - layer* get_layer(entry* e) { return get_layer(e->bit_width); } - layer const* get_layer(unsigned bit_width) const; - layer const* get_layer(entry* e) const { return get_layer(e->bit_width); } - entry* get_entries(unsigned bit_width) const { layer const* l = get_layer(bit_width); return l ? l->entries : nullptr; } - }; - - ptr_vector m_alloc; - vector m_units; // set of viable values based on unit multipliers, layered by bit-width in descending order - ptr_vector m_equal_lin; // entries that have non-unit multipliers, but are equal - ptr_vector m_diseq_lin; // entries that have distinct non-zero multipliers - - bool well_formed(entry* e); - bool well_formed(layers const& ls); - - entry* alloc_entry(pvar v); - - std::ostream& display_one(std::ostream& out, pvar v, entry const* e) const; - std::ostream& display_all(std::ostream& out, pvar v, entry const* e, char const* delimiter = "") const; - void log(); - void log(pvar v); - - struct pop_viable_trail; - void pop_viable(entry* e, pvar v, entry_kind k); - struct push_viable_trail; - void push_viable(entry* e, pvar v, entry_kind k); - - void insert(entry* e, pvar v, ptr_vector& entries, entry_kind k); - - bool intersect(pvar v, entry* e); - - void ensure_var(pvar v); - - public: - viable(core& c); - - ~viable(); - - /** - * Find a next viable value for variable. - */ - find_t find_viable(pvar v, rational& out_val); - - /* - * Explain why the current variable is not viable or signleton. - */ - dependency_vector explain(); - - /* - * Register constraint at index 'idx' as unitary in v. - */ - void add_unitary(pvar v, unsigned idx); - - }; - -} From ec6cab377ad7f1282a7f1dbf3107b142e80d4a76 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Wed, 13 Dec 2023 21:16:02 -0800 Subject: [PATCH 75/89] bv semantics Signed-off-by: Nikolaj Bjorner --- src/sat/smt/intblast_solver.cpp | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index ca8389557..5d9608163 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -541,6 +541,23 @@ namespace intblast { return r; }; + auto rotate_left = [&](unsigned n) { + auto sz = bv.get_bv_size(e); + expr* r = arg(0); + if (n != 0 && sz != 1) { + // r[sz - n - 1 : 0] ++ r[sz - 1 : sz - n] + // r * 2^(sz - n) + (r / 2^(sz - n)) mod n??? + NOT_IMPLEMENTED_YET(); + auto N = bv_size(e); + auto A = rational::power_of_two(sz - n); + auto B = rational::power_of_two(n); + auto hi = a.mk_mul(r, a.mk_int(B)); + auto lo = a.mk_idiv(a.mk_mod(r, a.mk_int(B)), a.mk_int(A)); + r = a.mk_add(hi, lo); + } + return r; + }; + expr* bv_expr = e; expr* r = nullptr; auto const& args = m_args; @@ -590,7 +607,7 @@ namespace intblast { break; case OP_CONCAT: { unsigned sz = 0; - for (unsigned i = 0; i < args.size(); ++i) { + for (unsigned i = args.size(); i-- > 0; ++i) { expr* old_arg = e->get_arg(i); expr* new_arg = umod(old_arg, i); if (sz > 0) { @@ -786,8 +803,18 @@ namespace intblast { r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, r); break; } - case OP_ROTATE_LEFT: - case OP_ROTATE_RIGHT: + case OP_ROTATE_LEFT: { + auto n = e->get_parameter(0).get_int(); + r = rotate_left(n); + break; + } + case OP_ROTATE_RIGHT: { + auto n = e->get_parameter(0).get_int(); + unsigned sz = bv.get_bv_size(e); + n = n % sz; + r = rotate_left(sz - n); + break; + } case OP_EXT_ROTATE_LEFT: case OP_EXT_ROTATE_RIGHT: case OP_REPEAT: From 4af6238f1c006d917cbc5691976d44dd3ebccf4f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 14 Dec 2023 10:35:13 -0800 Subject: [PATCH 76/89] weed out some bugs, add more bv op support in intblast and polysat solvers Signed-off-by: Nikolaj Bjorner --- src/sat/smt/intblast_solver.cpp | 57 ++++++++++---- src/sat/smt/polysat/core.cpp | 13 ++-- src/sat/smt/polysat/types.h | 18 ----- src/sat/smt/polysat/viable.h | 20 +++++ src/sat/smt/polysat_internalize.cpp | 110 +++++++++++++++++++++++++--- src/sat/smt/polysat_solver.cpp | 28 +------ src/sat/smt/polysat_solver.h | 6 ++ 7 files changed, 175 insertions(+), 77 deletions(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 5d9608163..592c8c0f4 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -543,16 +543,17 @@ namespace intblast { auto rotate_left = [&](unsigned n) { auto sz = bv.get_bv_size(e); + n = n % sz; expr* r = arg(0); if (n != 0 && sz != 1) { // r[sz - n - 1 : 0] ++ r[sz - 1 : sz - n] - // r * 2^(sz - n) + (r / 2^(sz - n)) mod n??? - NOT_IMPLEMENTED_YET(); + // r * 2^(sz - n) + (r div 2^n) mod 2^(sz - n)??? + // r * A + (r div B) mod A auto N = bv_size(e); auto A = rational::power_of_two(sz - n); auto B = rational::power_of_two(n); - auto hi = a.mk_mul(r, a.mk_int(B)); - auto lo = a.mk_idiv(a.mk_mod(r, a.mk_int(B)), a.mk_int(A)); + auto hi = a.mk_mul(r, a.mk_int(A)); + auto lo = a.mk_mod(a.mk_idiv(umod(e, 0), a.mk_int(B)), a.mk_int(A)); r = a.mk_add(hi, lo); } return r; @@ -607,7 +608,7 @@ namespace intblast { break; case OP_CONCAT: { unsigned sz = 0; - for (unsigned i = args.size(); i-- > 0; ++i) { + for (unsigned i = args.size(); i-- > 0;) { expr* old_arg = e->get_arg(i); expr* new_arg = umod(old_arg, i); if (sz > 0) { @@ -809,20 +810,48 @@ namespace intblast { break; } case OP_ROTATE_RIGHT: { - auto n = e->get_parameter(0).get_int(); unsigned sz = bv.get_bv_size(e); - n = n % sz; + auto n = e->get_parameter(0).get_int(); r = rotate_left(sz - n); break; } - case OP_EXT_ROTATE_LEFT: - case OP_EXT_ROTATE_RIGHT: - case OP_REPEAT: - case OP_BREDOR: - case OP_BREDAND: - verbose_stream() << mk_pp(e, m) << "\n"; - NOT_IMPLEMENTED_YET(); + case OP_EXT_ROTATE_LEFT: { + unsigned sz = bv.get_bv_size(e); + expr* y = umod(e, 1); + r = a.mk_int(0); + for (unsigned i = 0; i < sz; ++i) + r = m.mk_ite(m.mk_eq(a.mk_int(i), y), rotate_left(i), r); break; + } + case OP_EXT_ROTATE_RIGHT: { + unsigned sz = bv.get_bv_size(e); + expr* y = umod(e, 1); + r = a.mk_int(0); + for (unsigned i = 0; i < sz; ++i) + r = m.mk_ite(m.mk_eq(a.mk_int(i), y), rotate_left(sz - i), r); + break; + } + case OP_REPEAT: { + unsigned n = e->get_parameter(0).get_int(); + expr* x = umod(e->get_arg(0), 0); + r = x; + rational N = bv_size(e->get_arg(0)); + rational N0 = N; + for (unsigned i = 1; i < n; ++i) + r = a.mk_add(a.mk_mul(a.mk_int(N), x), r), N *= N0; + break; + } + case OP_BREDOR: { + r = umod(e->get_arg(0), 0); + r = m.mk_not(m.mk_eq(r, a.mk_int(0))); + break; + } + case OP_BREDAND: { + rational N = bv_size(e->get_arg(0)); + r = umod(e->get_arg(0), 0); + r = m.mk_not(m.mk_eq(r, a.mk_int(N - 1))); + break; + } default: verbose_stream() << mk_pp(e, m) << "\n"; NOT_IMPLEMENTED_YET(); diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index c0b56a3d8..b3f8474b7 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -140,14 +140,14 @@ namespace polysat { for (; i < sz && j < 2; ++i) if (!is_assigned(vars[i])) std::swap(vars[i], vars[j++]); - sc.set_num_watch(i); - if (i > 0) + sc.set_num_watch(j); + if (j > 0) add_watch(idx, vars[0]); - if (i > 1) + if (j > 1) add_watch(idx, vars[1]); IF_VERBOSE(10, verbose_stream() << "add watch " << sc << " " << vars << " "; - if (vars.size() > 0) verbose_stream() << "( " << idx << " : " << m_watch[vars[0]] << ") "; - if (vars.size() > 1) verbose_stream() << "( " << idx << " : " << m_watch[vars[1]] << ") "; + if (j > 0) verbose_stream() << "( " << idx << " : " << m_watch[vars[0]] << ") "; + if (j > 1) verbose_stream() << "( " << idx << " : " << m_watch[vars[1]] << ") "; verbose_stream() << "\n"); s.trail().push(mk_add_watch(*this)); return idx; @@ -227,7 +227,6 @@ namespace polysat { m_assignment.push(v , value); s.trail().push(mk_assign_var(v, *this)); - return; // update the watch lists for pvars // remove constraints from m_watch[v] that have more than 2 free variables. // for entries where there is only one free variable left add to viable set @@ -242,7 +241,7 @@ namespace polysat { bool swapped = false; for (unsigned i = vars.size(); i-- > 2; ) { if (!is_assigned(vars[i])) { - verbose_stream() << "watch instead " << idx << " " << vars[i] << "instead of " << vars[0] << "\n"; + verbose_stream() << "watch instead " << vars[i] << " instead of " << vars[0] << " for " << idx << "\n"; add_watch(idx, vars[i]); std::swap(vars[i], vars[0]); swapped = true; diff --git a/src/sat/smt/polysat/types.h b/src/sat/smt/polysat/types.h index e7beb3eb1..d0b5f7bca 100644 --- a/src/sat/smt/polysat/types.h +++ b/src/sat/smt/polysat/types.h @@ -59,24 +59,6 @@ namespace polysat { return out << "v" << d.eq().first << " == v" << d.eq().second << "@" << d.level(); } - struct trailing_bits { - unsigned length; - rational bits; - bool positive; - unsigned src_idx; - }; - - struct leading_bits { - unsigned length; - bool positive; // either all 0 or all 1 - unsigned src_idx; - }; - - struct single_bit { - bool positive; - unsigned position; - unsigned src_idx; - }; struct fixed_bits { unsigned hi = 0; diff --git a/src/sat/smt/polysat/viable.h b/src/sat/smt/polysat/viable.h index 5f5af7616..64a8d3194 100644 --- a/src/sat/smt/polysat/viable.h +++ b/src/sat/smt/polysat/viable.h @@ -33,6 +33,26 @@ namespace polysat { resource_out, }; + struct trailing_bits { + unsigned length; + rational bits; + bool positive; + unsigned src_idx; + }; + + struct leading_bits { + unsigned length; + bool positive; // either all 0 or all 1 + unsigned src_idx; + }; + + struct single_bit { + bool positive; + unsigned position; + unsigned src_idx; + }; + + class core; class constraints; diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index ef469fe6f..0ba59de0a 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -139,20 +139,21 @@ namespace polysat { case OP_ZERO_EXT: internalize_zero_extend(a); break; case OP_SIGN_EXT: internalize_sign_extend(a); break; - // polysat::solver should also support at least: + case OP_BSREM: + case OP_BSREM_I: + case OP_BSMOD: + case OP_BSMOD_I: + case OP_BSDIV: + case OP_BSDIV_I: + expr2pdd(a); + m_delayed_axioms.push_back(a); + ctx.push(push_back_vector(m_delayed_axioms)); + break; + case OP_BREDAND: // x == 2^K - 1 unary, return single bit, 1 if all input bits are set. case OP_BREDOR: // x > 0 unary, return single bit, 1 if at least one input bit is set. case OP_BCOMP: // x == y binary, return single bit, 1 if the arguments are equal. - case OP_BSDIV: - case OP_BSREM: - case OP_BSMOD: - case OP_BSDIV_I: - case OP_BSREM_I: - case OP_BSMOD_I: - IF_VERBOSE(0, verbose_stream() << mk_pp(a, m) << "\n"); - NOT_IMPLEMENTED_YET(); - return; default: IF_VERBOSE(0, verbose_stream() << mk_pp(a, m) << "\n"); NOT_IMPLEMENTED_YET(); @@ -263,7 +264,94 @@ namespace polysat { expr* x, * y; VERIFY(bv.is_bv_shl(n, x, y)); m_core.shl(expr2pdd(x), expr2pdd(y), expr2pdd(n)); - } + } + + bool solver::propagate_delayed_axioms() { + if (m_delayed_axioms_qhead == m_delayed_axioms.size()) + return false; + ctx.push(value_trail(m_delayed_axioms_qhead)); + for (; m_delayed_axioms_qhead < m_delayed_axioms.size() && !inconsistent(); ++m_delayed_axioms_qhead) { + app* e = m_delayed_axioms[m_delayed_axioms_qhead]; + expr* x, *y; + if (bv.is_bv_sdiv(e, x, y)) + axiomatize_sdiv(e, x, y); + else if (bv.is_bv_sdivi(e, x, y)) + axiomatize_sdiv(e, x, y); + else if (bv.is_bv_srem(e, x, y)) + axiomatize_srem(e, x, y); + else if (bv.is_bv_sremi(e, x, y)) + axiomatize_srem(e, x, y); + else if (bv.is_bv_smod(e, x, y)) + axiomatize_smod(e, x, y); + else if (bv.is_bv_smodi(e, x, y)) + axiomatize_smod(e, x, y); + else + UNREACHABLE(); + } + return true; + } + + // y = 0 -> x + // else x - sdiv(x, y) * y + void solver::axiomatize_srem(app* e, expr* x, expr* y) { + unsigned sz = bv.get_bv_size(x); + sat::literal y_eq0 = eq_internalize(y, bv.mk_zero(sz)); + add_clause(~y_eq0, eq_internalize(e, x)); + add_clause(y_eq0, eq_internalize(e, bv.mk_bv_mul(bv.mk_bv_sdiv(x, y), y))); + } + + // u := umod(x, y) + // u = 0 -> 0 + // y = 0 -> x + // x < 0, y < 0 -> -u + // x < 0, y >= 0 -> y - u + // x >= 0, y < 0 -> y + u + // x >= 0, y >= 0 -> u + void solver::axiomatize_smod(app* e, expr* x, expr* y) { + unsigned sz = bv.get_bv_size(x); + expr* u = bv.mk_bv_urem(x, y); + rational N = rational::power_of_two(bv.get_bv_size(x)); + expr* signx = bv.mk_ule(bv.mk_numeral(N / 2, sz), x); + expr* signy = bv.mk_ule(bv.mk_numeral(N / 2, sz), y); + sat::literal lsignx = mk_literal(signx); + sat::literal lsigny = mk_literal(signy); + sat::literal u_eq0 = eq_internalize(u, bv.mk_zero(sz)); + sat::literal y_eq0 = eq_internalize(y, bv.mk_zero(sz)); + add_clause(~u_eq0, eq_internalize(e, bv.mk_zero(sz))); + add_clause(u_eq0, ~y_eq0, eq_internalize(e, x)); + add_clause(~lsignx, ~lsigny, eq_internalize(e, bv.mk_bv_neg(u))); + add_clause(y_eq0, ~lsignx, lsigny, eq_internalize(e, bv.mk_bv_sub(y, u))); + add_clause(y_eq0, lsignx, ~lsigny, eq_internalize(e, bv.mk_bv_add(y, u))); + add_clause(y_eq0, lsignx, lsigny, eq_internalize(e, u)); + } + + + // d = udiv(abs(x), abs(y)) + // y = 0, x > 0 -> 1 + // y = 0, x <= 0 -> -1 + // x = 0, y != 0 -> 0 + // x > 0, y < 0 -> -d + // x < 0, y > 0 -> -d + // x > 0, y > 0 -> d + // x < 0, y < 0 -> d + void solver::axiomatize_sdiv(app* e, expr* x, expr* y) { + unsigned sz = bv.get_bv_size(x); + rational N = rational::power_of_two(bv.get_bv_size(x)); + expr* signx = bv.mk_ule(bv.mk_numeral(N/2, sz), x); + expr* signy = bv.mk_ule(bv.mk_numeral(N/2, sz), y); + expr* absx = m.mk_ite(signx, bv.mk_bv_sub(bv.mk_numeral(N-1, sz), x), x); + expr* absy = m.mk_ite(signy, bv.mk_bv_sub(bv.mk_numeral(N-1, sz), y), y); + expr* d = bv.mk_bv_udiv(absx, absy); + sat::literal lsignx = mk_literal(signx); + sat::literal lsigny = mk_literal(signy); + sat::literal y_eq0 = eq_internalize(y, bv.mk_zero(sz)); + add_clause(~y_eq0, ~lsignx, eq_internalize(e, bv.mk_numeral(1, sz))); + add_clause(~y_eq0, lsignx, eq_internalize(e, bv.mk_numeral(N-1, sz))); + add_clause(y_eq0, lsignx, ~lsigny, eq_internalize(e, bv.mk_bv_neg(d))); + add_clause(y_eq0, ~lsignx, lsigny, eq_internalize(e, bv.mk_bv_neg(d))); + add_clause(y_eq0, lsignx, lsigny, eq_internalize(e, d)); + add_clause(y_eq0, ~lsignx, ~lsigny, eq_internalize(e, d)); + } void solver::internalize_urem_i(app* rem) { expr* x, *y; diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 271b6986e..0fa4ab8e6 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -52,7 +52,7 @@ namespace polysat { } bool solver::unit_propagate() { - return m_core.propagate(); + return m_core.propagate() || propagate_delayed_axioms(); } sat::check_result solver::check() { @@ -266,32 +266,6 @@ namespace polysat { void solver::add_polysat_clause(char const* name, core_vector cs, bool is_redundant) { sat::literal_vector lits; - signed_constraint sc; - unsigned constraint_count = 0; - for (auto e : cs) { - if (std::holds_alternative(e)) { - sc = *std::get_if(&e); - constraint_count++; - } - } - if (constraint_count == 1) { - auto lit = ctx.mk_literal(constraint2expr(sc)); - svector eqs; - for (auto e : cs) { - if (std::holds_alternative(e)) { - auto d = *std::get_if(&e); - if (d.is_literal()) - lits.push_back(d.literal()); - else if (d.is_eq()) { - auto [v1, v2] = d.eq(); - eqs.push_back({ var2enode(v1), var2enode(v2) }); - } - } - } - ctx.propagate(lit, euf::th_explain::propagate(*this, lits, eqs, lit, nullptr)); - return; - } - for (auto e : cs) { if (std::holds_alternative(e)) { auto d = *std::get_if(&e); diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index e88eafdd2..c8d9e314d 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -130,6 +130,12 @@ namespace polysat { void internalize_udiv_i(app* e); void internalize_urem_i(app* e); void internalize_div_rem(app* e, bool is_div); + void axiomatize_srem(app* e, expr* x, expr* y); + void axiomatize_smod(app* e, expr* x, expr* y); + void axiomatize_sdiv(app* e, expr* x, expr* y); + unsigned m_delayed_axioms_qhead = 0; + ptr_vector m_delayed_axioms; + bool propagate_delayed_axioms(); void internalize_polysat(app* a); void assert_bv2int_axiom(app * n); void assert_int2bv_axiom(app* n); From 2de63b89c520b0d751d594000b0048e0cab79791 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 14 Dec 2023 12:12:11 -0800 Subject: [PATCH 77/89] weed out some bugs, add more bv op support in intblast and polysat solvers Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat_internalize.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 0ba59de0a..68e5e4cc6 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -152,8 +152,11 @@ namespace polysat { case OP_BREDAND: // x == 2^K - 1 unary, return single bit, 1 if all input bits are set. case OP_BREDOR: // x > 0 unary, return single bit, 1 if at least one input bit is set. - case OP_BCOMP: // x == y binary, return single bit, 1 if the arguments are equal. - + case OP_BCOMP: // x == y ? 1 : 0 binary, return single bit, 1 if the arguments are equal. + case OP_ROTATE_LEFT: + case OP_ROTATE_RIGHT: + case OP_EXT_ROTATE_LEFT: + case OP_EXT_ROTATE_RIGHT: default: IF_VERBOSE(0, verbose_stream() << mk_pp(a, m) << "\n"); NOT_IMPLEMENTED_YET(); From 54ee098cfd9261727db035f07dfeba0a2395ae3b Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 14 Dec 2023 17:22:33 -0800 Subject: [PATCH 78/89] more fixes --- src/ast/bv_decl_plugin.h | 22 +++++++ src/sat/sat_solver.cpp | 16 +++-- src/sat/smt/arith_axioms.cpp | 1 - src/sat/smt/euf_internalize.cpp | 1 - src/sat/smt/polysat_internalize.cpp | 96 +++++++++++++++++++++++++++-- src/sat/smt/polysat_solver.h | 10 +++ 6 files changed, 133 insertions(+), 13 deletions(-) diff --git a/src/ast/bv_decl_plugin.h b/src/ast/bv_decl_plugin.h index cb1f63881..89588ee0e 100644 --- a/src/ast/bv_decl_plugin.h +++ b/src/ast/bv_decl_plugin.h @@ -386,9 +386,31 @@ public: bool is_bv_shl(expr const * e) const { return is_app_of(e, get_fid(), OP_BSHL); } bool is_sign_ext(expr const * e) const { return is_app_of(e, get_fid(), OP_SIGN_EXT); } bool is_bv_umul_no_ovfl(expr const* e) const { return is_app_of(e, get_fid(), OP_BUMUL_NO_OVFL); } + bool is_redand(expr const* e) const { return is_app_of(e, get_fid(), OP_BREDAND); } + bool is_redor(expr const* e) const { return is_app_of(e, get_fid(), OP_BREDOR); } + bool is_comp(expr const* e) const { return is_app_of(e, get_fid(), OP_BCOMP); } + bool is_rotate_left(expr const* e) const { return is_app_of(e, get_fid(), OP_ROTATE_LEFT); } + bool is_rotate_right(expr const* e) const { return is_app_of(e, get_fid(), OP_ROTATE_RIGHT); } + bool is_ext_rotate_left(expr const* e) const { return is_app_of(e, get_fid(), OP_EXT_ROTATE_LEFT); } + bool is_ext_rotate_right(expr const* e) const { return is_app_of(e, get_fid(), OP_EXT_ROTATE_RIGHT); } + + bool is_rotate_left(expr const* e, unsigned& n, expr*& x) const { + return is_rotate_left(e) && (n = to_app(e)->get_parameter(0).get_int(), x = to_app(e)->get_arg(0), true); + } + bool is_rotate_right(expr const* e, unsigned& n, expr*& x) const { + return is_rotate_right(e) && (n = to_app(e)->get_parameter(0).get_int(), x = to_app(e)->get_arg(0), true); + } + bool is_int2bv(expr const* e, unsigned& n, expr*& x) const { + return is_int2bv(e) && (n = to_app(e)->get_parameter(0).get_int(), x = to_app(e)->get_arg(0), true); + } MATCH_UNARY(is_bv_not); + MATCH_UNARY(is_redand); + MATCH_UNARY(is_redor); + MATCH_BINARY(is_ext_rotate_left); + MATCH_BINARY(is_ext_rotate_right); + MATCH_BINARY(is_comp); MATCH_BINARY(is_bv_add); MATCH_BINARY(is_bv_sub); MATCH_BINARY(is_bv_mul); diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 716a8effe..a1b88bed1 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -879,7 +879,6 @@ namespace sat { m_conflict = c; m_not_l = not_l; TRACE("sat", display(display_justification(tout << "conflict " << not_l << " ", c) << "\n")); - TRACE("sat", display_watches(tout)); } void solver::assign_core(literal l, justification j) { @@ -1720,6 +1719,9 @@ namespace sat { if (next == null_bool_var) return false; } + else { + SASSERT(value(next) == l_undef); + } push(); m_stats.m_decision++; @@ -1729,11 +1731,14 @@ namespace sat { phase = guess(next) ? l_true: l_false; literal next_lit(next, false); + SASSERT(value(next_lit) == l_undef); if (m_ext && m_ext->decide(next, phase)) { + if (used_queue) m_case_split_queue.unassign_var_eh(next); next_lit = literal(next, false); + SASSERT(value(next_lit) == l_undef); } if (phase == l_undef) @@ -2553,7 +2558,8 @@ namespace sat { } SASSERT(lvl(c_var) < m_conflict_lvl); } - CTRACE("sat", idx == 0, + CTRACE("sat", idx == 0, + tout << "conflict level " << m_conflict_lvl << "\n"; for (literal lit : m_trail) if (is_marked(lit.var())) tout << "missed " << lit << "@" << lvl(lit) << "\n";); @@ -2808,8 +2814,9 @@ namespace sat { unsigned level = 0; if (not_l != null_literal) { - level = lvl(not_l); + level = lvl(not_l); } + TRACE("sat", tout << "level " << not_l << " is " << level << " " << js << "\n"); switch (js.get_kind()) { case justification::NONE: @@ -3484,11 +3491,10 @@ namespace sat { // // ----------------------- void solver::push() { + SASSERT(!m_ext || !m_ext->can_propagate()); SASSERT(!inconsistent()); TRACE("sat_verbose", tout << "q:" << m_qhead << " trail: " << m_trail.size() << "\n";); SASSERT(m_qhead == m_trail.size()); - if (m_ext) - m_ext->unit_propagate(); m_scopes.push_back(scope()); scope & s = m_scopes.back(); m_scope_lvl++; diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 0150824b2..09db74f75 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -250,7 +250,6 @@ namespace arith { add_clause(~bitof(n, i), bitof(y, i)); else continue; - verbose_stream() << "added b-and clause\n"; return false; } return true; diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index 943d0b324..f750f186d 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -106,7 +106,6 @@ namespace euf { attach_node(mk_enode(e, 0, nullptr)); return true; } - bool solver::post_visit(expr* e, bool sign, bool root) { unsigned num = is_app(e) ? to_app(e)->get_num_args() : 0; m_args.reset(); diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 68e5e4cc6..b2302a22d 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -145,18 +145,21 @@ namespace polysat { case OP_BSMOD_I: case OP_BSDIV: case OP_BSDIV_I: - expr2pdd(a); - m_delayed_axioms.push_back(a); - ctx.push(push_back_vector(m_delayed_axioms)); - break; - - case OP_BREDAND: // x == 2^K - 1 unary, return single bit, 1 if all input bits are set. case OP_BREDOR: // x > 0 unary, return single bit, 1 if at least one input bit is set. + case OP_BREDAND: // x == 2^K - 1 unary, return single bit, 1 if all input bits are set. case OP_BCOMP: // x == y ? 1 : 0 binary, return single bit, 1 if the arguments are equal. case OP_ROTATE_LEFT: case OP_ROTATE_RIGHT: case OP_EXT_ROTATE_LEFT: case OP_EXT_ROTATE_RIGHT: + case OP_INT2BV: + case OP_BV2INT: + if (bv.is_bv(a)) + expr2pdd(a); + m_delayed_axioms.push_back(a); + ctx.push(push_back_vector(m_delayed_axioms)); + break; + default: IF_VERBOSE(0, verbose_stream() << mk_pp(a, m) << "\n"); NOT_IMPLEMENTED_YET(); @@ -276,6 +279,7 @@ namespace polysat { for (; m_delayed_axioms_qhead < m_delayed_axioms.size() && !inconsistent(); ++m_delayed_axioms_qhead) { app* e = m_delayed_axioms[m_delayed_axioms_qhead]; expr* x, *y; + unsigned n = 0; if (bv.is_bv_sdiv(e, x, y)) axiomatize_sdiv(e, x, y); else if (bv.is_bv_sdivi(e, x, y)) @@ -288,12 +292,92 @@ namespace polysat { axiomatize_smod(e, x, y); else if (bv.is_bv_smodi(e, x, y)) axiomatize_smod(e, x, y); + else if (bv.is_redand(e, x)) + axiomatize_redand(e, x); + else if (bv.is_redor(e, x)) + axiomatize_redor(e, x); + else if (bv.is_comp(e, x, y)) + axiomatize_comp(e, x, y); + else if (bv.is_rotate_left(e, n, x)) + axiomatize_rotate_left(e, n, x); + else if (bv.is_rotate_right(e, n, x)) + axiomatize_rotate_right(e, n, x); + else if (bv.is_ext_rotate_left(e, x, y)) + axiomatize_ext_rotate_left(e, x, y); + else if (bv.is_ext_rotate_right(e, x, y)) + axiomatize_ext_rotate_right(e, x, y); + else if (bv.is_bv2int(e, x)) + axiomatize_bv2int(e, x); + else if (bv.is_int2bv(e, n, x)) + axiomatize_int2bv(e, n, x); else UNREACHABLE(); } return true; } + void solver::axiomatize_int2bv(app* e, unsigned & sz, expr* x) { + NOT_IMPLEMENTED_YET(); + + } + + void solver::axiomatize_bv2int(app* e, expr* x) { + NOT_IMPLEMENTED_YET(); + } + + + expr* solver::rotate_left(app* e, unsigned n, expr* x) { + unsigned sz = bv.get_bv_size(x); + n = n % sz; + if (n == 0 || sz == 1) + return x; + else + return bv.mk_concat(bv.mk_extract(n, 0, x), bv.mk_extract(sz - 1, sz - n - 1, x)); + } + + void solver::axiomatize_rotate_left(app* e, unsigned n, expr* x) { + // e = x[n : 0] ++ x[sz - 1, sz - n - 1] + add_unit(eq_internalize(e, rotate_left(e, n, x))); + } + + void solver::axiomatize_rotate_right(app* e, unsigned n, expr* x) { + unsigned sz = bv.get_bv_size(x); + axiomatize_rotate_left(e, sz - n, x); + } + + void solver::axiomatize_ext_rotate_left(app* e, expr* x, expr* y) { + unsigned sz = bv.get_bv_size(x); + for (unsigned i = 0; i < sz; ++i) + add_clause(~eq_internalize(bv.mk_numeral(i, sz), y), eq_internalize(e, rotate_left(e, i, x))); + add_clause(~mk_literal(bv.mk_ule(bv.mk_numeral(sz, sz), y)), eq_internalize(e, bv.mk_zero(sz))); + } + + void solver::axiomatize_ext_rotate_right(app* e, expr* x, expr* y) { + unsigned sz = bv.get_bv_size(x); + for (unsigned i = 0; i < sz; ++i) + add_clause(~eq_internalize(bv.mk_numeral(i, sz), y), eq_internalize(e, rotate_left(e, sz - i, x))); + add_clause(~mk_literal(bv.mk_ule(bv.mk_numeral(sz, sz), y)), eq_internalize(e, bv.mk_zero(sz))); + } + + // x = N - 1 + void solver::axiomatize_redand(app* e, expr* x) { + unsigned sz = bv.get_bv_size(x); + rational N = rational::power_of_two(sz); + add_equiv(expr2literal(e), eq_internalize(x, bv.mk_numeral(N - 1, sz))); + } + + void solver::axiomatize_redor(app* e, expr* x) { + unsigned sz = bv.get_bv_size(x); + add_equiv(expr2literal(e), ~eq_internalize(x, bv.mk_zero(sz))); + } + + void solver::axiomatize_comp(app* e, expr* x, expr* y) { + unsigned sz = bv.get_bv_size(x); + auto eq = eq_internalize(x, y); + add_clause(~eq, eq_internalize(e, bv.mk_numeral(1, sz))); + add_clause(eq, eq_internalize(e, bv.mk_numeral(0, sz))); + } + // y = 0 -> x // else x - sdiv(x, y) * y void solver::axiomatize_srem(app* e, expr* x, expr* y) { diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index c8d9e314d..8038cc4bb 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -133,6 +133,16 @@ namespace polysat { void axiomatize_srem(app* e, expr* x, expr* y); void axiomatize_smod(app* e, expr* x, expr* y); void axiomatize_sdiv(app* e, expr* x, expr* y); + void axiomatize_redand(app* e, expr* x); + void axiomatize_redor(app* e, expr* x); + void axiomatize_comp(app* e, expr* x, expr* y); + void axiomatize_rotate_left(app* e, unsigned n, expr* x); + void axiomatize_rotate_right(app* e, unsigned n, expr* x); + void axiomatize_ext_rotate_left(app* e, expr* x, expr* y); + void axiomatize_ext_rotate_right(app* e, expr* x, expr* y); + void axiomatize_int2bv(app* e, unsigned & sz, expr* x); + void axiomatize_bv2int(app* e, expr* x); + expr* rotate_left(app* e, unsigned n, expr* x); unsigned m_delayed_axioms_qhead = 0; ptr_vector m_delayed_axioms; bool propagate_delayed_axioms(); From ce1acd8c414225f6944f38be772539b25092cd19 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 14 Dec 2023 19:30:21 -0800 Subject: [PATCH 79/89] fix encoding bugs Signed-off-by: Nikolaj Bjorner --- src/sat/smt/intblast_solver.cpp | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 592c8c0f4..765cc0678 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -468,6 +468,7 @@ namespace intblast { sorts.push_back(a.mk_int()); } else + sorts.push_back(s); } b = translated(b); @@ -701,17 +702,17 @@ namespace intblast { // // ashr(x, y) // if y = k & x >= 0 -> x / 2^k - // if y = k & x < 0 -> - (x / 2^k) + // if y = k & x < 0 -> (x / 2^k) - 1 + 2^{N-k} // - rational N = rational::power_of_two(bv.get_bv_size(e)); - expr* x = umod(e, 0); - expr* y = umod(e, 1); - expr* signbit = a.mk_ge(x, a.mk_int(N / 2)); - r = m.mk_ite(signbit, a.mk_int(N - 1), a.mk_int(0)); - for (unsigned i = 0; i < bv.get_bv_size(e); ++i) { - expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); + unsigned sz = bv.get_bv_size(e); + rational N = bv_size(e); + expr* x = umod(e, 0), *y = umod(e, 1); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + r = m.mk_ite(signx, a.mk_int(N - 1), a.mk_int(0)); + for (unsigned i = 0; i < sz; ++i) { + expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), - m.mk_ite(signbit, a.mk_uminus(d), d), + m.mk_ite(signx, a.mk_add(d, a.mk_int(N - rational::power_of_two(sz-i))), d), r); } break; @@ -793,12 +794,13 @@ namespace intblast { case OP_BSREM: { // y = 0 -> x // else x - sdiv(x, y) * y - bv_expr = e; - expr* x = umod(bv_expr, 0), * y = umod(bv_expr, 1); - rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); + expr* x = umod(e, 0), * y = umod(e, 1); + rational N = bv_size(e); expr* signx = a.mk_ge(x, a.mk_int(N / 2)); expr* signy = a.mk_ge(y, a.mk_int(N / 2)); - expr* d = a.mk_idiv(x, y); + expr* absx = m.mk_ite(signx, a.mk_sub(a.mk_int(N), x), x); + expr* absy = m.mk_ite(signy, a.mk_sub(a.mk_int(N), y), y); + expr* d = a.mk_idiv(absx, absy); d = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); r = a.mk_sub(x, a.mk_mul(d, y)); r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, r); From 3c21e3ae423855acc28587d9dc7e10b2a3d3a85d Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 14 Dec 2023 20:12:09 -0800 Subject: [PATCH 80/89] add and fix axioms --- src/sat/smt/intblast_solver.cpp | 4 +- src/sat/smt/polysat/op_constraint.cpp | 167 ++++++++++++++------------ src/sat/smt/polysat/op_constraint.h | 10 +- src/sat/smt/polysat_internalize.cpp | 19 ++- src/sat/smt/polysat_solver.h | 2 +- 5 files changed, 115 insertions(+), 87 deletions(-) diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 765cc0678..9d03d0ad0 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -708,11 +708,11 @@ namespace intblast { rational N = bv_size(e); expr* x = umod(e, 0), *y = umod(e, 1); expr* signx = a.mk_ge(x, a.mk_int(N / 2)); - r = m.mk_ite(signx, a.mk_int(N - 1), a.mk_int(0)); + r = m.mk_ite(signx, a.mk_int(- 1), a.mk_int(0)); for (unsigned i = 0; i < sz; ++i) { expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), - m.mk_ite(signx, a.mk_add(d, a.mk_int(N - rational::power_of_two(sz-i))), d), + m.mk_ite(signx, a.mk_add(d, a.mk_int(- rational::power_of_two(sz-i))), d), r); } break; diff --git a/src/sat/smt/polysat/op_constraint.cpp b/src/sat/smt/polysat/op_constraint.cpp index c971fe1cd..b7d312d55 100644 --- a/src/sat/smt/polysat/op_constraint.cpp +++ b/src/sat/smt/polysat/op_constraint.cpp @@ -25,8 +25,8 @@ Additional possible functionality on constraints: namespace polysat { - op_constraint::op_constraint(code c, pdd const& p, pdd const& q, pdd const& r) : - m_op(c), m_p(p), m_q(q), m_r(r) { + op_constraint::op_constraint(code c, pdd const& _p, pdd const& _q, pdd const& _r) : + m_op(c), p(_p), q(_q), r(_r) { vars().append(p.free_vars()); for (auto v : q.free_vars()) if (!vars().contains(v)) @@ -38,7 +38,7 @@ namespace polysat { switch (c) { case code::and_op: if (p.index() > q.index()) - std::swap(m_p, m_q); + std::swap(p, q); break; case code::inv_op: SASSERT(q.is_zero()); @@ -50,11 +50,11 @@ namespace polysat { } lbool op_constraint::eval() const { - return eval(p(), q(), r()); + return eval(p, q, r); } lbool op_constraint::eval(assignment const& a) const { - return eval(a.apply_to(p()), a.apply_to(q()), a.apply_to(r())); + return eval(a.apply_to(p), a.apply_to(q), a.apply_to(r)); } lbool op_constraint::eval(pdd const& p, pdd const& q, pdd const& r) const { @@ -67,6 +67,8 @@ namespace polysat { return eval_and(p, q, r); case code::inv_op: return eval_inv(p, r); + case code::ashr_op: + return eval_ashr(p, q, r); default: return l_undef; } @@ -93,6 +95,11 @@ namespace polysat { return l_undef; } + lbool op_constraint::eval_ashr(pdd const& p, pdd const& q, pdd const& r) { + NOT_IMPLEMENTED_YET(); + return l_undef; + } + /** Evaluate constraint: r == p << q */ lbool op_constraint::eval_shl(pdd const& p, pdd const& q, pdd const& r) { auto& m = p.manager(); @@ -171,9 +178,9 @@ namespace polysat { std::ostream& op_constraint::display(std::ostream& out, char const* eq) const { if (m_op == code::inv_op) - return out << r() << " " << eq << " " << m_op << " " << p(); + return out << r << " " << eq << " " << m_op << " " << p; - return out << r() << " " << eq << " " << p() << " " << m_op << " " << q(); + return out << r << " " << eq << " " << p << " " << m_op << " " << q; } void op_constraint::activate(core& c, bool sign, dependency const& dep) { @@ -239,52 +246,50 @@ namespace polysat { * when r, q are variables. */ void op_constraint::propagate_lshr(core& c, dependency const& d) { - auto& m = p().manager(); - auto const pv = c.subst(p()); - auto const qv = c.subst(q()); - auto const rv = c.subst(r()); + auto& m = p.manager(); + auto const pv = c.subst(p); + auto const qv = c.subst(q); + auto const rv = c.subst(r); unsigned const N = m.power_of_2(); - - signed_constraint const lshr(polysat::ckind_t::op_t, this); auto& C = c.cs(); if (pv.is_val() && rv.is_val() && rv.val() > pv.val()) - c.add_clause("lshr 1", { d, C.ule(r(), p()) }, false); + c.add_clause("lshr 1", { C.ule(r, p) }, false); else if (qv.is_val() && qv.val() >= N && rv.is_val() && !rv.is_zero()) // TODO: instead of rv.is_val() && !rv.is_zero(), we should use !is_forced_zero(r) which checks whether eval(r) = 0 or bvalue(r=0) = true; see saturation.cpp - c.add_clause("q >= N -> r = 0", { d, ~C.ule(N, q()), C.eq(r()) }, true); + c.add_clause("q >= N -> r = 0", { ~C.ule(N, q), C.eq(r) }, true); else if (qv.is_zero() && pv.is_val() && rv.is_val() && pv != rv) - c.add_clause("q = 0 -> p = r", { d, ~C.eq(q()), C.eq(p(), r()) } , true); + c.add_clause("q = 0 -> p = r", { ~C.eq(q), C.eq(p, r) } , true); else if (qv.is_val() && !qv.is_zero() && pv.is_val() && rv.is_val() && !pv.is_zero() && rv.val() >= pv.val()) - c.add_clause("q != 0 & p > 0 -> r < p", { d, C.eq(q()), C.ule(p(), 0), C.ult(r(), p()) }, true); + c.add_clause("q != 0 & p > 0 -> r < p", { C.eq(q), C.ule(p, 0), C.ult(r, p) }, true); else if (qv.is_val() && !qv.is_zero() && qv.val() < N && rv.is_val() && rv.val() > rational::power_of_two(N - qv.val().get_unsigned()) - 1) - c.add_clause("q >= k -> r <= 2^{N-k} - 1", { d, ~C.ule(qv.val(), q()), C.ule(r(), rational::power_of_two(N - qv.val().get_unsigned()) - 1)}, true); + c.add_clause("q >= k -> r <= 2^{N-k} - 1", { ~C.ule(qv.val(), q), C.ule(r, rational::power_of_two(N - qv.val().get_unsigned()) - 1)}, true); else if (pv.is_val() && rv.is_val() && qv.is_val() && !qv.is_zero()) { unsigned k = qv.val().get_unsigned(); for (unsigned i = 0; i < N - k; ++i) { if (rv.val().get_bit(i) && !pv.val().get_bit(i + k)) - c.add_clause("q = k -> r[i] = p[i+k] for 0 <= i < N - k", { d, ~C.eq(q(), k), ~C.bit(r(), i), C.bit(p(), i + k) }, true); + c.add_clause("q = k -> r[i] = p[i+k] for 0 <= i < N - k", { ~C.eq(q, k), ~C.bit(r, i), C.bit(p, i + k) }, true); if (!rv.val().get_bit(i) && pv.val().get_bit(i + k)) - c.add_clause("q = k -> r[i] = p[i+k] for 0 <= i < N - k", { d, ~C.eq(q(), k), C.bit(r(), i), ~C.bit(p(), i + k) }, true); + c.add_clause("q = k -> r[i] = p[i+k] for 0 <= i < N - k", { ~C.eq(q, k), C.bit(r, i), ~C.bit(p, i + k) }, true); } } else { // forward propagation SASSERT(!(pv.is_val() && qv.is_val() && rv.is_val())); - // LOG(p() << " = " << pv << " and " << q() << " = " << qv << " yields [>>] " << r() << " = " << (qv.val().is_unsigned() ? machine_div2k(pv.val(), qv.val().get_unsigned()) : rational::zero())); + // LOG(p << " = " << pv << " and " << q << " = " << qv << " yields [>>] " << r << " = " << (qv.val().is_unsigned() ? machine_div2k(pv.val(), qv.val().get_unsigned()) : rational::zero())); if (qv.is_val() && !rv.is_val()) { rational const& q_val = qv.val(); if (q_val >= N) // q >= N ==> r = 0 - c.add_clause("q >= N ==> r = 0", { d, ~C.ule(N, q()), C.eq(r()) }, true); + c.add_clause("q >= N ==> r = 0", { ~C.ule(N, q), C.eq(r) }, true); else if (pv.is_val()) { SASSERT(q_val.is_unsigned()); // rational const r_val = machine_div2k(pv.val(), q_val.get_unsigned()); - c.add_clause("p = p_val & q = q_val ==> r = p_val / 2^q_val", { d, ~C.eq(p(), pv), ~C.eq(q(), qv), C.eq(r(), r_val) }, true); + c.add_clause("p = p_val & q = q_val ==> r = p_val / 2^q_val", { ~C.eq(p, pv), ~C.eq(q, qv), C.eq(r, r_val) }, true); } } } @@ -292,7 +297,7 @@ namespace polysat { void op_constraint::activate_and(core& c, dependency const& d) { - auto x = p(), y = q(); + auto x = p, y = q; auto& C = c.cs(); if (x.is_val()) std::swap(x, y); @@ -303,21 +308,35 @@ namespace polysat { if (!(yv + 1).is_power_of_two()) return; if (yv == m.max_value()) - c.add_clause("band-mask-true", { d, C.eq(x, r()) }, false); + c.add_clause("band-mask-true", { C.eq(x, r) }, false); else if (yv == 0) - c.add_clause("band-mask-false", { d, C.eq(r()) }, false); + c.add_clause("band-mask-false", { C.eq(r) }, false); else { unsigned N = m.power_of_2(); unsigned k = yv.get_num_bits(); SASSERT(k < N); rational exp = rational::power_of_two(N - k); - c.add_clause("band-mask 1", { d, C.eq(x * exp, r() * exp) }, false); - c.add_clause("band-mask 2", { d, C.ule(r(), y) }, false); // maybe always activate these constraints regardless? + c.add_clause("band-mask 1", { C.eq(x * exp, r * exp) }, false); + c.add_clause("band-mask 2", { C.ule(r, y) }, false); // maybe always activate these constraints regardless? } } - void op_constraint::propagate_ashr(core& s, dependency const& dep) { + void op_constraint::propagate_ashr(core& c, dependency const& dep) { + // + // ashr(x, y) + // if q >= N & p < 0 -> -1 + // if q >= N & p >= 0 -> 0 + // if q = k & p >= 0 -> p / 2^k + // if q = k & p < 0 -> (p / 2^k) - 1 + 2^{N-k} + // + auto& m = p.manager(); + auto const pv = c.subst(p); + auto const qv = c.subst(q); + auto const rv = c.subst(r); + unsigned const N = m.power_of_2(); + + NOT_IMPLEMENTED_YET(); } @@ -331,49 +350,49 @@ namespace polysat { * q = 0 -> r = p */ void op_constraint::propagate_shl(core& c, dependency const& d) { - auto& m = p().manager(); - auto const pv = c.subst(p()); - auto const qv = c.subst(q()); - auto const rv = c.subst(r()); + auto& m = p.manager(); + auto const pv = c.subst(p); + auto const qv = c.subst(q); + auto const rv = c.subst(r); unsigned const N = m.power_of_2(); auto& C = c.cs(); if (qv.is_val() && qv.val() >= N && rv.is_val() && !rv.is_zero()) - c.add_clause("q >= N -> r = 0", { d, ~C.ule(N, q()), C.eq(r()) }, true); + c.add_clause("q >= N -> r = 0", { ~C.ule(N, q), C.eq(r) }, true); else if (qv.is_zero() && pv.is_val() && rv.is_val() && rv != pv) // - c.add_clause("q = 0 -> r = p", { d, ~C.eq(q()), C.eq(r(), p()) }, true); + c.add_clause("q = 0 -> r = p", { ~C.eq(q), C.eq(r, p) }, true); else if (qv.is_val() && !qv.is_zero() && qv.val() < N && rv.is_val() && !rv.is_zero() && rv.val() < rational::power_of_two(qv.val().get_unsigned())) // q >= k -> r = 0 \/ r >= 2^k (intuitive version) // q >= k -> r - 1 >= 2^k - 1 (equivalent unit constraint to better support narrowing) - c.add_clause("q >= k -> r - 1 >= 2^k - 1", { d, ~C.ule(qv.val(), q()), C.ule(rational::power_of_two(qv.val().get_unsigned()) - 1, r() - 1) }, true); + c.add_clause("q >= k -> r - 1 >= 2^k - 1", { ~C.ule(qv.val(), q), C.ule(rational::power_of_two(qv.val().get_unsigned()) - 1, r - 1) }, true); else if (pv.is_val() && rv.is_val() && qv.is_val() && !qv.is_zero()) { unsigned k = qv.val().get_unsigned(); // q = k -> r[i+k] = p[i] for 0 <= i < N - k for (unsigned i = 0; i < N - k; ++i) { if (rv.val().get_bit(i + k) && !pv.val().get_bit(i)) { - c.add_clause("q = k -> r[i+k] = p[i] for 0 <= i < N - k", { d, ~C.eq(q(), k), ~C.bit(r(), i + k), C.bit(p(), i) }, true); + c.add_clause("q = k -> r[i+k] = p[i] for 0 <= i < N - k", { ~C.eq(q, k), ~C.bit(r, i + k), C.bit(p, i) }, true); } if (!rv.val().get_bit(i + k) && pv.val().get_bit(i)) { - c.add_clause("q = k -> r[i+k] = p[i] for 0 <= i < N - k", { d, ~C.eq(q(), k), C.bit(r(), i + k), ~C.bit(p(), i) }, true); + c.add_clause("q = k -> r[i+k] = p[i] for 0 <= i < N - k", { ~C.eq(q, k), C.bit(r, i + k), ~C.bit(p, i) }, true); } } } else { // forward propagation SASSERT(!(pv.is_val() && qv.is_val() && rv.is_val())); - // LOG(p() << " = " << pv << " and " << q() << " = " << qv << " yields [<<] " << r() << " = " << (qv.val().is_unsigned() ? rational::power_of_two(qv.val().get_unsigned()) * pv.val() : rational::zero())); + // LOG(p << " = " << pv << " and " << q << " = " << qv << " yields [<<] " << r << " = " << (qv.val().is_unsigned() ? rational::power_of_two(qv.val().get_unsigned()) * pv.val() : rational::zero())); if (qv.is_val() && !rv.is_val()) { rational const& q_val = qv.val(); if (q_val >= N) // q >= N ==> r = 0 - c.add_clause("shl forward 1", {d, ~C.ule(N, q()), C.eq(r())}, true); + c.add_clause("shl forward 1", {~C.ule(N, q), C.eq(r)}, true); if (pv.is_val()) { SASSERT(q_val.is_unsigned()); // p = p_val & q = q_val ==> r = p_val * 2^q_val rational const r_val = pv.val() * rational::power_of_two(q_val.get_unsigned()); - c.add_clause("shl forward 2", {d, ~C.eq(p(), pv), ~C.eq(q(), qv), C.eq(r(), r_val)}, true); + c.add_clause("shl forward 2", {~C.eq(p, pv), ~C.eq(q, qv), C.eq(r, r_val)}, true); } } } @@ -393,38 +412,38 @@ namespace polysat { * q = 2^k - 1 && r = 0 && p != 0 => p >= 2^k */ void op_constraint::propagate_and(core& c, dependency const& d) { - auto& m = p().manager(); - auto pv = c.subst(p()); - auto qv = c.subst(q()); - auto rv = c.subst(r()); + auto& m = p.manager(); + auto pv = c.subst(p); + auto qv = c.subst(q); + auto rv = c.subst(r); auto& C = c.cs(); if (pv.is_val() && rv.is_val() && rv.val() > pv.val()) - c.add_clause("p&q <= p", { d, C.ule(r(), p()) }, true); + c.add_clause("p&q <= p", { C.ule(r, p) }, true); else if (qv.is_val() && rv.is_val() && rv.val() > qv.val()) - c.add_clause("p&q <= q", { d, C.ule(r(), q()) }, true); + c.add_clause("p&q <= q", { C.ule(r, q) }, true); else if (pv.is_val() && qv.is_val() && rv.is_val() && pv == qv && rv != pv) - c.add_clause("p = q => r = p", { d, ~C.eq(p(), q()), C.eq(r(), p()) }, true); + c.add_clause("p = q => r = p", { ~C.eq(p, q), C.eq(r, p) }, true); else if (pv.is_val() && qv.is_val() && rv.is_val()) { if (pv.is_max() && qv != rv) - c.add_clause("p = -1 => r = q", { d, ~C.eq(p(), m.max_value()), C.eq(q(), r()) }, true); + c.add_clause("p = -1 => r = q", { ~C.eq(p, m.max_value()), C.eq(q, r) }, true); if (qv.is_max() && pv != rv) - c.add_clause("q = -1 => r = p", { d, ~C.eq(q(), m.max_value()), C.eq(p(), r()) }, true); + c.add_clause("q = -1 => r = p", { ~C.eq(q, m.max_value()), C.eq(p, r) }, true); unsigned const N = m.power_of_2(); unsigned pow; if ((pv.val() + 1).is_power_of_two(pow)) { if (rv.is_zero() && !qv.is_zero() && qv.val() <= pv.val()) - c.add_clause("p = 2^k - 1 && r = 0 && q != 0 => q >= 2^k", { d, ~C.eq(p(), pv), ~C.eq(r()), C.eq(q()), C.ule(pv + 1, q()) }, true); + c.add_clause("p = 2^k - 1 && r = 0 && q != 0 => q >= 2^k", { ~C.eq(p, pv), ~C.eq(r), C.eq(q), C.ule(pv + 1, q) }, true); if (rv != qv) - c.add_clause("p = 2^k - 1 ==> r*2^{N - k} = q*2^{N - k}", { d, ~C.eq(p(), pv), C.eq(r() * rational::power_of_two(N - pow), q() * rational::power_of_two(N - pow)) }, true); + c.add_clause("p = 2^k - 1 ==> r*2^{N - k} = q*2^{N - k}", { ~C.eq(p, pv), C.eq(r * rational::power_of_two(N - pow), q * rational::power_of_two(N - pow)) }, true); } if ((qv.val() + 1).is_power_of_two(pow)) { if (rv.is_zero() && !pv.is_zero() && pv.val() <= qv.val()) - c.add_clause("q = 2^k - 1 && r = 0 && p != 0 ==> p >= 2^k", { d, ~C.eq(q(), qv), ~C.eq(r()), C.eq(p()), C.ule(qv + 1, p()) }, true); + c.add_clause("q = 2^k - 1 && r = 0 && p != 0 ==> p >= 2^k", { ~C.eq(q, qv), ~C.eq(r), C.eq(p), C.ule(qv + 1, p) }, true); // if (rv != pv) - c.add_clause("q = 2^k - 1 ==> r*2^{N - k} = p*2^{N - k}", { d, ~C.eq(q(), qv), C.eq(r() * rational::power_of_two(N - pow), p() * rational::power_of_two(N - pow)) }, true); + c.add_clause("q = 2^k - 1 ==> r*2^{N - k} = p*2^{N - k}", { ~C.eq(q, qv), C.eq(r * rational::power_of_two(N - pow), p * rational::power_of_two(N - pow)) }, true); } for (unsigned i = 0; i < N; ++i) { @@ -434,11 +453,11 @@ namespace polysat { if (rb == (pb && qb)) continue; if (pb && qb && !rb) - c.add_clause("p&q[i] = p[i]&q[i]", { d, ~C.bit(p(), i), ~C.bit(q(), i), C.bit(r(), i) }, true); + c.add_clause("p&q[i] = p[i]&q[i]", { ~C.bit(p, i), ~C.bit(q, i), C.bit(r, i) }, true); else if (!pb && rb) - c.add_clause("p&q[i] = p[i]&q[i]", { d, C.bit(p(), i), ~C.bit(r(), i) }, true); + c.add_clause("p&q[i] = p[i]&q[i]", { C.bit(p, i), ~C.bit(r, i) }, true); else if (!qb && rb) - c.add_clause("p&q[i] = p[i]&q[i]", { d, C.bit(q(), i), ~C.bit(r(), i) }, true); + c.add_clause("p&q[i] = p[i]&q[i]", { C.bit(q, i), ~C.bit(r, i) }, true); else UNREACHABLE(); } @@ -447,14 +466,14 @@ namespace polysat { // Propagate r if p or q are 0 else if (pv.is_zero() && !rv.is_zero()) // rv not necessarily fully evaluated - c.add_clause("p = 0 -> p&q = 0", { d, C.ule(r(), p()) }, true); + c.add_clause("p = 0 -> p&q = 0", { C.ule(r, p) }, true); else if (qv.is_zero() && !rv.is_zero()) // rv not necessarily fully evaluated - c.add_clause("q = 0 -> p&q = 0", { d, C.ule(r(), q()) }, true); + c.add_clause("q = 0 -> p&q = 0", { C.ule(r, q) }, true); // p = a && q = b ==> r = a & b else if (pv.is_val() && qv.is_val() && !rv.is_val()) { // Just assign by this very weak justification. It will be strengthened in saturation in case of a conflict - LOG(p() << " = " << pv << " and " << q() << " = " << qv << " yields [band] " << r() << " = " << bitwise_and(pv.val(), qv.val())); - c.add_clause("p = a & q = b => r = a&b", { d, ~C.eq(p(), pv), ~C.eq(q(), qv), C.eq(r(), bitwise_and(pv.val(), qv.val())) }, true); + LOG(p << " = " << pv << " and " << q << " = " << qv << " yields [band] " << r << " = " << bitwise_and(pv.val(), qv.val())); + c.add_clause("p = a & q = b => r = a&b", { ~C.eq(p, pv), ~C.eq(q, qv), C.eq(r, bitwise_and(pv.val(), qv.val())) }, true); } } @@ -470,9 +489,9 @@ namespace polysat { * parity(p) < k ==> r <= 2^(N - k) - 1 (because r is the smallest pseudo-inverse) */ clause_ref op_constraint::lemma_inv(solver& s, assignment const& a) { - auto& m = p().manager(); - auto pv = a.apply_to(p()); - auto rv = a.apply_to(r()); + auto& m = p.manager(); + auto pv = a.apply_to(p); + auto rv = a.apply_to(r); if (eval_inv(pv, rv) == l_true) return {}; @@ -481,15 +500,15 @@ namespace polysat { // p = 0 ==> r = 0 if (pv.is_zero()) - c.add_clause(~invc, ~C.eq(p()), C.eq(r()), true); + c.add_clause(~invc, ~C.eq(p), C.eq(r), true); // r = 0 ==> p = 0 if (rv.is_zero()) - c.add_clause(~invc, ~C.eq(r()), C.eq(p()), true); + c.add_clause(~invc, ~C.eq(r), C.eq(p), true); // forward propagation: p assigned ==> r = pseudo_inverse(eval(p)) // TODO: (later) this should be propagated instead of adding a clause /*if (pv.is_val() && !rv.is_val()) - c.add_clause(~invc, ~C.eq(p(), pv), C.eq(r(), pv.val().pseudo_inverse(m.power_of_2())), true);*/ + c.add_clause(~invc, ~C.eq(p, pv), C.eq(r, pv.val().pseudo_inverse(m.power_of_2())), true);*/ if (!pv.is_val() || !rv.is_val()) return {}; @@ -497,14 +516,14 @@ namespace polysat { unsigned parity_pv = pv.val().trailing_zeros(); unsigned parity_rv = rv.val().trailing_zeros(); - LOG("p: " << p() << " := " << pv << " parity " << parity_pv); - LOG("r: " << r() << " := " << rv << " parity " << parity_rv); + LOG("p: " << p << " := " << pv << " parity " << parity_pv); + LOG("r: " << r << " := " << rv << " parity " << parity_rv); // p != 0 ==> odd(r) if (parity_rv != 0) - c.add_clause("r = inv p & p != 0 ==> odd(r)", {~invc, C.eq(p()), s.odd(r())}, true); + c.add_clause("r = inv p & p != 0 ==> odd(r)", {~invc, C.eq(p), s.odd(r)}, true); - pdd prod = p() * r(); + pdd prod = p * r; rational prodv = (pv * rv).val(); // if (prodv != rational::power_of_two(parity_pv)) // verbose_stream() << prodv << " " << rational::power_of_two(parity_pv) << " " << parity_pv << " " << pv << " " << rv << "\n"; @@ -519,12 +538,12 @@ namespace polysat { // parity(p) >= k ==> p * r >= 2^k if (prodv < rational::power_of_two(middle)) c.add_clause("r = inv p & parity(p) >= k ==> p*r >= 2^k", - {~invc, ~s.parity_at_least(p(), middle), s.uge(prod, rational::power_of_two(middle))}, false); + {~invc, ~s.parity_at_least(p, middle), s.uge(prod, rational::power_of_two(middle))}, false); // parity(p) >= k ==> r <= 2^(N - k) - 1 (because r is the smallest pseudo-inverse) rational const max_rv = rational::power_of_two(m.power_of_2() - middle) - 1; if (rv.val() > max_rv) c.add_clause("r = inv p & parity(p) >= k ==> r <= 2^(N - k) - 1", - {~invc, ~s.parity_at_least(p(), middle), C.ule(r(), max_rv)}, false); + {~invc, ~s.parity_at_least(p, middle), C.ule(r, max_rv)}, false); } else { // parity less than middle SASSERT(parity_pv < middle); @@ -533,7 +552,7 @@ namespace polysat { // parity(p) < k ==> p * r <= 2^k - 1 if (prodv > rational::power_of_two(middle)) c.add_clause("r = inv p & parity(p) < k ==> p*r <= 2^k - 1", - {~invc, s.parity_at_least(p(), middle), C.ule(prod, rational::power_of_two(middle) - 1)}, false); + {~invc, s.parity_at_least(p, middle), C.ule(prod, rational::power_of_two(middle) - 1)}, false); } } // Why did it evaluate to false in this case? diff --git a/src/sat/smt/polysat/op_constraint.h b/src/sat/smt/polysat/op_constraint.h index 1aec1c486..c5400fbab 100644 --- a/src/sat/smt/polysat/op_constraint.h +++ b/src/sat/smt/polysat/op_constraint.h @@ -44,14 +44,15 @@ namespace polysat { friend class constraints; code m_op; - pdd m_p; // operand1 - pdd m_q; // operand2 - pdd m_r; // result + pdd p; // operand1 + pdd q; // operand2 + pdd r; // result op_constraint(code c, pdd const& r, pdd const& p, pdd const& q); lbool eval(pdd const& r, pdd const& p, pdd const& q) const; static lbool eval_lshr(pdd const& p, pdd const& q, pdd const& r); + static lbool eval_ashr(pdd const& p, pdd const& q, pdd const& r); static lbool eval_shl(pdd const& p, pdd const& q, pdd const& r); static lbool eval_and(pdd const& p, pdd const& q, pdd const& r); static lbool eval_inv(pdd const& p, pdd const& r); @@ -70,9 +71,6 @@ namespace polysat { public: ~op_constraint() override {} - pdd const& p() const { return m_p; } - pdd const& q() const { return m_q; } - pdd const& r() const { return m_r; } code get_op() const { return m_op; } std::ostream& display(std::ostream& out, lbool status) const override; std::ostream& display(std::ostream& out) const override; diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index b2302a22d..3cbf4b0ed 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -316,13 +316,24 @@ namespace polysat { return true; } - void solver::axiomatize_int2bv(app* e, unsigned & sz, expr* x) { - NOT_IMPLEMENTED_YET(); - + void solver::axiomatize_int2bv(app* e, unsigned sz, expr* x) { + // e = int2bv(x) + // bv2int(int2bv(x)) = x mod N + rational N = rational::power_of_two(sz); + add_unit(eq_internalize(bv.mk_bv2int(e), m_autil.mk_mod(x, m_autil.mk_int(N)))); } void solver::axiomatize_bv2int(app* e, expr* x) { - NOT_IMPLEMENTED_YET(); + // e := bv2int(x) + // e = sum_bits(x) + unsigned sz = bv.get_bv_size(x); + expr* one = m_autil.mk_int(1); + expr* zero = m_autil.mk_int(0); + expr* r = zero; + pdd p = expr2pdd(x); + for (unsigned i = 0; i < sz; ++i) + r = m_autil.mk_add(r, m.mk_ite(constraint2expr(m_core.bit(p, i)), one, zero)); + add_unit(eq_internalize(e, r)); } diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 8038cc4bb..d7032046d 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -140,7 +140,7 @@ namespace polysat { void axiomatize_rotate_right(app* e, unsigned n, expr* x); void axiomatize_ext_rotate_left(app* e, expr* x, expr* y); void axiomatize_ext_rotate_right(app* e, expr* x, expr* y); - void axiomatize_int2bv(app* e, unsigned & sz, expr* x); + void axiomatize_int2bv(app* e, unsigned sz, expr* x); void axiomatize_bv2int(app* e, expr* x); expr* rotate_left(app* e, unsigned n, expr* x); unsigned m_delayed_axioms_qhead = 0; From 922358b9ba6149167e987f89ed1e1030debb5dc0 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 15 Dec 2023 08:59:05 -0800 Subject: [PATCH 81/89] import pdd updates from polysat Signed-off-by: Nikolaj Bjorner --- src/math/dd/dd_pdd.cpp | 162 +++++++++++++++++-------- src/math/dd/dd_pdd.h | 180 +++++++++++++++++----------- src/sat/smt/polysat_internalize.cpp | 2 - 3 files changed, 224 insertions(+), 120 deletions(-) diff --git a/src/math/dd/dd_pdd.cpp b/src/math/dd/dd_pdd.cpp index 970eb991f..3ad64acfd 100644 --- a/src/math/dd/dd_pdd.cpp +++ b/src/math/dd/dd_pdd.cpp @@ -113,11 +113,34 @@ namespace dd { pdd pdd_manager::add(rational const& r, pdd const& b) { pdd c(mk_val(r)); return pdd(apply(c.root, b.root, pdd_add_op), this); } pdd pdd_manager::zero() { return pdd(zero_pdd, this); } pdd pdd_manager::one() { return pdd(one_pdd, this); } - - pdd pdd_manager::mk_or(pdd const& p, pdd const& q) { return p + q - (p*q); } - pdd pdd_manager::mk_xor(pdd const& p, pdd const& q) { if (m_semantics == mod2_e) return p + q; return (p*q*2) - p - q; } - pdd pdd_manager::mk_xor(pdd const& p, unsigned x) { pdd q(mk_val(x)); if (m_semantics == mod2_e) return p + q; return (p*q*2) - p - q; } - pdd pdd_manager::mk_not(pdd const& p) { return 1 - p; } + + // NOTE: bit-wise AND cannot be expressed in mod2N_e semantics with the existing operations. + pdd pdd_manager::mk_and(pdd const& p, pdd const& q) { + VERIFY(m_semantics == mod2_e || m_semantics == zero_one_vars_e); + return p * q; + } + + pdd pdd_manager::mk_or(pdd const& p, pdd const& q) { + return p + q - mk_and(p, q); + } + + pdd pdd_manager::mk_xor(pdd const& p, pdd const& q) { + if (m_semantics == mod2_e) + return p + q; + return p + q - 2*mk_and(p, q); + } + + pdd pdd_manager::mk_xor(pdd const& p, unsigned x) { + pdd q(mk_val(x)); + return mk_xor(p, q); + } + + pdd pdd_manager::mk_not(pdd const& p) { + if (m_semantics == mod2N_e) + return -p - 1; + VERIFY(m_semantics == mod2_e || m_semantics == zero_one_vars_e); + return 1 - p; + } pdd pdd_manager::subst_val(pdd const& p, unsigned v, rational const& val) { pdd r = mk_var(v) + val; @@ -173,15 +196,8 @@ namespace dd { if (m_semantics != mod2N_e) return 0; - if (is_val(p)) { - rational v = val(p); - if (v.is_zero()) - return m_power_of_2 + 1; - unsigned r = 0; - while (v.is_even() && v > 0) - r++, v /= 2; - return r; - } + if (is_val(p)) + return val(p).parity(m_power_of_2); init_mark(); PDD q = p; m_todo.push_back(hi(q)); @@ -189,9 +205,9 @@ namespace dd { q = lo(q); m_todo.push_back(hi(q)); } - unsigned p2 = val(q).trailing_zeros(); + unsigned parity = val(q).parity(m_power_of_2); init_mark(); - while (p2 != 0 && !m_todo.empty()) { + while (parity != 0 && !m_todo.empty()) { PDD r = m_todo.back(); m_todo.pop_back(); if (is_marked(r)) @@ -203,11 +219,11 @@ namespace dd { } else if (val(r).is_zero()) continue; - else if (val(r).trailing_zeros() < p2) - p2 = val(r).trailing_zeros(); + else + parity = std::min(parity, val(r).trailing_zeros()); } m_todo.reset(); - return p2; + return parity; } pdd pdd_manager::subst_val(pdd const& p, pdd const& s) { @@ -246,7 +262,7 @@ namespace dd { } pdd_manager::PDD pdd_manager::apply(PDD arg1, PDD arg2, pdd_op op) { - bool first = true; + unsigned count = 0; SASSERT(well_formed()); scoped_push _sp(*this); while (true) { @@ -255,8 +271,9 @@ namespace dd { } catch (const mem_out &) { try_gc(); - if (!first) throw; - first = false; + if (count > 0) + m_max_num_nodes *= 2; + count++; } } SASSERT(well_formed()); @@ -507,7 +524,7 @@ namespace dd { if (m_semantics == mod2_e) { return a; } - bool first = true; + unsigned count = 0; SASSERT(well_formed()); scoped_push _sp(*this); while (true) { @@ -516,8 +533,9 @@ namespace dd { } catch (const mem_out &) { try_gc(); - if (!first) throw; - first = false; + if (count > 0) + m_max_num_nodes *= 2; + ++count; } } SASSERT(well_formed()); @@ -565,7 +583,7 @@ namespace dd { return true; } SASSERT(c.is_int()); - bool first = true; + unsigned count = 0; SASSERT(well_formed()); scoped_push _sp(*this); while (true) { @@ -578,8 +596,9 @@ namespace dd { } catch (const mem_out &) { try_gc(); - if (!first) throw; - first = false; + if (count > 0) + m_max_num_nodes *= 2; + ++count; } } } @@ -1138,6 +1157,7 @@ namespace dd { unsigned pdd_manager::max_pow2_divisor(PDD p) { init_mark(); unsigned min_j = UINT_MAX; + SASSERT(m_todo.empty()); m_todo.push_back(p); while (!m_todo.empty()) { PDD r = m_todo.back(); @@ -1785,27 +1805,44 @@ namespace dd { } pdd& pdd::operator=(pdd const& other) { + if (m != other.m) { + verbose_stream() << "pdd manager confusion: " << *this << " (mod 2^" << power_of_2() << ") := " << other << " (mod 2^" << other.power_of_2() << ")\n"; + UNREACHABLE(); + // TODO: in the end, this operator should probably be changed to also update the manager. But for now I want to detect such confusions. + reset(*other.m); + } + SASSERT_EQ(power_of_2(), other.power_of_2()); + VERIFY_EQ(power_of_2(), other.power_of_2()); + VERIFY_EQ(m, other.m); unsigned r1 = root; root = other.root; - m.inc_ref(root); - m.dec_ref(r1); + m->inc_ref(root); + m->dec_ref(r1); return *this; } pdd& pdd::operator=(unsigned k) { - m.dec_ref(root); - root = m.mk_val(k).root; - m.inc_ref(root); + m->dec_ref(root); + root = m->mk_val(k).root; + m->inc_ref(root); return *this; } pdd& pdd::operator=(rational const& k) { - m.dec_ref(root); - root = m.mk_val(k).root; - m.inc_ref(root); + m->dec_ref(root); + root = m->mk_val(k).root; + m->inc_ref(root); return *this; } + /* Reset pdd to 0. Allows re-assigning the pdd manager. */ + void pdd::reset(pdd_manager& new_m) { + m->dec_ref(root); + root = 0; + m = &new_m; + SASSERT(is_zero()); + } + rational const& pdd::leading_coefficient() const { pdd p = *this; while (!p.is_val()) @@ -1813,11 +1850,10 @@ namespace dd { return p.val(); } - rational const& pdd::offset() const { - pdd p = *this; - while (!p.is_val()) - p = p.lo(); - return p.val(); + rational const& pdd_manager::offset(PDD p) const { + while (!is_val(p)) + p = lo(p); + return val(p); } pdd pdd::shl(unsigned n) const { @@ -1831,7 +1867,7 @@ namespace dd { pdd pdd::subst_pdd(unsigned v, pdd const& r) const { if (is_val()) return *this; - if (m.m_var2level[var()] < m.m_var2level[v]) + if (m->m_var2level[var()] < m->m_var2level[v]) return *this; pdd l = lo().subst_pdd(v, r); pdd h = hi().subst_pdd(v, r); @@ -1840,7 +1876,7 @@ namespace dd { else if (l == lo() && h == hi()) return *this; else - return m.mk_var(var())*h + l; + return m->mk_var(var())*h + l; } std::pair pdd::var_factors() const { @@ -1871,7 +1907,7 @@ namespace dd { ++i; ++j; } - else if (m.m_var2level[lo_vars[i]] > m.m_var2level[hi_vars[j]]) + else if (m->m_var2level[lo_vars[i]] > m->m_var2level[hi_vars[j]]) hi_vars[jr++] = hi_vars[j++]; else lo_vars[ir++] = lo_vars[i++]; @@ -1882,7 +1918,7 @@ namespace dd { auto mul = [&](unsigned_vector const& vars, pdd p) { for (auto v : vars) - p *= m.mk_var(v); + p *= m->mk_var(v); return p; }; @@ -1908,7 +1944,7 @@ namespace dd { std::ostream& operator<<(std::ostream& out, pdd const& b) { return b.display(out); } void pdd_iterator::next() { - auto& m = m_pdd.m; + auto& m = m_pdd.manager(); while (!m_nodes.empty()) { auto& p = m_nodes.back(); if (p.first && !m.is_val(p.second)) { @@ -1935,13 +1971,16 @@ namespace dd { void pdd_iterator::first() { unsigned n = m_pdd.root; - auto& m = m_pdd.m; + auto& m = m_pdd.manager(); while (!m.is_val(n)) { m_nodes.push_back(std::make_pair(true, n)); m_mono.vars.push_back(m.var(n)); n = m.hi(n); } m_mono.coeff = m.val(n); + // if m_pdd is constant and non-zero, the iterator should return a single monomial + if (m_nodes.empty() && !m_mono.coeff.is_zero()) + m_nodes.push_back(std::make_pair(false, n)); } pdd_iterator pdd::begin() const { return pdd_iterator(*this, true); } @@ -1960,5 +1999,32 @@ namespace dd { return out; } + void pdd_linear_iterator::first() { + m_next = m_pdd.root; + next(); + } -} + void pdd_linear_iterator::next() { + SASSERT(m_next != pdd_manager::null_pdd); + auto& m = m_pdd.manager(); + while (!m.is_val(m_next)) { + unsigned const var = m.var(m_next); + rational const val = m.offset(m.hi(m_next)); + m_next = m.lo(m_next); + if (!val.is_zero()) { + m_mono = {val, var}; + return; + } + } + m_next = pdd_manager::null_pdd; + } + + pdd_linear_iterator pdd::pdd_linear_monomials::begin() const { + return pdd_linear_iterator(m_pdd, true); + } + + pdd_linear_iterator pdd::pdd_linear_monomials::end() const { + return pdd_linear_iterator(m_pdd, false); + } + +} // namespace dd diff --git a/src/math/dd/dd_pdd.h b/src/math/dd/dd_pdd.h index f2547e962..2dc9a9480 100644 --- a/src/math/dd/dd_pdd.h +++ b/src/math/dd/dd_pdd.h @@ -45,6 +45,7 @@ namespace dd { class pdd; class pdd_manager; class pdd_iterator; + class pdd_linear_iterator; class pdd_manager { public: @@ -53,13 +54,14 @@ namespace dd { friend test; friend pdd; friend pdd_iterator; + friend pdd_linear_iterator; typedef unsigned PDD; typedef vector> monomials_t; - const PDD null_pdd = UINT_MAX; - const PDD zero_pdd = 0; - const PDD one_pdd = 1; + static constexpr PDD null_pdd = UINT_MAX; + static constexpr PDD zero_pdd = 0; + static constexpr PDD one_pdd = 1; enum pdd_op { pdd_add_op = 2, @@ -261,6 +263,7 @@ namespace dd { inline PDD lo(PDD p) const { return m_nodes[p].m_lo; } inline PDD hi(PDD p) const { return m_nodes[p].m_hi; } inline rational const& val(PDD p) const { SASSERT(is_val(p)); return m_values[lo(p)]; } + inline rational get_signed_val(PDD p) const { SASSERT(m_semantics == mod2_e || m_semantics == mod2N_e); rational const& a = val(p); return a.get_bit(power_of_2() - 1) ? a - two_to_N() : a; } inline void inc_ref(PDD p) { if (m_nodes[p].m_refcount != max_rc) m_nodes[p].m_refcount++; SASSERT(!m_free_nodes.contains(p)); } inline void dec_ref(PDD p) { if (m_nodes[p].m_refcount != max_rc) m_nodes[p].m_refcount--; SASSERT(!m_free_nodes.contains(p)); } inline PDD level2pdd(unsigned l) const { return m_var2pdd[m_level2var[l]]; } @@ -341,9 +344,10 @@ namespace dd { pdd mul(rational const& c, pdd const& b); pdd div(pdd const& a, rational const& c); bool try_div(pdd const& a, rational const& c, pdd& out_result); + pdd mk_and(pdd const& p, pdd const& q); pdd mk_or(pdd const& p, pdd const& q); pdd mk_xor(pdd const& p, pdd const& q); - pdd mk_xor(pdd const& p, unsigned q); + pdd mk_xor(pdd const& p, unsigned x); pdd mk_not(pdd const& p); pdd reduce(pdd const& a, pdd const& b); pdd subst_val0(pdd const& a, vector> const& s); @@ -368,6 +372,8 @@ namespace dd { bool is_univariate_in(PDD p, unsigned v); void get_univariate_coefficients(PDD p, vector& coeff); + rational const& offset(PDD p) const; + // create an spoly r if leading monomials of a and b overlap bool try_spoly(pdd const& a, pdd const& b, pdd& r); @@ -399,106 +405,120 @@ namespace dd { friend test; friend class pdd_manager; friend class pdd_iterator; + friend class pdd_linear_iterator; unsigned root; - pdd_manager& m; - pdd(unsigned root, pdd_manager& m): root(root), m(m) { m.inc_ref(root); } - pdd(unsigned root, pdd_manager* _m): root(root), m(*_m) { m.inc_ref(root); } + pdd_manager* m; + pdd(unsigned root, pdd_manager& pm): root(root), m(&pm) { m->inc_ref(root); } + pdd(unsigned root, pdd_manager* pm): root(root), m(pm) { m->inc_ref(root); } public: - pdd(pdd_manager& pm): root(0), m(pm) { SASSERT(is_zero()); } - pdd(pdd const& other): root(other.root), m(other.m) { m.inc_ref(root); } - pdd(pdd && other) noexcept : root(0), m(other.m) { std::swap(root, other.root); } + pdd(pdd_manager& m): pdd(0, m) { SASSERT(is_zero()); } + pdd(pdd const& other): pdd(other.root, other.m) { m->inc_ref(root); } + pdd(pdd && other) noexcept : pdd(0, other.m) { std::swap(root, other.root); } pdd& operator=(pdd const& other); pdd& operator=(unsigned k); pdd& operator=(rational const& k); - ~pdd() { m.dec_ref(root); } - pdd lo() const { return pdd(m.lo(root), m); } - pdd hi() const { return pdd(m.hi(root), m); } + // TODO: pdd& operator=(pdd&& other); (just swap like move constructor?) + ~pdd() { m->dec_ref(root); } + void reset(pdd_manager& new_m); + pdd lo() const { return pdd(m->lo(root), m); } + pdd hi() const { return pdd(m->hi(root), m); } unsigned index() const { return root; } - unsigned var() const { return m.var(root); } - rational const& val() const { SASSERT(is_val()); return m.val(root); } + unsigned var() const { return m->var(root); } + rational const& val() const { return m->val(root); } + rational get_signed_val() const { return m->get_signed_val(root); } rational const& leading_coefficient() const; - rational const& offset() const; - bool is_val() const { return m.is_val(root); } - bool is_one() const { return m.is_one(root); } - bool is_zero() const { return m.is_zero(root); } - bool is_linear() const { return m.is_linear(root); } - bool is_var() const { return m.is_var(root); } - bool is_max() const { return m.is_max(root); } + rational const& offset() const { return m->offset(root); } + bool is_val() const { return m->is_val(root); } + bool is_one() const { return m->is_one(root); } + bool is_zero() const { return m->is_zero(root); } + bool is_linear() const { return m->is_linear(root); } + bool is_var() const { return m->is_var(root); } + bool is_max() const { return m->is_max(root); } /** Polynomial is of the form a * x + b for some numerals a, b. */ bool is_unilinear() const { return !is_val() && lo().is_val() && hi().is_val(); } /** Polynomial is of the form a * x for some numeral a. */ bool is_unary() const { return !is_val() && lo().is_zero() && hi().is_val(); } bool is_offset() const { return !is_val() && lo().is_val() && hi().is_one(); } - bool is_binary() const { return m.is_binary(root); } - bool is_monomial() const { return m.is_monomial(root); } - bool is_univariate() const { return m.is_univariate(root); } - bool is_univariate_in(unsigned v) const { return m.is_univariate_in(root, v); } - void get_univariate_coefficients(vector& coeff) const { m.get_univariate_coefficients(root, coeff); } - vector get_univariate_coefficients() const { vector coeff; m.get_univariate_coefficients(root, coeff); return coeff; } - bool is_never_zero() const { return m.is_never_zero(root); } - unsigned min_parity() const { return m.min_parity(root); } - bool var_is_leaf(unsigned v) const { return m.var_is_leaf(root, v); } + bool is_binary() const { return m->is_binary(root); } + bool is_monomial() const { return m->is_monomial(root); } + bool is_univariate() const { return m->is_univariate(root); } + bool is_univariate_in(unsigned v) const { return m->is_univariate_in(root, v); } + void get_univariate_coefficients(vector& coeff) const { m->get_univariate_coefficients(root, coeff); } + vector get_univariate_coefficients() const { vector coeff; m->get_univariate_coefficients(root, coeff); return coeff; } + bool is_never_zero() const { return m->is_never_zero(root); } + unsigned min_parity() const { return m->min_parity(root); } + bool var_is_leaf(unsigned v) const { return m->var_is_leaf(root, v); } - pdd operator-() const { return m.minus(*this); } - pdd operator+(pdd const& other) const { return m.add(*this, other); } - pdd operator-(pdd const& other) const { return m.sub(*this, other); } - pdd operator*(pdd const& other) const { return m.mul(*this, other); } - pdd operator&(pdd const& other) const { return m.mul(*this, other); } - pdd operator|(pdd const& other) const { return m.mk_or(*this, other); } - pdd operator^(pdd const& other) const { return m.mk_xor(*this, other); } - pdd operator^(unsigned other) const { return m.mk_xor(*this, other); } + pdd operator-() const { return m->minus(*this); } + pdd operator+(pdd const& other) const { VERIFY_EQ(m, other.m); return m->add(*this, other); } + pdd operator-(pdd const& other) const { VERIFY_EQ(m, other.m); return m->sub(*this, other); } + pdd operator*(pdd const& other) const { VERIFY_EQ(m, other.m); return m->mul(*this, other); } + pdd operator&(pdd const& other) const { VERIFY_EQ(m, other.m); return m->mk_and(*this, other); } + pdd operator|(pdd const& other) const { VERIFY_EQ(m, other.m); return m->mk_or(*this, other); } + pdd operator^(pdd const& other) const { VERIFY_EQ(m, other.m); return m->mk_xor(*this, other); } + pdd operator^(unsigned other) const { return m->mk_xor(*this, m->mk_val(other)); } - pdd operator*(rational const& other) const { return m.mul(other, *this); } - pdd operator+(rational const& other) const { return m.add(other, *this); } - pdd operator~() const { return m.mk_not(*this); } + pdd operator*(rational const& other) const { return m->mul(other, *this); } + pdd operator+(rational const& other) const { return m->add(other, *this); } + pdd operator~() const { return m->mk_not(*this); } pdd shl(unsigned n) const; - pdd rev_sub(rational const& r) const { return m.sub(m.mk_val(r), *this); } - pdd div(rational const& other) const { return m.div(*this, other); } - bool try_div(rational const& other, pdd& out_result) const { return m.try_div(*this, other, out_result); } - pdd pow(unsigned j) const { return m.pow(*this, j); } - pdd reduce(pdd const& other) const { return m.reduce(*this, other); } - bool different_leading_term(pdd const& other) const { return m.different_leading_term(*this, other); } - void factor(unsigned v, unsigned degree, pdd& lc, pdd& rest) const { m.factor(*this, v, degree, lc, rest); } - bool factor(unsigned v, unsigned degree, pdd& lc) const { return m.factor(*this, v, degree, lc); } - bool resolve(unsigned v, pdd const& other, pdd& result) { return m.resolve(v, *this, other, result); } - pdd reduce(unsigned v, pdd const& other) const { return m.reduce(v, *this, other); } + pdd rev_sub(rational const& r) const { return m->sub(m->mk_val(r), *this); } + pdd div(rational const& other) const { return m->div(*this, other); } + bool try_div(rational const& other, pdd& out_result) const { VERIFY_EQ(m, out_result.m); return m->try_div(*this, other, out_result); } + pdd pow(unsigned j) const { return m->pow(*this, j); } + pdd reduce(pdd const& other) const { VERIFY_EQ(m, other.m); return m->reduce(*this, other); } + bool different_leading_term(pdd const& other) const { VERIFY_EQ(m, other.m); return m->different_leading_term(*this, other); } + void factor(unsigned v, unsigned degree, pdd& lc, pdd& rest) const { VERIFY_EQ(m, lc.m); VERIFY_EQ(m, rest.m); m->factor(*this, v, degree, lc, rest); } + bool factor(unsigned v, unsigned degree, pdd& lc) const { VERIFY_EQ(m, lc.m); return m->factor(*this, v, degree, lc); } + bool resolve(unsigned v, pdd const& other, pdd& result) { VERIFY_EQ(m, other.m); VERIFY_EQ(m, result.m); return m->resolve(v, *this, other, result); } + pdd reduce(unsigned v, pdd const& other) const { VERIFY_EQ(m, other.m); return m->reduce(v, *this, other); } /** * \brief factor out variables */ std::pair var_factors() const; - pdd subst_val0(vector> const& s) const { return m.subst_val0(*this, s); } - pdd subst_val(pdd const& s) const { return m.subst_val(*this, s); } - pdd subst_val(unsigned v, rational const& val) const { return m.subst_val(*this, v, val); } - pdd subst_add(unsigned var, rational const& val) const { return m.subst_add(*this, var, val); } - bool subst_get(unsigned var, rational& out_val) const { return m.subst_get(*this, var, out_val); } + pdd subst_val0(vector> const& s) const { return m->subst_val0(*this, s); } + pdd subst_val(pdd const& s) const { VERIFY_EQ(m, s.m); return m->subst_val(*this, s); } + pdd subst_val(unsigned v, rational const& val) const { return m->subst_val(*this, v, val); } + pdd subst_add(unsigned var, rational const& val) const { return m->subst_add(*this, var, val); } + bool subst_get(unsigned var, rational& out_val) const { return m->subst_get(*this, var, out_val); } /** * \brief substitute variable v by r. */ pdd subst_pdd(unsigned v, pdd const& r) const; - std::ostream& display(std::ostream& out) const { return m.display(out, *this); } - bool operator==(pdd const& other) const { return root == other.root; } - bool operator!=(pdd const& other) const { return root != other.root; } + std::ostream& display(std::ostream& out) const { return m->display(out, *this); } + bool operator==(pdd const& other) const { return root == other.root && m == other.m; } + bool operator!=(pdd const& other) const { return !operator==(other); } unsigned hash() const { return root; } - unsigned power_of_2() const { return m.power_of_2(); } + unsigned power_of_2() const { return m->power_of_2(); } - unsigned dag_size() const { return m.dag_size(*this); } - double tree_size() const { return m.tree_size(*this); } - unsigned degree() const { return m.degree(*this); } - unsigned degree(unsigned v) const { return m.degree(root, v); } - unsigned max_pow2_divisor() const { return m.max_pow2_divisor(root); } - unsigned_vector const& free_vars() const { return m.free_vars(*this); } + unsigned dag_size() const { return m->dag_size(*this); } + double tree_size() const { return m->tree_size(*this); } + unsigned degree() const { return m->degree(*this); } + unsigned degree(unsigned v) const { return m->degree(root, v); } + unsigned max_pow2_divisor() const { return m->max_pow2_divisor(root); } + unsigned_vector const& free_vars() const { return m->free_vars(*this); } - void swap(pdd& other) { std::swap(root, other.root); } + void swap(pdd& other) { VERIFY_EQ(m, other.m); std::swap(root, other.root); } pdd_iterator begin() const; pdd_iterator end() const; - pdd_manager& manager() const { return m; } + class pdd_linear_monomials { + friend class pdd; + pdd const& m_pdd; + pdd_linear_monomials(pdd const& p): m_pdd(p) {} + public: + pdd_linear_iterator begin() const; + pdd_linear_iterator end() const; + }; + pdd_linear_monomials linear_monomials() const { return pdd_linear_monomials(*this); } + + pdd_manager& manager() const { return *m; } }; inline pdd operator*(rational const& r, pdd const& b) { return b * r; } @@ -552,7 +572,27 @@ namespace dd { pdd_iterator& operator++() { next(); return *this; } pdd_iterator operator++(int) { auto tmp = *this; next(); return tmp; } bool operator==(pdd_iterator const& other) const { return m_nodes == other.m_nodes; } - bool operator!=(pdd_iterator const& other) const { return m_nodes != other.m_nodes; } + bool operator!=(pdd_iterator const& other) const { return !operator==(other); } + }; + + class pdd_linear_iterator { + friend class pdd::pdd_linear_monomials; + pdd m_pdd; + std::pair m_mono; + pdd_manager::PDD m_next = pdd_manager::null_pdd; + pdd_linear_iterator(pdd const& p, bool at_start): m_pdd(p) { if (at_start) first(); } + void first(); + void next(); + public: + using value_type = std::pair; // coefficient and variable + using reference = value_type const&; + using pointer = value_type const*; + reference operator*() const { return m_mono; } + pointer operator->() const { return &m_mono; } + pdd_linear_iterator& operator++() { next(); return *this; } + pdd_linear_iterator operator++(int) { auto tmp = *this; next(); return tmp; } + bool operator==(pdd_linear_iterator const& other) const { return m_next == other.m_next; } + bool operator!=(pdd_linear_iterator const& other) const { return m_next != other.m_next; } }; class val_pp { diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index 3cbf4b0ed..42336d478 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -698,9 +698,7 @@ namespace polysat { m_var2pdd.reserve(get_num_vars(), p); m_var2pdd_valid.reserve(get_num_vars(), false); ctx.push(set_bitvector_trail(m_var2pdd_valid, v)); -#if 0 m_var2pdd[v].reset(p.manager()); -#endif m_var2pdd[v] = p; } From 196409b3022e622b049badf9339315942fd1f6a3 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 15 Dec 2023 10:40:02 -0800 Subject: [PATCH 82/89] refactor polysat core / solver interface Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat/core.cpp | 85 ++++++++++++-------- src/sat/smt/polysat/core.h | 14 ++-- src/sat/smt/polysat/types.h | 3 + src/sat/smt/polysat/umul_ovfl_constraint.cpp | 2 +- src/sat/smt/polysat_solver.cpp | 27 ++++--- src/sat/smt/polysat_solver.h | 4 +- 6 files changed, 76 insertions(+), 59 deletions(-) diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index b3f8474b7..0607f530d 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -69,19 +69,29 @@ namespace polysat { void undo() override { auto& [sc, lit, val] = c.m_constraint_index.back(); auto& vars = sc.vars(); + auto idx = c.m_constraint_index.size() - 1; IF_VERBOSE(10, verbose_stream() << "undo add watch " << sc << " "; - if (vars.size() > 0) verbose_stream() << "(" << c.m_constraint_index.size() -1 << ": " << c.m_watch[vars[0]] << ") "; - if (vars.size() > 1) verbose_stream() << "(" << c.m_constraint_index.size() -1 << ": " << c.m_watch[vars[1]] << ") "; + if (vars.size() > 0) verbose_stream() << "(" << idx << ": " << c.m_watch[vars[0]] << ") "; + if (vars.size() > 1) verbose_stream() << "(" << idx<< ": " << c.m_watch[vars[1]] << ") "; verbose_stream() << "\n"); unsigned n = sc.num_watch(); SASSERT(n <= vars.size()); - SASSERT(n <= 0 || c.m_watch[vars[0]].back() == c.m_constraint_index.size() - 1); - SASSERT(n <= 1 || c.m_watch[vars[1]].back() == c.m_constraint_index.size() - 1); - if (n > 0) - c.m_watch[vars[0]].pop_back(); + auto del_watch = [&](unsigned i) { + auto& w = c.m_watch[vars[i]]; + for (unsigned j = w.size(); j-- > 0;) { + if (w[j] == idx) { + std::swap(w[w.size() - 1], w[j]); + w.pop_back(); + return; + } + } + UNREACHABLE(); + }; + if (n > 0) + del_watch(0); if (n > 1) - c.m_watch[vars[1]].pop_back(); + del_watch(1); c.m_constraint_index.pop_back(); } }; @@ -132,7 +142,7 @@ namespace polysat { m_var_queue.del_var_eh(v); } - unsigned core::register_constraint(signed_constraint& sc, dependency d) { + constraint_id core::register_constraint(signed_constraint& sc, dependency d) { unsigned idx = m_constraint_index.size(); m_constraint_index.push_back({ sc, d, l_undef }); auto& vars = sc.vars(); @@ -150,7 +160,7 @@ namespace polysat { if (j > 1) verbose_stream() << "( " << idx << " : " << m_watch[vars[1]] << ") "; verbose_stream() << "\n"); s.trail().push(mk_add_watch(*this)); - return idx; + return { idx }; } // case split on unassigned variables until all are assigned values. @@ -202,9 +212,11 @@ namespace polysat { return sc; } - void core::propagate_assignment(prop_item& dc) { - auto [idx, sign, dep] = dc; - auto sc = get_constraint(idx, sign); + void core::propagate_assignment(constraint_id idx) { + auto [sc, dep, value] = m_constraint_index[idx.id]; + SASSERT(value != l_undef); + if (value == l_false) + sc = ~sc; if (sc.is_eq(m_var, m_value)) propagate_assignment(m_var, m_value, dep); else @@ -252,7 +264,9 @@ namespace polysat { // this can create fresh literals and update m_watch, but // will not update m_watch[v] (other than copy constructor for m_watch) // because v has been assigned a value. - sc.propagate(*this, value, dep); + propagate(sc, value, dep); + if (s.inconsistent()) + return; SASSERT(!swapped || vars.size() <= 1 || (!is_assigned(vars[0]) && !is_assigned(vars[1]))); if (swapped) @@ -272,30 +286,17 @@ namespace polysat { verbose_stream() << "new watch " << v << ": " << m_watch[v] << "\n"; } - void core::propagate_value(prop_item const& dc) { - auto [idx, sign, dep] = dc; - auto sc = get_constraint(idx, sign); - // check if sc evaluates to false - switch (eval(sc)) { - case l_true: - break; - case l_false: - m_unsat_core = explain_eval(sc); - m_unsat_core.push_back(dep); - propagate_unsat_core(); - return; - default: - break; - } + void core::propagate_value(constraint_id idx) { + auto [sc, d, value] = m_constraint_index[idx.id]; // propagate current assignment for sc - sc.propagate(*this, to_lbool(!sign), dep); + propagate(sc, value, d); if (s.inconsistent()) return; // if sc is v == value, then check the watch list for v to propagate truth assignments if (sc.is_eq(m_var, m_value)) { for (auto idx1 : m_watch[m_var]) { - if (idx == idx1) + if (idx.id == idx1) continue; auto [sc, d, value] = m_constraint_index[idx1]; switch (eval(sc)) { @@ -312,6 +313,19 @@ namespace polysat { } } + void core::propagate(signed_constraint& sc, lbool value, dependency const& d) { + lbool eval_value = eval(sc); + if (eval_value == l_undef) + sc.propagate(*this, value, d); + else if (value == l_undef) + s.propagate(d, eval_value != l_true, explain_eval(sc)); + else if (value != eval_value) { + m_unsat_core = explain_eval(sc); + m_unsat_core.push_back(value == l_false ? ~d : d); + propagate_unsat_core(); + } + } + void core::get_bitvector_prefixes(pvar v, pvar_vector& out) { s.get_bitvector_prefixes(v, out); } @@ -331,7 +345,7 @@ namespace polysat { s.set_conflict(m_unsat_core); } - void core::assign_eh(unsigned index, bool sign, dependency const& dep) { + void core::assign_eh(constraint_id index, bool sign, unsigned level) { struct unassign : public trail { core& c; unsigned m_index; @@ -341,9 +355,10 @@ namespace polysat { c.m_prop_queue.pop_back(); } }; - m_prop_queue.push_back({ index, sign, dep }); - m_constraint_index[index].value = to_lbool(!sign); - s.trail().push(unassign(*this, index)); + m_prop_queue.push_back(index); + m_constraint_index[index.id].value = to_lbool(!sign); + m_constraint_index[index.id].d.set_level(level); + s.trail().push(unassign(*this, index.id)); } dependency_vector core::explain_eval(signed_constraint const& sc) { @@ -392,7 +407,7 @@ namespace polysat { void core::add_axiom(signed_constraint sc) { auto idx = register_constraint(sc, dependency::axiom()); - assign_eh(idx, false, dependency::axiom()); + assign_eh(idx, false, 0); } void core::add_clause(char const* name, core_vector const& cs, bool is_redundant) { diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index 6297e567e..fb0875ec8 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -37,7 +37,6 @@ namespace polysat { class mk_assign_var; class mk_add_watch; typedef svector> activity; - typedef std::tuple prop_item; friend class viable; friend class constraints; friend class assignment; @@ -53,7 +52,7 @@ namespace polysat { constraints m_constraints; assignment m_assignment; unsigned m_qhead = 0, m_vqhead = 0; - svector m_prop_queue; + svector m_prop_queue; svector m_constraint_index; // index of constraints dependency_vector m_unsat_core; @@ -76,17 +75,16 @@ namespace polysat { void del_var(); bool is_assigned(pvar v) { return !m_justification[v].is_null(); } - void propagate_value(prop_item const& dc); - void propagate_assignment(prop_item& dc); + void propagate_value(constraint_id idx); + void propagate_assignment(constraint_id idx); void propagate_assignment(pvar v, rational const& value, dependency dep); void propagate_unsat_core(); + void propagate(signed_constraint& sc, lbool value, dependency const& d); void get_bitvector_prefixes(pvar v, pvar_vector& out); void get_fixed_bits(pvar v, svector& fixed_bits); bool inconsistent() const; - - void add_watch(unsigned idx, unsigned var); signed_constraint get_constraint(unsigned idx, bool sign); @@ -100,9 +98,9 @@ namespace polysat { core(solver_interface& s); sat::check_result check(); - unsigned register_constraint(signed_constraint& sc, dependency d); + constraint_id register_constraint(signed_constraint& sc, dependency d); bool propagate(); - void assign_eh(unsigned idx, bool sign, dependency const& d); + void assign_eh(constraint_id idx, bool sign, unsigned level); pdd value(rational const& v, unsigned sz); pdd subst(pdd const&); diff --git a/src/sat/smt/polysat/types.h b/src/sat/smt/polysat/types.h index d0b5f7bca..d9008392c 100644 --- a/src/sat/smt/polysat/types.h +++ b/src/sat/smt/polysat/types.h @@ -22,6 +22,7 @@ namespace polysat { using pdd = dd::pdd; using pvar = unsigned; using theory_var = unsigned; + struct constraint_id { unsigned id; }; using pvar_vector = unsigned_vector; inline const pvar null_var = UINT_MAX; @@ -44,6 +45,8 @@ namespace polysat { sat::literal literal() const { SASSERT(is_literal()); return *std::get_if(&m_data); } std::pair eq() const { SASSERT(!is_literal()); return *std::get_if>(&m_data); } unsigned level() const { return m_level; } + void set_level(unsigned level) { m_level = level; } + dependency operator~() const { SASSERT(is_literal()); return dependency(~literal(), level()); } }; inline const dependency null_dependency = dependency(sat::null_literal, UINT_MAX); diff --git a/src/sat/smt/polysat/umul_ovfl_constraint.cpp b/src/sat/smt/polysat/umul_ovfl_constraint.cpp index 445169c2f..596a94340 100644 --- a/src/sat/smt/polysat/umul_ovfl_constraint.cpp +++ b/src/sat/smt/polysat/umul_ovfl_constraint.cpp @@ -97,7 +97,7 @@ namespace polysat { if (!p.is_val()) return false; - VERIFY(!p.is_zero() && !p.is_one()); // evaluation should catch this case + SASSERT(!p.is_zero() && !p.is_one()); // evaluation should catch this case rational const& M = p.manager().two_to_N(); auto& C = c.cs(); diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 0fa4ab8e6..75dd09075 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -91,7 +91,7 @@ namespace polysat { if (!a) return; force_push(); - m_core.assign_eh(a->m_index, l.sign(), dependency(l, s().lvl(l))); + m_core.assign_eh(a->m_index, l.sign(), s().lvl(l)); } void solver::set_conflict(dependency_vector const& core) { @@ -115,17 +115,18 @@ namespace polysat { eqs.push_back(euf::enode_pair(n1, n2)); } } - DEBUG_CODE({ - for (auto lit : core) - VERIFY(s().value(lit) == l_true); - for (auto const& [n1, n2] : eqs) - VERIFY(n1->get_root() == n2->get_root()); - }); IF_VERBOSE(10, for (auto lit : core) - verbose_stream() << " " << lit << ": " << mk_ismt2_pp(literal2expr(lit), m) << "\n"; + verbose_stream() << " " << lit << ": " << mk_ismt2_pp(literal2expr(lit), m) << " " << s().value(lit) << "\n"; + for (auto const& [n1, n2] : eqs) + verbose_stream() << " " << ctx.bpp(n1) << " == " << ctx.bpp(n2) << "\n";); + DEBUG_CODE({ + for (auto lit : core) + SASSERT(s().value(lit) == l_true); for (auto const& [n1, n2] : eqs) - verbose_stream() << " " << ctx.bpp(n1) << " == " << ctx.bpp(n2) << "\n";); + SASSERT(n1->get_root() == n2->get_root()); + }); + return { core, eqs }; } @@ -203,8 +204,8 @@ namespace polysat { m_var_eqs.setx(m_var_eqs_head, {v1, v2}, {v1, v2}); ctx.push(value_trail(m_var_eqs_head)); auto d = dependency(v1, v2, s().scope_lvl()); - unsigned index = m_core.register_constraint(sc, d); - m_core.assign_eh(index, false, d); + constraint_id id = m_core.register_constraint(sc, d); + m_core.assign_eh(id, false, s().scope_lvl()); m_var_eqs_head++; } @@ -218,9 +219,9 @@ namespace polysat { auto sc = ~m_core.eq(p, q); sat::literal neq = ~expr2literal(ne.eq()); auto d = dependency(neq, s().lvl(neq)); - auto index = m_core.register_constraint(sc, d); + auto id = m_core.register_constraint(sc, d); TRACE("bv", tout << neq << " := " << s().value(neq) << " @" << s().scope_lvl() << "\n"); - m_core.assign_eh(index, false, d); + m_core.assign_eh(id, false, s().lvl(neq)); } // Core uses the propagate callback to add unit propagations to the trail. diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index d7032046d..0ecc8941b 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -43,8 +43,8 @@ namespace polysat { struct atom { bool_var m_bv; - unsigned m_index; - atom(bool_var b, unsigned index) :m_bv(b), m_index(index) {} + constraint_id m_index; + atom(bool_var b, constraint_id index) :m_bv(b), m_index(index) {} ~atom() { } }; From faa3a7ab4f64ade7bc1d4813512df361cf1c5b2f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 15 Dec 2023 13:50:26 -0800 Subject: [PATCH 83/89] updates to poly --- src/ast/arith_decl_plugin.cpp | 17 ++++++--- src/ast/arith_decl_plugin.h | 51 ++++++++++++++------------- src/sat/smt/arith_axioms.cpp | 28 +++++++-------- src/sat/smt/arith_solver.cpp | 3 -- src/sat/smt/intblast_solver.h | 4 ++- src/sat/smt/polysat/core.cpp | 5 ++- src/sat/smt/polysat/core.h | 9 ++--- src/sat/smt/polysat/op_constraint.cpp | 16 ++++++++- src/sat/smt/polysat_model.cpp | 10 ++---- src/sat/smt/polysat_solver.cpp | 47 ++++++++++++++---------- src/sat/smt/polysat_solver.h | 3 +- 11 files changed, 111 insertions(+), 82 deletions(-) diff --git a/src/ast/arith_decl_plugin.cpp b/src/ast/arith_decl_plugin.cpp index 2d830d510..8317b37c3 100644 --- a/src/ast/arith_decl_plugin.cpp +++ b/src/ast/arith_decl_plugin.cpp @@ -707,7 +707,16 @@ expr * arith_decl_plugin::get_some_value(sort * s) { return mk_numeral(rational(0), s == m_int_decl); } -bool arith_recognizers::is_numeral(expr const * n, rational & val, bool & is_int) const { +bool arith_util::is_numeral(expr const * n, rational & val, bool & is_int) const { + if (is_irrational_algebraic_numeral(n)) { + scoped_anum an(am()); + is_irrational_algebraic_numeral2(n, an); + if (am().is_rational(an)) { + am().to_rational(an, val); + is_int = val.is_int(); + return true; + } + } if (!is_app_of(n, arith_family_id, OP_NUM)) return false; func_decl * decl = to_app(n)->get_decl(); @@ -738,7 +747,7 @@ bool arith_recognizers::is_int_expr(expr const *e) const { if (is_to_real(e)) { // pass } - else if (is_numeral(e, r) && r.is_int()) { + else if (is_numeral(e) && is_int(e)) { // pass } else if (is_add(e) || is_mul(e)) { @@ -761,14 +770,14 @@ void arith_util::init_plugin() { m_plugin = static_cast(m_manager.get_plugin(arith_family_id)); } -bool arith_util::is_irrational_algebraic_numeral2(expr const * n, algebraic_numbers::anum & val) { +bool arith_util::is_irrational_algebraic_numeral2(expr const * n, algebraic_numbers::anum & val) const { if (!is_app_of(n, arith_family_id, OP_IRRATIONAL_ALGEBRAIC_NUM)) return false; am().set(val, to_irrational_algebraic_numeral(n)); return true; } -algebraic_numbers::anum const & arith_util::to_irrational_algebraic_numeral(expr const * n) { +algebraic_numbers::anum const & arith_util::to_irrational_algebraic_numeral(expr const * n) const { SASSERT(is_irrational_algebraic_numeral(n)); return plugin().aw().to_anum(to_app(n)->get_decl()); } diff --git a/src/ast/arith_decl_plugin.h b/src/ast/arith_decl_plugin.h index b073e205e..25c4977e9 100644 --- a/src/ast/arith_decl_plugin.h +++ b/src/ast/arith_decl_plugin.h @@ -237,26 +237,10 @@ public: family_id get_family_id() const { return arith_family_id; } bool is_arith_expr(expr const * n) const { return is_app(n) && to_app(n)->get_family_id() == arith_family_id; } - bool is_irrational_algebraic_numeral(expr const * n) const; - bool is_unsigned(expr const * n, unsigned& u) const { - rational val; - bool is_int = true; - return is_numeral(n, val, is_int) && is_int && val.is_unsigned() && (u = val.get_unsigned(), true); - } - bool is_numeral(expr const * n, rational & val, bool & is_int) const; - bool is_numeral(expr const * n, rational & val) const { bool is_int; return is_numeral(n, val, is_int); } - bool is_numeral(expr const * n) const { return is_app_of(n, arith_family_id, OP_NUM); } - bool is_zero(expr const * n) const { rational val; return is_numeral(n, val) && val.is_zero(); } - bool is_minus_one(expr * n) const { rational tmp; return is_numeral(n, tmp) && tmp.is_minus_one(); } - // return true if \c n is a term of the form (* -1 r) - bool is_times_minus_one(expr * n, expr * & r) const { - if (is_mul(n) && to_app(n)->get_num_args() == 2 && is_minus_one(to_app(n)->get_arg(0))) { - r = to_app(n)->get_arg(1); - return true; - } - return false; - } + bool is_irrational_algebraic_numeral(expr const* n) const; + + bool is_numeral(expr const* n) const { return is_app_of(n, arith_family_id, OP_NUM); } bool is_int_expr(expr const * e) const; bool is_le(expr const * n) const { return is_app_of(n, arith_family_id, OP_LE); } @@ -399,13 +383,32 @@ public: return *m_plugin; } - algebraic_numbers::manager & am() { + algebraic_numbers::manager & am() const { return plugin().am(); } + // return true if \c n is a term of the form (* -1 r) + bool is_zero(expr const* n) const { rational val; return is_numeral(n, val) && val.is_zero(); } + bool is_minus_one(expr* n) const { rational tmp; return is_numeral(n, tmp) && tmp.is_minus_one(); } + bool is_times_minus_one(expr* n, expr*& r) const { + if (is_mul(n) && to_app(n)->get_num_args() == 2 && is_minus_one(to_app(n)->get_arg(0))) { + r = to_app(n)->get_arg(1); + return true; + } + return false; + } + bool is_unsigned(expr const* n, unsigned& u) const { + rational val; + bool is_int = true; + return is_numeral(n, val, is_int) && is_int && val.is_unsigned() && (u = val.get_unsigned(), true); + } + bool is_numeral(expr const* n) const { return arith_recognizers::is_numeral(n); } + bool is_numeral(expr const* n, rational& val, bool& is_int) const; + bool is_numeral(expr const* n, rational& val) const { bool is_int; return is_numeral(n, val, is_int); } + bool convert_int_numerals_to_real() const { return plugin().convert_int_numerals_to_real(); } - bool is_irrational_algebraic_numeral2(expr const * n, algebraic_numbers::anum & val); - algebraic_numbers::anum const & to_irrational_algebraic_numeral(expr const * n); + bool is_irrational_algebraic_numeral2(expr const * n, algebraic_numbers::anum & val) const; + algebraic_numbers::anum const & to_irrational_algebraic_numeral(expr const * n) const; sort * mk_int() { return m_manager.mk_sort(arith_family_id, INT_SORT); } sort * mk_real() { return m_manager.mk_sort(arith_family_id, REAL_SORT); } @@ -512,11 +515,11 @@ public: if none of them are numerals, then the left-hand-side has a smaller id than the right hand side. */ app * mk_eq(expr * lhs, expr * rhs) { - if (is_numeral(lhs) || (!is_numeral(rhs) && lhs->get_id() > rhs->get_id())) + if (arith_recognizers::is_numeral(lhs) || (!arith_recognizers::is_numeral(rhs) && lhs->get_id() > rhs->get_id())) std::swap(lhs, rhs); if (lhs == rhs) return m_manager.mk_true(); - if (is_numeral(lhs) && is_numeral(rhs)) { + if (arith_recognizers::is_numeral(lhs) && arith_recognizers::is_numeral(rhs)) { SASSERT(lhs != rhs); return m_manager.mk_false(); } diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 09db74f75..f004422a6 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -211,25 +211,23 @@ namespace arith { if (!ctx.is_relevant(expr2enode(n))) return true; VERIFY(a.is_band(n, sz, x, y)); - if (use_nra_model()) { + expr_ref vx(m), vy(m),vn(m); + if (!get_value(expr2enode(x), vx) || !get_value(expr2enode(y), vy) || !get_value(expr2enode(n), vn)) { + IF_VERBOSE(2, verbose_stream() << "could not get value of " << mk_pp(n, m) << "\n"); found_unsupported(n); return true; } - theory_var vx = expr2enode(x)->get_th_var(get_id()); - theory_var vy = expr2enode(y)->get_th_var(get_id()); - theory_var vn = expr2enode(n)->get_th_var(get_id()); - rational N = rational::power_of_two(sz); - if (!get_value(vx).is_int() || !get_value(vy).is_int()) { - - s().display(verbose_stream()); - verbose_stream() << vx << " " << vy << " " << mk_pp(n, m) << "\n"; + rational valn, valx, valy; + bool is_int; + if (!a.is_numeral(vn, valn, is_int) || !is_int || !a.is_numeral(vx, valx, is_int) || !is_int || !a.is_numeral(vy, valy, is_int) || !is_int) { + IF_VERBOSE(2, verbose_stream() << "could not get value of " << mk_pp(n, m) << "\n"); + found_unsupported(n); + return true; } - SASSERT(get_value(vx).is_int()); - SASSERT(get_value(vy).is_int()); - SASSERT(get_value(vn).is_int()); - rational valx = mod(get_value(vx), N); - rational valy = mod(get_value(vy), N); - rational valn = get_value(vn); + // verbose_stream() << "band: " << mk_pp(n, m) << " " << valn << " := " << valx << "&" << valy << "\n"; + rational N = rational::power_of_two(sz); + valx = mod(valx, N); + valy = mod(valy, N); SASSERT(0 <= valn && valn < N); // x mod 2^{i + 1} >= 2^i means the i'th bit is 1. diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 37aef2bf8..eff25bc4a 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -628,9 +628,6 @@ namespace arith { } else if (use_nra_model() && lp().external_to_local(v) != lp::null_lpvar) { anum const& an = nl_value(v, m_nla->tmp1()); - - - if (a.is_int(o) && !m_nla->am().is_int(an)) value = a.mk_numeral(rational::zero(), a.is_int(o)); else diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index 493b1f3c5..d59dac935 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -64,7 +64,7 @@ namespace intblast { void translate(expr_ref_vector& es); void sorted_subterms(expr_ref_vector& es, ptr_vector& sorted); - rational get_value(expr* e) const; + bool is_translated(expr* e) const { return !!m_translate.get(e->get_id(), nullptr); } expr* translated(expr* e) const { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } @@ -136,6 +136,8 @@ namespace intblast { void eq_internalized(euf::enode* n) override; + rational get_value(expr* e) const; + }; } diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index 0607f530d..c9deb5726 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -177,10 +177,9 @@ namespace polysat { s.set_lemma(m_viable.get_core(), m_viable.explain()); // propagate_unsat_core(); return sat::check_result::CR_CONTINUE; - case find_t::singleton: { + case find_t::singleton: s.propagate(m_constraints.eq(var2pdd(m_var), m_value), m_viable.explain()); - return sat::check_result::CR_CONTINUE; - } + return sat::check_result::CR_CONTINUE; case find_t::multiple: s.add_eq_literal(m_var, m_value); return sat::check_result::CR_CONTINUE; diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index fb0875ec8..46661dc84 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -60,10 +60,10 @@ namespace polysat { // attributes associated with variables vector m_vars; // for each variable a pdd vector m_values; // current value of assigned variable - svector m_justification; // justification for assignment - activity m_activity; // activity of variables - var_queue m_var_queue; // priority queue of variables to assign - vector m_watch; // watch lists for variables for constraints on m_prop_queue where they occur + svector m_justification; // justification for assignment + activity m_activity; // activity of variables + var_queue m_var_queue; // priority queue of variables to assign + vector m_watch; // watch lists for variables for constraints on m_prop_queue where they occur // values to split on rational m_value; @@ -101,6 +101,7 @@ namespace polysat { constraint_id register_constraint(signed_constraint& sc, dependency d); bool propagate(); void assign_eh(constraint_id idx, bool sign, unsigned level); + pvar next_var() { return m_var_queue.next_var(); } pdd value(rational const& v, unsigned sz); pdd subst(pdd const&); diff --git a/src/sat/smt/polysat/op_constraint.cpp b/src/sat/smt/polysat/op_constraint.cpp index b7d312d55..175d3a145 100644 --- a/src/sat/smt/polysat/op_constraint.cpp +++ b/src/sat/smt/polysat/op_constraint.cpp @@ -96,7 +96,21 @@ namespace polysat { } lbool op_constraint::eval_ashr(pdd const& p, pdd const& q, pdd const& r) { - NOT_IMPLEMENTED_YET(); + auto& m = p.manager(); + if (r.is_val() && p.is_val() && q.is_val()) { + auto M = m.max_value(); + auto N = M + 1; + if (p.val() >= N/2) { + if (q.val() >= m.power_of_2()) + return to_lbool(r.val() == M); + unsigned k = q.val().get_unsigned(); + return to_lbool(r.val() == p.val() - rational::power_of_two(k)); + } + else + return eval_lshr(p, q, r); + } + if (q.is_val() && q.is_zero() && p == r) + return l_true; return l_undef; } diff --git a/src/sat/smt/polysat_model.cpp b/src/sat/smt/polysat_model.cpp index 5bd8d4dc9..028aeed6b 100644 --- a/src/sat/smt/polysat_model.cpp +++ b/src/sat/smt/polysat_model.cpp @@ -23,12 +23,7 @@ Author: namespace polysat { - void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { - - if (m_use_intblast_model) { - m_intblast.add_value(n, mdl, values); - return; - } + void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { auto p = expr2pdd(n->get_expr()); rational val; if (!m_core.try_eval(p, val)) { @@ -82,8 +77,7 @@ namespace polysat { for (unsigned v = 0; v < get_num_vars(); ++v) if (m_var2pdd_valid.get(v, false)) out << ctx.bpp(var2enode(v)) << " := " << m_var2pdd[v] << "\n"; - if (m_use_intblast_model) - m_intblast.display(out); + m_intblast.display(out); return out; } } diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 75dd09075..219b9017a 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -61,25 +61,36 @@ namespace polysat { return sat::check_result::CR_DONE; case sat::check_result::CR_CONTINUE: return sat::check_result::CR_CONTINUE; - case sat::check_result::CR_GIVEUP: { - if (!m.inc()) - return sat::check_result::CR_GIVEUP; - switch (m_intblast.check_solver_state()) { - case l_true: - trail().push(value_trail(m_use_intblast_model)); - m_use_intblast_model = true; - return sat::check_result::CR_DONE; - case l_false: { - auto core = m_intblast.unsat_core(); - for (auto& lit : core) - lit.neg(); - s().add_clause(core.size(), core.data(), sat::status::th(true, get_id(), nullptr)); - return sat::check_result::CR_CONTINUE; - } - case l_undef: - return sat::check_result::CR_GIVEUP; - } + case sat::check_result::CR_GIVEUP: + return intblast(); } + UNREACHABLE(); + return sat::check_result::CR_GIVEUP; + } + + sat::check_result solver::intblast() { + if (!m.inc()) + return sat::check_result::CR_GIVEUP; + switch (m_intblast.check_solver_state()) { + case l_true: { + pvar pv = m_core.next_var(); + auto v = m_pddvar2var[pv]; + auto n = var2expr(v); + auto val = m_intblast.get_value(n); + sat::literal lit = eq_internalize(n, bv.mk_numeral(val, get_bv_size(v))); + s().set_phase(lit); + return sat::check_result::CR_CONTINUE; + } + case l_false: { + IF_VERBOSE(2, verbose_stream() << "unsat core: " << m_intblast.unsat_core() << "\n"); + auto core = m_intblast.unsat_core(); + for (auto& lit : core) + lit.neg(); + s().add_clause(core.size(), core.data(), sat::status::th(true, get_id(), nullptr)); + return sat::check_result::CR_CONTINUE; + } + case l_undef: + return sat::check_result::CR_GIVEUP; } UNREACHABLE(); return sat::check_result::CR_GIVEUP; diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 0ecc8941b..60535207b 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -59,7 +59,6 @@ namespace polysat { stats m_stats; core m_core; intblast::solver m_intblast; - bool m_use_intblast_model = false; vector m_var2pdd; // theory_var 2 pdd bool_vector m_var2pdd_valid; // valid flag @@ -73,6 +72,8 @@ namespace polysat { unsigned m_lemma_level = 0; expr_ref_vector m_lemma; + sat::check_result intblast(); + // internalize bool visit(expr* e) override; bool visited(expr* e) override; From a3f3abb8f246433bd44f4bf4d94172aff8798da8 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 15 Dec 2023 13:59:06 -0800 Subject: [PATCH 84/89] use suggestion from #7047 Signed-off-by: Nikolaj Bjorner --- src/api/python/setup.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/api/python/setup.py b/src/api/python/setup.py index 325fb4230..95f717c75 100644 --- a/src/api/python/setup.py +++ b/src/api/python/setup.py @@ -313,12 +313,11 @@ if 'bdist_wheel' in sys.argv and '--plat-name' not in sys.argv: osver = RELEASE_METADATA[3] if osver.count('.') > 1: osver = '.'.join(osver.split('.')[:2]) - if osver.startswith("11"): - osver = "11_0" + osver = osver.replace('.','_') if arch == 'x64': - plat_name ='macosx_%s_x86_64' % osver.replace('.', '_') + plat_name ='macosx_%s_x86_64' % re.sub(r'\A(1[1-9])(_[\d]+)*\Z', r'\1_0', osver) elif arch == 'arm64': - plat_name ='macosx_%s_arm64' % osver.replace('.', '_') + plat_name ='macosx_%s_arm64' % osver else: raise Exception(f"idk how os {distos} {osver} works. what goes here?") else: From d0b03a15269fb20b173ca110ad5c1d8735ec5c5d Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 15 Dec 2023 14:30:13 -0800 Subject: [PATCH 85/89] work on ashr Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat/constraints.h | 4 +++ src/sat/smt/polysat/op_constraint.cpp | 39 +++++++++++++++++---------- src/sat/smt/polysat/op_constraint.h | 1 + 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/sat/smt/polysat/constraints.h b/src/sat/smt/polysat/constraints.h index 47c9beb49..fa2b62c11 100644 --- a/src/sat/smt/polysat/constraints.h +++ b/src/sat/smt/polysat/constraints.h @@ -119,6 +119,7 @@ namespace polysat { signed_constraint uge(pdd const& p, pdd const& q) { return ule(q, p); } signed_constraint uge(pdd const& p, rational const& q) { return ule(q, p); } + signed_constraint uge(pdd const& p, int q) { return ule(q, p); } signed_constraint ult(pdd const& p, rational const& q) { return ult(p, p.manager().mk_val(q)); } signed_constraint ult(rational const& p, pdd const& q) { return ult(q.manager().mk_val(p), q); } @@ -141,6 +142,9 @@ namespace polysat { signed_constraint sgt(int p, pdd const& q) { return slt(q, p); } signed_constraint sgt(unsigned p, pdd const& q) { return slt(q, p); } + signed_constraint sge(pdd const& p, pdd const& q) { return ~slt(q, p); } + signed_constraint sge(pdd const& p, int q) { return ~slt(q, p); } + signed_constraint umul_ovfl(pdd const& p, rational const& q) { return umul_ovfl(p, p.manager().mk_val(q)); } signed_constraint umul_ovfl(rational const& p, pdd const& q) { return umul_ovfl(q.manager().mk_val(p), q); } signed_constraint umul_ovfl(pdd const& p, int q) { return umul_ovfl(p, rational(q)); } diff --git a/src/sat/smt/polysat/op_constraint.cpp b/src/sat/smt/polysat/op_constraint.cpp index 175d3a145..666d950a9 100644 --- a/src/sat/smt/polysat/op_constraint.cpp +++ b/src/sat/smt/polysat/op_constraint.cpp @@ -203,6 +203,9 @@ namespace polysat { case code::and_op: activate_and(c, dep); break; + case code::ashr_op: + activate_ashr(c, dep); + break; default: break; } @@ -309,6 +312,28 @@ namespace polysat { } } + void op_constraint::activate_ashr(core& c, dependency const& d) { + // + // if q = k & p >= 0 -> r*2^k + + // if q = k & p < 0 -> (p / 2^k) - 1 + 2^{N-k} + // + + auto& m = p.manager(); + auto const pv = c.subst(p); + auto const qv = c.subst(q); + auto const rv = c.subst(r); + unsigned const N = m.power_of_2(); + + auto& C = c.cs(); + c.add_clause("q >= N & p < 0 -> p << q = -1", {~C.uge(q, N), ~C.slt(p, 0), C.eq(r, m.max_value())}, true); + c.add_clause("q >= N & p >= 0 -> p << q = 0", {~C.uge(q, N), ~C.sge(p, 0), C.eq(r)}, true); + c.add_clause("q = 0 -> p << q = p", { ~C.eq(q), C.eq(r, p) }, true); + for (unsigned k = 0; k < N; ++k) { +// c.add_clause("q = k & p >= 0 -> p << q = p / 2^k", {~C.eq(q, k), ~C.sge(p, 0), ... }, true); +// c.add_clause("q = k & p < 0 -> p << q = (p / 2^k) -1 + 2^{N-k}", {~C.eq(q, k), ~C.slt(p, 0), ... }, true); + } + } + void op_constraint::activate_and(core& c, dependency const& d) { auto x = p, y = q; @@ -336,21 +361,7 @@ namespace polysat { } void op_constraint::propagate_ashr(core& c, dependency const& dep) { - // - // ashr(x, y) - // if q >= N & p < 0 -> -1 - // if q >= N & p >= 0 -> 0 - // if q = k & p >= 0 -> p / 2^k - // if q = k & p < 0 -> (p / 2^k) - 1 + 2^{N-k} - // - auto& m = p.manager(); - auto const pv = c.subst(p); - auto const qv = c.subst(q); - auto const rv = c.subst(r); - unsigned const N = m.power_of_2(); - - NOT_IMPLEMENTED_YET(); } diff --git a/src/sat/smt/polysat/op_constraint.h b/src/sat/smt/polysat/op_constraint.h index c5400fbab..4b44b3009 100644 --- a/src/sat/smt/polysat/op_constraint.h +++ b/src/sat/smt/polysat/op_constraint.h @@ -68,6 +68,7 @@ namespace polysat { std::ostream& display(std::ostream& out, char const* eq) const; void activate_and(core& s, dependency const& d); + void activate_ashr(core& s, dependency const& d); public: ~op_constraint() override {} From a6e08b22f8d9d1ac28eb9c66f6b424314c665a77 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 15 Dec 2023 14:54:20 -0800 Subject: [PATCH 86/89] add rewrites for band Signed-off-by: Nikolaj Bjorner --- src/ast/rewriter/arith_rewriter.cpp | 29 +++++++++++++++++++++++++++ src/ast/rewriter/arith_rewriter.h | 1 + src/sat/smt/polysat/op_constraint.cpp | 14 ++++++------- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/ast/rewriter/arith_rewriter.cpp b/src/ast/rewriter/arith_rewriter.cpp index 44b91826f..ddfabed83 100644 --- a/src/ast/rewriter/arith_rewriter.cpp +++ b/src/ast/rewriter/arith_rewriter.cpp @@ -91,6 +91,7 @@ br_status arith_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * c case OP_SINH: SASSERT(num_args == 1); st = mk_sinh_core(args[0], result); break; case OP_COSH: SASSERT(num_args == 1); st = mk_cosh_core(args[0], result); break; case OP_TANH: SASSERT(num_args == 1); st = mk_tanh_core(args[0], result); break; + case OP_ARITH_BAND: SASSERT(num_args == 2); st = mk_band_core(f->get_parameter(0).get_int(), args[0], args[1], result); break; default: st = BR_FAILED; break; } CTRACE("arith_rewriter", st != BR_FAILED, tout << st << ": " << mk_pp(f, m); @@ -1349,6 +1350,34 @@ app* arith_rewriter_core::mk_power(expr* x, rational const& r, sort* s) { return y; } +br_status arith_rewriter::mk_band_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result) { + numeral x, y, N; + bool is_num_x = m_util.is_numeral(arg1, x); + bool is_num_y = m_util.is_numeral(arg2, y); + N = rational::power_of_two(sz); + if (is_num_x) + x = mod(x, N); + if (is_num_y) + y = mod(y, N); + if (is_num_x && x.is_zero()) { + result = m_util.mk_int(0); + return BR_DONE; + } + if (is_num_y && y.is_zero()) { + result = m_util.mk_int(0); + return BR_DONE; + } + if (is_num_x && is_num_y) { + rational r(0); + for (unsigned i = 0; i < sz; ++i) + if (x.get_bit(i) && y.get_bit(i)) + r += rational::power_of_two(i); + result = m_util.mk_int(r); + return BR_DONE; + } + return BR_FAILED; +} + br_status arith_rewriter::mk_power_core(expr * arg1, expr * arg2, expr_ref & result) { numeral x, y; bool is_num_x = m_util.is_numeral(arg1, x); diff --git a/src/ast/rewriter/arith_rewriter.h b/src/ast/rewriter/arith_rewriter.h index edc84b25a..548ab80db 100644 --- a/src/ast/rewriter/arith_rewriter.h +++ b/src/ast/rewriter/arith_rewriter.h @@ -159,6 +159,7 @@ public: br_status mk_mod_core(expr * arg1, expr * arg2, expr_ref & result); br_status mk_rem_core(expr * arg1, expr * arg2, expr_ref & result); br_status mk_power_core(expr* arg1, expr* arg2, expr_ref & result); + br_status mk_band_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result); void mk_div(expr * arg1, expr * arg2, expr_ref & result) { if (mk_div_core(arg1, arg2, result) == BR_FAILED) result = m.mk_app(get_fid(), OP_DIV, arg1, arg2); diff --git a/src/sat/smt/polysat/op_constraint.cpp b/src/sat/smt/polysat/op_constraint.cpp index 666d950a9..713b9d6ef 100644 --- a/src/sat/smt/polysat/op_constraint.cpp +++ b/src/sat/smt/polysat/op_constraint.cpp @@ -319,15 +319,12 @@ namespace polysat { // auto& m = p.manager(); - auto const pv = c.subst(p); - auto const qv = c.subst(q); - auto const rv = c.subst(r); unsigned const N = m.power_of_2(); auto& C = c.cs(); - c.add_clause("q >= N & p < 0 -> p << q = -1", {~C.uge(q, N), ~C.slt(p, 0), C.eq(r, m.max_value())}, true); - c.add_clause("q >= N & p >= 0 -> p << q = 0", {~C.uge(q, N), ~C.sge(p, 0), C.eq(r)}, true); - c.add_clause("q = 0 -> p << q = p", { ~C.eq(q), C.eq(r, p) }, true); + c.add_clause("q >= N & p < 0 -> p << q = -1", {~C.uge(q, N), ~C.slt(p, 0), C.eq(r, m.max_value())}, false); + c.add_clause("q >= N & p >= 0 -> p << q = 0", {~C.uge(q, N), ~C.sge(p, 0), C.eq(r)}, false); + c.add_clause("q = 0 -> p << q = p", { ~C.eq(q), C.eq(r, p) }, false); for (unsigned k = 0; k < N; ++k) { // c.add_clause("q = k & p >= 0 -> p << q = p / 2^k", {~C.eq(q, k), ~C.sge(p, 0), ... }, true); // c.add_clause("q = k & p < 0 -> p << q = (p / 2^k) -1 + 2^{N-k}", {~C.eq(q, k), ~C.slt(p, 0), ... }, true); @@ -338,6 +335,10 @@ namespace polysat { void op_constraint::activate_and(core& c, dependency const& d) { auto x = p, y = q; auto& C = c.cs(); + + c.add_clause("band-mask p&q <= p", { C.ule(r, p) }, false); + c.add_clause("band-mask p&q <= q", { C.ule(r, q) }, false); + if (x.is_val()) std::swap(x, y); if (!y.is_val()) @@ -356,7 +357,6 @@ namespace polysat { SASSERT(k < N); rational exp = rational::power_of_two(N - k); c.add_clause("band-mask 1", { C.eq(x * exp, r * exp) }, false); - c.add_clause("band-mask 2", { C.ule(r, y) }, false); // maybe always activate these constraints regardless? } } From 657dcdeb6135dcc27c380416edf4c82055a51396 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 15 Dec 2023 16:02:13 -0800 Subject: [PATCH 87/89] ps Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat/CMakeLists.txt | 1 + src/sat/smt/polysat/core.cpp | 8 ---- src/sat/smt/polysat/core.h | 2 - src/sat/smt/polysat/op_constraint.cpp | 48 ++++++++++++++----- ...saturation.cpp.disabled => saturation.cpp} | 3 ++ src/sat/smt/polysat/saturation.h | 2 +- 6 files changed, 42 insertions(+), 22 deletions(-) rename src/sat/smt/polysat/{saturation.cpp.disabled => saturation.cpp} (99%) diff --git a/src/sat/smt/polysat/CMakeLists.txt b/src/sat/smt/polysat/CMakeLists.txt index c7f5e49d5..72d919b94 100644 --- a/src/sat/smt/polysat/CMakeLists.txt +++ b/src/sat/smt/polysat/CMakeLists.txt @@ -6,6 +6,7 @@ z3_add_component(polysat fixed_bits.cpp forbidden_intervals.cpp op_constraint.cpp + saturation.cpp ule_constraint.cpp umul_ovfl_constraint.cpp viable.cpp diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index c9deb5726..7c6977299 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -204,13 +204,6 @@ namespace polysat { return true; } - signed_constraint core::get_constraint(unsigned idx, bool sign) { - auto sc = m_constraint_index[idx].sc; - if (sign) - sc = ~sc; - return sc; - } - void core::propagate_assignment(constraint_id idx) { auto [sc, dep, value] = m_constraint_index[idx.id]; SASSERT(value != l_undef); @@ -252,7 +245,6 @@ namespace polysat { bool swapped = false; for (unsigned i = vars.size(); i-- > 2; ) { if (!is_assigned(vars[i])) { - verbose_stream() << "watch instead " << vars[i] << " instead of " << vars[0] << " for " << idx << "\n"; add_watch(idx, vars[i]); std::swap(vars[i], vars[0]); swapped = true; diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index 46661dc84..cfc6c1d82 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -87,8 +87,6 @@ namespace polysat { void add_watch(unsigned idx, unsigned var); - signed_constraint get_constraint(unsigned idx, bool sign); - lbool eval(signed_constraint const& sc); dependency_vector explain_eval(signed_constraint const& sc); diff --git a/src/sat/smt/polysat/op_constraint.cpp b/src/sat/smt/polysat/op_constraint.cpp index 713b9d6ef..d36e069f6 100644 --- a/src/sat/smt/polysat/op_constraint.cpp +++ b/src/sat/smt/polysat/op_constraint.cpp @@ -313,22 +313,12 @@ namespace polysat { } void op_constraint::activate_ashr(core& c, dependency const& d) { - // - // if q = k & p >= 0 -> r*2^k + - // if q = k & p < 0 -> (p / 2^k) - 1 + 2^{N-k} - // - auto& m = p.manager(); unsigned const N = m.power_of_2(); - auto& C = c.cs(); c.add_clause("q >= N & p < 0 -> p << q = -1", {~C.uge(q, N), ~C.slt(p, 0), C.eq(r, m.max_value())}, false); c.add_clause("q >= N & p >= 0 -> p << q = 0", {~C.uge(q, N), ~C.sge(p, 0), C.eq(r)}, false); c.add_clause("q = 0 -> p << q = p", { ~C.eq(q), C.eq(r, p) }, false); - for (unsigned k = 0; k < N; ++k) { -// c.add_clause("q = k & p >= 0 -> p << q = p / 2^k", {~C.eq(q, k), ~C.sge(p, 0), ... }, true); -// c.add_clause("q = k & p < 0 -> p << q = (p / 2^k) -1 + 2^{N-k}", {~C.eq(q, k), ~C.slt(p, 0), ... }, true); - } } @@ -361,7 +351,43 @@ namespace polysat { } void op_constraint::propagate_ashr(core& c, dependency const& dep) { - + // + // Suppose q = k, p >= 0: + // p = ab, where b has length k, a has length N - k + // r = 0a, where 0 has length k, a has length N - k + // r*2^k = a0 + // ab - a0 = 0b = p - r*2^k < 2^k + // r < 2^{N-k} + // + // Suppose q = k, p < 0 + // p = ab + // r = 111a where 111 has length k + // r*2^k = a0 + // ab - a0 = 0b = p - r*2^k < 2^k + // r >= 1110 + // example: + // 1100 = 12 = 16 - 4 = 2^4 - 2^2 = 2^N - 2^k + // + // Succinct: + // if q = k & p >= 0 -> r*2^k + p < 2^{N-k} && r < 2^k + // if q = k & p < 0 -> (p / 2^k) - 2^N + 2^{N-k} + // + auto& m = p.manager(); + auto N = m.power_of_2(); + auto qv = c.subst(q); + if (qv.is_val() && 1 <= qv.val() && qv.val() < N) { + auto pv = c.subst(p); + auto rv = c.subst(r); + auto& C = c.cs(); + unsigned k = qv.val().get_unsigned(); + rational twoN = rational::power_of_two(N); + rational twoK = rational::power_of_two(k); + rational twoNk = rational::power_of_two(N - k); + auto eqK = C.eq(q, k); + c.add_clause("q = k -> r*2^k + p < 2^k", { ~eqK, C.ult(p - r * twoK, twoK) }, true); + c.add_clause("q = k & p >= 0 -> r < 2^{N-k}", { ~eqK, ~C.ule(0, p), C.ult(r, twoNk) }, true); + c.add_clause("q = k & p < 0 -> r >= 2^N - 2^{N-k}", { ~eqK, ~C.slt(p, 0), C.uge(r, twoN - twoNk) }, true); + } } diff --git a/src/sat/smt/polysat/saturation.cpp.disabled b/src/sat/smt/polysat/saturation.cpp similarity index 99% rename from src/sat/smt/polysat/saturation.cpp.disabled rename to src/sat/smt/polysat/saturation.cpp index 81fd6f221..c2a961e7f 100644 --- a/src/sat/smt/polysat/saturation.cpp.disabled +++ b/src/sat/smt/polysat/saturation.cpp @@ -34,6 +34,7 @@ namespace polysat { saturation::saturation(core& c) : c(c), C(c.cs()) {} +#if 0 void saturation::perform(pvar v) { for (signed_constraint c : core) perform(v, sc, core); @@ -2187,4 +2188,6 @@ namespace polysat { } } +#endif + } diff --git a/src/sat/smt/polysat/saturation.h b/src/sat/smt/polysat/saturation.h index f0dcc56ce..2d81f52fa 100644 --- a/src/sat/smt/polysat/saturation.h +++ b/src/sat/smt/polysat/saturation.h @@ -13,7 +13,7 @@ Author: --*/ #pragma once -#include "math/polysat/constraints.h" +#include "sat/smt/polysat/constraints.h" namespace polysat { From 275e72a358168f0a7fed7d0533189f018ef9afb3 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 15 Dec 2023 16:28:59 -0800 Subject: [PATCH 88/89] refactor for handling cores Signed-off-by: Nikolaj Bjorner --- src/sat/smt/polysat/core.cpp | 44 +++++++++++++++++++++------------- src/sat/smt/polysat/core.h | 13 ++++++---- src/sat/smt/polysat/types.h | 7 ++++-- src/sat/smt/polysat/viable.cpp | 6 ++--- src/sat/smt/polysat/viable.h | 2 +- 5 files changed, 44 insertions(+), 28 deletions(-) diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index 7c6977299..6cf1db764 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -38,7 +38,7 @@ namespace polysat { public: mk_assign_var(pvar v, core& c) : m_var(v), c(c) {} void undo() { - c.m_justification[m_var] = null_dependency; + c.m_justification[m_var] = constraint_id::null(); c.m_assignment.pop(); } }; @@ -123,7 +123,7 @@ namespace polysat { unsigned v = m_vars.size(); m_vars.push_back(sz2pdd(sz).mk_var(v)); m_activity.push_back({ sz, 0 }); - m_justification.push_back(null_dependency); + m_justification.push_back(constraint_id::null()); m_watch.push_back({}); m_var_queue.mk_var_eh(v); m_viable.ensure_var(v); @@ -174,11 +174,11 @@ namespace polysat { s.trail().push(mk_dqueue_var(m_var, *this)); switch (m_viable.find_viable(m_var, m_value)) { case find_t::empty: - s.set_lemma(m_viable.get_core(), m_viable.explain()); + s.set_lemma(m_viable.get_core(), get_dependencies(m_viable.explain())); // propagate_unsat_core(); return sat::check_result::CR_CONTINUE; case find_t::singleton: - s.propagate(m_constraints.eq(var2pdd(m_var), m_value), m_viable.explain()); + s.propagate(m_constraints.eq(var2pdd(m_var), m_value), get_dependencies(m_viable.explain())); return sat::check_result::CR_CONTINUE; case find_t::multiple: s.add_eq_literal(m_var, m_value); @@ -210,7 +210,7 @@ namespace polysat { if (value == l_false) sc = ~sc; if (sc.is_eq(m_var, m_value)) - propagate_assignment(m_var, m_value, dep); + propagate_assignment(m_var, m_value, idx); else sc.activate(*this, dep); } @@ -219,7 +219,7 @@ namespace polysat { m_watch[var].push_back(idx); } - void core::propagate_assignment(pvar v, rational const& value, dependency dep) { + void core::propagate_assignment(pvar v, rational const& value, constraint_id dep) { if (is_assigned(v)) return; if (m_var_queue.contains(v)) { @@ -255,7 +255,7 @@ namespace polysat { // this can create fresh literals and update m_watch, but // will not update m_watch[v] (other than copy constructor for m_watch) // because v has been assigned a value. - propagate(sc, value, dep); + propagate({ idx }, sc, value, dep); if (s.inconsistent()) return; @@ -280,7 +280,7 @@ namespace polysat { void core::propagate_value(constraint_id idx) { auto [sc, d, value] = m_constraint_index[idx.id]; // propagate current assignment for sc - propagate(sc, value, d); + propagate(idx, sc, value, d); if (s.inconsistent()) return; @@ -292,10 +292,10 @@ namespace polysat { auto [sc, d, value] = m_constraint_index[idx1]; switch (eval(sc)) { case l_false: - s.propagate(d, true, explain_eval(sc)); + s.propagate(d, true, get_dependencies(explain_eval(sc))); break; case l_true: - s.propagate(d, false, explain_eval(sc)); + s.propagate(d, false, get_dependencies(explain_eval(sc))); break; default: break; @@ -304,15 +304,25 @@ namespace polysat { } } - void core::propagate(signed_constraint& sc, lbool value, dependency const& d) { + dependency_vector core::get_dependencies(constraint_id_vector const& cc) { + dependency_vector result; + for (auto idx : cc) { + auto [sc, d, value] = m_constraint_index[idx.id]; + SASSERT(value != l_undef); + result.push_back(value == l_false ? ~d : d); + } + return result; + } + + void core::propagate(constraint_id id, signed_constraint& sc, lbool value, dependency const& d) { lbool eval_value = eval(sc); if (eval_value == l_undef) sc.propagate(*this, value, d); else if (value == l_undef) - s.propagate(d, eval_value != l_true, explain_eval(sc)); + s.propagate(d, eval_value != l_true, get_dependencies(explain_eval(sc))); else if (value != eval_value) { m_unsat_core = explain_eval(sc); - m_unsat_core.push_back(value == l_false ? ~d : d); + m_unsat_core.push_back(id); propagate_unsat_core(); } } @@ -333,7 +343,7 @@ namespace polysat { // default is to use unsat core: // if core is based on viable, use s.set_lemma(); - s.set_conflict(m_unsat_core); + s.set_conflict(get_dependencies(m_unsat_core)); } void core::assign_eh(constraint_id index, bool sign, unsigned level) { @@ -352,8 +362,8 @@ namespace polysat { s.trail().push(unassign(*this, index.id)); } - dependency_vector core::explain_eval(signed_constraint const& sc) { - dependency_vector deps; + constraint_id_vector core::explain_eval(signed_constraint const& sc) { + constraint_id_vector deps; for (auto v : sc.vars()) if (is_assigned(v)) deps.push_back(m_justification[v]); @@ -379,7 +389,7 @@ namespace polysat { for (auto const& [sc, d, value] : m_constraint_index) out << sc << " " << d << " := " << value << "\n"; for (unsigned i = 0; i < m_vars.size(); ++i) - out << m_vars[i] << " := " << m_values[i] << " " << m_justification[i] << "\n"; + out << m_vars[i] << " := " << m_values[i] << " " << m_constraint_index[m_justification[i].id].d << "\n"; m_var_queue.display(out << "vars ") << "\n"; return out; } diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index cfc6c1d82..109f0ac0e 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -31,6 +31,8 @@ namespace polysat { class core; class solver_interface; + + class core { class mk_add_var; class mk_dqueue_var; @@ -54,13 +56,13 @@ namespace polysat { unsigned m_qhead = 0, m_vqhead = 0; svector m_prop_queue; svector m_constraint_index; // index of constraints - dependency_vector m_unsat_core; + constraint_id_vector m_unsat_core; // attributes associated with variables vector m_vars; // for each variable a pdd vector m_values; // current value of assigned variable - svector m_justification; // justification for assignment + svector m_justification; // justification for assignment activity m_activity; // activity of variables var_queue m_var_queue; // priority queue of variables to assign vector m_watch; // watch lists for variables for constraints on m_prop_queue where they occur @@ -77,9 +79,9 @@ namespace polysat { bool is_assigned(pvar v) { return !m_justification[v].is_null(); } void propagate_value(constraint_id idx); void propagate_assignment(constraint_id idx); - void propagate_assignment(pvar v, rational const& value, dependency dep); + void propagate_assignment(pvar v, rational const& value, constraint_id dep); void propagate_unsat_core(); - void propagate(signed_constraint& sc, lbool value, dependency const& d); + void propagate(constraint_id id, signed_constraint& sc, lbool value, dependency const& d); void get_bitvector_prefixes(pvar v, pvar_vector& out); void get_fixed_bits(pvar v, svector& fixed_bits); @@ -88,7 +90,8 @@ namespace polysat { void add_watch(unsigned idx, unsigned var); lbool eval(signed_constraint const& sc); - dependency_vector explain_eval(signed_constraint const& sc); + constraint_id_vector explain_eval(signed_constraint const& sc); + dependency_vector get_dependencies(constraint_id_vector const& cc); void add_axiom(signed_constraint sc); diff --git a/src/sat/smt/polysat/types.h b/src/sat/smt/polysat/types.h index d9008392c..e0aefb6a9 100644 --- a/src/sat/smt/polysat/types.h +++ b/src/sat/smt/polysat/types.h @@ -22,7 +22,10 @@ namespace polysat { using pdd = dd::pdd; using pvar = unsigned; using theory_var = unsigned; - struct constraint_id { unsigned id; }; + struct constraint_id { + unsigned id; bool is_null() const { return id == UINT_MAX; } + static constraint_id null() { return constraint_id{ UINT_MAX }; } + }; using pvar_vector = unsigned_vector; inline const pvar null_var = UINT_MAX; @@ -80,7 +83,7 @@ namespace polysat { using dependency_vector = vector; using core_vector = std::initializer_list>; - + using constraint_id_vector = svector; // diff --git a/src/sat/smt/polysat/viable.cpp b/src/sat/smt/polysat/viable.cpp index 20a31b730..3f217d1ce 100644 --- a/src/sat/smt/polysat/viable.cpp +++ b/src/sat/smt/polysat/viable.cpp @@ -809,12 +809,12 @@ namespace polysat { /* * Explain why the current variable is not viable or signleton. */ - dependency_vector viable::explain() { - dependency_vector result; + constraint_id_vector viable::explain() { + constraint_id_vector result; for (auto e : m_explain) { auto index = e->constraint_index; auto const& [sc, d, value] = c.m_constraint_index[index]; - result.push_back(d); + result.push_back({ index }); result.append(c.explain_eval(sc)); } // TODO: explaination for fixed bits diff --git a/src/sat/smt/polysat/viable.h b/src/sat/smt/polysat/viable.h index 64a8d3194..03f94e698 100644 --- a/src/sat/smt/polysat/viable.h +++ b/src/sat/smt/polysat/viable.h @@ -253,7 +253,7 @@ namespace polysat { /* * Explain why the current variable is not viable or signleton. */ - dependency_vector explain(); + constraint_id_vector explain(); /* * flag whether there is a forbidden interval core From bdc40b1f5f83cca22dc1d6c5808e935a3b50176c Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 16 Dec 2023 16:10:06 -0800 Subject: [PATCH 89/89] na --- src/sat/smt/intblast_solver.cpp | 98 ++++----- src/sat/smt/polysat/CMakeLists.txt | 1 + src/sat/smt/polysat/constraints.cpp | 8 + src/sat/smt/polysat/constraints.h | 2 + src/sat/smt/polysat/core.cpp | 27 ++- src/sat/smt/polysat/core.h | 19 +- src/sat/smt/polysat/inequality.cpp | 131 ++++++++++++ src/sat/smt/polysat/inequality.h | 164 +++++++++++++++ src/sat/smt/polysat/saturation.cpp | 296 ++++++---------------------- src/sat/smt/polysat/saturation.h | 126 ++---------- src/sat/smt/polysat/types.h | 4 +- src/sat/smt/polysat/viable.cpp | 2 +- src/sat/smt/polysat_solver.cpp | 10 +- src/sat/smt/polysat_solver.h | 4 +- 14 files changed, 478 insertions(+), 414 deletions(-) create mode 100644 src/sat/smt/polysat/inequality.cpp create mode 100644 src/sat/smt/polysat/inequality.h diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 9d03d0ad0..27f85525d 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -104,10 +104,10 @@ namespace intblast { ctx.push(push_back_vector(m_preds)); } - void solver::set_translated(expr* e, expr* r) { + void solver::set_translated(expr* e, expr* r) { SASSERT(r); - SASSERT(!is_translated(e)); - m_translate.setx(e->get_id(), r); + SASSERT(!is_translated(e)); + m_translate.setx(e->get_id(), r); ctx.push(set_vector_idx_trail(m_translate, e->get_id())); } @@ -148,7 +148,7 @@ namespace intblast { auto a = expr2literal(e); auto b = mk_literal(r); ctx.mark_relevant(b); -// verbose_stream() << "add-predicate-axiom: " << mk_pp(e, m) << " == " << r << "\n"; + // verbose_stream() << "add-predicate-axiom: " << mk_pp(e, m) << " == " << r << "\n"; add_equiv(a, b); } return true; @@ -157,7 +157,7 @@ namespace intblast { bool solver::unit_propagate() { return add_bound_axioms() || add_predicate_axioms(); } - + void solver::ensure_translated(expr* e) { if (m_translate.get(e->get_id(), nullptr)) return; @@ -179,7 +179,7 @@ namespace intblast { } } std::stable_sort(todo.begin(), todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); - for (expr* e : todo) + for (expr* e : todo) translate_expr(e); } @@ -305,28 +305,6 @@ namespace intblast { sorted.push_back(arg); } } - - // - // Add ground equalities to ensure the model is valid with respect to the current case splits. - // This may cause more conflicts than necessary. Instead could use intblast on the base level, but using literal - // assignment from complete level. - // E.g., force the solver to completely backtrack, check satisfiability using the assignment obtained under a complete assignment. - // If intblast is SAT, then force the model and literal assignment on the rest of the literals. - // - if (!is_ground(e)) - continue; - euf::enode* n = ctx.get_enode(e); - if (!n) - continue; - if (n == n->get_root()) - continue; - expr* r = n->get_root()->get_expr(); - es.push_back(m.mk_eq(e, r)); - r = es.back(); - if (!visited.is_marked(r) && !is_translated(r)) { - visited.mark(r); - sorted.push_back(r); - } } else if (is_quantifier(e)) { quantifier* q = to_quantifier(e); @@ -357,7 +335,7 @@ namespace intblast { es[i] = translated(es.get(i)); } - sat::check_result solver::check() { + sat::check_result solver::check() { // ensure that bv2int is injective for (auto e : m_bv2int) { euf::enode* n = expr2enode(e); @@ -369,10 +347,10 @@ namespace intblast { continue; if (sib->get_arg(0)->get_root() == r1) continue; - auto a = eq_internalize(n, sib); - auto b = eq_internalize(sib->get_arg(0), n->get_arg(0)); - ctx.mark_relevant(a); - ctx.mark_relevant(b); + auto a = eq_internalize(n, sib); + auto b = eq_internalize(sib->get_arg(0), n->get_arg(0)); + ctx.mark_relevant(a); + ctx.mark_relevant(b); add_clause(~a, b, nullptr); return sat::check_result::CR_CONTINUE; } @@ -390,13 +368,13 @@ namespace intblast { auto nBv2int = ctx.get_enode(bv2int); auto nxModN = ctx.get_enode(xModN); if (nBv2int->get_root() != nxModN->get_root()) { - auto a = eq_internalize(nBv2int, nxModN); - ctx.mark_relevant(a); + auto a = eq_internalize(nBv2int, nxModN); + ctx.mark_relevant(a); add_unit(a); return sat::check_result::CR_CONTINUE; } } - return sat::check_result::CR_DONE; + return sat::check_result::CR_DONE; } expr* solver::umod(expr* bv_expr, unsigned i) { @@ -504,8 +482,8 @@ namespace intblast { m_args[i] = bv.mk_int2bv(bv.get_bv_size(e->get_arg(i)), m_args.get(i)); if (has_bv_sort) - m_vars.push_back(e); - + m_vars.push_back(e); + if (m_is_plugin) { expr* r = m.mk_app(f, m_args); if (has_bv_sort) { @@ -526,7 +504,7 @@ namespace intblast { f = g; m_pinned.push_back(f); } - set_translated(e, m.mk_app(f, m_args)); + set_translated(e, m.mk_app(f, m_args)); } void solver::translate_bv(app* e) { @@ -558,7 +536,7 @@ namespace intblast { r = a.mk_add(hi, lo); } return r; - }; + }; expr* bv_expr = e; expr* r = nullptr; @@ -659,7 +637,7 @@ namespace intblast { expr* x = arg(0), * y = umod(e, 1); r = a.mk_int(0); for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); break; } case OP_BNOT: @@ -671,7 +649,7 @@ namespace intblast { for (unsigned i = 0; i < bv.get_bv_size(e); ++i) r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); break; - } + } case OP_BOR: { // p | q := (p + q) - band(p, q) r = arg(0); @@ -706,13 +684,13 @@ namespace intblast { // unsigned sz = bv.get_bv_size(e); rational N = bv_size(e); - expr* x = umod(e, 0), *y = umod(e, 1); + expr* x = umod(e, 0), * y = umod(e, 1); expr* signx = a.mk_ge(x, a.mk_int(N / 2)); - r = m.mk_ite(signx, a.mk_int(- 1), a.mk_int(0)); + r = m.mk_ite(signx, a.mk_int(-1), a.mk_int(0)); for (unsigned i = 0; i < sz; ++i) { - expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); + expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), - m.mk_ite(signx, a.mk_add(d, a.mk_int(- rational::power_of_two(sz-i))), d), + m.mk_ite(signx, a.mk_add(d, a.mk_int(-rational::power_of_two(sz - i))), d), r); } break; @@ -749,11 +727,11 @@ namespace intblast { r = m.mk_ite(m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(1), a.mk_int(0)); break; case OP_BSMOD_I: - case OP_BSMOD: { - expr* x = umod(e, 0), *y = umod(e, 1); - rational N = bv_size(e); - expr* signx = a.mk_ge(x, a.mk_int(N/2)); - expr* signy = a.mk_ge(y, a.mk_int(N/2)); + case OP_BSMOD: { + expr* x = umod(e, 0), * y = umod(e, 1); + rational N = bv_size(e); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + expr* signy = a.mk_ge(y, a.mk_int(N / 2)); expr* u = a.mk_mod(x, y); // u = 0 -> 0 // y = 0 -> x @@ -761,14 +739,14 @@ namespace intblast { // x < 0, y >= 0 -> y - u // x >= 0, y < 0 -> y + u // x >= 0, y >= 0 -> u - r = a.mk_uminus(u); + r = a.mk_uminus(u); r = m.mk_ite(m.mk_and(m.mk_not(signx), signy), a.mk_add(u, y), r); r = m.mk_ite(m.mk_and(signx, m.mk_not(signy)), a.mk_sub(y, u), r); r = m.mk_ite(m.mk_and(m.mk_not(signx), m.mk_not(signy)), u, r); r = m.mk_ite(m.mk_eq(u, a.mk_int(0)), a.mk_int(0), r); r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, r); break; - } + } case OP_BSDIV_I: case OP_BSDIV: { // d = udiv(abs(x), abs(y)) @@ -804,7 +782,7 @@ namespace intblast { d = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); r = a.mk_sub(x, a.mk_mul(d, y)); r = m.mk_ite(m.mk_eq(y, a.mk_int(0)), x, r); - break; + break; } case OP_ROTATE_LEFT: { auto n = e->get_parameter(0).get_int(); @@ -817,11 +795,11 @@ namespace intblast { r = rotate_left(sz - n); break; } - case OP_EXT_ROTATE_LEFT: { + case OP_EXT_ROTATE_LEFT: { unsigned sz = bv.get_bv_size(e); expr* y = umod(e, 1); r = a.mk_int(0); - for (unsigned i = 0; i < sz; ++i) + for (unsigned i = 0; i < sz; ++i) r = m.mk_ite(m.mk_eq(a.mk_int(i), y), rotate_left(i), r); break; } @@ -829,7 +807,7 @@ namespace intblast { unsigned sz = bv.get_bv_size(e); expr* y = umod(e, 1); r = a.mk_int(0); - for (unsigned i = 0; i < sz; ++i) + for (unsigned i = 0; i < sz; ++i) r = m.mk_ite(m.mk_eq(a.mk_int(i), y), rotate_left(sz - i), r); break; } @@ -842,7 +820,7 @@ namespace intblast { for (unsigned i = 1; i < n; ++i) r = a.mk_add(a.mk_mul(a.mk_int(N), x), r), N *= N0; break; - } + } case OP_BREDOR: { r = umod(e->get_arg(0), 0); r = m.mk_not(m.mk_eq(r, a.mk_int(0))); @@ -902,7 +880,7 @@ namespace intblast { } bool solver::add_dep(euf::enode* n, top_sort& dep) { - if (!is_app(n->get_expr())) + if (!is_app(n->get_expr())) return false; app* e = to_app(n->get_expr()); if (n->num_args() == 0) { @@ -921,7 +899,7 @@ namespace intblast { void solver::add_value_solver(euf::enode* n, model& mdl, expr_ref_vector& values) { expr* e = n->get_expr(); SASSERT(bv.is_bv(e)); - + if (bv.is_numeral(e)) { values.setx(n->get_root_id(), e); return; diff --git a/src/sat/smt/polysat/CMakeLists.txt b/src/sat/smt/polysat/CMakeLists.txt index 72d919b94..70e0f9592 100644 --- a/src/sat/smt/polysat/CMakeLists.txt +++ b/src/sat/smt/polysat/CMakeLists.txt @@ -5,6 +5,7 @@ z3_add_component(polysat core.cpp fixed_bits.cpp forbidden_intervals.cpp + inequality.cpp op_constraint.cpp saturation.cpp ule_constraint.cpp diff --git a/src/sat/smt/polysat/constraints.cpp b/src/sat/smt/polysat/constraints.cpp index 83476160a..c96b43d6a 100644 --- a/src/sat/smt/polysat/constraints.cpp +++ b/src/sat/smt/polysat/constraints.cpp @@ -89,4 +89,12 @@ namespace polysat { return out << *m_constraint; } + bool signed_constraint::is_currently_true(core& c) const { + return eval(c.get_assignment()) == l_true; + } + + bool signed_constraint::is_currently_false(core& c) const { + return eval(c.get_assignment()) == l_false; + } + } diff --git a/src/sat/smt/polysat/constraints.h b/src/sat/smt/polysat/constraints.h index fa2b62c11..28d4b0529 100644 --- a/src/sat/smt/polysat/constraints.h +++ b/src/sat/smt/polysat/constraints.h @@ -72,6 +72,8 @@ namespace polysat { void propagate(core& c, lbool value, dependency const& d) { m_constraint->propagate(c, value, d); } bool is_always_true() const { return eval() == l_true; } bool is_always_false() const { return eval() == l_false; } + bool is_currently_true(core& c) const; + bool is_currently_false(core& c) const; lbool eval(assignment& a) const; lbool eval() const { return m_sign ? ~m_constraint->eval() : m_constraint->eval();} ckind_t op() const { return m_op; } diff --git a/src/sat/smt/polysat/core.cpp b/src/sat/smt/polysat/core.cpp index 6cf1db764..79f3f6bcb 100644 --- a/src/sat/smt/polysat/core.cpp +++ b/src/sat/smt/polysat/core.cpp @@ -314,6 +314,16 @@ namespace polysat { return result; } + dependency_vector core::get_dependencies(std::initializer_list const& cc) { + dependency_vector result; + for (auto idx : cc) { + auto [sc, d, value] = m_constraint_index[idx.id]; + SASSERT(value != l_undef); + result.push_back(value == l_false ? ~d : d); + } + return result; + } + void core::propagate(constraint_id id, signed_constraint& sc, lbool value, dependency const& d) { lbool eval_value = eval(sc); if (eval_value == l_undef) @@ -327,8 +337,8 @@ namespace polysat { } } - void core::get_bitvector_prefixes(pvar v, pvar_vector& out) { - s.get_bitvector_prefixes(v, out); + void core::get_bitvector_suffixes(pvar v, pvar_vector& out) { + s.get_bitvector_suffixes(v, out); } void core::get_fixed_bits(pvar v, svector& fixed_bits) { @@ -415,4 +425,17 @@ namespace polysat { s.add_polysat_clause(name, cs, is_redundant); } + signed_constraint core::get_constraint(constraint_id idx) { + auto [sc, d, value] = m_constraint_index[idx.id]; + SASSERT(value != l_undef); + if (value == l_false) + sc = ~sc; + return sc; + } + + lbool core::eval(constraint_id id) { + auto sc = get_constraint(id); + return sc.eval(m_assignment); + } + } diff --git a/src/sat/smt/polysat/core.h b/src/sat/smt/polysat/core.h index 109f0ac0e..37cc348f7 100644 --- a/src/sat/smt/polysat/core.h +++ b/src/sat/smt/polysat/core.h @@ -83,7 +83,7 @@ namespace polysat { void propagate_unsat_core(); void propagate(constraint_id id, signed_constraint& sc, lbool value, dependency const& d); - void get_bitvector_prefixes(pvar v, pvar_vector& out); + void get_bitvector_suffixes(pvar v, pvar_vector& out); void get_fixed_bits(pvar v, svector& fixed_bits); bool inconsistent() const; @@ -92,6 +92,9 @@ namespace polysat { lbool eval(signed_constraint const& sc); constraint_id_vector explain_eval(signed_constraint const& sc); dependency_vector get_dependencies(constraint_id_vector const& cc); + dependency_vector get_dependencies(std::initializer_list const& cc); + + void add_axiom(signed_constraint sc); @@ -143,6 +146,20 @@ namespace polysat { trail_stack& trail(); std::ostream& display(std::ostream& out) const; + + /* + * Saturation + */ + signed_constraint get_constraint(constraint_id id); + constraint_id_vector const& unsat_core() const { return m_unsat_core; } + lbool eval(constraint_id id); + bool propagate(signed_constraint const& sc, constraint_id_vector const& ids) { return s.propagate(sc, get_dependencies(ids)); } + bool propagate(signed_constraint const& sc, std::initializer_list const& ids) { return s.propagate(sc, get_dependencies(ids)); } + + /* + * Constraints + */ + assignment& get_assignment() { return m_assignment; } }; } diff --git a/src/sat/smt/polysat/inequality.cpp b/src/sat/smt/polysat/inequality.cpp new file mode 100644 index 000000000..e2233e5e1 --- /dev/null +++ b/src/sat/smt/polysat/inequality.cpp @@ -0,0 +1,131 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + Polysat inequalities + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-6 + + + +--*/ +#include "sat/smt/polysat/core.h" +#include "sat/smt/polysat/inequality.h" +#include "sat/smt/polysat/ule_constraint.h" + + +namespace polysat { + + + inequality inequality::from_ule(core& c, constraint_id id) { + auto src = c.get_constraint(id); + ule_constraint const& ule = src.to_ule(); + if (src.is_positive()) + return inequality(c, id, ule.lhs(), ule.rhs(), src); + else + return inequality(c, id, ule.rhs(), ule.lhs(), src); + } + + +#if 0 + + + bool saturation::verify_Y_l_AxB(pvar x, inequality const& i, pdd const& y, pdd const& a, pdd& b) { + return i.lhs() == y && i.rhs() == a * c.var(x) + b; + } + + + /** + * Match [x] a*x + b <= y, val(y) = 0 + */ + bool saturation::is_AxB_eq_0(pvar x, inequality const& i, pdd& a, pdd& b, pdd& y) { + y.reset(i.rhs().manager()); + y = i.rhs(); + rational y_val; + if (!c.try_eval(y, y_val) || y_val != 0) + return false; + return i.lhs().degree(x) == 1 && (i.lhs().factor(x, 1, a, b), true); + } + + bool saturation::verify_AxB_eq_0(pvar x, inequality const& i, pdd const& a, pdd const& b, pdd const& y) { + return y.is_val() && y.val() == 0 && i.rhs() == y && i.lhs() == a * c.var(x) + b; + } + + bool saturation::is_AxB_diseq_0(pvar x, inequality const& i, pdd& a, pdd& b, pdd& y) { + if (!i.is_strict()) + return false; + y.reset(i.lhs().manager()); + y = i.lhs(); + if (i.rhs().is_val() && i.rhs().val() == 1) + return false; + rational y_val; + if (!c.try_eval(y, y_val) || y_val != 0) + return false; + a.reset(i.rhs().manager()); + b.reset(i.rhs().manager()); + return i.rhs().degree(x) == 1 && (i.rhs().factor(x, 1, a, b), true); + } + + /** + * Match [coeff*x] coeff*x*Y where x is a variable + */ + bool saturation::is_coeffxY(pdd const& x, pdd const& p, pdd& y) { + pdd xy = x.manager().zero(); + return x.is_unary() && p.try_div(x.hi().val(), xy) && xy.factor(x.var(), 1, y); + } + + /** + * Match [v] v*x <= z*x with x a variable + */ + bool saturation::is_Xy_l_XZ(pvar v, inequality const& i, pdd& x, pdd& z) { + return is_xY(v, i.lhs(), x) && is_coeffxY(x, i.rhs(), z); + } + + bool saturation::verify_Xy_l_XZ(pvar v, inequality const& i, pdd const& x, pdd const& z) { + return i.lhs() == c.var(v) * x && i.rhs() == z * x; + } + + + /** + * Determine whether values of x * y is non-overflowing. + */ + bool saturation::is_non_overflow(pdd const& x, pdd const& y) { + rational x_val, y_val; + rational bound = x.manager().two_to_N(); + return c.try_eval(x, x_val) && c.try_eval(y, y_val) && x_val * y_val < bound; + } + + + /** + * Match [z] yx <= zx with x a variable + */ + bool saturation::is_YX_l_zX(pvar z, inequality const& c, pdd& x, pdd& y) { + return is_xY(z, c.rhs(), x) && is_coeffxY(x, c.lhs(), y); + } + + bool saturation::verify_YX_l_zX(pvar z, inequality const& i, pdd const& x, pdd const& y) { + return i.lhs() == y * x && i.rhs() == c.var(z) * x; + } + + /** + * Match [x] xY <= xZ + */ + bool saturation::is_xY_l_xZ(pvar x, inequality const& c, pdd& y, pdd& z) { + return is_xY(x, c.lhs(), y) && is_xY(x, c.rhs(), z); + } + + /** + * Match xy = x * Y + */ + bool saturation::is_xY(pvar x, pdd const& xy, pdd& y) { + return xy.degree(x) == 1 && xy.factor(x, 1, y); + } + +#endif + + +} diff --git a/src/sat/smt/polysat/inequality.h b/src/sat/smt/polysat/inequality.h new file mode 100644 index 000000000..03a45bfc1 --- /dev/null +++ b/src/sat/smt/polysat/inequality.h @@ -0,0 +1,164 @@ +/*++ +Copyright (c) 2021 Microsoft Corporation + +Module Name: + + Polysat core saturation + +Author: + + Nikolaj Bjorner (nbjorner) 2021-03-19 + Jakob Rath 2021-04-6 + +--*/ +#pragma once + +#include "sat/smt/polysat/constraints.h" + +namespace polysat { + + /// Normalized inequality: + /// lhs <= rhs, if !is_strict + /// lhs < rhs, otherwise + class inequality { + core& c; + constraint_id m_id; + pdd m_lhs; + pdd m_rhs; + signed_constraint m_src; + + inequality(core& c, constraint_id id, pdd lhs, pdd rhs, signed_constraint src) : + c(c), m_id(id), m_lhs(std::move(lhs)), m_rhs(std::move(rhs)), m_src(std::move(src)) {} + + void set(pdd& dst, pdd const& src) const { + dst.reset(src.manager()); + dst = src; + } + + public: + static inequality from_ule(core& c, constraint_id id); + pdd const& lhs() const { return m_lhs; } + pdd const& rhs() const { return m_rhs; } + bool is_strict() const { return m_src.is_negative(); } + constraint_id id() const { return m_id; } + signed_constraint as_signed_constraint() const { return m_src; } + operator signed_constraint() const { return m_src; } + + // c := lhs ~ v + // where ~ is < or <= + bool is_l_v(pvar v) const { return rhs() == c.var(v); } + + // c := v ~ rhs + bool is_g_v(pvar v) const { return lhs() == c.var(v); } + + // c := x ~ Y + bool is_x_l_Y(pvar x, pdd& y) const { return is_g_v(x) && (set(y, rhs()), true); } + + // c := Y ~ x + bool is_Y_l_x(pvar x, pdd& y) const { return is_l_v(x) && (set(y, lhs()), true); } + + // c := Y ~ Ax + bool is_Y_l_Ax(pvar x, pdd& a, pdd& y) const { return is_xY(x, rhs(), a) && (set(y, lhs()), true); } + bool verify_Y_l_Ax(pvar x, pdd const& a, pdd const& y) const { return lhs() == y && rhs() == a * c.var(x); } + + // c := X*y ~ X*Z + bool is_Xy_l_XZ(pvar y, pdd& x, pdd& z) const { return is_xY(y, lhs(), x) && (false); } + bool verify_Xy_l_XZ(pvar y, pdd const& x, pdd const& z) const { lhs() == c.var(y) * x && rhs() == z * x; } + + // c := Ax ~ Y + bool is_Ax_l_Y(pvar x, pdd& a, pdd& y) const; + bool verify_Ax_l_Y(pvar x, pdd const& a, pdd const& y) const; + + // c := Ax + B ~ Y + bool is_AxB_l_Y(pvar x, pdd& a, pdd& b, pdd& y) const { + return lhs().degree(x) == 1 && (set(y, rhs()), lhs().factor(x, 1, a, b), true); + } + bool verify_AxB_l_Y(pvar x, pdd const& a, pdd const& b, pdd const& y) const { return rhs() == y && lhs() == a * c.var(x) + b; } + + // c := Y ~ Ax + B + bool is_Y_l_AxB(pvar x, pdd& y, pdd& a, pdd& b) const { return rhs().degree(x) == 1 && (set(y, lhs()), rhs().factor(x, 1, a, b), true); } + bool verify_Y_l_AxB(pvar x, pdd const& y, pdd const& a, pdd& b) const; + + // c := Ax + B ~ Y, val(Y) = 0 + bool is_AxB_eq_0(pvar x, pdd& a, pdd& b, pdd& y) const; + bool verify_AxB_eq_0(pvar x, pdd const& a, pdd const& b, pdd const& y) const; + + // c := Ax + B != Y, val(Y) = 0 + bool is_AxB_diseq_0(pvar x, pdd& a, pdd& b, pdd& y) const; + + // c := Y*X ~ z*X + bool is_YX_l_zX(pvar z, pdd& x, pdd& y) const; + bool verify_YX_l_zX(pvar z, pdd const& x, pdd const& y) const; + + // c := xY <= xZ + bool is_xY_l_xZ(pvar x, pdd& y, pdd& z) const; + + /** + * Match xy = x * Y + */ + static bool is_xY(pvar x, pdd const& xy, pdd& y) { return xy.degree(x) == 1 && xy.factor(x, 1, y); } + + /** + * Rewrite to one of six equivalent forms: + * + * i=0 p <= q (unchanged) + * i=1 p <= p - q - 1 + * i=2 q - p <= q + * i=3 q - p <= -p - 1 + * i=4 -q - 1 <= -p - 1 + * i=5 -q - 1 <= p - q - 1 + */ + //inequality rewrite_equiv(int i) const; + + //std::ostream& display(std::ostream& out) const; + }; + + + struct bilinear { + rational a, b, c, d; + + rational eval(rational const& x, rational const& y) const { + return a*x*y + b*x + c*y + d; + } + + bilinear operator-() const { + bilinear r(*this); + r.a = -r.a; + r.b = -r.b; + r.c = -r.c; + r.d = -r.d; + return r; + } + + bilinear operator-(bilinear const& other) const { + bilinear r(*this); + r.a -= other.a; + r.b -= other.b; + r.c -= other.c; + r.d -= other.d; + return r; + } + + bilinear operator+(rational const& d) const { + bilinear r(*this); + r.d += d; + return r; + } + + bilinear operator-(rational const& d) const { + bilinear r(*this); + r.d -= d; + return r; + } + + bilinear operator-(int d) const { + bilinear r(*this); + r.d -= d; + return r; + } +}; + + inline std::ostream& operator<<(std::ostream& out, bilinear const& b) { + return out << b.a << "*x*y + " << b.b << "*x + " << b.c << "*y + " << b.d; + } +} diff --git a/src/sat/smt/polysat/saturation.cpp b/src/sat/smt/polysat/saturation.cpp index c2a961e7f..204a44de3 100644 --- a/src/sat/smt/polysat/saturation.cpp +++ b/src/sat/smt/polysat/saturation.cpp @@ -34,29 +34,70 @@ namespace polysat { saturation::saturation(core& c) : c(c), C(c.cs()) {} -#if 0 - void saturation::perform(pvar v) { - for (signed_constraint c : core) - perform(v, sc, core); + void saturation::propagate(pvar v) { + for (auto id : c.unsat_core()) + propagate(v, id); } - bool saturation::perform(pvar v, signed_constraint sc) { - if (sc.is_currently_true(c)) + bool saturation::propagate(pvar v, constraint_id id) { + if (c.eval(id) == l_true) return false; + auto sc = c.get_constraint(id); + m_propagated = false; + if (sc.is_ule()) + propagate(v, inequality::from_ule(c, id)); + else + ; - if (sc.is_ule()) { - auto i = inequality::from_ule(sc); - return try_inequality(v, i); - } - if (sc.is_umul_ovfl()) - return try_umul_ovfl(v, sc); - - if (sc.is_op()) - return try_op(v, sc); - - return false; + return m_propagated; } + void saturation::propagate(pvar v, inequality const& i) { + if (c.size(v) != i.lhs().power_of_2()) + return; + propagate_infer_equality(v, i); + return; + + } + + void saturation::propagate(signed_constraint const& sc, std::initializer_list const& premises) { + if (c.propagate(sc, premises)) + m_propagated = true; + } + + /** + * p <= q, q <= p => p - q = 0 + */ + void saturation::propagate_infer_equality(pvar x, inequality const& a_l_b) { + set_rule("[x] p <= q, q <= p => p - q = 0"); + if (a_l_b.is_strict()) + return; + if (a_l_b.lhs().degree(x) == 0 && a_l_b.rhs().degree(x) == 0) + return; + for (auto id : c.unsat_core()) { + auto sc = c.get_constraint(id); + if (!sc.is_ule()) + continue; + auto i = inequality::from_ule(c, id); + if (i.lhs() == a_l_b.rhs() && i.rhs() == a_l_b.lhs() && !i.is_strict()) { + c.propagate(c.eq(i.lhs(), i.rhs()), { id, a_l_b.id() }); + return; + } + } + } + + /** + * Determine whether values of x * y is non-overflowing. + */ + bool saturation::is_non_overflow(pdd const& x, pdd const& y) { + rational x_val, y_val; + rational bound = x.manager().two_to_N(); + return c.try_eval(x, x_val) && c.try_eval(y, y_val) && x_val * y_val < bound; + } + +#if 0 + + bool saturation::try_inequality(pvar v, inequality const& i, conflict& core) { bool prop = false; if (s.size(v) != i.lhs().power_of_2()) @@ -98,35 +139,6 @@ namespace polysat { return prop; } - bool saturation::try_congruence(pvar x, conflict& core, inequality const& i) { - set_rule("egraph(x == y) & C(x,y) ==> C(y,y)"); - // TODO: generalize to other constraint types? - // if (!i.is_strict()) - // return false; - // if (!core.vars().contains(x)) - // return false; - if (!i.as_signed_constraint().contains_var(x)) - return false; - for (pvar y : s.m_slicing.equivalent_vars(x)) { - if (x == y) - continue; - if (!s.is_assigned(y)) - continue; - if (!core.vars().contains(y)) - continue; - if (!i.as_signed_constraint().contains_var(y)) - continue; - SASSERT(s.m_search.get_pvar_index(y) < s.m_search.get_pvar_index(x)); // y was the earlier one since we are currently resolving x - pdd const lhs = i.lhs().subst_pdd(x, s.var(y)); - pdd const rhs = i.rhs().subst_pdd(x, s.var(y)); - signed_constraint c = ineq(true, lhs, rhs); - m_lemma.reset(); - s.m_slicing.explain_equal(x, y, [&](sat::literal lit) { m_lemma.insert(~lit); }); - if (propagate(x, core, i, c)) - return true; - } - return false; - } bool saturation::try_nonzero_upper_extract(pvar y, conflict& core, inequality const& i) { set_rule("y = x[h:l] & y != 0 ==> x >= 2^l"); @@ -402,172 +414,7 @@ namespace polysat { return false; } - /* - * Match [v] .. <= v - */ - bool saturation::is_l_v(pvar v, inequality const& i) { - return i.rhs() == s.var(v); - } - - /* - * Match [v] v <= ... - */ - bool saturation::is_g_v(pvar v, inequality const& i) { - return i.lhs() == s.var(v); - } - - /* - * Match [x] x <= y - */ - bool saturation::is_x_l_Y(pvar x, inequality const& i, pdd& y) { - y.reset(i.rhs().manager()); - y = i.rhs(); - return is_g_v(x, i); - } - - /* - * Match [x] y <= x - */ - bool saturation::is_Y_l_x(pvar x, inequality const& i, pdd& y) { - y.reset(i.lhs().manager()); - y = i.lhs(); - return is_l_v(x, i); - } - - /* - * Match [x] y <= a*x - */ - bool saturation::is_Y_l_Ax(pvar x, inequality const& i, pdd& a, pdd& y) { - y.reset(i.lhs().manager()); - y = i.lhs(); - return is_xY(x, i.rhs(), a); - } - - bool saturation::verify_Y_l_Ax(pvar x, inequality const& i, pdd const& a, pdd const& y) { - return i.lhs() == y && i.rhs() == a * s.var(x); - } - - /** - * Match [x] a*x <= y - */ - bool saturation::is_Ax_l_Y(pvar x, inequality const& i, pdd& a, pdd& y) { - y.reset(i.rhs().manager()); - y = i.rhs(); - return is_xY(x, i.lhs(), a); - } - - bool saturation::verify_Ax_l_Y(pvar x, inequality const& i, pdd const& a, pdd const& y) { - return i.rhs() == y && i.lhs() == a * s.var(x); - } - - /** - * Match [x] a*x + b <= y - */ - bool saturation::is_AxB_l_Y(pvar x, inequality const& i, pdd& a, pdd& b, pdd& y) { - y.reset(i.rhs().manager()); - y = i.rhs(); - return i.lhs().degree(x) == 1 && (i.lhs().factor(x, 1, a, b), true); - } - - bool saturation::verify_AxB_l_Y(pvar x, inequality const& i, pdd const& a, pdd const& b, pdd const& y) { - return i.rhs() == y && i.lhs() == a * s.var(x) + b; - } - - - bool saturation::is_Y_l_AxB(pvar x, inequality const& i, pdd& y, pdd& a, pdd& b) { - y.reset(i.lhs().manager()); - y = i.lhs(); - return i.rhs().degree(x) == 1 && (i.rhs().factor(x, 1, a, b), true); - } - - bool saturation::verify_Y_l_AxB(pvar x, inequality const& i, pdd const& y, pdd const& a, pdd& b) { - return i.lhs() == y && i.rhs() == a * s.var(x) + b; - } - - - /** - * Match [x] a*x + b <= y, val(y) = 0 - */ - bool saturation::is_AxB_eq_0(pvar x, inequality const& i, pdd& a, pdd& b, pdd& y) { - y.reset(i.rhs().manager()); - y = i.rhs(); - rational y_val; - if (!s.try_eval(y, y_val) || y_val != 0) - return false; - return i.lhs().degree(x) == 1 && (i.lhs().factor(x, 1, a, b), true); - } - - bool saturation::verify_AxB_eq_0(pvar x, inequality const& i, pdd const& a, pdd const& b, pdd const& y) { - return y.is_val() && y.val() == 0 && i.rhs() == y && i.lhs() == a * s.var(x) + b; - } - - bool saturation::is_AxB_diseq_0(pvar x, inequality const& i, pdd& a, pdd& b, pdd& y) { - if (!i.is_strict()) - return false; - y.reset(i.lhs().manager()); - y = i.lhs(); - if (i.rhs().is_val() && i.rhs().val() == 1) - return false; - rational y_val; - if (!s.try_eval(y, y_val) || y_val != 0) - return false; - a.reset(i.rhs().manager()); - b.reset(i.rhs().manager()); - return i.rhs().degree(x) == 1 && (i.rhs().factor(x, 1, a, b), true); - } - - /** - * Match [coeff*x] coeff*x*Y where x is a variable - */ - bool saturation::is_coeffxY(pdd const& x, pdd const& p, pdd& y) { - pdd xy = x.manager().zero(); - return x.is_unary() && p.try_div(x.hi().val(), xy) && xy.factor(x.var(), 1, y); - } - - /** - * Determine whether values of x * y is non-overflowing. - */ - bool saturation::is_non_overflow(pdd const& x, pdd const& y) { - rational x_val, y_val; - rational bound = x.manager().two_to_N(); - return s.try_eval(x, x_val) && s.try_eval(y, y_val) && x_val * y_val < bound; - } - - /** - * Match [v] v*x <= z*x with x a variable - */ - bool saturation::is_Xy_l_XZ(pvar v, inequality const& i, pdd& x, pdd& z) { - return is_xY(v, i.lhs(), x) && is_coeffxY(x, i.rhs(), z); - } - - bool saturation::verify_Xy_l_XZ(pvar v, inequality const& i, pdd const& x, pdd const& z) { - return i.lhs() == s.var(v) * x && i.rhs() == z * x; - } - - /** - * Match [z] yx <= zx with x a variable - */ - bool saturation::is_YX_l_zX(pvar z, inequality const& c, pdd& x, pdd& y) { - return is_xY(z, c.rhs(), x) && is_coeffxY(x, c.lhs(), y); - } - - bool saturation::verify_YX_l_zX(pvar z, inequality const& c, pdd const& x, pdd const& y) { - return c.lhs() == y * x && c.rhs() == s.var(z) * x; - } - - /** - * Match [x] xY <= xZ - */ - bool saturation::is_xY_l_xZ(pvar x, inequality const& c, pdd& y, pdd& z) { - return is_xY(x, c.lhs(), y) && is_xY(x, c.rhs(), z); - } - - /** - * Match xy = x * Y - */ - bool saturation::is_xY(pvar x, pdd const& xy, pdd& y) { - return xy.degree(x) == 1 && xy.factor(x, 1, y); - } + // // overall comment: we use value propagation to check if p is val @@ -1336,30 +1183,7 @@ namespace polysat { return false; } - /** - * p <= q, q <= p => p - q = 0 - */ - bool saturation::try_infer_equality(pvar x, conflict& core, inequality const& a_l_b) { - set_rule("[x] p <= q, q <= p => p - q = 0"); - if (a_l_b.is_strict()) - return false; - if (a_l_b.lhs().degree(x) == 0 && a_l_b.rhs().degree(x) == 0) - return false; - for (auto c : core) { - if (!c->is_ule()) - continue; - auto i = inequality::from_ule(c); - if (i.lhs() == a_l_b.rhs() && i.rhs() == a_l_b.lhs() && !i.is_strict()) { - m_lemma.reset(); - m_lemma.insert(~c); - if (propagate(x, core, a_l_b, s.eq(i.lhs() - i.rhs()))) { - IF_VERBOSE(1, verbose_stream() << "infer equality " << s.eq(i.lhs() - i.rhs()) << "\n"); - return true; - } - } - } - return false; - } + lbool saturation::get_multiple(const pdd& p1, const pdd& p2, pdd& out) { LOG("Check if " << p2 << " can be multiplied with something to get " << p1); diff --git a/src/sat/smt/polysat/saturation.h b/src/sat/smt/polysat/saturation.h index 2d81f52fa..58c3aca62 100644 --- a/src/sat/smt/polysat/saturation.h +++ b/src/sat/smt/polysat/saturation.h @@ -14,57 +14,10 @@ Author: #pragma once #include "sat/smt/polysat/constraints.h" +#include "sat/smt/polysat/inequality.h" namespace polysat { - struct bilinear { - rational a, b, c, d; - - rational eval(rational const& x, rational const& y) const { - return a*x*y + b*x + c*y + d; - } - - bilinear operator-() const { - bilinear r(*this); - r.a = -r.a; - r.b = -r.b; - r.c = -r.c; - r.d = -r.d; - return r; - } - - bilinear operator-(bilinear const& other) const { - bilinear r(*this); - r.a -= other.a; - r.b -= other.b; - r.c -= other.c; - r.d -= other.d; - return r; - } - - bilinear operator+(rational const& d) const { - bilinear r(*this); - r.d += d; - return r; - } - - bilinear operator-(rational const& d) const { - bilinear r(*this); - r.d -= d; - return r; - } - - bilinear operator-(int d) const { - bilinear r(*this); - r.d -= d; - return r; - } -}; - - inline std::ostream& operator<<(std::ostream& out, bilinear const& b) { - return out << b.a << "*x*y + " << b.b << "*x + " << b.c << "*y + " << b.d; - } - /** * Introduce lemmas that derive new (simpler) constraints from the current conflict and partial model. */ @@ -73,13 +26,26 @@ namespace polysat { core& c; constraints& C; char const* m_rule = nullptr; + bool m_propagated = false; + void set_rule(char const* r) { m_rule = r; } + + void propagate(signed_constraint const& sc, std::initializer_list const& premises); + + + // a * b does not overflow + bool is_non_overflow(pdd const& a, pdd const& b); + + // p := coeff*x*y where coeff_x = coeff*x, x a variable + bool is_coeffxY(pdd const& coeff_x, pdd const& p, pdd& y); + + void propagate_infer_equality(pvar x, inequality const& a_l_b); #if 0 parity_tracker m_parity_tracker; unsigned_vector m_occ; unsigned_vector m_occ_cnt; - void set_rule(char const* r) { m_rule = r; } + bool is_non_overflow(pdd const& x, pdd const& y, signed_constraint& c); signed_constraint ineq(bool strict, pdd const& lhs, pdd const& rhs); @@ -137,61 +103,7 @@ namespace polysat { void fix_values(pvar x, pvar y, pdd const& p); void fix_values(pvar y, pdd const& p); - // c := lhs ~ v - // where ~ is < or <= - bool is_l_v(pvar v, inequality const& c); - // c := v ~ rhs - bool is_g_v(pvar v, inequality const& c); - - // c := x ~ Y - bool is_x_l_Y(pvar x, inequality const& i, pdd& y); - - // c := Y ~ x - bool is_Y_l_x(pvar x, inequality const& i, pdd& y); - - // c := X*y ~ X*Z - bool is_Xy_l_XZ(pvar y, inequality const& c, pdd& x, pdd& z); - bool verify_Xy_l_XZ(pvar y, inequality const& c, pdd const& x, pdd const& z); - - // c := Y ~ Ax - bool is_Y_l_Ax(pvar x, inequality const& c, pdd& a, pdd& y); - bool verify_Y_l_Ax(pvar x, inequality const& c, pdd const& a, pdd const& y); - - // c := Ax ~ Y - bool is_Ax_l_Y(pvar x, inequality const& c, pdd& a, pdd& y); - bool verify_Ax_l_Y(pvar x, inequality const& c, pdd const& a, pdd const& y); - - // c := Ax + B ~ Y - bool is_AxB_l_Y(pvar x, inequality const& c, pdd& a, pdd& b, pdd& y); - bool verify_AxB_l_Y(pvar x, inequality const& c, pdd const& a, pdd const& b, pdd const& y); - - // c := Y ~ Ax + B - bool is_Y_l_AxB(pvar x, inequality const& c, pdd& y, pdd& a, pdd& b); - bool verify_Y_l_AxB(pvar x, inequality const& c, pdd const& y, pdd const& a, pdd& b); - - // c := Ax + B ~ Y, val(Y) = 0 - bool is_AxB_eq_0(pvar x, inequality const& c, pdd& a, pdd& b, pdd& y); - bool verify_AxB_eq_0(pvar x, inequality const& c, pdd const& a, pdd const& b, pdd const& y); - - // c := Ax + B != Y, val(Y) = 0 - bool is_AxB_diseq_0(pvar x, inequality const& c, pdd& a, pdd& b, pdd& y); - - // c := Y*X ~ z*X - bool is_YX_l_zX(pvar z, inequality const& c, pdd& x, pdd& y); - bool verify_YX_l_zX(pvar z, inequality const& c, pdd const& x, pdd const& y); - - // c := xY <= xZ - bool is_xY_l_xZ(pvar x, inequality const& c, pdd& y, pdd& z); - - // xy := x * Y - bool is_xY(pvar x, pdd const& xy, pdd& y); - - // a * b does not overflow - bool is_non_overflow(pdd const& a, pdd const& b); - - // p := coeff*x*y where coeff_x = coeff*x, x a variable - bool is_coeffxY(pdd const& coeff_x, pdd const& p, pdd& y); bool is_add_overflow(pvar x, inequality const& i, pdd& y, bool& is_minus); @@ -222,7 +134,7 @@ namespace polysat { bool is_forced_true(signed_constraint const& sc); - bool try_inequality(pvar v, inequality const& i); + bool try_umul_ovfl(pvar v, signed_constraint c); bool try_umul_ovfl_noovfl(pvar v, signed_constraint c); @@ -233,9 +145,11 @@ namespace polysat { bool try_op(pvar v, signed_constraint c); #endif + void propagate(pvar v); + bool propagate(pvar v, constraint_id cid); + void propagate(pvar v, inequality const& i); + public: saturation(core& c); - void perform(pvar v); - bool perform(pvar v, signed_constraint sc); }; } diff --git a/src/sat/smt/polysat/types.h b/src/sat/smt/polysat/types.h index e0aefb6a9..1c88620fe 100644 --- a/src/sat/smt/polysat/types.h +++ b/src/sat/smt/polysat/types.h @@ -97,11 +97,11 @@ namespace polysat { virtual void set_conflict(dependency_vector const& core) = 0; virtual void set_lemma(core_vector const& aux_core, dependency_vector const& core) = 0; virtual void add_polysat_clause(char const* name, core_vector cs, bool redundant) = 0; - virtual dependency propagate(signed_constraint sc, dependency_vector const& deps) = 0; + virtual bool propagate(signed_constraint sc, dependency_vector const& deps) = 0; virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0; virtual trail_stack& trail() = 0; virtual bool inconsistent() const = 0; - virtual void get_bitvector_prefixes(pvar v, pvar_vector& out) = 0; + virtual void get_bitvector_suffixes(pvar v, pvar_vector& out) = 0; virtual void get_fixed_bits(pvar v, svector& fixed_bits) = 0; }; diff --git a/src/sat/smt/polysat/viable.cpp b/src/sat/smt/polysat/viable.cpp index 3f217d1ce..6442f6739 100644 --- a/src/sat/smt/polysat/viable.cpp +++ b/src/sat/smt/polysat/viable.cpp @@ -104,7 +104,7 @@ namespace polysat { #endif pvar_vector overlaps; - c.get_bitvector_prefixes(v, overlaps); + c.get_bitvector_suffixes(v, overlaps); std::sort(overlaps.begin(), overlaps.end(), [&](pvar x, pvar y) { return c.size(x) > c.size(y); }); uint_set widths_set; diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 219b9017a..28a040869 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -238,12 +238,14 @@ namespace polysat { // Core uses the propagate callback to add unit propagations to the trail. // The polysat::solver takes care of translating signed constraints into expressions, which translate into literals. // Everything goes over expressions/literals. polysat::core is not responsible for replaying expressions. - dependency solver::propagate(signed_constraint sc, dependency_vector const& deps) { + bool solver::propagate(signed_constraint sc, dependency_vector const& deps) { sat::literal lit = ctx.mk_literal(constraint2expr(sc)); + if (s().value(lit) == l_true) + return false; auto [core, eqs] = explain_deps(deps); auto ex = euf::th_explain::propagate(*this, core, eqs, lit, nullptr); ctx.propagate(lit, ex); - return dependency(lit, s().lvl(lit)); + return true; } void solver::propagate(dependency const& d, bool sign, dependency_vector const& deps) { @@ -338,7 +340,7 @@ namespace polysat { } // walk the egraph starting with pvar for overlaps. - void solver::get_bitvector_prefixes(pvar pv, pvar_vector& out) { + void solver::get_bitvector_suffixes(pvar pv, pvar_vector& out) { theory_var v = m_pddvar2var[pv]; euf::enode_vector todo; uint_set seen; @@ -355,7 +357,7 @@ namespace polysat { // identify prefixes if (bv.is_concat(sib->get_expr())) - todo.push_back(sib->get_arg(0)); + todo.push_back(sib->get_arg(sib->num_args() - 1)); if (w == euf::null_theory_var) continue; if (seen.contains(w)) diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 60535207b..74989c1ac 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -162,11 +162,11 @@ namespace polysat { void add_eq_literal(pvar v, rational const& val) override; void set_conflict(dependency_vector const& core) override; void set_lemma(core_vector const& aux_core, dependency_vector const& core) override; - dependency propagate(signed_constraint sc, dependency_vector const& deps) override; + bool propagate(signed_constraint sc, dependency_vector const& deps) override; void propagate(dependency const& d, bool sign, dependency_vector const& deps) override; trail_stack& trail() override; bool inconsistent() const override; - void get_bitvector_prefixes(pvar v, pvar_vector& out) override; + void get_bitvector_suffixes(pvar v, pvar_vector& out) override; void get_fixed_bits(pvar v, svector& fixed_bits) override; void add_polysat_clause(char const* name, core_vector cs, bool redundant) override;