3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-10-04 23:13:57 +00:00

add user propagators to opt_solver

This commit is contained in:
Nikolaj Bjorner 2025-10-02 19:44:22 -07:00
parent 0e6b3a922a
commit e137aaa249
5 changed files with 58 additions and 6 deletions

View file

@ -657,6 +657,8 @@ void cmd_context::set_opt(opt_wrapper* opt) {
for (auto const& [var, value] : m_var2values) for (auto const& [var, value] : m_var2values)
m_opt->initialize_value(var, value); m_opt->initialize_value(var, value);
m_opt->set_logic(m_logic); m_opt->set_logic(m_logic);
if (m_preferred)
m_opt->set_preferred(m_preferred.get());
} }
void cmd_context::global_params_updated() { void cmd_context::global_params_updated() {
@ -1896,11 +1898,9 @@ void cmd_context::set_preferred(expr* fmla) {
get_solver()->user_propagate_register_decide(p->decide_eh); get_solver()->user_propagate_register_decide(p->decide_eh);
} }
} }
if (get_opt())
get_opt()->set_preferred(m_preferred.get());
m_preferred->set_preferred(fmla); m_preferred->set_preferred(fmla);
if (get_opt()) {
throw default_exception("setting preferred on optimization context is not supported yet");
return;
}
} }
void cmd_context::reset_preferred() { void cmd_context::reset_preferred() {

View file

@ -164,6 +164,9 @@ struct builtin_decl {
}; };
class opt_wrapper : public check_sat_result { class opt_wrapper : public check_sat_result {
protected:
preferred_value_propagator *m_preferred = nullptr;
public: public:
opt_wrapper(ast_manager& m): check_sat_result(m) {} opt_wrapper(ast_manager& m): check_sat_result(m) {}
virtual bool empty() = 0; virtual bool empty() = 0;
@ -177,7 +180,7 @@ public:
virtual void get_box_model(model_ref& mdl, unsigned index) = 0; virtual void get_box_model(model_ref& mdl, unsigned index) = 0;
virtual void updt_params(params_ref const& p) = 0; virtual void updt_params(params_ref const& p) = 0;
virtual void initialize_value(expr* var, expr* value) = 0; virtual void initialize_value(expr* var, expr* value) = 0;
void set_preferred(preferred_value_propagator *p) { m_preferred = p; }
}; };
class ast_context_params : public context_params { class ast_context_params : public context_params {

View file

@ -316,6 +316,11 @@ namespace opt {
m_model_converter->convert_initialize_value(m_scoped_state.m_values); m_model_converter->convert_initialize_value(m_scoped_state.m_values);
for (auto & [var, value] : m_scoped_state.m_values) for (auto & [var, value] : m_scoped_state.m_values)
s.user_propagate_initialize_value(var, value); s.user_propagate_initialize_value(var, value);
if (m_preferred) {
auto p = m_preferred;
s.user_propagate_init(p, p->push_eh, p->pop_eh, p->fresh_eh);
s.user_propagate_register_decide(p->decide_eh);
}
opt_params optp(m_params); opt_params optp(m_params);
symbol pri = optp.priority(); symbol pri = optp.priority();

View file

@ -22,6 +22,7 @@ Notes:
#include "ast/bv_decl_plugin.h" #include "ast/bv_decl_plugin.h"
#include "ast/converters/model_converter.h" #include "ast/converters/model_converter.h"
#include "tactic/tactic.h" #include "tactic/tactic.h"
#include "solver/preferred_value_propagator.h"
#include "qe/qsat.h" #include "qe/qsat.h"
#include "opt/opt_solver.h" #include "opt/opt_solver.h"
#include "opt/opt_pareto.h" #include "opt/opt_pareto.h"
@ -231,7 +232,7 @@ namespace opt {
void get_labels(svector<symbol> & r) override; void get_labels(svector<symbol> & r) override;
void get_unsat_core(expr_ref_vector & r) override; void get_unsat_core(expr_ref_vector & r) override;
std::string reason_unknown() const override; std::string reason_unknown() const override;
void set_reason_unknown(char const* msg) override { m_unknown = msg; } void set_reason_unknown(char const* msg) override { m_unknown = msg; }
void display_assignment(std::ostream& out) override; void display_assignment(std::ostream& out) override;
bool is_pareto() override { return m_pareto.get() != nullptr; } bool is_pareto() override { return m_pareto.get() != nullptr; }

View file

@ -117,6 +117,49 @@ namespace opt {
void set_phase(phase* p) override { m_context.set_phase(p); } void set_phase(phase* p) override { m_context.set_phase(p); }
void move_to_front(expr* e) override { m_context.move_to_front(e); } void move_to_front(expr* e) override { m_context.move_to_front(e); }
void user_propagate_initialize_value(expr* var, expr* value) override { m_context.user_propagate_initialize_value(var, value); } void user_propagate_initialize_value(expr* var, expr* value) override { m_context.user_propagate_initialize_value(var, value); }
void user_propagate_init(void *ctx, user_propagator::push_eh_t &push_eh, user_propagator::pop_eh_t &pop_eh, user_propagator::fresh_eh_t &fresh_eh) override {
m_context.user_propagate_init(ctx, push_eh, pop_eh, fresh_eh);
m_first = false;
}
void user_propagate_register_fixed(user_propagator::fixed_eh_t &fixed_eh) override {
m_context.user_propagate_register_fixed(fixed_eh);
}
void user_propagate_register_final(user_propagator::final_eh_t &final_eh) override {
m_context.user_propagate_register_final(final_eh);
}
void user_propagate_register_eq(user_propagator::eq_eh_t &eq_eh) override {
m_context.user_propagate_register_eq(eq_eh);
}
void user_propagate_register_diseq(user_propagator::eq_eh_t &diseq_eh) override {
m_context.user_propagate_register_diseq(diseq_eh);
}
void user_propagate_register_expr(expr *e) override {
m_context.user_propagate_register_expr(e);
}
void user_propagate_register_created(user_propagator::created_eh_t &r) override {
m_context.user_propagate_register_created(r);
}
void user_propagate_register_decide(user_propagator::decide_eh_t &r) override {
m_context.user_propagate_register_decide(r);
}
void user_propagate_register_on_binding(user_propagator::binding_eh_t &r) override {
m_context.user_propagate_register_on_binding(r);
}
void user_propagate_clear() override {
}
void register_on_clause(void *, user_propagator::on_clause_eh_t &r) override {
m_context.register_on_clause(nullptr, r);
}
void set_logic(symbol const& logic); void set_logic(symbol const& logic);