From 87f7a20e14413eabc40bb9b0b7799136b1126daf Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 19 Dec 2024 23:26:42 +0100 Subject: [PATCH] Add (updated and general) solve_for functionality for arithmetic, add congruence_explain to API to retrieve explanation for why two terms are congruent Tweak handling of smt.qi.max_instantations Add API solve_for(vars). It takes a list of variables and returns a triangular solved form for the variables. Currently for arithmetic. The solved form is a list with elements of the form (var, term, guard). Variables solved in the tail of the list do not occur before in the list. For example it can return a solution [(x, z, True), (y, x + z, True)] because first x was solved to be z, then y was solved to be x + z which is the same as 2z. Add congruent_explain that retuns an explanation for congruent terms. Terms congruent in the final state after calling SimpleSolver().check() can be queried for an explanation, i.e., a list of literals that collectively entail the equality under congruence closure. The literals are asserted in the final state of search. Adjust smt_context cancellation for the smt.qi.max_instantiations parameter. It gets checked when qi-queue elements are consumed. Prior it was checked on insertion time, which didn't allow for processing as many instantations as there were in the queue. Moreover, it would not cancel the solver. So it would keep adding instantations to the queue when it was full / depleted the configuration limit. --- src/api/api_solver.cpp | 37 ++++- src/api/python/z3/z3.py | 28 +++- src/api/z3_api.h | 17 +- src/math/lp/lar_solver.cpp | 153 +++++++++++++++--- src/math/lp/lar_solver.h | 39 +++-- src/muz/spacer/spacer_iuc_solver.h | 1 + src/opt/opt_solver.h | 1 + src/sat/sat_solver/inc_sat_solver.cpp | 1 + src/sat/sat_solver/sat_smt_solver.cpp | 1 + src/smt/qi_queue.cpp | 5 + src/smt/qi_queue.h | 2 +- src/smt/smt_context.cpp | 23 ++- src/smt/smt_context.h | 12 +- src/smt/smt_kernel.cpp | 29 +++- src/smt/smt_kernel.h | 4 +- src/smt/smt_quantifier.cpp | 7 +- src/smt/smt_solver.cpp | 3 +- src/smt/smt_theory.h | 8 +- src/smt/theory_lra.cpp | 145 +++++++++++++---- src/smt/theory_lra.h | 2 +- src/smt/theory_sls.cpp | 2 +- src/solver/combined_solver.cpp | 1 + src/solver/simplifier_solver.cpp | 1 + src/solver/slice_solver.cpp | 1 + src/solver/solver.h | 13 +- src/solver/solver_pool.cpp | 3 +- src/solver/tactic2solver.cpp | 1 + .../fd_solver/bounded_int2bv_solver.cpp | 1 + src/tactic/fd_solver/enum2bv_solver.cpp | 1 + src/tactic/fd_solver/pb2bv_solver.cpp | 1 + src/tactic/fd_solver/smtfd_solver.cpp | 2 + 31 files changed, 428 insertions(+), 117 deletions(-) diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 16daed6d4..ac28be572 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -967,20 +967,41 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } - Z3_ast Z3_API Z3_solver_solve_for(Z3_context c, Z3_solver s, Z3_ast a) { + Z3_ast Z3_API Z3_solver_congruence_explain(Z3_context c, Z3_solver s, Z3_ast a, Z3_ast b) { Z3_TRY; - LOG_Z3_solver_solve_for(c, s, a); + LOG_Z3_solver_congruence_explain(c, s, a, b); RESET_ERROR_CODE(); init_solver(c, s); - ast_manager& m = mk_c(c)->m(); - expr_ref term(m); - if (!to_solver_ref(s)->solve_for(to_expr(a), term)) - term = to_expr(a); - mk_c(c)->save_ast_trail(term.get()); - RETURN_Z3(of_expr(term.get())); + auto exp = to_solver_ref(s)->congruence_explain(to_expr(a), to_expr(b)); + mk_c(c)->save_ast_trail(exp.get()); + RETURN_Z3(of_expr(exp)); Z3_CATCH_RETURN(nullptr); } + void Z3_API Z3_solver_solve_for(Z3_context c, Z3_solver s, Z3_ast_vector vars, Z3_ast_vector terms, Z3_ast_vector guards) { + Z3_TRY; + LOG_Z3_solver_solve_for(c, s, vars, terms, guards); + RESET_ERROR_CODE(); + init_solver(c, s); + ast_manager& m = mk_c(c)->m(); + auto& _vars = to_ast_vector_ref(vars); + auto& _terms = to_ast_vector_ref(terms); + auto& _guards = to_ast_vector_ref(guards); + vector solutions; + for (auto t : _vars) + solutions.push_back({ to_expr(t), expr_ref(m), expr_ref(m) }); + to_solver_ref(s)->solve_for(solutions); + _vars.reset(); + _terms.reset(); + _guards.reset(); + for (solver::solution const& s : solutions) { + _vars.push_back(s.var); + _terms.push_back(s.term); + _guards.push_back(s.guard); + } + Z3_CATCH; + } + class api_context_obj : public user_propagator::context_obj { api::context* c; public: diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 34e046520..ccdd719a5 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -7336,26 +7336,44 @@ class Solver(Z3PPObject): return self.cube_vs def root(self, t): - t = _py2expr(t, self.ctx) """Retrieve congruence closure root of the term t relative to the current search state The function primarily works for SimpleSolver. Terms and variables that are eliminated during pre-processing are not visible to the congruence closure. """ + t = _py2expr(t, self.ctx) return _to_expr_ref(Z3_solver_congruence_root(self.ctx.ref(), self.solver, t.ast), self.ctx) def next(self, t): - t = _py2expr(t, self.ctx) """Retrieve congruence closure sibling of the term t relative to the current search state The function primarily works for SimpleSolver. Terms and variables that are eliminated during pre-processing are not visible to the congruence closure. """ + t = _py2expr(t, self.ctx) return _to_expr_ref(Z3_solver_congruence_next(self.ctx.ref(), self.solver, t.ast), self.ctx) - def solve_for(self, t): - t = _py2expr(t, self.ctx) + def explain_congruent(self, a, b): + """Explain congruence of a and b relative to the current search state""" + a = _py2expr(a, self.ctx) + b = _py2expr(b, self.ctx) + return _to_expr_ref(Z3_solver_congruence_explain(self.ctx.ref(), self.solver, a.ast, b.ast), self.ctx) + + def solve_for1(self, t): """Retrieve a solution for t relative to linear equations maintained in the current state. The function primarily works for SimpleSolver and when there is a solution using linear arithmetic.""" - return _to_expr_ref(Z3_solver_solve_for(self.ctx.ref(), self.solver, t.ast), self.ctx) + t = _py2expr(t, self.ctx) + return _to_expr_ref(Z3_solver_solve_for1(self.ctx.ref(), self.solver, t.ast), self.ctx) + + def solve_for(self, ts): + """Retrieve a solution for t relative to linear equations maintained in the current state.""" + vars = AstVector(ctx=self.ctx); + terms = AstVector(ctx=self.ctx); + guards = AstVector(ctx=self.ctx); + for t in ts: + t = _py2expr(t, self.ctx) + vars.push(t) + Z3_solver_solve_for(self.ctx.ref(), self.solver, vars.vector, terms.vector, guards.vector) + return [(vars[i], terms[i], guards[i]) for i in range(len(vars))] + def proof(self): """Return a proof for the last `check()`. Proof construction must be enabled.""" diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 791d36b2f..3082e8dc3 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -7076,14 +7076,23 @@ extern "C" { */ Z3_ast Z3_API Z3_solver_congruence_next(Z3_context c, Z3_solver s, Z3_ast a); + /** + \brief retrieve explanation for congruence. + \pre root(a) = root(b) + + def_API('Z3_solver_congruence_explain', AST, (_in(CONTEXT), _in(SOLVER), _in(AST), _in(AST))) + */ + Z3_ast Z3_API Z3_solver_congruence_explain(Z3_context c, Z3_solver s, Z3_ast a, Z3_ast b); /** - \brief retrieve a 'solution' for \c t as defined by equalities in maintained by solvers. - At this point, only linear solution are supported. + \brief retrieve a 'solution' for \c variables as defined by equalities in maintained by solvers. + At this point, only linear solution are supported. + The solution to \c variables may be presented in triangular form, such that + variables used in solutions themselves have solutions. - def_API('Z3_solver_solve_for', AST, (_in(CONTEXT), _in(SOLVER), _in(AST))) + def_API('Z3_solver_solve_for', VOID, (_in(CONTEXT), _in(SOLVER), _in(AST_VECTOR), _in(AST_VECTOR), _in(AST_VECTOR))) */ - Z3_ast Z3_API Z3_solver_solve_for(Z3_context c, Z3_solver s, Z3_ast t); + void Z3_API Z3_solver_solve_for(Z3_context c, Z3_solver s, Z3_ast_vector variables, Z3_ast_vector terms, Z3_ast_vector guards); /** \brief register a callback to that retrieves assumed, inferred and deleted clauses during search. diff --git a/src/math/lp/lar_solver.cpp b/src/math/lp/lar_solver.cpp index eb4e16806..08dc81cdf 100644 --- a/src/math/lp/lar_solver.cpp +++ b/src/math/lp/lar_solver.cpp @@ -570,8 +570,6 @@ namespace lp { A_r().pop(k); } - - void lar_solver::set_upper_bound_witness(lpvar j, u_dependency* dep) { m_trail.push(vector_value_trail(m_columns, j)); m_columns[j].upper_bound_witness() = dep; @@ -617,35 +615,148 @@ namespace lp { m_touched_rows.insert(rid); } - bool lar_solver::solve_for(unsigned j, lar_term& t, mpq& coeff) { - t.clear(); - IF_VERBOSE(10, verbose_stream() << "j " << j << " is fixed " << column_is_fixed(j) << " is-base " << is_base(j) << "\n"); - if (column_is_fixed(j)) { - coeff = get_value(j); - return true; + void lar_solver::check_fixed(unsigned j) { + if (column_is_fixed(j)) + return; + + auto explain = [&](constraint_index ci, lp::explanation const& exp) { + u_dependency* d = nullptr; + for (auto const& e : exp) + if (e.ci() != ci) + d = join_deps(d, m_dependencies.mk_leaf(e.ci())); + return d; + }; + + if (!column_has_lower_bound(j) || column_lower_bound(j) != get_value(j)) { + push(); + mpq val = get_value(j); + auto ci = add_var_bound(j, lconstraint_kind::LT, val); + auto r = solve(); + lp::explanation exp; + if (r == lp_status::INFEASIBLE) + get_infeasibility_explanation(exp); + pop(); + if (r == lp_status::INFEASIBLE) { + auto d = explain(ci, exp); + update_column_type_and_bound(j, lconstraint_kind::GE, val, d); + } + solve(); } + + if (!column_has_upper_bound(j) || column_upper_bound(j) != get_value(j)) { + push(); + auto val = get_value(j); + auto ci = add_var_bound(j, lconstraint_kind::GT, val); + auto r = solve(); + lp::explanation exp; + if (r == lp_status::INFEASIBLE) + get_infeasibility_explanation(exp); + pop(); + if (r == lp_status::INFEASIBLE) { + auto d = explain(ci, exp); + update_column_type_and_bound(j, lconstraint_kind::LE, val, d); + } + solve(); + } + } + + void lar_solver::solve_for(unsigned_vector const& js, vector& sols) { + uint_set tabu, fixed_checked; + for (auto j : js) + solve_for(j, tabu, sols); + solve(); // clear updated columns. + auto check = [&](unsigned j) { + if (fixed_checked.contains(j)) + return; + fixed_checked.insert(j); + check_fixed(j); + }; + for (auto const& [j, t] : sols) { + check(j); + for (auto const& v : t) + check(v.j()); + } + } + + void lar_solver::solve_for(unsigned j, uint_set& tabu, vector& sols) { + if (tabu.contains(j)) + return; + tabu.insert(j); + IF_VERBOSE(10, verbose_stream() << "solve for " << j << " base " << is_base(j) << " " << column_is_fixed(j) << "\n"); + if (column_is_fixed(j)) + return; + if (!is_base(j)) { - for (const auto & c : A_r().m_columns[j]) { + for (const auto& c : A_r().m_columns[j]) { lpvar basic_in_row = r_basis()[c.var()]; + if (tabu.contains(basic_in_row)) + continue; pivot(j, basic_in_row); - IF_VERBOSE(10, verbose_stream() << "is base " << is_base(j) << " c.var() = " << c.var() << " basic_in_row = " << basic_in_row << "\n"); break; } } - if (!is_base(j)) - return false; + if (!is_base(j)) + return; + + lar_term t; + auto const& col = m_columns[j]; auto const& r = basic2row(j); for (auto const& c : r) { - if (c.var() == j) - continue; - if (column_is_fixed(c.var())) - coeff -= get_value(c.var()); - else - t.add_monomial(-c.coeff(), c.var()); + if (c.var() != j) + t.add_monomial(-c.coeff(), c.var()); + } + for (auto const& v : t) + solve_for(v.j(), tabu, sols); + lp::impq lo, hi; + bool lo_valid = true, hi_valid = true; + for (auto const& v : t) { + + if (v.coeff().is_pos()) { + if (lo_valid && column_has_lower_bound(v.j())) + lo += column_lower_bound(v.j()) * v.coeff(); + else + lo_valid = false; + if (hi_valid && column_has_upper_bound(v.j())) + hi += column_upper_bound(v.j()) * v.coeff(); + else + hi_valid = false; + } + else { + if (lo_valid && column_has_upper_bound(v.j())) + lo += column_upper_bound(v.j()) * v.coeff(); + else + lo_valid = false; + if (hi_valid && column_has_lower_bound(v.j())) + hi += column_lower_bound(v.j()) * v.coeff(); + else + hi_valid = false; + } + } + + if (lo_valid && (!column_has_lower_bound(j) || lo > column_lower_bound(j).x)) { + u_dependency* dep = nullptr; + for (auto const& v : t) { + if (v.coeff().is_pos()) + dep = join_deps(dep, m_columns[v.j()].lower_bound_witness()); + else + dep = join_deps(dep, m_columns[v.j()].upper_bound_witness()); + } + update_column_type_and_bound(j, lo.y == 0 ? lconstraint_kind::GE : lconstraint_kind::GT, lo.x, dep); } - IF_VERBOSE(10, verbose_stream() << "j = " << j << " t = "; - print_term(t, verbose_stream()) << " coeff = " << coeff << "\n"); - return true; + + if (hi_valid && (!column_has_upper_bound(j) || hi < column_upper_bound(j).x)) { + u_dependency* dep = nullptr; + for (auto const& v : t) { + if (v.coeff().is_pos()) + dep = join_deps(dep, m_columns[v.j()].upper_bound_witness()); + else + dep = join_deps(dep, m_columns[v.j()].lower_bound_witness()); + } + update_column_type_and_bound(j, hi.y == 0 ? lconstraint_kind::LE : lconstraint_kind::LT, hi.x, dep); + } + + if (!column_is_fixed(j)) + sols.push_back({j, t}); } diff --git a/src/math/lp/lar_solver.h b/src/math/lp/lar_solver.h index 6d7b4f08f..2c58bbd45 100644 --- a/src/math/lp/lar_solver.h +++ b/src/math/lp/lar_solver.h @@ -351,7 +351,15 @@ public: * return the rest of the row as t comprising of non-fixed variables and coeff as sum of fixed variables. * return false if j has no rows. */ - bool solve_for(unsigned j, lar_term& t, mpq& coeff); + + struct solution { + unsigned j; + lar_term t; + }; + + void solve_for(unsigned_vector const& js, vector& sol); + void check_fixed(unsigned j); + void solve_for(unsigned j, uint_set& tabu, vector& sol); inline unsigned get_base_column_in_row(unsigned row_index) const { return m_mpq_lar_core_solver.m_r_solver.get_base_column_in_row(row_index); @@ -591,21 +599,32 @@ public: return m_columns[j].lower_bound_witness(); } inline bool column_has_term(lpvar j) const { return m_columns[j].term() != nullptr; } - inline std::ostream& print_column_info(unsigned j, std::ostream& out) const { - m_mpq_lar_core_solver.m_r_solver.print_column_info(j, out); - if (column_has_term(j)) { - print_term_as_indices(get_term(j), out) << "\n"; - } else if (column_has_term(j)) { - const lar_term& t = get_term(m_var_register.local_to_external(j)); - print_term_as_indices(t, out) << "\n"; - } + std::ostream& print_column_info(unsigned j, std::ostream& out) const { + m_mpq_lar_core_solver.m_r_solver.print_column_info(j, out); + if (column_has_term(j)) + print_term_as_indices(get_term(j), out) << "\n"; + display_column_explanation(out, j); + return out; + } + + std::ostream& display_column_explanation(std::ostream& out, unsigned j) const { + const column& ul = m_columns[j]; + svector vs1, vs2; + m_dependencies.linearize(ul.lower_bound_witness(), vs1); + m_dependencies.linearize(ul.upper_bound_witness(), vs2); + if (!vs1.empty()) + out << "lo: " << vs1; + if (!vs2.empty()) + out << "hi: " << vs2; + if (!vs1.empty() || !vs2.empty()) + out << "\n"; return out; } void subst_known_terms(lar_term*); - inline std::ostream& print_column_bound_info(unsigned j, std::ostream& out) const { + std::ostream& print_column_bound_info(unsigned j, std::ostream& out) const { return m_mpq_lar_core_solver.m_r_solver.print_column_bound_info(j, out); } diff --git a/src/muz/spacer/spacer_iuc_solver.h b/src/muz/spacer/spacer_iuc_solver.h index e201a1fe1..cdf355f03 100644 --- a/src/muz/spacer/spacer_iuc_solver.h +++ b/src/muz/spacer/spacer_iuc_solver.h @@ -124,6 +124,7 @@ public: expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); } expr* congruence_root(expr* e) override { return e; } expr* congruence_next(expr* e) override { return e; } + expr_ref congruence_explain(expr *a, expr *b) override { return expr_ref(m.mk_eq(a, b), m); } void get_levels(ptr_vector const& vars, unsigned_vector& depth) override { m_solver.get_levels(vars, depth); } expr_ref_vector get_trail(unsigned max_level) override { return m_solver.get_trail(max_level); } diff --git a/src/opt/opt_solver.h b/src/opt/opt_solver.h index e614a54fc..bacd8da7d 100644 --- a/src/opt/opt_solver.h +++ b/src/opt/opt_solver.h @@ -111,6 +111,7 @@ namespace opt { expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); } expr* congruence_root(expr* e) override { return e; } expr* congruence_next(expr* e) override { return e; } + expr_ref congruence_explain(expr* a, expr* b) override { return expr_ref(m.mk_eq(a, b), m); } void set_phase(expr* e) override { m_context.set_phase(e); } phase* get_phase() override { return m_context.get_phase(); } void set_phase(phase* p) override { m_context.set_phase(p); } diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index ba972c617..102f65c2f 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -480,6 +480,7 @@ public: expr* congruence_next(expr* e) override { return e; } expr* congruence_root(expr* e) override { return e; } + expr_ref congruence_explain(expr* a, expr* b) override { return expr_ref(m.mk_eq(a, b), m); } lbool get_consequences_core(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector& conseq) override { diff --git a/src/sat/sat_solver/sat_smt_solver.cpp b/src/sat/sat_solver/sat_smt_solver.cpp index 44fe15f33..19ff978dc 100644 --- a/src/sat/sat_solver/sat_smt_solver.cpp +++ b/src/sat/sat_solver/sat_smt_solver.cpp @@ -426,6 +426,7 @@ public: expr* congruence_next(expr* e) override { return e; } expr* congruence_root(expr* e) override { return e; } + expr_ref congruence_explain(expr* a, expr* b) override { return expr_ref(m.mk_eq(a, b), m); } lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) override { diff --git a/src/smt/qi_queue.cpp b/src/smt/qi_queue.cpp index 81d41eeee..1de60e69a 100644 --- a/src/smt/qi_queue.cpp +++ b/src/smt/qi_queue.cpp @@ -153,6 +153,11 @@ namespace smt { if (m_context.get_cancel_flag()) { break; } + if (m_stats.m_num_instances > m_params.m_qi_max_instances) { + m_context.set_reason_unknown("maximum number of quantifier instances was reached"); + m_context.set_internal_completed(); + break; + } fingerprint * f = curr.m_qb; quantifier * qa = static_cast(f->get_data()); diff --git a/src/smt/qi_queue.h b/src/smt/qi_queue.h index 961dd73e0..27589eee3 100644 --- a/src/smt/qi_queue.h +++ b/src/smt/qi_queue.h @@ -51,7 +51,7 @@ namespace smt { cost_evaluator m_evaluator; cached_var_subst m_subst; svector m_vals; - double m_eager_cost_threshold; + double m_eager_cost_threshold = 0; struct entry { fingerprint * m_qb; float m_cost; diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 4c1ba3da7..96b66149a 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -104,7 +104,7 @@ namespace smt { */ bool context::get_cancel_flag() { - if (l_true == m_sls_completed && !m.limit().suspended()) { + if (l_true == m_internal_completed && !m.limit().suspended()) { m_last_search_failure = CANCELED; return true; } @@ -3509,11 +3509,11 @@ namespace smt { display_profile(verbose_stream()); if (r == l_true && get_cancel_flag()) r = l_undef; - if (r == l_undef && m_sls_completed == l_true && has_sls_model()) { + if (r == l_undef && m_internal_completed == l_true && has_sls_model()) { m_last_search_failure = OK; r = l_true; } - m_sls_completed = l_false; + m_internal_completed = l_false; if (r == l_true && gparams::get_value("model_validate") == "true") { recfun::util u(m); if (u.get_rec_funs().empty() && m_proto_model) { @@ -3753,7 +3753,7 @@ namespace smt { m_phase_default = false; m_case_split_queue ->init_search_eh(); m_next_progress_sample = 0; - m_sls_completed = l_undef; + m_internal_completed = l_undef; if (m.has_type_vars() && !m_theories.get_plugin(poly_family_id)) register_plugin(alloc(theory_polymorphism, *this)); TRACE("literal_occ", display_literal_num_occs(tout);); @@ -4653,16 +4653,13 @@ namespace smt { if (th == nullptr) return false; return th->get_value(n, value); + } + + void context::solve_for(vector& sol) { + for (auto th : m_theories) + if (th) + th->solve_for(sol); } - - bool context::solve_for(enode * n, expr_ref & term) { - sort * s = n->get_sort(); - family_id fid = s->get_family_id(); - theory * th = get_theory(fid); - if (th == nullptr) - return false; - return th->solve_for(n, term); - } bool context::update_model(bool refinalize) { final_check_status fcs = FC_DONE; diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 5daffb959..9ab1db60a 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -130,7 +130,7 @@ namespace smt { class parallel* m_par = nullptr; unsigned m_par_index = 0; bool m_internalizing_assertions = false; - lbool m_sls_completed = l_undef; + lbool m_internal_completed = l_undef; // ----------------------------------- @@ -291,9 +291,9 @@ namespace smt { bool get_cancel_flag(); - void set_sls_completed() { - if (m_sls_completed == l_undef) - m_sls_completed = l_true; + void set_internal_completed() { + if (m_internal_completed == l_undef) + m_internal_completed = l_true; } region & get_region() { @@ -1377,13 +1377,13 @@ namespace smt { // ----------------------------------- // - // Model checking... (must be improved) + // Value extraction and solving // // ----------------------------------- public: bool get_value(enode * n, expr_ref & value); - bool solve_for(enode* n, expr_ref& term); + void solve_for(vector& sol); // ----------------------------------- // diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index b789fc499..be4bc66d6 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -20,6 +20,7 @@ Revision History: #include "smt/smt_context.h" #include "smt/smt_lookahead.h" #include "ast/ast_smt2_pp.h" +#include "ast/ast_util.h" #include "smt/params/smt_params_helper.hpp" namespace smt { @@ -213,11 +214,14 @@ namespace smt { return out; } - bool kernel::solve_for(expr* e, expr_ref& term) { - smt::enode* n = m_imp->m_kernel.find_enode(e); - if (!n) - return false; - return m_imp->m_kernel.solve_for(n, term); + void kernel::solve_for(vector& sol) { + vector solution; + for (auto const& [v, t, g] : sol) + solution.push_back({ v, t, g }); + m_imp->m_kernel.solve_for(solution); + sol.reset(); + for (auto s : solution) + sol.push_back({ s.var, s.term, s.guard }); } expr* kernel::congruence_root(expr * e) { @@ -234,6 +238,21 @@ namespace smt { return n->get_next()->get_expr(); } + expr_ref kernel::congruence_explain(expr* a, expr* b) { + auto& ctx = m_imp->m_kernel; + ast_manager& m = ctx.get_manager(); + smt::enode* n1 = ctx.find_enode(a); + smt::enode* n2 = ctx.find_enode(b); + if (!n1 || !n2 || n1->get_root() != n2->get_root()) + return expr_ref(m.mk_eq(a, b), m); + literal_vector lits; + ctx.get_cr().eq2literals(n1, n2, lits); + expr_ref_vector es(m); + for (auto lit : lits) + es.push_back(ctx.literal2expr(lit)); + return mk_and(es); + } + void kernel::collect_statistics(::statistics & st) const { m_imp->m_kernel.collect_statistics(st); } diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index fb152db8f..92dac74d5 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -246,7 +246,9 @@ namespace smt { expr* congruence_root(expr* e); - bool solve_for(expr* e, expr_ref& term); + expr_ref congruence_explain(expr* a, expr* b); + + void solve_for(vector& s); /** \brief retrieve depth of variables from decision stack. diff --git a/src/smt/smt_quantifier.cpp b/src/smt/smt_quantifier.cpp index b0696049b..77d654bfb 100644 --- a/src/smt/smt_quantifier.cpp +++ b/src/smt/smt_quantifier.cpp @@ -133,7 +133,7 @@ namespace smt { q::quantifier_stat_gen m_qstat_gen; ptr_vector m_quantifiers; scoped_ptr m_plugin; - unsigned m_num_instances; + unsigned m_num_instances = 0; imp(quantifier_manager & wrapper, context & ctx, smt_params & p, quantifier_manager_plugin * plugin): m_wrapper(wrapper), @@ -142,7 +142,6 @@ namespace smt { m_qi_queue(m_wrapper, ctx, p), m_qstat_gen(ctx.get_manager(), ctx.get_region()), m_plugin(plugin) { - m_num_instances = 0; m_qi_queue.setup(); } @@ -297,9 +296,7 @@ namespace smt { vector> & used_enodes) { max_generation = std::max(max_generation, get_generation(q)); - if (m_num_instances > m_params.m_qi_max_instances) { - return false; - } + get_stat(q)->update_max_generation(max_generation); fingerprint * f = m_context.add_fingerprint(q, q->get_id(), num_bindings, bindings, def); if (f) { diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 65feaaa9d..847eb5077 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -337,7 +337,8 @@ namespace { expr* congruence_next(expr* e) override { return m_context.congruence_next(e); } expr* congruence_root(expr* e) override { return m_context.congruence_root(e); } - bool solve_for(expr* e, expr_ref& term) override { return m_context.solve_for(e, term); } + expr_ref congruence_explain(expr* a, expr* b) override { return m_context.congruence_explain(a, b); } + void solve_for(vector& s) override { m_context.solve_for(s); } expr_ref_vector cube(expr_ref_vector& vars, unsigned cutoff) override { diff --git a/src/smt/smt_theory.h b/src/smt/smt_theory.h index 2aed3f449..50f29606e 100644 --- a/src/smt/smt_theory.h +++ b/src/smt/smt_theory.h @@ -29,6 +29,12 @@ namespace smt { class model_generator; class model_value_proc; + struct solution { + expr* var; + expr_ref term; + expr_ref guard; + }; + class theory { protected: theory_id m_id; @@ -605,7 +611,7 @@ namespace smt { virtual char const * get_name() const { return "unknown"; } - virtual bool solve_for(enode* n, expr_ref& r) { return false; } + virtual void solve_for(vector& s) {} // ----------------------------------- // diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index 7a73905b7..15eca3d88 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -3169,7 +3169,7 @@ public: typedef std::pair constraint_bound; vector m_lower_terms; vector m_upper_terms; - + void propagate_eqs(lp::lpvar t, lp::constraint_index ci1, lp::lconstraint_kind k, api_bound& b, rational const& value) { u_dependency* ci2 = nullptr; auto pair = [&]() { return lp().dep_manager().mk_join(lp().dep_manager().mk_leaf(ci1), ci2); }; @@ -3392,9 +3392,8 @@ public: } void set_evidence(lp::constraint_index idx, literal_vector& core, svector& eqs) { - if (idx == UINT_MAX) { - return; - } + if (idx == UINT_MAX) + return; switch (m_constraint_sources[idx]) { case inequality_source: { literal lit = m_inequalities[idx]; @@ -3629,33 +3628,116 @@ public: return lp().has_upper_bound(vi, dep, val, is_strict); } - bool solve_for(enode* n, expr_ref& term) { - theory_var v = n->get_th_var(get_id()); - if (!is_registered_var(v)) - return false; - lpvar vi = get_lpvar(v); - lp::lar_term t; - rational coeff; - if (!lp().solve_for(vi, t, coeff)) - return false; - rational lc(1); - if (is_int(v)) { - lc = denominator(coeff); - for (auto const& cv : t) - lc = lcm(denominator(cv.coeff()), lc); - if (lc != 1) { - coeff *= lc; - t *= lc; - } - } - term = mk_term(t, is_int(v)); - if (coeff != 0) - term = a.mk_add(a.mk_numeral(coeff, is_int(v)), term); - if (lc != 1) - term = a.mk_idiv(term, a.mk_numeral(lc, true)); - return true; + void solve_fixed(enode* n, lpvar j, expr_ref& term, expr_ref& guard) { + term = a.mk_numeral(lp().get_value(j), a.is_int(n->get_expr())); + reset_evidence(); + add_explain(j); + guard = extract_explain(); } + void add_explain(unsigned j) { + auto d = lp().get_bound_constraint_witnesses_for_column(j); + set_evidence(d, m_core, m_eqs); + } + + expr_ref extract_explain() { + expr_ref_vector es(m); + for (auto [l, r] : m_eqs) + es.push_back(a.mk_eq(l->get_expr(), r->get_expr())); + for (auto l : m_core) + es.push_back(ctx().literal2expr(l)); + // remove duplicats from es: + std::stable_sort(es.data(), es.data() + es.size()); + unsigned j = 0; + for (unsigned i = 0; i < es.size(); ++i) { + if (i > 0 && es.get(i) == es.get(i - 1)) + continue; + es[j++] = es.get(i); + } + es.shrink(j); + return mk_and(es); + } + + void solve_term(enode* n, lp::lar_term & lt, expr_ref& term, expr_ref& guard) { + bool is_int = a.is_int(n->get_expr()); + bool all_int = is_int; + lp::lar_term t; + rational coeff(0); + expr_ref_vector guards(m); + reset_evidence(); + for (auto const& cv : lt) { + if (lp().column_is_fixed(cv.j())) { + coeff += lp().get_value(cv.j()) * cv.coeff(); + add_explain(cv.j()); + } + else + t.add_monomial(cv.coeff(), cv.j()); + } + guards.push_back(extract_explain()); + rational lc = denominator(coeff); + for (auto const& cv : t) { + lc = lcm(denominator(cv.coeff()), lc); + all_int &= lp().column_is_int(cv.j()); + } + if (lc != 1) + t *= lc, coeff *= lc; + term = mk_term(t, is_int); + if (coeff != 0) + term = a.mk_add(term, a.mk_numeral(coeff, is_int)); + + if (lc == 1) { + guard = mk_and(guards); + return; + } + expr_ref lce(a.mk_numeral(lc, true), m); + if (all_int) + guards.push_back(m.mk_eq(a.mk_mod(term, lce), a.mk_int(0))); + else if (is_int) + guards.push_back(a.mk_is_int(a.mk_div(term, lce))); + term = a.mk_idiv(term, lce); + guard = mk_and(guards); + } + + void solve_for(vector& solutions) { + unsigned_vector vars; + unsigned j = 0; + for (auto [e, t, g] : solutions) { + auto n = get_enode(e); + if (!n) { + solutions[j++] = { e, t, g }; + continue; + } + + theory_var v = n->get_th_var(get_id()); + if (!is_registered_var(v)) + solutions[j++] = { e, t, g }; + else + vars.push_back(get_lpvar(v)); + } + solutions.shrink(j); + + expr_ref term(m), guard(m); + vector sols; + lp().solve_for(vars, sols); + uint_set seen; + for (auto& s : sols) { + auto n = get_enode(lp().local_to_external(s.j)); + if (lp().column_is_fixed(s.j)) + solve_fixed(n, s.j, term, guard); + else + solve_term(n, s.t, term, guard); + solutions.push_back({ n->get_expr(), term, guard}); + seen.insert(s.j); + } + for (auto j : vars) { + if (seen.contains(j) || !lp().column_is_fixed(j)) + continue; + auto n = get_enode(lp().local_to_external(j)); + solve_fixed(n, j, term, guard); + solutions.push_back({ n->get_expr(), term, guard }); + } + } + bool get_upper(enode* n, expr_ref& r) { bool is_strict; rational val; @@ -4166,8 +4248,9 @@ bool theory_lra::get_lower(enode* n, rational& r, bool& is_strict) { bool theory_lra::get_upper(enode* n, rational& r, bool& is_strict) { return m_imp->get_upper(n, r, is_strict); } -bool theory_lra::solve_for(enode* n, expr_ref& r) { - return m_imp->solve_for(n, r); + +void theory_lra::solve_for(vector& sol) { + m_imp->solve_for(sol); } void theory_lra::display(std::ostream & out) const { diff --git a/src/smt/theory_lra.h b/src/smt/theory_lra.h index 6267ae68d..1624bab0a 100644 --- a/src/smt/theory_lra.h +++ b/src/smt/theory_lra.h @@ -93,7 +93,7 @@ namespace smt { bool get_upper(enode* n, expr_ref& r); bool get_lower(enode* n, rational& r, bool& is_strict); bool get_upper(enode* n, rational& r, bool& is_strict); - bool solve_for(enode* n, expr_ref& r) override; + void solve_for(vector& s) override; void display(std::ostream & out) const override; diff --git a/src/smt/theory_sls.cpp b/src/smt/theory_sls.cpp index 9c1bd2bef..4ab93435f 100644 --- a/src/smt/theory_sls.cpp +++ b/src/smt/theory_sls.cpp @@ -58,7 +58,7 @@ namespace smt { } void theory_sls::set_finished() { - ctx.set_sls_completed(); + ctx.set_internal_completed(); } bool theory_sls::get_smt_value(expr* v, expr_ref& value) { diff --git a/src/solver/combined_solver.cpp b/src/solver/combined_solver.cpp index e2a2f2011..97ecc8e1c 100644 --- a/src/solver/combined_solver.cpp +++ b/src/solver/combined_solver.cpp @@ -277,6 +277,7 @@ public: expr* congruence_next(expr* e) override { switch_inc_mode(); return m_solver2->congruence_next(e); } expr* congruence_root(expr* e) override { switch_inc_mode(); return m_solver2->congruence_root(e); } + expr_ref congruence_explain(expr* a, expr* b) override { switch_inc_mode(); return m_solver2->congruence_explain(a, b); } expr * get_assumption(unsigned idx) const override { diff --git a/src/solver/simplifier_solver.cpp b/src/solver/simplifier_solver.cpp index b114f364f..23a802c49 100644 --- a/src/solver/simplifier_solver.cpp +++ b/src/solver/simplifier_solver.cpp @@ -365,6 +365,7 @@ public: expr* congruence_root(expr* e) override { return s->congruence_root(e); } expr* congruence_next(expr* e) override { return s->congruence_next(e); } + expr_ref congruence_explain(expr* a, expr* b) override { return s->congruence_explain(a, b); } std::ostream& display(std::ostream& out, unsigned n, expr* const* assumptions) const override { return s->display(out, n, assumptions); } diff --git a/src/solver/slice_solver.cpp b/src/solver/slice_solver.cpp index 28815b4d6..8310c47f4 100644 --- a/src/solver/slice_solver.cpp +++ b/src/solver/slice_solver.cpp @@ -393,6 +393,7 @@ public: expr* congruence_root(expr* e) override { return s->congruence_root(e); } expr* congruence_next(expr* e) override { return s->congruence_next(e); } + expr_ref congruence_explain(expr* a, expr* b) override { return s->congruence_explain(a, b); } std::ostream& display(std::ostream& out, unsigned n, expr* const* assumptions) const override { return s->display(out, n, assumptions); } diff --git a/src/solver/solver.h b/src/solver/solver.h index 99acc4a1e..d5b3ee4f8 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -27,6 +27,7 @@ class solver; class model_converter; + class solver_factory { public: virtual ~solver_factory() = default; @@ -249,9 +250,17 @@ public: virtual expr* congruence_next(expr* e) = 0; /** - \brief try to solve for term e (when e is arithmetical). + \brief expose explanation for congruence. */ - virtual bool solve_for(expr* e, expr_ref& term) { return false; } + virtual expr_ref congruence_explain(expr* a, expr* b) = 0; + + struct solution { + expr* var; + expr_ref term; + expr_ref guard; + }; + + virtual void solve_for(vector& s) {} /** \brief Display the content of this solver. diff --git a/src/solver/solver_pool.cpp b/src/solver/solver_pool.cpp index 411634162..5fbe8fa11 100644 --- a/src/solver/solver_pool.cpp +++ b/src/solver/solver_pool.cpp @@ -81,7 +81,7 @@ public: } void push_params() override {m_base->push_params();} void pop_params() override {m_base->pop_params();} - + void collect_param_descrs(param_descrs & r) override { m_base->collect_param_descrs(r); } void collect_statistics(statistics & st) const override { m_base->collect_statistics(st); } unsigned get_num_assertions() const override { return m_base->get_num_assertions(); } @@ -264,6 +264,7 @@ public: expr* congruence_next(expr* e) override { return e; } expr* congruence_root(expr* e) override { return e; } + expr_ref congruence_explain(expr* a, expr* b) override { return expr_ref(m.mk_eq(a, b), m); } ast_manager& get_manager() const override { return m_base->get_manager(); } diff --git a/src/solver/tactic2solver.cpp b/src/solver/tactic2solver.cpp index dfdae5e5a..bf849d830 100644 --- a/src/solver/tactic2solver.cpp +++ b/src/solver/tactic2solver.cpp @@ -145,6 +145,7 @@ public: expr* congruence_next(expr* e) override { return e; } expr* congruence_root(expr* e) override { return e; } + expr_ref congruence_explain(expr* a, expr* b) override { return expr_ref(get_manager().mk_eq(a, b), get_manager()); } model_converter_ref get_model_converter() const override { return m_mc; } diff --git a/src/tactic/fd_solver/bounded_int2bv_solver.cpp b/src/tactic/fd_solver/bounded_int2bv_solver.cpp index 45b444fe9..6f305372d 100644 --- a/src/tactic/fd_solver/bounded_int2bv_solver.cpp +++ b/src/tactic/fd_solver/bounded_int2bv_solver.cpp @@ -212,6 +212,7 @@ public: ast_manager& get_manager() const override { return m; } expr* congruence_next(expr* e) override { return m_solver->congruence_next(e); } expr* congruence_root(expr* e) override { return m_solver->congruence_root(e); } + expr_ref congruence_explain(expr* a, expr* b) override { return m_solver->congruence_explain(a, b); } expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) override { flush_assertions(); return m_solver->cube(vars, backtrack_level); } lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) override { return m_solver->find_mutexes(vars, mutexes); } lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { diff --git a/src/tactic/fd_solver/enum2bv_solver.cpp b/src/tactic/fd_solver/enum2bv_solver.cpp index 2690e7033..061f67010 100644 --- a/src/tactic/fd_solver/enum2bv_solver.cpp +++ b/src/tactic/fd_solver/enum2bv_solver.cpp @@ -133,6 +133,7 @@ public: } expr* congruence_next(expr* e) override { return m_solver->congruence_next(e); } expr* congruence_root(expr* e) override { return m_solver->congruence_root(e); } + expr_ref congruence_explain(expr* a, expr* b) override { return m_solver->congruence_explain(a, b); } lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { diff --git a/src/tactic/fd_solver/pb2bv_solver.cpp b/src/tactic/fd_solver/pb2bv_solver.cpp index 19f2630f2..7bc0f27dd 100644 --- a/src/tactic/fd_solver/pb2bv_solver.cpp +++ b/src/tactic/fd_solver/pb2bv_solver.cpp @@ -124,6 +124,7 @@ public: expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) override { flush_assertions(); return m_solver->cube(vars, backtrack_level); } expr* congruence_next(expr* e) override { return m_solver->congruence_next(e); } expr* congruence_root(expr* e) override { return m_solver->congruence_root(e); } + expr_ref congruence_explain(expr* a, expr* b) override { return m_solver->congruence_explain(a, b); } lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) override { return m_solver->find_mutexes(vars, mutexes); } lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { flush_assertions(); diff --git a/src/tactic/fd_solver/smtfd_solver.cpp b/src/tactic/fd_solver/smtfd_solver.cpp index 4d0912fdc..46fb4054e 100644 --- a/src/tactic/fd_solver/smtfd_solver.cpp +++ b/src/tactic/fd_solver/smtfd_solver.cpp @@ -2070,6 +2070,8 @@ namespace smtfd { expr* congruence_root(expr* e) override { return e; } expr* congruence_next(expr* e) override { return e; } + + expr_ref congruence_explain(expr* a, expr* b) override { return expr_ref(m.mk_eq(a, b), m); } lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { return l_undef;