diff --git a/src/math/polysat/constraint_manager.cpp b/src/math/polysat/constraint_manager.cpp index d3ebad245..371b252aa 100644 --- a/src/math/polysat/constraint_manager.cpp +++ b/src/math/polysat/constraint_manager.cpp @@ -616,8 +616,7 @@ namespace polysat { unsigned const v_sz = p_sz + q_sz; if (p.is_val() || q.is_val()) return zero_ext(p, q_sz) * rational::power_of_two(q_sz) + zero_ext(q, p_sz); - auto const args = {s.m_names.mk_name(p), s.m_names.mk_name(q)}; - pvar const v = s.m_slicing.mk_concat(args.size(), args.begin()); + pvar const v = s.m_slicing.mk_concat({s.m_names.mk_name(p), s.m_names.mk_name(q)}); return s.var(v); } diff --git a/src/math/polysat/slicing.cpp b/src/math/polysat/slicing.cpp index 97190e9df..cf560cae9 100644 --- a/src/math/polysat/slicing.cpp +++ b/src/math/polysat/slicing.cpp @@ -186,18 +186,62 @@ namespace polysat { unsigned const target_lvl = lvl - num_scopes; unsigned const target_size = m_scopes[target_lvl]; m_scopes.shrink(target_lvl); - while (m_trail.size() > target_size) { - switch (m_trail.back()) { - case trail_item::add_var: undo_add_var(); break; - case trail_item::split_core: undo_split_core(); break; - case trail_item::mk_extract: undo_mk_extract(); break; + unsigned_vector replay_add_var; + svector> replay_extract; + unsigned num_replay_concat = 0; + for (unsigned i = m_trail.size(); i-- > target_size; ) { + switch (m_trail[i]) { + case trail_item::add_var: + replay_add_var.push_back(width(m_var2slice.back())); + undo_add_var(); + break; + case trail_item::split_core: + undo_split_core(); + break; + case trail_item::mk_extract: + extract_args const& args = m_extract_trail.back(); + replay_extract.push_back({args, m_extract_dedup[args]}); + undo_mk_extract(); + break; + case trail_item::mk_concat: + num_replay_concat++; + break; default: UNREACHABLE(); } - m_trail.pop_back(); } m_egraph.pop(num_scopes); m_needs_congruence.reset(); m_disequality_conflict = nullptr; + // replay add_var/mk_extract/mk_concat in the same order + // (only until polysat::solver supports proper garbage collection of variables) + unsigned add_var_idx = replay_add_var.size(); + unsigned extract_idx = replay_extract.size(); + unsigned concat_idx = m_concat_trail.size() - num_replay_concat; + for (unsigned i = target_size; i < m_trail.size(); ++i) { + switch (m_trail[i]) { + case trail_item::add_var: + unsigned const sz = replay_add_var[--add_var_idx]; + add_var(replay_add_var[i]); + break; + case trail_item::split_core: + /* do nothing */ + break; + case trail_item::mk_extract: + auto const [args, v] = replay_extract[--extract_idx]; + this->replay_extract(args, v); + break; + case trail_item::mk_concat: + auto ci = m_concat_trail[concat_idx++]; + num_replay_concat++; + replay_concat(ci.num_args, &m_concat_args[ci.args_idx], ci.v); + break; + default: UNREACHABLE(); + } + + } + m_concat_trail.shrink(m_concat_trail.size() - num_replay_concat); + m_concat_args.shrink(m_concat_trail.empty() ? 0 : m_concat_trail.back().next_args_idx()); + m_trail.shrink(target_size); } void slicing::add_var(unsigned bit_width) { @@ -724,7 +768,7 @@ namespace polysat { get_base_core(src, out_base); } - pvar slicing::mk_extract(enode* src, unsigned hi, unsigned lo) { + pvar slicing::mk_extract(enode* src, unsigned hi, unsigned lo, pvar replay_var = null_var) { enode_vector& slices = m_tmp3; SASSERT(slices.empty()); mk_slice(src, hi, lo, slices, false, true); @@ -732,6 +776,15 @@ namespace polysat { // try to re-use variable of an existing slice if (slices.size() == 1) v = slice2var(slices[0]); + if (replay_var != null_var && v != replay_var) { + // replayed variable should be 'fresh', unless it was a re-used variable + enode* s = var2slice(replay_var); + SASSERT(s->is_root()); + SASSERT_EQ(s->class_size(), 1); + SASSERT(!has_sub(s)); + SASSERT_EQ(width(s), hi - lo + 1); + v = replay_var; + } // allocate new variable if we cannot reuse it if (v == null_var) v = m_solver.add_var(hi - lo + 1); @@ -752,6 +805,15 @@ namespace polysat { return v; } + void slicing::replay_extract(extract_args const& args, pvar r) { + SASSERT(r != null_var); + SASSERT(!m_extract_dedup.contains(args)); + VERIFY_EQ(mk_extract(var2slice(args.src), args.hi, args.lo, r), r); + m_extract_dedup.insert(args, r); + m_extract_trail.push_back(args); + m_trail.push_back(trail_item::mk_extract); + } + pvar slicing::mk_extract(pvar src, unsigned hi, unsigned lo) { extract_args args{src, hi, lo}; auto it = m_extract_dedup.find_iterator(args); @@ -770,7 +832,7 @@ namespace polysat { m_extract_dedup.remove(args); } - pvar slicing::mk_concat(unsigned num_args, pvar const* args) { + pvar slicing::mk_concat(unsigned num_args, pvar const* args, pvar replay_var) { enode_vector& slices = m_tmp3; SASSERT(slices.empty()); unsigned total_width = 0; @@ -780,22 +842,49 @@ namespace polysat { total_width += width(s); } // NOTE: we use concat nodes to deduplicate (syntactically equal) concat expressions. - // we might end up reusing variables that are not introduced by mk_concat (if we enable the variable re-use optimizatio in mk_extract), + // we might end up reusing variables that are not introduced by mk_concat (if we enable the variable re-use optimization in mk_extract), // but because such congruence nodes are only added over direct descendants, we do not get unwanted dependencies from this re-use. // (but note that the nodes from mk_concat are not only over direct descendants) enode* concat = mk_concat_node(slices); pvar v = slice2var(concat); if (v != null_var) return v; - v = m_solver.add_var(total_width); + if (replay_var != null_var) { + // replayed variable should be 'fresh' + enode* s = var2slice(replay_var); + SASSERT(s->is_root()); + SASSERT_EQ(s->class_size(), 1); + SASSERT(!has_sub(s)); + SASSERT_EQ(width(s), total_width); + v = replay_var; + } + else + v = m_solver.add_var(total_width); enode* sv = var2slice(v); VERIFY(merge(slices, sv, dep_t())); // NOTE: add_concat_node must be done after merge to preserve the invariant: "a base slice is never equivalent to a congruence node". add_concat_node(sv, concat); slices.reset(); + + // Note about the early return above: + // all such variables should have been introduced by mk_extract or mk_concat, so replay will properly restore them + concat_info ci; + ci.v = v; + ci.num_args = num_args; + ci.args_idx = m_concat_args.size(); + m_concat_trail.push_back(ci); + for (unsigned i = 0; i < num_args; ++i) + m_concat_args.push_back(args[i]); + m_trail.push_back(trail_item::mk_concat); + return v; } + void slicing::replay_concat(unsigned num_args, pvar const* args, pvar r) { + SASSERT(r != null_var); + VERIFY_EQ(mk_concat(num_args, args, r), r); + } + pvar slicing::mk_concat(std::initializer_list args) { return mk_concat(args.size(), args.begin()); } diff --git a/src/math/polysat/slicing.h b/src/math/polysat/slicing.h index cbd2f90d2..a1bc8e4f7 100644 --- a/src/math/polysat/slicing.h +++ b/src/math/polysat/slicing.h @@ -214,12 +214,22 @@ namespace polysat { add_var, split_core, mk_extract, + mk_concat, }; svector m_trail; enode_vector m_split_trail; svector m_extract_trail; unsigned_vector m_scopes; + struct concat_info { + pvar v; + unsigned num_args; + unsigned args_idx; + unsigned next_args_idx() const { args_idx + num_args; } + }; + svector m_concat_trail; + svector m_concat_args; + void undo_add_var(); void undo_split_core(); void undo_mk_extract(); @@ -232,7 +242,12 @@ namespace polysat { uint_set m_marked_vars; /** Get variable representing src[hi:lo] */ - pvar mk_extract(enode* src, unsigned hi, unsigned lo); + pvar mk_extract(enode* src, unsigned hi, unsigned lo, pvar replay_var = null_var); + /** Restore r = src[hi:lo] */ + void replay_extract(extract_args const& args, pvar r); + + pvar mk_concat(unsigned num_args, pvar const* args, pvar replay_var); + void replay_concat(unsigned num_args, pvar const* args, pvar r); bool add_equation(pvar x, pdd const& body, sat::literal lit); @@ -254,7 +269,7 @@ namespace polysat { pvar mk_extract(pvar x, unsigned hi, unsigned lo); /** Get or create variable representing x1 ++ x2 ++ ... ++ xn */ - pvar mk_concat(unsigned num_args, pvar const* args); + pvar mk_concat(unsigned num_args, pvar const* args) { return mk_concat(num_args, args, null_var); } pvar mk_concat(std::initializer_list args); // Track value assignments to variables (and propagate to subslices)