diff --git a/src/sat/smt/polysat/polysat_core.cpp b/src/sat/smt/polysat/polysat_core.cpp index a6ac25eda..c41938a1a 100644 --- a/src/sat/smt/polysat/polysat_core.cpp +++ b/src/sat/smt/polysat/polysat_core.cpp @@ -278,6 +278,14 @@ namespace polysat { } } + void core::get_bitvector_prefixes(pvar v, pvar_vector& out) { + s.get_bitvector_prefixes(v, out); + } + + bool core::inconsistent() const { + return s.inconsistent(); + } + void core::propagate_unsat_core() { // default is to use unsat core: // if core is based on viable, use s.set_lemma(); diff --git a/src/sat/smt/polysat/polysat_core.h b/src/sat/smt/polysat/polysat_core.h index bb21ee641..766c3a9bc 100644 --- a/src/sat/smt/polysat/polysat_core.h +++ b/src/sat/smt/polysat/polysat_core.h @@ -79,6 +79,9 @@ namespace polysat { void propagate_assignment(pvar v, rational const& value, dependency dep); void propagate_unsat_core(); + void get_bitvector_prefixes(pvar v, pvar_vector& out); + bool inconsistent() const; + void add_watch(unsigned idx, unsigned var); signed_constraint get_constraint(unsigned idx, bool sign); @@ -89,8 +92,7 @@ namespace polysat { public: core(solver_interface& s); - sat::check_result check(); - + sat::check_result check(); unsigned register_constraint(signed_constraint& sc, dependency d); bool propagate(); void assign_eh(unsigned idx, bool sign, dependency const& d); diff --git a/src/sat/smt/polysat/polysat_types.h b/src/sat/smt/polysat/polysat_types.h index e77e755bf..5b63b5ee5 100644 --- a/src/sat/smt/polysat/polysat_types.h +++ b/src/sat/smt/polysat/polysat_types.h @@ -10,6 +10,9 @@ Author: #pragma once #include + + + #include "math/dd/dd_pdd.h" #include "util/trail.h" #include "util/sat_literal.h" @@ -62,6 +65,7 @@ namespace polysat { virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0; virtual trail_stack& trail() = 0; virtual bool inconsistent() const = 0; + virtual void get_bitvector_prefixes(pvar v, pvar_vector& out) = 0; }; } diff --git a/src/sat/smt/polysat/polysat_viable.cpp b/src/sat/smt/polysat/polysat_viable.cpp index 27b36c582..d69f40180 100644 --- a/src/sat/smt/polysat/polysat_viable.cpp +++ b/src/sat/smt/polysat/polysat_viable.cpp @@ -98,42 +98,31 @@ namespace polysat { #endif pvar_vector overlaps; -#if 0 - // TODO s.m_slicing.collect_simple_overlaps(v, overlaps); -#endif + c.get_bitvector_prefixes(v, overlaps); std::sort(overlaps.begin(), overlaps.end(), [&](pvar x, pvar y) { return c.size(x) > c.size(y); }); uint_set widths_set; // max size should always be present, regardless of whether we have intervals there (to make sure all fixed bits are considered) widths_set.insert(c.size(v)); -#if 0 - LOG("Overlaps with v" << v << ":"); - for (pvar x : overlaps) { - unsigned hi, lo; - if (s.m_slicing.is_extract(x, v, hi, lo)) - LOG(" v" << x << " = v" << v << "[" << hi << ":" << lo << "]"); - else - LOG(" v" << x << " not extracted from v" << v << "; size " << s.size(x)); - for (layer const& l : m_units[x].get_layers()) { - widths_set.insert(l.bit_width); - } - } -#endif + for (pvar v : overlaps) + ensure_var(v); + for (pvar v : overlaps) + for (layer const& l : m_units[v].get_layers()) + widths_set.insert(l.bit_width); + unsigned_vector widths; for (unsigned w : widths_set) - widths.push_back(w); + widths.push_back(w); + LOG("Overlaps with v" << v << ":" << overlaps); LOG("widths: " << widths); rational const& max_value = c.var2pdd(v).max_value(); -#if 0 lbool result_lo = find_on_layers(v, widths, overlaps, fbi, rational::zero(), max_value, lo); - if (result_lo == l_false) - return l_false; // conflict - if (result_lo == l_undef) - return find_viable_fallback(v, overlaps, lo, hi); + if (result_lo != l_true) + return result_lo; if (lo == max_value) { hi = lo; @@ -141,12 +130,294 @@ namespace polysat { } lbool result_hi = find_on_layers(v, widths, overlaps, fbi, lo + 1, max_value, hi); - if (result_hi == l_false) - hi = lo; // no other viable value - if (result_hi == l_undef) - return find_viable_fallback(v, overlaps, lo, hi); -#endif - return l_true; + + switch (result_hi) { + case l_false: + hi = lo; + return l_true; + case l_undef: + return l_undef; + default: + return l_true; + } + } + + // l_true ... found viable value + // l_false ... no viable value in [to_cover_lo;to_cover_hi[ + // l_undef ... out of resources + lbool viable::find_on_layers( + pvar const v, + unsigned_vector const& widths, + pvar_vector const& overlaps, + fixed_bits_info const& fbi, + rational const& to_cover_lo, + rational const& to_cover_hi, + rational& val + ) { + ptr_vector refine_todo; + ptr_vector relevant_entries; + + // max number of interval refinements before falling back to the univariate solver + unsigned const refinement_budget = 100; + unsigned refinements = refinement_budget; + + while (refinements--) { + relevant_entries.clear(); + lbool result = find_on_layer(v, widths.size() - 1, widths, overlaps, fbi, to_cover_lo, to_cover_hi, val, refine_todo, relevant_entries); + + // store bit-intervals we have used + for (entry* e : refine_todo) + intersect(v, e); + refine_todo.clear(); + + if (result != l_true) + return l_false; + + // overlaps are sorted by variable size in descending order + // start refinement on smallest variable + // however, we probably should rotate to avoid getting stuck in refinement loop on a 'bad' constraint + bool refined = false; + for (unsigned i = overlaps.size(); i-- > 0; ) { + pvar x = overlaps[i]; + rational const& mod_value = c.var2pdd(x).two_to_N(); + rational x_val = mod(val, mod_value); + if (!refine_viable(x, x_val)) { + refined = true; + break; + } + } + + if (!refined) + return l_true; + } + + LOG("Refinement budget exhausted! Fall back to univariate solver."); + return l_undef; + } + + // find viable values in half-open interval [to_cover_lo;to_cover_hi[ w.r.t. unit intervals on the given layer + // + // Returns: + // - l_true ... found value that is viable w.r.t. units and fixed bits + // - l_false ... found conflict + // - l_undef ... found no viable value in target interval [to_cover_lo;to_cover_hi[ + lbool viable::find_on_layer( + pvar const v, + unsigned const w_idx, + unsigned_vector const& widths, + pvar_vector const& overlaps, + fixed_bits_info const& fbi, + rational const& to_cover_lo, + rational const& to_cover_hi, + rational& val, + ptr_vector& refine_todo, + ptr_vector& relevant_entries + ) { + unsigned const w = widths[w_idx]; + rational const& mod_value = rational::power_of_two(w); + unsigned const first_relevant_for_conflict = relevant_entries.size(); + + LOG("layer " << w << " bits, to_cover: [" << to_cover_lo << "; " << to_cover_hi << "["); + SASSERT(0 <= to_cover_lo); + SASSERT(0 <= to_cover_hi); + SASSERT(to_cover_lo < mod_value); + SASSERT(to_cover_hi <= mod_value); // full interval if to_cover_hi == mod_value + SASSERT(to_cover_lo != to_cover_hi); // non-empty search domain (but it may wrap) + + // TODO: refinement of eq/diseq should happen only on the correct layer: where (one of) the coefficient(s) are odd + // for refinement, we have to choose an entry that is violated, but if there are multiple, we can choose the one on smallest domain (lowest bit-width). + // (by maintaining descending order by bit-width; and refine first that fails). + // but fixed-bit-refinement is cheap and could be done during the search. + + // when we arrive at a hole the possibilities are: + // 1) go to lower bitwidth + // 2) refinement of some eq/diseq constraint (if one is violated at that point) -- defer this until point is viable for all layers and fixed bits. + // 3) refinement by using bit constraints? -- TODO: do this during search (another interval check after/before the entry_cursors) + // 4) (point is actually feasible) + + // a complication is that we have to iterate over multiple lists of intervals. + // might be useful to merge them upfront to simplify the remainder of the algorithm? + // (non-trivial since prev/next pointers are embedded into entries and lists are updated by refinement) + struct entry_cursor { + entry* cur; + // entry* first; + // entry* last; + }; + + // find relevant interval lists + svector ecs; + for (pvar x : overlaps) { + if (c.size(x) < w) // note that overlaps are sorted by variable size descending + break; + if (entry* e = m_units[x].get_entries(w)) { + display_all(std::cerr << "units for width " << w << ":\n", 0, e, "\n"); + entry_cursor ec; + ec.cur = e; // TODO: e->prev() probably makes it faster when querying 0 (can often save going around the interval list once) + // ec.first = nullptr; + // ec.last = nullptr; + ecs.push_back(ec); + } + } + + rational const to_cover_len = r_interval::len(to_cover_lo, to_cover_hi, mod_value); + val = to_cover_lo; + + rational progress; // = 0 + SASSERT(progress.is_zero()); + while (true) { + while (true) { + entry* e = nullptr; + + // try to make progress using any of the relevant interval lists + for (entry_cursor& ec : ecs) { + // advance until current value 'val' + auto const [n, n_contains_val] = find_value(val, ec.cur); + // display_one(std::cerr << "found entry n: ", 0, n) << "\n"; + // LOG("n_contains_val: " << n_contains_val << " val = " << val); + ec.cur = n; + if (n_contains_val) { + e = n; + break; + } + } + + // when we cannot make progress by existing intervals any more, try interval from fixed bits + if (!e) { + e = refine_bits(v, val, w, fbi); + if (e) { + refine_todo.push_back(e); + display_one(std::cerr << "found entry by bits: ", 0, e) << "\n"; + } + } + + // no more progress on current layer + if (!e) + break; + + relevant_entries.push_back(e); + + if (e->interval.is_full()) { + relevant_entries.clear(); + relevant_entries.push_back(e); // full interval e -> all other intervals are subsumed/irrelevant + set_conflict_by_interval(v, w, relevant_entries, 0); + return l_false; + } + + SASSERT(e->interval.currently_contains(val)); + rational const& new_val = e->interval.hi_val(); + rational const dist = distance(val, new_val, mod_value); + SASSERT(dist > 0); + val = new_val; + progress += dist; + LOG("val: " << val << " progress: " << progress); + + if (progress >= mod_value) { + // covered the whole domain => conflict + set_conflict_by_interval(v, w, relevant_entries, first_relevant_for_conflict); + return l_false; + } + if (progress >= to_cover_len) { + // we covered the hole left at larger bit-width + // TODO: maybe we want to keep trying a bit longer to see if we can cover the whole domain. or maybe only if we enter this layer multiple times. + return l_undef; + } + + // (another way to compute 'progress') + SASSERT_EQ(progress, distance(to_cover_lo, val, mod_value)); + } + + // no more progress + // => 'val' is viable w.r.t. unit intervals until current layer + + if (!w_idx) { + // we are at the lowest layer + // => found viable value w.r.t. unit intervals and fixed bits + return l_true; + } + + // find next covered value + rational next_val = to_cover_hi; + for (entry_cursor& ec : ecs) { + // each ec.cur is now the next interval after 'lo' + rational const& n = ec.cur->interval.lo_val(); + SASSERT(r_interval::contains(ec.cur->prev()->interval.hi_val(), n, val)); + if (distance(val, n, mod_value) < distance(val, next_val, mod_value)) + next_val = n; + } + if (entry* e = refine_bits(v, next_val, w, fbi)) { + refine_todo.push_back(e); + rational const& n = e->interval.lo_val(); + SASSERT(distance(val, n, mod_value) < distance(val, next_val, mod_value)); + next_val = n; + } + SASSERT(!refine_bits(v, val, w, fbi)); + SASSERT(val != next_val); + + unsigned const lower_w = widths[w_idx - 1]; + rational const lower_mod_value = rational::power_of_two(lower_w); + + rational lower_cover_lo, lower_cover_hi; + if (distance(val, next_val, mod_value) >= lower_mod_value) { + // NOTE: in this case we do not get the first viable value, but the one with smallest value in the lower bits. + // this is because we start the search in the recursive case at 0. + // if this is a problem, adapt to lower_cover_lo = mod(val, lower_mod_value), lower_cover_hi = ... + lower_cover_lo = 0; + lower_cover_hi = lower_mod_value; + rational a; + lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo, relevant_entries); + VERIFY(result != l_undef); + if (result == l_false) { + SASSERT(c.inconsistent()); + return l_false; // conflict + } + SASSERT(result == l_true); + // replace lower bits of 'val' by 'a' + rational const val_lower = mod(val, lower_mod_value); + val = val - val_lower + a; + if (a < val_lower) + a += lower_mod_value; + LOG("distance(val, cover_hi) = " << distance(val, to_cover_hi, mod_value)); + LOG("distance(next_val, cover_hi) = " << distance(next_val, to_cover_hi, mod_value)); + SASSERT(distance(val, to_cover_hi, mod_value) >= distance(next_val, to_cover_hi, mod_value)); + return l_true; + } + + lower_cover_lo = mod(val, lower_mod_value); + lower_cover_hi = mod(next_val, lower_mod_value); + + rational a; + lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo, relevant_entries); + if (result == l_false) { + SASSERT(c.inconsistent()); + return l_false; // conflict + } + + // replace lower bits of 'val' by 'a' + rational const dist = distance(lower_cover_lo, a, lower_mod_value); + val += dist; + progress += dist; + LOG("distance(val, cover_hi) = " << distance(val, to_cover_hi, mod_value)); + LOG("distance(next_val, cover_hi) = " << distance(next_val, to_cover_hi, mod_value)); + SASSERT(distance(val, to_cover_hi, mod_value) >= distance(next_val, to_cover_hi, mod_value)); + + if (result == l_true) + return l_true; // done + + SASSERT(result == l_undef); + + if (progress >= mod_value) { + // covered the whole domain => conflict + set_conflict_by_interval(v, w, relevant_entries, first_relevant_for_conflict); + return l_false; + } + + if (progress >= to_cover_len) { + // we covered the hole left at larger bit-width + return l_undef; + } + } + UNREACHABLE(); + return l_undef; } diff --git a/src/sat/smt/polysat/polysat_viable.h b/src/sat/smt/polysat/polysat_viable.h index 7c9c019dc..37b1d7b0c 100644 --- a/src/sat/smt/polysat/polysat_viable.h +++ b/src/sat/smt/polysat/polysat_viable.h @@ -162,6 +162,63 @@ namespace polysat { lbool find_viable(pvar v, rational& lo, rational& hi); + lbool find_on_layers( + pvar v, + unsigned_vector const& widths, + pvar_vector const& overlaps, + fixed_bits_info const& fbi, + rational const& to_cover_lo, + rational const& to_cover_hi, + rational& out_val); + + lbool find_on_layer( + pvar v, + unsigned w_idx, + unsigned_vector const& widths, + pvar_vector const& overlaps, + fixed_bits_info const& fbi, + rational const& to_cover_lo, + rational const& to_cover_hi, + rational& out_val, + ptr_vector& refine_todo, + ptr_vector& relevant_entries); + + + template + bool refine_viable(pvar v, rational const& val, fixed_bits_info const& fbi) { + throw default_exception("nyi"); + } + + bool refine_viable(pvar v, rational const& val) { + throw default_exception("nyi"); + } + + template + bool refine_bits(pvar v, rational const& val, fixed_bits_info const& fbi) { + throw default_exception("nyi"); + } + + template + entry* refine_bits(pvar v, rational const& val, unsigned num_bits, fixed_bits_info const& fbi) { + throw default_exception("nyi"); + } + + bool refine_equal_lin(pvar v, rational const& val) { + throw default_exception("nyi"); + } + + bool refine_disequal_lin(pvar v, rational const& val) { + throw default_exception("nyi"); + } + + bool set_conflict_by_interval(pvar v, unsigned w, ptr_vector& intervals, unsigned first_interval) { + throw default_exception("nyi"); + } + + std::pair find_value(rational const& val, entry* entries) { + throw default_exception("nyi"); + } + public: viable(core& c); diff --git a/src/sat/smt/polysat/slicing.cpp b/src/sat/smt/polysat/slicing.cpp new file mode 100644 index 000000000..04fe8c4fc --- /dev/null +++ b/src/sat/smt/polysat/slicing.cpp @@ -0,0 +1,1727 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + polysat slicing + +Author: + + Jakob Rath 2023-06-01 + +--*/ + + + + +/* + +Example: +(1) x = y +(2) z = y[3:0] +(3) explain(x[3:0] == z)? should be { (1), (2) } + + (1) + x ========================> y + / \ / \ (2) + x[7:4] x[3:0] y[7:4] y[3:0] ===========> z + + +TODO: +- About the sub-slice sharing among equivalent nodes: + - When extracting a variable y := x[h:l], we always need to create a new slice for y. + - Merge slices for x[h:l] with y; store as dependency 'x[h:l]' (rather than 'null_dep' as we do now). + - when merging, we must avoid that the new variable becomes the root of the equivalence class, + because when finding dependencies for 'y := x[h:l]', such extraction-dependencies would be false/unnecessary. + (alternatively, just ignore them. but we never *have* to have them as root, so just don't do it. but add assertions for 1. new var is not root, 2. no extraction-dependency when walking from 'x' to 'x[h:l]'.) + - When encountering this dependency, need to start at slice for 'x' and walk towards 'x[h:l]', + collecting dependencies whenever we move to a representative. +- when solver assigns value of a variable v, add equations with v substituted by its value? + - since we only track equations over variables/names, this might not lead to many further additions + - a question is how to track/handle the dependency on the assignment +- check Algorithm 2 of "Solving Bitvectors with MCSAT"; what is the difference to what we are doing now? +- track equalities such as x = -y ? +- on_merge could propagate values upwards: + if slice has value but parent has no value, then check if sub_other(parent(s)) [sibling(s)?] has a value. + if yes, can propagate value upwards. (add a congruence term to track deps properly?). + we have to check the whole equivalence class, because the parents may be in different classes. + it is enough to propagate values to variables. we could count (in the variable slice) the number of its base slices that are still unassigned. + +*/ + + +#include "ast/reg_decl_plugins.h" +#include "math/polysat/slicing.h" +#include "math/polysat/solver.h" +#include "math/polysat/log.h" +#include "util/tptr.h" + + +namespace { + + template + [[maybe_unused]] + inline constexpr bool always_false_v = false; + +} + +namespace polysat { + + void* slicing::dep_t::encode() const { + void* p = std::visit([](auto arg) -> void* { + using T = std::decay_t; + if constexpr (std::is_same_v) + return nullptr; + else if constexpr (std::is_same_v) + return box(arg.to_uint(), 1); + else if constexpr (std::is_same_v) + return box(arg, 2); + else + static_assert(always_false_v, "non-exhaustive visitor!"); + }, m_data); + SASSERT(*this == decode(p)); + return p; + } + + slicing::dep_t slicing::dep_t::decode(void* p) { + if (!p) + return {}; + unsigned tag = get_tag(p); + SASSERT(tag == 1 || tag == 2); + if (tag == 1) + return dep_t(sat::to_literal(unbox(p))); + else + return dep_t(unbox(p)); + } + + std::ostream& slicing::display(std::ostream& out, dep_t d) const { + if (d.is_null()) + out << "null"; + else if (d.is_value()) { + pvar x = get_dep_var(d); + enode* n = get_dep_slice(d); + sat::literal lit = get_dep_lit(d); + out << "value(v" << x << " on slice "; + if (n) + out << n->get_id(); + else + out << ""; + if (lit != sat::null_literal) + out << " by literal " << lit; + out << ")"; + } + else if (d.is_lit()) + out << "lit(" << d.lit() << ")"; + return out; + } + + slicing::dep_t slicing::mk_var_dep(pvar v, enode* s, sat::literal lit) { + SASSERT_EQ(m_dep_var.size(), m_dep_slice.size()); + SASSERT_EQ(m_dep_var.size(), m_dep_lit.size()); + unsigned const idx = m_dep_var.size(); + m_dep_var.push_back(v); + m_dep_lit.push_back(lit); + m_dep_slice.push_back(s); + return dep_t(idx); + } + + slicing::slicing(solver& s): + m_solver(s), + m_egraph(m_ast) + { + reg_decl_plugins(m_ast); + m_bv = alloc(bv_util, m_ast); + m_egraph.set_display_justification([&](std::ostream& out, void* dp) { display(out, dep_t::decode(dp)); }); + m_egraph.set_on_merge([&](enode* root, enode* other) { egraph_on_merge(root, other); }); + m_egraph.set_on_propagate([&](enode* lit, enode* ante) { egraph_on_propagate(lit, ante); }); + // m_egraph.set_on_make([&](enode* n) { egraph_on_make(n); }); + } + + slicing::slice_info& slicing::info(enode* n) { + return const_cast(std::as_const(*this).info(n)); + } + + slicing::slice_info const& slicing::info(enode* n) const { + SASSERT(n); + SASSERT(!n->is_equality()); + SASSERT(m_bv->is_bv_sort(n->get_sort())); + slice_info const& i = m_info[n->get_id()]; + return i.slice ? info(i.slice) : i; + } + + bool slicing::is_slice(enode* n) const { + if (n->is_equality()) + return false; + SASSERT(m_bv->is_bv_sort(n->get_sort())); + slice_info const& i = m_info[n->get_id()]; + return !i.slice; + } + + bool slicing::is_concat(enode* n) const { + if (n->is_equality()) + return false; + return !is_slice(n); + } + + unsigned slicing::width(enode* s) const { + SASSERT(!s->is_equality()); + return m_bv->get_bv_size(s->get_expr()); + } + + slicing::enode* slicing::sibling(enode* s) const { + enode* p = parent(s); + SASSERT(p); + SASSERT(sub_lo(p) == s || sub_hi(p) == s); + if (s != sub_hi(p)) + return sub_hi(p); + else + return sub_lo(p); + } + + func_decl* slicing::mk_concat_decl(ptr_vector const& args) { + SASSERT(args.size() >= 2); + ptr_vector domain; + unsigned sz = 0; + for (expr* e : args) { + domain.push_back(e->get_sort()); + sz += m_bv->get_bv_size(e); + } + sort* range = m_bv->mk_sort(sz); + return m_ast.mk_func_decl(symbol("slice-concat"), domain.size(), domain.data(), range); + } + + void slicing::push_scope() { + LOG("push_scope"); + if (can_propagate()) + propagate(); + m_scopes.push_back(m_trail.size()); + m_egraph.push(); + m_dep_size_trail.push_back(m_dep_var.size()); + SASSERT(!use_var_congruences() || m_needs_congruence.empty()); + } + + void slicing::pop_scope(unsigned num_scopes) { + LOG("pop_scope(" << num_scopes << ")"); + if (num_scopes == 0) + return; + unsigned const lvl = m_scopes.size(); + SASSERT(num_scopes <= lvl); + unsigned const target_lvl = lvl - num_scopes; + unsigned const target_size = m_scopes[target_lvl]; + m_scopes.shrink(target_lvl); + svector replay_trail; + unsigned_vector replay_add_var_trail; + svector> replay_extract_trail; + svector replay_concat_trail; + unsigned num_replay_concat = 0; + for (unsigned i = m_trail.size(); i-- > target_size; ) { + switch (m_trail[i]) { + case trail_item::add_var: + replay_trail.push_back(trail_item::add_var); + replay_add_var_trail.push_back(width(m_var2slice.back())); + undo_add_var(); + break; + case trail_item::split_core: + undo_split_core(); + break; + case trail_item::mk_extract: { + replay_trail.push_back(trail_item::mk_extract); + extract_args const& args = m_extract_trail.back(); + replay_extract_trail.push_back({args, m_extract_dedup[args]}); + undo_mk_extract(); + break; + } + case trail_item::mk_concat: + replay_trail.push_back(trail_item::mk_concat); + num_replay_concat++; + break; + case trail_item::set_value_node: + undo_set_value_node(); + break; + default: + UNREACHABLE(); + } + } + m_egraph.pop(num_scopes); + m_needs_congruence.reset(); + m_disequality_conflict = nullptr; + m_dep_var.shrink(m_dep_size_trail[target_lvl]); + m_dep_lit.shrink(m_dep_size_trail[target_lvl]); + m_dep_slice.shrink(m_dep_size_trail[target_lvl]); + m_dep_size_trail.shrink(target_lvl); + m_trail.shrink(target_size); + // replay add_var/mk_extract/mk_concat in the same order + // (only until polysat::solver supports proper garbage collection of variables) + unsigned add_var_idx = replay_add_var_trail.size(); + unsigned extract_idx = replay_extract_trail.size(); + unsigned concat_idx = m_concat_trail.size() - num_replay_concat; + for (auto it = replay_trail.rbegin(); it != replay_trail.rend(); ++it) { + switch (*it) { + case trail_item::add_var: { + unsigned const sz = replay_add_var_trail[--add_var_idx]; + add_var(sz); + break; + } + case trail_item::mk_extract: { + auto const [args, v] = replay_extract_trail[--extract_idx]; + replay_extract(args, v); + break; + } + case trail_item::mk_concat: { + NOT_IMPLEMENTED_YET(); + auto const ci = m_concat_trail[concat_idx++]; + num_replay_concat++; + replay_concat(ci.num_args, &m_concat_args[ci.args_idx], ci.v); + break; + } + default: + UNREACHABLE(); + } + } + SASSERT(invariant()); + } + + void slicing::add_var(unsigned bit_width) { + pvar const v = m_var2slice.size(); + enode* s = alloc_slice(bit_width, v); + m_var2slice.push_back(s); + m_trail.push_back(trail_item::add_var); + LOG_V(10, "add_var: v" << v << " -> " << slice_pp(*this, s)); + } + + void slicing::undo_add_var() { + m_var2slice.pop_back(); + } + + slicing::enode* slicing::find_or_alloc_disequality(enode* x, enode* y, sat::literal lit) { + expr_ref eq(m_ast.mk_eq(x->get_expr(), y->get_expr()), m_ast); + enode* eqn = m_egraph.find(eq); + if (eqn) + return eqn; + auto args = {x, y}; + eqn = m_egraph.mk(eq, 0, args.size(), args.begin()); + auto j = euf::justification::external(dep_t(lit).encode()); + m_egraph.set_value(eqn, l_false, j); + SASSERT(eqn->is_equality()); + SASSERT_EQ(eqn->value(), l_false); + return eqn; + } + + void slicing::egraph_on_make(enode* n) { + LOG("on_make: " << e_pp(n)); + } + + slicing::enode* slicing::alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var) { + SASSERT(!m_egraph.find(e)); + // NOTE: sometimes egraph::mk already triggers a merge due to congruence. + // in this case we have to make sure to allocate m_info early enough. + unsigned const id = e->get_id(); + m_info.reserve(id + 1); + slice_info& i = m_info[id]; + i.reset(); + i.var = var; + enode* n = m_egraph.mk(e, 0, num_args, args); // NOTE: the egraph keeps a strong reference to 'e' + LOG_V(10, "alloc_enode: " << slice_pp(*this, n) << " " << e_pp(n)); + return n; + } + + slicing::enode* slicing::find_or_alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var) { + enode* n = m_egraph.find(e); + if (n) { + SASSERT_EQ(info(n).var, var); + return n; + } + return alloc_enode(e, num_args, args, var); + } + + slicing::enode* slicing::alloc_slice(unsigned width, pvar var) { + SASSERT(width > 0); + app_ref a(m_ast.mk_fresh_const("s", m_bv->mk_sort(width), false), m_ast); + return alloc_enode(a, 0, nullptr, var); + } + + slicing::enode* slicing::mk_concat_node(enode_vector const& slices) { + return mk_concat_node(slices.size(), slices.data()); + } + + slicing::enode* slicing::mk_concat_node(unsigned num_slices, enode* const* slices) { + ptr_vector args; + for (unsigned i = 0; i < num_slices; ++i) + args.push_back(slices[i]->get_expr()); + app_ref a(m_ast.mk_app(mk_concat_decl(args), args), m_ast); + return find_or_alloc_enode(a, num_slices, slices, null_var); + } + + void slicing::add_concat_node(enode* s, enode* concat) { + SASSERT(slice2var(s) != null_var); // all concat nodes should point to a variable node + SASSERT(is_app(concat->get_expr())); + slice_info& concat_info = m_info[concat->get_id()]; + if (s->get_root() == concat->get_root()) { + SASSERT_EQ(s, concat_info.slice); + return; + } + SASSERT(!concat_info.slice); // not yet set + concat_info.slice = s; + egraph_merge(s, concat, dep_t()); + } + + void slicing::add_var_congruence(pvar v) { + enode_vector& base = m_tmp2; + SASSERT(base.empty()); + enode* sv = var2slice(v); + get_base(sv, base); + // Add equation v == concat(s1, ..., sn) + add_concat_node(sv, mk_concat_node(base)); + base.clear(); + } + + void slicing::add_var_congruence_if_needed(pvar v) { + if (!m_needs_congruence.contains(v)) + return; + m_needs_congruence.remove(v); + add_var_congruence(v); + } + + void slicing::update_var_congruences() { + if (!use_var_congruences()) + return; + // TODO: this is only needed once per equivalence class + // (mark root of var2slice to detect duplicates?) + for (pvar v : m_needs_congruence) + add_var_congruence(v); + m_needs_congruence.reset(); + } + + bool slicing::use_var_congruences() const { + return m_solver.config().m_slicing_congruence; + } + + // split a single slice without updating any equivalences + void slicing::split_core(enode* s, unsigned cut) { + SASSERT(is_slice(s)); // this action only makes sense for slices + SASSERT(!has_sub(s)); + SASSERT(info(s).sub_hi == nullptr && info(s).sub_lo == nullptr); + SASSERT(width(s) > cut + 1); + unsigned const width_hi = width(s) - cut - 1; + unsigned const width_lo = cut + 1; + enode* sub_hi; + enode* sub_lo; + if (is_value(s)) { + rational const val = get_value(s); + sub_hi = mk_value_slice(machine_div2k(val, width_lo), width_hi); + sub_lo = mk_value_slice(mod2k(val, width_lo), width_lo); + } + else { + sub_hi = alloc_slice(width_hi); + sub_lo = alloc_slice(width_lo); + } + SASSERT(!parent(sub_hi)); + SASSERT(!parent(sub_lo)); + info(sub_hi).parent = s; + info(sub_lo).parent = s; + info(s).set_cut(cut, sub_hi, sub_lo); + m_trail.push_back(trail_item::split_core); + m_enode_trail.push_back(s); + for (enode* n = s; n != nullptr; n = parent(n)) { + pvar const v = slice2var(n); + if (v == null_var) + continue; + if (m_needs_congruence.contains(v)) { + SASSERT(invariant_needs_congruence()); + break; // added parents already previously + } + m_needs_congruence.insert(v); + } + } + + bool slicing::invariant_needs_congruence() const { + for (pvar v : m_needs_congruence) + for (enode* s = var2slice(v); s != nullptr; s = parent(s)) + if (slice2var(s) != null_var) { + VERIFY(m_needs_congruence.contains(slice2var(s))); + } + return true; + } + + void slicing::undo_split_core() { + enode* s = m_enode_trail.back(); + m_enode_trail.pop_back(); + info(s).set_cut(null_cut, nullptr, nullptr); + } + + void slicing::split(enode* s, unsigned cut) { + // this action only makes sense for base slices. + // a base slice is never equivalent to a congruence node. + SASSERT(is_slice(s)); + SASSERT(!has_sub(s)); + SASSERT(cut != null_cut); + // split all slices in the equivalence class + for (enode* n : euf::enode_class(s)) + split_core(n, cut); + // propagate equivalences to subslices + for (enode* n : euf::enode_class(s)) { + enode* target = n->get_target(); + if (!target) + continue; + euf::justification const j = n->get_justification(); + SASSERT(j.is_external()); // cannot be a congruence since the slice wasn't split before. + dep_t const d = dep_t::decode(j.ext()); + egraph_merge(sub_hi(n), sub_hi(target), d); + egraph_merge(sub_lo(n), sub_lo(target), d); + } + } + + void slicing::mk_slice(enode* src, unsigned const hi, unsigned const lo, enode_vector& out, bool output_full_src, bool output_base) { + SASSERT(hi >= lo); + SASSERT(width(src) > hi); // extracted range must be fully contained inside the src slice + auto output_slice = [this, output_base, &out](enode* s) { + if (output_base) + get_base(s, out); + else + out.push_back(s); + }; + if (lo == 0 && width(src) - 1 == hi) { + output_slice(src); + return; + } + if (has_sub(src)) { + // src is split into [src.width-1, cut+1] and [cut, 0] + unsigned const cut = info(src).cut; + if (lo >= cut + 1) { + // target slice falls into upper subslice + mk_slice(sub_hi(src), hi - cut - 1, lo - cut - 1, out, output_full_src, output_base); + if (output_full_src) + output_slice(sub_lo(src)); + return; + } + else if (cut >= hi) { + // target slice falls into lower subslice + if (output_full_src) + output_slice(sub_hi(src)); + mk_slice(sub_lo(src), hi, lo, out, output_full_src, output_base); + return; + } + else { + SASSERT(hi > cut && cut >= lo); + // desired range spans over the cutpoint, so we get multiple slices in the result + mk_slice(sub_hi(src), hi - cut - 1, 0, out, output_full_src, output_base); + mk_slice(sub_lo(src), cut, lo, out, output_full_src, output_base); + return; + } + } + else { + // [src.width-1, 0] has no subdivision yet + if (width(src) - 1 > hi) { + split(src, hi); + SASSERT(!has_sub(sub_hi(src))); + if (output_full_src) + out.push_back(sub_hi(src)); + mk_slice(sub_lo(src), hi, lo, out, output_full_src, output_base); // recursive call to take care of case lo > 0 + return; + } + else { + SASSERT(lo > 0); + split(src, lo - 1); + out.push_back(sub_hi(src)); + SASSERT(!has_sub(sub_lo(src))); + if (output_full_src) + out.push_back(sub_lo(src)); + return; + } + } + UNREACHABLE(); + } + + slicing::enode* slicing::mk_value_slice(rational const& val, unsigned bit_width) { + SASSERT(bit_width > 0); + SASSERT(0 <= val && val < rational::power_of_two(bit_width)); + sort* bv_sort = m_bv->mk_sort(bit_width); + func_decl_ref f(m_ast.mk_fresh_func_decl("val", nullptr, 1, &bv_sort, bv_sort, false), m_ast); + app_ref a(m_ast.mk_app(f, m_bv->mk_numeral(val, bit_width)), m_ast); + enode* s = alloc_enode(a, 0, nullptr, null_var); + set_value_node(s, s); + SASSERT_EQ(get_value(s), val); + return s; + } + + slicing::enode* slicing::mk_interpreted_value_node(enode* s) { + SASSERT(is_value(s)); + // NOTE: how this is used now, the node will not yet be contained in the egraph. + enode* n = alloc_enode(s->get_app()->get_arg(0), 0, nullptr, null_var); + info(n).value_node = s; + n->mark_interpreted(); + SASSERT(n->interpreted()); + SASSERT_EQ(get_value_node(n), s); + return n; + } + + bool slicing::is_value(enode* n) const { + SASSERT(n); + SASSERT(is_app(n->get_expr())); // we only create app nodes at the moment; if this fails just return false here. + app* a = n->get_app(); + return a->get_num_args() == 1 && m_bv->is_numeral(a->get_arg(0)); + } + + rational slicing::get_value(enode* s) const { + SASSERT(is_value(s)); + rational val; + VERIFY(try_get_value(s, val)); + return val; + } + + bool slicing::try_get_value(enode* s, rational& val) const { + if (!s) + return false; + app* a = s->get_app(); + if (a->get_num_args() != 1) + return false; + bool const ok = m_bv->is_numeral(a->get_arg(0), val); + SASSERT_EQ(ok, is_value(s)); + return ok; + } + + void slicing::explain_class(enode* x, enode* y, ptr_vector& out_deps) { + SASSERT_EQ(x->get_root(), y->get_root()); + m_egraph.begin_explain(); + m_egraph.explain_eq(out_deps, nullptr, x, y); + m_egraph.end_explain(); + } + + void slicing::explain_equal(enode* x, enode* y, ptr_vector& out_deps) { + SASSERT(is_equal(x, y)); + SASSERT_EQ(width(x), width(y)); + enode_vector& xs = m_tmp2; + enode_vector& ys = m_tmp3; + SASSERT(xs.empty()); + SASSERT(ys.empty()); + xs.push_back(x); + ys.push_back(y); + while (!xs.empty()) { + SASSERT(!ys.empty()); + enode* const x = xs.back(); xs.pop_back(); + enode* const y = ys.back(); ys.pop_back(); + if (x == y) + continue; + if (width(x) == width(y)) { + enode* const rx = x->get_root(); + enode* const ry = y->get_root(); + if (rx == ry) + explain_class(x, y, out_deps); + else { + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + ys.push_back(sub_hi(y)); + ys.push_back(sub_lo(y)); + } + } + else if (width(x) > width(y)) { + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + ys.push_back(y); + } + else { + SASSERT(width(x) < width(y)); + xs.push_back(x); + ys.push_back(sub_hi(y)); + ys.push_back(sub_lo(y)); + } + } + SASSERT(ys.empty()); + } + + void slicing::explain_equal(pvar x, pvar y, ptr_vector& out_deps) { + explain_equal(var2slice(x), var2slice(y), out_deps); + } + + void slicing::explain_equal(pvar x, pvar y, std::function const& on_lit) { + SASSERT(m_marked_lits.empty()); + SASSERT(m_tmp_deps.empty()); + explain_equal(x, y, m_tmp_deps); + for (void* dp : m_tmp_deps) { + dep_t const d = dep_t::decode(dp); + if (d.is_null()) + continue; + if (d.is_lit()) { + sat::literal lit = d.lit(); + if (m_marked_lits.contains(lit)) + continue; + m_marked_lits.insert(lit); + on_lit(d.lit()); + } + else { + // equivalence between to variables cannot be due to value assignment + UNREACHABLE(); + } + } + m_marked_lits.reset(); + m_tmp_deps.reset(); + } + + void slicing::explain(ptr_vector& out_deps) { + SASSERT(is_conflict()); + m_egraph.begin_explain(); + if (m_disequality_conflict) { + LOG("Disequality conflict: " << m_disequality_conflict); + enode* eqn = m_disequality_conflict; + SASSERT(eqn->is_equality()); + SASSERT_EQ(eqn->value(), l_false); + SASSERT(eqn->get_lit_justification().is_external()); + SASSERT(m_ast.is_eq(eqn->get_expr())); + SASSERT_EQ(eqn->get_arg(0)->get_root(), eqn->get_arg(1)->get_root()); + m_egraph.explain_eq(out_deps, nullptr, eqn->get_arg(0), eqn->get_arg(1)); + out_deps.push_back(eqn->get_lit_justification().ext()); + } + else { + SASSERT(m_egraph.inconsistent()); + m_egraph.explain(out_deps, nullptr); + } + m_egraph.end_explain(); + } + + clause_ref slicing::build_conflict_clause() { + LOG_H1("slicing: build_conflict_clause"); + // display_tree(verbose_stream()); + + SASSERT(invariant()); + SASSERT(is_conflict()); + SASSERT(m_marked_lits.empty()); + SASSERT(m_tmp_deps.empty()); + explain(m_tmp_deps); + clause_builder cb(m_solver, "slicing"); + + auto add_premise = [this, &cb](sat::literal lit) { + LOG("Premise: " << lit_pp(m_solver, lit)); + cb.insert(~lit); + }; + + auto add_conclusion = [this, &cb](signed_constraint c) { + LOG("Conclusion: " << lit_pp(m_solver, c)); + cb.insert_eval(c); + }; + + pvar x = null_var; enode* sx = nullptr; sat::literal xlit = sat::null_literal; + pvar y = null_var; enode* sy = nullptr; sat::literal ylit = sat::null_literal; + for (void* dp : m_tmp_deps) { + dep_t const d = dep_t::decode(dp); + // LOG("dep: " << dep_pp(*this, d)); + if (d.is_null()) + continue; + if (d.is_lit()) { + sat::literal const lit = d.lit(); + if (m_marked_lits.contains(lit)) + continue; + m_marked_lits.insert(lit); + add_premise(lit); + } + else { + SASSERT(d.is_value()); + pvar const v = get_dep_var(d); + enode* const sv = get_dep_slice(d); + sat::literal const lit = get_dep_lit(d); + if (x == null_var) + x = v, sx = sv, xlit = lit; + else if (y == null_var) + y = v, sy = sv, ylit = lit; + else { + // pvar justifications are only introduced by add_value, i.e., when a variable is assigned in the solver. + // thus there can be at most two pvar justifications in a single conflict. + UNREACHABLE(); + } + } + } + m_marked_lits.reset(); + m_tmp_deps.reset(); + + if (x != null_var && y != null_var && xlit == sat::null_literal && ylit != sat::null_literal) { + using std::swap; + swap(x, y); + swap(sx, sy); + swap(xlit, ylit); + } + + if (x != null_var) { + LOG("Variable v" << x << " with slice " << slice_pp(*this, sx) << " by literal " << lit_pp(m_solver, xlit)); + if (m_solver.is_assigned(x)) + LOG("solver-value " << assignment_pp(m_solver, x, m_solver.get_value(x))); + } + if (y != null_var) { + LOG("Variable v" << y << " with slice " << slice_pp(*this, sy) << " by literal " << lit_pp(m_solver, ylit)); + if (m_solver.is_assigned(y)) + LOG("solver-value " << assignment_pp(m_solver, y, m_solver.get_value(y))); + } + + // conflict has either 0 or 2 vars + VERIFY(x != null_var || y == null_var); + VERIFY(y != null_var || x == null_var); + + if (xlit != sat::null_literal && ylit != sat::null_literal) { + verbose_stream() << "build_conflict_clause (2)" << std::endl; + add_premise(xlit); + add_premise(ylit); + } + else if (xlit != sat::null_literal && ylit == sat::null_literal) { + verbose_stream() << "build_conflict_clause (1)" << std::endl; + add_premise(xlit); + + // rational const x_slice_value = get_value(get_value_node(var2slice(x))); + // LOG("v" << x << " slice_value: " << x_slice_value); + rational const y_slice_value = get_value(get_value_node(var2slice(y))); + LOG("v" << y << " slice_value: " << y_slice_value); + // SASSERT(x_slice_value != y_slice_value); + add_conclusion(~m_solver.eq(m_solver.var(y), y_slice_value)); + +/* + unsigned x_hi, x_lo; + VERIFY(find_range_in_ancestor(sx, var2slice(x), x_hi, x_lo)); + pvar const xx = mk_extract(x, x_hi, x_lo); + LOG("find_range_in_ancestor: v" << x << "[" << x_hi << ":" << x_lo << "] = " << slice_pp(*this, sx) << " --> represented by variable v" << xx); + unsigned y_hi, y_lo; + VERIFY(find_range_in_ancestor(sy, var2slice(y), y_hi, y_lo)); + pvar const yy = mk_extract(y, y_hi, y_lo); + LOG("find_range_in_ancestor: v" << y << "[" << y_hi << ":" << y_lo << "] = " << slice_pp(*this, sy) << " --> represented by variable v" << yy); + // LOG("v" << x << " has solver-value? " << m_solver.is_assigned(x)); + if (m_solver.is_assigned(x)) LOG("v" << x << " has solver-value " << m_solver.get_value(x)); + // LOG("v" << y << " has solver-value? " << m_solver.is_assigned(y)); + if (m_solver.is_assigned(y)) LOG("v" << y << " has solver-value " << m_solver.get_value(y)); + LOG("v" << x << " is slice " << slice_pp(*this, var2slice(x))); + LOG("v" << y << " is slice " << slice_pp(*this, var2slice(y))); + SASSERT_EQ(sy->get_root(), var2slice(yy)->get_root()); + rational const sy_slice_value = get_value(get_value_node(sy)); + // rational const sy_solver_value = mod2k(machine_div2k(m_solver.get_value(y), lo), hi - lo + 1); + // c = m_solver.eq(m_solver.var(yy), sy_slice_value); +*/ + } + else { + verbose_stream() << "build_conflict_clause (0)" << std::endl; + SASSERT(xlit == sat::null_literal); + SASSERT(ylit == sat::null_literal); + + // unsigned x_hi, x_lo, y_hi, y_lo; + // VERIFY(find_range_in_ancestor(sx, var2slice(x), x_hi, x_lo)); + // VERIFY(find_range_in_ancestor(sy, var2slice(y), y_hi, y_lo)); + // pvar const xx = mk_extract(x, x_hi, x_lo); + // pvar const yy = mk_extract(y, y_hi, y_lo); + // SASSERT_EQ(sx->get_root(), var2slice(xx)->get_root()); + // SASSERT_EQ(sy->get_root(), var2slice(yy)->get_root()); + // rational sval = mod2k(machine_div2k(m_solver.get_value(x), x_lo), x_hi - x_lo + 1); + // LOG("find_range_in_ancestor: v" << x << "[" << x_hi << ":" << x_lo << "] = " << slice_pp(*this, sx) << " --> represented by variable v" << xx); + // LOG("find_range_in_ancestor: v" << y << "[" << y_hi << ":" << y_lo << "] = " << slice_pp(*this, sy) << " --> represented by variable v" << yy); + LOG("v" << x << " is slice " << slice_pp(*this, var2slice(x))); + LOG("v" << y << " is slice " << slice_pp(*this, var2slice(y))); + if (m_solver.is_assigned(x)) LOG("v" << x << " has solver-value " << m_solver.get_value(x)); + if (m_solver.is_assigned(y)) LOG("v" << y << " has solver-value " << m_solver.get_value(y)); + // SASSERT(xx != yy); + // c = m_solver.eq(m_solver.var(xx), m_solver.var(yy)); // similar to what Algorithm 1 in BitvectorsMCSAT is doing + // LOG("c: " << lit_pp(m_solver, c)); + + rational const x_slice_value = get_value(get_value_node(var2slice(x))); + LOG("v" << x << " slice-value: " << x_slice_value); + add_conclusion(~m_solver.eq(m_solver.var(x), x_slice_value)); + + rational const y_slice_value = get_value(get_value_node(var2slice(y))); + LOG("v" << y << " slice-value: " << y_slice_value); + add_conclusion(~m_solver.eq(m_solver.var(y), y_slice_value)); + } + + // TODO: we don't need clauses like this ... rather set up the conflict core from it + + return cb.build(); + } + + void slicing::explain_value(enode* s, std::function const& on_lit, std::function const& on_var) { + SASSERT(invariant()); + SASSERT(m_marked_lits.empty()); + + enode* n = get_value_node(s); + SASSERT(is_value(n)); + + SASSERT(m_tmp_deps.empty()); + explain_equal(s, n, m_tmp_deps); + + for (void* dp : m_tmp_deps) { + dep_t const d = dep_t::decode(dp); + if (d.is_null()) + continue; + if (d.is_lit()) { + sat::literal const lit = d.lit(); + if (!m_marked_lits.contains(lit)) { + on_lit(lit); + m_marked_lits.insert(lit); + } + } + else { + SASSERT(d.is_value()); + sat::literal const lit = get_dep_lit(d); + if (lit == sat::null_literal) + on_var(get_dep_var(d)); + else if (!m_marked_lits.contains(lit)) { + on_lit(lit); + m_marked_lits.insert(lit); + } + } + } + m_tmp_deps.reset(); + m_marked_lits.reset(); + } + + void slicing::explain_value(pvar v, std::function const& on_lit, std::function const& on_var) { + explain_value(var2slice(v), on_lit, on_var); + } + + bool slicing::find_range_in_ancestor(enode* s, enode* a, unsigned& out_hi, unsigned& out_lo) { + out_hi = width(s) - 1; + out_lo = 0; + while (true) { + if (s == a) + return true; + enode* p = parent(s); + if (!p) + return false; + if (s == sub_hi(p)) { + unsigned offset = 1 + info(p).cut; + out_hi += offset; + out_lo += offset; + } + else { + SASSERT_EQ(s, sub_lo(p)); + /* range stays unchanged */ + } + s = p; + } + } + + bool slicing::is_extract(pvar x, pvar src, unsigned& out_hi, unsigned& out_lo) { + return find_range_in_ancestor(var2slice(x), var2slice(src), out_hi, out_lo); + } + + void slicing::egraph_on_merge(enode* root, enode* other) { + LOG("on_merge: root " << slice_pp(*this, root) << " other " << slice_pp(*this, other)); + if (root->interpreted()) + return; + if (root->is_equality()) { + SASSERT(other->is_equality()); + return; + } + SASSERT(!other->interpreted()); // by convention, interpreted nodes are always chosen as root + SASSERT(root != other); + SASSERT_EQ(root, root->get_root()); + SASSERT_EQ(root, other->get_root()); + + enode* const v1 = info(root).value_node; // root is the root + enode* const v2 = info(other).value_node; // 'other' was its own root before the merge + if (v1 && v2 && get_value(v1) != get_value(v2)) { + // we have a conflict. add interpreted enodes to make the egraph realize it. + enode* const i1 = mk_interpreted_value_node(v1); + enode* const i2 = mk_interpreted_value_node(v2); + m_egraph.merge(i1, v1, dep_t().encode()); + m_egraph.merge(i2, v2, dep_t().encode()); + SASSERT(is_conflict()); + return; + } + + enode* const v = v1 ? v1 : v2; + if (v && !(v1 && v2)) { + // exactly one node has a value + rational const val = get_value(v); + for (enode* n : euf::enode_class(other)) { + enode* const vn = get_value_node(n); + if (!vn) + set_value_node(n, v); + + pvar const var = slice2var(n); + if (var == null_var) + continue; + if (m_solver.is_assigned(var)) + continue; + LOG("on_merge: v" << var << " := " << val); + m_solver.assign_propagate_by_slicing(var, val); + } + } + } + + void slicing::set_value_node(enode* s, enode* value_node) { + SASSERT(!info(s).value_node); + SASSERT(is_value(value_node)); + info(s).value_node = value_node; + if (s != value_node) { + m_trail.push_back(trail_item::set_value_node); + m_enode_trail.push_back(s); + } + } + + void slicing::undo_set_value_node() { + enode* s = m_enode_trail.back(); + m_enode_trail.pop_back(); + info(s).value_node = nullptr; + } + + void slicing::egraph_on_propagate(enode* lit, enode* ante) { + // ante may be set when symmetric equality is added by congruence + if (ante) + return; + // on_propagate may be called before set_value + if (lit->value() == l_undef) + return; + SASSERT(lit->is_equality()); + SASSERT_EQ(lit->value(), l_false); + SASSERT(lit->get_lit_justification().is_external()); + m_disequality_conflict = lit; + } + + bool slicing::can_propagate() const { + if (use_var_congruences() && !m_needs_congruence.empty()) + return true; + return m_egraph.can_propagate(); + } + + void slicing::propagate() { + // m_egraph.propagate(); + if (is_conflict()) + return; + update_var_congruences(); + m_egraph.propagate(); + } + + bool slicing::egraph_merge(enode* s1, enode* s2, dep_t dep) { + LOG("egraph_merge: " << slice_pp(*this, s1) << " and " << slice_pp(*this, s2) << " by " << dep_pp(*this, dep)); + SASSERT_EQ(width(s1), width(s2)); + if (dep.is_value()) { + if (is_value(s1)) + std::swap(s1, s2); + SASSERT(is_value(s2)); + SASSERT(!is_value(s1)); // we never merge two value slices directly + if (get_dep_slice(dep) != s1) + dep = mk_var_dep(get_dep_var(dep), s1, get_dep_lit(dep)); + } + m_egraph.merge(s1, s2, dep.encode()); + return !is_conflict(); + } + + bool slicing::merge_base(enode* s1, enode* s2, dep_t dep) { + SASSERT(!has_sub(s1)); + SASSERT(!has_sub(s2)); + return egraph_merge(s1, s2, dep); + } + + bool slicing::merge(enode_vector& xs, enode_vector& ys, dep_t dep) { + while (!xs.empty()) { + SASSERT(!ys.empty()); + enode* const x = xs.back(); + enode* const y = ys.back(); + xs.pop_back(); + ys.pop_back(); + if (x == y) + continue; + if (x->get_root() == y->get_root()) { + DEBUG_CODE({ + // invariant: parents merged => base slices merged + enode_vector const x_base = get_base(x); + enode_vector const y_base = get_base(y); + SASSERT_EQ(x_base.size(), y_base.size()); + for (unsigned i = x_base.size(); i-- > 0; ) { + SASSERT_EQ(x_base[i]->get_root(), y_base[i]->get_root()); + } + }); + continue; + } +#if 0 + if (has_sub(x)) { + get_base(x, xs); + x = xs.back(); + xs.pop_back(); + } + if (has_sub(y)) { + get_base(y, ys); + y = ys.back(); + ys.pop_back(); + } + SASSERT(!has_sub(x)); + SASSERT(!has_sub(y)); + if (width(x) == width(y)) { + if (!merge_base(x, y, dep)) { + xs.clear(); + ys.clear(); + return false; + } + } + else if (width(x) > width(y)) { + // need to split x according to y + mk_slice(x, width(y) - 1, 0, xs, true); + ys.push_back(y); + } + else { + SASSERT(width(y) > width(x)); + // need to split y according to x + mk_slice(y, width(x) - 1, 0, ys, true); + xs.push_back(x); + } +#else + if (width(x) == width(y)) { + // We may merge slices if at least one of them doesn't have a subslice yet, + // because in that case all intermediate cut points will be aligned. + // NOTE: it is necessary to merge intermediate slices for value nodes, to ensure downward propagation of assignments. + bool const should_merge = (!has_sub(x) || !has_sub(y)); + // If either slice has a subdivision, we have to cut the other and advance to subslices + if (has_sub(x) || has_sub(y)) { + if (!has_sub(x)) + split(x, get_cut(y)); + if (!has_sub(y)) + split(y, get_cut(x)); + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + ys.push_back(sub_hi(y)); + ys.push_back(sub_lo(y)); + } + // We may only merge intermediate nodes after we're done with splitting (since we currently split the whole equivalence class at once) + if (should_merge) { + if (!egraph_merge(x, y, dep)) { + xs.clear(); + ys.clear(); + return false; + } + } + } + else if (width(x) > width(y)) { + if (!has_sub(x)) + split(x, width(y) - 1); + // split(x, has_sub(y) ? get_cut(y) : (width(y) - 1)); + xs.push_back(sub_hi(x)); + xs.push_back(sub_lo(x)); + ys.push_back(y); + } + else { + SASSERT(width(y) > width(x)); + if (!has_sub(y)) + split(y, width(x) - 1); + // split(y, has_sub(x) ? get_cut(x) : (width(x) - 1)); + ys.push_back(sub_hi(y)); + ys.push_back(sub_lo(y)); + xs.push_back(x); + } +#endif + } + SASSERT(ys.empty()); + return true; + } + + bool slicing::merge(enode_vector& xs, enode* y, dep_t dep) { + enode_vector& ys = m_tmp2; + SASSERT(ys.empty()); + ys.push_back(y); + return merge(xs, ys, dep); // will clear xs and ys + } + + bool slicing::merge(enode* x, enode* y, dep_t dep) { + LOG("merge: " << slice_pp(*this, x) << " and " << slice_pp(*this, y)); + SASSERT_EQ(width(x), width(y)); + if (!has_sub(x) && !has_sub(y)) + return merge_base(x, y, dep); + enode_vector& xs = m_tmp2; + enode_vector& ys = m_tmp3; + SASSERT(xs.empty()); + SASSERT(ys.empty()); + xs.push_back(x); + ys.push_back(y); + return merge(xs, ys, dep); // will clear xs and ys + } + + bool slicing::is_equal(enode* x, enode* y) { + SASSERT_EQ(width(x), width(y)); + x = x->get_root(); + y = y->get_root(); + if (x == y) + return true; + enode_vector& xs = m_tmp2; + enode_vector& ys = m_tmp3; + SASSERT(xs.empty()); + SASSERT(ys.empty()); + on_scope_exit clear_vectors([&xs, &ys](){ + xs.clear(); + ys.clear(); + }); + // TODO: we don't always have to collect the full base if intermediate nodes are already equal + get_base(x, xs); + get_base(y, ys); + if (xs.size() != ys.size()) + return false; + for (unsigned i = xs.size(); i-- > 0; ) + if (xs[i]->get_root() != ys[i]->get_root()) + return false; + return true; + } + + void slicing::get_base(enode* src, enode_vector& out_base) const { + enode_vector& todo = m_tmp1; + SASSERT(todo.empty()); + todo.push_back(src); + while (!todo.empty()) { + enode* s = todo.back(); + todo.pop_back(); + if (!has_sub(s)) + out_base.push_back(s); + else { + todo.push_back(sub_lo(s)); + todo.push_back(sub_hi(s)); + } + } + SASSERT(todo.empty()); + } + + slicing::enode_vector slicing::get_base(enode* src) const { + enode_vector out; + get_base(src, out); + return out; + } + + pvar slicing::mk_extract(enode* src, unsigned hi, unsigned lo, pvar replay_var) { + LOG("mk_extract: src=" << slice_pp(*this, src) << " hi=" << hi << " lo=" << lo); + enode_vector& slices = m_tmp3; + SASSERT(slices.empty()); + mk_slice(src, hi, lo, slices, false, false); + pvar v = null_var; + // try to re-use variable of an existing slice + if (slices.size() == 1) + v = slice2var(slices[0]); + if (replay_var != null_var && v != replay_var) { + // replayed variable should be 'fresh', unless it was a re-used variable + enode* s = var2slice(replay_var); + SASSERT(s->is_root()); + SASSERT_EQ(s->class_size(), 1); + SASSERT(!has_sub(s)); + SASSERT_EQ(width(s), hi - lo + 1); + v = replay_var; + } + // allocate new variable if we cannot reuse it + if (v == null_var) { + v = m_solver.add_var(hi - lo + 1, pvar_kind::internal); +#if 1 + // slice didn't have a variable yet; so we can re-use it for the new variable + // (we end up with a "phantom" enode that was first created for the variable) + if (slices.size() == 1) { + enode* s = slices[0]; + LOG("re-using slice " << slice_pp(*this, s) << " for new variable v" << v); + // display_tree(std::cerr, s, 0, hi, lo); + SASSERT_EQ(info(s).var, null_var); + info(m_var2slice[v]).var = null_var; // disconnect the "phantom" enode + info(s).var = v; + m_var2slice[v] = s; + } +#endif + } + // connect new variable + VERIFY(merge(slices, var2slice(v), dep_t())); + slices.reset(); + return v; + } + + void slicing::replay_extract(extract_args const& args, pvar r) { + LOG("replay_extract"); + SASSERT(r != null_var); + SASSERT(!m_extract_dedup.contains(args)); + VERIFY_EQ(mk_extract(var2slice(args.src), args.hi, args.lo, r), r); + m_extract_dedup.insert(args, r); + m_extract_trail.push_back(args); + m_trail.push_back(trail_item::mk_extract); + } + + pvar slicing::mk_extract(pvar src, unsigned hi, unsigned lo) { + LOG_H3("mk_extract: v" << src << "[" << hi << ":" << lo << "] size(v" << src << ") = " << m_solver.size(src)); + if (m_solver.size(src) == hi - lo + 1) + return src; + extract_args args{src, hi, lo}; + auto it = m_extract_dedup.find_iterator(args); + if (it != m_extract_dedup.end()) + return it->m_value; + pvar const v = mk_extract(var2slice(src), hi, lo); + m_extract_dedup.insert(args, v); + m_extract_trail.push_back(args); + m_trail.push_back(trail_item::mk_extract); + LOG("mk_extract: v" << src << "[" << hi << ":" << lo << "] = v" << v); + return v; + } + + void slicing::undo_mk_extract() { + extract_args args = m_extract_trail.back(); + m_extract_trail.pop_back(); + m_extract_dedup.remove(args); + } + + pvar slicing::mk_concat(unsigned num_args, pvar const* args, pvar replay_var) { + enode_vector& slices = m_tmp3; + SASSERT(slices.empty()); + unsigned total_width = 0; + for (unsigned i = 0; i < num_args; ++i) { + enode* s = var2slice(args[i]); + slices.push_back(s); + total_width += width(s); + } + // NOTE: we use concat nodes to deduplicate (syntactically equal) concat expressions. + // we might end up reusing variables that are not introduced by mk_concat (if we enable the variable re-use optimization in mk_extract), + // but because such congruence nodes are only added over direct descendants, we do not get unwanted dependencies from this re-use. + // (but note that the nodes from mk_concat are not only over direct descendants) + enode* concat = mk_concat_node(slices); + pvar v = slice2var(concat); + if (v != null_var) + return v; + if (replay_var != null_var) { + // replayed variable should be 'fresh' + enode* s = var2slice(replay_var); + SASSERT(s->is_root()); + SASSERT_EQ(s->class_size(), 1); + SASSERT(!has_sub(s)); + SASSERT_EQ(width(s), total_width); + v = replay_var; + } + else + v = m_solver.add_var(total_width, pvar_kind::internal); + enode* sv = var2slice(v); + VERIFY(merge(slices, sv, dep_t())); + // NOTE: add_concat_node must be done after merge to preserve the invariant: "a base slice is never equivalent to a congruence node". + add_concat_node(sv, concat); + slices.reset(); + + // don't mess with the concat_trail during replay + if (replay_var == null_var) { + concat_info ci; + ci.v = v; + ci.num_args = num_args; + ci.args_idx = m_concat_args.size(); + m_concat_trail.push_back(ci); + for (unsigned i = 0; i < num_args; ++i) + m_concat_args.push_back(args[i]); + } + m_trail.push_back(trail_item::mk_concat); + + return v; + } + + void slicing::replay_concat(unsigned num_args, pvar const* args, pvar r) { + SASSERT(r != null_var); + VERIFY_EQ(mk_concat(num_args, args, r), r); + } + + pvar slicing::mk_concat(std::initializer_list args) { + return mk_concat(args.size(), args.begin()); + } + + void slicing::add_constraint(signed_constraint c) { + LOG(c); + SASSERT(!is_conflict()); + if (!add_fixed_bits(c)) + return; + if (c->is_eq()) + add_constraint_eq(c->to_eq(), c.blit()); + } + + bool slicing::add_fixed_bits(signed_constraint c) { + // TODO: what is missing here: + // - we don't prioritize constraints that set larger bit ranges + // e.g., c1 sets 3 lower bits, and c2 sets 5 lower bits. + // slicing may have both {c1,c2} in justifications while previously we always prefer c2. + // - instead of prioritizing constraints (which is annoying to do incrementally), let subsumption take care of this issue. + // if constraint C subsumes constraint D, then we might even want to completely deactivate D in the solver? (not easy if D is on higher level than C). + // - (we could wait until propagate() to add fixed bits to the egraph. but that would only work on a single decision level.) + if (c->vars().size() != 1) + return true; + fixed_bits fb; + if (!get_fixed_bits(c, fb)) + return true; + pvar const x = c->vars()[0]; + return add_fixed_bits(x, fb.hi, fb.lo, fb.value, c.blit()); + } + + bool slicing::add_fixed_bits(pvar x, unsigned hi, unsigned lo, rational const& value, sat::literal lit) { + LOG("add_fixed_bits: v" << x << "[" << hi << ":" << lo << "] = " << value << " by " << lit_pp(m_solver, lit)); + enode_vector& xs = m_tmp3; + SASSERT(xs.empty()); + mk_slice(var2slice(x), hi, lo, xs, false, false); + enode* const sval = mk_value_slice(value, hi - lo + 1); + // 'xs' will be cleared by 'merge'. + // NOTE: the 'nullptr' argument will be fixed by 'egraph_merge' + return merge(xs, sval, mk_var_dep(x, nullptr, lit)); + } + + bool slicing::add_constraint_eq(pdd const& p, sat::literal lit) { + auto& m = p.manager(); + for (auto& [a, x] : p.linear_monomials()) { + if (a != 1 && a != m.max_value()) + continue; + pdd const body = a.is_one() ? (m.mk_var(x) - p) : (m.mk_var(x) + p); + // c is either x = body or x != body, depending on polarity + if (!add_equation(x, body, lit)) { + SASSERT(is_conflict()); + return false; + } + // without this check, when p = x - y we would handle both x = y and y = x separately + if (body.is_unary()) + break; + } + return true; + } + + // TODO: handle equations 2^k x = 2^k y? (lower n-k bits of x and y are equal) + bool slicing::add_equation(pvar x, pdd const& body, sat::literal lit) { + LOG("Equation from lit(" << lit << "): v" << x << (lit.sign() ? " != " : " = ") << body); + if (!lit.sign() && body.is_val()) { + LOG(" simple assignment"); + // Simple assignment x = value + return add_value(x, body.val(), lit); + } + enode* const sx = var2slice(x); + pvar const y = m_solver.m_names.get_name(body); + if (y == null_var) { + if (!body.is_val()) { + // TODO: register name trigger (if a name for value 'body' is created later, then merge x=y at that time) + // could also count how often 'body' was registered and introduce name when more than once. + // maybe better: register x as an existing name for 'body'? question is how to track the dependency on c. + LOG(" skip for now (unnamed body)"); + } else + LOG(" skip for now (disequality with constant)"); + return true; + } + enode* const sy = var2slice(y); + if (!lit.sign()) { + LOG(" merge v" << x << " and v" << y); + return merge(sx, sy, lit); + } + else { + LOG(" store disequality v" << x << " != v" << y); + enode* n = find_or_alloc_disequality(sx, sy, lit); + if (!m_disequality_conflict && is_equal(sx, sy)) { + add_var_congruence_if_needed(x); + add_var_congruence_if_needed(y); + m_disequality_conflict = n; + return false; + } + } + return true; + } + + bool slicing::add_value(pvar v, rational const& value, sat::literal lit) { + enode* const sv = var2slice(v); + if (get_value_node(sv) && get_value(get_value_node(sv)) == value) + return true; + enode* const sval = mk_value_slice(value, width(sv)); + return merge(sv, sval, mk_var_dep(v, sv, lit)); + } + + void slicing::add_value(pvar v, rational const& value) { + LOG("v" << v << " := " << value); + SASSERT(!is_conflict()); + (void)add_value(v, value, sat::null_literal); + } + + void slicing::collect_simple_overlaps(pvar v, pvar_vector& out) { + unsigned const first_out = out.size(); + enode* const sv = var2slice(v); + unsigned const v_width = width(sv); + enode_vector& v_base = m_tmp2; + SASSERT(v_base.empty()); + get_base(var2slice(v), v_base); + + SASSERT(all_of(m_egraph.nodes(), [](enode* n) { return !n->is_marked1(); })); + + // Collect direct sub-slices of v and their equivalences + // (these don't need any extra checks) + for (enode* s = sv; s != nullptr; s = has_sub(s) ? sub_lo(s) : nullptr) { + for (enode* n : euf::enode_class(s)) { + if (!is_proper_slice(n)) + continue; + pvar const w = slice2var(n); + if (w == null_var) + continue; + SASSERT(!n->is_marked1()); + n->mark1(); + out.push_back(w); + } + } + + // lowermost base slice of v + enode* const v_base_lo = v_base.back(); + + svector> candidates; + // Collect all other candidate variables, + // i.e., those who share the lowermost base slice with v. + for (enode* n : euf::enode_class(v_base_lo)) { + if (!is_proper_slice(n)) + continue; + if (n == v_base_lo) + continue; + enode* const n0 = n; + pvar w2 = null_var; // the highest variable we care about from this equivalence class + // examine parents to find variables + SASSERT(!has_sub(n)); + while (true) { + pvar const w = slice2var(n); + if (w != null_var && !n->is_marked1()) + w2 = w; + enode* p = parent(n); + if (!p) + break; + if (sub_lo(p) != n) // we only care about lowermost slices of variables + break; + if (width(p) > v_width) + break; + n = p; + } + if (w2 != null_var) + candidates.push_back({n0, w2}); + } + + // Check candidates + for (auto const& [n0, w2] : candidates) { + // unsigned v_next = v_base.size(); + auto v_it = v_base.rbegin(); + enode* n = n0; + SASSERT_EQ(n->get_root(), (*v_it)->get_root()); + ++v_it; + while (true) { + // here: base of n is equivalent to lower portion of base of v + pvar const w = slice2var(n); + if (w != null_var && !n->is_marked1()) { + n->mark1(); + out.push_back(w); + } + if (w == w2) + break; + // + enode* const p = parent(n); + SASSERT(p); + SASSERT_EQ(sub_lo(p), n); // otherwise not a candidate + // check if base of sub_hi(p) matches the base of v + enode_vector& p_hi_base = m_tmp3; + SASSERT(p_hi_base.empty()); + get_base(sub_hi(p), p_hi_base); + auto p_it = p_hi_base.rbegin(); + bool ok = true; + while (ok && p_it != p_hi_base.rend()) { + if (v_it == v_base.rend()) + ok = false; + else if ((*p_it)->get_root() != (*v_it)->get_root()) + ok = false; + else { + ++p_it; + ++v_it; + } + } + p_hi_base.reset(); + if (!ok) + break; + n = p; + } + } + + v_base.reset(); + for (unsigned i = out.size(); i-- > first_out; ) { + enode* n = var2slice(out[i]); + SASSERT(n->is_marked1()); + n->unmark1(); + } + SASSERT(all_of(m_egraph.nodes(), [](enode* n) { return !n->is_marked1(); })); + } + + void slicing::explain_simple_overlap(pvar v, pvar x, std::function const& on_lit) { + SASSERT(width(var2slice(x)) <= width(var2slice(v))); + SASSERT(m_marked_lits.empty()); + SASSERT(m_tmp_deps.empty()); + + if (v == x) + return; + + enode_vector& v_base = m_tmp4; + SASSERT(v_base.empty()); + get_base(var2slice(v), v_base); + enode_vector& x_base = m_tmp5; + SASSERT(x_base.empty()); + get_base(var2slice(x), x_base); + + auto v_it = v_base.rbegin(); + auto x_it = x_base.rbegin(); + while (x_it != x_base.rend()) { + SASSERT(v_it != v_base.rend()); + enode* nv = *v_it; ++v_it; + enode* nx = *x_it; ++x_it; + SASSERT_EQ(nv->get_root(), nx->get_root()); + explain_equal(nv, nx, m_tmp_deps); + } + + for (void* dp : m_tmp_deps) { + dep_t const d = dep_t::decode(dp); + if (d.is_null()) + continue; + if (d.is_lit()) { + sat::literal lit = d.lit(); + if (m_marked_lits.contains(lit)) + continue; + m_marked_lits.insert(lit); + on_lit(d.lit()); + } + else { + // equivalence between to variables cannot be due to value assignment + UNREACHABLE(); + } + } + m_marked_lits.reset(); + m_tmp_deps.reset(); + } + + void slicing::collect_fixed(pvar v, justified_fixed_bits_vector& out) { + enode_vector& base = m_tmp2; + SASSERT(base.empty()); + get_base(var2slice(v), base); + rational a; + unsigned lo = 0; + for (auto it = base.rbegin(); it != base.rend(); ++it) { + enode* const n = *it; + enode* const nv = get_value_node(n); + unsigned const w = width(n); + unsigned const hi = lo + w - 1; + if (try_get_value(nv, a)) + out.push_back({hi, lo, a, n}); + lo += w; + } + base.reset(); + } + + void slicing::explain_fixed(euf::enode* const n, std::function const& on_lit, std::function const& on_var) { + explain_value(n, on_lit, on_var); + } + + pvar_vector slicing::equivalent_vars(pvar v) const { + pvar_vector xs; + for (enode* n : euf::enode_class(var2slice(v))) { + pvar const x = slice2var(n); + if (x != null_var) + xs.push_back(x); + } + return xs; + } + + std::ostream& slicing::display(std::ostream& out) const { + enode_vector base; + for (pvar v = 0; v < m_var2slice.size(); ++v) { + out << "v" << v << ":"; + base.reset(); + enode* const vs = var2slice(v); + get_base(vs, base); + for (enode* s : base) + display(out << " ", s); + if (enode* vnode = get_value_node(vs)) + out << " [root_value: " << get_value(vnode) << "]"; + out << "\n"; + } + return out; + } + + std::ostream& slicing::display_tree(std::ostream& out) const { + for (pvar v = 0; v < m_var2slice.size(); ++v) { + out << "v" << v << ":\n"; + enode* const s = var2slice(v); + display_tree(out, s, 4, width(s) - 1, 0); + } + out << m_egraph << "\n"; + return out; + } + + std::ostream& slicing::display_tree(std::ostream& out, enode* s, unsigned indent, unsigned hi, unsigned lo) const { + out << std::string(indent, ' ') << "[" << hi << ":" << lo << "]"; + out << " id=" << s->get_id(); + out << " w=" << width(s); + if (slice2var(s) != null_var) + out << " var=v" << slice2var(s); + if (parent(s)) + out << " parent=" << parent(s)->get_id(); + if (!s->is_root()) + out << " root=" << s->get_root_id(); + if (enode* n = get_value_node(s)) + out << " value=" << get_value(n); + // if (is_proper_slice(s)) + // out << " "; + if (is_value(s)) + out << " "; + if (is_concat(s)) + out << " "; + if (is_equality(s)) + out << " "; + out << "\n"; + if (has_sub(s)) { + unsigned cut = info(s).cut; + display_tree(out, sub_hi(s), indent + 4, hi, cut + 1 + lo); + display_tree(out, sub_lo(s), indent + 4, cut + lo, lo); + } + return out; + } + + std::ostream& slicing::display(std::ostream& out, enode* s) const { + out << "{id:" << s->get_id(); + if (is_equality(s)) + return out << ",}"; + out << ",w:" << width(s); + out << ",root:" << s->get_root_id(); + if (slice2var(s) != null_var) + out << ",var:v" << slice2var(s); + if (enode* n = get_value_node(s)) + out << ",value:" << get_value(n); + if (s->interpreted()) + out << ","; + if (is_concat(s)) + out << ","; + out << "}"; + return out; + } + + bool slicing::invariant() const { + VERIFY(m_tmp1.empty()); + VERIFY(m_tmp2.empty()); + VERIFY(m_tmp3.empty()); + if (is_conflict()) // if we break a loop early on conflict, we can't guarantee that all properties are satisfied + return true; + for (enode* s : m_egraph.nodes()) { + // we use equality enodes only to track disequalities + if (s->is_equality()) + continue; + // if the slice is equivalent to a variable, then the variable's slice is in the equivalence class + pvar const v = slice2var(s); + if (v != null_var) { + VERIFY_EQ(var2slice(v)->get_root(), s->get_root()); + } + // if slice has a value, it should be propagated to its sub-slices + if (get_value_node(s) && has_sub(s)) { + VERIFY(get_value_node(sub_hi(s))); + VERIFY(get_value_node(sub_lo(s))); + } + // a base slice is never equivalent to a congruence node + if (is_slice(s) && !has_sub(s)) { + VERIFY(all_of(euf::enode_class(s), [&](enode* n) { return is_slice(n); })); + } + if (is_concat(s)) { + // all concat nodes point to a variable slice + VERIFY(slice2var(s) != null_var); + enode* sv = var2slice(slice2var(s)); // the corresponding variable slice + VERIFY(s != sv); + VERIFY(is_slice(sv)); + VERIFY(s->num_args() >= 2); + } + ///////////////////////////////////////////////////////////////// + // class properties (i.e., skip non-representatives) + if (!s->is_root()) + continue; + bool const sub = has_sub(s); + enode_vector const s_base = get_base(s); + for (enode* n : euf::enode_class(s)) { + // equivalence class only contains slices of equal length + VERIFY_EQ(width(s), width(n)); + // either all nodes in the class have subslices or none do + SASSERT_EQ(sub, has_sub(n)); + // bases of equivalent nodes are equivalent + enode_vector const n_base = get_base(n); + VERIFY_EQ(s_base.size(), n_base.size()); + for (unsigned i = s_base.size(); i-- > 0; ) { + VERIFY_EQ(s_base[i]->get_root(), n_base[i]->get_root()); + } + } + } + return true; + } + +} diff --git a/src/sat/smt/polysat/slicing.h b/src/sat/smt/polysat/slicing.h new file mode 100644 index 000000000..f9f90610b --- /dev/null +++ b/src/sat/smt/polysat/slicing.h @@ -0,0 +1,397 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + polysat slicing (relating variables of different bit-widths by extraction) + +Author: + + Jakob Rath 2023-06-01 + +Notation: + + Let x be a bit-vector of width w. + Let l, h indices such that 0 <= l <= h < w. + Then x[h:l] extracts h - l + 1 bits of x. + Shorthands: + - x[h:] stands for x[h:0], and + - x[:l] stands for x[w-1:l]. + + Example: + 0001[0:] = 1 + 0001[2:0] = 001 + +--*/ +#pragma once +#include "ast/euf/euf_egraph.h" +#include "ast/bv_decl_plugin.h" +#include "math/polysat/types.h" +#include "math/polysat/constraint.h" +#include "math/polysat/fixed_bits.h" +#include + +namespace polysat { + + class solver; + + class slicing final { + + friend class test_slicing; + + public: + using enode = euf::enode; + using enode_vector = euf::enode_vector; + using enode_pair = euf::enode_pair; + using enode_pair_vector = euf::enode_pair_vector; + + private: + class dep_t { + std::variant m_data; + public: + dep_t() { SASSERT(is_null()); } + dep_t(sat::literal l): m_data(l) { SASSERT(l != sat::null_literal); SASSERT_EQ(l, lit()); } + explicit dep_t(unsigned idx): m_data(idx) { SASSERT_EQ(idx, value_idx()); } + bool is_null() const { return std::holds_alternative(m_data); } + bool is_lit() const { return std::holds_alternative(m_data); } + bool is_value() const { return std::holds_alternative(m_data); } + sat::literal lit() const { SASSERT(is_lit()); return *std::get_if(&m_data); } + unsigned value_idx() const { SASSERT(is_value()); return *std::get_if(&m_data); } + bool operator==(dep_t other) const { return m_data == other.m_data; } + bool operator!=(dep_t other) const { return !operator==(other); } + void* encode() const; + static dep_t decode(void* p); + }; + + using dep_vector = svector; + + std::ostream& display(std::ostream& out, dep_t d) const; + + dep_t mk_var_dep(pvar v, enode* s, sat::literal lit); + + pvar_vector m_dep_var; + ptr_vector m_dep_slice; + sat::literal_vector m_dep_lit; // optional, value assignment comes from a literal "x == val" rather than a solver assignment + unsigned_vector m_dep_size_trail; + + pvar get_dep_var(dep_t d) const { return m_dep_var[d.value_idx()]; } + sat::literal get_dep_lit(dep_t d) const { return m_dep_lit[d.value_idx()]; } + enode* get_dep_slice(dep_t d) const { return m_dep_slice[d.value_idx()]; } + + static constexpr unsigned null_cut = std::numeric_limits::max(); + + // We use the following kinds of enodes: + // - proper slices (of variables) + // - value slices + // - interpreted value nodes ... these are short-lived, and only created to immediately trigger a conflict inside the egraph + // - virtual concat(...) expressions + // - equalities between enodes (to track disequalities; currently not represented in slice_info) + struct slice_info { + // Cut point: if not null_cut, the slice s has been subdivided into s[|s|-1:cut+1] and s[cut:0]. + // The cut point is relative to the parent slice (rather than a root variable, which might not be unique) + unsigned cut = null_cut; // cut point, or null_cut if no subslices + pvar var = null_var; // slice is equivalent to this variable, if any (without dependencies) + enode* parent = nullptr; // parent slice, only for proper slices (if not null: s == sub_hi(parent(s)) || s == sub_lo(parent(s))) + enode* slice = nullptr; // if enode corresponds to a concat(...) expression, this field links to the represented slice. + enode* sub_hi = nullptr; // upper subslice s[|s|-1:cut+1] + enode* sub_lo = nullptr; // lower subslice s[cut:0] + enode* value_node = nullptr; // the root of an equivalence class stores the value slice here, if any + + void reset() { *this = slice_info(); } + bool has_sub() const { return !!sub_hi; } + void set_cut(unsigned cut, enode* sub_hi, enode* sub_lo) { this->cut = cut; this->sub_hi = sub_hi; this->sub_lo = sub_lo; } + }; + using slice_info_vector = svector; + + // Return true iff n is either a proper slice or a value slice + bool is_slice(enode* n) const; + + bool is_proper_slice(enode* n) const { return !is_value(n) && is_slice(n); } + bool is_value(enode* n) const; + bool is_concat(enode* n) const; + bool is_equality(enode* n) const { return n->is_equality(); } + + solver& m_solver; + + ast_manager m_ast; + scoped_ptr m_bv; + + euf::egraph m_egraph; + slice_info_vector m_info; // indexed by enode::get_id() + enode_vector m_var2slice; // pvar -> slice + tracked_uint_set m_needs_congruence; // set of pvars that need updated concat(...) expressions + enode* m_disequality_conflict = nullptr; + + // Add an equation v = concat(s1, ..., sn) + // for each variable v with base slices s1, ..., sn + void update_var_congruences(); + void add_var_congruence(pvar v); + void add_var_congruence_if_needed(pvar v); + bool use_var_congruences() const; + + func_decl* mk_concat_decl(ptr_vector const& args); + enode* mk_concat_node(enode_vector const& slices); + enode* mk_concat_node(std::initializer_list slices) { return mk_concat_node(slices.size(), slices.begin()); } + enode* mk_concat_node(unsigned num_slices, enode* const* slices); + // Add s = concat(s1, ..., sn) + void add_concat_node(enode* s, enode* concat); + + slice_info& info(euf::enode* n); + slice_info const& info(euf::enode* n) const; + + enode* alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var); + enode* find_or_alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var); + enode* alloc_slice(unsigned width, pvar var = null_var); + enode* find_or_alloc_disequality(enode* x, enode* y, sat::literal lit); + + // Find hi, lo such that s = a[hi:lo] + bool find_range_in_ancestor(enode* s, enode* a, unsigned& out_hi, unsigned& out_lo); + + enode* var2slice(pvar v) const { return m_var2slice[v]; } + pvar slice2var(enode* s) const { return info(s).var; } + + unsigned width(enode* s) const; + + enode* parent(enode* s) const { return info(s).parent; } + + enode* get_value_node(enode* s) const { return info(s).value_node; } + void set_value_node(enode* s, enode* value_node); + + unsigned get_cut(enode* s) const { return info(s).cut; } + + bool has_sub(enode* s) const { return info(s).has_sub(); } + + /// Upper subslice (direct child, not necessarily the representative) + enode* sub_hi(enode* s) const { return info(s).sub_hi; } + + /// Lower subslice (direct child, not necessarily the representative) + enode* sub_lo(enode* s) const { return info(s).sub_lo; } + + /// sub_lo(parent(s)) or sub_hi(parent(s)), whichever is different from s. + enode* sibling(enode* s) const; + + // Retrieve (or create) a slice representing the given value. + enode* mk_value_slice(rational const& val, unsigned bit_width); + + // Turn value node into unwrapped BV constant node + enode* mk_interpreted_value_node(enode* value_slice); + + rational get_value(enode* s) const; + bool try_get_value(enode* s, rational& val) const; + + /// Split slice s into s[|s|-1:cut+1] and s[cut:0] + void split(enode* s, unsigned cut); + void split_core(enode* s, unsigned cut); + + /// Retrieve base slices s_1,...,s_n such that src == s_1 ++ ... ++ s_n (actual descendant subslices) + void get_base(enode* src, enode_vector& out_base) const; + enode_vector get_base(enode* src) const; + + /// Retrieve (or create) base slices s_1,...,s_n such that src[hi:lo] == s_1 ++ ... ++ s_n. + /// If output_full_src is true, return the new base for src, i.e., src == s_1 ++ ... ++ s_n. + /// If output_base is false, return coarsest intermediate slices instead of only base slices. + void mk_slice(enode* src, unsigned hi, unsigned lo, enode_vector& out, bool output_full_src = false, bool output_base = true); + + // Extract reason why slices x and y are in the same equivalence class + void explain_class(enode* x, enode* y, ptr_vector& out_deps); + + // Extract reason why slices x and y are equal + // (i.e., x and y have the same base, but are not necessarily in the same equivalence class) + void explain_equal(enode* x, enode* y, ptr_vector& out_deps); + + /** Explain why slice is equivalent to a value */ + void explain_value(enode* s, std::function const& on_lit, std::function const& on_var); + + /** Extract reason for conflict */ + void explain(ptr_vector& out_deps); + + /** Extract reason for x == y */ + void explain_equal(pvar x, pvar y, ptr_vector& out_deps); + + void egraph_on_make(enode* n); + void egraph_on_merge(enode* root, enode* other); + void egraph_on_propagate(enode* lit, enode* ante); + + // Merge slices in the e-graph. + bool egraph_merge(enode* s1, enode* s2, dep_t dep); + + // Merge equivalence classes of two base slices. + // Returns true if merge succeeded without conflict. + [[nodiscard]] bool merge_base(enode* s1, enode* s2, dep_t dep); + + // Merge equality x_1 ++ ... ++ x_n == y_1 ++ ... ++ y_k + // + // Precondition: + // - sequence of slices with equal total width + // - ordered from msb to lsb + // + // The argument vectors will be cleared. + // + // Returns true if merge succeeded without conflict. + [[nodiscard]] bool merge(enode_vector& xs, enode_vector& ys, dep_t dep); + [[nodiscard]] bool merge(enode_vector& xs, enode* y, dep_t dep); + [[nodiscard]] bool merge(enode* x, enode* y, dep_t dep); + + // Check whether two slices are known to be equal + bool is_equal(enode* x, enode* y); + + // deduplication of extract terms + struct extract_args { + pvar src = null_var; + unsigned hi = 0; + unsigned lo = 0; + bool operator==(extract_args const& other) const { return src == other.src && hi == other.hi && lo == other.lo; } + unsigned hash() const { return mk_mix(src, hi, lo); } + }; + using extract_args_eq = default_eq; + using extract_args_hash = obj_hash; + using extract_map = map; + extract_map m_extract_dedup; + // svector m_extract_origin; // pvar -> extract_args + // TODO: add 'm_extract_origin' (pvar -> extract_args)? 1. for dependency tracking when sharing subslice trees; 2. for easily checking if a variable is an extraction of another; 3. also makes the replay easier + // bool is_extract(pvar v) const { return m_extract_origin[v].src != null_var; } + + enum class trail_item : std::uint8_t { + add_var, + split_core, + mk_extract, + mk_concat, + set_value_node, + }; + svector m_trail; + enode_vector m_enode_trail; + svector m_extract_trail; + unsigned_vector m_scopes; + + struct concat_info { + pvar v; + unsigned num_args; + unsigned args_idx; + unsigned next_args_idx() const { return args_idx + num_args; } + }; + svector m_concat_trail; + svector m_concat_args; + + void undo_add_var(); + void undo_split_core(); + void undo_mk_extract(); + void undo_set_value_node(); + + mutable enode_vector m_tmp1; + mutable enode_vector m_tmp2; + mutable enode_vector m_tmp3; + mutable enode_vector m_tmp4; + mutable enode_vector m_tmp5; + ptr_vector m_tmp_deps; + sat::literal_set m_marked_lits; + + /** Get variable representing src[hi:lo] */ + pvar mk_extract(enode* src, unsigned hi, unsigned lo, pvar replay_var = null_var); + /** Restore r = src[hi:lo] */ + void replay_extract(extract_args const& args, pvar r); + + pvar mk_concat(unsigned num_args, pvar const* args, pvar replay_var); + void replay_concat(unsigned num_args, pvar const* args, pvar r); + + bool add_constraint_eq(pdd const& p, sat::literal lit); + bool add_equation(pvar x, pdd const& body, sat::literal lit); + bool add_value(pvar v, rational const& value, sat::literal lit); + bool add_fixed_bits(signed_constraint c); + bool add_fixed_bits(pvar x, unsigned hi, unsigned lo, rational const& value, sat::literal lit); + + bool invariant() const; + bool invariant_needs_congruence() const; + + std::ostream& display(std::ostream& out, enode* s) const; + std::ostream& display_tree(std::ostream& out, enode* s, unsigned indent, unsigned hi, unsigned lo) const; + + class slice_pp { + slicing const& s; + enode* n; + public: + slice_pp(slicing const& s, enode* n): s(s), n(n) {} + std::ostream& display(std::ostream& out) const { return s.display(out, n); } + }; + friend std::ostream& operator<<(std::ostream& out, slice_pp const& s) { return s.display(out); } + + class dep_pp { + slicing const& s; + dep_t d; + public: + dep_pp(slicing const& s, dep_t d): s(s), d(d) {} + std::ostream& display(std::ostream& out) const { return s.display(out, d); } + }; + friend std::ostream& operator<<(std::ostream& out, dep_pp const& d) { return d.display(out); } + + euf::egraph::e_pp e_pp(enode* n) const { return euf::egraph::e_pp(m_egraph, n); } + + public: + slicing(solver& s); + + void push_scope(); + void pop_scope(unsigned num_scopes = 1); + + void add_var(unsigned bit_width); + + /** Get or create variable representing x[hi:lo] */ + pvar mk_extract(pvar x, unsigned hi, unsigned lo); + + /** Get or create variable representing x1 ++ x2 ++ ... ++ xn */ + pvar mk_concat(unsigned num_args, pvar const* args) { return mk_concat(num_args, args, null_var); } + pvar mk_concat(std::initializer_list args); + + // Find hi, lo such that x = src[hi:lo]. + bool is_extract(pvar x, pvar src, unsigned& out_hi, unsigned& out_lo); + + // Track value assignments to variables (and propagate to subslices) + void add_value(pvar v, rational const& value); + void add_value(pvar v, unsigned value) { add_value(v, rational(value)); } + void add_value(pvar v, int value) { add_value(v, rational(value)); } + void add_constraint(signed_constraint c); + + bool can_propagate() const; + + // update congruences, egraph + void propagate(); + + bool is_conflict() const { return m_disequality_conflict || m_egraph.inconsistent(); } + + /** Extract conflict clause */ + clause_ref build_conflict_clause(); + + /** Explain why slicing has propagated the value assignment for v */ + void explain_value(pvar v, std::function const& on_lit, std::function const& on_var); + + /** For a given variable v, find the set of variables w such that w = v[|w|:0]. */ + void collect_simple_overlaps(pvar v, pvar_vector& out); + void explain_simple_overlap(pvar v, pvar x, std::function const& on_lit); + + struct justified_fixed_bits : public fixed_bits { + enode* just; + + justified_fixed_bits(unsigned hi, unsigned lo, rational value, enode* just): fixed_bits(hi, lo, value), just(just) {} + }; + + using justified_fixed_bits_vector = vector; + + /** Collect fixed portions of the variable v */ + void collect_fixed(pvar v, justified_fixed_bits_vector& out); + void explain_fixed(enode* just, std::function const& on_lit, std::function const& on_var); + + /** + * Collect variables that are equivalent to v (including v itself) + * + * NOTE: this might miss some variables that are equal due to equivalent base slices. With 'polysat.slicing.congruence=true' and after propagate(), it should return all equal variables. + */ + pvar_vector equivalent_vars(pvar v) const; + + /** Explain why variables x and y are equivalent */ + void explain_equal(pvar x, pvar y, std::function const& on_lit); + + std::ostream& display(std::ostream& out) const; + std::ostream& display_tree(std::ostream& out) const; + }; + + inline std::ostream& operator<<(std::ostream& out, slicing const& s) { return s.display(out); } + +} diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 9e57b8cae..999db42ea 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -254,4 +254,46 @@ namespace polysat { return expr_ref(bv.mk_bv_add(lo, hi), m); } + // walk the egraph starting with pvar for overlaps. + void solver::get_bitvector_prefixes(pvar pv, pvar_vector& out) { + theory_var v = m_pddvar2var[pv]; + euf::enode_vector todo; + uint_set seen; + unsigned lo, hi; + expr* e = nullptr; + todo.push_back(var2enode(v)); + for (unsigned i = 0; i < todo.size(); ++i) { + auto n = todo[i]->get_root(); + if (n->is_marked1()) + continue; + n->mark1(); + for (auto sib : euf::enode_class(n)) { + theory_var w = sib->get_th_var(get_id()); + + // identify prefixes + if (bv.is_concat(sib->get_expr())) + todo.push_back(sib->get_arg(0)); + if (w == euf::null_theory_var) + continue; + if (seen.contains(w)) + continue; + seen.insert(w); + auto const& p = m_var2pdd[w]; + if (p.is_var()) + out.push_back(p.var()); + } + for (auto p : euf::enode_parents(n)) { + if (p->is_marked1()) + continue; + // find prefixes: e[hi:0] a parent of n + if (bv.is_extract(p->get_expr(), lo, hi, e) && lo == 0) { + SASSERT(n == expr2enode(e)->get_root()); + todo.push_back(p); + } + } + } + for (auto n : todo) + n->get_root()->unmark1(); + } + } diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 8489619da..56ab615dc 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -133,6 +133,7 @@ namespace polysat { void propagate(dependency const& d, bool sign, dependency_vector const& deps) override; trail_stack& trail() override; bool inconsistent() const override; + void get_bitvector_prefixes(pvar v, pvar_vector& out) override; void add_lemma(vector const& lemma);