3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-24 17:45:32 +00:00

slicing: track disequalities

This commit is contained in:
Jakob Rath 2023-07-19 12:04:45 +02:00
parent 970e68c70e
commit af73f26941
4 changed files with 157 additions and 17 deletions

View file

@ -213,6 +213,7 @@ namespace euf {
enode* get_target() const { return m_target; }
justification get_justification() const { return m_justification; }
justification get_lit_justification() const { return m_lit_justification; }
bool has_lbl_hash() const { return m_lbl_hash >= 0; }
unsigned char get_lbl_hash() const {

View file

@ -28,7 +28,6 @@ Example:
TODO:
- track disequalities
- track fixed bits along with enodes
- notify solver about equalities discovered by congruence
- implement query functions
@ -107,6 +106,27 @@ namespace polysat {
reg_decl_plugins(m_ast);
m_bv = alloc(bv_util, m_ast);
m_egraph.set_display_justification(display_dep);
std::function<void(enode* lit, enode* ante)> propagate_negation = [&](enode* lit, enode* ante) {
// LOG("lit: " << lit->get_id() << " value=" << lit->value());
// if (ante)
// LOG("ante: " << ante->get_id() << " value=" << ante->value());
// else
// LOG("ante: <null>");
// LOG(m_egraph);
// ante may be set when symmetric equality is added by congruence
if (ante)
return;
// on_propagate may be called before set_value
if (lit->value() == l_undef)
return;
SASSERT(lit->is_equality());
SASSERT_EQ(lit->value(), l_false);
SASSERT(lit->get_lit_justification().is_external());
// LOG("lit: id=" << lit->get_id() << " value=" << lit->value() << " dep=" << decode_dep(lit->get_lit_justification().ext<void>()));
m_disequality_conflict = lit;
};
m_egraph.set_on_propagate(propagate_negation);
}
slicing::slice_info& slicing::info(euf::enode* n) {
@ -114,6 +134,7 @@ namespace polysat {
}
slicing::slice_info const& slicing::info(euf::enode* n) const {
SASSERT(!n->is_equality());
slice_info const& i = m_info[n->get_id()];
return i.is_slice() ? i : info(i.slice);
}
@ -131,6 +152,7 @@ namespace polysat {
}
void slicing::push_scope() {
SASSERT(!is_conflict());
if (can_propagate())
propagate();
m_scopes.push_back(m_trail.size());
@ -156,6 +178,7 @@ namespace polysat {
}
m_egraph.pop(num_scopes);
m_needs_congruence.reset();
m_disequality_conflict = nullptr;
}
void slicing::add_var(unsigned bit_width) {
@ -168,6 +191,21 @@ namespace polysat {
m_var2slice.pop_back();
}
slicing::enode* slicing::find_or_alloc_disequality(enode* x, enode* y, sat::literal lit) {
expr_ref eq(m_ast.mk_eq(x->get_expr(), y->get_expr()), m_ast);
enode* eqn = m_egraph.find(eq);
if (eqn)
return eqn;
auto args = {x, y};
eqn = m_egraph.mk(eq, 0, args.size(), args.begin());
auto j = euf::justification::external(encode_dep(lit));
LOG("calling set_value");
m_egraph.set_value(eqn, l_false, j);
SASSERT(eqn->is_equality());
SASSERT_EQ(eqn->value(), l_false);
return eqn;
}
slicing::enode* slicing::alloc_enode(expr* e, unsigned num_args, enode* const* args, unsigned width, pvar var) {
SASSERT(width > 0);
SASSERT(!m_egraph.find(e));
@ -460,7 +498,20 @@ namespace polysat {
begin_explain();
SASSERT(m_tmp_justifications.empty());
m_egraph.begin_explain();
m_egraph.explain(m_tmp_justifications, nullptr);
if (m_disequality_conflict) {
enode* eqn = m_disequality_conflict;
SASSERT(eqn->is_equality());
SASSERT_EQ(eqn->value(), l_false);
SASSERT(eqn->get_lit_justification().is_external());
SASSERT(m_ast.is_eq(eqn->get_expr()));
SASSERT_EQ(eqn->get_arg(0)->get_root(), eqn->get_arg(1)->get_root());
m_egraph.explain_eq(m_tmp_justifications, nullptr, eqn->get_arg(0), eqn->get_arg(1));
push_dep(eqn->get_lit_justification().ext<void>(), out_lits, out_vars);
}
else {
SASSERT(m_egraph.inconsistent());
m_egraph.explain(m_tmp_justifications, nullptr);
}
m_egraph.end_explain();
for (void* dp : m_tmp_justifications)
push_dep(dp, out_lits, out_vars);
@ -485,7 +536,7 @@ namespace polysat {
SASSERT(!has_sub(s1));
SASSERT(!has_sub(s2));
m_egraph.merge(s1, s2, encode_dep(dep));
return !m_egraph.inconsistent();
return !is_conflict();
}
bool slicing::merge(enode_vector& xs, enode_vector& ys, dep_t dep) {
@ -662,6 +713,7 @@ namespace polysat {
}
void slicing::add_constraint(signed_constraint c) {
SASSERT(!is_conflict());
if (!c->is_eq())
return;
dep_t const d = c.blit();
@ -672,9 +724,10 @@ namespace polysat {
continue;
pdd body = a.is_one() ? (m.mk_var(x) - p) : (m.mk_var(x) + p);
// c is either x = body or x != body, depending on polarity
LOG("Equation from constraint " << c << ": v" << x << " = " << body);
LOG("Equation from lit(" << c.blit() << ") " << c << ": v" << x << " = " << body);
enode* const sx = var2slice(x);
if (body.is_val()) {
if (c.is_positive() && body.is_val()) {
LOG(" simple assignment");
// Simple assignment x = value
enode* const sval = mk_value_slice(body.val(), body.power_of_2());
if (!merge(sx, sval, d)) {
@ -685,7 +738,10 @@ namespace polysat {
}
pvar const y = m_solver.m_names.get_name(body);
if (y == null_var) {
LOG(" skip for now (unnamed body)");
// TODO: register name trigger (if a name for value 'body' is created later, then merge x=y at that time)
// could also count how often 'body' was registered and introduce name when more than once.
// maybe better: register x as an existing name for 'body'? question is how to track the dependency on c.
continue;
}
enode* const sy = var2slice(y);
@ -697,17 +753,17 @@ namespace polysat {
}
else {
SASSERT(c.is_negative());
enode* n = find_or_alloc_disequality(sy, sx, c.blit());
if (is_equal(sx, sy)) {
// TODO: conflict
NOT_IMPLEMENTED_YET();
SASSERT(is_conflict());
return;
SASSERT_EQ(m_disequality_conflict, n); // already discovered by egraph in simple examples... TODO: probably not when we need the slice congruences
// m_disequality_conflict = n;
}
}
}
}
void slicing::add_value(pvar v, rational const& val) {
SASSERT(!is_conflict());
enode* const sv = var2slice(v);
enode* const sval = mk_value_slice(val, width(sv));
(void)merge(sv, sval, v);
@ -766,6 +822,9 @@ namespace polysat {
VERIFY(m_tmp2.empty());
VERIFY(m_tmp3.empty());
for (enode* s : m_egraph.nodes()) {
// we use equality enodes only to track disequalities
if (s->is_equality())
continue;
// if the slice is equivalent to a variable, then the variable's slice is in the equivalence class
pvar const v = slice2var(s);
if (v != null_var) {
@ -779,6 +838,8 @@ namespace polysat {
VERIFY(has_value(sub_lo(s)));
}
}
// we don't need to store the width separately anymore
VERIFY_EQ(width(s), m_bv->get_bv_size(s->get_expr()));
// properties below only matter for representatives
if (!s->is_root())
continue;

View file

@ -63,10 +63,11 @@ namespace polysat {
static constexpr unsigned null_cut = std::numeric_limits<unsigned>::max();
// Kinds of slices:
// - proper (from variables)
// We use the following kinds of enodes:
// - proper slices (of variables)
// - values
// - virtual concat(...) expressions
// - equalities between enodes (to track disequalities; currently not represented in slice_info)
struct slice_info {
unsigned width = 0; // number of bits in the slice
// Cut point: if not null_cut, the slice s has been subdivided into s[|s|-1:cut+1] and s[cut:0].
@ -95,6 +96,7 @@ namespace polysat {
slice_info_vector m_info; // indexed by enode::get_id()
enode_vector m_var2slice; // pvar -> slice
tracked_uint_set m_needs_congruence; // set of pvars that need updated concat(...) expressions
enode* m_disequality_conflict = nullptr;
// Add an equation v = concat(s1, ..., sn)
// for each variable v with base slices s1, ..., sn
@ -113,6 +115,7 @@ namespace polysat {
enode* alloc_enode(expr* e, unsigned num_args, enode* const* args, unsigned width, pvar var);
enode* find_or_alloc_enode(expr* e, unsigned num_args, enode* const* args, unsigned width, pvar var);
enode* alloc_slice(unsigned width, pvar var = null_var);
enode* find_or_alloc_disequality(enode* x, enode* y, sat::literal lit);
enode* var2slice(pvar v) const { return m_var2slice[v]; }
pvar slice2var(enode* s) const { return info(s).var; }
@ -245,7 +248,7 @@ namespace polysat {
// update congruences, egraph
void propagate();
bool is_conflict() const { return m_egraph.inconsistent(); }
bool is_conflict() const { return m_disequality_conflict || m_egraph.inconsistent(); }
/** Extract reason for conflict */
void explain(sat::literal_vector& out_lits, unsigned_vector& out_vars);

View file

@ -1,6 +1,24 @@
#include "math/polysat/slicing.h"
#include "math/polysat/solver.h"
namespace {
template <typename T>
void permute_args(unsigned k, T& a, T& b, T& c) {
using std::swap;
SASSERT(k < 6);
unsigned i = k % 3;
unsigned j = k % 2;
if (i == 1)
swap(a, b);
else if (i == 2)
swap(a, c);
if (j == 1)
swap(b, c);
}
}
namespace polysat {
struct solver_scope_slicing {
@ -200,6 +218,61 @@ namespace polysat {
VERIFY(sl.invariant());
}
static void test6() {
std::cout << __func__ << "\n";
scoped_solver_slicing s;
slicing& sl = s.sl();
pdd x = s.var(s.add_var(8));
pdd y = s.var(s.add_var(8));
pdd z = s.var(s.add_var(8));
sl.add_constraint(s.eq(x, z));
sl.add_constraint(s.eq(y, z));
sl.add_constraint(s.eq(x, rational(5)));
sl.add_value(x.var(), rational(5));
sl.add_value(y.var(), rational(7));
SASSERT(sl.is_conflict());
sat::literal_vector reason_lits;
unsigned_vector reason_vars;
sl.explain(reason_lits, reason_vars);
std::cout << "Conflict: " << reason_lits << " vars " << reason_vars << "\n";
sl.display_tree(std::cout);
VERIFY(sl.invariant());
}
// x != z
// x = y
// y = z
// in various permutations
static void test7() {
std::cout << __func__ << "\n";
scoped_set_log_enabled _logging(false);
scoped_solver_slicing s;
slicing& sl = s.sl();
pdd x = s.var(s.add_var(8));
pdd y = s.var(s.add_var(8));
pdd z = s.var(s.add_var(8));
for (unsigned k = 0; k < 6; ++k) {
s.push();
signed_constraint c1 = s.diseq(x, z);
signed_constraint c2 = s.eq(x, y);
signed_constraint c3 = s.eq(y, z);
permute_args(k, c1, c2, c3);
sl.add_constraint(c1);
sl.add_constraint(c2);
sl.add_constraint(c3);
SASSERT(sl.is_conflict());
sat::literal_vector reason_lits;
unsigned_vector reason_vars;
sl.explain(reason_lits, reason_vars);
std::cout << "Conflict: " << reason_lits << " vars " << reason_vars << "\n";
// sl.display_tree(std::cout);
VERIFY(sl.invariant());
s.pop();
}
}
};
}
@ -207,10 +280,12 @@ namespace polysat {
void tst_slicing() {
using namespace polysat;
test_slicing::test1();
test_slicing::test2();
test_slicing::test3();
test_slicing::test4();
test_slicing::test5();
// test_slicing::test1();
// test_slicing::test2();
// test_slicing::test3();
// test_slicing::test4();
// test_slicing::test5();
// test_slicing::test6();
test_slicing::test7();
std::cout << "ok\n";
}