diff --git a/src/smt/nseq_context_solver.h b/src/smt/nseq_context_solver.h index 853b04d51..2c2b07616 100644 --- a/src/smt/nseq_context_solver.h +++ b/src/smt/nseq_context_solver.h @@ -137,14 +137,14 @@ namespace smt { seq::dep_tracker core() override { return m_last_core; } - bool lower_bound(expr* e, rational& lo) const override { + bool lower_bound(expr* e, rational& lo, literal_vector& lits, enode_pair_vector& eqs) const override { bool is_strict = true; - return m_arith_value.get_lo(e, lo, is_strict) && !is_strict && lo.is_int(); + return m_arith_value.get_lo(e, lo, is_strict, lits, eqs) && !is_strict && lo.is_int(); } - bool upper_bound(expr* e, rational& hi) const override { + bool upper_bound(expr* e, rational& hi, literal_vector& lits, enode_pair_vector& eqs) const override { bool is_strict = true; - return m_arith_value.get_up(e, hi, is_strict) && !is_strict && hi.is_int(); + return m_arith_value.get_up(e, hi, is_strict, lits, eqs) && !is_strict && hi.is_int(); } bool current_value(expr* e, rational& v) const override { diff --git a/src/smt/seq/seq_nielsen.cpp b/src/smt/seq/seq_nielsen.cpp index 420ce74e8..4c4430570 100644 --- a/src/smt/seq/seq_nielsen.cpp +++ b/src/smt/seq/seq_nielsen.cpp @@ -45,7 +45,7 @@ NSB review: namespace seq { - void deps_to_lits(dep_tracker deps, svector &eqs, svector &lits, vector& es) { + void deps_to_lits(dep_tracker deps, svector &eqs, svector &lits) { vector vs; dep_manager::s_linearize(deps, vs); for (dep_source const &d : vs) { @@ -54,7 +54,7 @@ namespace seq { else if (std::holds_alternative(d)) lits.push_back(std::get(d)); else - es.push_back(std::get(d)); + UNREACHABLE(); } } @@ -279,6 +279,39 @@ namespace seq { m_str_mem.push_back(mem); } + bool nielsen_node::lower_bound(expr *e, rational &lo, dep_tracker &dep) { + literal_vector lits; + enode_pair_vector eqs; + if (m_graph.a.is_numeral(e, lo)) + return true; + if (!m_graph.m_solver.lower_bound(e, lo, lits, eqs)) + return false; + for (auto lit : lits) + dep = m_graph.dep_mgr().mk_join(dep, m_graph.dep_mgr().mk_leaf(lit)); + for (auto eq : eqs) + dep = m_graph.dep_mgr().mk_join(dep, m_graph.dep_mgr().mk_leaf(eq)); + + expr_ref lo_expr(m_graph.a.mk_int(lo), m_graph.m); + m_graph.add_le_dependency(dep, this, lo_expr, e); + return true; + } + + bool nielsen_node::upper_bound(expr *e, rational &up, dep_tracker &dep) { + literal_vector lits; + enode_pair_vector eqs; + if (m_graph.a.is_numeral(e, up)) + return true; + if (!m_graph.m_solver.upper_bound(e, up, lits, eqs)) + return false; + for (auto lit : lits) + dep = m_graph.dep_mgr().mk_join(dep, m_graph.dep_mgr().mk_leaf(lit)); + for (auto eq : eqs) + dep = m_graph.dep_mgr().mk_join(dep, m_graph.dep_mgr().mk_leaf(eq)); + expr_ref up_expr(m_graph.a.mk_int(up), m_graph.m); + m_graph.add_le_dependency(dep, this, e, up_expr); + return true; + } + void nielsen_node::add_constraint(constraint const &c) { auto& m = graph().get_manager(); if (m.is_true(c.fml)) @@ -395,31 +428,6 @@ namespace seq { } add_constraint(constraint(m.mk_or(cases), dep, m)); } - - bool nielsen_node::lower_bound(expr* e, rational& lo, dep_tracker& dep) { - SASSERT(e); - if (!m_graph.m_solver.lower_bound(e, lo)) - return false; - expr_ref lo_expr(m_graph.a.mk_int(lo), m_graph.m); - m_graph.add_le_dependency(dep, this, lo_expr.get(), e); - return true; - } - - - bool nielsen_node::upper_bound(expr* e, rational& up, dep_tracker& dep) { - SASSERT(e); - rational v; - if (m_graph.a.is_numeral(e, v)) { - up = v; - return true; - } - if (!m_graph.m_solver.upper_bound(e, up)) - return false; - expr_ref up_expr(m_graph.a.mk_int(up), m_graph.m); - m_graph.add_le_dependency(dep, this, e, up_expr.get()); - return true; - } - // ----------------------------------------------- // nielsen_graph // ----------------------------------------------- @@ -549,16 +557,14 @@ namespace seq { SASSERT(m_sat_node == nullptr); } - void nielsen_graph::add_le_dependency(dep_tracker& dep, nielsen_node* n, expr* lhs, expr* rhs) { + void nielsen_graph::add_le_dependency(dep_tracker dep, nielsen_node* n, expr* lhs, expr* rhs) { SASSERT(lhs); SASSERT(rhs); expr_ref le(a.mk_le(lhs, rhs), m); // just assume it to be correct - dep_tracker d = m_dep_mgr.mk_leaf(le); // Just add the constraint - we do not have to recompute it // [also it is on the set of side-conditions if we assert a satisfied node] - n->add_constraint(constraint(le, d, m)); - dep = m_dep_mgr.mk_join(dep, d); + n->add_constraint(constraint(le, dep, m)); } // ----------------------------------------------------------------------- @@ -4259,8 +4265,7 @@ namespace seq { // NSB review: this is one of several methods exposed for testing void nielsen_graph::test_aux_explain_conflict(svector& eqs, - svector& mem_literals, - vector& es) const { + svector& mem_literals) const { SASSERT(m_root); auto deps = collect_conflict_deps(); vector vs; @@ -4270,8 +4275,8 @@ namespace seq { eqs.push_back(std::get(d)); else if (std::holds_alternative(d)) mem_literals.push_back(std::get(d)); - else if (std::holds_alternative(d)) - es.push_back(std::get(d)); + else + UNREACHABLE(); } } @@ -4592,32 +4597,28 @@ namespace seq { dep = nullptr; rational lhs_lo, rhs_up; - bool has_lhs_lo = false, has_rhs_up = false; - dep_tracker lhs_lo_dep = nullptr, rhs_up_dep = nullptr; - if (n->lower_bound(lhs, lhs_lo, lhs_lo_dep)) - has_lhs_lo = true; - if (has_lhs_lo && n->upper_bound(rhs, rhs_up, rhs_up_dep)) - has_rhs_up = true; - if (has_lhs_lo && has_rhs_up) { - if (lhs_lo > rhs_up) - // NB: we only justify if we return true - return false; // definitely infeasible + literal_vector lits; + enode_pair_vector eqs; + if (m_solver.lower_bound(lhs, lhs_lo, lits, eqs) && + m_solver.upper_bound(rhs, rhs_up, lits, eqs) && lhs_lo > rhs_up) + return false; + + // lhs <= lhs_up <= rhs_lo <= rhs + // => lhs <= rhs is entailed + + lits.reset(); + eqs.reset(); + rational rhs_lo, lhs_up; + if (m_solver.upper_bound(lhs, lhs_up, lits, eqs) && + m_solver.lower_bound(rhs, rhs_lo, lits, eqs) && + lhs_up <= rhs_lo) { + for (auto lit : lits) + dep = m_dep_mgr.mk_join(dep, m_dep_mgr.mk_leaf(lit)); + for (enode_pair eq : eqs) + dep = m_dep_mgr.mk_join(dep, m_dep_mgr.mk_leaf(eq)); + return true; } - rational rhs_lo, lhs_up; - bool has_rhs_lo = false, has_lhs_up = false; - dep_tracker rhs_lo_dep = nullptr, lhs_up_dep = nullptr; - if (n->upper_bound(lhs, lhs_up, lhs_up_dep)) - has_lhs_up = true; - if (has_lhs_up && n->lower_bound(rhs, rhs_lo, rhs_lo_dep)) - has_rhs_lo = true; - if (has_lhs_up && has_rhs_lo) { - if (lhs_up <= rhs_lo) { - dep = m_dep_mgr.mk_join(dep, lhs_up_dep); - dep = m_dep_mgr.mk_join(dep, rhs_lo_dep); - return true; // definitely feasible - } - } // fall through - ask the solver [expensive] // TODO: Maybe cache the result? @@ -4631,9 +4632,11 @@ namespace seq { m_solver.push(); assert_to_subsolver(a.mk_ge(lhs, rhs_plus_one)); lbool result = m_solver.check(); + if (result == l_false) + dep = m_solver.core(); m_solver.pop(1); if (result == l_false) { - add_le_dependency(dep, n, lhs, rhs); + n->add_constraint(constraint(a.mk_le(lhs, rhs), dep, m)); return true; } return false; diff --git a/src/smt/seq/seq_nielsen.h b/src/smt/seq/seq_nielsen.h index 287453a07..e48fe9c79 100644 --- a/src/smt/seq/seq_nielsen.h +++ b/src/smt/seq/seq_nielsen.h @@ -314,9 +314,11 @@ namespace seq { // index is the 0-based position in the input eq or mem list respectively. using enode_pair = std::pair; + using literal_vector = svector; + using enode_pair_vector = svector; + using dep_source = std::variant; - using dep_source = std::variant; // Arena-based dependency manager: builds an immutable tree of dep_source @@ -347,8 +349,8 @@ namespace seq { virtual dep_tracker core() { return nullptr; } // Optional bound queries on arithmetic expressions (non-strict integer bounds). // Default implementation reports "unsupported". - virtual bool lower_bound(expr* e, rational& lo) const { return false; } - virtual bool upper_bound(expr* e, rational& hi) const { return false; } + virtual bool lower_bound(expr* e, rational& l, literal_vector& lits, enode_pair_vector& eqs) const { return false; } + virtual bool upper_bound(expr* e, rational& hi, literal_vector& lits, enode_pair_vector& eqs) const { return false; } virtual bool current_value(expr* e, rational& v) const { return false; } virtual void reset() = 0; }; @@ -356,9 +358,8 @@ namespace seq { // partition dep_source leaves from deps into enode pairs, sat literals, // and arithmetic <= dependencies. void deps_to_lits(dep_tracker deps, - svector& eqs, - svector& lits, - vector& es); + enode_pair_vector& eqs, + literal_vector& lits); // string equality constraint: lhs = rhs // mirrors ZIPT's StrEq (both sides are regex-free snode trees) @@ -994,8 +995,7 @@ namespace seq { // (kind::eq) and str_mem indices (kind::mem). // Must be called after solve() returns unsat. void test_aux_explain_conflict(svector &eqs, - svector &mem_literals, - vector& es) const; + svector &mem_literals) const; // accumulated search statistics @@ -1028,7 +1028,7 @@ namespace seq { dep_manager const& dep_mgr() const { return m_dep_mgr; } // Add a dependency leaf for lhs <= rhs and join it to dep. - void add_le_dependency(dep_tracker& dep, nielsen_node* n, expr* lhs, expr* rhs); + void add_le_dependency(dep_tracker dep, nielsen_node* n, expr* lhs, expr* rhs); void assert_to_subsolver(const constraint& c); diff --git a/src/smt/smt_arith_value.cpp b/src/smt/smt_arith_value.cpp index 806598e76..8badd3cb4 100644 --- a/src/smt/smt_arith_value.cpp +++ b/src/smt/smt_arith_value.cpp @@ -101,6 +101,28 @@ namespace smt { return false; } + bool arith_value::get_up(expr *e, rational &up, bool &is_strict, literal_vector& core, enode_pair_vector& eqs) const { + if (!m_ctx->e_internalized(e)) + return false; + is_strict = false; + enode *n = m_ctx->get_enode(e); + if (m_thr) + return m_thr->get_upper(n, up, is_strict, core, eqs); + TRACE(arith_value, tout << "value not found for " << mk_pp(e, m_ctx->get_manager()) << "\n";); + return false; + } + + bool arith_value::get_lo(expr *e, rational &up, bool &is_strict, literal_vector& core, enode_pair_vector& eqs) const { + if (!m_ctx->e_internalized(e)) + return false; + is_strict = false; + enode *n = m_ctx->get_enode(e); + if (m_thr) + return m_thr->get_lower(n, up, is_strict, core, eqs); + TRACE(arith_value, tout << "value not found for " << mk_pp(e, m_ctx->get_manager()) << "\n";); + return false; + } + bool arith_value::get_value(expr* e, rational& val) const { if (!m_ctx->e_internalized(e)) return false; expr_ref _val(m); diff --git a/src/smt/smt_arith_value.h b/src/smt/smt_arith_value.h index 7e351e43d..0b58784ce 100644 --- a/src/smt/smt_arith_value.h +++ b/src/smt/smt_arith_value.h @@ -43,6 +43,8 @@ namespace smt { bool get_value_equiv(expr* e, rational& value) const; bool get_lo(expr* e, rational& lo, bool& strict) const; bool get_up(expr* e, rational& up, bool& strict) const; + bool get_lo(expr *e, rational &lo, bool &strict, literal_vector& core, enode_pair_vector& eqs) const; + bool get_up(expr *e, rational &up, bool &strict, literal_vector& core, enode_pair_vector& eqs) const; bool get_value(expr* e, rational& value) const; expr_ref get_lo(expr* e) const; expr_ref get_up(expr* e) const; diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index 901785378..fa35c2372 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -3748,13 +3748,20 @@ public: return include_func_interp(n->get_decl()); } - bool get_lower(enode* n, rational& val, bool& is_strict) { + bool get_lower(enode *n, rational &val, bool &is_strict, literal_vector* lits = nullptr, enode_pair_vector* eqs = nullptr) { + if (a.is_numeral(n->get_expr(), val)) { + is_strict = false; + return true; + } theory_var v = n->get_th_var(get_id()); - if (!is_registered_var(v)) - return false; + if (!is_registered_var(v)) + return false; lpvar vi = get_lpvar(v); - u_dependency* ci; - return lp().has_lower_bound(vi, ci, val, is_strict); + u_dependency *ci = nullptr; + bool r = lp().has_lower_bound(vi, ci, val, is_strict); + if (r && lits && eqs) + set_evidence(ci, *lits, *eqs); + return r; } bool get_lower(enode* n, expr_ref& r) { @@ -3767,13 +3774,21 @@ public: return false; } - bool get_upper(enode* n, rational& val, bool& is_strict) { + bool get_upper(enode *n, rational &val, bool &is_strict, literal_vector *lits = nullptr, + enode_pair_vector *eqs = nullptr) { + if (a.is_numeral(n->get_expr(), val)) { + is_strict = false; + return true; + } theory_var v = n->get_th_var(get_id()); if (!is_registered_var(v)) return false; lpvar vi = get_lpvar(v); u_dependency* dep = nullptr; - return lp().has_upper_bound(vi, dep, val, is_strict); + bool r = lp().has_upper_bound(vi, dep, val, is_strict); + if (r && lits && eqs) + set_evidence(dep, *lits, *eqs); + return r; } void solve_fixed(enode* n, lpvar j, expr_ref& term, expr_ref& guard) { @@ -4483,6 +4498,13 @@ bool theory_lra::get_upper(enode* n, rational& r, bool& is_strict) { return m_imp->get_upper(n, r, is_strict); } +bool theory_lra::get_lower(enode* n, rational& r, bool& is_strict, literal_vector& core, enode_pair_vector& eqs) { + return m_imp->get_lower(n, r, is_strict, &core, &eqs); +} +bool theory_lra::get_upper(enode* n, rational& r, bool& is_strict, literal_vector& core, enode_pair_vector& eqs) { + return m_imp->get_upper(n, r, is_strict, &core, &eqs); +} + void theory_lra::solve_for(vector& sol) { m_imp->solve_for(sol); } diff --git a/src/smt/theory_lra.h b/src/smt/theory_lra.h index fb1a16b15..a6bfcb570 100644 --- a/src/smt/theory_lra.h +++ b/src/smt/theory_lra.h @@ -97,6 +97,8 @@ namespace smt { bool get_upper(enode* n, expr_ref& r); bool get_lower(enode* n, rational& r, bool& is_strict); bool get_upper(enode* n, rational& r, bool& is_strict); + bool get_lower(enode *n, rational &r, bool &is_strict, literal_vector& core, enode_pair_vector& eqs); + bool get_upper(enode *n, rational &r, bool &is_strict, literal_vector &core, enode_pair_vector &eqs); void solve_for(vector& s) override; diff --git a/src/smt/theory_nseq.cpp b/src/smt/theory_nseq.cpp index d058b9fbc..9c395fbdc 100644 --- a/src/smt/theory_nseq.cpp +++ b/src/smt/theory_nseq.cpp @@ -875,7 +875,7 @@ namespace smt { else if (std::holds_alternative(d)) lits.push_back(std::get(d)); else - lits.push_back(mk_literal(std::get(d))); + UNREACHABLE(); } ++m_num_conflicts; set_conflict(eqs, lits); @@ -929,8 +929,7 @@ namespace smt { else if (std::holds_alternative(d)) kernel.assert_expr(ctx.literal2expr(std::get(d))); else { - auto const& e = std::get(d); - kernel.assert_expr(e); + UNREACHABLE(); } } auto res = kernel.check(); @@ -1348,10 +1347,7 @@ namespace smt { // conditional constraints: propagate with justification from dep_tracker enode_pair_vector eqs; literal_vector lits; - vector es; - seq::deps_to_lits(lc.m_dep, eqs, lits, es); - for (auto const& e : es) - lits.push_back(mk_literal(e)); + seq::deps_to_lits(lc.m_dep, eqs, lits); set_propagate(eqs, lits, lit); @@ -1743,12 +1739,10 @@ namespace smt { enode_pair_vector eqs; literal_vector dep_lits; - vector dep_exprs; + for (unsigned idx : mem_indices) - seq::deps_to_lits(mems[idx].m_dep, eqs, dep_lits, dep_exprs); + seq::deps_to_lits(mems[idx].m_dep, eqs, dep_lits); - for (auto const &e : dep_exprs) - dep_lits.push_back(mk_literal(e)); set_propagate(eqs, dep_lits, lit_prop); diff --git a/src/test/nseq_basic.cpp b/src/test/nseq_basic.cpp index bb1f6d995..760e35fb4 100644 --- a/src/test/nseq_basic.cpp +++ b/src/test/nseq_basic.cpp @@ -161,8 +161,7 @@ static void test_nseq_symbol_clash() { // verify conflict explanation returns the equality index smt::enode_pair_vector eqs; sat::literal_vector mem_idx; - vector es; - ng.test_aux_explain_conflict(eqs, mem_idx, es); + ng.test_aux_explain_conflict(eqs, mem_idx); SASSERT(eqs.size() == 1); SASSERT(eqs[0].first == nullptr); SASSERT(mem_idx.empty()); diff --git a/src/test/seq_nielsen.cpp b/src/test/seq_nielsen.cpp index 87be1f51d..7a676ef36 100644 --- a/src/test/seq_nielsen.cpp +++ b/src/test/seq_nielsen.cpp @@ -1494,8 +1494,7 @@ static void test_explain_conflict_single_eq() { // but the conflict should still be detected svector eqs; svector mem_literals; - vector es; - ng.test_aux_explain_conflict(eqs, mem_literals, es); + ng.test_aux_explain_conflict(eqs, mem_literals); // with test-friendly overload (null deps), eqs will be empty // the important check is that the conflict was detected } @@ -1525,8 +1524,7 @@ static void test_explain_conflict_multi_eq() { // the important check is that the conflict was detected svector eqs; svector mem_literals; - vector es; - ng.test_aux_explain_conflict(eqs, mem_literals, es); + ng.test_aux_explain_conflict(eqs, mem_literals); } // test that is_extended is set after solve generates extensions @@ -2118,8 +2116,7 @@ static void test_explain_conflict_mem_only() { // with test-friendly overload (null deps), explain_conflict won't return deps svector eqs; svector mem_literals; - vector es; - ng.test_aux_explain_conflict(eqs, mem_literals, es); + ng.test_aux_explain_conflict(eqs, mem_literals); } // test explain_conflict: mixed eq + mem conflict @@ -2153,8 +2150,7 @@ static void test_explain_conflict_mixed_eq_mem() { // with test-friendly overload (null deps), explain_conflict won't return deps svector eqs; svector mem_literals; - vector es; - ng.test_aux_explain_conflict(eqs, mem_literals, es); + ng.test_aux_explain_conflict(eqs, mem_literals); } // test subsumption pruning during solve: a node whose constraint set