diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 16daed6d4..ac28be572 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -967,20 +967,41 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } - Z3_ast Z3_API Z3_solver_solve_for(Z3_context c, Z3_solver s, Z3_ast a) { + Z3_ast Z3_API Z3_solver_congruence_explain(Z3_context c, Z3_solver s, Z3_ast a, Z3_ast b) { Z3_TRY; - LOG_Z3_solver_solve_for(c, s, a); + LOG_Z3_solver_congruence_explain(c, s, a, b); RESET_ERROR_CODE(); init_solver(c, s); - ast_manager& m = mk_c(c)->m(); - expr_ref term(m); - if (!to_solver_ref(s)->solve_for(to_expr(a), term)) - term = to_expr(a); - mk_c(c)->save_ast_trail(term.get()); - RETURN_Z3(of_expr(term.get())); + auto exp = to_solver_ref(s)->congruence_explain(to_expr(a), to_expr(b)); + mk_c(c)->save_ast_trail(exp.get()); + RETURN_Z3(of_expr(exp)); Z3_CATCH_RETURN(nullptr); } + void Z3_API Z3_solver_solve_for(Z3_context c, Z3_solver s, Z3_ast_vector vars, Z3_ast_vector terms, Z3_ast_vector guards) { + Z3_TRY; + LOG_Z3_solver_solve_for(c, s, vars, terms, guards); + RESET_ERROR_CODE(); + init_solver(c, s); + ast_manager& m = mk_c(c)->m(); + auto& _vars = to_ast_vector_ref(vars); + auto& _terms = to_ast_vector_ref(terms); + auto& _guards = to_ast_vector_ref(guards); + vector solutions; + for (auto t : _vars) + solutions.push_back({ to_expr(t), expr_ref(m), expr_ref(m) }); + to_solver_ref(s)->solve_for(solutions); + _vars.reset(); + _terms.reset(); + _guards.reset(); + for (solver::solution const& s : solutions) { + _vars.push_back(s.var); + _terms.push_back(s.term); + _guards.push_back(s.guard); + } + Z3_CATCH; + } + class api_context_obj : public user_propagator::context_obj { api::context* c; public: diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 34e046520..ccdd719a5 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -7336,26 +7336,44 @@ class Solver(Z3PPObject): return self.cube_vs def root(self, t): - t = _py2expr(t, self.ctx) """Retrieve congruence closure root of the term t relative to the current search state The function primarily works for SimpleSolver. Terms and variables that are eliminated during pre-processing are not visible to the congruence closure. """ + t = _py2expr(t, self.ctx) return _to_expr_ref(Z3_solver_congruence_root(self.ctx.ref(), self.solver, t.ast), self.ctx) def next(self, t): - t = _py2expr(t, self.ctx) """Retrieve congruence closure sibling of the term t relative to the current search state The function primarily works for SimpleSolver. Terms and variables that are eliminated during pre-processing are not visible to the congruence closure. """ + t = _py2expr(t, self.ctx) return _to_expr_ref(Z3_solver_congruence_next(self.ctx.ref(), self.solver, t.ast), self.ctx) - def solve_for(self, t): - t = _py2expr(t, self.ctx) + def explain_congruent(self, a, b): + """Explain congruence of a and b relative to the current search state""" + a = _py2expr(a, self.ctx) + b = _py2expr(b, self.ctx) + return _to_expr_ref(Z3_solver_congruence_explain(self.ctx.ref(), self.solver, a.ast, b.ast), self.ctx) + + def solve_for1(self, t): """Retrieve a solution for t relative to linear equations maintained in the current state. The function primarily works for SimpleSolver and when there is a solution using linear arithmetic.""" - return _to_expr_ref(Z3_solver_solve_for(self.ctx.ref(), self.solver, t.ast), self.ctx) + t = _py2expr(t, self.ctx) + return _to_expr_ref(Z3_solver_solve_for1(self.ctx.ref(), self.solver, t.ast), self.ctx) + + def solve_for(self, ts): + """Retrieve a solution for t relative to linear equations maintained in the current state.""" + vars = AstVector(ctx=self.ctx); + terms = AstVector(ctx=self.ctx); + guards = AstVector(ctx=self.ctx); + for t in ts: + t = _py2expr(t, self.ctx) + vars.push(t) + Z3_solver_solve_for(self.ctx.ref(), self.solver, vars.vector, terms.vector, guards.vector) + return [(vars[i], terms[i], guards[i]) for i in range(len(vars))] + def proof(self): """Return a proof for the last `check()`. Proof construction must be enabled.""" diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 791d36b2f..3082e8dc3 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -7076,14 +7076,23 @@ extern "C" { */ Z3_ast Z3_API Z3_solver_congruence_next(Z3_context c, Z3_solver s, Z3_ast a); + /** + \brief retrieve explanation for congruence. + \pre root(a) = root(b) + + def_API('Z3_solver_congruence_explain', AST, (_in(CONTEXT), _in(SOLVER), _in(AST), _in(AST))) + */ + Z3_ast Z3_API Z3_solver_congruence_explain(Z3_context c, Z3_solver s, Z3_ast a, Z3_ast b); /** - \brief retrieve a 'solution' for \c t as defined by equalities in maintained by solvers. - At this point, only linear solution are supported. + \brief retrieve a 'solution' for \c variables as defined by equalities in maintained by solvers. + At this point, only linear solution are supported. + The solution to \c variables may be presented in triangular form, such that + variables used in solutions themselves have solutions. - def_API('Z3_solver_solve_for', AST, (_in(CONTEXT), _in(SOLVER), _in(AST))) + def_API('Z3_solver_solve_for', VOID, (_in(CONTEXT), _in(SOLVER), _in(AST_VECTOR), _in(AST_VECTOR), _in(AST_VECTOR))) */ - Z3_ast Z3_API Z3_solver_solve_for(Z3_context c, Z3_solver s, Z3_ast t); + void Z3_API Z3_solver_solve_for(Z3_context c, Z3_solver s, Z3_ast_vector variables, Z3_ast_vector terms, Z3_ast_vector guards); /** \brief register a callback to that retrieves assumed, inferred and deleted clauses during search. diff --git a/src/math/lp/lar_solver.cpp b/src/math/lp/lar_solver.cpp index eb4e16806..08dc81cdf 100644 --- a/src/math/lp/lar_solver.cpp +++ b/src/math/lp/lar_solver.cpp @@ -570,8 +570,6 @@ namespace lp { A_r().pop(k); } - - void lar_solver::set_upper_bound_witness(lpvar j, u_dependency* dep) { m_trail.push(vector_value_trail(m_columns, j)); m_columns[j].upper_bound_witness() = dep; @@ -617,35 +615,148 @@ namespace lp { m_touched_rows.insert(rid); } - bool lar_solver::solve_for(unsigned j, lar_term& t, mpq& coeff) { - t.clear(); - IF_VERBOSE(10, verbose_stream() << "j " << j << " is fixed " << column_is_fixed(j) << " is-base " << is_base(j) << "\n"); - if (column_is_fixed(j)) { - coeff = get_value(j); - return true; + void lar_solver::check_fixed(unsigned j) { + if (column_is_fixed(j)) + return; + + auto explain = [&](constraint_index ci, lp::explanation const& exp) { + u_dependency* d = nullptr; + for (auto const& e : exp) + if (e.ci() != ci) + d = join_deps(d, m_dependencies.mk_leaf(e.ci())); + return d; + }; + + if (!column_has_lower_bound(j) || column_lower_bound(j) != get_value(j)) { + push(); + mpq val = get_value(j); + auto ci = add_var_bound(j, lconstraint_kind::LT, val); + auto r = solve(); + lp::explanation exp; + if (r == lp_status::INFEASIBLE) + get_infeasibility_explanation(exp); + pop(); + if (r == lp_status::INFEASIBLE) { + auto d = explain(ci, exp); + update_column_type_and_bound(j, lconstraint_kind::GE, val, d); + } + solve(); } + + if (!column_has_upper_bound(j) || column_upper_bound(j) != get_value(j)) { + push(); + auto val = get_value(j); + auto ci = add_var_bound(j, lconstraint_kind::GT, val); + auto r = solve(); + lp::explanation exp; + if (r == lp_status::INFEASIBLE) + get_infeasibility_explanation(exp); + pop(); + if (r == lp_status::INFEASIBLE) { + auto d = explain(ci, exp); + update_column_type_and_bound(j, lconstraint_kind::LE, val, d); + } + solve(); + } + } + + void lar_solver::solve_for(unsigned_vector const& js, vector& sols) { + uint_set tabu, fixed_checked; + for (auto j : js) + solve_for(j, tabu, sols); + solve(); // clear updated columns. + auto check = [&](unsigned j) { + if (fixed_checked.contains(j)) + return; + fixed_checked.insert(j); + check_fixed(j); + }; + for (auto const& [j, t] : sols) { + check(j); + for (auto const& v : t) + check(v.j()); + } + } + + void lar_solver::solve_for(unsigned j, uint_set& tabu, vector& sols) { + if (tabu.contains(j)) + return; + tabu.insert(j); + IF_VERBOSE(10, verbose_stream() << "solve for " << j << " base " << is_base(j) << " " << column_is_fixed(j) << "\n"); + if (column_is_fixed(j)) + return; + if (!is_base(j)) { - for (const auto & c : A_r().m_columns[j]) { + for (const auto& c : A_r().m_columns[j]) { lpvar basic_in_row = r_basis()[c.var()]; + if (tabu.contains(basic_in_row)) + continue; pivot(j, basic_in_row); - IF_VERBOSE(10, verbose_stream() << "is base " << is_base(j) << " c.var() = " << c.var() << " basic_in_row = " << basic_in_row << "\n"); break; } } - if (!is_base(j)) - return false; + if (!is_base(j)) + return; + + lar_term t; + auto const& col = m_columns[j]; auto const& r = basic2row(j); for (auto const& c : r) { - if (c.var() == j) - continue; - if (column_is_fixed(c.var())) - coeff -= get_value(c.var()); - else - t.add_monomial(-c.coeff(), c.var()); + if (c.var() != j) + t.add_monomial(-c.coeff(), c.var()); + } + for (auto const& v : t) + solve_for(v.j(), tabu, sols); + lp::impq lo, hi; + bool lo_valid = true, hi_valid = true; + for (auto const& v : t) { + + if (v.coeff().is_pos()) { + if (lo_valid && column_has_lower_bound(v.j())) + lo += column_lower_bound(v.j()) * v.coeff(); + else + lo_valid = false; + if (hi_valid && column_has_upper_bound(v.j())) + hi += column_upper_bound(v.j()) * v.coeff(); + else + hi_valid = false; + } + else { + if (lo_valid && column_has_upper_bound(v.j())) + lo += column_upper_bound(v.j()) * v.coeff(); + else + lo_valid = false; + if (hi_valid && column_has_lower_bound(v.j())) + hi += column_lower_bound(v.j()) * v.coeff(); + else + hi_valid = false; + } + } + + if (lo_valid && (!column_has_lower_bound(j) || lo > column_lower_bound(j).x)) { + u_dependency* dep = nullptr; + for (auto const& v : t) { + if (v.coeff().is_pos()) + dep = join_deps(dep, m_columns[v.j()].lower_bound_witness()); + else + dep = join_deps(dep, m_columns[v.j()].upper_bound_witness()); + } + update_column_type_and_bound(j, lo.y == 0 ? lconstraint_kind::GE : lconstraint_kind::GT, lo.x, dep); } - IF_VERBOSE(10, verbose_stream() << "j = " << j << " t = "; - print_term(t, verbose_stream()) << " coeff = " << coeff << "\n"); - return true; + + if (hi_valid && (!column_has_upper_bound(j) || hi < column_upper_bound(j).x)) { + u_dependency* dep = nullptr; + for (auto const& v : t) { + if (v.coeff().is_pos()) + dep = join_deps(dep, m_columns[v.j()].upper_bound_witness()); + else + dep = join_deps(dep, m_columns[v.j()].lower_bound_witness()); + } + update_column_type_and_bound(j, hi.y == 0 ? lconstraint_kind::LE : lconstraint_kind::LT, hi.x, dep); + } + + if (!column_is_fixed(j)) + sols.push_back({j, t}); } diff --git a/src/math/lp/lar_solver.h b/src/math/lp/lar_solver.h index 6d7b4f08f..2c58bbd45 100644 --- a/src/math/lp/lar_solver.h +++ b/src/math/lp/lar_solver.h @@ -351,7 +351,15 @@ public: * return the rest of the row as t comprising of non-fixed variables and coeff as sum of fixed variables. * return false if j has no rows. */ - bool solve_for(unsigned j, lar_term& t, mpq& coeff); + + struct solution { + unsigned j; + lar_term t; + }; + + void solve_for(unsigned_vector const& js, vector& sol); + void check_fixed(unsigned j); + void solve_for(unsigned j, uint_set& tabu, vector& sol); inline unsigned get_base_column_in_row(unsigned row_index) const { return m_mpq_lar_core_solver.m_r_solver.get_base_column_in_row(row_index); @@ -591,21 +599,32 @@ public: return m_columns[j].lower_bound_witness(); } inline bool column_has_term(lpvar j) const { return m_columns[j].term() != nullptr; } - inline std::ostream& print_column_info(unsigned j, std::ostream& out) const { - m_mpq_lar_core_solver.m_r_solver.print_column_info(j, out); - if (column_has_term(j)) { - print_term_as_indices(get_term(j), out) << "\n"; - } else if (column_has_term(j)) { - const lar_term& t = get_term(m_var_register.local_to_external(j)); - print_term_as_indices(t, out) << "\n"; - } + std::ostream& print_column_info(unsigned j, std::ostream& out) const { + m_mpq_lar_core_solver.m_r_solver.print_column_info(j, out); + if (column_has_term(j)) + print_term_as_indices(get_term(j), out) << "\n"; + display_column_explanation(out, j); + return out; + } + + std::ostream& display_column_explanation(std::ostream& out, unsigned j) const { + const column& ul = m_columns[j]; + svector vs1, vs2; + m_dependencies.linearize(ul.lower_bound_witness(), vs1); + m_dependencies.linearize(ul.upper_bound_witness(), vs2); + if (!vs1.empty()) + out << "lo: " << vs1; + if (!vs2.empty()) + out << "hi: " << vs2; + if (!vs1.empty() || !vs2.empty()) + out << "\n"; return out; } void subst_known_terms(lar_term*); - inline std::ostream& print_column_bound_info(unsigned j, std::ostream& out) const { + std::ostream& print_column_bound_info(unsigned j, std::ostream& out) const { return m_mpq_lar_core_solver.m_r_solver.print_column_bound_info(j, out); } diff --git a/src/muz/spacer/spacer_iuc_solver.h b/src/muz/spacer/spacer_iuc_solver.h index e201a1fe1..cdf355f03 100644 --- a/src/muz/spacer/spacer_iuc_solver.h +++ b/src/muz/spacer/spacer_iuc_solver.h @@ -124,6 +124,7 @@ public: expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); } expr* congruence_root(expr* e) override { return e; } expr* congruence_next(expr* e) override { return e; } + expr_ref congruence_explain(expr *a, expr *b) override { return expr_ref(m.mk_eq(a, b), m); } void get_levels(ptr_vector const& vars, unsigned_vector& depth) override { m_solver.get_levels(vars, depth); } expr_ref_vector get_trail(unsigned max_level) override { return m_solver.get_trail(max_level); } diff --git a/src/opt/opt_solver.h b/src/opt/opt_solver.h index e614a54fc..bacd8da7d 100644 --- a/src/opt/opt_solver.h +++ b/src/opt/opt_solver.h @@ -111,6 +111,7 @@ namespace opt { expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); } expr* congruence_root(expr* e) override { return e; } expr* congruence_next(expr* e) override { return e; } + expr_ref congruence_explain(expr* a, expr* b) override { return expr_ref(m.mk_eq(a, b), m); } void set_phase(expr* e) override { m_context.set_phase(e); } phase* get_phase() override { return m_context.get_phase(); } void set_phase(phase* p) override { m_context.set_phase(p); } diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index ba972c617..102f65c2f 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -480,6 +480,7 @@ public: expr* congruence_next(expr* e) override { return e; } expr* congruence_root(expr* e) override { return e; } + expr_ref congruence_explain(expr* a, expr* b) override { return expr_ref(m.mk_eq(a, b), m); } lbool get_consequences_core(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector& conseq) override { diff --git a/src/sat/sat_solver/sat_smt_solver.cpp b/src/sat/sat_solver/sat_smt_solver.cpp index 44fe15f33..19ff978dc 100644 --- a/src/sat/sat_solver/sat_smt_solver.cpp +++ b/src/sat/sat_solver/sat_smt_solver.cpp @@ -426,6 +426,7 @@ public: expr* congruence_next(expr* e) override { return e; } expr* congruence_root(expr* e) override { return e; } + expr_ref congruence_explain(expr* a, expr* b) override { return expr_ref(m.mk_eq(a, b), m); } lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) override { diff --git a/src/smt/qi_queue.cpp b/src/smt/qi_queue.cpp index 81d41eeee..1de60e69a 100644 --- a/src/smt/qi_queue.cpp +++ b/src/smt/qi_queue.cpp @@ -153,6 +153,11 @@ namespace smt { if (m_context.get_cancel_flag()) { break; } + if (m_stats.m_num_instances > m_params.m_qi_max_instances) { + m_context.set_reason_unknown("maximum number of quantifier instances was reached"); + m_context.set_internal_completed(); + break; + } fingerprint * f = curr.m_qb; quantifier * qa = static_cast(f->get_data()); diff --git a/src/smt/qi_queue.h b/src/smt/qi_queue.h index 961dd73e0..27589eee3 100644 --- a/src/smt/qi_queue.h +++ b/src/smt/qi_queue.h @@ -51,7 +51,7 @@ namespace smt { cost_evaluator m_evaluator; cached_var_subst m_subst; svector m_vals; - double m_eager_cost_threshold; + double m_eager_cost_threshold = 0; struct entry { fingerprint * m_qb; float m_cost; diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 4c1ba3da7..96b66149a 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -104,7 +104,7 @@ namespace smt { */ bool context::get_cancel_flag() { - if (l_true == m_sls_completed && !m.limit().suspended()) { + if (l_true == m_internal_completed && !m.limit().suspended()) { m_last_search_failure = CANCELED; return true; } @@ -3509,11 +3509,11 @@ namespace smt { display_profile(verbose_stream()); if (r == l_true && get_cancel_flag()) r = l_undef; - if (r == l_undef && m_sls_completed == l_true && has_sls_model()) { + if (r == l_undef && m_internal_completed == l_true && has_sls_model()) { m_last_search_failure = OK; r = l_true; } - m_sls_completed = l_false; + m_internal_completed = l_false; if (r == l_true && gparams::get_value("model_validate") == "true") { recfun::util u(m); if (u.get_rec_funs().empty() && m_proto_model) { @@ -3753,7 +3753,7 @@ namespace smt { m_phase_default = false; m_case_split_queue ->init_search_eh(); m_next_progress_sample = 0; - m_sls_completed = l_undef; + m_internal_completed = l_undef; if (m.has_type_vars() && !m_theories.get_plugin(poly_family_id)) register_plugin(alloc(theory_polymorphism, *this)); TRACE("literal_occ", display_literal_num_occs(tout);); @@ -4653,16 +4653,13 @@ namespace smt { if (th == nullptr) return false; return th->get_value(n, value); + } + + void context::solve_for(vector& sol) { + for (auto th : m_theories) + if (th) + th->solve_for(sol); } - - bool context::solve_for(enode * n, expr_ref & term) { - sort * s = n->get_sort(); - family_id fid = s->get_family_id(); - theory * th = get_theory(fid); - if (th == nullptr) - return false; - return th->solve_for(n, term); - } bool context::update_model(bool refinalize) { final_check_status fcs = FC_DONE; diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 5daffb959..9ab1db60a 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -130,7 +130,7 @@ namespace smt { class parallel* m_par = nullptr; unsigned m_par_index = 0; bool m_internalizing_assertions = false; - lbool m_sls_completed = l_undef; + lbool m_internal_completed = l_undef; // ----------------------------------- @@ -291,9 +291,9 @@ namespace smt { bool get_cancel_flag(); - void set_sls_completed() { - if (m_sls_completed == l_undef) - m_sls_completed = l_true; + void set_internal_completed() { + if (m_internal_completed == l_undef) + m_internal_completed = l_true; } region & get_region() { @@ -1377,13 +1377,13 @@ namespace smt { // ----------------------------------- // - // Model checking... (must be improved) + // Value extraction and solving // // ----------------------------------- public: bool get_value(enode * n, expr_ref & value); - bool solve_for(enode* n, expr_ref& term); + void solve_for(vector& sol); // ----------------------------------- // diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index b789fc499..be4bc66d6 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -20,6 +20,7 @@ Revision History: #include "smt/smt_context.h" #include "smt/smt_lookahead.h" #include "ast/ast_smt2_pp.h" +#include "ast/ast_util.h" #include "smt/params/smt_params_helper.hpp" namespace smt { @@ -213,11 +214,14 @@ namespace smt { return out; } - bool kernel::solve_for(expr* e, expr_ref& term) { - smt::enode* n = m_imp->m_kernel.find_enode(e); - if (!n) - return false; - return m_imp->m_kernel.solve_for(n, term); + void kernel::solve_for(vector& sol) { + vector solution; + for (auto const& [v, t, g] : sol) + solution.push_back({ v, t, g }); + m_imp->m_kernel.solve_for(solution); + sol.reset(); + for (auto s : solution) + sol.push_back({ s.var, s.term, s.guard }); } expr* kernel::congruence_root(expr * e) { @@ -234,6 +238,21 @@ namespace smt { return n->get_next()->get_expr(); } + expr_ref kernel::congruence_explain(expr* a, expr* b) { + auto& ctx = m_imp->m_kernel; + ast_manager& m = ctx.get_manager(); + smt::enode* n1 = ctx.find_enode(a); + smt::enode* n2 = ctx.find_enode(b); + if (!n1 || !n2 || n1->get_root() != n2->get_root()) + return expr_ref(m.mk_eq(a, b), m); + literal_vector lits; + ctx.get_cr().eq2literals(n1, n2, lits); + expr_ref_vector es(m); + for (auto lit : lits) + es.push_back(ctx.literal2expr(lit)); + return mk_and(es); + } + void kernel::collect_statistics(::statistics & st) const { m_imp->m_kernel.collect_statistics(st); } diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index fb152db8f..92dac74d5 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -246,7 +246,9 @@ namespace smt { expr* congruence_root(expr* e); - bool solve_for(expr* e, expr_ref& term); + expr_ref congruence_explain(expr* a, expr* b); + + void solve_for(vector& s); /** \brief retrieve depth of variables from decision stack. diff --git a/src/smt/smt_quantifier.cpp b/src/smt/smt_quantifier.cpp index b0696049b..77d654bfb 100644 --- a/src/smt/smt_quantifier.cpp +++ b/src/smt/smt_quantifier.cpp @@ -133,7 +133,7 @@ namespace smt { q::quantifier_stat_gen m_qstat_gen; ptr_vector m_quantifiers; scoped_ptr m_plugin; - unsigned m_num_instances; + unsigned m_num_instances = 0; imp(quantifier_manager & wrapper, context & ctx, smt_params & p, quantifier_manager_plugin * plugin): m_wrapper(wrapper), @@ -142,7 +142,6 @@ namespace smt { m_qi_queue(m_wrapper, ctx, p), m_qstat_gen(ctx.get_manager(), ctx.get_region()), m_plugin(plugin) { - m_num_instances = 0; m_qi_queue.setup(); } @@ -297,9 +296,7 @@ namespace smt { vector> & used_enodes) { max_generation = std::max(max_generation, get_generation(q)); - if (m_num_instances > m_params.m_qi_max_instances) { - return false; - } + get_stat(q)->update_max_generation(max_generation); fingerprint * f = m_context.add_fingerprint(q, q->get_id(), num_bindings, bindings, def); if (f) { diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 65feaaa9d..847eb5077 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -337,7 +337,8 @@ namespace { expr* congruence_next(expr* e) override { return m_context.congruence_next(e); } expr* congruence_root(expr* e) override { return m_context.congruence_root(e); } - bool solve_for(expr* e, expr_ref& term) override { return m_context.solve_for(e, term); } + expr_ref congruence_explain(expr* a, expr* b) override { return m_context.congruence_explain(a, b); } + void solve_for(vector& s) override { m_context.solve_for(s); } expr_ref_vector cube(expr_ref_vector& vars, unsigned cutoff) override { diff --git a/src/smt/smt_theory.h b/src/smt/smt_theory.h index 2aed3f449..50f29606e 100644 --- a/src/smt/smt_theory.h +++ b/src/smt/smt_theory.h @@ -29,6 +29,12 @@ namespace smt { class model_generator; class model_value_proc; + struct solution { + expr* var; + expr_ref term; + expr_ref guard; + }; + class theory { protected: theory_id m_id; @@ -605,7 +611,7 @@ namespace smt { virtual char const * get_name() const { return "unknown"; } - virtual bool solve_for(enode* n, expr_ref& r) { return false; } + virtual void solve_for(vector& s) {} // ----------------------------------- // diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index 7a73905b7..15eca3d88 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -3169,7 +3169,7 @@ public: typedef std::pair constraint_bound; vector m_lower_terms; vector m_upper_terms; - + void propagate_eqs(lp::lpvar t, lp::constraint_index ci1, lp::lconstraint_kind k, api_bound& b, rational const& value) { u_dependency* ci2 = nullptr; auto pair = [&]() { return lp().dep_manager().mk_join(lp().dep_manager().mk_leaf(ci1), ci2); }; @@ -3392,9 +3392,8 @@ public: } void set_evidence(lp::constraint_index idx, literal_vector& core, svector& eqs) { - if (idx == UINT_MAX) { - return; - } + if (idx == UINT_MAX) + return; switch (m_constraint_sources[idx]) { case inequality_source: { literal lit = m_inequalities[idx]; @@ -3629,33 +3628,116 @@ public: return lp().has_upper_bound(vi, dep, val, is_strict); } - bool solve_for(enode* n, expr_ref& term) { - theory_var v = n->get_th_var(get_id()); - if (!is_registered_var(v)) - return false; - lpvar vi = get_lpvar(v); - lp::lar_term t; - rational coeff; - if (!lp().solve_for(vi, t, coeff)) - return false; - rational lc(1); - if (is_int(v)) { - lc = denominator(coeff); - for (auto const& cv : t) - lc = lcm(denominator(cv.coeff()), lc); - if (lc != 1) { - coeff *= lc; - t *= lc; - } - } - term = mk_term(t, is_int(v)); - if (coeff != 0) - term = a.mk_add(a.mk_numeral(coeff, is_int(v)), term); - if (lc != 1) - term = a.mk_idiv(term, a.mk_numeral(lc, true)); - return true; + void solve_fixed(enode* n, lpvar j, expr_ref& term, expr_ref& guard) { + term = a.mk_numeral(lp().get_value(j), a.is_int(n->get_expr())); + reset_evidence(); + add_explain(j); + guard = extract_explain(); } + void add_explain(unsigned j) { + auto d = lp().get_bound_constraint_witnesses_for_column(j); + set_evidence(d, m_core, m_eqs); + } + + expr_ref extract_explain() { + expr_ref_vector es(m); + for (auto [l, r] : m_eqs) + es.push_back(a.mk_eq(l->get_expr(), r->get_expr())); + for (auto l : m_core) + es.push_back(ctx().literal2expr(l)); + // remove duplicats from es: + std::stable_sort(es.data(), es.data() + es.size()); + unsigned j = 0; + for (unsigned i = 0; i < es.size(); ++i) { + if (i > 0 && es.get(i) == es.get(i - 1)) + continue; + es[j++] = es.get(i); + } + es.shrink(j); + return mk_and(es); + } + + void solve_term(enode* n, lp::lar_term & lt, expr_ref& term, expr_ref& guard) { + bool is_int = a.is_int(n->get_expr()); + bool all_int = is_int; + lp::lar_term t; + rational coeff(0); + expr_ref_vector guards(m); + reset_evidence(); + for (auto const& cv : lt) { + if (lp().column_is_fixed(cv.j())) { + coeff += lp().get_value(cv.j()) * cv.coeff(); + add_explain(cv.j()); + } + else + t.add_monomial(cv.coeff(), cv.j()); + } + guards.push_back(extract_explain()); + rational lc = denominator(coeff); + for (auto const& cv : t) { + lc = lcm(denominator(cv.coeff()), lc); + all_int &= lp().column_is_int(cv.j()); + } + if (lc != 1) + t *= lc, coeff *= lc; + term = mk_term(t, is_int); + if (coeff != 0) + term = a.mk_add(term, a.mk_numeral(coeff, is_int)); + + if (lc == 1) { + guard = mk_and(guards); + return; + } + expr_ref lce(a.mk_numeral(lc, true), m); + if (all_int) + guards.push_back(m.mk_eq(a.mk_mod(term, lce), a.mk_int(0))); + else if (is_int) + guards.push_back(a.mk_is_int(a.mk_div(term, lce))); + term = a.mk_idiv(term, lce); + guard = mk_and(guards); + } + + void solve_for(vector& solutions) { + unsigned_vector vars; + unsigned j = 0; + for (auto [e, t, g] : solutions) { + auto n = get_enode(e); + if (!n) { + solutions[j++] = { e, t, g }; + continue; + } + + theory_var v = n->get_th_var(get_id()); + if (!is_registered_var(v)) + solutions[j++] = { e, t, g }; + else + vars.push_back(get_lpvar(v)); + } + solutions.shrink(j); + + expr_ref term(m), guard(m); + vector sols; + lp().solve_for(vars, sols); + uint_set seen; + for (auto& s : sols) { + auto n = get_enode(lp().local_to_external(s.j)); + if (lp().column_is_fixed(s.j)) + solve_fixed(n, s.j, term, guard); + else + solve_term(n, s.t, term, guard); + solutions.push_back({ n->get_expr(), term, guard}); + seen.insert(s.j); + } + for (auto j : vars) { + if (seen.contains(j) || !lp().column_is_fixed(j)) + continue; + auto n = get_enode(lp().local_to_external(j)); + solve_fixed(n, j, term, guard); + solutions.push_back({ n->get_expr(), term, guard }); + } + } + bool get_upper(enode* n, expr_ref& r) { bool is_strict; rational val; @@ -4166,8 +4248,9 @@ bool theory_lra::get_lower(enode* n, rational& r, bool& is_strict) { bool theory_lra::get_upper(enode* n, rational& r, bool& is_strict) { return m_imp->get_upper(n, r, is_strict); } -bool theory_lra::solve_for(enode* n, expr_ref& r) { - return m_imp->solve_for(n, r); + +void theory_lra::solve_for(vector& sol) { + m_imp->solve_for(sol); } void theory_lra::display(std::ostream & out) const { diff --git a/src/smt/theory_lra.h b/src/smt/theory_lra.h index 6267ae68d..1624bab0a 100644 --- a/src/smt/theory_lra.h +++ b/src/smt/theory_lra.h @@ -93,7 +93,7 @@ namespace smt { bool get_upper(enode* n, expr_ref& r); bool get_lower(enode* n, rational& r, bool& is_strict); bool get_upper(enode* n, rational& r, bool& is_strict); - bool solve_for(enode* n, expr_ref& r) override; + void solve_for(vector& s) override; void display(std::ostream & out) const override; diff --git a/src/smt/theory_sls.cpp b/src/smt/theory_sls.cpp index 9c1bd2bef..4ab93435f 100644 --- a/src/smt/theory_sls.cpp +++ b/src/smt/theory_sls.cpp @@ -58,7 +58,7 @@ namespace smt { } void theory_sls::set_finished() { - ctx.set_sls_completed(); + ctx.set_internal_completed(); } bool theory_sls::get_smt_value(expr* v, expr_ref& value) { diff --git a/src/solver/combined_solver.cpp b/src/solver/combined_solver.cpp index e2a2f2011..97ecc8e1c 100644 --- a/src/solver/combined_solver.cpp +++ b/src/solver/combined_solver.cpp @@ -277,6 +277,7 @@ public: expr* congruence_next(expr* e) override { switch_inc_mode(); return m_solver2->congruence_next(e); } expr* congruence_root(expr* e) override { switch_inc_mode(); return m_solver2->congruence_root(e); } + expr_ref congruence_explain(expr* a, expr* b) override { switch_inc_mode(); return m_solver2->congruence_explain(a, b); } expr * get_assumption(unsigned idx) const override { diff --git a/src/solver/simplifier_solver.cpp b/src/solver/simplifier_solver.cpp index b114f364f..23a802c49 100644 --- a/src/solver/simplifier_solver.cpp +++ b/src/solver/simplifier_solver.cpp @@ -365,6 +365,7 @@ public: expr* congruence_root(expr* e) override { return s->congruence_root(e); } expr* congruence_next(expr* e) override { return s->congruence_next(e); } + expr_ref congruence_explain(expr* a, expr* b) override { return s->congruence_explain(a, b); } std::ostream& display(std::ostream& out, unsigned n, expr* const* assumptions) const override { return s->display(out, n, assumptions); } diff --git a/src/solver/slice_solver.cpp b/src/solver/slice_solver.cpp index 28815b4d6..8310c47f4 100644 --- a/src/solver/slice_solver.cpp +++ b/src/solver/slice_solver.cpp @@ -393,6 +393,7 @@ public: expr* congruence_root(expr* e) override { return s->congruence_root(e); } expr* congruence_next(expr* e) override { return s->congruence_next(e); } + expr_ref congruence_explain(expr* a, expr* b) override { return s->congruence_explain(a, b); } std::ostream& display(std::ostream& out, unsigned n, expr* const* assumptions) const override { return s->display(out, n, assumptions); } diff --git a/src/solver/solver.h b/src/solver/solver.h index 99acc4a1e..d5b3ee4f8 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -27,6 +27,7 @@ class solver; class model_converter; + class solver_factory { public: virtual ~solver_factory() = default; @@ -249,9 +250,17 @@ public: virtual expr* congruence_next(expr* e) = 0; /** - \brief try to solve for term e (when e is arithmetical). + \brief expose explanation for congruence. */ - virtual bool solve_for(expr* e, expr_ref& term) { return false; } + virtual expr_ref congruence_explain(expr* a, expr* b) = 0; + + struct solution { + expr* var; + expr_ref term; + expr_ref guard; + }; + + virtual void solve_for(vector& s) {} /** \brief Display the content of this solver. diff --git a/src/solver/solver_pool.cpp b/src/solver/solver_pool.cpp index 411634162..5fbe8fa11 100644 --- a/src/solver/solver_pool.cpp +++ b/src/solver/solver_pool.cpp @@ -81,7 +81,7 @@ public: } void push_params() override {m_base->push_params();} void pop_params() override {m_base->pop_params();} - + void collect_param_descrs(param_descrs & r) override { m_base->collect_param_descrs(r); } void collect_statistics(statistics & st) const override { m_base->collect_statistics(st); } unsigned get_num_assertions() const override { return m_base->get_num_assertions(); } @@ -264,6 +264,7 @@ public: expr* congruence_next(expr* e) override { return e; } expr* congruence_root(expr* e) override { return e; } + expr_ref congruence_explain(expr* a, expr* b) override { return expr_ref(m.mk_eq(a, b), m); } ast_manager& get_manager() const override { return m_base->get_manager(); } diff --git a/src/solver/tactic2solver.cpp b/src/solver/tactic2solver.cpp index dfdae5e5a..bf849d830 100644 --- a/src/solver/tactic2solver.cpp +++ b/src/solver/tactic2solver.cpp @@ -145,6 +145,7 @@ public: expr* congruence_next(expr* e) override { return e; } expr* congruence_root(expr* e) override { return e; } + expr_ref congruence_explain(expr* a, expr* b) override { return expr_ref(get_manager().mk_eq(a, b), get_manager()); } model_converter_ref get_model_converter() const override { return m_mc; } diff --git a/src/tactic/fd_solver/bounded_int2bv_solver.cpp b/src/tactic/fd_solver/bounded_int2bv_solver.cpp index 45b444fe9..6f305372d 100644 --- a/src/tactic/fd_solver/bounded_int2bv_solver.cpp +++ b/src/tactic/fd_solver/bounded_int2bv_solver.cpp @@ -212,6 +212,7 @@ public: ast_manager& get_manager() const override { return m; } expr* congruence_next(expr* e) override { return m_solver->congruence_next(e); } expr* congruence_root(expr* e) override { return m_solver->congruence_root(e); } + expr_ref congruence_explain(expr* a, expr* b) override { return m_solver->congruence_explain(a, b); } expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) override { flush_assertions(); return m_solver->cube(vars, backtrack_level); } lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) override { return m_solver->find_mutexes(vars, mutexes); } lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { diff --git a/src/tactic/fd_solver/enum2bv_solver.cpp b/src/tactic/fd_solver/enum2bv_solver.cpp index 2690e7033..061f67010 100644 --- a/src/tactic/fd_solver/enum2bv_solver.cpp +++ b/src/tactic/fd_solver/enum2bv_solver.cpp @@ -133,6 +133,7 @@ public: } expr* congruence_next(expr* e) override { return m_solver->congruence_next(e); } expr* congruence_root(expr* e) override { return m_solver->congruence_root(e); } + expr_ref congruence_explain(expr* a, expr* b) override { return m_solver->congruence_explain(a, b); } lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { diff --git a/src/tactic/fd_solver/pb2bv_solver.cpp b/src/tactic/fd_solver/pb2bv_solver.cpp index 19f2630f2..7bc0f27dd 100644 --- a/src/tactic/fd_solver/pb2bv_solver.cpp +++ b/src/tactic/fd_solver/pb2bv_solver.cpp @@ -124,6 +124,7 @@ public: expr_ref_vector cube(expr_ref_vector& vars, unsigned backtrack_level) override { flush_assertions(); return m_solver->cube(vars, backtrack_level); } expr* congruence_next(expr* e) override { return m_solver->congruence_next(e); } expr* congruence_root(expr* e) override { return m_solver->congruence_root(e); } + expr_ref congruence_explain(expr* a, expr* b) override { return m_solver->congruence_explain(a, b); } lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) override { return m_solver->find_mutexes(vars, mutexes); } lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { flush_assertions(); diff --git a/src/tactic/fd_solver/smtfd_solver.cpp b/src/tactic/fd_solver/smtfd_solver.cpp index 4d0912fdc..46fb4054e 100644 --- a/src/tactic/fd_solver/smtfd_solver.cpp +++ b/src/tactic/fd_solver/smtfd_solver.cpp @@ -2070,6 +2070,8 @@ namespace smtfd { expr* congruence_root(expr* e) override { return e; } expr* congruence_next(expr* e) override { return e; } + + expr_ref congruence_explain(expr* a, expr* b) override { return expr_ref(m.mk_eq(a, b), m); } lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) override { return l_undef;