From 89bf2d43687f09afcec4ccf7e0ae5b226bfdd7cc Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 15 Feb 2019 12:05:24 -0800 Subject: [PATCH] add API for setting variable activity Signed-off-by: Nikolaj Bjorner --- src/api/api_solver.cpp | 9 +++++++ src/api/c++/z3++.h | 1 + src/api/python/z3/z3.py | 6 +++++ src/api/z3_api.h | 6 +++++ src/muz/spacer/spacer_iuc_solver.h | 1 + src/opt/opt_solver.h | 1 + src/sat/sat_solver.cpp | 6 +++++ src/sat/sat_solver.h | 1 + src/sat/sat_solver/inc_sat_solver.cpp | 9 +++++++ src/smt/smt_case_split_queue.cpp | 24 +++++++++++++++++++ src/smt/smt_case_split_queue.h | 1 + src/smt/smt_context.h | 11 ++++++++- src/smt/smt_kernel.cpp | 11 +++++++++ src/smt/smt_kernel.h | 5 ++++ src/smt/smt_solver.cpp | 4 ++++ src/solver/combined_solver.cpp | 5 ++++ src/solver/solver.h | 2 ++ src/solver/solver_pool.cpp | 4 ++++ src/solver/tactic2solver.cpp | 5 ++++ .../fd_solver/bounded_int2bv_solver.cpp | 4 ++++ src/tactic/fd_solver/enum2bv_solver.cpp | 4 ++++ src/tactic/fd_solver/pb2bv_solver.cpp | 6 +++++ 22 files changed, 125 insertions(+), 1 deletion(-) diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 1c0664bb2..cafbfb9ff 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -431,6 +431,15 @@ extern "C" { Z3_CATCH; } + void Z3_API Z3_solver_set_activity(Z3_context c, Z3_solver s, Z3_ast a, double activity) { + Z3_TRY; + LOG_Z3_solver_set_activity(c, s, a, activity); + RESET_ERROR_CODE(); + init_solver(c, s); + to_solver_ref(s)->set_activity(to_expr(a), activity); + Z3_CATCH; + } + Z3_ast_vector Z3_API Z3_solver_get_trail(Z3_context c, Z3_solver s) { Z3_TRY; LOG_Z3_solver_get_trail(c, s); diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index e711457a3..6113fdbe9 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -2245,6 +2245,7 @@ namespace z3 { check_error(); return result; } + void set_activity(expr const& lit, double act) { Z3_solver_set_activity(ctx(), m_solver, lit, act); } expr proof() const { Z3_ast r = Z3_solver_get_proof(ctx(), m_solver); check_error(); return expr(ctx(), r); } friend std::ostream & operator<<(std::ostream & out, solver const & s); diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 7185f520b..f8651ab90 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -6744,6 +6744,12 @@ class Solver(Z3PPObject): """ return AstVector(Z3_solver_get_trail(self.ctx.ref(), self.solver), self.ctx) + def set_activity(self, lit, act): + """Set activity of literal on solver object. + This influences the case split order of the variable. + """ + Z3_solver_set_activity(self.ctx.ref(), self.solver, lit.ast, act) + def statistics(self): """Return statistics for the last `check()`. diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 6f65e4a37..e1710b499 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -6236,6 +6236,12 @@ extern "C" { */ void Z3_API Z3_solver_get_levels(Z3_context c, Z3_solver s, Z3_ast_vector literals, unsigned sz, unsigned levels[]); + /** + \brief set activity score associated with literal. + + def_API('Z3_solver_set_activity', VOID, (_in(CONTEXT), _in(SOLVER), _in(AST), _in(DOUBLE))) + */ + void Z3_API Z3_solver_set_activity(Z3_context c, Z3_solver s, Z3_ast l, double activity); /** \brief Check whether the assertions in a given solver are consistent or not. diff --git a/src/muz/spacer/spacer_iuc_solver.h b/src/muz/spacer/spacer_iuc_solver.h index c3561a6d4..9b50b4c4e 100644 --- a/src/muz/spacer/spacer_iuc_solver.h +++ b/src/muz/spacer/spacer_iuc_solver.h @@ -124,6 +124,7 @@ public: expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); } void get_levels(ptr_vector const& vars, unsigned_vector& depth) override { m_solver.get_levels(vars, depth); } expr_ref_vector get_trail() override { return m_solver.get_trail(); } + void set_activity(expr* lit, double act) override { m_solver.set_activity(lit, act); } void push() override; void pop(unsigned n) override; diff --git a/src/opt/opt_solver.h b/src/opt/opt_solver.h index 9eda063e9..be71376ac 100644 --- a/src/opt/opt_solver.h +++ b/src/opt/opt_solver.h @@ -110,6 +110,7 @@ namespace opt { lbool preferred_sat(expr_ref_vector const& asms, vector& cores) override; void get_levels(ptr_vector const& vars, unsigned_vector& depth) override; expr_ref_vector get_trail() override { return m_context.get_trail(); } + void set_activity(expr* lit, double act) override { m_context.set_activity(lit, act); } expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); } void set_logic(symbol const& logic); diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 9af8774e0..d917d1ae6 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -1619,6 +1619,12 @@ namespace sat { return tracking_assumptions() && m_assumption_set.contains(l); } + void solver::set_activity(bool_var v, unsigned act) { + unsigned old_act = m_activity[v]; + m_activity[v] = act; + m_case_split_queue.activity_changed_eh(v, act > old_act); + } + bool solver::is_assumption(bool_var v) const { return is_assumption(literal(v, false)) || is_assumption(literal(v, true)); } diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index a977bbbaa..89b161c9f 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -387,6 +387,7 @@ namespace sat { char const* get_reason_unknown() const { return m_reason_unknown.c_str(); } bool check_clauses(model const& m) const; bool is_assumption(bool_var v) const; + void set_activity(bool_var v, unsigned act); lbool cube(bool_var_vector& vars, literal_vector& lits, unsigned backtrack_level); diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index f0712da38..8a0747d15 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -341,6 +341,15 @@ public: return result; } + void set_activity(expr* var, double activity) override { + m.is_not(var, var); + sat::bool_var v = m_map.to_bool_var(var); + if (v == sat::null_bool_var) { + throw default_exception("literal does not correspond to a Boolean variable"); + } + m_solver.set_activity(v, activity); + } + proof * get_proof() override { UNREACHABLE(); return nullptr; diff --git a/src/smt/smt_case_split_queue.cpp b/src/smt/smt_case_split_queue.cpp index 6290ed14d..a4cb68506 100644 --- a/src/smt/smt_case_split_queue.cpp +++ b/src/smt/smt_case_split_queue.cpp @@ -80,6 +80,11 @@ namespace smt { m_queue.decreased(v); } + void activity_decreased_eh(bool_var v) override { + if (m_queue.contains(v)) + m_queue.increased(v); + } + void mk_var_eh(bool_var v) override { m_queue.reserve(v+1); SASSERT(!m_queue.contains(v)); @@ -167,6 +172,14 @@ namespace smt { m_delayed_queue.decreased(v); } + void activity_decreased_eh(bool_var v) override { + act_case_split_queue::activity_decreased_eh(v); + if (m_queue.contains(v)) + m_queue.increased(v); + if (m_delayed_queue.contains(v)) + m_delayed_queue.increased(v); + } + void mk_var_eh(bool_var v) override { m_queue.reserve(v+1); m_delayed_queue.reserve(v+1); @@ -324,6 +337,8 @@ namespace smt { void activity_increased_eh(bool_var v) override {} + void activity_decreased_eh(bool_var v) override {} + void mk_var_eh(bool_var v) override {} void del_var_eh(bool_var v) override {} @@ -509,6 +524,8 @@ namespace smt { void activity_increased_eh(bool_var v) override {} + void activity_decreased_eh(bool_var v) override {} + void mk_var_eh(bool_var v) override { if (m_context.is_searching()) { SASSERT(v >= m_bs_num_bool_vars); @@ -753,6 +770,8 @@ namespace smt { void activity_increased_eh(bool_var v) override {} + void activity_decreased_eh(bool_var v) override {} + void mk_var_eh(bool_var v) override {} void del_var_eh(bool_var v) override {} @@ -1133,6 +1152,11 @@ namespace smt { m_queue.decreased(v); } + void activity_decreased_eh(bool_var v) override { + if (m_queue.contains(v)) + m_queue.increased(v); + } + void mk_var_eh(bool_var v) override { m_queue.reserve(v+1); m_queue.insert(v); diff --git a/src/smt/smt_case_split_queue.h b/src/smt/smt_case_split_queue.h index cfa33bfe2..3bad083c8 100644 --- a/src/smt/smt_case_split_queue.h +++ b/src/smt/smt_case_split_queue.h @@ -32,6 +32,7 @@ namespace smt { class case_split_queue { public: virtual void activity_increased_eh(bool_var v) = 0; + virtual void activity_decreased_eh(bool_var v) = 0; virtual void mk_var_eh(bool_var v) = 0; virtual void del_var_eh(bool_var v) = 0; virtual void assign_lit_eh(literal l) {} diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index d935bf53e..dc9e43dca 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -412,10 +412,19 @@ namespace smt { return m_activity[v]; } - void set_activity(bool_var v, double & act) { + void set_activity(bool_var v, double const & act) { m_activity[v] = act; } + void activity_changed(bool_var v, bool increased) { + if (increased) { + m_case_split_queue->activity_increased_eh(v); + } + else { + m_case_split_queue->activity_decreased_eh(v); + } + } + bool is_assumption(bool_var v) const { return get_bdata(v).m_assumption; } diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index f2f17321c..571032039 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -154,6 +154,13 @@ namespace smt { expr_ref_vector get_trail() { return m_kernel.get_trail(); } + + void set_activity(expr* lit, double act) { + auto v = m_kernel.get_bool_var(lit); + double old_act = m_kernel.get_activity(v); + m_kernel.set_activity(v, act); + m_kernel.activity_changed(v, act > old_act); + } failure last_failure() const { return m_kernel.get_last_search_failure(); @@ -412,5 +419,9 @@ namespace smt { return m_imp->get_trail(); } + void kernel::set_activity(expr* lit, double activity) { + m_imp->set_activity(lit, activity); + } + }; diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index 6eeb8d728..a46195e02 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -229,6 +229,11 @@ namespace smt { */ expr_ref_vector get_trail(); + /** + \brief set activity of literal + */ + void set_activity(expr* lit, double activity); + /** \brief (For debubbing purposes) Prints the state of the kernel */ diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index e36858b06..0052dd316 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -203,6 +203,10 @@ namespace smt { return m_context.get_trail(); } + void set_activity(expr* lit, double activity) override { + m_context.set_activity(lit, activity); + } + struct scoped_minimize_core { smt_solver& s; expr_ref_vector m_assumptions; diff --git a/src/solver/combined_solver.cpp b/src/solver/combined_solver.cpp index b602fe6e1..b939efc6b 100644 --- a/src/solver/combined_solver.cpp +++ b/src/solver/combined_solver.cpp @@ -328,6 +328,11 @@ public: return m_solver2->get_trail(); } + void set_activity(expr* lit, double activity) override { + m_solver1->set_activity(lit, activity); + m_solver2->set_activity(lit, activity); + } + proof * get_proof() override { if (m_use_solver1_results) return m_solver1->get_proof(); diff --git a/src/solver/solver.h b/src/solver/solver.h index 0c509b8c7..be3751bcd 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -251,6 +251,8 @@ public: virtual void get_levels(ptr_vector const& vars, unsigned_vector& depth) = 0; + virtual void set_activity(expr* lit, double activity) = 0; + class scoped_push { solver& s; bool m_nopop; diff --git a/src/solver/solver_pool.cpp b/src/solver/solver_pool.cpp index 7f3882447..4d6724aa7 100644 --- a/src/solver/solver_pool.cpp +++ b/src/solver/solver_pool.cpp @@ -127,6 +127,10 @@ public: return m_base->get_trail(); } + void set_activity(expr* var, double activity) override { + m_base->set_activity(var, activity); + } + lbool check_sat_core2(unsigned num_assumptions, expr * const * assumptions) override { SASSERT(!m_pushed || get_scope_level() > 0); m_proof.reset(); diff --git a/src/solver/tactic2solver.cpp b/src/solver/tactic2solver.cpp index 8721e106b..c62e920e7 100644 --- a/src/solver/tactic2solver.cpp +++ b/src/solver/tactic2solver.cpp @@ -93,6 +93,11 @@ public: throw default_exception("cannot retrieve trail from solvers created using tactcis"); } + void set_activity(expr* var, double activity) override { + throw default_exception("cannot set activity for solvers created using tactcis"); + } + + }; ast_manager& tactic2solver::get_manager() const { return m_assertions.get_manager(); } diff --git a/src/tactic/fd_solver/bounded_int2bv_solver.cpp b/src/tactic/fd_solver/bounded_int2bv_solver.cpp index b19875737..0decaec82 100644 --- a/src/tactic/fd_solver/bounded_int2bv_solver.cpp +++ b/src/tactic/fd_solver/bounded_int2bv_solver.cpp @@ -161,6 +161,10 @@ public: expr_ref_vector get_trail() override { return m_solver->get_trail(); } + void set_activity(expr* var, double activity) override { + m_solver->set_activity(var, activity); + } + model_converter* external_model_converter() const { return concat(mc0(), local_model_converter()); } diff --git a/src/tactic/fd_solver/enum2bv_solver.cpp b/src/tactic/fd_solver/enum2bv_solver.cpp index d8119ddfd..b232d8ea3 100644 --- a/src/tactic/fd_solver/enum2bv_solver.cpp +++ b/src/tactic/fd_solver/enum2bv_solver.cpp @@ -186,6 +186,10 @@ public: return m_solver->get_trail(); } + void set_activity(expr* var, double activity) override { + m_solver->set_activity(var, activity); + } + unsigned get_num_assertions() const override { return m_solver->get_num_assertions(); } diff --git a/src/tactic/fd_solver/pb2bv_solver.cpp b/src/tactic/fd_solver/pb2bv_solver.cpp index 29aa573e3..c6866ecf4 100644 --- a/src/tactic/fd_solver/pb2bv_solver.cpp +++ b/src/tactic/fd_solver/pb2bv_solver.cpp @@ -96,6 +96,7 @@ public: if (mc) (*mc)(mdl); } } + void get_levels(ptr_vector const& vars, unsigned_vector& depth) override { m_solver->get_levels(vars, depth); } @@ -104,9 +105,14 @@ public: return m_solver->get_trail(); } + void set_activity(expr* var, double activity) override { + m_solver->set_activity(var, activity); + } + model_converter* external_model_converter() const{ return concat(mc0(), local_model_converter()); } + model_converter_ref get_model_converter() const override { model_converter_ref mc = external_model_converter(); mc = concat(mc.get(), m_solver->get_model_converter().get());