diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 0af855baf..9cfbe941a 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -100,6 +100,7 @@ namespace euf { for (enode* child : enode_args(n)) child->get_root()->add_parent(n); n->set_update_children(); + m_updates.push_back(update_record(n, update_record::update_children())); } enode* egraph::mk(expr* f, unsigned generation, unsigned num_args, enode *const* args) { @@ -264,23 +265,21 @@ namespace euf { void egraph::set_merge_enabled(enode* n, bool enable_merge) { if (enable_merge != n->merge_enabled()) { - toggle_merge_enabled(n); + toggle_merge_enabled(n, false); m_updates.push_back(update_record(n, update_record::toggle_merge())); - if (enable_merge && n->num_args() > 0) { - auto [n2, comm] = insert_table(n); - if (n2 != n) - merge(n, n2, justification::congruence(comm)); - } } } - void egraph::toggle_merge_enabled(enode* n) { + void egraph::toggle_merge_enabled(enode* n, bool backtracking) { bool enable_merge = !n->merge_enabled(); n->set_merge_enabled(enable_merge); if (n->num_args() > 0) { - if (enable_merge) - insert_table(n); - else if (m_table.contains_ptr(n)) + if (enable_merge) { + auto [n2, comm] = insert_table(n); + if (n2 != n && !backtracking) + m_to_merge.push_back(to_merge(n, n2, comm)); + } + else if (n->is_cgr()) erase_from_table(n); } VERIFY(n->num_args() == 0 || !n->merge_enabled() || m_table.contains(n)); @@ -337,14 +336,15 @@ namespace euf { m_nodes.pop_back(); m_exprs.pop_back(); }; - for (unsigned i = m_updates.size(); i-- > num_updates; ) { + unsigned sz = m_updates.size(); + for (unsigned i = sz; i-- > num_updates; ) { auto const& p = m_updates[i]; switch (p.tag) { case update_record::tag_t::is_add_node: undo_node(); break; case update_record::tag_t::is_toggle_merge: - toggle_merge_enabled(p.r1); + toggle_merge_enabled(p.r1, true); break; case update_record::tag_t::is_set_parent: undo_eq(p.r1, p.n1, p.r2_num_parents); @@ -381,12 +381,18 @@ namespace euf { case update_record::tag_t::is_lbl_set: p.r1->m_lbls.set(p.m_lbls); break; + case update_record::tag_t::is_update_children: + for (unsigned i = 0; i < p.r1->num_args(); ++i) { + SASSERT(p.r1->m_args[i]->get_root()->m_parents.back() == p.r1); + p.r1->m_args[i]->get_root()->m_parents.pop_back(); + } + break; default: UNREACHABLE(); break; - } + } } - + SASSERT(m_updates.size() == sz); m_updates.shrink(num_updates); m_scopes.shrink(old_lim); m_region.pop_scope(num_scopes); diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index ac12f6cef..5c828678e 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -104,7 +104,8 @@ namespace euf { struct value_assignment {}; struct lbl_hash {}; struct lbl_set {}; - enum class tag_t { is_set_parent, is_add_node, is_toggle_merge, + struct update_children {}; + enum class tag_t { is_set_parent, is_add_node, is_toggle_merge, is_update_children, is_add_th_var, is_replace_th_var, is_new_lit, is_new_th_eq, is_lbl_hash, is_new_th_eq_qhead, is_new_lits_qhead, is_inconsistent, is_value_assignment, is_lbl_set }; @@ -148,6 +149,8 @@ namespace euf { tag(tag_t::is_lbl_hash), r1(n), n1(nullptr), m_lbl_hash(n->m_lbl_hash) {} update_record(enode* n, lbl_set): tag(tag_t::is_lbl_set), r1(n), n1(nullptr), m_lbls(n->m_lbls.get()) {} + update_record(enode* n, update_children) : + tag(tag_t::is_update_children), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} }; ast_manager& m; svector m_to_merge; @@ -211,7 +214,7 @@ namespace euf { void push_to_lca(enode* a, enode* lca); void push_congruence(enode* n1, enode* n2, bool commutative); void push_todo(enode* n); - void toggle_merge_enabled(enode* n); + void toggle_merge_enabled(enode* n, bool backtracking); enode_bool_pair insert_table(enode* p); void erase_from_table(enode* p); diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index d2b4cf02a..73405b77e 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -46,7 +46,6 @@ namespace euf { bool m_mark1 = false; bool m_mark2 = false; bool m_commutative = false; - bool m_update_children = false; bool m_interpreted = false; bool m_merge_enabled = true; bool m_is_equality = false; // Does the expression represent an equality @@ -124,10 +123,7 @@ namespace euf { n->m_args[i] = nullptr; return n; } - - void set_update_children() { m_update_children = true; } - - + friend class add_th_var_trail; friend class replace_th_var_trail; void add_th_var(theory_var v, theory_id id, region & r) { m_th_vars.add_var(v, id, r); } @@ -142,12 +138,6 @@ namespace euf { ~enode() { SASSERT(m_root == this); SASSERT(class_size() == 1); - if (m_update_children) { - for (unsigned i = 0; i < num_args(); ++i) { - SASSERT(m_args[i]->get_root()->m_parents.back() == this); - m_args[i]->get_root()->m_parents.pop_back(); - } - } } enode* const* args() const { return m_args; }