3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 10:25:18 +00:00

add facility to solve for a linear term over API

This commit is contained in:
Nikolaj Bjorner 2024-11-30 09:34:27 -08:00
parent d2411567b5
commit 05e053247d
14 changed files with 109 additions and 6 deletions

View file

@ -967,6 +967,20 @@ extern "C" {
Z3_CATCH_RETURN(nullptr);
}
Z3_ast Z3_API Z3_solver_solve_for(Z3_context c, Z3_solver s, Z3_ast a) {
Z3_TRY;
LOG_Z3_solver_solve_for(c, s, a);
RESET_ERROR_CODE();
init_solver(c, s);
ast_manager& m = mk_c(c)->m();
expr_ref term(m);
if (!to_solver_ref(s)->solve_for(to_expr(a), term))
term = to_expr(a);
mk_c(c)->save_ast_trail(term.get());
RETURN_Z3(of_expr(term.get()));
Z3_CATCH_RETURN(nullptr);
}
class api_context_obj : public user_propagator::context_obj {
api::context* c;
public:

View file

@ -7351,6 +7351,12 @@ class Solver(Z3PPObject):
"""
return _to_expr_ref(Z3_solver_congruence_next(self.ctx.ref(), self.solver, t.ast), self.ctx)
def solve_for(self, t):
t = _py2expr(t, self.ctx)
"""Retrieve a solution for t relative to linear equations maintained in the current state.
The function primarily works for SimpleSolver and when there is a solution using linear arithmetic."""
return _to_expr_ref(Z3_solver_solve_for(self.ctx.ref(), self.solver, t.ast), self.ctx)
def proof(self):
"""Return a proof for the last `check()`. Proof construction must be enabled."""
return _to_expr_ref(Z3_solver_get_proof(self.ctx.ref(), self.solver), self.ctx)

View file

@ -7077,6 +7077,14 @@ extern "C" {
Z3_ast Z3_API Z3_solver_congruence_next(Z3_context c, Z3_solver s, Z3_ast a);
/**
\brief retrieve a 'solution' for \c t as defined by equalities in maintained by solvers.
At this point, only linear solution are supported.
def_API('Z3_solver_solve_for', AST, (_in(CONTEXT), _in(SOLVER), _in(AST)))
*/
Z3_ast Z3_API Z3_solver_solve_for(Z3_context c, Z3_solver s, Z3_ast t);
/**
\brief register a callback to that retrieves assumed, inferred and deleted clauses during search.

View file

@ -617,6 +617,34 @@ namespace lp {
m_touched_rows.insert(rid);
}
bool lar_solver::solve_for(unsigned j, lar_term& t, mpq& coeff) {
t.clear();
if (column_is_fixed(j)) {
coeff = get_value(j);
return true;
}
if (!is_base(j)) {
for (const auto & c : A_r().m_columns[j]) {
lpvar basic_in_row = r_basis()[c.var()];
pivot(j, basic_in_row);
break;
}
}
if (!is_base(j))
return false;
auto const& r = basic2row(j);
for (auto const& c : r) {
if (c.var() == j)
continue;
if (column_is_fixed(c.var()))
coeff -= get_value(c.var());
else
t.add_monomial(-c.coeff(), c.var());
}
return true;
}
void lar_solver::remove_fixed_vars_from_base() {
// this will allow to disable and restore the tracking of the touched rows
flet<indexed_uint_set*> f(m_mpq_lar_core_solver.m_r_solver.m_touched_rows, nullptr);

View file

@ -346,6 +346,12 @@ public:
void set_value_for_nbasic_column(unsigned j, const impq& new_val);
void remove_fixed_vars_from_base();
/**
* \brief set j to basic (if not already basic)
* return the rest of the row as t comprising of non-fixed variables and coeff as sum of fixed variables.
* return false if j has no rows.
*/
bool solve_for(unsigned j, lar_term& t, mpq& coeff);
inline unsigned get_base_column_in_row(unsigned row_index) const {
return m_mpq_lar_core_solver.m_r_solver.get_base_column_in_row(row_index);

View file

@ -4654,6 +4654,15 @@ namespace smt {
return false;
return th->get_value(n, value);
}
bool context::solve_for(enode * n, expr_ref & term) {
sort * s = n->get_sort();
family_id fid = s->get_family_id();
theory * th = get_theory(fid);
if (th == nullptr)
return false;
return th->solve_for(n, term);
}
bool context::update_model(bool refinalize) {
final_check_status fcs = FC_DONE;

View file

@ -1375,11 +1375,6 @@ namespace smt {
bool can_propagate() const;
// Retrieve arithmetic values.
bool get_arith_lo(expr* e, rational& lo, bool& strict);
bool get_arith_up(expr* e, rational& up, bool& strict);
bool get_arith_value(expr* e, rational& value);
// -----------------------------------
//
// Model checking... (must be improved)
@ -1388,6 +1383,8 @@ namespace smt {
public:
bool get_value(enode * n, expr_ref & value);
bool solve_for(enode* n, expr_ref& term);
// -----------------------------------
//
// Pretty Printing

View file

@ -213,8 +213,15 @@ namespace smt {
return out;
}
bool kernel::solve_for(expr* e, expr_ref& term) {
smt::enode* n = m_imp->m_kernel.find_enode(e);
if (!n)
return false;
return m_imp->m_kernel.solve_for(n, term);
}
expr* kernel::congruence_root(expr * e) {
smt::enode* n = m_imp->m_kernel.find_enode(e);
smt::enode* n = m_imp->m_kernel.find_enode(e);
if (!n)
return e;
return n->get_root()->get_expr();

View file

@ -246,6 +246,7 @@ namespace smt {
expr* congruence_root(expr* e);
bool solve_for(expr* e, expr_ref& term);
/**
\brief retrieve depth of variables from decision stack.

View file

@ -337,6 +337,7 @@ namespace {
expr* congruence_next(expr* e) override { return m_context.congruence_next(e); }
expr* congruence_root(expr* e) override { return m_context.congruence_root(e); }
bool solve_for(expr* e, expr_ref& term) override { return m_context.solve_for(e, term); }
expr_ref_vector cube(expr_ref_vector& vars, unsigned cutoff) override {

View file

@ -605,6 +605,8 @@ namespace smt {
virtual char const * get_name() const { return "unknown"; }
virtual bool solve_for(enode* n, expr_ref& r) { return false; }
// -----------------------------------
//
// Return a fresh new instance of the given theory.

View file

@ -3627,7 +3627,21 @@ public:
lpvar vi = get_lpvar(v);
u_dependency* dep = nullptr;
return lp().has_upper_bound(vi, dep, val, is_strict);
}
bool solve_for(enode* n, expr_ref& term) {
theory_var v = n->get_th_var(get_id());
if (!is_registered_var(v))
return false;
lpvar vi = get_lpvar(v);
lp::lar_term t;
rational coeff;
if (!lp().solve_for(vi, t, coeff))
return false;
term = mk_term(t, is_int(v));
if (coeff != 0)
term = a.mk_add(a.mk_numeral(coeff, is_int(v)), term);
return true;
}
bool get_upper(enode* n, expr_ref& r) {
@ -4140,6 +4154,10 @@ bool theory_lra::get_lower(enode* n, rational& r, bool& is_strict) {
bool theory_lra::get_upper(enode* n, rational& r, bool& is_strict) {
return m_imp->get_upper(n, r, is_strict);
}
bool theory_lra::solve_for(enode* n, expr_ref& r) {
return m_imp->solve_for(n, r);
}
void theory_lra::display(std::ostream & out) const {
m_imp->display(out);
}

View file

@ -93,6 +93,7 @@ namespace smt {
bool get_upper(enode* n, expr_ref& r);
bool get_lower(enode* n, rational& r, bool& is_strict);
bool get_upper(enode* n, rational& r, bool& is_strict);
bool solve_for(enode* n, expr_ref& r) override;
void display(std::ostream & out) const override;

View file

@ -248,6 +248,11 @@ public:
*/
virtual expr* congruence_next(expr* e) = 0;
/**
\brief try to solve for term e (when e is arithmetical).
*/
virtual bool solve_for(expr* e, expr_ref& term) { return false; }
/**
\brief Display the content of this solver.
*/