From c124cbae975cf52b4d301f2214c514b876628107 Mon Sep 17 00:00:00 2001 From: Jakob Rath Date: Tue, 18 Jul 2023 14:47:44 +0200 Subject: [PATCH] Add virtual concat terms on demand during propagation --- src/math/polysat/slicing.cpp | 70 +++++++++++++++++++++++------------- src/math/polysat/slicing.h | 16 +++++---- src/test/slicing.cpp | 5 +++ 3 files changed, 60 insertions(+), 31 deletions(-) diff --git a/src/math/polysat/slicing.cpp b/src/math/polysat/slicing.cpp index 86ab11d39..f64940add 100644 --- a/src/math/polysat/slicing.cpp +++ b/src/math/polysat/slicing.cpp @@ -171,7 +171,6 @@ namespace polysat { for (unsigned i = arity; i-- > 0; ) domain.push_back(m_slice_sort); SASSERT_EQ(arity, domain.size()); - // TODO: mk_fresh_func_decl("concat", ...) if overload doesn't work decl = m_ast.mk_func_decl(symbol("slice-concat"), arity, domain.data(), m_slice_sort); m_concat_decls.setx(arity, decl); } @@ -212,10 +211,10 @@ namespace polysat { m_var2slice.pop_back(); } - slicing::enode* slicing::alloc_enode(expr* e, unsigned width, pvar var) { + slicing::enode* slicing::alloc_enode(expr* e, unsigned num_args, enode* const* args, unsigned width, pvar var) { SASSERT(width > 0); SASSERT(!m_egraph.find(e)); - euf::enode* n = m_egraph.mk(e, 0, 0, nullptr); // NOTE: the egraph keeps a strong reference to 'e' + euf::enode* n = m_egraph.mk(e, 0, num_args, args); // NOTE: the egraph keeps a strong reference to 'e' m_info.reserve(n->get_id() + 1); slice_info& i = info(n); i.reset(); @@ -224,20 +223,41 @@ namespace polysat { return n; } - slicing::enode* slicing::find_or_alloc_enode(expr* e, unsigned width, pvar var) { + slicing::enode* slicing::find_or_alloc_enode(expr* e, unsigned num_args, enode* const* args, unsigned width, pvar var) { enode* n = m_egraph.find(e); if (n) { SASSERT_EQ(info(n).width, width); SASSERT_EQ(info(n).var, var); return n; } - return alloc_enode(e, width, var); + return alloc_enode(e, num_args, args, width, var); } slicing::enode* slicing::alloc_slice(unsigned width, pvar var) { // app* a = m_ast.mk_fresh_const("s", m_bv->mk_sort(width), false); app* a = m_ast.mk_fresh_const("s", m_slice_sort, false); - return alloc_enode(a, width, var); + return alloc_enode(a, 0, nullptr, width, var); + } + + void slicing::add_congruence(pvar v) { + enode_vector& base = m_tmp2; + SASSERT(base.empty()); + enode* sv = var2slice(v); + get_base(sv, base); + // Add equation v == concat(s1, ..., sn) + ptr_vector args; + for (enode* n : base) + args.push_back(n->get_expr()); + app* a = m_ast.mk_app(get_concat_decl(args.size()), args); + enode* concat = find_or_alloc_enode(a, base.size(), base.data(), width(sv), null_var); + base.clear(); + m_egraph.merge(sv, concat, encode_dep(dep_t())); + } + + void slicing::update_var_congruences() { + for (pvar v : m_needs_congruence) + add_congruence(v); + m_needs_congruence.reset(); } // split a single slice without updating any equivalences @@ -257,22 +277,15 @@ namespace polysat { else { sub_hi = alloc_slice(width_hi); sub_lo = alloc_slice(width_lo); - // info(sub_hi).parent = s; - // info(sub_lo).parent = s; + info(sub_hi).parent = s; + info(sub_lo).parent = s; } info(s).set_cut(cut, sub_hi, sub_lo); m_trail.push_back(trail_item::split_core); m_split_trail.push_back(s); - - // // s = hi ++ lo ... TODO: necessary??? probably not - // euf::enode* s_n = slice2enode(s); - // euf::enode* hi_n = slice2enode(sub_hi); - // euf::enode* lo_n = slice2enode(sub_lo); - // app* a = m_ast.mk_app(get_concat_decl(2), hi_n->get_expr(), lo_n->get_expr()); - // auto args = {hi_n, lo_n}; - // euf::enode* concat_n = m_egraph.mk(a, 0, args.size(), blup.begin()); - // m_egraph.merge(s_n, concat_n, encode_dep(null_dep)); - // SASSERT(!concat_n->is_root()); // else we have to register it in enode2slice + for (enode* n = s; n != nullptr; n = parent(n)) + if (slice2var(n) != null_var) + m_needs_congruence.insert(slice2var(n)); } void slicing::undo_split_core() { @@ -295,14 +308,14 @@ namespace polysat { m_egraph.merge(sub_hi(n), sub_hi(target), j.ext()); m_egraph.merge(sub_lo(n), sub_lo(target), j.ext()); } - m_egraph.propagate(); // TODO: could do this later + // m_egraph.propagate(); } slicing::enode* slicing::mk_value_slice(rational const& val, unsigned bit_width) { SASSERT(0 <= val && val < rational::power_of_two(bit_width)); app* a = m_bv->mk_numeral(val, bit_width); a = m_ast.mk_app(get_embed_decl(bit_width), a); // adjust sort - enode* s = find_or_alloc_enode(a, bit_width, null_var); + enode* s = find_or_alloc_enode(a, 0, nullptr, bit_width, null_var); s->mark_interpreted(); SASSERT(s->interpreted()); SASSERT_EQ(get_value(s), val); @@ -424,6 +437,14 @@ namespace polysat { end_explain(); } + void slicing::propagate() { + // m_egraph.propagate(); + if (is_conflict()) + return; + update_var_congruences(); + m_egraph.propagate(); + } + bool slicing::merge_base(enode* s1, enode* s2, dep_t dep) { SASSERT_EQ(width(s1), width(s2)); SASSERT(!has_sub(s1)); @@ -673,7 +694,7 @@ namespace polysat { } void slicing::add_constraint(signed_constraint c) { - // TODO: evaluate under current assignment? + // TODO: evaluate under current assignment? (no, do that externally) if (!c->is_eq()) return; dep_t const d = c.blit(); @@ -690,8 +711,7 @@ namespace polysat { // Simple assignment x = value enode* const sval = mk_value_slice(body.val(), body.power_of_2()); if (!merge(sx, sval, d)) { - // TODO: conflict - NOT_IMPLEMENTED_YET(); + SASSERT(is_conflict()); return; } continue; @@ -704,8 +724,7 @@ namespace polysat { enode* const sy = var2slice(y); if (c.is_positive()) { if (!merge(sx, sy, d)) { - // TODO: conflict - NOT_IMPLEMENTED_YET(); + SASSERT(is_conflict()); return; } } @@ -714,6 +733,7 @@ namespace polysat { if (is_equal(sx, sy)) { // TODO: conflict NOT_IMPLEMENTED_YET(); + SASSERT(is_conflict()); return; } } diff --git a/src/math/polysat/slicing.h b/src/math/polysat/slicing.h index 7147f6d34..61f7f438c 100644 --- a/src/math/polysat/slicing.h +++ b/src/math/polysat/slicing.h @@ -73,7 +73,7 @@ namespace polysat { // The cut point is relative to the parent slice (rather than a root variable, which might not be unique) unsigned cut = null_cut; // cut point, or null_cut if no subslices pvar var = null_var; // slice is equivalent to this variable, if any (without dependencies) - // enode* parent = nullptr; // parent slice, only for proper slices (if not null: s == sub_hi(parent(s)) || s == sub_lo(parent(s))) + enode* parent = nullptr; // parent slice, only for proper slices (if not null: s == sub_hi(parent(s)) || s == sub_lo(parent(s))) enode* slice = nullptr; // if enode corresponds to a concat(...) expression, this field links to the represented slice. enode* sub_hi = nullptr; // upper subslice s[|s|-1:cut+1] enode* sub_lo = nullptr; // lower subslice s[cut:0] @@ -97,10 +97,12 @@ namespace polysat { euf::egraph m_egraph; slice_info_vector m_info; // indexed by enode::get_id() enode_vector m_var2slice; // pvar -> slice + tracked_uint_set m_needs_congruence; // set of pvars that need updated concat(...) expressions - - - + // Add an equation v = concat(s1, ..., sn) + // for each variable v with base slices s1, ..., sn + void update_var_congruences(); + void add_congruence(pvar v); func_decl* get_embed_decl(unsigned bit_width); func_decl* get_concat_decl(unsigned arity); @@ -112,8 +114,8 @@ namespace polysat { slice_info& info(euf::enode* n); slice_info const& info(euf::enode* n) const; - enode* alloc_enode(expr* e, unsigned width, pvar var); - enode* find_or_alloc_enode(expr* e, unsigned width, pvar var); + enode* alloc_enode(expr* e, unsigned num_args, enode* const* args, unsigned width, pvar var); + enode* find_or_alloc_enode(expr* e, unsigned num_args, enode* const* args, unsigned width, pvar var); enode* alloc_slice(unsigned width, pvar var = null_var); enode* var2slice(pvar v) const { return m_var2slice[v]; } @@ -121,6 +123,8 @@ namespace polysat { unsigned width(enode* s) const { return info(s).width; } + enode* parent(enode* s) const { return info(s).parent; } + bool has_sub(enode* s) const { return info(s).has_sub(); } /// Upper subslice (direct child, not necessarily the representative) diff --git a/src/test/slicing.cpp b/src/test/slicing.cpp index 67324ed2e..0864abfc8 100644 --- a/src/test/slicing.cpp +++ b/src/test/slicing.cpp @@ -128,6 +128,11 @@ namespace polysat { sl.explain_equal(sl.var2slice(b), sl.pdd2slice(d), reason_lits, reason_vars); std::cout << " Reason: " << reason_lits << " vars " << reason_vars << "\n"; + sl.display_tree(std::cout); + VERIFY(sl.invariant()); + + sl.propagate(); + sl.display_tree(std::cout); VERIFY(sl.invariant()); }