3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 08:35:31 +00:00

updates to viable

This commit is contained in:
Nikolaj Bjorner 2023-12-09 12:08:02 -08:00
parent e9c86bf3a3
commit 30c874d301
9 changed files with 2539 additions and 30 deletions

View file

@ -278,6 +278,14 @@ namespace polysat {
}
}
void core::get_bitvector_prefixes(pvar v, pvar_vector& out) {
s.get_bitvector_prefixes(v, out);
}
bool core::inconsistent() const {
return s.inconsistent();
}
void core::propagate_unsat_core() {
// default is to use unsat core:
// if core is based on viable, use s.set_lemma();

View file

@ -79,6 +79,9 @@ namespace polysat {
void propagate_assignment(pvar v, rational const& value, dependency dep);
void propagate_unsat_core();
void get_bitvector_prefixes(pvar v, pvar_vector& out);
bool inconsistent() const;
void add_watch(unsigned idx, unsigned var);
signed_constraint get_constraint(unsigned idx, bool sign);
@ -89,8 +92,7 @@ namespace polysat {
public:
core(solver_interface& s);
sat::check_result check();
sat::check_result check();
unsigned register_constraint(signed_constraint& sc, dependency d);
bool propagate();
void assign_eh(unsigned idx, bool sign, dependency const& d);

View file

@ -10,6 +10,9 @@ Author:
#pragma once
#include <variant>
#include "math/dd/dd_pdd.h"
#include "util/trail.h"
#include "util/sat_literal.h"
@ -62,6 +65,7 @@ namespace polysat {
virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0;
virtual trail_stack& trail() = 0;
virtual bool inconsistent() const = 0;
virtual void get_bitvector_prefixes(pvar v, pvar_vector& out) = 0;
};
}

View file

@ -98,42 +98,31 @@ namespace polysat {
#endif
pvar_vector overlaps;
#if 0
// TODO s.m_slicing.collect_simple_overlaps(v, overlaps);
#endif
c.get_bitvector_prefixes(v, overlaps);
std::sort(overlaps.begin(), overlaps.end(), [&](pvar x, pvar y) { return c.size(x) > c.size(y); });
uint_set widths_set;
// max size should always be present, regardless of whether we have intervals there (to make sure all fixed bits are considered)
widths_set.insert(c.size(v));
#if 0
LOG("Overlaps with v" << v << ":");
for (pvar x : overlaps) {
unsigned hi, lo;
if (s.m_slicing.is_extract(x, v, hi, lo))
LOG(" v" << x << " = v" << v << "[" << hi << ":" << lo << "]");
else
LOG(" v" << x << " not extracted from v" << v << "; size " << s.size(x));
for (layer const& l : m_units[x].get_layers()) {
widths_set.insert(l.bit_width);
}
}
#endif
for (pvar v : overlaps)
ensure_var(v);
for (pvar v : overlaps)
for (layer const& l : m_units[v].get_layers())
widths_set.insert(l.bit_width);
unsigned_vector widths;
for (unsigned w : widths_set)
widths.push_back(w);
widths.push_back(w);
LOG("Overlaps with v" << v << ":" << overlaps);
LOG("widths: " << widths);
rational const& max_value = c.var2pdd(v).max_value();
#if 0
lbool result_lo = find_on_layers(v, widths, overlaps, fbi, rational::zero(), max_value, lo);
if (result_lo == l_false)
return l_false; // conflict
if (result_lo == l_undef)
return find_viable_fallback(v, overlaps, lo, hi);
if (result_lo != l_true)
return result_lo;
if (lo == max_value) {
hi = lo;
@ -141,12 +130,294 @@ namespace polysat {
}
lbool result_hi = find_on_layers(v, widths, overlaps, fbi, lo + 1, max_value, hi);
if (result_hi == l_false)
hi = lo; // no other viable value
if (result_hi == l_undef)
return find_viable_fallback(v, overlaps, lo, hi);
#endif
return l_true;
switch (result_hi) {
case l_false:
hi = lo;
return l_true;
case l_undef:
return l_undef;
default:
return l_true;
}
}
// l_true ... found viable value
// l_false ... no viable value in [to_cover_lo;to_cover_hi[
// l_undef ... out of resources
lbool viable::find_on_layers(
pvar const v,
unsigned_vector const& widths,
pvar_vector const& overlaps,
fixed_bits_info const& fbi,
rational const& to_cover_lo,
rational const& to_cover_hi,
rational& val
) {
ptr_vector<entry> refine_todo;
ptr_vector<entry> relevant_entries;
// max number of interval refinements before falling back to the univariate solver
unsigned const refinement_budget = 100;
unsigned refinements = refinement_budget;
while (refinements--) {
relevant_entries.clear();
lbool result = find_on_layer(v, widths.size() - 1, widths, overlaps, fbi, to_cover_lo, to_cover_hi, val, refine_todo, relevant_entries);
// store bit-intervals we have used
for (entry* e : refine_todo)
intersect(v, e);
refine_todo.clear();
if (result != l_true)
return l_false;
// overlaps are sorted by variable size in descending order
// start refinement on smallest variable
// however, we probably should rotate to avoid getting stuck in refinement loop on a 'bad' constraint
bool refined = false;
for (unsigned i = overlaps.size(); i-- > 0; ) {
pvar x = overlaps[i];
rational const& mod_value = c.var2pdd(x).two_to_N();
rational x_val = mod(val, mod_value);
if (!refine_viable(x, x_val)) {
refined = true;
break;
}
}
if (!refined)
return l_true;
}
LOG("Refinement budget exhausted! Fall back to univariate solver.");
return l_undef;
}
// find viable values in half-open interval [to_cover_lo;to_cover_hi[ w.r.t. unit intervals on the given layer
//
// Returns:
// - l_true ... found value that is viable w.r.t. units and fixed bits
// - l_false ... found conflict
// - l_undef ... found no viable value in target interval [to_cover_lo;to_cover_hi[
lbool viable::find_on_layer(
pvar const v,
unsigned const w_idx,
unsigned_vector const& widths,
pvar_vector const& overlaps,
fixed_bits_info const& fbi,
rational const& to_cover_lo,
rational const& to_cover_hi,
rational& val,
ptr_vector<entry>& refine_todo,
ptr_vector<entry>& relevant_entries
) {
unsigned const w = widths[w_idx];
rational const& mod_value = rational::power_of_two(w);
unsigned const first_relevant_for_conflict = relevant_entries.size();
LOG("layer " << w << " bits, to_cover: [" << to_cover_lo << "; " << to_cover_hi << "[");
SASSERT(0 <= to_cover_lo);
SASSERT(0 <= to_cover_hi);
SASSERT(to_cover_lo < mod_value);
SASSERT(to_cover_hi <= mod_value); // full interval if to_cover_hi == mod_value
SASSERT(to_cover_lo != to_cover_hi); // non-empty search domain (but it may wrap)
// TODO: refinement of eq/diseq should happen only on the correct layer: where (one of) the coefficient(s) are odd
// for refinement, we have to choose an entry that is violated, but if there are multiple, we can choose the one on smallest domain (lowest bit-width).
// (by maintaining descending order by bit-width; and refine first that fails).
// but fixed-bit-refinement is cheap and could be done during the search.
// when we arrive at a hole the possibilities are:
// 1) go to lower bitwidth
// 2) refinement of some eq/diseq constraint (if one is violated at that point) -- defer this until point is viable for all layers and fixed bits.
// 3) refinement by using bit constraints? -- TODO: do this during search (another interval check after/before the entry_cursors)
// 4) (point is actually feasible)
// a complication is that we have to iterate over multiple lists of intervals.
// might be useful to merge them upfront to simplify the remainder of the algorithm?
// (non-trivial since prev/next pointers are embedded into entries and lists are updated by refinement)
struct entry_cursor {
entry* cur;
// entry* first;
// entry* last;
};
// find relevant interval lists
svector<entry_cursor> ecs;
for (pvar x : overlaps) {
if (c.size(x) < w) // note that overlaps are sorted by variable size descending
break;
if (entry* e = m_units[x].get_entries(w)) {
display_all(std::cerr << "units for width " << w << ":\n", 0, e, "\n");
entry_cursor ec;
ec.cur = e; // TODO: e->prev() probably makes it faster when querying 0 (can often save going around the interval list once)
// ec.first = nullptr;
// ec.last = nullptr;
ecs.push_back(ec);
}
}
rational const to_cover_len = r_interval::len(to_cover_lo, to_cover_hi, mod_value);
val = to_cover_lo;
rational progress; // = 0
SASSERT(progress.is_zero());
while (true) {
while (true) {
entry* e = nullptr;
// try to make progress using any of the relevant interval lists
for (entry_cursor& ec : ecs) {
// advance until current value 'val'
auto const [n, n_contains_val] = find_value(val, ec.cur);
// display_one(std::cerr << "found entry n: ", 0, n) << "\n";
// LOG("n_contains_val: " << n_contains_val << " val = " << val);
ec.cur = n;
if (n_contains_val) {
e = n;
break;
}
}
// when we cannot make progress by existing intervals any more, try interval from fixed bits
if (!e) {
e = refine_bits<true>(v, val, w, fbi);
if (e) {
refine_todo.push_back(e);
display_one(std::cerr << "found entry by bits: ", 0, e) << "\n";
}
}
// no more progress on current layer
if (!e)
break;
relevant_entries.push_back(e);
if (e->interval.is_full()) {
relevant_entries.clear();
relevant_entries.push_back(e); // full interval e -> all other intervals are subsumed/irrelevant
set_conflict_by_interval(v, w, relevant_entries, 0);
return l_false;
}
SASSERT(e->interval.currently_contains(val));
rational const& new_val = e->interval.hi_val();
rational const dist = distance(val, new_val, mod_value);
SASSERT(dist > 0);
val = new_val;
progress += dist;
LOG("val: " << val << " progress: " << progress);
if (progress >= mod_value) {
// covered the whole domain => conflict
set_conflict_by_interval(v, w, relevant_entries, first_relevant_for_conflict);
return l_false;
}
if (progress >= to_cover_len) {
// we covered the hole left at larger bit-width
// TODO: maybe we want to keep trying a bit longer to see if we can cover the whole domain. or maybe only if we enter this layer multiple times.
return l_undef;
}
// (another way to compute 'progress')
SASSERT_EQ(progress, distance(to_cover_lo, val, mod_value));
}
// no more progress
// => 'val' is viable w.r.t. unit intervals until current layer
if (!w_idx) {
// we are at the lowest layer
// => found viable value w.r.t. unit intervals and fixed bits
return l_true;
}
// find next covered value
rational next_val = to_cover_hi;
for (entry_cursor& ec : ecs) {
// each ec.cur is now the next interval after 'lo'
rational const& n = ec.cur->interval.lo_val();
SASSERT(r_interval::contains(ec.cur->prev()->interval.hi_val(), n, val));
if (distance(val, n, mod_value) < distance(val, next_val, mod_value))
next_val = n;
}
if (entry* e = refine_bits<false>(v, next_val, w, fbi)) {
refine_todo.push_back(e);
rational const& n = e->interval.lo_val();
SASSERT(distance(val, n, mod_value) < distance(val, next_val, mod_value));
next_val = n;
}
SASSERT(!refine_bits<true>(v, val, w, fbi));
SASSERT(val != next_val);
unsigned const lower_w = widths[w_idx - 1];
rational const lower_mod_value = rational::power_of_two(lower_w);
rational lower_cover_lo, lower_cover_hi;
if (distance(val, next_val, mod_value) >= lower_mod_value) {
// NOTE: in this case we do not get the first viable value, but the one with smallest value in the lower bits.
// this is because we start the search in the recursive case at 0.
// if this is a problem, adapt to lower_cover_lo = mod(val, lower_mod_value), lower_cover_hi = ...
lower_cover_lo = 0;
lower_cover_hi = lower_mod_value;
rational a;
lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo, relevant_entries);
VERIFY(result != l_undef);
if (result == l_false) {
SASSERT(c.inconsistent());
return l_false; // conflict
}
SASSERT(result == l_true);
// replace lower bits of 'val' by 'a'
rational const val_lower = mod(val, lower_mod_value);
val = val - val_lower + a;
if (a < val_lower)
a += lower_mod_value;
LOG("distance(val, cover_hi) = " << distance(val, to_cover_hi, mod_value));
LOG("distance(next_val, cover_hi) = " << distance(next_val, to_cover_hi, mod_value));
SASSERT(distance(val, to_cover_hi, mod_value) >= distance(next_val, to_cover_hi, mod_value));
return l_true;
}
lower_cover_lo = mod(val, lower_mod_value);
lower_cover_hi = mod(next_val, lower_mod_value);
rational a;
lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo, relevant_entries);
if (result == l_false) {
SASSERT(c.inconsistent());
return l_false; // conflict
}
// replace lower bits of 'val' by 'a'
rational const dist = distance(lower_cover_lo, a, lower_mod_value);
val += dist;
progress += dist;
LOG("distance(val, cover_hi) = " << distance(val, to_cover_hi, mod_value));
LOG("distance(next_val, cover_hi) = " << distance(next_val, to_cover_hi, mod_value));
SASSERT(distance(val, to_cover_hi, mod_value) >= distance(next_val, to_cover_hi, mod_value));
if (result == l_true)
return l_true; // done
SASSERT(result == l_undef);
if (progress >= mod_value) {
// covered the whole domain => conflict
set_conflict_by_interval(v, w, relevant_entries, first_relevant_for_conflict);
return l_false;
}
if (progress >= to_cover_len) {
// we covered the hole left at larger bit-width
return l_undef;
}
}
UNREACHABLE();
return l_undef;
}

View file

@ -162,6 +162,63 @@ namespace polysat {
lbool find_viable(pvar v, rational& lo, rational& hi);
lbool find_on_layers(
pvar v,
unsigned_vector const& widths,
pvar_vector const& overlaps,
fixed_bits_info const& fbi,
rational const& to_cover_lo,
rational const& to_cover_hi,
rational& out_val);
lbool find_on_layer(
pvar v,
unsigned w_idx,
unsigned_vector const& widths,
pvar_vector const& overlaps,
fixed_bits_info const& fbi,
rational const& to_cover_lo,
rational const& to_cover_hi,
rational& out_val,
ptr_vector<entry>& refine_todo,
ptr_vector<entry>& relevant_entries);
template <bool FORWARD>
bool refine_viable(pvar v, rational const& val, fixed_bits_info const& fbi) {
throw default_exception("nyi");
}
bool refine_viable(pvar v, rational const& val) {
throw default_exception("nyi");
}
template <bool FORWARD>
bool refine_bits(pvar v, rational const& val, fixed_bits_info const& fbi) {
throw default_exception("nyi");
}
template <bool FORWARD>
entry* refine_bits(pvar v, rational const& val, unsigned num_bits, fixed_bits_info const& fbi) {
throw default_exception("nyi");
}
bool refine_equal_lin(pvar v, rational const& val) {
throw default_exception("nyi");
}
bool refine_disequal_lin(pvar v, rational const& val) {
throw default_exception("nyi");
}
bool set_conflict_by_interval(pvar v, unsigned w, ptr_vector<entry>& intervals, unsigned first_interval) {
throw default_exception("nyi");
}
std::pair<entry*, bool> find_value(rational const& val, entry* entries) {
throw default_exception("nyi");
}
public:
viable(core& c);

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,397 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
polysat slicing (relating variables of different bit-widths by extraction)
Author:
Jakob Rath 2023-06-01
Notation:
Let x be a bit-vector of width w.
Let l, h indices such that 0 <= l <= h < w.
Then x[h:l] extracts h - l + 1 bits of x.
Shorthands:
- x[h:] stands for x[h:0], and
- x[:l] stands for x[w-1:l].
Example:
0001[0:] = 1
0001[2:0] = 001
--*/
#pragma once
#include "ast/euf/euf_egraph.h"
#include "ast/bv_decl_plugin.h"
#include "math/polysat/types.h"
#include "math/polysat/constraint.h"
#include "math/polysat/fixed_bits.h"
#include <variant>
namespace polysat {
class solver;
class slicing final {
friend class test_slicing;
public:
using enode = euf::enode;
using enode_vector = euf::enode_vector;
using enode_pair = euf::enode_pair;
using enode_pair_vector = euf::enode_pair_vector;
private:
class dep_t {
std::variant<std::monostate, sat::literal, unsigned> m_data;
public:
dep_t() { SASSERT(is_null()); }
dep_t(sat::literal l): m_data(l) { SASSERT(l != sat::null_literal); SASSERT_EQ(l, lit()); }
explicit dep_t(unsigned idx): m_data(idx) { SASSERT_EQ(idx, value_idx()); }
bool is_null() const { return std::holds_alternative<std::monostate>(m_data); }
bool is_lit() const { return std::holds_alternative<sat::literal>(m_data); }
bool is_value() const { return std::holds_alternative<unsigned>(m_data); }
sat::literal lit() const { SASSERT(is_lit()); return *std::get_if<sat::literal>(&m_data); }
unsigned value_idx() const { SASSERT(is_value()); return *std::get_if<unsigned>(&m_data); }
bool operator==(dep_t other) const { return m_data == other.m_data; }
bool operator!=(dep_t other) const { return !operator==(other); }
void* encode() const;
static dep_t decode(void* p);
};
using dep_vector = svector<dep_t>;
std::ostream& display(std::ostream& out, dep_t d) const;
dep_t mk_var_dep(pvar v, enode* s, sat::literal lit);
pvar_vector m_dep_var;
ptr_vector<enode> m_dep_slice;
sat::literal_vector m_dep_lit; // optional, value assignment comes from a literal "x == val" rather than a solver assignment
unsigned_vector m_dep_size_trail;
pvar get_dep_var(dep_t d) const { return m_dep_var[d.value_idx()]; }
sat::literal get_dep_lit(dep_t d) const { return m_dep_lit[d.value_idx()]; }
enode* get_dep_slice(dep_t d) const { return m_dep_slice[d.value_idx()]; }
static constexpr unsigned null_cut = std::numeric_limits<unsigned>::max();
// We use the following kinds of enodes:
// - proper slices (of variables)
// - value slices
// - interpreted value nodes ... these are short-lived, and only created to immediately trigger a conflict inside the egraph
// - virtual concat(...) expressions
// - equalities between enodes (to track disequalities; currently not represented in slice_info)
struct slice_info {
// Cut point: if not null_cut, the slice s has been subdivided into s[|s|-1:cut+1] and s[cut:0].
// The cut point is relative to the parent slice (rather than a root variable, which might not be unique)
unsigned cut = null_cut; // cut point, or null_cut if no subslices
pvar var = null_var; // slice is equivalent to this variable, if any (without dependencies)
enode* parent = nullptr; // parent slice, only for proper slices (if not null: s == sub_hi(parent(s)) || s == sub_lo(parent(s)))
enode* slice = nullptr; // if enode corresponds to a concat(...) expression, this field links to the represented slice.
enode* sub_hi = nullptr; // upper subslice s[|s|-1:cut+1]
enode* sub_lo = nullptr; // lower subslice s[cut:0]
enode* value_node = nullptr; // the root of an equivalence class stores the value slice here, if any
void reset() { *this = slice_info(); }
bool has_sub() const { return !!sub_hi; }
void set_cut(unsigned cut, enode* sub_hi, enode* sub_lo) { this->cut = cut; this->sub_hi = sub_hi; this->sub_lo = sub_lo; }
};
using slice_info_vector = svector<slice_info>;
// Return true iff n is either a proper slice or a value slice
bool is_slice(enode* n) const;
bool is_proper_slice(enode* n) const { return !is_value(n) && is_slice(n); }
bool is_value(enode* n) const;
bool is_concat(enode* n) const;
bool is_equality(enode* n) const { return n->is_equality(); }
solver& m_solver;
ast_manager m_ast;
scoped_ptr<bv_util> m_bv;
euf::egraph m_egraph;
slice_info_vector m_info; // indexed by enode::get_id()
enode_vector m_var2slice; // pvar -> slice
tracked_uint_set m_needs_congruence; // set of pvars that need updated concat(...) expressions
enode* m_disequality_conflict = nullptr;
// Add an equation v = concat(s1, ..., sn)
// for each variable v with base slices s1, ..., sn
void update_var_congruences();
void add_var_congruence(pvar v);
void add_var_congruence_if_needed(pvar v);
bool use_var_congruences() const;
func_decl* mk_concat_decl(ptr_vector<expr> const& args);
enode* mk_concat_node(enode_vector const& slices);
enode* mk_concat_node(std::initializer_list<enode*> slices) { return mk_concat_node(slices.size(), slices.begin()); }
enode* mk_concat_node(unsigned num_slices, enode* const* slices);
// Add s = concat(s1, ..., sn)
void add_concat_node(enode* s, enode* concat);
slice_info& info(euf::enode* n);
slice_info const& info(euf::enode* n) const;
enode* alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var);
enode* find_or_alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var);
enode* alloc_slice(unsigned width, pvar var = null_var);
enode* find_or_alloc_disequality(enode* x, enode* y, sat::literal lit);
// Find hi, lo such that s = a[hi:lo]
bool find_range_in_ancestor(enode* s, enode* a, unsigned& out_hi, unsigned& out_lo);
enode* var2slice(pvar v) const { return m_var2slice[v]; }
pvar slice2var(enode* s) const { return info(s).var; }
unsigned width(enode* s) const;
enode* parent(enode* s) const { return info(s).parent; }
enode* get_value_node(enode* s) const { return info(s).value_node; }
void set_value_node(enode* s, enode* value_node);
unsigned get_cut(enode* s) const { return info(s).cut; }
bool has_sub(enode* s) const { return info(s).has_sub(); }
/// Upper subslice (direct child, not necessarily the representative)
enode* sub_hi(enode* s) const { return info(s).sub_hi; }
/// Lower subslice (direct child, not necessarily the representative)
enode* sub_lo(enode* s) const { return info(s).sub_lo; }
/// sub_lo(parent(s)) or sub_hi(parent(s)), whichever is different from s.
enode* sibling(enode* s) const;
// Retrieve (or create) a slice representing the given value.
enode* mk_value_slice(rational const& val, unsigned bit_width);
// Turn value node into unwrapped BV constant node
enode* mk_interpreted_value_node(enode* value_slice);
rational get_value(enode* s) const;
bool try_get_value(enode* s, rational& val) const;
/// Split slice s into s[|s|-1:cut+1] and s[cut:0]
void split(enode* s, unsigned cut);
void split_core(enode* s, unsigned cut);
/// Retrieve base slices s_1,...,s_n such that src == s_1 ++ ... ++ s_n (actual descendant subslices)
void get_base(enode* src, enode_vector& out_base) const;
enode_vector get_base(enode* src) const;
/// Retrieve (or create) base slices s_1,...,s_n such that src[hi:lo] == s_1 ++ ... ++ s_n.
/// If output_full_src is true, return the new base for src, i.e., src == s_1 ++ ... ++ s_n.
/// If output_base is false, return coarsest intermediate slices instead of only base slices.
void mk_slice(enode* src, unsigned hi, unsigned lo, enode_vector& out, bool output_full_src = false, bool output_base = true);
// Extract reason why slices x and y are in the same equivalence class
void explain_class(enode* x, enode* y, ptr_vector<void>& out_deps);
// Extract reason why slices x and y are equal
// (i.e., x and y have the same base, but are not necessarily in the same equivalence class)
void explain_equal(enode* x, enode* y, ptr_vector<void>& out_deps);
/** Explain why slice is equivalent to a value */
void explain_value(enode* s, std::function<void(sat::literal)> const& on_lit, std::function<void(pvar)> const& on_var);
/** Extract reason for conflict */
void explain(ptr_vector<void>& out_deps);
/** Extract reason for x == y */
void explain_equal(pvar x, pvar y, ptr_vector<void>& out_deps);
void egraph_on_make(enode* n);
void egraph_on_merge(enode* root, enode* other);
void egraph_on_propagate(enode* lit, enode* ante);
// Merge slices in the e-graph.
bool egraph_merge(enode* s1, enode* s2, dep_t dep);
// Merge equivalence classes of two base slices.
// Returns true if merge succeeded without conflict.
[[nodiscard]] bool merge_base(enode* s1, enode* s2, dep_t dep);
// Merge equality x_1 ++ ... ++ x_n == y_1 ++ ... ++ y_k
//
// Precondition:
// - sequence of slices with equal total width
// - ordered from msb to lsb
//
// The argument vectors will be cleared.
//
// Returns true if merge succeeded without conflict.
[[nodiscard]] bool merge(enode_vector& xs, enode_vector& ys, dep_t dep);
[[nodiscard]] bool merge(enode_vector& xs, enode* y, dep_t dep);
[[nodiscard]] bool merge(enode* x, enode* y, dep_t dep);
// Check whether two slices are known to be equal
bool is_equal(enode* x, enode* y);
// deduplication of extract terms
struct extract_args {
pvar src = null_var;
unsigned hi = 0;
unsigned lo = 0;
bool operator==(extract_args const& other) const { return src == other.src && hi == other.hi && lo == other.lo; }
unsigned hash() const { return mk_mix(src, hi, lo); }
};
using extract_args_eq = default_eq<extract_args>;
using extract_args_hash = obj_hash<extract_args>;
using extract_map = map<extract_args, pvar, extract_args_hash, extract_args_eq>;
extract_map m_extract_dedup;
// svector<extract_args> m_extract_origin; // pvar -> extract_args
// TODO: add 'm_extract_origin' (pvar -> extract_args)? 1. for dependency tracking when sharing subslice trees; 2. for easily checking if a variable is an extraction of another; 3. also makes the replay easier
// bool is_extract(pvar v) const { return m_extract_origin[v].src != null_var; }
enum class trail_item : std::uint8_t {
add_var,
split_core,
mk_extract,
mk_concat,
set_value_node,
};
svector<trail_item> m_trail;
enode_vector m_enode_trail;
svector<extract_args> m_extract_trail;
unsigned_vector m_scopes;
struct concat_info {
pvar v;
unsigned num_args;
unsigned args_idx;
unsigned next_args_idx() const { return args_idx + num_args; }
};
svector<concat_info> m_concat_trail;
svector<pvar> m_concat_args;
void undo_add_var();
void undo_split_core();
void undo_mk_extract();
void undo_set_value_node();
mutable enode_vector m_tmp1;
mutable enode_vector m_tmp2;
mutable enode_vector m_tmp3;
mutable enode_vector m_tmp4;
mutable enode_vector m_tmp5;
ptr_vector<void> m_tmp_deps;
sat::literal_set m_marked_lits;
/** Get variable representing src[hi:lo] */
pvar mk_extract(enode* src, unsigned hi, unsigned lo, pvar replay_var = null_var);
/** Restore r = src[hi:lo] */
void replay_extract(extract_args const& args, pvar r);
pvar mk_concat(unsigned num_args, pvar const* args, pvar replay_var);
void replay_concat(unsigned num_args, pvar const* args, pvar r);
bool add_constraint_eq(pdd const& p, sat::literal lit);
bool add_equation(pvar x, pdd const& body, sat::literal lit);
bool add_value(pvar v, rational const& value, sat::literal lit);
bool add_fixed_bits(signed_constraint c);
bool add_fixed_bits(pvar x, unsigned hi, unsigned lo, rational const& value, sat::literal lit);
bool invariant() const;
bool invariant_needs_congruence() const;
std::ostream& display(std::ostream& out, enode* s) const;
std::ostream& display_tree(std::ostream& out, enode* s, unsigned indent, unsigned hi, unsigned lo) const;
class slice_pp {
slicing const& s;
enode* n;
public:
slice_pp(slicing const& s, enode* n): s(s), n(n) {}
std::ostream& display(std::ostream& out) const { return s.display(out, n); }
};
friend std::ostream& operator<<(std::ostream& out, slice_pp const& s) { return s.display(out); }
class dep_pp {
slicing const& s;
dep_t d;
public:
dep_pp(slicing const& s, dep_t d): s(s), d(d) {}
std::ostream& display(std::ostream& out) const { return s.display(out, d); }
};
friend std::ostream& operator<<(std::ostream& out, dep_pp const& d) { return d.display(out); }
euf::egraph::e_pp e_pp(enode* n) const { return euf::egraph::e_pp(m_egraph, n); }
public:
slicing(solver& s);
void push_scope();
void pop_scope(unsigned num_scopes = 1);
void add_var(unsigned bit_width);
/** Get or create variable representing x[hi:lo] */
pvar mk_extract(pvar x, unsigned hi, unsigned lo);
/** Get or create variable representing x1 ++ x2 ++ ... ++ xn */
pvar mk_concat(unsigned num_args, pvar const* args) { return mk_concat(num_args, args, null_var); }
pvar mk_concat(std::initializer_list<pvar> args);
// Find hi, lo such that x = src[hi:lo].
bool is_extract(pvar x, pvar src, unsigned& out_hi, unsigned& out_lo);
// Track value assignments to variables (and propagate to subslices)
void add_value(pvar v, rational const& value);
void add_value(pvar v, unsigned value) { add_value(v, rational(value)); }
void add_value(pvar v, int value) { add_value(v, rational(value)); }
void add_constraint(signed_constraint c);
bool can_propagate() const;
// update congruences, egraph
void propagate();
bool is_conflict() const { return m_disequality_conflict || m_egraph.inconsistent(); }
/** Extract conflict clause */
clause_ref build_conflict_clause();
/** Explain why slicing has propagated the value assignment for v */
void explain_value(pvar v, std::function<void(sat::literal)> const& on_lit, std::function<void(pvar)> const& on_var);
/** For a given variable v, find the set of variables w such that w = v[|w|:0]. */
void collect_simple_overlaps(pvar v, pvar_vector& out);
void explain_simple_overlap(pvar v, pvar x, std::function<void(sat::literal)> const& on_lit);
struct justified_fixed_bits : public fixed_bits {
enode* just;
justified_fixed_bits(unsigned hi, unsigned lo, rational value, enode* just): fixed_bits(hi, lo, value), just(just) {}
};
using justified_fixed_bits_vector = vector<justified_fixed_bits>;
/** Collect fixed portions of the variable v */
void collect_fixed(pvar v, justified_fixed_bits_vector& out);
void explain_fixed(enode* just, std::function<void(sat::literal)> const& on_lit, std::function<void(pvar)> const& on_var);
/**
* Collect variables that are equivalent to v (including v itself)
*
* NOTE: this might miss some variables that are equal due to equivalent base slices. With 'polysat.slicing.congruence=true' and after propagate(), it should return all equal variables.
*/
pvar_vector equivalent_vars(pvar v) const;
/** Explain why variables x and y are equivalent */
void explain_equal(pvar x, pvar y, std::function<void(sat::literal)> const& on_lit);
std::ostream& display(std::ostream& out) const;
std::ostream& display_tree(std::ostream& out) const;
};
inline std::ostream& operator<<(std::ostream& out, slicing const& s) { return s.display(out); }
}

