mirror of
https://github.com/Z3Prover/z3
synced 2025-04-22 00:26:38 +00:00
updates to viable
This commit is contained in:
parent
683a5dda37
commit
94ba85bb12
9 changed files with 2539 additions and 30 deletions
|
@ -278,6 +278,14 @@ namespace polysat {
|
|||
}
|
||||
}
|
||||
|
||||
void core::get_bitvector_prefixes(pvar v, pvar_vector& out) {
|
||||
s.get_bitvector_prefixes(v, out);
|
||||
}
|
||||
|
||||
bool core::inconsistent() const {
|
||||
return s.inconsistent();
|
||||
}
|
||||
|
||||
void core::propagate_unsat_core() {
|
||||
// default is to use unsat core:
|
||||
// if core is based on viable, use s.set_lemma();
|
||||
|
|
|
@ -79,6 +79,9 @@ namespace polysat {
|
|||
void propagate_assignment(pvar v, rational const& value, dependency dep);
|
||||
void propagate_unsat_core();
|
||||
|
||||
void get_bitvector_prefixes(pvar v, pvar_vector& out);
|
||||
bool inconsistent() const;
|
||||
|
||||
void add_watch(unsigned idx, unsigned var);
|
||||
|
||||
signed_constraint get_constraint(unsigned idx, bool sign);
|
||||
|
@ -89,8 +92,7 @@ namespace polysat {
|
|||
public:
|
||||
core(solver_interface& s);
|
||||
|
||||
sat::check_result check();
|
||||
|
||||
sat::check_result check();
|
||||
unsigned register_constraint(signed_constraint& sc, dependency d);
|
||||
bool propagate();
|
||||
void assign_eh(unsigned idx, bool sign, dependency const& d);
|
||||
|
|
|
@ -10,6 +10,9 @@ Author:
|
|||
#pragma once
|
||||
|
||||
#include <variant>
|
||||
|
||||
|
||||
|
||||
#include "math/dd/dd_pdd.h"
|
||||
#include "util/trail.h"
|
||||
#include "util/sat_literal.h"
|
||||
|
@ -62,6 +65,7 @@ namespace polysat {
|
|||
virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0;
|
||||
virtual trail_stack& trail() = 0;
|
||||
virtual bool inconsistent() const = 0;
|
||||
virtual void get_bitvector_prefixes(pvar v, pvar_vector& out) = 0;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -98,42 +98,31 @@ namespace polysat {
|
|||
#endif
|
||||
|
||||
pvar_vector overlaps;
|
||||
#if 0
|
||||
// TODO s.m_slicing.collect_simple_overlaps(v, overlaps);
|
||||
#endif
|
||||
c.get_bitvector_prefixes(v, overlaps);
|
||||
std::sort(overlaps.begin(), overlaps.end(), [&](pvar x, pvar y) { return c.size(x) > c.size(y); });
|
||||
|
||||
uint_set widths_set;
|
||||
// max size should always be present, regardless of whether we have intervals there (to make sure all fixed bits are considered)
|
||||
widths_set.insert(c.size(v));
|
||||
|
||||
#if 0
|
||||
LOG("Overlaps with v" << v << ":");
|
||||
for (pvar x : overlaps) {
|
||||
unsigned hi, lo;
|
||||
if (s.m_slicing.is_extract(x, v, hi, lo))
|
||||
LOG(" v" << x << " = v" << v << "[" << hi << ":" << lo << "]");
|
||||
else
|
||||
LOG(" v" << x << " not extracted from v" << v << "; size " << s.size(x));
|
||||
for (layer const& l : m_units[x].get_layers()) {
|
||||
widths_set.insert(l.bit_width);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
for (pvar v : overlaps)
|
||||
ensure_var(v);
|
||||
|
||||
for (pvar v : overlaps)
|
||||
for (layer const& l : m_units[v].get_layers())
|
||||
widths_set.insert(l.bit_width);
|
||||
|
||||
unsigned_vector widths;
|
||||
for (unsigned w : widths_set)
|
||||
widths.push_back(w);
|
||||
widths.push_back(w);
|
||||
LOG("Overlaps with v" << v << ":" << overlaps);
|
||||
LOG("widths: " << widths);
|
||||
|
||||
rational const& max_value = c.var2pdd(v).max_value();
|
||||
|
||||
#if 0
|
||||
lbool result_lo = find_on_layers(v, widths, overlaps, fbi, rational::zero(), max_value, lo);
|
||||
if (result_lo == l_false)
|
||||
return l_false; // conflict
|
||||
if (result_lo == l_undef)
|
||||
return find_viable_fallback(v, overlaps, lo, hi);
|
||||
if (result_lo != l_true)
|
||||
return result_lo;
|
||||
|
||||
if (lo == max_value) {
|
||||
hi = lo;
|
||||
|
@ -141,12 +130,294 @@ namespace polysat {
|
|||
}
|
||||
|
||||
lbool result_hi = find_on_layers(v, widths, overlaps, fbi, lo + 1, max_value, hi);
|
||||
if (result_hi == l_false)
|
||||
hi = lo; // no other viable value
|
||||
if (result_hi == l_undef)
|
||||
return find_viable_fallback(v, overlaps, lo, hi);
|
||||
#endif
|
||||
return l_true;
|
||||
|
||||
switch (result_hi) {
|
||||
case l_false:
|
||||
hi = lo;
|
||||
return l_true;
|
||||
case l_undef:
|
||||
return l_undef;
|
||||
default:
|
||||
return l_true;
|
||||
}
|
||||
}
|
||||
|
||||
// l_true ... found viable value
|
||||
// l_false ... no viable value in [to_cover_lo;to_cover_hi[
|
||||
// l_undef ... out of resources
|
||||
lbool viable::find_on_layers(
|
||||
pvar const v,
|
||||
unsigned_vector const& widths,
|
||||
pvar_vector const& overlaps,
|
||||
fixed_bits_info const& fbi,
|
||||
rational const& to_cover_lo,
|
||||
rational const& to_cover_hi,
|
||||
rational& val
|
||||
) {
|
||||
ptr_vector<entry> refine_todo;
|
||||
ptr_vector<entry> relevant_entries;
|
||||
|
||||
// max number of interval refinements before falling back to the univariate solver
|
||||
unsigned const refinement_budget = 100;
|
||||
unsigned refinements = refinement_budget;
|
||||
|
||||
while (refinements--) {
|
||||
relevant_entries.clear();
|
||||
lbool result = find_on_layer(v, widths.size() - 1, widths, overlaps, fbi, to_cover_lo, to_cover_hi, val, refine_todo, relevant_entries);
|
||||
|
||||
// store bit-intervals we have used
|
||||
for (entry* e : refine_todo)
|
||||
intersect(v, e);
|
||||
refine_todo.clear();
|
||||
|
||||
if (result != l_true)
|
||||
return l_false;
|
||||
|
||||
// overlaps are sorted by variable size in descending order
|
||||
// start refinement on smallest variable
|
||||
// however, we probably should rotate to avoid getting stuck in refinement loop on a 'bad' constraint
|
||||
bool refined = false;
|
||||
for (unsigned i = overlaps.size(); i-- > 0; ) {
|
||||
pvar x = overlaps[i];
|
||||
rational const& mod_value = c.var2pdd(x).two_to_N();
|
||||
rational x_val = mod(val, mod_value);
|
||||
if (!refine_viable(x, x_val)) {
|
||||
refined = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!refined)
|
||||
return l_true;
|
||||
}
|
||||
|
||||
LOG("Refinement budget exhausted! Fall back to univariate solver.");
|
||||
return l_undef;
|
||||
}
|
||||
|
||||
// find viable values in half-open interval [to_cover_lo;to_cover_hi[ w.r.t. unit intervals on the given layer
|
||||
//
|
||||
// Returns:
|
||||
// - l_true ... found value that is viable w.r.t. units and fixed bits
|
||||
// - l_false ... found conflict
|
||||
// - l_undef ... found no viable value in target interval [to_cover_lo;to_cover_hi[
|
||||
lbool viable::find_on_layer(
|
||||
pvar const v,
|
||||
unsigned const w_idx,
|
||||
unsigned_vector const& widths,
|
||||
pvar_vector const& overlaps,
|
||||
fixed_bits_info const& fbi,
|
||||
rational const& to_cover_lo,
|
||||
rational const& to_cover_hi,
|
||||
rational& val,
|
||||
ptr_vector<entry>& refine_todo,
|
||||
ptr_vector<entry>& relevant_entries
|
||||
) {
|
||||
unsigned const w = widths[w_idx];
|
||||
rational const& mod_value = rational::power_of_two(w);
|
||||
unsigned const first_relevant_for_conflict = relevant_entries.size();
|
||||
|
||||
LOG("layer " << w << " bits, to_cover: [" << to_cover_lo << "; " << to_cover_hi << "[");
|
||||
SASSERT(0 <= to_cover_lo);
|
||||
SASSERT(0 <= to_cover_hi);
|
||||
SASSERT(to_cover_lo < mod_value);
|
||||
SASSERT(to_cover_hi <= mod_value); // full interval if to_cover_hi == mod_value
|
||||
SASSERT(to_cover_lo != to_cover_hi); // non-empty search domain (but it may wrap)
|
||||
|
||||
// TODO: refinement of eq/diseq should happen only on the correct layer: where (one of) the coefficient(s) are odd
|
||||
// for refinement, we have to choose an entry that is violated, but if there are multiple, we can choose the one on smallest domain (lowest bit-width).
|
||||
// (by maintaining descending order by bit-width; and refine first that fails).
|
||||
// but fixed-bit-refinement is cheap and could be done during the search.
|
||||
|
||||
// when we arrive at a hole the possibilities are:
|
||||
// 1) go to lower bitwidth
|
||||
// 2) refinement of some eq/diseq constraint (if one is violated at that point) -- defer this until point is viable for all layers and fixed bits.
|
||||
// 3) refinement by using bit constraints? -- TODO: do this during search (another interval check after/before the entry_cursors)
|
||||
// 4) (point is actually feasible)
|
||||
|
||||
// a complication is that we have to iterate over multiple lists of intervals.
|
||||
// might be useful to merge them upfront to simplify the remainder of the algorithm?
|
||||
// (non-trivial since prev/next pointers are embedded into entries and lists are updated by refinement)
|
||||
struct entry_cursor {
|
||||
entry* cur;
|
||||
// entry* first;
|
||||
// entry* last;
|
||||
};
|
||||
|
||||
// find relevant interval lists
|
||||
svector<entry_cursor> ecs;
|
||||
for (pvar x : overlaps) {
|
||||
if (c.size(x) < w) // note that overlaps are sorted by variable size descending
|
||||
break;
|
||||
if (entry* e = m_units[x].get_entries(w)) {
|
||||
display_all(std::cerr << "units for width " << w << ":\n", 0, e, "\n");
|
||||
entry_cursor ec;
|
||||
ec.cur = e; // TODO: e->prev() probably makes it faster when querying 0 (can often save going around the interval list once)
|
||||
// ec.first = nullptr;
|
||||
// ec.last = nullptr;
|
||||
ecs.push_back(ec);
|
||||
}
|
||||
}
|
||||
|
||||
rational const to_cover_len = r_interval::len(to_cover_lo, to_cover_hi, mod_value);
|
||||
val = to_cover_lo;
|
||||
|
||||
rational progress; // = 0
|
||||
SASSERT(progress.is_zero());
|
||||
while (true) {
|
||||
while (true) {
|
||||
entry* e = nullptr;
|
||||
|
||||
// try to make progress using any of the relevant interval lists
|
||||
for (entry_cursor& ec : ecs) {
|
||||
// advance until current value 'val'
|
||||
auto const [n, n_contains_val] = find_value(val, ec.cur);
|
||||
// display_one(std::cerr << "found entry n: ", 0, n) << "\n";
|
||||
// LOG("n_contains_val: " << n_contains_val << " val = " << val);
|
||||
ec.cur = n;
|
||||
if (n_contains_val) {
|
||||
e = n;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// when we cannot make progress by existing intervals any more, try interval from fixed bits
|
||||
if (!e) {
|
||||
e = refine_bits<true>(v, val, w, fbi);
|
||||
if (e) {
|
||||
refine_todo.push_back(e);
|
||||
display_one(std::cerr << "found entry by bits: ", 0, e) << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
// no more progress on current layer
|
||||
if (!e)
|
||||
break;
|
||||
|
||||
relevant_entries.push_back(e);
|
||||
|
||||
if (e->interval.is_full()) {
|
||||
relevant_entries.clear();
|
||||
relevant_entries.push_back(e); // full interval e -> all other intervals are subsumed/irrelevant
|
||||
set_conflict_by_interval(v, w, relevant_entries, 0);
|
||||
return l_false;
|
||||
}
|
||||
|
||||
SASSERT(e->interval.currently_contains(val));
|
||||
rational const& new_val = e->interval.hi_val();
|
||||
rational const dist = distance(val, new_val, mod_value);
|
||||
SASSERT(dist > 0);
|
||||
val = new_val;
|
||||
progress += dist;
|
||||
LOG("val: " << val << " progress: " << progress);
|
||||
|
||||
if (progress >= mod_value) {
|
||||
// covered the whole domain => conflict
|
||||
set_conflict_by_interval(v, w, relevant_entries, first_relevant_for_conflict);
|
||||
return l_false;
|
||||
}
|
||||
if (progress >= to_cover_len) {
|
||||
// we covered the hole left at larger bit-width
|
||||
// TODO: maybe we want to keep trying a bit longer to see if we can cover the whole domain. or maybe only if we enter this layer multiple times.
|
||||
return l_undef;
|
||||
}
|
||||
|
||||
// (another way to compute 'progress')
|
||||
SASSERT_EQ(progress, distance(to_cover_lo, val, mod_value));
|
||||
}
|
||||
|
||||
// no more progress
|
||||
// => 'val' is viable w.r.t. unit intervals until current layer
|
||||
|
||||
if (!w_idx) {
|
||||
// we are at the lowest layer
|
||||
// => found viable value w.r.t. unit intervals and fixed bits
|
||||
return l_true;
|
||||
}
|
||||
|
||||
// find next covered value
|
||||
rational next_val = to_cover_hi;
|
||||
for (entry_cursor& ec : ecs) {
|
||||
// each ec.cur is now the next interval after 'lo'
|
||||
rational const& n = ec.cur->interval.lo_val();
|
||||
SASSERT(r_interval::contains(ec.cur->prev()->interval.hi_val(), n, val));
|
||||
if (distance(val, n, mod_value) < distance(val, next_val, mod_value))
|
||||
next_val = n;
|
||||
}
|
||||
if (entry* e = refine_bits<false>(v, next_val, w, fbi)) {
|
||||
refine_todo.push_back(e);
|
||||
rational const& n = e->interval.lo_val();
|
||||
SASSERT(distance(val, n, mod_value) < distance(val, next_val, mod_value));
|
||||
next_val = n;
|
||||
}
|
||||
SASSERT(!refine_bits<true>(v, val, w, fbi));
|
||||
SASSERT(val != next_val);
|
||||
|
||||
unsigned const lower_w = widths[w_idx - 1];
|
||||
rational const lower_mod_value = rational::power_of_two(lower_w);
|
||||
|
||||
rational lower_cover_lo, lower_cover_hi;
|
||||
if (distance(val, next_val, mod_value) >= lower_mod_value) {
|
||||
// NOTE: in this case we do not get the first viable value, but the one with smallest value in the lower bits.
|
||||
// this is because we start the search in the recursive case at 0.
|
||||
// if this is a problem, adapt to lower_cover_lo = mod(val, lower_mod_value), lower_cover_hi = ...
|
||||
lower_cover_lo = 0;
|
||||
lower_cover_hi = lower_mod_value;
|
||||
rational a;
|
||||
lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo, relevant_entries);
|
||||
VERIFY(result != l_undef);
|
||||
if (result == l_false) {
|
||||
SASSERT(c.inconsistent());
|
||||
return l_false; // conflict
|
||||
}
|
||||
SASSERT(result == l_true);
|
||||
// replace lower bits of 'val' by 'a'
|
||||
rational const val_lower = mod(val, lower_mod_value);
|
||||
val = val - val_lower + a;
|
||||
if (a < val_lower)
|
||||
a += lower_mod_value;
|
||||
LOG("distance(val, cover_hi) = " << distance(val, to_cover_hi, mod_value));
|
||||
LOG("distance(next_val, cover_hi) = " << distance(next_val, to_cover_hi, mod_value));
|
||||
SASSERT(distance(val, to_cover_hi, mod_value) >= distance(next_val, to_cover_hi, mod_value));
|
||||
return l_true;
|
||||
}
|
||||
|
||||
lower_cover_lo = mod(val, lower_mod_value);
|
||||
lower_cover_hi = mod(next_val, lower_mod_value);
|
||||
|
||||
rational a;
|
||||
lbool result = find_on_layer(v, w_idx - 1, widths, overlaps, fbi, lower_cover_lo, lower_cover_hi, a, refine_todo, relevant_entries);
|
||||
if (result == l_false) {
|
||||
SASSERT(c.inconsistent());
|
||||
return l_false; // conflict
|
||||
}
|
||||
|
||||
// replace lower bits of 'val' by 'a'
|
||||
rational const dist = distance(lower_cover_lo, a, lower_mod_value);
|
||||
val += dist;
|
||||
progress += dist;
|
||||
LOG("distance(val, cover_hi) = " << distance(val, to_cover_hi, mod_value));
|
||||
LOG("distance(next_val, cover_hi) = " << distance(next_val, to_cover_hi, mod_value));
|
||||
SASSERT(distance(val, to_cover_hi, mod_value) >= distance(next_val, to_cover_hi, mod_value));
|
||||
|
||||
if (result == l_true)
|
||||
return l_true; // done
|
||||
|
||||
SASSERT(result == l_undef);
|
||||
|
||||
if (progress >= mod_value) {
|
||||
// covered the whole domain => conflict
|
||||
set_conflict_by_interval(v, w, relevant_entries, first_relevant_for_conflict);
|
||||
return l_false;
|
||||
}
|
||||
|
||||
if (progress >= to_cover_len) {
|
||||
// we covered the hole left at larger bit-width
|
||||
return l_undef;
|
||||
}
|
||||
}
|
||||
UNREACHABLE();
|
||||
return l_undef;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -162,6 +162,63 @@ namespace polysat {
|
|||
|
||||
lbool find_viable(pvar v, rational& lo, rational& hi);
|
||||
|
||||
lbool find_on_layers(
|
||||
pvar v,
|
||||
unsigned_vector const& widths,
|
||||
pvar_vector const& overlaps,
|
||||
fixed_bits_info const& fbi,
|
||||
rational const& to_cover_lo,
|
||||
rational const& to_cover_hi,
|
||||
rational& out_val);
|
||||
|
||||
lbool find_on_layer(
|
||||
pvar v,
|
||||
unsigned w_idx,
|
||||
unsigned_vector const& widths,
|
||||
pvar_vector const& overlaps,
|
||||
fixed_bits_info const& fbi,
|
||||
rational const& to_cover_lo,
|
||||
rational const& to_cover_hi,
|
||||
rational& out_val,
|
||||
ptr_vector<entry>& refine_todo,
|
||||
ptr_vector<entry>& relevant_entries);
|
||||
|
||||
|
||||
template <bool FORWARD>
|
||||
bool refine_viable(pvar v, rational const& val, fixed_bits_info const& fbi) {
|
||||
throw default_exception("nyi");
|
||||
}
|
||||
|
||||
bool refine_viable(pvar v, rational const& val) {
|
||||
throw default_exception("nyi");
|
||||
}
|
||||
|
||||
template <bool FORWARD>
|
||||
bool refine_bits(pvar v, rational const& val, fixed_bits_info const& fbi) {
|
||||
throw default_exception("nyi");
|
||||
}
|
||||
|
||||
template <bool FORWARD>
|
||||
entry* refine_bits(pvar v, rational const& val, unsigned num_bits, fixed_bits_info const& fbi) {
|
||||
throw default_exception("nyi");
|
||||
}
|
||||
|
||||
bool refine_equal_lin(pvar v, rational const& val) {
|
||||
throw default_exception("nyi");
|
||||
}
|
||||
|
||||
bool refine_disequal_lin(pvar v, rational const& val) {
|
||||
throw default_exception("nyi");
|
||||
}
|
||||
|
||||
bool set_conflict_by_interval(pvar v, unsigned w, ptr_vector<entry>& intervals, unsigned first_interval) {
|
||||
throw default_exception("nyi");
|
||||
}
|
||||
|
||||
std::pair<entry*, bool> find_value(rational const& val, entry* entries) {
|
||||
throw default_exception("nyi");
|
||||
}
|
||||
|
||||
public:
|
||||
viable(core& c);
|
||||
|
||||
|
|
1727
src/sat/smt/polysat/slicing.cpp
Normal file
1727
src/sat/smt/polysat/slicing.cpp
Normal file
File diff suppressed because it is too large
Load diff
397
src/sat/smt/polysat/slicing.h
Normal file
397
src/sat/smt/polysat/slicing.h
Normal file
|
@ -0,0 +1,397 @@
|
|||
/*++
|
||||
Copyright (c) 2023 Microsoft Corporation
|
||||
|
||||
Module Name:
|
||||
|
||||
polysat slicing (relating variables of different bit-widths by extraction)
|
||||
|
||||
Author:
|
||||
|
||||
Jakob Rath 2023-06-01
|
||||
|
||||
Notation:
|
||||
|
||||
Let x be a bit-vector of width w.
|
||||
Let l, h indices such that 0 <= l <= h < w.
|
||||
Then x[h:l] extracts h - l + 1 bits of x.
|
||||
Shorthands:
|
||||
- x[h:] stands for x[h:0], and
|
||||
- x[:l] stands for x[w-1:l].
|
||||
|
||||
Example:
|
||||
0001[0:] = 1
|
||||
0001[2:0] = 001
|
||||
|
||||
--*/
|
||||
#pragma once
|
||||
#include "ast/euf/euf_egraph.h"
|
||||
#include "ast/bv_decl_plugin.h"
|
||||
#include "math/polysat/types.h"
|
||||
#include "math/polysat/constraint.h"
|
||||
#include "math/polysat/fixed_bits.h"
|
||||
#include <variant>
|
||||
|
||||
namespace polysat {
|
||||
|
||||
class solver;
|
||||
|
||||
class slicing final {
|
||||
|
||||
friend class test_slicing;
|
||||
|
||||
public:
|
||||
using enode = euf::enode;
|
||||
using enode_vector = euf::enode_vector;
|
||||
using enode_pair = euf::enode_pair;
|
||||
using enode_pair_vector = euf::enode_pair_vector;
|
||||
|
||||
private:
|
||||
class dep_t {
|
||||
std::variant<std::monostate, sat::literal, unsigned> m_data;
|
||||
public:
|
||||
dep_t() { SASSERT(is_null()); }
|
||||
dep_t(sat::literal l): m_data(l) { SASSERT(l != sat::null_literal); SASSERT_EQ(l, lit()); }
|
||||
explicit dep_t(unsigned idx): m_data(idx) { SASSERT_EQ(idx, value_idx()); }
|
||||
bool is_null() const { return std::holds_alternative<std::monostate>(m_data); }
|
||||
bool is_lit() const { return std::holds_alternative<sat::literal>(m_data); }
|
||||
bool is_value() const { return std::holds_alternative<unsigned>(m_data); }
|
||||
sat::literal lit() const { SASSERT(is_lit()); return *std::get_if<sat::literal>(&m_data); }
|
||||
unsigned value_idx() const { SASSERT(is_value()); return *std::get_if<unsigned>(&m_data); }
|
||||
bool operator==(dep_t other) const { return m_data == other.m_data; }
|
||||
bool operator!=(dep_t other) const { return !operator==(other); }
|
||||
void* encode() const;
|
||||
static dep_t decode(void* p);
|
||||
};
|
||||
|
||||
using dep_vector = svector<dep_t>;
|
||||
|
||||
std::ostream& display(std::ostream& out, dep_t d) const;
|
||||
|
||||
dep_t mk_var_dep(pvar v, enode* s, sat::literal lit);
|
||||
|
||||
pvar_vector m_dep_var;
|
||||
ptr_vector<enode> m_dep_slice;
|
||||
sat::literal_vector m_dep_lit; // optional, value assignment comes from a literal "x == val" rather than a solver assignment
|
||||
unsigned_vector m_dep_size_trail;
|
||||
|
||||
pvar get_dep_var(dep_t d) const { return m_dep_var[d.value_idx()]; }
|
||||
sat::literal get_dep_lit(dep_t d) const { return m_dep_lit[d.value_idx()]; }
|
||||
enode* get_dep_slice(dep_t d) const { return m_dep_slice[d.value_idx()]; }
|
||||
|
||||
static constexpr unsigned null_cut = std::numeric_limits<unsigned>::max();
|
||||
|
||||
// We use the following kinds of enodes:
|
||||
// - proper slices (of variables)
|
||||
// - value slices
|
||||
// - interpreted value nodes ... these are short-lived, and only created to immediately trigger a conflict inside the egraph
|
||||
// - virtual concat(...) expressions
|
||||
// - equalities between enodes (to track disequalities; currently not represented in slice_info)
|
||||
struct slice_info {
|
||||
// Cut point: if not null_cut, the slice s has been subdivided into s[|s|-1:cut+1] and s[cut:0].
|
||||
// The cut point is relative to the parent slice (rather than a root variable, which might not be unique)
|
||||
unsigned cut = null_cut; // cut point, or null_cut if no subslices
|
||||
pvar var = null_var; // slice is equivalent to this variable, if any (without dependencies)
|
||||
enode* parent = nullptr; // parent slice, only for proper slices (if not null: s == sub_hi(parent(s)) || s == sub_lo(parent(s)))
|
||||
enode* slice = nullptr; // if enode corresponds to a concat(...) expression, this field links to the represented slice.
|
||||
enode* sub_hi = nullptr; // upper subslice s[|s|-1:cut+1]
|
||||
enode* sub_lo = nullptr; // lower subslice s[cut:0]
|
||||
enode* value_node = nullptr; // the root of an equivalence class stores the value slice here, if any
|
||||
|
||||
void reset() { *this = slice_info(); }
|
||||
bool has_sub() const { return !!sub_hi; }
|
||||
void set_cut(unsigned cut, enode* sub_hi, enode* sub_lo) { this->cut = cut; this->sub_hi = sub_hi; this->sub_lo = sub_lo; }
|
||||
};
|
||||
using slice_info_vector = svector<slice_info>;
|
||||
|
||||
// Return true iff n is either a proper slice or a value slice
|
||||
bool is_slice(enode* n) const;
|
||||
|
||||
bool is_proper_slice(enode* n) const { return !is_value(n) && is_slice(n); }
|
||||
bool is_value(enode* n) const;
|
||||
bool is_concat(enode* n) const;
|
||||
bool is_equality(enode* n) const { return n->is_equality(); }
|
||||
|
||||
solver& m_solver;
|
||||
|
||||
ast_manager m_ast;
|
||||
scoped_ptr<bv_util> m_bv;
|
||||
|
||||
euf::egraph m_egraph;
|
||||
slice_info_vector m_info; // indexed by enode::get_id()
|
||||
enode_vector m_var2slice; // pvar -> slice
|
||||
tracked_uint_set m_needs_congruence; // set of pvars that need updated concat(...) expressions
|
||||
enode* m_disequality_conflict = nullptr;
|
||||
|
||||
// Add an equation v = concat(s1, ..., sn)
|
||||
// for each variable v with base slices s1, ..., sn
|
||||
void update_var_congruences();
|
||||
void add_var_congruence(pvar v);
|
||||
void add_var_congruence_if_needed(pvar v);
|
||||
bool use_var_congruences() const;
|
||||
|
||||
func_decl* mk_concat_decl(ptr_vector<expr> const& args);
|
||||
enode* mk_concat_node(enode_vector const& slices);
|
||||
enode* mk_concat_node(std::initializer_list<enode*> slices) { return mk_concat_node(slices.size(), slices.begin()); }
|
||||
enode* mk_concat_node(unsigned num_slices, enode* const* slices);
|
||||
// Add s = concat(s1, ..., sn)
|
||||
void add_concat_node(enode* s, enode* concat);
|
||||
|
||||
slice_info& info(euf::enode* n);
|
||||
slice_info const& info(euf::enode* n) const;
|
||||
|
||||
enode* alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var);
|
||||
enode* find_or_alloc_enode(expr* e, unsigned num_args, enode* const* args, pvar var);
|
||||
enode* alloc_slice(unsigned width, pvar var = null_var);
|
||||
enode* find_or_alloc_disequality(enode* x, enode* y, sat::literal lit);
|
||||
|
||||
// Find hi, lo such that s = a[hi:lo]
|
||||
bool find_range_in_ancestor(enode* s, enode* a, unsigned& out_hi, unsigned& out_lo);
|
||||
|
||||
enode* var2slice(pvar v) const { return m_var2slice[v]; }
|
||||
pvar slice2var(enode* s) const { return info(s).var; }
|
||||
|
||||
unsigned width(enode* s) const;
|
||||
|
||||
enode* parent(enode* s) const { return info(s).parent; }
|
||||
|
||||
enode* get_value_node(enode* s) const { return info(s).value_node; }
|
||||
void set_value_node(enode* s, enode* value_node);
|
||||
|
||||
unsigned get_cut(enode* s) const { return info(s).cut; }
|
||||
|
||||
bool has_sub(enode* s) const { return info(s).has_sub(); }
|
||||
|
||||
/// Upper subslice (direct child, not necessarily the representative)
|
||||
enode* sub_hi(enode* s) const { return info(s).sub_hi; }
|
||||
|
||||
/// Lower subslice (direct child, not necessarily the representative)
|
||||
enode* sub_lo(enode* s) const { return info(s).sub_lo; }
|
||||
|
||||
/// sub_lo(parent(s)) or sub_hi(parent(s)), whichever is different from s.
|
||||
enode* sibling(enode* s) const;
|
||||
|
||||
// Retrieve (or create) a slice representing the given value.
|
||||
enode* mk_value_slice(rational const& val, unsigned bit_width);
|
||||
|
||||
// Turn value node into unwrapped BV constant node
|
||||
enode* mk_interpreted_value_node(enode* value_slice);
|
||||
|
||||
rational get_value(enode* s) const;
|
||||
bool try_get_value(enode* s, rational& val) const;
|
||||
|
||||
/// Split slice s into s[|s|-1:cut+1] and s[cut:0]
|
||||
void split(enode* s, unsigned cut);
|
||||
void split_core(enode* s, unsigned cut);
|
||||
|
||||
/// Retrieve base slices s_1,...,s_n such that src == s_1 ++ ... ++ s_n (actual descendant subslices)
|
||||
void get_base(enode* src, enode_vector& out_base) const;
|
||||
enode_vector get_base(enode* src) const;
|
||||
|
||||
/// Retrieve (or create) base slices s_1,...,s_n such that src[hi:lo] == s_1 ++ ... ++ s_n.
|
||||
/// If output_full_src is true, return the new base for src, i.e., src == s_1 ++ ... ++ s_n.
|
||||
/// If output_base is false, return coarsest intermediate slices instead of only base slices.
|
||||
void mk_slice(enode* src, unsigned hi, unsigned lo, enode_vector& out, bool output_full_src = false, bool output_base = true);
|
||||
|
||||
// Extract reason why slices x and y are in the same equivalence class
|
||||
void explain_class(enode* x, enode* y, ptr_vector<void>& out_deps);
|
||||
|
||||
// Extract reason why slices x and y are equal
|
||||
// (i.e., x and y have the same base, but are not necessarily in the same equivalence class)
|
||||
void explain_equal(enode* x, enode* y, ptr_vector<void>& out_deps);
|
||||
|
||||
/** Explain why slice is equivalent to a value */
|
||||
void explain_value(enode* s, std::function<void(sat::literal)> const& on_lit, std::function<void(pvar)> const& on_var);
|
||||
|
||||
/** Extract reason for conflict */
|
||||
void explain(ptr_vector<void>& out_deps);
|
||||
|
||||
/** Extract reason for x == y */
|
||||
void explain_equal(pvar x, pvar y, ptr_vector<void>& out_deps);
|
||||
|
||||
void egraph_on_make(enode* n);
|
||||
void egraph_on_merge(enode* root, enode* other);
|
||||
void egraph_on_propagate(enode* lit, enode* ante);
|
||||
|
||||
// Merge slices in the e-graph.
|
||||
bool egraph_merge(enode* s1, enode* s2, dep_t dep);
|
||||
|
||||
// Merge equivalence classes of two base slices.
|
||||
// Returns true if merge succeeded without conflict.
|
||||
[[nodiscard]] bool merge_base(enode* s1, enode* s2, dep_t dep);
|
||||
|
||||
// Merge equality x_1 ++ ... ++ x_n == y_1 ++ ... ++ y_k
|
||||
//
|
||||
// Precondition:
|
||||
// - sequence of slices with equal total width
|
||||
// - ordered from msb to lsb
|
||||
//
|
||||
// The argument vectors will be cleared.
|
||||
//
|
||||
// Returns true if merge succeeded without conflict.
|
||||
[[nodiscard]] bool merge(enode_vector& xs, enode_vector& ys, dep_t dep);
|
||||
[[nodiscard]] bool merge(enode_vector& xs, enode* y, dep_t dep);
|
||||
[[nodiscard]] bool merge(enode* x, enode* y, dep_t dep);
|
||||
|
||||
// Check whether two slices are known to be equal
|
||||
bool is_equal(enode* x, enode* y);
|
||||
|
||||
// deduplication of extract terms
|
||||
struct extract_args {
|
||||
pvar src = null_var;
|
||||
unsigned hi = 0;
|
||||
unsigned lo = 0;
|
||||
bool operator==(extract_args const& other) const { return src == other.src && hi == other.hi && lo == other.lo; }
|
||||
unsigned hash() const { return mk_mix(src, hi, lo); }
|
||||
};
|
||||
using extract_args_eq = default_eq<extract_args>;
|
||||
using extract_args_hash = obj_hash<extract_args>;
|
||||
using extract_map = map<extract_args, pvar, extract_args_hash, extract_args_eq>;
|
||||
extract_map m_extract_dedup;
|
||||
// svector<extract_args> m_extract_origin; // pvar -> extract_args
|
||||
// TODO: add 'm_extract_origin' (pvar -> extract_args)? 1. for dependency tracking when sharing subslice trees; 2. for easily checking if a variable is an extraction of another; 3. also makes the replay easier
|
||||
// bool is_extract(pvar v) const { return m_extract_origin[v].src != null_var; }
|
||||
|
||||
enum class trail_item : std::uint8_t {
|
||||
add_var,
|
||||
split_core,
|
||||
mk_extract,
|
||||
mk_concat,
|
||||
set_value_node,
|
||||
};
|
||||
svector<trail_item> m_trail;
|
||||
enode_vector m_enode_trail;
|
||||
svector<extract_args> m_extract_trail;
|
||||
unsigned_vector m_scopes;
|
||||
|
||||
struct concat_info {
|
||||
pvar v;
|
||||
unsigned num_args;
|
||||
unsigned args_idx;
|
||||
unsigned next_args_idx() const { return args_idx + num_args; }
|
||||
};
|
||||
svector<concat_info> m_concat_trail;
|
||||
svector<pvar> m_concat_args;
|
||||
|
||||
void undo_add_var();
|
||||
void undo_split_core();
|
||||
void undo_mk_extract();
|
||||
void undo_set_value_node();
|
||||
|
||||
mutable enode_vector m_tmp1;
|
||||
mutable enode_vector m_tmp2;
|
||||
mutable enode_vector m_tmp3;
|
||||
mutable enode_vector m_tmp4;
|
||||
mutable enode_vector m_tmp5;
|
||||
ptr_vector<void> m_tmp_deps;
|
||||
sat::literal_set m_marked_lits;
|
||||
|
||||
/** Get variable representing src[hi:lo] */
|
||||
pvar mk_extract(enode* src, unsigned hi, unsigned lo, pvar replay_var = null_var);
|
||||
/** Restore r = src[hi:lo] */
|
||||
void replay_extract(extract_args const& args, pvar r);
|
||||
|
||||
pvar mk_concat(unsigned num_args, pvar const* args, pvar replay_var);
|
||||
void replay_concat(unsigned num_args, pvar const* args, pvar r);
|
||||
|
||||
bool add_constraint_eq(pdd const& p, sat::literal lit);
|
||||
bool add_equation(pvar x, pdd const& body, sat::literal lit);
|
||||
bool add_value(pvar v, rational const& value, sat::literal lit);
|
||||
bool add_fixed_bits(signed_constraint c);
|
||||
bool add_fixed_bits(pvar x, unsigned hi, unsigned lo, rational const& value, sat::literal lit);
|
||||
|
||||
bool invariant() const;
|
||||
bool invariant_needs_congruence() const;
|
||||
|
||||
std::ostream& display(std::ostream& out, enode* s) const;
|
||||
std::ostream& display_tree(std::ostream& out, enode* s, unsigned indent, unsigned hi, unsigned lo) const;
|
||||
|
||||
class slice_pp {
|
||||
slicing const& s;
|
||||
enode* n;
|
||||
public:
|
||||
slice_pp(slicing const& s, enode* n): s(s), n(n) {}
|
||||
std::ostream& display(std::ostream& out) const { return s.display(out, n); }
|
||||
};
|
||||
friend std::ostream& operator<<(std::ostream& out, slice_pp const& s) { return s.display(out); }
|
||||
|
||||
class dep_pp {
|
||||
slicing const& s;
|
||||
dep_t d;
|
||||
public:
|
||||
dep_pp(slicing const& s, dep_t d): s(s), d(d) {}
|
||||
std::ostream& display(std::ostream& out) const { return s.display(out, d); }
|
||||
};
|
||||
friend std::ostream& operator<<(std::ostream& out, dep_pp const& d) { return d.display(out); }
|
||||
|
||||
euf::egraph::e_pp e_pp(enode* n) const { return euf::egraph::e_pp(m_egraph, n); }
|
||||
|
||||
public:
|
||||
slicing(solver& s);
|
||||
|
||||
void push_scope();
|
||||
void pop_scope(unsigned num_scopes = 1);
|
||||
|
||||
void add_var(unsigned bit_width);
|
||||
|
||||
/** Get or create variable representing x[hi:lo] */
|
||||
pvar mk_extract(pvar x, unsigned hi, unsigned lo);
|
||||
|
||||
/** Get or create variable representing x1 ++ x2 ++ ... ++ xn */
|
||||
pvar mk_concat(unsigned num_args, pvar const* args) { return mk_concat(num_args, args, null_var); }
|
||||
pvar mk_concat(std::initializer_list<pvar> args);
|
||||
|
||||
// Find hi, lo such that x = src[hi:lo].
|
||||
bool is_extract(pvar x, pvar src, unsigned& out_hi, unsigned& out_lo);
|
||||
|
||||
// Track value assignments to variables (and propagate to subslices)
|
||||
void add_value(pvar v, rational const& value);
|
||||
void add_value(pvar v, unsigned value) { add_value(v, rational(value)); }
|
||||
void add_value(pvar v, int value) { add_value(v, rational(value)); }
|
||||
void add_constraint(signed_constraint c);
|
||||
|
||||
bool can_propagate() const;
|
||||
|
||||
// update congruences, egraph
|
||||
void propagate();
|
||||
|
||||
bool is_conflict() const { return m_disequality_conflict || m_egraph.inconsistent(); }
|
||||
|
||||
/** Extract conflict clause */
|
||||
clause_ref build_conflict_clause();
|
||||
|
||||
/** Explain why slicing has propagated the value assignment for v */
|
||||
void explain_value(pvar v, std::function<void(sat::literal)> const& on_lit, std::function<void(pvar)> const& on_var);
|
||||
|
||||
/** For a given variable v, find the set of variables w such that w = v[|w|:0]. */
|
||||
void collect_simple_overlaps(pvar v, pvar_vector& out);
|
||||
void explain_simple_overlap(pvar v, pvar x, std::function<void(sat::literal)> const& on_lit);
|
||||
|
||||
struct justified_fixed_bits : public fixed_bits {
|
||||
enode* just;
|
||||
|
||||
justified_fixed_bits(unsigned hi, unsigned lo, rational value, enode* just): fixed_bits(hi, lo, value), just(just) {}
|
||||
};
|
||||
|
||||
using justified_fixed_bits_vector = vector<justified_fixed_bits>;
|
||||
|
||||
/** Collect fixed portions of the variable v */
|
||||
void collect_fixed(pvar v, justified_fixed_bits_vector& out);
|
||||
void explain_fixed(enode* just, std::function<void(sat::literal)> const& on_lit, std::function<void(pvar)> const& on_var);
|
||||
|
||||
/**
|
||||
* Collect variables that are equivalent to v (including v itself)
|
||||
*
|
||||
* NOTE: this might miss some variables that are equal due to equivalent base slices. With 'polysat.slicing.congruence=true' and after propagate(), it should return all equal variables.
|
||||
*/
|
||||
pvar_vector equivalent_vars(pvar v) const;
|
||||
|
||||
/** Explain why variables x and y are equivalent */
|
||||
void explain_equal(pvar x, pvar y, std::function<void(sat::literal)> const& on_lit);
|
||||
|
||||
std::ostream& display(std::ostream& out) const;
|
||||
std::ostream& display_tree(std::ostream& out) const;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, slicing const& s) { return s.display(out); }
|
||||
|
||||
}
|
|
@ -254,4 +254,46 @@ namespace polysat {
|
|||
return expr_ref(bv.mk_bv_add(lo, hi), m);
|
||||
}
|
||||
|
||||
// walk the egraph starting with pvar for overlaps.
|
||||
void solver::get_bitvector_prefixes(pvar pv, pvar_vector& out) {
|
||||
theory_var v = m_pddvar2var[pv];
|
||||
euf::enode_vector todo;
|
||||
uint_set seen;
|
||||
unsigned lo, hi;
|
||||
expr* e = nullptr;
|
||||
todo.push_back(var2enode(v));
|
||||
for (unsigned i = 0; i < todo.size(); ++i) {
|
||||
auto n = todo[i]->get_root();
|
||||
if (n->is_marked1())
|
||||
continue;
|
||||
n->mark1();
|
||||
for (auto sib : euf::enode_class(n)) {
|
||||
theory_var w = sib->get_th_var(get_id());
|
||||
|
||||
// identify prefixes
|
||||
if (bv.is_concat(sib->get_expr()))
|
||||
todo.push_back(sib->get_arg(0));
|
||||
if (w == euf::null_theory_var)
|
||||
continue;
|
||||
if (seen.contains(w))
|
||||
continue;
|
||||
seen.insert(w);
|
||||
auto const& p = m_var2pdd[w];
|
||||
if (p.is_var())
|
||||
out.push_back(p.var());
|
||||
}
|
||||
for (auto p : euf::enode_parents(n)) {
|
||||
if (p->is_marked1())
|
||||
continue;
|
||||
// find prefixes: e[hi:0] a parent of n
|
||||
if (bv.is_extract(p->get_expr(), lo, hi, e) && lo == 0) {
|
||||
SASSERT(n == expr2enode(e)->get_root());
|
||||
todo.push_back(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto n : todo)
|
||||
n->get_root()->unmark1();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -133,6 +133,7 @@ namespace polysat {
|
|||
void propagate(dependency const& d, bool sign, dependency_vector const& deps) override;
|
||||
trail_stack& trail() override;
|
||||
bool inconsistent() const override;
|
||||
void get_bitvector_prefixes(pvar v, pvar_vector& out) override;
|
||||
|
||||
void add_lemma(vector<signed_constraint> const& lemma);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue