diff --git a/src/api/api_opt.cpp b/src/api/api_opt.cpp index 5854bdaca..a9a0b3230 100644 --- a/src/api/api_opt.cpp +++ b/src/api/api_opt.cpp @@ -459,6 +459,21 @@ extern "C" { Z3_CATCH; } - + void Z3_API Z3_optimize_set_initial_value(Z3_context c, Z3_optimize o, Z3_ast var, Z3_ast value) { + Z3_TRY; + LOG_Z3_optimize_set_initial_value(c, o, var, value); + RESET_ERROR_CODE(); + if (to_expr(var)->get_sort() != to_expr(value)->get_sort()) { + SET_ERROR_CODE(Z3_INVALID_USAGE, "variable and value should have same sort"); + return; + } + ast_manager& m = mk_c(c)->m(); + if (!m.is_value(to_expr(value))) { + SET_ERROR_CODE(Z3_INVALID_USAGE, "a proper value was not supplied"); + return; + } + to_optimize_ptr(o)->initialize_value(to_expr(var), to_expr(value)); + Z3_CATCH; + } }; diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index f18edd96b..f226529de 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -1143,5 +1143,23 @@ extern "C" { Z3_CATCH_RETURN(nullptr); } + void Z3_API Z3_solver_set_initial_value(Z3_context c, Z3_solver s, Z3_ast var, Z3_ast value) { + Z3_TRY; + LOG_Z3_solver_set_initial_value(c, s, var, value); + RESET_ERROR_CODE(); + if (to_expr(var)->get_sort() != to_expr(value)->get_sort()) { + SET_ERROR_CODE(Z3_INVALID_USAGE, "variable and value should have same sort"); + return; + } + ast_manager& m = mk_c(c)->m(); + if (!m.is_value(to_expr(value))) { + SET_ERROR_CODE(Z3_INVALID_USAGE, "a proper value was not supplied"); + return; + } + to_solver_ref(s)->user_propagate_initialize_value(to_expr(var), to_expr(value)); + Z3_CATCH; + } + + }; diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index 81d5bcaa9..bb148d712 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -2865,6 +2865,17 @@ namespace z3 { check_error(); return result; } + void set_initial_value(expr const& var, expr const& value) { + Z3_solver_set_initial_value(ctx(), m_solver, var, value); + check_error(); + } + void set_initial_value(expr const& var, int i) { + set_initial_value(var, ctx().num_val(i, var.get_sort())); + } + void set_initial_value(expr const& var, bool b) { + set_initial_value(var, ctx().bool_val(b)); + } + expr proof() const { Z3_ast r = Z3_solver_get_proof(ctx(), m_solver); check_error(); return expr(ctx(), r); } friend std::ostream & operator<<(std::ostream & out, solver const & s); @@ -3330,6 +3341,17 @@ namespace z3 { handle add(expr const& e, unsigned weight) { return add_soft(e, weight); } + void set_initial_value(expr const& var, expr const& value) { + Z3_optimize_set_initial_value(ctx(), m_opt, var, value); + check_error(); + } + void set_initial_value(expr const& var, int i) { + set_initial_value(var, ctx().num_val(i, var.get_sort())); + } + void set_initial_value(expr const& var, bool b) { + set_initial_value(var, ctx().bool_val(b)); + } + handle maximize(expr const& e) { return handle(Z3_optimize_maximize(ctx(), m_opt, e)); } diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 25dc341b8..5c2a35995 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -7353,6 +7353,13 @@ class Solver(Z3PPObject): Z3_solver_get_levels(self.ctx.ref(), self.solver, trail.vector, len(trail), levels) return trail, levels + def set_initial_value(self, var, value): + """initialize the solver's state by setting the initial value of var to value + """ + s = var.sort() + value = s.cast(value) + Z3_solver_set_initial_value(self.ctx.ref(), self.solver, var.ast, value.ast) + def trail(self): """Return trail of the solver state after a check() call. """ @@ -8032,6 +8039,13 @@ class Optimize(Z3PPObject): return [asoft(a) for a in arg] return asoft(arg) + def set_initial_value(self, var, value): + """initialize the solver's state by setting the initial value of var to value + """ + s = var.sort() + value = s.cast(value) + Z3_optimize_set_initial_value(self.ctx.ref(), self.optimize, var.ast, value.ast) + def maximize(self, arg): """Add objective function to maximize.""" return OptimizeObjective( diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 1f0daf8b5..fdc25ef46 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -7241,6 +7241,18 @@ extern "C" { bool Z3_API Z3_solver_propagate_consequence(Z3_context c, Z3_solver_callback cb, unsigned num_fixed, Z3_ast const* fixed, unsigned num_eqs, Z3_ast const* eq_lhs, Z3_ast const* eq_rhs, Z3_ast conseq); + + /** + \brief provide an initialization hint to the solver. The initialization hint is used to calibrate an initial value of the expression that + represents a variable. If the variable is Boolean, the initial phase is set according to \c value. If the variable is an integer or real, + the initial Simplex tableau is recalibrated to attempt to follow the value assignment. + + def_API('Z3_solver_set_initial_value', VOID, (_in(CONTEXT), _in(SOLVER), _in(AST), _in(AST))) + */ + + void Z3_API Z3_solver_set_initial_value(Z3_context c, Z3_solver s, Z3_ast var, Z3_ast value); + + /** \brief Check whether the assertions in a given solver are consistent or not. diff --git a/src/api/z3_optimization.h b/src/api/z3_optimization.h index 8bf0e9da5..ad55cab1d 100644 --- a/src/api/z3_optimization.h +++ b/src/api/z3_optimization.h @@ -139,6 +139,18 @@ extern "C" { */ void Z3_API Z3_optimize_pop(Z3_context c, Z3_optimize d); + /** + \brief provide an initialization hint to the solver. + The initialization hint is used to calibrate an initial value of the expression that + represents a variable. If the variable is Boolean, the initial phase is set + according to \c value. If the variable is an integer or real, + the initial Simplex tableau is recalibrated to attempt to follow the value assignment. + + def_API('Z3_optimize_set_initial_value', VOID, (_in(CONTEXT), _in(OPTIMIZE), _in(AST), _in(AST))) + */ + + void Z3_API Z3_optimize_set_initial_value(Z3_context c, Z3_optimize o, Z3_ast var, Z3_ast value); + /** \brief Check consistency and produce optimal values. \param c - context diff --git a/src/math/lp/lar_solver.cpp b/src/math/lp/lar_solver.cpp index 9272e0298..b0f153e46 100644 --- a/src/math/lp/lar_solver.cpp +++ b/src/math/lp/lar_solver.cpp @@ -2081,6 +2081,24 @@ namespace lp { lpvar lar_solver::to_column(unsigned ext_j) const { return m_var_register.external_to_local(ext_j); } + + bool lar_solver::move_lpvar_to_value(lpvar j, mpq const& value) { + if (is_base(j)) + return false; + + impq ivalue(value); + auto& lcs = m_mpq_lar_core_solver; + auto& slv = m_mpq_lar_core_solver.m_r_solver; + + if (slv.column_has_upper_bound(j) && lcs.m_r_upper_bounds()[j] < ivalue) + return false; + if (slv.column_has_lower_bound(j) && lcs.m_r_lower_bounds()[j] > ivalue) + return false; + + set_value_for_nbasic_column(j, ivalue); + return true; + } + bool lar_solver::tighten_term_bounds_by_delta(lpvar j, const impq& delta) { SASSERT(column_has_term(j)); diff --git a/src/math/lp/lar_solver.h b/src/math/lp/lar_solver.h index f223b5cc5..c58fe7917 100644 --- a/src/math/lp/lar_solver.h +++ b/src/math/lp/lar_solver.h @@ -623,6 +623,7 @@ public: lp_status find_feasible_solution(); void move_non_basic_columns_to_bounds(); bool move_non_basic_column_to_bounds(unsigned j); + bool move_lpvar_to_value(lpvar j, mpq const& value); inline bool r_basis_has_inf_int() const { for (unsigned j : r_basis()) { if (column_is_int(j) && !column_value_is_int(j)) diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index 1b57a7200..90bc0ddd6 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -58,8 +58,9 @@ namespace opt { } void context::scoped_state::pop() { - m_hard.resize(m_hard_lim.back()); - m_asms.resize(m_asms_lim.back()); + m_hard.shrink(m_hard_lim.back()); + m_asms.shrink(m_asms_lim.back()); + m_values.shrink(m_values_lim.back()); unsigned k = m_objectives_term_trail_lim.back(); while (m_objectives_term_trail.size() > k) { unsigned idx = m_objectives_term_trail.back(); @@ -79,6 +80,7 @@ namespace opt { m_objectives_lim.pop_back(); m_hard_lim.pop_back(); m_asms_lim.pop_back(); + m_values_lim.pop_back(); } void context::scoped_state::add(expr* hard) { @@ -306,13 +308,11 @@ namespace opt { if (contains_quantifiers()) { warning_msg("optimization with quantified constraints is not supported"); } -#if 0 - if (is_qsat_opt()) { - return run_qsat_opt(); - } -#endif solver& s = get_solver(); s.assert_expr(m_hard_constraints); + for (auto const& [var, value] : m_scoped_state.m_values) { + s.user_propagate_initialize_value(var, value); + } opt_params optp(m_params); symbol pri = optp.priority(); @@ -697,6 +697,11 @@ namespace opt { } } + void context::initialize_value(expr* var, expr* value) { + m_scoped_state.m_values.push_back({expr_ref(var, m), expr_ref(value, m)}); + } + + /** * Set the solver to the SAT core. * It requres: diff --git a/src/opt/opt_context.h b/src/opt/opt_context.h index 4e791531e..845fd3968 100644 --- a/src/opt/opt_context.h +++ b/src/opt/opt_context.h @@ -140,12 +140,14 @@ namespace opt { unsigned_vector m_objectives_lim; unsigned_vector m_objectives_term_trail; unsigned_vector m_objectives_term_trail_lim; + unsigned_vector m_values_lim; map_id m_indices; public: expr_ref_vector m_hard; expr_ref_vector m_asms; vector m_objectives; + vector> m_values; scoped_state(ast_manager& m): m(m), @@ -275,6 +277,8 @@ namespace opt { void add_offset(unsigned id, rational const& o) override; + void initialize_value(expr* var, expr* value); + void register_on_model(on_model_t& ctx, std::function& on_model) { m_on_model_ctx = ctx; m_on_model_eh = on_model; diff --git a/src/opt/opt_solver.h b/src/opt/opt_solver.h index 2682fca09..66835df48 100644 --- a/src/opt/opt_solver.h +++ b/src/opt/opt_solver.h @@ -116,6 +116,7 @@ namespace opt { phase* get_phase() override { return m_context.get_phase(); } void set_phase(phase* p) override { m_context.set_phase(p); } void move_to_front(expr* e) override { m_context.move_to_front(e); } + void user_propagate_initialize_value(expr* var, expr* value) override { m_context.user_propagate_initialize_value(var, value); } void set_logic(symbol const& logic); diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index 4574d3da3..0129a026e 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -702,6 +702,10 @@ public: ensure_euf()->user_propagate_register_decide(r); } + void user_propagate_initialize_value(expr* var, expr* value) override { + ensure_euf()->user_propagate_initialize_value(var, value); + } + private: diff --git a/src/sat/sat_solver/sat_smt_solver.cpp b/src/sat/sat_solver/sat_smt_solver.cpp index 19b10eb3e..1c141a801 100644 --- a/src/sat/sat_solver/sat_smt_solver.cpp +++ b/src/sat/sat_solver/sat_smt_solver.cpp @@ -577,6 +577,11 @@ public: ensure_euf()->user_propagate_register_decide(r); } + void user_propagate_initialize_value(expr* var, expr* value) override { + ensure_euf()->user_propagate_initialize_value(var, value); + } + + private: void add_assumption(expr* a) { diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index b866990af..0a58d7d1d 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -1256,6 +1256,31 @@ namespace euf { add_solver(m_user_propagator); } + void solver::user_propagate_initialize_value(expr* var, expr* value) { + if (m.is_bool(var)) { + auto lit = expr2literal(var); + if (lit == sat::null_literal) { + IF_VERBOSE(5, verbose_stream() << "no literal associated with " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n"); + return; + } + if (m.is_true(value)) + s().set_phase(lit); + else if (m.is_false(value)) + s().set_phase(~lit); + else + IF_VERBOSE(5, verbose_stream() << "malformed value " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n"); + return; + } + auto* th = m_id2solver.get(var->get_sort()->get_family_id(), nullptr); + if (!th) { + IF_VERBOSE(5, verbose_stream() << "no default initialization associated with " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n"); + return; + } + // th->initialize_value(var, value); + IF_VERBOSE(5, verbose_stream() << "no default initialization associated with " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n"); + } + + bool solver::watches_fixed(enode* n) const { return m_user_propagator && m_user_propagator->has_fixed() && n->get_th_var(m_user_propagator->get_id()) != null_theory_var; } diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 1b13d5137..8c436b942 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -564,6 +564,8 @@ namespace euf { m_user_propagator->add_expr(e); } + void user_propagate_initialize_value(expr* var, expr* value); + // solver factory ::solver* mk_solver() { return m_mk_solver(); } void set_mk_solver(std::function<::solver*(void)>& mk) { m_mk_solver = mk; } diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index c11084e56..163da9b16 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -2914,6 +2914,44 @@ namespace smt { register_plugin(m_user_propagator); } + void context::user_propagate_initialize_value(expr* var, expr* value) { + m_values.push_back({expr_ref(var, m), expr_ref(value, m)}); + push_trail(push_back_vector(m_values)); + } + + void context::initialize_value(expr* var, expr* value) { + IF_VERBOSE(10, verbose_stream() << "context initialize " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n"); + sort* s = var->get_sort(); + ensure_internalized(var); + + if (m.is_bool(s)) { + auto v = get_bool_var_of_id_option(var->get_id()); + if (v == null_bool_var) { + IF_VERBOSE(5, verbose_stream() << "Boolean variable has no literal " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n"); + return; + } + m_bdata[v].m_phase_available = true; + if (m.is_true(value)) + m_bdata[v].m_phase = true; + else if (m.is_false(value)) + m_bdata[v].m_phase = false; + else + IF_VERBOSE(5, verbose_stream() << "Boolean value is not constant " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n"); + return; + } + + if (!e_internalized(var)) + return; + enode* n = get_enode(var); + theory* th = m_theories.get_plugin(s->get_family_id()); + if (!th) { + IF_VERBOSE(5, verbose_stream() << "No theory is attached to variable " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n"); + return; + } + th->initialize_value(var, value); + + } + bool context::watches_fixed(enode* n) const { return m_user_propagator && m_user_propagator->has_fixed() && n->get_th_var(m_user_propagator->get_family_id()) != null_theory_var; } @@ -3756,6 +3794,9 @@ namespace smt { TRACE("search", display(tout); display_enodes_lbls(tout);); TRACE("search_detail", m_asserted_formulas.display(tout);); init_search(); + for (auto const& [var, value] : m_values) + initialize_value(var, value); + flet l(m_searching, true); TRACE("after_init_search", display(tout);); IF_VERBOSE(2, verbose_stream() << "(smt.searching)\n";); diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index da7cde7e7..afbfd0e85 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -123,6 +123,7 @@ namespace smt { unsigned m_par_index = 0; bool m_internalizing_assertions = false; + // ----------------------------------- // // Equality & Uninterpreted functions @@ -246,6 +247,16 @@ namespace smt { vector m_th_case_split_sets; u_map< vector > m_literal2casesplitsets; // returns the case split literal sets that a literal participates in + + // ---------------------------------- + // + // Value initialization + // + // ---------------------------------- + vector> m_values; + void initialize_value(expr* var, expr* value); + + // ----------------------------------- // // Accessors @@ -1777,6 +1788,8 @@ namespace smt { m_user_propagator->register_decide(r); } + void user_propagate_initialize_value(expr* var, expr* value); + bool watches_fixed(enode* n) const; bool has_split_candidate(bool_var& var, bool& is_pos); diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index 74f0bded6..2d6c29532 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -305,5 +305,9 @@ namespace smt { void kernel::user_propagate_register_decide(user_propagator::decide_eh_t& r) { m_imp->m_kernel.user_propagate_register_decide(r); } + + void kernel::user_propagate_initialize_value(expr* var, expr* value) { + m_imp->m_kernel.user_propagate_initialize_value(var, value); + } }; diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index eec74f8b1..539a32750 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -322,6 +322,8 @@ namespace smt { void user_propagate_register_decide(user_propagator::decide_eh_t& r); + void user_propagate_initialize_value(expr* var, expr* value); + /** \brief Return a reference to smt::context. This breaks abstractions. diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index f91a31111..7b9d416f3 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -252,6 +252,10 @@ namespace { m_context.user_propagate_register_decide(c); } + void user_propagate_initialize_value(expr* var, expr* value) override { + m_context.user_propagate_initialize_value(var, value); + } + struct scoped_minimize_core { smt_solver& s; expr_ref_vector m_assumptions; diff --git a/src/smt/smt_theory.h b/src/smt/smt_theory.h index d0e73cc92..416d626f7 100644 --- a/src/smt/smt_theory.h +++ b/src/smt/smt_theory.h @@ -549,6 +549,10 @@ namespace smt { return get_manager().mk_eq(lhs, rhs); } + virtual void initialize_value(expr* var, expr* value) { + IF_VERBOSE(5, verbose_stream() << "no default initialization associated with " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n"); + } + literal mk_eq(expr * a, expr * b, bool gate_ctx); literal mk_preferred_eq(expr* a, expr* b); diff --git a/src/smt/tactic/smt_tactic_core.cpp b/src/smt/tactic/smt_tactic_core.cpp index 2be2ace58..89b2b7135 100644 --- a/src/smt/tactic/smt_tactic_core.cpp +++ b/src/smt/tactic/smt_tactic_core.cpp @@ -41,6 +41,7 @@ class smt_tactic : public tactic { smt_params m_params; params_ref m_params_ref; expr_ref_vector m_vars; + vector> m_values; statistics m_stats; smt::kernel* m_ctx = nullptr; symbol m_logic; @@ -344,6 +345,8 @@ public: for (expr* v : m_vars) m_ctx->user_propagate_register_expr(v); + for (auto& [var, value] : m_values) + m_ctx->user_propagate_initialize_value(var, value); } void user_propagate_clear() override { @@ -403,6 +406,10 @@ public: void user_propagate_register_decide(user_propagator::decide_eh_t& decide_eh) override { m_decide_eh = decide_eh; } + + void user_propagate_initialize_value(expr* var, expr* value) override { + m_values.push_back({expr_ref(var, m), expr_ref(value, m)}); + } }; static tactic * mk_seq_smt_tactic(ast_manager& m, params_ref const & p) { diff --git a/src/smt/theory_arith.h b/src/smt/theory_arith.h index 86c05aec6..e68f0f53f 100644 --- a/src/smt/theory_arith.h +++ b/src/smt/theory_arith.h @@ -662,6 +662,7 @@ namespace smt { void restart_eh() override; void init_search_eh() override; + void initialize_value(expr* var, expr* value) override; /** \brief True if the assignment may be changed during final check. assume_eqs, check_int_feasibility, diff --git a/src/smt/theory_arith_aux.h b/src/smt/theory_arith_aux.h index 8141377c4..acb036d2a 100644 --- a/src/smt/theory_arith_aux.h +++ b/src/smt/theory_arith_aux.h @@ -2249,6 +2249,21 @@ namespace smt { return false; } + template + void theory_arith::initialize_value(expr* var, expr* value) { + theory_var v = expr2var(var); + rational r; + if (!m_util.is_numeral(value, r)) { + IF_VERBOSE(5, verbose_stream() << "numeric constant expected in initialization " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n"); + return; + } + if (v == null_theory_var) + return; + if (is_base(v)) + return; + update_value(v, inf_numeral(r)); + } + #if 0 /** diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index f0a96ddd1..8f334276d 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -154,6 +154,7 @@ class theory_lra::imp { svector m_asserted_atoms; ptr_vector m_not_handled; ptr_vector m_underspecified; + vector> m_values; vector > m_use_list; // bounds where variables are used. // attributes for incremental version: @@ -991,6 +992,16 @@ public: return lp().compare_values(vi, k, b->get_value()) ? l_true : l_false; } + void initialize_value(expr* var, expr* value) { + rational r; + if (!a.is_numeral(value, r)) { + IF_VERBOSE(5, verbose_stream() << "numeric constant expected in initialization " << mk_pp(var, m) << " := " << mk_pp(value, m) << "\n"); + return; + } + ctx().push_trail(push_back_vector(m_values)); + m_values.push_back({get_lpvar(var), r}); + } + void new_eq_eh(theory_var v1, theory_var v2) { TRACE("arith", tout << "eq " << v1 << " == " << v2 << "\n";); if (!is_int(v1) && !is_real(v1)) @@ -1409,6 +1420,9 @@ public: void init_search_eh() { m_arith_eq_adapter.init_search_eh(); m_num_conflicts = 0; + for (auto const& [v, r] : m_values) + lp().move_lpvar_to_value(v, r); + display(verbose_stream() << "init search\n"); } bool can_get_value(theory_var v) const { @@ -3878,6 +3892,9 @@ void theory_lra::assign_eh(bool_var v, bool is_true) { lbool theory_lra::get_phase(bool_var v) { return m_imp->get_phase(v); } +void theory_lra::initialize_value(expr* var, expr* value) { + m_imp->initialize_value(var, value); +} void theory_lra::new_eq_eh(theory_var v1, theory_var v2) { m_imp->new_eq_eh(v1, v2); } @@ -3912,7 +3929,7 @@ final_check_status theory_lra::final_check_eh() { } bool theory_lra::is_shared(theory_var v) const { return m_imp->is_shared(v); -} +} bool theory_lra::can_propagate() { return m_imp->can_propagate(); } diff --git a/src/smt/theory_lra.h b/src/smt/theory_lra.h index 4c2351c85..96988f957 100644 --- a/src/smt/theory_lra.h +++ b/src/smt/theory_lra.h @@ -80,6 +80,8 @@ namespace smt { void apply_sort_cnstr(enode * n, sort * s) override; void init_model(model_generator & m) override; + + void initialize_value(expr* var, expr* value) override; model_value_proc * mk_value(enode * n, model_generator & mg) override; void validate_model(proto_model& mdl) override; diff --git a/src/solver/combined_solver.cpp b/src/solver/combined_solver.cpp index 53aa56753..051dfd4eb 100644 --- a/src/solver/combined_solver.cpp +++ b/src/solver/combined_solver.cpp @@ -394,6 +394,11 @@ public: void user_propagate_clear() override { m_solver2->user_propagate_clear(); } + + void user_propagate_initialize_value(expr* var, expr* value) override { + m_solver2->user_propagate_initialize_value(var, value); + } + }; diff --git a/src/solver/simplifier_solver.cpp b/src/solver/simplifier_solver.cpp index 8995049f0..ae1d6b8a2 100644 --- a/src/solver/simplifier_solver.cpp +++ b/src/solver/simplifier_solver.cpp @@ -390,6 +390,8 @@ public: void user_propagate_register_expr(expr* e) override { m_preprocess_state.freeze(e); s->user_propagate_register_expr(e); } void user_propagate_register_created(user_propagator::created_eh_t& r) override { s->user_propagate_register_created(r); } void user_propagate_register_decide(user_propagator::decide_eh_t& r) override { s->user_propagate_register_decide(r); } + void user_propagate_initialize_value(expr* var, expr* value) override { m_preprocess_state.freeze(var); s->user_propagate_initialize_value(var, value); } + }; diff --git a/src/solver/tactic2solver.cpp b/src/solver/tactic2solver.cpp index 861a83185..cdabecb27 100644 --- a/src/solver/tactic2solver.cpp +++ b/src/solver/tactic2solver.cpp @@ -119,6 +119,10 @@ public: void user_propagate_register_expr(expr* e) override { m_tactic->user_propagate_register_expr(e); } + + void user_propagate_initialize_value(expr* var, expr* value) override { + m_tactic->user_propagate_initialize_value(var, value); + } void user_propagate_register_created(user_propagator::created_eh_t& created_eh) override { m_tactic->user_propagate_register_created(created_eh); diff --git a/src/tactic/tactical.cpp b/src/tactic/tactical.cpp index 0b8189e8d..626380913 100644 --- a/src/tactic/tactical.cpp +++ b/src/tactic/tactical.cpp @@ -213,6 +213,10 @@ public: m_t2->user_propagate_register_decide(decide_eh); } + void user_propagate_initialize_value(expr* var, expr* value) override { + m_t2->user_propagate_initialize_value(var, value); + } + }; tactic * and_then(tactic * t1, tactic * t2) { @@ -884,6 +888,7 @@ public: void set_progress_callback(progress_callback * callback) override { m_t->set_progress_callback(callback); } void user_propagate_register_expr(expr* e) override { m_t->user_propagate_register_expr(e); } void user_propagate_clear() override { m_t->user_propagate_clear(); } + void user_propagate_initialize_value(expr* var, expr* value) override { m_t->user_propagate_initialize_value(var, value); } protected: diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h index 58904a12d..968196f63 100644 --- a/src/tactic/user_propagator_base.h +++ b/src/tactic/user_propagator_base.h @@ -99,6 +99,10 @@ namespace user_propagator { throw default_exception("clause logging is only supported on the SMT solver"); } + virtual void user_propagate_initialize_value(expr* var, expr* value) { + throw default_exception("value initialization is only supported on the SMT solver"); + } + };