View file

@ -254,4 +254,46 @@ namespace polysat {
return expr_ref(bv.mk_bv_add(lo, hi), m);
}
// walk the egraph starting with pvar for overlaps.
void solver::get_bitvector_prefixes(pvar pv, pvar_vector& out) {
theory_var v = m_pddvar2var[pv];
euf::enode_vector todo;
uint_set seen;
unsigned lo, hi;
expr* e = nullptr;
todo.push_back(var2enode(v));
for (unsigned i = 0; i < todo.size(); ++i) {
auto n = todo[i]->get_root();
if (n->is_marked1())
continue;
n->mark1();
for (auto sib : euf::enode_class(n)) {
theory_var w = sib->get_th_var(get_id());
// identify prefixes
if (bv.is_concat(sib->get_expr()))
todo.push_back(sib->get_arg(0));
if (w == euf::null_theory_var)
continue;
if (seen.contains(w))
continue;
seen.insert(w);
auto const& p = m_var2pdd[w];
if (p.is_var())
out.push_back(p.var());
}
for (auto p : euf::enode_parents(n)) {
if (p->is_marked1())
continue;
// find prefixes: e[hi:0] a parent of n
if (bv.is_extract(p->get_expr(), lo, hi, e) && lo == 0) {
SASSERT(n == expr2enode(e)->get_root());
todo.push_back(p);
}
}
}
for (auto n : todo)
n->get_root()->unmark1();
}
}

View file

@ -133,6 +133,7 @@ namespace polysat {
void propagate(dependency const& d, bool sign, dependency_vector const& deps) override;
trail_stack& trail() override;
bool inconsistent() const override;
void get_bitvector_prefixes(pvar v, pvar_vector& out) override;
void add_lemma(vector<signed_constraint> const& lemma);