diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index 9ff43367a..76c3611eb 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -213,6 +213,7 @@ namespace euf { enode* get_target() const { return m_target; } justification get_justification() const { return m_justification; } + justification get_lit_justification() const { return m_lit_justification; } bool has_lbl_hash() const { return m_lbl_hash >= 0; } unsigned char get_lbl_hash() const { diff --git a/src/math/polysat/slicing.cpp b/src/math/polysat/slicing.cpp index 34f5dd614..2d027d4ac 100644 --- a/src/math/polysat/slicing.cpp +++ b/src/math/polysat/slicing.cpp @@ -28,7 +28,6 @@ Example: TODO: -- track disequalities - track fixed bits along with enodes - notify solver about equalities discovered by congruence - implement query functions @@ -107,6 +106,27 @@ namespace polysat { reg_decl_plugins(m_ast); m_bv = alloc(bv_util, m_ast); m_egraph.set_display_justification(display_dep); + std::function propagate_negation = [&](enode* lit, enode* ante) { + // LOG("lit: " << lit->get_id() << " value=" << lit->value()); + // if (ante) + // LOG("ante: " << ante->get_id() << " value=" << ante->value()); + // else + // LOG("ante: "); + // LOG(m_egraph); + // ante may be set when symmetric equality is added by congruence + if (ante) + return; + // on_propagate may be called before set_value + if (lit->value() == l_undef) + return; + SASSERT(lit->is_equality()); + SASSERT_EQ(lit->value(), l_false); + SASSERT(lit->get_lit_justification().is_external()); + // LOG("lit: id=" << lit->get_id() << " value=" << lit->value() << " dep=" << decode_dep(lit->get_lit_justification().ext())); + m_disequality_conflict = lit; + }; + m_egraph.set_on_propagate(propagate_negation); + } slicing::slice_info& slicing::info(euf::enode* n) { @@ -114,6 +134,7 @@ namespace polysat { } slicing::slice_info const& slicing::info(euf::enode* n) const { + SASSERT(!n->is_equality()); slice_info const& i = m_info[n->get_id()]; return i.is_slice() ? i : info(i.slice); } @@ -131,6 +152,7 @@ namespace polysat { } void slicing::push_scope() { + SASSERT(!is_conflict()); if (can_propagate()) propagate(); m_scopes.push_back(m_trail.size()); @@ -156,6 +178,7 @@ namespace polysat { } m_egraph.pop(num_scopes); m_needs_congruence.reset(); + m_disequality_conflict = nullptr; } void slicing::add_var(unsigned bit_width) { @@ -168,6 +191,21 @@ namespace polysat { m_var2slice.pop_back(); } + slicing::enode* slicing::find_or_alloc_disequality(enode* x, enode* y, sat::literal lit) { + expr_ref eq(m_ast.mk_eq(x->get_expr(), y->get_expr()), m_ast); + enode* eqn = m_egraph.find(eq); + if (eqn) + return eqn; + auto args = {x, y}; + eqn = m_egraph.mk(eq, 0, args.size(), args.begin()); + auto j = euf::justification::external(encode_dep(lit)); + LOG("calling set_value"); + m_egraph.set_value(eqn, l_false, j); + SASSERT(eqn->is_equality()); + SASSERT_EQ(eqn->value(), l_false); + return eqn; + } + slicing::enode* slicing::alloc_enode(expr* e, unsigned num_args, enode* const* args, unsigned width, pvar var) { SASSERT(width > 0); SASSERT(!m_egraph.find(e)); @@ -460,7 +498,20 @@ namespace polysat { begin_explain(); SASSERT(m_tmp_justifications.empty()); m_egraph.begin_explain(); - m_egraph.explain(m_tmp_justifications, nullptr); + if (m_disequality_conflict) { + enode* eqn = m_disequality_conflict; + SASSERT(eqn->is_equality()); + SASSERT_EQ(eqn->value(), l_false); + SASSERT(eqn->get_lit_justification().is_external()); + SASSERT(m_ast.is_eq(eqn->get_expr())); + SASSERT_EQ(eqn->get_arg(0)->get_root(), eqn->get_arg(1)->get_root()); + m_egraph.explain_eq(m_tmp_justifications, nullptr, eqn->get_arg(0), eqn->get_arg(1)); + push_dep(eqn->get_lit_justification().ext(), out_lits, out_vars); + } + else { + SASSERT(m_egraph.inconsistent()); + m_egraph.explain(m_tmp_justifications, nullptr); + } m_egraph.end_explain(); for (void* dp : m_tmp_justifications) push_dep(dp, out_lits, out_vars); @@ -485,7 +536,7 @@ namespace polysat { SASSERT(!has_sub(s1)); SASSERT(!has_sub(s2)); m_egraph.merge(s1, s2, encode_dep(dep)); - return !m_egraph.inconsistent(); + return !is_conflict(); } bool slicing::merge(enode_vector& xs, enode_vector& ys, dep_t dep) { @@ -662,6 +713,7 @@ namespace polysat { } void slicing::add_constraint(signed_constraint c) { + SASSERT(!is_conflict()); if (!c->is_eq()) return; dep_t const d = c.blit(); @@ -672,9 +724,10 @@ namespace polysat { continue; pdd body = a.is_one() ? (m.mk_var(x) - p) : (m.mk_var(x) + p); // c is either x = body or x != body, depending on polarity - LOG("Equation from constraint " << c << ": v" << x << " = " << body); + LOG("Equation from lit(" << c.blit() << ") " << c << ": v" << x << " = " << body); enode* const sx = var2slice(x); - if (body.is_val()) { + if (c.is_positive() && body.is_val()) { + LOG(" simple assignment"); // Simple assignment x = value enode* const sval = mk_value_slice(body.val(), body.power_of_2()); if (!merge(sx, sval, d)) { @@ -685,7 +738,10 @@ namespace polysat { } pvar const y = m_solver.m_names.get_name(body); if (y == null_var) { + LOG(" skip for now (unnamed body)"); // TODO: register name trigger (if a name for value 'body' is created later, then merge x=y at that time) + // could also count how often 'body' was registered and introduce name when more than once. + // maybe better: register x as an existing name for 'body'? question is how to track the dependency on c. continue; } enode* const sy = var2slice(y); @@ -697,17 +753,17 @@ namespace polysat { } else { SASSERT(c.is_negative()); + enode* n = find_or_alloc_disequality(sy, sx, c.blit()); if (is_equal(sx, sy)) { - // TODO: conflict - NOT_IMPLEMENTED_YET(); - SASSERT(is_conflict()); - return; + SASSERT_EQ(m_disequality_conflict, n); // already discovered by egraph in simple examples... TODO: probably not when we need the slice congruences + // m_disequality_conflict = n; } } } } void slicing::add_value(pvar v, rational const& val) { + SASSERT(!is_conflict()); enode* const sv = var2slice(v); enode* const sval = mk_value_slice(val, width(sv)); (void)merge(sv, sval, v); @@ -766,6 +822,9 @@ namespace polysat { VERIFY(m_tmp2.empty()); VERIFY(m_tmp3.empty()); for (enode* s : m_egraph.nodes()) { + // we use equality enodes only to track disequalities + if (s->is_equality()) + continue; // if the slice is equivalent to a variable, then the variable's slice is in the equivalence class pvar const v = slice2var(s); if (v != null_var) { @@ -779,6 +838,8 @@ namespace polysat { VERIFY(has_value(sub_lo(s))); } } + // we don't need to store the width separately anymore + VERIFY_EQ(width(s), m_bv->get_bv_size(s->get_expr())); // properties below only matter for representatives if (!s->is_root()) continue; diff --git a/src/math/polysat/slicing.h b/src/math/polysat/slicing.h index a9914c469..0efd66b5a 100644 --- a/src/math/polysat/slicing.h +++ b/src/math/polysat/slicing.h @@ -63,10 +63,11 @@ namespace polysat { static constexpr unsigned null_cut = std::numeric_limits::max(); - // Kinds of slices: - // - proper (from variables) + // We use the following kinds of enodes: + // - proper slices (of variables) // - values // - virtual concat(...) expressions + // - equalities between enodes (to track disequalities; currently not represented in slice_info) struct slice_info { unsigned width = 0; // number of bits in the slice // Cut point: if not null_cut, the slice s has been subdivided into s[|s|-1:cut+1] and s[cut:0]. @@ -95,6 +96,7 @@ namespace polysat { slice_info_vector m_info; // indexed by enode::get_id() enode_vector m_var2slice; // pvar -> slice tracked_uint_set m_needs_congruence; // set of pvars that need updated concat(...) expressions + enode* m_disequality_conflict = nullptr; // Add an equation v = concat(s1, ..., sn) // for each variable v with base slices s1, ..., sn @@ -113,6 +115,7 @@ namespace polysat { enode* alloc_enode(expr* e, unsigned num_args, enode* const* args, unsigned width, pvar var); enode* find_or_alloc_enode(expr* e, unsigned num_args, enode* const* args, unsigned width, pvar var); enode* alloc_slice(unsigned width, pvar var = null_var); + enode* find_or_alloc_disequality(enode* x, enode* y, sat::literal lit); enode* var2slice(pvar v) const { return m_var2slice[v]; } pvar slice2var(enode* s) const { return info(s).var; } @@ -245,7 +248,7 @@ namespace polysat { // update congruences, egraph void propagate(); - bool is_conflict() const { return m_egraph.inconsistent(); } + bool is_conflict() const { return m_disequality_conflict || m_egraph.inconsistent(); } /** Extract reason for conflict */ void explain(sat::literal_vector& out_lits, unsigned_vector& out_vars); diff --git a/src/test/slicing.cpp b/src/test/slicing.cpp index 0864abfc8..8e88b2f7a 100644 --- a/src/test/slicing.cpp +++ b/src/test/slicing.cpp @@ -1,6 +1,24 @@ #include "math/polysat/slicing.h" #include "math/polysat/solver.h" +namespace { + + template + void permute_args(unsigned k, T& a, T& b, T& c) { + using std::swap; + SASSERT(k < 6); + unsigned i = k % 3; + unsigned j = k % 2; + if (i == 1) + swap(a, b); + else if (i == 2) + swap(a, c); + if (j == 1) + swap(b, c); + } + +} + namespace polysat { struct solver_scope_slicing { @@ -200,6 +218,61 @@ namespace polysat { VERIFY(sl.invariant()); } + static void test6() { + std::cout << __func__ << "\n"; + scoped_solver_slicing s; + slicing& sl = s.sl(); + pdd x = s.var(s.add_var(8)); + pdd y = s.var(s.add_var(8)); + pdd z = s.var(s.add_var(8)); + sl.add_constraint(s.eq(x, z)); + sl.add_constraint(s.eq(y, z)); + sl.add_constraint(s.eq(x, rational(5))); + sl.add_value(x.var(), rational(5)); + sl.add_value(y.var(), rational(7)); + + SASSERT(sl.is_conflict()); + sat::literal_vector reason_lits; + unsigned_vector reason_vars; + sl.explain(reason_lits, reason_vars); + std::cout << "Conflict: " << reason_lits << " vars " << reason_vars << "\n"; + + sl.display_tree(std::cout); + VERIFY(sl.invariant()); + } + + // x != z + // x = y + // y = z + // in various permutations + static void test7() { + std::cout << __func__ << "\n"; + scoped_set_log_enabled _logging(false); + scoped_solver_slicing s; + slicing& sl = s.sl(); + pdd x = s.var(s.add_var(8)); + pdd y = s.var(s.add_var(8)); + pdd z = s.var(s.add_var(8)); + + for (unsigned k = 0; k < 6; ++k) { + s.push(); + signed_constraint c1 = s.diseq(x, z); + signed_constraint c2 = s.eq(x, y); + signed_constraint c3 = s.eq(y, z); + permute_args(k, c1, c2, c3); + sl.add_constraint(c1); + sl.add_constraint(c2); + sl.add_constraint(c3); + SASSERT(sl.is_conflict()); + sat::literal_vector reason_lits; + unsigned_vector reason_vars; + sl.explain(reason_lits, reason_vars); + std::cout << "Conflict: " << reason_lits << " vars " << reason_vars << "\n"; + // sl.display_tree(std::cout); + VERIFY(sl.invariant()); + s.pop(); + } + } }; } @@ -207,10 +280,12 @@ namespace polysat { void tst_slicing() { using namespace polysat; - test_slicing::test1(); - test_slicing::test2(); - test_slicing::test3(); - test_slicing::test4(); - test_slicing::test5(); + // test_slicing::test1(); + // test_slicing::test2(); + // test_slicing::test3(); + // test_slicing::test4(); + // test_slicing::test5(); + // test_slicing::test6(); + test_slicing::test7(); std::cout << "ok\n"; }