diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 4a447b44f..b3a1d8516 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -21,6 +21,28 @@ Author: namespace euf { + /** + \brief Trail for add_th_var + */ + class add_th_var_trail : public trail { + enode * m_enode; + theory_id m_th_id; + public: + add_th_var_trail(enode * n, theory_id th_id): + m_enode(n), + m_th_id(th_id) { + } + + void undo(egraph & ctx) override { + theory_var v = m_enode->get_th_var(m_th_id); + SASSERT(v != null_var); + m_enode->del_th_var(m_th_id); + enode * root = m_enode->get_root(); + if (root != m_enode && root->get_th_var(m_th_id) == v) + root->del_th_var(m_th_id); + } + }; + void egraph::undo_eq(enode* r1, enode* n1, unsigned r2_num_parents) { enode* r2 = r1->get_root(); r2->dec_class_size(r1->class_size()); @@ -89,6 +111,7 @@ namespace euf { s.m_inconsistent = m_inconsistent; s.m_num_eqs = m_eqs.size(); s.m_num_nodes = m_nodes.size(); + s.m_trail_sz = m_trail.size(); m_scopes.push_back(s); m_region.push_scope(); } @@ -135,6 +158,14 @@ namespace euf { n->m_parents.finalize(); } + void egraph::add_th_var(enode* n, theory_var v, theory_id id) { + force_push(); + SASSERT(null_var == n->get_th_var(id)); + SASSERT(n->class_size() == 1); + n->add_th_var(v, id, m_region); + m_trail.push_back(new (m_region) add_th_var_trail(n, id)); + } + void egraph::pop(unsigned num_scopes) { if (num_scopes <= m_num_scopes) { m_num_scopes -= num_scopes; @@ -154,6 +185,7 @@ namespace euf { m_expr2enode[n->get_owner_id()] = nullptr; n->~enode(); } + undo_trail_stack(*this, m_trail, s.m_trail_sz); m_inconsistent = s.m_inconsistent; m_eqs.shrink(s.m_num_eqs); m_nodes.shrink(s.m_num_nodes); @@ -194,12 +226,30 @@ namespace euf { std::swap(r1->m_next, r2->m_next); r2->inc_class_size(r1->class_size()); r2->m_parents.append(r1->m_parents); + merge_th_eq(r1, r2); m_worklist.push_back(r2); } + void egraph::merge_th_eq(enode* n, enode* root) { + SASSERT(n != root); + for (auto iv : enode_th_vars(n)) { + theory_id id = iv.get_id(); + theory_var v = root->get_th_var(id); + if (v == null_var) { + root->add_th_var(iv.get_var(), id, m_region); + m_trail.push_back(new (m_region) add_th_var_trail(root, id)); + } + else { + SASSERT(v != iv.get_var()); + m_new_th_eqs.push_back(th_eq(id, v, iv.get_var(), n, root)); + } + } + } + void egraph::propagate() { m_new_eqs.reset(); m_new_lits.reset(); + m_new_th_eqs.reset(); SASSERT(m_num_scopes == 0 || m_worklist.empty()); unsigned head = 0, tail = m_worklist.size(); while (head < tail && m.limit().inc() && !inconsistent()) { @@ -315,7 +365,7 @@ namespace euf { } template - void egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b, bool comm) { + void egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b) { SASSERT(m_todo.empty()); SASSERT(a->get_root() == b->get_root()); enode* lca = find_lca(a, b); @@ -394,6 +444,7 @@ namespace euf { for (unsigned i = 0; i < src.m_nodes.size(); ++i) { enode* n1 = src.m_nodes[i]; expr* e1 = src.m_exprs[i]; + SASSERT(!n1->has_th_vars()); args.reset(); for (unsigned j = 0; j < n1->num_args(); ++j) args.push_back(old_expr2new_enode[n1->get_arg(j)->get_owner_id()]); @@ -418,9 +469,9 @@ namespace euf { template void euf::egraph::explain(ptr_vector& justifications); template void euf::egraph::explain_todo(ptr_vector& justifications); -template void euf::egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b, bool comm); +template void euf::egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b); template void euf::egraph::explain(ptr_vector& justifications); template void euf::egraph::explain_todo(ptr_vector& justifications); -template void euf::egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b, bool comm); +template void euf::egraph::explain_eq(ptr_vector& justifications, enode* a, enode* b); diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index de7accda6..72313a88a 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -25,6 +25,7 @@ Notes: #pragma once #include "util/statistics.h" +#include "util/trail.h" #include "ast/euf/euf_enode.h" #include "ast/euf/euf_etable.h" @@ -37,12 +38,24 @@ namespace euf { add_eq_record(enode* r1, enode* n1, unsigned r2_num_parents): r1(r1), n1(n1), r2_num_parents(r2_num_parents) {} }; + + struct th_eq { + theory_id m_id; + theory_var m_v1; + theory_var m_v2; + enode* m_child; + enode* m_root; + th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) : + m_id(id), m_v1(v1), m_v2(v2), m_child(c), m_root(r) {} + }; class egraph { + typedef ptr_vector > trail_stack; struct scope { bool m_inconsistent; unsigned m_num_eqs; unsigned m_num_nodes; + unsigned m_trail_sz; }; struct stats { unsigned m_num_merge; @@ -53,6 +66,7 @@ namespace euf { void reset() { memset(this, 0, sizeof(*this)); } }; ast_manager& m; + trail_stack m_trail; region m_region; enode_vector m_worklist; etable m_table; @@ -68,10 +82,11 @@ namespace euf { justification m_justification; enode_vector m_new_eqs; enode_vector m_new_lits; + svector m_new_th_eqs; enode_vector m_todo; stats m_stats; std::function m_used_eq; - std::function m_used_cc; + std::function m_used_cc; void push_eq(enode* r1, enode* n1, unsigned r2_num_parents) { m_eqs.push_back(add_eq_record(r1, n1, r2_num_parents)); @@ -82,6 +97,7 @@ namespace euf { void force_push(); void set_conflict(enode* n1, enode* n2, justification j); void merge(enode* n1, enode* n2, justification j); + void merge_th_eq(enode* n, enode* root); void merge_justification(enode* n1, enode* n2, justification j); void unmerge_justification(enode* n1); void dedup_equalities(); @@ -132,6 +148,9 @@ namespace euf { bool inconsistent() const { return m_inconsistent; } enode_vector const& new_eqs() const { return m_new_eqs; } enode_vector const& new_lits() const { return m_new_lits; } + svector const& new_th_eqs() const { return m_new_th_eqs; } + + void add_th_var(enode* n, theory_var v, theory_id id); void set_used_eq(std::function& used_eq) { m_used_eq = used_eq; } void set_used_cc(std::function& used_cc) { m_used_cc = used_cc; } @@ -139,7 +158,7 @@ namespace euf { template void explain(ptr_vector& justifications); template - void explain_eq(ptr_vector& justifications, enode* a, enode* b, bool comm); + void explain_eq(ptr_vector& justifications, enode* a, enode* b); enode_vector const& nodes() const { return m_nodes; } void invariant(); void copy_from(egraph const& src, std::function& copy_justification); diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index 59166df31..ad330f0b5 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -16,6 +16,7 @@ Author: --*/ #include "util/vector.h" +#include "util/id_var_list.h" #include "ast/ast.h" #include "ast/euf/euf_justification.h" @@ -28,6 +29,11 @@ namespace euf { typedef ptr_vector enode_vector; typedef std::pair enode_pair; typedef svector enode_pair_vector; + typedef id_var_list<> th_var_list; + typedef int theory_var; + typedef int theory_id; + const theory_var null_var = -1; + const theory_id null_id = -1; class enode { expr* m_owner; @@ -42,6 +48,7 @@ namespace euf { enode* m_next; enode* m_root; enode* m_target { nullptr }; + th_var_list m_th_vars; justification m_justification; unsigned m_num_args; enode* m_args[0]; @@ -49,6 +56,7 @@ namespace euf { friend class enode_args; friend class enode_parents; friend class enode_class; + friend class enode_th_vars; friend class etable; friend class egraph; @@ -73,6 +81,12 @@ namespace euf { } void set_update_children() { m_update_children = true; } + + friend class add_th_var_trail; + void add_th_var(theory_var v, theory_id id, region & r) { m_th_vars.add_var(v, id, r); } + void replace_th_var(theory_var v, theory_id id) { m_th_vars.replace(v, id); } + void del_th_var(theory_id id) { m_th_vars.del_var(id); } + public: ~enode() { SASSERT(m_root == this); @@ -127,6 +141,9 @@ namespace euf { expr* get_owner() const { return m_owner; } unsigned get_owner_id() const { return m_owner->get_id(); } unsigned get_root_id() const { return m_root->m_owner->get_id(); } + theory_var get_th_var(theory_id id) const { return m_th_vars.find(id); } + bool has_th_vars() const { return !m_th_vars.empty(); } + void inc_class_size(unsigned n) { m_class_size += n; } void dec_class_size(unsigned n) { m_class_size -= n; } @@ -177,4 +194,24 @@ namespace euf { iterator begin() const { return iterator(&n, nullptr); } iterator end() const { return iterator(&n, &n); } }; + + class enode_th_vars { + enode& n; + public: + class iterator { + th_var_list* m_th_vars; + public: + iterator(th_var_list* n) : m_th_vars(n) {} + th_var_list operator*() { return *m_th_vars; } + iterator& operator++() { m_th_vars = m_th_vars->get_next(); return *this; } + iterator operator++(int) { iterator tmp = *this; ++* this; return tmp; } + bool operator==(iterator const& other) const { return m_th_vars == other.m_th_vars; } + bool operator!=(iterator const& other) const { return !(*this == other); } + }; + enode_th_vars(enode& _n) :n(_n) {} + enode_th_vars(enode* _n) :n(*_n) {} + iterator begin() const { return iterator(n.m_th_vars.empty() ? nullptr : &n.m_th_vars); } + iterator end() const { return iterator(nullptr); } + }; + } diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 3fbd08de5..7b4b9e8c4 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -127,13 +127,13 @@ namespace euf { SASSERT(n); SASSERT(m_egraph.is_equality(n)); SASSERT(!l.sign()); - m_egraph.explain_eq(m_explain, n->get_arg(0), n->get_arg(1), n->commutative()); + m_egraph.explain_eq(m_explain, n->get_arg(0), n->get_arg(1)); break; case constraint::kind_t::lit: n = m_var2node[l.var()]; SASSERT(n); SASSERT(m.is_bool(n->get_owner())); - m_egraph.explain_eq(m_explain, n, (l.sign() ? mk_false() : mk_true()), false); + m_egraph.explain_eq(m_explain, n, (l.sign() ? mk_false() : mk_true())); break; default: IF_VERBOSE(0, verbose_stream() << (unsigned)j.kind() << "\n"); diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 37ac66518..ae5524095 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -781,10 +781,10 @@ namespace smt { if (r2->m_th_var_list.get_next() == nullptr && r1->m_th_var_list.get_next() == nullptr) { // Common case: r2 and r1 have at most one theory var. - theory_id t2 = r2->m_th_var_list.get_th_id(); - theory_id t1 = r1->m_th_var_list.get_th_id(); - theory_var v2 = m_fparams.m_new_core2th_eq ? get_closest_var(n2, t2) : r2->m_th_var_list.get_th_var(); - theory_var v1 = m_fparams.m_new_core2th_eq ? get_closest_var(n1, t1) : r1->m_th_var_list.get_th_var(); + theory_id t2 = r2->m_th_var_list.get_id(); + theory_id t1 = r1->m_th_var_list.get_id(); + theory_var v2 = m_fparams.m_new_core2th_eq ? get_closest_var(n2, t2) : r2->m_th_var_list.get_var(); + theory_var v1 = m_fparams.m_new_core2th_eq ? get_closest_var(n1, t1) : r1->m_th_var_list.get_var(); TRACE("merge_theory_vars", tout << "v2: " << v2 << " #" << r2->get_owner_id() << ", v1: " << v1 << " #" << r1->get_owner_id() << ", t2: " << t2 << ", t1: " << t1 << "\n";); @@ -805,8 +805,8 @@ namespace smt { push_new_th_diseqs(r1, v2, get_theory(t2)); } else if (v1 != null_theory_var && v2 == null_theory_var) { - r2->m_th_var_list.set_th_var(v1); - r2->m_th_var_list.set_th_id(t1); + r2->m_th_var_list.set_var(v1); + r2->m_th_var_list.set_id(t1); TRACE("merge_theory_vars", tout << "push_new_th_diseqs v1: " << v1 << ", t1: " << t1 << "\n";); push_new_th_diseqs(r2, v1, get_theory(t1)); } @@ -819,8 +819,8 @@ namespace smt { theory_var_list * l2 = r2->get_th_var_list(); while (l2) { - theory_id t2 = l2->get_th_id(); - theory_var v2 = m_fparams.m_new_core2th_eq ? get_closest_var(n2, t2) : l2->get_th_var(); + theory_id t2 = l2->get_id(); + theory_var v2 = m_fparams.m_new_core2th_eq ? get_closest_var(n2, t2) : l2->get_var(); SASSERT(v2 != null_theory_var); SASSERT(t2 != null_theory_id); theory_var v1 = m_fparams.m_new_core2th_eq ? get_closest_var(n1, t2) : r1->get_th_var(t2); @@ -838,8 +838,8 @@ namespace smt { theory_var_list * l1 = r1->get_th_var_list(); while (l1) { - theory_id t1 = l1->get_th_id(); - theory_var v1 = m_fparams.m_new_core2th_eq ? get_closest_var(n1, t1) : l1->get_th_var(); + theory_id t1 = l1->get_id(); + theory_var v1 = m_fparams.m_new_core2th_eq ? get_closest_var(n1, t1) : l1->get_var(); SASSERT(v1 != null_theory_var); SASSERT(t1 != null_theory_id); theory_var v2 = r2->get_th_var(t1); @@ -973,13 +973,13 @@ namespace smt { // restore theory vars if (r2->m_th_var_list.get_next() == nullptr) { // common case: r2 has at most one variable - theory_var v2 = r2->m_th_var_list.get_th_var(); + theory_var v2 = r2->m_th_var_list.get_var(); if (v2 != null_theory_var) { - theory_id t2 = r2->m_th_var_list.get_th_id(); + theory_id t2 = r2->m_th_var_list.get_id(); if (get_theory(t2)->get_enode(v2)->get_root() != r2) { SASSERT(get_theory(t2)->get_enode(v2)->get_root() == r1); - r2->m_th_var_list.set_th_var(null_theory_var); //remove variable from r2 - r2->m_th_var_list.set_th_id(null_theory_id); + r2->m_th_var_list.set_var(null_theory_var); //remove variable from r2 + r2->m_th_var_list.set_id(null_theory_id); } } } @@ -1019,8 +1019,8 @@ namespace smt { theory_var_list * new_l2 = nullptr; theory_var_list * l2 = r2->get_th_var_list(); while (l2) { - theory_var v2 = l2->get_th_var(); - theory_id t2 = l2->get_th_id(); + theory_var v2 = l2->get_var(); + theory_id t2 = l2->get_id(); if (get_theory(t2)->get_enode(v2)->get_root() != r2) { SASSERT(get_theory(t2)->get_enode(v2)->get_root() == r1); @@ -1043,7 +1043,7 @@ namespace smt { new_l2->set_next(nullptr); } else { - r2->m_th_var_list.set_th_var(null_theory_var); + r2->m_th_var_list.set_var(null_theory_var); r2->m_th_var_list.set_next(nullptr); } } @@ -1070,7 +1070,7 @@ namespace smt { TRACE("add_diseq_inconsistent", tout << "add_diseq #" << n1->get_owner_id() << " #" << n2->get_owner_id() << " inconsistency, scope_lvl: " << m_scope_lvl << "\n";); //return false; - theory_id t1 = r1->m_th_var_list.get_th_id(); + theory_id t1 = r1->m_th_var_list.get_id(); if (t1 == null_theory_id) return false; return get_theory(t1)->use_diseqs(); } @@ -1078,15 +1078,15 @@ namespace smt { // Propagate disequalities to theories if (r1->m_th_var_list.get_next() == nullptr && r2->m_th_var_list.get_next() == nullptr) { // common case: r2 and r1 have at most one theory var. - theory_id t1 = r1->m_th_var_list.get_th_id(); - theory_var v1 = m_fparams.m_new_core2th_eq ? get_closest_var(n1, t1) : r1->m_th_var_list.get_th_var(); - theory_var v2 = m_fparams.m_new_core2th_eq ? get_closest_var(n2, t1) : r2->m_th_var_list.get_th_var(); + theory_id t1 = r1->m_th_var_list.get_id(); + theory_var v1 = m_fparams.m_new_core2th_eq ? get_closest_var(n1, t1) : r1->m_th_var_list.get_var(); + theory_var v2 = m_fparams.m_new_core2th_eq ? get_closest_var(n2, t1) : r2->m_th_var_list.get_var(); TRACE("add_diseq", tout << "one theory diseq\n"; tout << v1 << " != " << v2 << "\n"; - tout << "th1: " << t1 << " th2: " << r2->m_th_var_list.get_th_id() << "\n"; + tout << "th1: " << t1 << " th2: " << r2->m_th_var_list.get_id() << "\n"; ); if (t1 != null_theory_id && v1 != null_theory_var && v2 != null_theory_var && - t1 == r2->m_th_var_list.get_th_id()) { + t1 == r2->m_th_var_list.get_id()) { if (get_theory(t1)->use_diseqs()) push_new_th_diseq(t1, v1, v2); } @@ -1094,8 +1094,8 @@ namespace smt { else { theory_var_list * l1 = r1->get_th_var_list(); while (l1) { - theory_id t1 = l1->get_th_id(); - theory_var v1 = m_fparams.m_new_core2th_eq ? get_closest_var(n1, t1) : l1->get_th_var(); + theory_id t1 = l1->get_id(); + theory_var v1 = m_fparams.m_new_core2th_eq ? get_closest_var(n1, t1) : l1->get_var(); theory * th = get_theory(t1); TRACE("add_diseq", tout << m.get_family_name(t1) << "\n";); if (th->use_diseqs()) { @@ -1586,7 +1586,7 @@ namespace smt { enode * e = get_enode(n); theory_var_list * l = e->get_th_var_list(); while (l) { - theory_id th_id = l->get_th_id(); + theory_id th_id = l->get_id(); theory * th = get_theory(th_id); // I don't want to invoke relevant_eh twice for the same n. if (th != propagated_th) @@ -4509,7 +4509,7 @@ namespace smt { // contains a parent application. theory_var_list * l = n->get_th_var_list(); - theory_id th_id = l->get_th_id(); + theory_id th_id = l->get_id(); for (enode * parent : enode::parents(n)) { app* p = parent->get_owner(); @@ -4548,7 +4548,7 @@ namespace smt { // the theories of (array int int) and (array (array int int) int). // Remark: The inconsistency is not going to be detected if they are // not marked as shared. - return get_theory(th_id)->is_shared(l->get_th_var()); + return get_theory(th_id)->is_shared(l->get_var()); } default: return true; diff --git a/src/smt/smt_context_inv.cpp b/src/smt/smt_context_inv.cpp index 0fc8d5fff..b3c460925 100644 --- a/src/smt/smt_context_inv.cpp +++ b/src/smt/smt_context_inv.cpp @@ -282,7 +282,7 @@ namespace smt { << mk_pp(m.get_sort(lhs->get_owner()), m) << "\n";); theory_var_list * l = lhs->get_th_var_list(); while (l) { - theory_id th_id = l->get_th_id(); + theory_id th_id = l->get_id(); theory * th = get_theory(th_id); TRACE("check_th_diseq_propagation", tout << "checking theory: " << m.get_family_name(th_id) << "\n";); // if the theory doesn't use diseqs, then the diseqs are not propagated. diff --git a/src/smt/smt_enode.cpp b/src/smt/smt_enode.cpp index 4d98f0665..a5d84feae 100644 --- a/src/smt/smt_enode.cpp +++ b/src/smt/smt_enode.cpp @@ -93,13 +93,7 @@ namespace smt { } unsigned enode::get_num_th_vars() const { - unsigned r = 0; - theory_var_list const * l = get_th_var_list(); - while(l) { - r++; - l = l->get_next(); - } - return r; + return m_th_var_list.size(); } /** @@ -109,44 +103,14 @@ namespace smt { with a variable of theory th_id */ theory_var enode::get_th_var(theory_id th_id) const { - if (m_th_var_list.get_th_var() == null_theory_var) - return null_theory_var; - theory_var_list const * l = &m_th_var_list; - while (l) { - if (l->get_th_id() == th_id) { - return l->get_th_var(); - } - l = l->get_next(); - } - return null_theory_var; + return m_th_var_list.find(th_id); } /** \brief Add the entry (v, id) to the list of theory variables. */ void enode::add_th_var(theory_var v, theory_id id, region & r) { -#ifdef Z3DEBUG - unsigned old_size = get_num_th_vars(); -#endif - SASSERT(get_th_var(id) == null_theory_var); - if (m_th_var_list.get_th_var() == null_theory_var) { - m_th_var_list.set_th_var(v); - m_th_var_list.set_th_id(id); - m_th_var_list.set_next(nullptr); - } - else { - theory_var_list * l = &m_th_var_list; - while (l->get_next() != nullptr) { - SASSERT(l->get_th_id() != id); - l = l->get_next(); - } - SASSERT(l); - SASSERT(l->get_next() == 0); - theory_var_list * new_cell = new (r) theory_var_list(id, v); - l->set_next(new_cell); - } - SASSERT(get_num_th_vars() == old_size + 1); - SASSERT(get_th_var(id) == v); + m_th_var_list.add_var(v, id, r); } /** @@ -154,16 +118,7 @@ namespace smt { The enode must have an entry (v', id) */ void enode::replace_th_var(theory_var v, theory_id id) { - SASSERT(get_th_var(id) != null_theory_var); - theory_var_list * l = get_th_var_list(); - while (l) { - if (l->get_th_id() == id) { - l->set_th_var(v); - return; - } - l = l->get_next(); - } - UNREACHABLE(); + m_th_var_list.replace(v, id); } /** @@ -171,33 +126,7 @@ namespace smt { enode is associated with a variable of the given theory. */ void enode::del_th_var(theory_id id) { - SASSERT(get_th_var(id) != null_theory_var); - if (m_th_var_list.get_th_id() == id) { - theory_var_list * next = m_th_var_list.get_next(); - if (next == nullptr) { - // most common case - m_th_var_list.set_th_var(null_theory_var); - m_th_var_list.set_th_id(null_theory_id); - m_th_var_list.set_next(nullptr); - } - else { - m_th_var_list = *next; - } - } - else { - theory_var_list * prev = get_th_var_list(); - theory_var_list * l = prev->get_next(); - while (l) { - SASSERT(prev->get_next() == l); - if (l->get_th_id() == id) { - prev->set_next(l->get_next()); - return; - } - prev = l; - l = l->get_next(); - } - UNREACHABLE(); - } + m_th_var_list.del_var(id); } diff --git a/src/smt/smt_enode.h b/src/smt/smt_enode.h index 4626b996b..4233c51b1 100644 --- a/src/smt/smt_enode.h +++ b/src/smt/smt_enode.h @@ -18,11 +18,12 @@ Revision History: --*/ #pragma once +#include "util/id_var_list.h" +#include "util/approx_set.h" #include "ast/ast.h" #include "smt/smt_types.h" #include "smt/smt_eq_justification.h" -#include "smt/smt_theory_var_list.h" -#include "util/approx_set.h" + namespace smt { @@ -48,6 +49,7 @@ namespace smt { unknown performance penalty for this. */ typedef ptr_vector app2enode_t; // app -> enode + typedef id_var_list theory_var_list; class tmp_enode; @@ -91,7 +93,7 @@ namespace smt { then the congruent f(b) in m_parents will also be relevant. */ enode_vector m_parents; //!< Parent enodes of the equivalence class. - theory_var_list m_th_var_list; //!< List of theories that 'care' about this enode. + id_var_list<> m_th_var_list; //!< List of theories that 'care' about this enode. trans_justification m_trans; //!< A justification for the enode being equal to its root. bool m_proof_is_logged; //!< Indicates that the proof for the enode being equal to its root is in the log. signed char m_lbl_hash; //!< It is different from -1, if enode is used in a pattern @@ -105,7 +107,7 @@ namespace smt { theory_var_list * get_th_var_list() { - return m_th_var_list.get_th_var() == null_theory_var ? nullptr : &m_th_var_list; + return m_th_var_list.get_var() == null_theory_var ? nullptr : &m_th_var_list; } friend class set_merge_tf_trail; @@ -356,11 +358,11 @@ namespace smt { iterator end() { return iterator(this, this); } theory_var_list const * get_th_var_list() const { - return m_th_var_list.get_th_var() == null_theory_var ? nullptr : &m_th_var_list; + return m_th_var_list.get_var() == null_theory_var ? nullptr : &m_th_var_list; } bool has_th_vars() const { - return m_th_var_list.get_th_var() != null_theory_var; + return m_th_var_list.get_var() != null_theory_var; } unsigned get_num_th_vars() const; diff --git a/src/smt/smt_theory_var_list.h b/src/smt/smt_theory_var_list.h deleted file mode 100644 index 68b27624f..000000000 --- a/src/smt/smt_theory_var_list.h +++ /dev/null @@ -1,74 +0,0 @@ -/*++ -Copyright (c) 2006 Microsoft Corporation - -Module Name: - - smt_theory_var_list.h - -Abstract: - - - -Author: - - Leonardo de Moura (leonardo) 2008-02-19. - -Revision History: - ---*/ -#pragma once - -#include "smt/smt_types.h" - -namespace smt { - - class theory_var_list { - int m_th_id:8; - int m_th_var:24; - theory_var_list * m_next; - - public: - theory_var_list(): - m_th_id(null_theory_id), - m_th_var(null_theory_var), - m_next(nullptr) { - } - - theory_var_list(theory_id t, theory_var v, theory_var_list * n = nullptr): - m_th_id(t), - m_th_var(v), - m_next(n) { - } - - theory_id get_th_id() const { - return m_th_id; - } - - theory_var get_th_var() const { - return m_th_var; - } - - theory_var_list * get_next() const { - return m_next; - } - - void set_th_id(theory_id id) { - m_th_id = id; - } - - void set_th_var(theory_var v) { - m_th_var = v; - } - - void set_next(theory_var_list * next) { - m_next = next; - } - }; - - // 32 bit machine - static_assert(sizeof(expr*) != 4 || sizeof(theory_var_list) == sizeof(theory_var_list *) + sizeof(int), "32 bit"); - // 64 bit machine - static_assert(sizeof(expr*) != 8 || sizeof(theory_var_list) == sizeof(theory_var_list *) + sizeof(int) + /* a structure must be aligned */ sizeof(int), "64 bit"); -}; - - diff --git a/src/smt/theory_array_base.cpp b/src/smt/theory_array_base.cpp index 967b1f299..1c5fc4bb4 100644 --- a/src/smt/theory_array_base.cpp +++ b/src/smt/theory_array_base.cpp @@ -498,7 +498,7 @@ namespace smt { TRACE("array_shared", tout << "new shared var: #" << r->get_owner_id() << "\n";); r->set_mark(); to_unmark.push_back(r); - theory_var r_th_var = r->get_th_var(get_id()); + theory_var r_th_var = r->get_var(get_id()); SASSERT(r_th_var != null_theory_var); result.push_back(r_th_var); } diff --git a/src/util/id_var_list.h b/src/util/id_var_list.h new file mode 100644 index 000000000..6a795c616 --- /dev/null +++ b/src/util/id_var_list.h @@ -0,0 +1,175 @@ +/*++ +Copyright (c) 2006 Microsoft Corporation + +Module Name: + + id_var_list.h + +Abstract: + + Association list from theory id -> var + where id in [0..255] and var is 24 bit. + +Author: + + Leonardo de Moura (leonardo) 2008-02-19. + +Revision History: + + Extracted from smt_theory_var_list +--*/ +#pragma once + +#include "util/region.h" + +template +class id_var_list { + int m_id:8; + int m_var:24; + id_var_list * m_next; + +public: + id_var_list(): + m_id(null_id), + m_var(null_var), + m_next(nullptr) { + } + + id_var_list(int t, int v, id_var_list * n = nullptr): + m_id(t), + m_var(v), + m_next(n) { + } + + int get_id() const { + return m_id; + } + + int get_var() const { + return m_var; + } + + bool empty() const { + return get_var() == null_var; + } + + int find(int id) const { + if (empty()) + return null_var; + auto l = this; + do { + if (id == l->get_id()) + return l->get_var(); + l = l->get_next(); + } + while (l); + return null_var; + } + + unsigned size() const { + if (empty()) + return 0; + unsigned r = 0; + auto l = this; + while (l) { + ++r; + l = l->get_next(); + } + return r; + } + + void add_var(int v, int id, region& r) { + SASSERT(find(id) == null_var); + if (get_var() == null_var) { + m_var = v; + m_id = id; + m_next = nullptr; + } + else { + auto l = this; + while (l->get_next()) { + SASSERT(l->get_id() != id); + l = l->get_next(); + } + SASSERT(l); + SASSERT(!l->get_next()); + auto * new_cell = new (r) id_var_list(id, v); + l->set_next(new_cell); + } + SASSERT(find(id) == v); + } + + /** + \brief Replace the entry (v', id) with the entry (v, id). + The enode must have an entry (v', id) + */ + void replace(int v, int id) { + SASSERT(find(id) != null_var); + auto l = this; + while (l) { + if (l->get_id() == id) { + l->set_var(v); + return; + } + l = l->get_next(); + } + UNREACHABLE(); + } + + /** + \brief Delete theory variable. It assumes the + enode is associated with a variable of the given theory. + */ + void del_var(int id) { + SASSERT(find(id) != null_var); + if (get_id() == id) { + if (!m_next) { + // most common case + m_var = null_var; + m_id = null_id; + } + else { + m_var = m_next->get_var(); + m_id = m_next->get_id(); + m_next = m_next->get_next(); + } + } + else { + auto* prev = this; + auto* l = prev->get_next(); + while (l) { + SASSERT(prev->get_next() == l); + if (l->get_id() == id) { + prev->set_next(l->get_next()); + return; + } + prev = l; + l = l->get_next(); + } + UNREACHABLE(); + } + } + + id_var_list * get_next() const { + return m_next; + } + + void set_id(int id) { + m_id = id; + } + + void set_var(int v) { + m_var = v; + } + + void set_next(id_var_list * next) { + m_next = next; + } +}; + +// 32 bit machine +static_assert(sizeof(unsigned*) != 4 || sizeof(id_var_list<>) == sizeof(id_var_list<> *) + sizeof(int), "32 bit"); +// 64 bit machine +static_assert(sizeof(unsigned*) != 8 || sizeof(id_var_list<>) == sizeof(id_var_list<> *) + sizeof(int) + /* a structure must be aligned */ sizeof(int), "64 bit"); + +