From e137aaa24988b315874fedbfa0e512fcf89ef463 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 2 Oct 2025 19:44:22 -0700 Subject: [PATCH] add user propagators to opt_solver --- src/cmd_context/cmd_context.cpp | 8 +++--- src/cmd_context/cmd_context.h | 5 +++- src/opt/opt_context.cpp | 5 ++++ src/opt/opt_context.h | 3 ++- src/opt/opt_solver.h | 43 +++++++++++++++++++++++++++++++++ 5 files changed, 58 insertions(+), 6 deletions(-) diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index 4af0782e1..5513a86ef 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -657,6 +657,8 @@ void cmd_context::set_opt(opt_wrapper* opt) { for (auto const& [var, value] : m_var2values) m_opt->initialize_value(var, value); m_opt->set_logic(m_logic); + if (m_preferred) + m_opt->set_preferred(m_preferred.get()); } void cmd_context::global_params_updated() { @@ -1896,11 +1898,9 @@ void cmd_context::set_preferred(expr* fmla) { get_solver()->user_propagate_register_decide(p->decide_eh); } } + if (get_opt()) + get_opt()->set_preferred(m_preferred.get()); m_preferred->set_preferred(fmla); - if (get_opt()) { - throw default_exception("setting preferred on optimization context is not supported yet"); - return; - } } void cmd_context::reset_preferred() { diff --git a/src/cmd_context/cmd_context.h b/src/cmd_context/cmd_context.h index ddc8b461a..b08944616 100644 --- a/src/cmd_context/cmd_context.h +++ b/src/cmd_context/cmd_context.h @@ -164,6 +164,9 @@ struct builtin_decl { }; class opt_wrapper : public check_sat_result { +protected: + preferred_value_propagator *m_preferred = nullptr; + public: opt_wrapper(ast_manager& m): check_sat_result(m) {} virtual bool empty() = 0; @@ -177,7 +180,7 @@ public: virtual void get_box_model(model_ref& mdl, unsigned index) = 0; virtual void updt_params(params_ref const& p) = 0; virtual void initialize_value(expr* var, expr* value) = 0; - + void set_preferred(preferred_value_propagator *p) { m_preferred = p; } }; class ast_context_params : public context_params { diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index 6244533f0..388befe93 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -316,6 +316,11 @@ namespace opt { m_model_converter->convert_initialize_value(m_scoped_state.m_values); for (auto & [var, value] : m_scoped_state.m_values) s.user_propagate_initialize_value(var, value); + if (m_preferred) { + auto p = m_preferred; + s.user_propagate_init(p, p->push_eh, p->pop_eh, p->fresh_eh); + s.user_propagate_register_decide(p->decide_eh); + } opt_params optp(m_params); symbol pri = optp.priority(); diff --git a/src/opt/opt_context.h b/src/opt/opt_context.h index 123d3a44b..ed2377bab 100644 --- a/src/opt/opt_context.h +++ b/src/opt/opt_context.h @@ -22,6 +22,7 @@ Notes: #include "ast/bv_decl_plugin.h" #include "ast/converters/model_converter.h" #include "tactic/tactic.h" +#include "solver/preferred_value_propagator.h" #include "qe/qsat.h" #include "opt/opt_solver.h" #include "opt/opt_pareto.h" @@ -231,7 +232,7 @@ namespace opt { void get_labels(svector & r) override; void get_unsat_core(expr_ref_vector & r) override; std::string reason_unknown() const override; - void set_reason_unknown(char const* msg) override { m_unknown = msg; } + void set_reason_unknown(char const* msg) override { m_unknown = msg; } void display_assignment(std::ostream& out) override; bool is_pareto() override { return m_pareto.get() != nullptr; } diff --git a/src/opt/opt_solver.h b/src/opt/opt_solver.h index e60bbfae6..a409e573a 100644 --- a/src/opt/opt_solver.h +++ b/src/opt/opt_solver.h @@ -117,6 +117,49 @@ namespace opt { void set_phase(phase* p) override { m_context.set_phase(p); } void move_to_front(expr* e) override { m_context.move_to_front(e); } void user_propagate_initialize_value(expr* var, expr* value) override { m_context.user_propagate_initialize_value(var, value); } + void user_propagate_init(void *ctx, user_propagator::push_eh_t &push_eh, user_propagator::pop_eh_t &pop_eh, user_propagator::fresh_eh_t &fresh_eh) override { + m_context.user_propagate_init(ctx, push_eh, pop_eh, fresh_eh); + m_first = false; + } + + void user_propagate_register_fixed(user_propagator::fixed_eh_t &fixed_eh) override { + m_context.user_propagate_register_fixed(fixed_eh); + } + + void user_propagate_register_final(user_propagator::final_eh_t &final_eh) override { + m_context.user_propagate_register_final(final_eh); + } + + void user_propagate_register_eq(user_propagator::eq_eh_t &eq_eh) override { + m_context.user_propagate_register_eq(eq_eh); + } + + void user_propagate_register_diseq(user_propagator::eq_eh_t &diseq_eh) override { + m_context.user_propagate_register_diseq(diseq_eh); + } + + void user_propagate_register_expr(expr *e) override { + m_context.user_propagate_register_expr(e); + } + + void user_propagate_register_created(user_propagator::created_eh_t &r) override { + m_context.user_propagate_register_created(r); + } + + void user_propagate_register_decide(user_propagator::decide_eh_t &r) override { + m_context.user_propagate_register_decide(r); + } + + void user_propagate_register_on_binding(user_propagator::binding_eh_t &r) override { + m_context.user_propagate_register_on_binding(r); + } + + void user_propagate_clear() override { + } + + void register_on_clause(void *, user_propagator::on_clause_eh_t &r) override { + m_context.register_on_clause(nullptr, r); + } void set_logic(symbol const& logic);