From 67f22d8d659e4ef287dcdc8e9123aa2fb3bee932 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 29 Nov 2018 11:32:52 -0800 Subject: [PATCH] improving performance for length constraints Signed-off-by: Nikolaj Bjorner --- src/ast/seq_decl_plugin.cpp | 4 +- src/smt/smt_arith_value.cpp | 85 ++++++++++++--------- src/smt/smt_arith_value.h | 17 ++++- src/smt/theory_jobscheduler.cpp | 9 ++- src/smt/theory_seq.cpp | 130 +++++++++++++++----------------- src/smt/theory_seq.h | 11 ++- src/smt/theory_str.cpp | 10 ++- 7 files changed, 142 insertions(+), 124 deletions(-) diff --git a/src/ast/seq_decl_plugin.cpp b/src/ast/seq_decl_plugin.cpp index c4813f9f7..0bfc94e11 100644 --- a/src/ast/seq_decl_plugin.cpp +++ b/src/ast/seq_decl_plugin.cpp @@ -377,8 +377,8 @@ bool seq_decl_plugin::match(ptr_vector& binding, sort* s, sort* sP) { if (s->get_family_id() == sP->get_family_id() && s->get_decl_kind() == sP->get_decl_kind() && s->get_num_parameters() == sP->get_num_parameters()) { - for (unsigned i = 0, sz = s->get_num_parameters(); i < sz; ++i) { - parameter const& p = s->get_parameter(i); + for (unsigned i = 0, sz = s->get_num_parameters(); i < sz; ++i) { + parameter const& p = s->get_parameter(i); if (p.is_ast() && is_sort(p.get_ast())) { parameter const& p2 = sP->get_parameter(i); if (!match(binding, to_sort(p.get_ast()), to_sort(p2.get_ast()))) return false; diff --git a/src/smt/smt_arith_value.cpp b/src/smt/smt_arith_value.cpp index 443112ecc..50fe6340a 100644 --- a/src/smt/smt_arith_value.cpp +++ b/src/smt/smt_arith_value.cpp @@ -18,30 +18,32 @@ Revision History: --*/ #include "smt/smt_arith_value.h" -#include "smt/theory_lra.h" -#include "smt/theory_arith.h" namespace smt { - arith_value::arith_value(context& ctx): - m_ctx(ctx), m(ctx.get_manager()), a(m) {} + arith_value::arith_value(ast_manager& m): + m_ctx(nullptr), m(m), a(m) {} - bool arith_value::get_lo(expr* e, rational& lo, bool& is_strict) { - if (!m_ctx.e_internalized(e)) return false; + void arith_value::init(context* ctx) { + m_ctx = ctx; family_id afid = a.get_family_id(); + theory* th = m_ctx->get_theory(afid); + m_tha = dynamic_cast(th); + m_thi = dynamic_cast(th); + m_thr = dynamic_cast(th); + } + + bool arith_value::get_lo_equiv(expr* e, rational& lo, bool& is_strict) { + if (!m_ctx->e_internalized(e)) return false; is_strict = false; - enode* next = m_ctx.get_enode(e), *n = next; + enode* next = m_ctx->get_enode(e), *n = next; bool found = false; bool is_strict1; rational lo1; - theory* th = m_ctx.get_theory(afid); - theory_mi_arith* tha = dynamic_cast(th); - theory_i_arith* thi = dynamic_cast(th); - theory_lra* thr = dynamic_cast(th); do { - if ((tha && tha->get_lower(next, lo1, is_strict1)) || - (thi && thi->get_lower(next, lo1, is_strict1)) || - (thr && thr->get_lower(next, lo1, is_strict1))) { + if ((m_tha && m_tha->get_lower(next, lo1, is_strict1)) || + (m_thi && m_thi->get_lower(next, lo1, is_strict1)) || + (m_thr && m_thr->get_lower(next, lo1, is_strict1))) { if (!found || lo1 > lo || (lo == lo1 && is_strict1)) lo = lo1, is_strict = is_strict1; found = true; } @@ -51,21 +53,16 @@ namespace smt { return found; } - bool arith_value::get_up(expr* e, rational& up, bool& is_strict) { - if (!m_ctx.e_internalized(e)) return false; - family_id afid = a.get_family_id(); + bool arith_value::get_up_equiv(expr* e, rational& up, bool& is_strict) { + if (!m_ctx->e_internalized(e)) return false; is_strict = false; - enode* next = m_ctx.get_enode(e), *n = next; + enode* next = m_ctx->get_enode(e), *n = next; bool found = false, is_strict1; rational up1; - theory* th = m_ctx.get_theory(afid); - theory_mi_arith* tha = dynamic_cast(th); - theory_i_arith* thi = dynamic_cast(th); - theory_lra* thr = dynamic_cast(th); do { - if ((tha && tha->get_upper(next, up1, is_strict1)) || - (thi && thi->get_upper(next, up1, is_strict1)) || - (thr && thr->get_upper(next, up1, is_strict1))) { + if ((m_tha && m_tha->get_upper(next, up1, is_strict1)) || + (m_thi && m_thi->get_upper(next, up1, is_strict1)) || + (m_thr && m_thr->get_upper(next, up1, is_strict1))) { if (!found || up1 < up || (up1 == up && is_strict1)) up = up1, is_strict = is_strict1; found = true; } @@ -75,20 +72,36 @@ namespace smt { return found; } + bool arith_value::get_up(expr* e, rational& up, bool& is_strict) const { + if (!m_ctx->e_internalized(e)) return false; + is_strict = false; + enode* n = m_ctx->get_enode(e); + if (m_tha) return m_tha->get_upper(n, up, is_strict); + if (m_thi) return m_thi->get_upper(n, up, is_strict); + if (m_thr) return m_thr->get_upper(n, up, is_strict); + return false; + } + + bool arith_value::get_lo(expr* e, rational& up, bool& is_strict) const { + if (!m_ctx->e_internalized(e)) return false; + is_strict = false; + enode* n = m_ctx->get_enode(e); + if (m_tha) return m_tha->get_lower(n, up, is_strict); + if (m_thi) return m_thi->get_lower(n, up, is_strict); + if (m_thr) return m_thr->get_lower(n, up, is_strict); + return false; + } + + bool arith_value::get_value(expr* e, rational& val) { - if (!m_ctx.e_internalized(e)) return false; + if (!m_ctx->e_internalized(e)) return false; expr_ref _val(m); - enode* next = m_ctx.get_enode(e), *n = next; - family_id afid = a.get_family_id(); - theory* th = m_ctx.get_theory(afid); - theory_mi_arith* tha = dynamic_cast(th); - theory_i_arith* thi = dynamic_cast(th); - theory_lra* thr = dynamic_cast(th); + enode* next = m_ctx->get_enode(e), *n = next; do { e = next->get_owner(); - if (tha && tha->get_value(next, _val) && a.is_numeral(_val, val)) return true; - if (thi && thi->get_value(next, _val) && a.is_numeral(_val, val)) return true; - if (thr && thr->get_value(next, val)) return true; + if (m_tha && m_tha->get_value(next, _val) && a.is_numeral(_val, val)) return true; + if (m_thi && m_thi->get_value(next, _val) && a.is_numeral(_val, val)) return true; + if (m_thr && m_thr->get_value(next, val)) return true; next = next->get_next(); } while (next != n); @@ -97,7 +110,7 @@ namespace smt { final_check_status arith_value::final_check() { family_id afid = a.get_family_id(); - theory * th = m_ctx.get_theory(afid); + theory * th = m_ctx->get_theory(afid); return th->final_check_eh(); } }; diff --git a/src/smt/smt_arith_value.h b/src/smt/smt_arith_value.h index b819b2b9a..ddaa113ea 100644 --- a/src/smt/smt_arith_value.h +++ b/src/smt/smt_arith_value.h @@ -21,18 +21,27 @@ Revision History: #include "ast/arith_decl_plugin.h" #include "smt/smt_context.h" +#include "smt/theory_lra.h" +#include "smt/theory_arith.h" namespace smt { class arith_value { - context& m_ctx; + context* m_ctx; ast_manager& m; arith_util a; + theory_mi_arith* m_tha; + theory_i_arith* m_thi; + theory_lra* m_thr; public: - arith_value(context& ctx); - bool get_lo(expr* e, rational& lo, bool& strict); - bool get_up(expr* e, rational& up, bool& strict); + arith_value(ast_manager& m); + void init(context* ctx); + bool get_lo_equiv(expr* e, rational& lo, bool& strict); + bool get_up_equiv(expr* e, rational& up, bool& strict); bool get_value(expr* e, rational& value); + bool get_lo(expr* e, rational& lo, bool& strict) const; + bool get_up(expr* e, rational& up, bool& strict) const; + bool get_fixed(expr* e, rational& value) const; final_check_status final_check(); }; }; diff --git a/src/smt/theory_jobscheduler.cpp b/src/smt/theory_jobscheduler.cpp index 3b218f56d..152eb715b 100644 --- a/src/smt/theory_jobscheduler.cpp +++ b/src/smt/theory_jobscheduler.cpp @@ -551,7 +551,8 @@ namespace smt { } time_t theory_jobscheduler::get_lo(expr* e) { - arith_value av(get_context()); + arith_value av(m); + av.init(&get_context()); rational val; bool is_strict; if (av.get_lo(e, val, is_strict) && !is_strict && val.is_uint64()) { @@ -561,7 +562,8 @@ namespace smt { } time_t theory_jobscheduler::get_up(expr* e) { - arith_value av(get_context()); + arith_value av(m); + av.init(&get_context()); rational val; bool is_strict; if (av.get_up(e, val, is_strict) && !is_strict && val.is_uint64()) { @@ -571,7 +573,8 @@ namespace smt { } time_t theory_jobscheduler::get_value(expr* e) { - arith_value av(get_context()); + arith_value av(get_manager()); + av.init(&get_context()); rational val; if (av.get_value(e, val) && val.is_uint64()) { return val.get_uint64(); diff --git a/src/smt/theory_seq.cpp b/src/smt/theory_seq.cpp index ccd85c42a..40ce87e13 100644 --- a/src/smt/theory_seq.cpp +++ b/src/smt/theory_seq.cpp @@ -209,10 +209,12 @@ theory_seq::theory_seq(ast_manager& m, theory_seq_params const & params): m_axioms_head(0), m_int_string(m), m_mg(nullptr), + m_length(m), m_rewrite(m), m_seq_rewrite(m), m_util(m), m_autil(m), + m_arith_value(m), m_trail_stack(*this), m_ls(m), m_rs(m), m_lhs(m), m_rhs(m), @@ -245,6 +247,7 @@ theory_seq::~theory_seq() { void theory_seq::init(context* ctx) { theory::init(ctx); + m_arith_value.init(ctx); } final_check_status theory_seq::final_check_eh() { @@ -991,8 +994,9 @@ void theory_seq::find_max_eq_len(expr_ref_vector const& ls, expr_ref_vector cons hi = 1; } else { - lower_bound(ls.get(j), lo); - upper_bound(ls.get(j), hi); + expr_ref len_s = mk_len(ls.get(j)); + lower_bound(len_s, lo); + upper_bound(len_s, hi); } if (!lo.is_minus_one()) { if (lo1.is_minus_one()) @@ -1024,8 +1028,9 @@ void theory_seq::find_max_eq_len(expr_ref_vector const& ls, expr_ref_vector cons hi = 1; } else { - lower_bound(rs.get(j), lo); - upper_bound(rs.get(j), hi); + expr_ref len_s = mk_len(rs.get(j)); + lower_bound(len_s, lo); + upper_bound(len_s, hi); } if (!lo.is_minus_one()) { if (lo2.is_minus_one()) @@ -1729,7 +1734,7 @@ bool theory_seq::propagate_length_coherence(expr* e) { } TRACE("seq", tout << "Unsolved " << mk_pp(e, m); if (!lower_bound2(e, lo)) lo = -rational::one(); - if (!upper_bound(e, hi)) hi = -rational::one(); + if (!upper_bound(mk_len(e), hi)) hi = -rational::one(); tout << " lo: " << lo << " hi: " << hi << "\n"; ); @@ -1747,9 +1752,10 @@ bool theory_seq::propagate_length_coherence(expr* e) { // len(e) >= low => e = tail; literal low(mk_literal(m_autil.mk_ge(mk_len(e), m_autil.mk_numeral(lo, true)))); add_axiom(~low, mk_seq_eq(e, tail)); - if (upper_bound(e, hi)) { + expr_ref len_e = mk_len(e); + if (upper_bound(len_e, hi)) { // len(e) <= hi => len(tail) <= hi - lo - expr_ref high1(m_autil.mk_le(mk_len(e), m_autil.mk_numeral(hi, true)), m); + expr_ref high1(m_autil.mk_le(len_e, m_autil.mk_numeral(hi, true)), m); if (hi == lo) { add_axiom(~mk_literal(high1), mk_seq_eq(seq, emp)); } @@ -1799,13 +1805,17 @@ bool theory_seq::check_length_coherence0(expr* e) { bool theory_seq::check_length_coherence() { #if 1 - for (auto e : m_length) { + for (expr* l : m_length) { + expr* e = nullptr; + VERIFY(m_util.str.is_length(l, e)); if (check_length_coherence0(e)) { return true; } } #endif - for (auto e : m_length) { + for (expr* l : m_length) { + expr* e = nullptr; + VERIFY(m_util.str.is_length(l, e)); if (check_length_coherence(e)) { return true; } @@ -1823,9 +1833,11 @@ bool theory_seq::fixed_length(bool is_zero) { return found; } -bool theory_seq::fixed_length(expr* e, bool is_zero) { +bool theory_seq::fixed_length(expr* len_e, bool is_zero) { rational lo, hi; - if (!(is_var(e) && lower_bound(e, lo) && upper_bound(e, hi) && lo == hi + expr* e = nullptr; + VERIFY(m_util.str.is_length(len_e, e)); + if (!(is_var(e) && lower_bound(len_e, lo) && upper_bound(len_e, hi) && lo == hi && ((is_zero && lo.is_zero()) || (!is_zero && lo.is_unsigned())))) { return false; } @@ -1858,9 +1870,9 @@ bool theory_seq::fixed_length(expr* e, bool is_zero) { seq = mk_concat(elems.size(), elems.c_ptr()); } TRACE("seq", tout << "Fixed: " << mk_pp(e, m) << " " << lo << "\n";); - add_axiom(~mk_eq(mk_len(e), m_autil.mk_numeral(lo, true), false), mk_seq_eq(seq, e)); + add_axiom(~mk_eq(len_e, m_autil.mk_numeral(lo, true), false), mk_seq_eq(seq, e)); if (!ctx.at_base_level()) { - m_trail_stack.push(push_replay(alloc(replay_fixed_length, m, e))); + m_trail_stack.push(push_replay(alloc(replay_fixed_length, m, len_e))); } return true; } @@ -3354,10 +3366,14 @@ bool theory_seq::internalize_term(app* term) { return true; } -void theory_seq::add_length(expr* e) { - SASSERT(!has_length(e)); - m_length.insert(e); - m_trail_stack.push(insert_obj_trail(m_length, e)); +void theory_seq::add_length(expr* l) { + expr* e = nullptr; + VERIFY(m_util.str.is_length(l, e)); + SASSERT(!m_length.contains(l)); + m_length.push_back(l); + m_has_length.insert(e); + m_trail_stack.push(insert_obj_trail(m_has_length, e)); + m_trail_stack.push(push_back_vector(m_length)); } @@ -3372,7 +3388,7 @@ void theory_seq::enforce_length(expr* e) { if (!has_length(o)) { expr_ref len = mk_len(o); enque_axiom(len); - add_length(o); + add_length(len); } n = n->get_next(); } @@ -3609,14 +3625,12 @@ void theory_seq::display(std::ostream & out) const { m_exclude.display(out); } - if (!m_length.empty()) { - for (auto e : m_length) { - rational lo(-1), hi(-1); - lower_bound(e, lo); - upper_bound(e, hi); - if (lo.is_pos() || !hi.is_minus_one()) { - out << mk_pp(e, m) << " [" << lo << ":" << hi << "]\n"; - } + for (auto e : m_length) { + rational lo(-1), hi(-1); + lower_bound(e, lo); + upper_bound(e, hi); + if (lo.is_pos() || !hi.is_minus_one()) { + out << mk_pp(e, m) << " [" << lo << ":" << hi << "]\n"; } } @@ -4215,7 +4229,7 @@ void theory_seq::deque_axiom(expr* n) { if (m_util.str.is_length(n)) { add_length_axiom(n); } - else if (m_util.str.is_empty(n) && !has_length(n) && !m_length.empty()) { + else if (m_util.str.is_empty(n) && !has_length(n) && !m_has_length.empty()) { enforce_length(n); } else if (m_util.str.is_index(n)) { @@ -4648,24 +4662,21 @@ bool theory_seq::get_num_value(expr* e, rational& val) const { return false; } -bool theory_seq::lower_bound(expr* _e, rational& lo) const { - context& ctx = get_context(); - expr_ref e = mk_len(_e); - expr_ref _lo(m); - family_id afid = m_autil.get_family_id(); - do { - theory_mi_arith* tha = get_th_arith(ctx, afid, e); - if (tha && tha->get_lower(ctx.get_enode(e), _lo)) break; - theory_i_arith* thi = get_th_arith(ctx, afid, e); - if (thi && thi->get_lower(ctx.get_enode(e), _lo)) break; - theory_lra* thr = get_th_arith(ctx, afid, e); - if (thr && thr->get_lower(ctx.get_enode(e), _lo)) break; - return false; - } - while (false); - return m_autil.is_numeral(_lo, lo) && lo.is_int(); +bool theory_seq::lower_bound(expr* e, rational& lo) const { + VERIFY(m_autil.is_int(e)); + bool is_strict = true; + return m_arith_value.get_lo(e, lo, is_strict) && !is_strict && lo.is_int(); + } +bool theory_seq::upper_bound(expr* e, rational& hi) const { + VERIFY(m_autil.is_int(e)); + bool is_strict = true; + return m_arith_value.get_up(e, hi, is_strict) && !is_strict && hi.is_int(); +} + + + // The difference with lower_bound function is that since in some cases, // the lower bound is not updated for all the enodes in the same eqc, // we have to traverse the eqc to query for the better lower bound. @@ -4705,23 +4716,6 @@ bool theory_seq::lower_bound2(expr* _e, rational& lo) { return true; } -bool theory_seq::upper_bound(expr* _e, rational& hi) const { - context& ctx = get_context(); - expr_ref e = mk_len(_e); - family_id afid = m_autil.get_family_id(); - expr_ref _hi(m); - do { - theory_mi_arith* tha = get_th_arith(ctx, afid, e); - if (tha && tha->get_upper(ctx.get_enode(e), _hi)) break; - theory_i_arith* thi = get_th_arith(ctx, afid, e); - if (thi && thi->get_upper(ctx.get_enode(e), _hi)) break; - theory_lra* thr = get_th_arith(ctx, afid, e); - if (thr && thr->get_upper(ctx.get_enode(e), _hi)) break; - return false; - } - while (false); - return m_autil.is_numeral(_hi, hi) && hi.is_int(); -} bool theory_seq::get_length(expr* e, rational& val) const { context& ctx = get_context(); @@ -5485,14 +5479,15 @@ void theory_seq::propagate_step(literal lit, expr* step) { TRACE("seq", tout << mk_pp(step, m) << " -> " << mk_pp(t, m) << "\n";); propagate_lit(nullptr, 1, &lit, mk_literal(t)); + expr_ref len_s = mk_len(s); rational lo; rational _idx; VERIFY(m_autil.is_numeral(idx, _idx)); - if (lower_bound(s, lo) && lo.is_unsigned() && lo >= _idx) { + if (lower_bound(len_s, lo) && lo.is_unsigned() && lo >= _idx) { // skip } else { - propagate_lit(nullptr, 1, &lit, ~mk_literal(m_autil.mk_le(mk_len(s), idx))); + propagate_lit(nullptr, 1, &lit, ~mk_literal(m_autil.mk_le(len_s, idx))); } ensure_nth(lit, s, idx); @@ -5554,15 +5549,8 @@ void theory_seq::add_theory_assumptions(expr_ref_vector & assumptions) { TRACE("seq", tout << "add_theory_assumption " << m_util.has_re() << "\n";); if (m_util.has_re()) { expr_ref dlimit(m); - if (m_max_unfolding_lit != null_literal && - m_max_unfolding_depth == 1) { - dlimit = mk_max_unfolding_depth(); - m_max_unfolding_lit = mk_literal(dlimit); - } - else { - dlimit = get_context().bool_var2expr(m_max_unfolding_lit.var()); - } - TRACE("seq", tout << "add_theory_assumption " << dlimit << " " << assumptions << "\n";); + dlimit = mk_max_unfolding_depth(); + m_max_unfolding_lit = mk_literal(dlimit); assumptions.push_back(dlimit); } } diff --git a/src/smt/theory_seq.h b/src/smt/theory_seq.h index 75ba54381..9df28acf5 100644 --- a/src/smt/theory_seq.h +++ b/src/smt/theory_seq.h @@ -19,9 +19,7 @@ Revision History: #ifndef THEORY_SEQ_H_ #define THEORY_SEQ_H_ -#include "smt/smt_theory.h" #include "ast/seq_decl_plugin.h" -#include "smt/theory_seq_empty.h" #include "ast/rewriter/th_rewriter.h" #include "ast/ast_trail.h" #include "util/scoped_vector.h" @@ -30,6 +28,9 @@ Revision History: #include "ast/rewriter/seq_rewriter.h" #include "util/union_find.h" #include "util/obj_ref_hashtable.h" +#include "smt/smt_theory.h" +#include "smt/smt_arith_value.h" +#include "smt/theory_seq_empty.h" namespace smt { @@ -344,13 +345,15 @@ namespace smt { bool m_incomplete; // is the solver (clearly) incomplete for the fragment. expr_ref_vector m_int_string; obj_map m_si_axioms; - obj_hashtable m_length; // is length applied + obj_hashtable m_has_length; // is length applied + expr_ref_vector m_length; // length applications themselves scoped_ptr_vector m_replay; // set of actions to replay model_generator* m_mg; th_rewriter m_rewrite; seq_rewriter m_seq_rewrite; seq_util m_util; arith_util m_autil; + arith_value m_arith_value; th_trail_stack m_trail_stack; stats m_stats; symbol m_prefix, m_suffix, m_accept, m_reject; @@ -557,7 +560,7 @@ namespace smt { bool is_extract_suffix(expr* s, expr* i, expr* l); - bool has_length(expr *e) const { return m_length.contains(e); } + bool has_length(expr *e) const { return m_has_length.contains(e); } void add_length(expr* e); void enforce_length(expr* n); bool enforce_length(expr_ref_vector const& es, vector& len); diff --git a/src/smt/theory_str.cpp b/src/smt/theory_str.cpp index 18267c4fc..0e5393ac7 100644 --- a/src/smt/theory_str.cpp +++ b/src/smt/theory_str.cpp @@ -4874,9 +4874,10 @@ namespace smt { return false; } - arith_value v(get_context()); + arith_value v(get_manager()); + v.init(&get_context()); bool strict; - return v.get_lo(_e, lo, strict); + return v.get_lo_equiv(_e, lo, strict); } bool theory_str::upper_bound(expr* _e, rational& hi) { @@ -4885,9 +4886,10 @@ namespace smt { return false; } - arith_value v(get_context()); + arith_value v(get_manager()); + v.init(&get_context()); bool strict; - return v.get_up(_e, hi, strict); + return v.get_up_equiv(_e, hi, strict); } bool theory_str::get_len_value(expr* e, rational& val) {