From fb1509d011d8509c277ce0b51ccd98d29586ff84 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 2 Feb 2021 14:29:06 -0800 Subject: [PATCH] expose internal API for set_phase --- src/muz/spacer/spacer_iuc_solver.h | 1 + src/opt/opt_solver.h | 1 + src/sat/sat_solver.h | 3 ++- src/sat/sat_solver/inc_sat_solver.cpp | 7 +++++++ src/smt/smt_solver.cpp | 3 +++ src/solver/combined_solver.cpp | 2 ++ src/solver/solver.h | 2 +- src/solver/solver_pool.cpp | 1 + src/solver/tactic2solver.cpp | 1 + src/tactic/fd_solver/bounded_int2bv_solver.cpp | 1 + src/tactic/fd_solver/enum2bv_solver.cpp | 1 + src/tactic/fd_solver/pb2bv_solver.cpp | 1 + src/tactic/fd_solver/smtfd_solver.cpp | 2 ++ 13 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/muz/spacer/spacer_iuc_solver.h b/src/muz/spacer/spacer_iuc_solver.h index a153e6af6..fc55d31ac 100644 --- a/src/muz/spacer/spacer_iuc_solver.h +++ b/src/muz/spacer/spacer_iuc_solver.h @@ -120,6 +120,7 @@ public: void set_produce_models(bool f) override { m_solver.set_produce_models(f); } void assert_expr_core(expr *t) override { m_solver.assert_expr(t); } void assert_expr_core2(expr *t, expr *a) override { NOT_IMPLEMENTED_YET(); } + void set_phase(expr* e) override { m_solver.set_phase(e); } expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); } void get_levels(ptr_vector const& vars, unsigned_vector& depth) override { m_solver.get_levels(vars, depth); } expr_ref_vector get_trail() override { return m_solver.get_trail(); } diff --git a/src/opt/opt_solver.h b/src/opt/opt_solver.h index bdef80765..396de5aea 100644 --- a/src/opt/opt_solver.h +++ b/src/opt/opt_solver.h @@ -110,6 +110,7 @@ namespace opt { void get_levels(ptr_vector const& vars, unsigned_vector& depth) override; expr_ref_vector get_trail() override { return m_context.get_trail(); } expr_ref_vector cube(expr_ref_vector&, unsigned) override { return expr_ref_vector(m); } + void set_phase(expr* e) override { NOT_IMPLEMENTED_YET(); } void set_logic(symbol const& logic); diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 6725057c3..08cd63aaf 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -352,7 +352,8 @@ namespace sat { bool was_eliminated(bool_var v) const { return m_eliminated[v]; } void set_eliminated(bool_var v, bool f) override; bool was_eliminated(literal l) const { return was_eliminated(l.var()); } - void set_phase(literal l) override { m_phase[l.var()] = !l.sign(); } + void set_phase(literal l) override { m_best_phase[l.var()] = m_phase[l.var()] = !l.sign(); } + void set_phase(bool_var b, bool sign) { set_phase(literal(b, sign)); } unsigned scope_lvl() const { return m_scope_lvl; } unsigned search_lvl() const { return m_search_lvl; } bool at_search_lvl() const { return m_scope_lvl == m_search_lvl; } diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index 61b74169d..1dbdc3582 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -294,6 +294,13 @@ public: } } + void set_phase(expr* e) override { + bool is_not = m.is_not(e, e); + sat::bool_var b = m_map.to_bool_var(e); + if (b != sat::null_bool_var) + m_solver.set_phase(b, is_not); + } + unsigned get_scope_level() const override { return m_num_scopes; } diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index c13b9ebda..58a3768c4 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -158,6 +158,9 @@ namespace { void assert_expr_core(expr * t) override { m_context.assert_expr(t); } + void set_phase(expr* e) override { + NOT_IMPLEMENTED_YET(); + } void assert_expr_core2(expr * t, expr * a) override { if (m_name2assertion.contains(a)) { diff --git a/src/solver/combined_solver.cpp b/src/solver/combined_solver.cpp index 4252d45c2..9369d5882 100644 --- a/src/solver/combined_solver.cpp +++ b/src/solver/combined_solver.cpp @@ -135,6 +135,8 @@ public: return r; } + void set_phase(expr* e) override { m_solver1->set_phase(e); m_solver2->set_phase(e); } + void updt_params(params_ref const & p) override { solver::updt_params(p); m_solver1->updt_params(p); diff --git a/src/solver/solver.h b/src/solver/solver.h index 9db168281..078d402a2 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -111,7 +111,7 @@ public: for (expr* e : ts) assert_expr(e); } -// void set_phase(expr* e) = 0; + virtual void set_phase(expr* e) = 0; void assert_expr(ptr_vector const& ts) { for (expr* e : ts) assert_expr(e); diff --git a/src/solver/solver_pool.cpp b/src/solver/solver_pool.cpp index 16764ba56..997c3343d 100644 --- a/src/solver/solver_pool.cpp +++ b/src/solver/solver_pool.cpp @@ -68,6 +68,7 @@ public: } solver* base_solver() { return m_base.get(); } + void set_phase(expr* e) override { m_base->set_phase(e); } solver* translate(ast_manager& m, params_ref const& p) override { UNREACHABLE(); return nullptr; } void updt_params(params_ref const& p) override { diff --git a/src/solver/tactic2solver.cpp b/src/solver/tactic2solver.cpp index f68131bab..0985478d3 100644 --- a/src/solver/tactic2solver.cpp +++ b/src/solver/tactic2solver.cpp @@ -79,6 +79,7 @@ public: unsigned get_num_assertions() const override; expr * get_assertion(unsigned idx) const override; + void set_phase(expr* e) override { } expr_ref_vector cube(expr_ref_vector& vars, unsigned ) override { diff --git a/src/tactic/fd_solver/bounded_int2bv_solver.cpp b/src/tactic/fd_solver/bounded_int2bv_solver.cpp index 6c749f694..2df5d0003 100644 --- a/src/tactic/fd_solver/bounded_int2bv_solver.cpp +++ b/src/tactic/fd_solver/bounded_int2bv_solver.cpp @@ -151,6 +151,7 @@ public: void set_progress_callback(progress_callback * callback) override { m_solver->set_progress_callback(callback); } void collect_statistics(statistics & st) const override { m_solver->collect_statistics(st); } void get_unsat_core(expr_ref_vector & r) override { m_solver->get_unsat_core(r); } + void set_phase(expr* e) override { m_solver->set_phase(e); } void get_model_core(model_ref & mdl) override { m_solver->get_model(mdl); if (mdl) { diff --git a/src/tactic/fd_solver/enum2bv_solver.cpp b/src/tactic/fd_solver/enum2bv_solver.cpp index 3efe3e511..b922a4bf9 100644 --- a/src/tactic/fd_solver/enum2bv_solver.cpp +++ b/src/tactic/fd_solver/enum2bv_solver.cpp @@ -89,6 +89,7 @@ public: void set_progress_callback(progress_callback * callback) override { m_solver->set_progress_callback(callback); } void collect_statistics(statistics & st) const override { m_solver->collect_statistics(st); } void get_unsat_core(expr_ref_vector & r) override { m_solver->get_unsat_core(r); } + void set_phase(expr* e) override { m_solver->set_phase(e); } void get_model_core(model_ref & mdl) override { m_solver->get_model(mdl); if (mdl) { diff --git a/src/tactic/fd_solver/pb2bv_solver.cpp b/src/tactic/fd_solver/pb2bv_solver.cpp index 9e17dd711..b574e4789 100644 --- a/src/tactic/fd_solver/pb2bv_solver.cpp +++ b/src/tactic/fd_solver/pb2bv_solver.cpp @@ -96,6 +96,7 @@ public: if (mc) (*mc)(mdl); } } + void set_phase(expr* e) override { m_solver->set_phase(e); } void get_levels(ptr_vector const& vars, unsigned_vector& depth) override { m_solver->get_levels(vars, depth); diff --git a/src/tactic/fd_solver/smtfd_solver.cpp b/src/tactic/fd_solver/smtfd_solver.cpp index 04a4a4da7..7cd27cd62 100644 --- a/src/tactic/fd_solver/smtfd_solver.cpp +++ b/src/tactic/fd_solver/smtfd_solver.cpp @@ -2031,6 +2031,8 @@ namespace smtfd { return r; } + void set_phase(expr* e) override {} + void updt_params(params_ref const & p) override { ::solver::updt_params(p); if (m_fd_sat_solver) {