diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index 8a9ede35f..d7ac3ff67 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -156,17 +156,17 @@ br_status seq_rewriter::mk_seq_length(expr* a, expr_ref& result) { unsigned len = 0; unsigned j = 0; for (unsigned i = 0; i < m_es.size(); ++i) { - if (m_util.str.is_string(m_es[i], b)) { + if (m_util.str.is_string(m_es[i].get(), b)) { len += b.length(); } - else if (m_util.str.is_unit(m_es[i])) { + else if (m_util.str.is_unit(m_es[i].get())) { len += 1; } - else if (m_util.str.is_empty(m_es[i])) { + else if (m_util.str.is_empty(m_es[i].get())) { // skip } else { - m_es[j] = m_es[i]; + m_es[j] = m_es[i].get(); ++j; } } @@ -177,7 +177,7 @@ br_status seq_rewriter::mk_seq_length(expr* a, expr_ref& result) { if (j != m_es.size() || j != 1) { expr_ref_vector es(m()); for (unsigned i = 0; i < j; ++i) { - es.push_back(m_util.str.mk_length(m_es[i])); + es.push_back(m_util.str.mk_length(m_es[i].get())); } if (len != 0) { es.push_back(m_autil.mk_numeral(rational(len, rational::ui64()), true)); @@ -207,14 +207,14 @@ br_status seq_rewriter::mk_seq_contains(expr* a, expr* b, expr_ref& result) { return BR_DONE; } // check if subsequence of b is in a. - ptr_vector as, bs; + expr_ref_vector as(m()), bs(m()); m_util.str.get_concat(a, as); m_util.str.get_concat(b, bs); bool found = false; for (unsigned i = 0; !found && i < as.size(); ++i) { if (bs.size() > as.size() - i) break; unsigned j = 0; - for (; j < bs.size() && as[j+i] == bs[j]; ++j) {}; + for (; j < bs.size() && as[j+i].get() == bs[j].get(); ++j) {}; found = j == bs.size(); } if (found) { @@ -292,7 +292,7 @@ br_status seq_rewriter::mk_seq_prefix(expr* a, expr* b, expr_ref& result) { expr* b1 = m_util.str.get_leftmost_concat(b); isc1 = m_util.str.is_string(a1, s1); isc2 = m_util.str.is_string(b1, s2); - ptr_vector as, bs; + expr_ref_vector as(m()), bs(m()); if (a1 != b1 && isc1 && isc2) { if (s1.length() <= s2.length()) { @@ -342,7 +342,7 @@ br_status seq_rewriter::mk_seq_prefix(expr* a, expr* b, expr_ref& result) { m_util.str.get_concat(a, as); m_util.str.get_concat(b, bs); unsigned i = 0; - for (; i < as.size() && i < bs.size() && as[i] == bs[i]; ++i) {}; + for (; i < as.size() && i < bs.size() && as[i].get() == bs[i].get(); ++i) {}; if (i == as.size()) { result = m().mk_true(); return BR_DONE; @@ -350,7 +350,7 @@ br_status seq_rewriter::mk_seq_prefix(expr* a, expr* b, expr_ref& result) { if (i == bs.size()) { expr_ref_vector es(m()); for (unsigned j = i; j < as.size(); ++j) { - es.push_back(m().mk_eq(m_util.str.mk_empty(m().get_sort(a)), as[j])); + es.push_back(m().mk_eq(m_util.str.mk_empty(m().get_sort(a)), as[j].get())); } result = mk_and(es); return BR_REWRITE3; @@ -522,7 +522,6 @@ bool seq_rewriter::reduce_eq(expr* l, expr* r, expr_ref_vector& lhs, expr_ref_ve expr* a, *b; zstring s; bool change = false; - expr_ref_vector trail(m()); m_lhs.reset(); m_rhs.reset(); m_util.str.get_concat(l, m_lhs); @@ -545,7 +544,7 @@ bool seq_rewriter::reduce_eq(expr* l, expr* r, expr_ref_vector& lhs, expr_ref_ve expr* r = m_rhs.back(); if (m_util.str.is_unit(r) && m_util.str.is_string(l)) { std::swap(l, r); - std::swap(m_lhs, m_rhs); + m_lhs.swap(m_rhs); } if (l == r) { m_lhs.pop_back(); @@ -575,7 +574,6 @@ bool seq_rewriter::reduce_eq(expr* l, expr* r, expr_ref_vector& lhs, expr_ref_ve else { expr_ref s2(m_util.str.mk_string(s.extract(0, s.length()-2)), m()); m_rhs[m_rhs.size()-1] = s2; - trail.push_back(s2); } } else { @@ -587,10 +585,10 @@ bool seq_rewriter::reduce_eq(expr* l, expr* r, expr_ref_vector& lhs, expr_ref_ve // solve from front unsigned head1 = 0, head2 = 0; while (true) { - while (head1 < m_lhs.size() && m_util.str.is_empty(m_lhs[head1])) { + while (head1 < m_lhs.size() && m_util.str.is_empty(m_lhs[head1].get())) { ++head1; } - while (head2 < m_rhs.size() && m_util.str.is_empty(m_rhs[head2])) { + while (head2 < m_rhs.size() && m_util.str.is_empty(m_rhs[head2].get())) { ++head2; } if (head1 == m_lhs.size() || head2 == m_rhs.size()) { @@ -598,11 +596,11 @@ bool seq_rewriter::reduce_eq(expr* l, expr* r, expr_ref_vector& lhs, expr_ref_ve } SASSERT(head1 < m_lhs.size() && head2 < m_rhs.size()); - expr* l = m_lhs[head1]; - expr* r = m_rhs[head2]; + expr* l = m_lhs[head1].get(); + expr* r = m_rhs[head2].get(); if (m_util.str.is_unit(r) && m_util.str.is_string(l)) { std::swap(l, r); - std::swap(m_lhs, m_rhs); + m_lhs.swap(m_rhs); } if (l == r) { ++head1; @@ -631,7 +629,6 @@ bool seq_rewriter::reduce_eq(expr* l, expr* r, expr_ref_vector& lhs, expr_ref_ve else { expr_ref s2(m_util.str.mk_string(s.extract(1, s.length()-1)), m()); m_rhs[m_rhs.size()-1] = s2; - trail.push_back(s2); } } else { @@ -643,8 +640,8 @@ bool seq_rewriter::reduce_eq(expr* l, expr* r, expr_ref_vector& lhs, expr_ref_ve zstring s1, s2; while (head1 < m_lhs.size() && head2 < m_rhs.size() && - m_util.str.is_string(m_lhs[head1], s1) && - m_util.str.is_string(m_rhs[head2], s2)) { + m_util.str.is_string(m_lhs[head1].get(), s1) && + m_util.str.is_string(m_rhs[head2].get(), s2)) { unsigned l = std::min(s1.length(), s2.length()); for (unsigned i = 0; i < l; ++i) { if (s1[i] != s2[i]) { @@ -656,14 +653,12 @@ bool seq_rewriter::reduce_eq(expr* l, expr* r, expr_ref_vector& lhs, expr_ref_ve } else { m_lhs[head1] = m_util.str.mk_string(s1.extract(l, s1.length()-l)); - trail.push_back(m_lhs[head1]); } if (l == s2.length()) { ++head2; } else { m_rhs[head2] = m_util.str.mk_string(s2.extract(l, s2.length()-l)); - trail.push_back(m_rhs[head2]); } change = true; } @@ -681,11 +676,9 @@ bool seq_rewriter::reduce_eq(expr* l, expr* r, expr_ref_vector& lhs, expr_ref_ve m_rhs.pop_back(); if (l < s1.length()) { m_lhs.push_back(m_util.str.mk_string(s1.extract(0, s1.length()-l))); - trail.push_back(m_lhs.back()); } if (l < s2.length()) { m_rhs.push_back(m_util.str.mk_string(s2.extract(0, s2.length()-l))); - trail.push_back(m_rhs.back()); } change = true; } diff --git a/src/ast/rewriter/seq_rewriter.h b/src/ast/rewriter/seq_rewriter.h index de3634a51..c3e466585 100644 --- a/src/ast/rewriter/seq_rewriter.h +++ b/src/ast/rewriter/seq_rewriter.h @@ -32,7 +32,7 @@ Notes: class seq_rewriter { seq_util m_util; arith_util m_autil; - ptr_vector m_es, m_lhs, m_rhs; + expr_ref_vector m_es, m_lhs, m_rhs; br_status mk_seq_concat(expr* a, expr* b, expr_ref& result); br_status mk_seq_length(expr* a, expr_ref& result); @@ -63,7 +63,7 @@ class seq_rewriter { public: seq_rewriter(ast_manager & m, params_ref const & p = params_ref()): - m_util(m), m_autil(m) { + m_util(m), m_autil(m), m_es(m), m_lhs(m), m_rhs(m) { } ast_manager & m() const { return m_util.get_manager(); } family_id get_fid() const { return m_util.get_family_id(); } diff --git a/src/ast/seq_decl_plugin.cpp b/src/ast/seq_decl_plugin.cpp index cda154050..75e27c081 100644 --- a/src/ast/seq_decl_plugin.cpp +++ b/src/ast/seq_decl_plugin.cpp @@ -626,7 +626,7 @@ bool seq_util::str::is_string(expr const* n, zstring& s) const { } -void seq_util::str::get_concat(expr* e, ptr_vector& es) const { +void seq_util::str::get_concat(expr* e, expr_ref_vector& es) const { expr* e1, *e2; while (is_concat(e, e1, e2)) { get_concat(e1, es); diff --git a/src/ast/seq_decl_plugin.h b/src/ast/seq_decl_plugin.h index 04161b08d..33d4de378 100644 --- a/src/ast/seq_decl_plugin.h +++ b/src/ast/seq_decl_plugin.h @@ -271,7 +271,7 @@ public: MATCH_BINARY(is_in_re); MATCH_UNARY(is_unit); - void get_concat(expr* e, ptr_vector& es) const; + void get_concat(expr* e, expr_ref_vector& es) const; expr* get_leftmost_concat(expr* e) const { expr* e1, *e2; while (is_concat(e, e1, e2)) e = e1; return e; } }; diff --git a/src/smt/smt_setup.cpp b/src/smt/smt_setup.cpp index 8ebfa2d71..8a40f9d7a 100644 --- a/src/smt/smt_setup.cpp +++ b/src/smt/smt_setup.cpp @@ -815,7 +815,7 @@ namespace smt { } void setup::setup_seq() { - m_context.register_plugin(alloc(theory_seq_empty, m_manager)); + m_context.register_plugin(alloc(theory_seq, m_manager)); } void setup::setup_card() { diff --git a/src/smt/theory_seq.cpp b/src/smt/theory_seq.cpp index 641d5e444..753ddabd2 100644 --- a/src/smt/theory_seq.cpp +++ b/src/smt/theory_seq.cpp @@ -145,6 +145,7 @@ theory_seq::theory_seq(ast_manager& m): } theory_seq::~theory_seq() { + m_trail_stack.reset(); } @@ -157,6 +158,9 @@ final_check_status theory_seq::final_check_eh() { if (simplify_and_solve_eqs()) { return FC_CONTINUE; } + if (solve_nqs()) { + return FC_CONTINUE; + } if (ctx.inconsistent()) { return FC_CONTINUE; } @@ -209,7 +213,7 @@ bool theory_seq::check_ineqs() { bool theory_seq::branch_variable() { context& ctx = get_context(); unsigned sz = m_eqs.size(); - ptr_vector ls, rs; + expr_ref_vector ls(m), rs(m); for (unsigned i = 0; i < sz; ++i) { unsigned k = (i + m_branch_variable_head) % sz; eq e = m_eqs[k]; @@ -218,11 +222,11 @@ bool theory_seq::branch_variable() { m_util.str.get_concat(e.m_lhs, ls); m_util.str.get_concat(e.m_rhs, rs); - if (!ls.empty() && find_branch_candidate(ls[0], rs)) { + if (!ls.empty() && find_branch_candidate(ls[0].get(), rs)) { m_branch_variable_head = k; return true; } - if (!rs.empty() && find_branch_candidate(rs[0], ls)) { + if (!rs.empty() && find_branch_candidate(rs[0].get(), ls)) { m_branch_variable_head = k; return true; } @@ -230,7 +234,7 @@ bool theory_seq::branch_variable() { return false; } -bool theory_seq::find_branch_candidate(expr* l, ptr_vector const& rs) { +bool theory_seq::find_branch_candidate(expr* l, expr_ref_vector const& rs) { TRACE("seq", tout << mk_pp(l, m) << " " << (is_var(l)?"var":"not var") << "\n";); @@ -434,8 +438,7 @@ bool theory_seq::simplify_eq(expr* l, expr* r, enode_pair_dependency* deps) { set_conflict(deps); return true; } - if (lhs.size() == 1 && l == lhs[0].get() && - rhs.size() == 1 && r == rhs[0].get()) { + if (unchanged(l, lhs) && unchanged(r, rhs)) { return false; } SASSERT(lhs.size() == rhs.size()); @@ -558,8 +561,115 @@ bool theory_seq::pre_process_eqs(bool simplify_or_solve) { return change; } +bool theory_seq::solve_nqs() { + bool change = false; + context & ctx = get_context(); + for (unsigned i = 0; !ctx.inconsistent() && i < m_nqs.size(); ++i) { + change = solve_ne(i) || change; + if (m_nqs[i].is_solved()) { + m_nqs.erase_and_swap(i); + --i; + } + } + return change; +} + +bool theory_seq::solve_ne(unsigned idx) { + context& ctx = get_context(); + seq_rewriter rw(m); + bool change = false; + ne const& n = m_nqs[idx]; + TRACE("seq", display_disequation(tout, n);); + + SASSERT(!n.is_solved()); + for (unsigned i = 0; i < n.m_lits.size(); ++i) { + switch (ctx.get_assignment(n.m_lits[i])) { + case l_true: + erase_lit(idx, i); + --i; + break; + case l_false: + // mark as solved in + mark_solved(idx); + return false; + case l_undef: + break; + } + } + for (unsigned i = 0; i < n.m_lhs.size(); ++i) { + expr_ref_vector lhs(m), rhs(m); + enode_pair_dependency* deps = 0; + expr* l = n.m_lhs[i]; + expr* r = n.m_rhs[i]; + expr_ref lh = canonize(l, deps); + expr_ref rh = canonize(r, deps); + if (!rw.reduce_eq(lh, rh, lhs, rhs)) { + mark_solved(idx); + return change; + } + else if (unchanged(l, lhs) && unchanged(r, rhs)) { + // continue + } + else if (unchanged(r, lhs) && unchanged(l, rhs)) { + // continue + } + else { + TRACE("seq", tout << lhs.size() << "\n"; + for (unsigned j = 0; j < lhs.size(); ++j) { + tout << mk_pp(lhs[j].get(), m) << " "; + } + tout << "\n"; + tout << mk_pp(l, m) << " != " << mk_pp(r, m) << "\n";); + + for (unsigned j = 0; j < lhs.size(); ++j) { + expr_ref nl(lhs[j].get(), m); + expr_ref nr(rhs[j].get(), m); + if (m_util.is_seq(nl) || m_util.is_re(nl)) { + //std::cout << "push_ne " << nl << " != " << nr << "\n"; + m_trail_stack.push(push_ne(*this, idx, nl, nr)); + } + else { + //std::cout << "push_lit\n"; + literal lit(mk_eq(nl, nr, false)); + m_trail_stack.push(push_lit(*this, idx, ~lit)); + ctx.mark_as_relevant(lit); + } + } + m_trail_stack.push(push_dep(*this, idx, deps)); + erase_index(idx, i); + --i; + } + } + if (n.m_lits.empty() && n.m_lhs.empty()) { + set_conflict(n.m_dep); + return true; + } + return change; +} + +void theory_seq::erase_lit(unsigned idx, unsigned i) { + ne const& n = m_nqs[idx]; + if (n.m_lits.size() < i + 1) { + m_trail_stack.push(set_lit(*this, idx, i, n.m_lits.back())); + } + m_trail_stack.push(pop_lit(*this, idx)); +} + +void theory_seq::mark_solved(unsigned idx) { + m_trail_stack.push(solved_ne(*this, idx)); +} + +void theory_seq::erase_index(unsigned idx, unsigned i) { + ne const& n = m_nqs[idx]; + unsigned sz = n.m_lhs.size(); + if (i + 1 != sz) { + m_trail_stack.push(set_ne(*this, idx, i, n.m_lhs[sz-1], n.m_rhs[sz-1])); + } + m_trail_stack.push(pop_ne(*this, idx)); +} + bool theory_seq::simplify_and_solve_eqs() { - context & ctx = get_context(); + context & ctx = get_context(); bool change = simplify_eqs(); while (!ctx.inconsistent() && solve_basic_eqs()) { simplify_eqs(); @@ -620,6 +730,7 @@ void theory_seq::apply_sort_cnstr(enode* n, sort* s) { void theory_seq::display(std::ostream & out) const { if (m_eqs.size() == 0 && + m_nqs.size() == 0 && m_ineqs.empty() && m_rep.empty() && m_exclude.empty()) { @@ -630,6 +741,10 @@ void theory_seq::display(std::ostream & out) const { out << "Equations:\n"; display_equations(out); } + if (m_nqs.size() > 0) { + out << "Disequations:\n"; + display_disequations(out); + } if (!m_ineqs.empty()) { out << "Negative constraints:\n"; for (unsigned i = 0; i < m_ineqs.size(); ++i) { @@ -654,6 +769,25 @@ void theory_seq::display_equations(std::ostream& out) const { } } +void theory_seq::display_disequations(std::ostream& out) const { + for (unsigned i = 0; i < m_nqs.size(); ++i) { + display_disequation(out, m_nqs[i]); + } +} + +void theory_seq::display_disequation(std::ostream& out, ne const& e) const { + for (unsigned j = 0; j < e.m_lits.size(); ++j) { + out << e.m_lits[j] << " "; + } + if (e.m_lits.size() > 0) { + out << "\n"; + } + for (unsigned j = 0; j < e.m_lhs.size(); ++j) { + out << mk_pp(e.m_lhs[j], m) << " != " << mk_pp(e.m_rhs[j], m) << "\n"; + } + display_deps(out, e.m_dep); +} + void theory_seq::display_deps(std::ostream& out, enode_pair_dependency* dep) const { vector _eqs; const_cast(m_dm).linearize(dep, _eqs); @@ -735,9 +869,6 @@ expr_ref theory_seq::expand(expr* e, enode_pair_dependency*& eqs) { else if (m_util.str.is_empty(e) || m_util.str.is_string(e)) { result = e; } - else if (m.is_eq(e, e1, e2)) { - result = m.mk_eq(expand(e1, deps), expand(e2, deps)); - } else if (m_util.str.is_prefix(e, e1, e2)) { result = m_util.str.mk_prefix(expand(e1, deps), expand(e2, deps)); } @@ -762,6 +893,9 @@ expr_ref theory_seq::expand(expr* e, enode_pair_dependency*& eqs) { else { result = e; } + if (result == e) { + deps = 0; + } expr_dep edr(result, deps); m_rep.add_cache(e, edr); eqs = m_dm.mk_join(eqs, deps); @@ -1164,6 +1298,13 @@ void theory_seq::assign_eq(bool_var v, bool is_true) { } } else { + //if (m_util.str.is_prefix(e, e1, e2)) { + // could add negative prefix axioms: + // len(e1) <= len(e2) => e2 = seq.prefix.left(e2)*seq.prefix.right(e2) + // & len(seq.prefix.left(e2)) = len(e1) + // & seq.prefix.left(e2) != e1 + // or could solve prefix/suffix disunification constraints. + //} m_trail_stack.push(push_back_vector(m_ineqs)); m_ineqs.push_back(e); } @@ -1181,15 +1322,15 @@ void theory_seq::new_eq_eh(theory_var v1, theory_var v2) { } void theory_seq::new_diseq_eh(theory_var v1, theory_var v2) { - expr* e1 = get_enode(v1)->get_owner(); - expr* e2 = get_enode(v2)->get_owner(); - m_trail_stack.push(push_back_vector(m_ineqs)); - m_ineqs.push_back(mk_eq_atom(e1, e2)); + enode* n1 = get_enode(v1); + enode* n2 = get_enode(v2); + expr_ref e1(n1->get_owner(), m); + expr_ref e2(n2->get_owner(), m); + m_nqs.push_back(ne(e1, e2, m_dm.mk_leaf(enode_pair(n1, n2)))); m_exclude.update(e1, e2); } void theory_seq::push_scope_eh() { - TRACE("seq", tout << "push " << m_eqs.size() << "\n";); theory::push_scope_eh(); m_rep.push_scope(); m_exclude.push_scope(); @@ -1197,16 +1338,17 @@ void theory_seq::push_scope_eh() { m_trail_stack.push_scope(); m_trail_stack.push(value_trail(m_axioms_head)); m_eqs.push_scope(); + m_nqs.push_scope(); } void theory_seq::pop_scope_eh(unsigned num_scopes) { - TRACE("seq", tout << "pop " << m_eqs.size() << "\n";); m_trail_stack.pop_scope(num_scopes); theory::pop_scope_eh(num_scopes); m_dm.pop_scope(num_scopes); m_rep.pop_scope(num_scopes); m_exclude.pop_scope(num_scopes); - m_eqs.pop_scopes(num_scopes); + m_eqs.pop_scope(num_scopes); + m_nqs.pop_scope(num_scopes); } void theory_seq::restart_eh() { diff --git a/src/smt/theory_seq.h b/src/smt/theory_seq.h index b2b45e77e..59ae5095a 100644 --- a/src/smt/theory_seq.h +++ b/src/smt/theory_seq.h @@ -104,6 +104,136 @@ namespace smt { eq& operator=(eq const& other) { m_lhs = other.m_lhs; m_rhs = other.m_rhs; m_dep = other.m_dep; return *this; } }; + + // asserted or derived disqequality with dependencies + struct ne { + bool m_solved; + expr_ref_vector m_lhs; + expr_ref_vector m_rhs; + literal_vector m_lits; + enode_pair_dependency* m_dep; + ne(expr_ref& l, expr_ref& r, enode_pair_dependency* d): + m_solved(false), m_lhs(l.get_manager()), m_rhs(r.get_manager()), m_dep(d) { + m_lhs.push_back(l); + m_rhs.push_back(r); + } + ne(ne const& other): + m_solved(other.m_solved), m_lhs(other.m_lhs), m_rhs(other.m_rhs), m_lits(other.m_lits), m_dep(other.m_dep) {} + ne& operator=(ne const& other) { + m_solved = other.m_solved; + m_lhs.reset(); m_lhs.append(other.m_lhs); + m_rhs.reset(); m_rhs.append(other.m_rhs); + m_lits.reset(); m_lits.append(other.m_lits); + m_dep = other.m_dep; + return *this; + } + bool is_solved() const { return m_solved; } + }; + + class pop_lit : public trail { + unsigned m_idx; + literal m_lit; + public: + pop_lit(theory_seq& th, unsigned idx): m_idx(idx), m_lit(th.m_nqs[idx].m_lits.back()) { + th.m_nqs.ref(m_idx).m_lits.pop_back(); + } + virtual void undo(theory_seq & th) { th.m_nqs.ref(m_idx).m_lits.push_back(m_lit); } + }; + class push_lit : public trail { + unsigned m_idx; + public: + push_lit(theory_seq& th, unsigned idx, literal lit): m_idx(idx) { + th.m_nqs.ref(m_idx).m_lits.push_back(lit); + } + virtual void undo(theory_seq & th) { th.m_nqs.ref(m_idx).m_lits.pop_back(); } + }; + class set_lit : public trail { + unsigned m_idx; + unsigned m_i; + literal m_lit; + public: + set_lit(theory_seq& th, unsigned idx, unsigned i, literal lit): + m_idx(idx), m_i(i), m_lit(th.m_nqs[idx].m_lits[i]) { + th.m_nqs.ref(m_idx).m_lits[i] = lit; + } + virtual void undo(theory_seq & th) { th.m_nqs.ref(m_idx).m_lits[m_i] = m_lit; } + }; + void erase_lit(unsigned idx, unsigned i); + + class solved_ne : public trail { + unsigned m_idx; + public: + solved_ne(theory_seq& th, unsigned idx) : m_idx(idx) { th.m_nqs.ref(idx).m_solved = true; } + virtual void undo(theory_seq& th) { th.m_nqs.ref(m_idx).m_solved = false; } + }; + void mark_solved(unsigned idx); + + class push_ne : public trail { + unsigned m_idx; + public: + push_ne(theory_seq& th, unsigned idx, expr* l, expr* r) : m_idx(idx) { + th.m_nqs.ref(m_idx).m_lhs.push_back(l); + th.m_nqs.ref(m_idx).m_rhs.push_back(r); + } + virtual void undo(theory_seq& th) { th.m_nqs.ref(m_idx).m_lhs.pop_back(); th.m_nqs.ref(m_idx).m_rhs.pop_back(); } + }; + + class pop_ne : public trail { + expr_ref m_lhs; + expr_ref m_rhs; + unsigned m_idx; + public: + pop_ne(theory_seq& th, unsigned idx): + m_lhs(th.m_nqs[idx].m_lhs.back(), th.m), + m_rhs(th.m_nqs[idx].m_rhs.back(), th.m), + m_idx(idx) { + th.m_nqs.ref(idx).m_lhs.pop_back(); + th.m_nqs.ref(idx).m_rhs.pop_back(); + } + virtual void undo(theory_seq& th) { + th.m_nqs.ref(m_idx).m_lhs.push_back(m_lhs); + th.m_nqs.ref(m_idx).m_rhs.push_back(m_rhs); + m_lhs.reset(); + m_rhs.reset(); + } + }; + + class set_ne : public trail { + expr_ref m_lhs; + expr_ref m_rhs; + unsigned m_idx; + unsigned m_i; + public: + set_ne(theory_seq& th, unsigned idx, unsigned i, expr* l, expr* r): + m_lhs(th.m_nqs[idx].m_lhs[i], th.m), + m_rhs(th.m_nqs[idx].m_rhs[i], th.m), + m_idx(idx), + m_i(i) { + th.m_nqs.ref(idx).m_lhs[i] = l; + th.m_nqs.ref(idx).m_rhs[i] = r; + } + virtual void undo(theory_seq& th) { + th.m_nqs.ref(m_idx).m_lhs[m_i] = m_lhs; + th.m_nqs.ref(m_idx).m_rhs[m_i] = m_rhs; + m_lhs.reset(); + m_rhs.reset(); + } + }; + + class push_dep : public trail { + enode_pair_dependency* m_dep; + unsigned m_idx; + public: + push_dep(theory_seq& th, unsigned idx, enode_pair_dependency* d): m_dep(th.m_nqs[idx].m_dep), m_idx(idx) { + th.m_nqs.ref(idx).m_dep = d; + } + virtual void undo(theory_seq& th) { + th.m_nqs.ref(m_idx).m_dep = m_dep; + } + }; + + void erase_index(unsigned idx, unsigned i); + struct stats { stats() { reset(); } void reset() { memset(this, 0, sizeof(stats)); } @@ -114,6 +244,7 @@ namespace smt { enode_pair_dependency_manager m_dm; solution_map m_rep; // unification representative. scoped_vector m_eqs; // set of current equations. + scoped_vector m_nqs; // set of current disequalities. seq_factory* m_factory; // value factory expr_ref_vector m_ineqs; // inequalities to check solution against @@ -174,13 +305,17 @@ namespace smt { bool solve_unit_eq(expr* l, expr* r, enode_pair_dependency* dep); bool solve_basic_eqs(); + bool solve_nqs(); + bool solve_ne(unsigned i); + bool unchanged(expr* e, expr_ref_vector& es) const { return es.size() == 1 && es[0] == e; } + // asserting consequences void propagate_lit(enode_pair_dependency* dep, literal lit); void propagate_eq(enode_pair_dependency* dep, enode* n1, enode* n2); void propagate_eq(bool_var v, expr* e1, expr* e2); void set_conflict(enode_pair_dependency* dep); - bool find_branch_candidate(expr* l, ptr_vector const& rs); + bool find_branch_candidate(expr* l, expr_ref_vector const& rs); bool assume_equality(expr* l, expr* r); // variable solving utilities @@ -219,6 +354,8 @@ namespace smt { // diagnostics void display_equations(std::ostream& out) const; + void display_disequations(std::ostream& out) const; + void display_disequation(std::ostream& out, ne const& e) const; void display_deps(std::ostream& out, enode_pair_dependency* deps) const; public: theory_seq(ast_manager& m); diff --git a/src/util/scoped_vector.h b/src/util/scoped_vector.h index 917ecf2ab..a05b19487 100644 --- a/src/util/scoped_vector.h +++ b/src/util/scoped_vector.h @@ -46,7 +46,7 @@ public: m_elems_lim.push_back(m_elems_start); } - void pop_scopes(unsigned num_scopes) { + void pop_scope(unsigned num_scopes) { if (num_scopes == 0) return; unsigned new_size = m_sizes.size() - num_scopes; unsigned src_lim = m_src_lim[new_size]; @@ -72,6 +72,12 @@ public: return m_elems[m_index[idx]]; } + // breaks abstraction, caller must ensure backtracking. + T& ref(unsigned idx) { + SASSERT(idx < m_size); + return m_elems[m_index[idx]]; + } + void set(unsigned idx, T const& t) { SASSERT(idx < m_size); unsigned n = m_index[idx]; @@ -102,6 +108,13 @@ public: SASSERT(invariant()); } + void erase_and_swap(unsigned i) { + if (i + 1 < size()) { + set(i, m_elems[m_index[i]]); + } + pop_back(); + } + unsigned size() const { return m_size; } bool empty() const { return m_size == 0; }