diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 023dedf97..16daed6d4 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -967,6 +967,20 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } + Z3_ast Z3_API Z3_solver_solve_for(Z3_context c, Z3_solver s, Z3_ast a) { + Z3_TRY; + LOG_Z3_solver_solve_for(c, s, a); + RESET_ERROR_CODE(); + init_solver(c, s); + ast_manager& m = mk_c(c)->m(); + expr_ref term(m); + if (!to_solver_ref(s)->solve_for(to_expr(a), term)) + term = to_expr(a); + mk_c(c)->save_ast_trail(term.get()); + RETURN_Z3(of_expr(term.get())); + Z3_CATCH_RETURN(nullptr); + } + class api_context_obj : public user_propagator::context_obj { api::context* c; public: diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index bd59c9917..8231c2036 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -7351,6 +7351,12 @@ class Solver(Z3PPObject): """ return _to_expr_ref(Z3_solver_congruence_next(self.ctx.ref(), self.solver, t.ast), self.ctx) + def solve_for(self, t): + t = _py2expr(t, self.ctx) + """Retrieve a solution for t relative to linear equations maintained in the current state. + The function primarily works for SimpleSolver and when there is a solution using linear arithmetic.""" + return _to_expr_ref(Z3_solver_solve_for(self.ctx.ref(), self.solver, t.ast), self.ctx) + def proof(self): """Return a proof for the last `check()`. Proof construction must be enabled.""" return _to_expr_ref(Z3_solver_get_proof(self.ctx.ref(), self.solver), self.ctx) diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 6c3efe7fc..791d36b2f 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -7077,6 +7077,14 @@ extern "C" { Z3_ast Z3_API Z3_solver_congruence_next(Z3_context c, Z3_solver s, Z3_ast a); + /** + \brief retrieve a 'solution' for \c t as defined by equalities in maintained by solvers. + At this point, only linear solution are supported. + + def_API('Z3_solver_solve_for', AST, (_in(CONTEXT), _in(SOLVER), _in(AST))) + */ + Z3_ast Z3_API Z3_solver_solve_for(Z3_context c, Z3_solver s, Z3_ast t); + /** \brief register a callback to that retrieves assumed, inferred and deleted clauses during search. diff --git a/src/math/lp/lar_solver.cpp b/src/math/lp/lar_solver.cpp index 3e3d98548..20c6d9f30 100644 --- a/src/math/lp/lar_solver.cpp +++ b/src/math/lp/lar_solver.cpp @@ -617,6 +617,34 @@ namespace lp { m_touched_rows.insert(rid); } + bool lar_solver::solve_for(unsigned j, lar_term& t, mpq& coeff) { + t.clear(); + if (column_is_fixed(j)) { + coeff = get_value(j); + return true; + } + if (!is_base(j)) { + for (const auto & c : A_r().m_columns[j]) { + lpvar basic_in_row = r_basis()[c.var()]; + pivot(j, basic_in_row); + break; + } + } + if (!is_base(j)) + return false; + auto const& r = basic2row(j); + for (auto const& c : r) { + if (c.var() == j) + continue; + if (column_is_fixed(c.var())) + coeff -= get_value(c.var()); + else + t.add_monomial(-c.coeff(), c.var()); + } + return true; + } + + void lar_solver::remove_fixed_vars_from_base() { // this will allow to disable and restore the tracking of the touched rows flet f(m_mpq_lar_core_solver.m_r_solver.m_touched_rows, nullptr); diff --git a/src/math/lp/lar_solver.h b/src/math/lp/lar_solver.h index c58fe7917..6d7b4f08f 100644 --- a/src/math/lp/lar_solver.h +++ b/src/math/lp/lar_solver.h @@ -346,6 +346,12 @@ public: void set_value_for_nbasic_column(unsigned j, const impq& new_val); void remove_fixed_vars_from_base(); + /** + * \brief set j to basic (if not already basic) + * return the rest of the row as t comprising of non-fixed variables and coeff as sum of fixed variables. + * return false if j has no rows. + */ + bool solve_for(unsigned j, lar_term& t, mpq& coeff); inline unsigned get_base_column_in_row(unsigned row_index) const { return m_mpq_lar_core_solver.m_r_solver.get_base_column_in_row(row_index); diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index e23c8949c..b165e54a6 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -4654,6 +4654,15 @@ namespace smt { return false; return th->get_value(n, value); } + + bool context::solve_for(enode * n, expr_ref & term) { + sort * s = n->get_sort(); + family_id fid = s->get_family_id(); + theory * th = get_theory(fid); + if (th == nullptr) + return false; + return th->solve_for(n, term); + } bool context::update_model(bool refinalize) { final_check_status fcs = FC_DONE; diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index eb9f2d30e..5daffb959 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -1375,11 +1375,6 @@ namespace smt { bool can_propagate() const; - // Retrieve arithmetic values. - bool get_arith_lo(expr* e, rational& lo, bool& strict); - bool get_arith_up(expr* e, rational& up, bool& strict); - bool get_arith_value(expr* e, rational& value); - // ----------------------------------- // // Model checking... (must be improved) @@ -1388,6 +1383,8 @@ namespace smt { public: bool get_value(enode * n, expr_ref & value); + bool solve_for(enode* n, expr_ref& term); + // ----------------------------------- // // Pretty Printing diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index 2d6c29532..b789fc499 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -213,8 +213,15 @@ namespace smt { return out; } + bool kernel::solve_for(expr* e, expr_ref& term) { + smt::enode* n = m_imp->m_kernel.find_enode(e); + if (!n) + return false; + return m_imp->m_kernel.solve_for(n, term); + } + expr* kernel::congruence_root(expr * e) { - smt::enode* n = m_imp->m_kernel.find_enode(e); + smt::enode* n = m_imp->m_kernel.find_enode(e); if (!n) return e; return n->get_root()->get_expr(); diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index 539a32750..fb152db8f 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -246,6 +246,7 @@ namespace smt { expr* congruence_root(expr* e); + bool solve_for(expr* e, expr_ref& term); /** \brief retrieve depth of variables from decision stack. diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 7b9d416f3..65feaaa9d 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -337,6 +337,7 @@ namespace { expr* congruence_next(expr* e) override { return m_context.congruence_next(e); } expr* congruence_root(expr* e) override { return m_context.congruence_root(e); } + bool solve_for(expr* e, expr_ref& term) override { return m_context.solve_for(e, term); } expr_ref_vector cube(expr_ref_vector& vars, unsigned cutoff) override { diff --git a/src/smt/smt_theory.h b/src/smt/smt_theory.h index 74dfe8aa2..2aed3f449 100644 --- a/src/smt/smt_theory.h +++ b/src/smt/smt_theory.h @@ -605,6 +605,8 @@ namespace smt { virtual char const * get_name() const { return "unknown"; } + virtual bool solve_for(enode* n, expr_ref& r) { return false; } + // ----------------------------------- // // Return a fresh new instance of the given theory. diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index 6b51e2f69..d7df4ea22 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -3627,7 +3627,21 @@ public: lpvar vi = get_lpvar(v); u_dependency* dep = nullptr; return lp().has_upper_bound(vi, dep, val, is_strict); + } + bool solve_for(enode* n, expr_ref& term) { + theory_var v = n->get_th_var(get_id()); + if (!is_registered_var(v)) + return false; + lpvar vi = get_lpvar(v); + lp::lar_term t; + rational coeff; + if (!lp().solve_for(vi, t, coeff)) + return false; + term = mk_term(t, is_int(v)); + if (coeff != 0) + term = a.mk_add(a.mk_numeral(coeff, is_int(v)), term); + return true; } bool get_upper(enode* n, expr_ref& r) { @@ -4140,6 +4154,10 @@ bool theory_lra::get_lower(enode* n, rational& r, bool& is_strict) { bool theory_lra::get_upper(enode* n, rational& r, bool& is_strict) { return m_imp->get_upper(n, r, is_strict); } +bool theory_lra::solve_for(enode* n, expr_ref& r) { + return m_imp->solve_for(n, r); +} + void theory_lra::display(std::ostream & out) const { m_imp->display(out); } diff --git a/src/smt/theory_lra.h b/src/smt/theory_lra.h index 96988f957..6267ae68d 100644 --- a/src/smt/theory_lra.h +++ b/src/smt/theory_lra.h @@ -93,6 +93,7 @@ namespace smt { bool get_upper(enode* n, expr_ref& r); bool get_lower(enode* n, rational& r, bool& is_strict); bool get_upper(enode* n, rational& r, bool& is_strict); + bool solve_for(enode* n, expr_ref& r) override; void display(std::ostream & out) const override; diff --git a/src/solver/solver.h b/src/solver/solver.h index 7d7a3eec2..99acc4a1e 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -248,6 +248,11 @@ public: */ virtual expr* congruence_next(expr* e) = 0; + /** + \brief try to solve for term e (when e is arithmetical). + */ + virtual bool solve_for(expr* e, expr_ref& term) { return false; } + /** \brief Display the content of this solver. */