diff --git a/scripts/mk_project.py b/scripts/mk_project.py index 88a7e6414..6fb377a80 100644 --- a/scripts/mk_project.py +++ b/scripts/mk_project.py @@ -49,7 +49,7 @@ def init_project_def(): add_lib('core_tactics', ['tactic', 'macros', 'normal_forms', 'rewriter', 'pattern'], 'tactic/core') add_lib('arith_tactics', ['core_tactics', 'sat'], 'tactic/arith') - add_lib('sat_smt', ['sat', 'euf', 'tactic', 'smt_params', 'bit_blaster'], 'sat/smt') + add_lib('sat_smt', ['sat', 'euf', 'tactic', 'solver', 'smt_params', 'bit_blaster'], 'sat/smt') add_lib('sat_tactic', ['tactic', 'sat', 'solver', 'sat_smt'], 'sat/tactic') add_lib('nlsat_tactic', ['nlsat', 'sat_tactic', 'arith_tactics'], 'nlsat/tactic') add_lib('subpaving_tactic', ['core_tactics', 'subpaving'], 'math/subpaving/tactic') diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 5c711019f..049ed9f59 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -974,7 +974,7 @@ extern "C" { Z3_TRY; LOG_Z3_solver_propagate_consequence(c, s, num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, conseq); RESET_ERROR_CODE(); - reinterpret_cast(s)->propagate(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, to_expr(conseq)); + reinterpret_cast(s)->propagate_cb(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, to_expr(conseq)); Z3_CATCH; } diff --git a/src/math/lp/lar_solver.h b/src/math/lp/lar_solver.h index df0c9a1b6..e36483d8c 100644 --- a/src/math/lp/lar_solver.h +++ b/src/math/lp/lar_solver.h @@ -134,9 +134,6 @@ class lar_solver : public column_namer { void add_row_from_term_no_constraint(const lar_term * term, unsigned term_ext_index); void add_basic_var_to_core_fields(); bool compare_values(impq const& lhs, lconstraint_kind k, const mpq & rhs); - // columns - bool column_is_int(column_index const& j) const { return column_is_int((unsigned)j); } - const impq& get_value(column_index const& j) const { return get_column_value(j); } void update_column_type_and_bound_check_on_equal(unsigned j, lconstraint_kind kind, const mpq & right_side, constraint_index constr_index, unsigned&); @@ -626,6 +623,9 @@ public: inline bool column_value_is_int(unsigned j) const { return m_mpq_lar_core_solver.m_r_x[j].is_int(); } inline static_matrix & A_r() { return m_mpq_lar_core_solver.m_r_A; } inline const static_matrix & A_r() const { return m_mpq_lar_core_solver.m_r_A; } + // columns + bool column_is_int(column_index const& j) const { return column_is_int((unsigned)j); } + const impq& get_value(column_index const& j) const { return get_column_value(j); } const impq& get_column_value(unsigned j) const { return m_mpq_lar_core_solver.m_r_x[j]; } inline var_index external_to_local(unsigned j) const { diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index 71214055d..8b5ef3fe3 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -57,11 +57,15 @@ namespace sat { protected: bool m_drating { false }; int m_id { 0 }; + solver* m_solver { nullptr }; public: extension(int id): m_id(id) {} virtual ~extension() {} - virtual int get_id() const { return m_id; } - virtual void set_solver(solver* s) = 0; + int get_id() const { return m_id; } + void set_solver(solver* s) { m_solver = s; } + solver& s() { return *m_solver; } + solver const& s() const { return *m_solver; } + virtual void set_lookahead(lookahead* s) {}; class scoped_drating { extension& ext; @@ -70,13 +74,13 @@ namespace sat { ~scoped_drating() { ext.m_drating = false; } }; virtual void init_search() {} - virtual bool propagate(literal l, ext_constraint_idx idx) = 0; - virtual bool unit_propagate() = 0; - virtual bool is_external(bool_var v) = 0; + virtual bool propagate(sat::literal l, sat::ext_constraint_idx idx) { UNREACHABLE(); return false; } + virtual bool unit_propagate() = 0; + virtual bool is_external(bool_var v) { return false; } virtual double get_reward(literal l, ext_constraint_idx idx, literal_occs_fun& occs) const { return 0; } virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r, bool probing) = 0; virtual bool is_extended_binary(ext_justification_idx idx, literal_vector & r) { return false; } - virtual void asserted(literal l) = 0; + virtual void asserted(literal l) {}; virtual check_result check() = 0; virtual lbool resolve_conflict() { return l_undef; } // stores result in sat::solver::m_lemma virtual void push() = 0; diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index a7dda80a4..e507c9a5f 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -624,6 +624,39 @@ public: m_preprocess->reset(); } + euf::solver* ensure_euf() { + auto* ext = dynamic_cast(m_solver.get_extension()); + return ext; + } + + void user_propagate_init( + void* ctx, + solver::push_eh_t& push_eh, + solver::pop_eh_t& pop_eh, + solver::fresh_eh_t& fresh_eh) override { + ensure_euf()->user_propagate_init(ctx, push_eh, pop_eh, fresh_eh); + } + + void user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh) override { + ensure_euf()->user_propagate_register_fixed(fixed_eh); + } + + void user_propagate_register_final(solver::final_eh_t& final_eh) override { + ensure_euf()->user_propagate_register_final(final_eh); + } + + void user_propagate_register_eq(solver::eq_eh_t& eq_eh) override { + ensure_euf()->user_propagate_register_eq(eq_eh); + } + + void user_propagate_register_diseq(solver::eq_eh_t& diseq_eh) override { + ensure_euf()->user_propagate_register_diseq(diseq_eh); + } + + unsigned user_propagate_register(expr* e) override { + return ensure_euf()->user_propagate_register(e); + } + private: lbool internalize_goal(goal_ref& g) { diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 7f1f54059..5556333fc 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -25,6 +25,7 @@ z3_add_component(sat_smt euf_solver.cpp sat_dual_solver.cpp sat_th.cpp + user_solver.cpp xor_solver.cpp COMPONENT_DEPENDENCIES sat diff --git a/src/sat/smt/array_solver.h b/src/sat/smt/array_solver.h index d7bd88d9d..afff61219 100644 --- a/src/sat/smt/array_solver.h +++ b/src/sat/smt/array_solver.h @@ -60,14 +60,12 @@ namespace array { array_util a; stats m_stats; - sat::solver* m_solver{ nullptr }; scoped_ptr_vector m_var_data; ast2ast_trailmap m_sort2epsilon; ast2ast_trailmap m_sort2diag; obj_map m_sort2diff; array_union_find m_find; - sat::solver& s() { return *m_solver; } theory_var find(theory_var v) { return m_find.find(v); } // internalize @@ -187,7 +185,6 @@ namespace array { public: solver(euf::solver& ctx, theory_id id); ~solver() override {} - void set_solver(sat::solver* s) override { m_solver = s; } bool is_external(bool_var v) override { return false; } bool propagate(literal l, sat::ext_constraint_idx idx) override { UNREACHABLE(); return false; } void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) override {} diff --git a/src/sat/smt/ba_solver.cpp b/src/sat/smt/ba_solver.cpp index d27726f5f..2155a7c39 100644 --- a/src/sat/smt/ba_solver.cpp +++ b/src/sat/smt/ba_solver.cpp @@ -1365,7 +1365,7 @@ namespace sat { ba_solver::ba_solver(ast_manager& m, sat::sat_internalizer& si, euf::theory_id id) : euf::th_solver(m, id), si(si), m_pb(m), - m_solver(nullptr), m_lookahead(nullptr), + m_lookahead(nullptr), m_constraint_id(0), m_ba(*this), m_sort(m_ba) { TRACE("ba", tout << this << "\n";); m_num_propagations_since_pop = 0; @@ -3579,14 +3579,14 @@ namespace sat { switch (cnstr.tag()) { case ba::tag_t::card_t: { card& c = cnstr.to_card(); - ineq.reset(offset*c.k()); + ineq.reset(static_cast(offset)*c.k()); for (literal l : c) ineq.push(l, offset); if (c.lit() != null_literal) ineq.push(~c.lit(), offset*c.k()); break; } case ba::tag_t::pb_t: { pb& p = cnstr.to_pb(); - ineq.reset(offset * p.k()); + ineq.reset(static_cast(offset) * p.k()); for (wliteral wl : p) ineq.push(wl.second, offset * wl.first); if (p.lit() != null_literal) ineq.push(~p.lit(), offset * p.k()); break; diff --git a/src/sat/smt/ba_solver.h b/src/sat/smt/ba_solver.h index 855b16410..81132ecec 100644 --- a/src/sat/smt/ba_solver.h +++ b/src/sat/smt/ba_solver.h @@ -86,7 +86,6 @@ namespace sat { sat_internalizer& si; pb_util m_pb; - solver* m_solver{ nullptr }; lookahead* m_lookahead{ nullptr }; stats m_stats; small_object_allocator m_allocator; @@ -140,9 +139,6 @@ namespace sat { void inc_parity(bool_var v); void reset_parity(bool_var v); - solver& s() const { return *m_solver; } - - // simplification routines vector> m_cnstr_use_list; @@ -400,7 +396,6 @@ namespace sat { ba_solver(euf::solver& ctx, euf::theory_id id); ba_solver(ast_manager& m, sat::sat_internalizer& si, euf::theory_id id); ~ba_solver() override; - void set_solver(solver* s) override { m_solver = s; } void set_lookahead(lookahead* l) override { m_lookahead = l; } void add_at_least(bool_var v, literal_vector const& lits, unsigned k); void add_pb_ge(bool_var v, svector const& wlits, unsigned k); diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index bd82410bc..e292e52f3 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -62,9 +62,14 @@ namespace bv { void solver::fixed_var_eh(theory_var v1) { numeral val1, val2; VERIFY(get_fixed_value(v1, val1)); + euf::enode* n1 = var2enode(v1); unsigned sz = m_bits[v1].size(); value_sort_pair key(val1, sz); theory_var v2; + if (ctx.watches_fixed(n1)) { + expr_ref value(bv.mk_numeral(val1, sz), m); + ctx.assign_fixed(n1, value, m_bits[v1]); + } bool is_current = m_fixed_var_table.find(key, v2) && v2 < static_cast(get_num_vars()) && @@ -74,12 +79,12 @@ namespace bv { if (!is_current) m_fixed_var_table.insert(key, v1); - else if (var2enode(v1)->get_root() != var2enode(v2)->get_root()) { + else if (n1->get_root() != var2enode(v2)->get_root()) { SASSERT(get_bv_size(v1) == get_bv_size(v2)); TRACE("bv", tout << "detected equality: v" << v1 << " = v" << v2 << "\n" << pp(v1) << pp(v2);); m_stats.m_num_bit2eq++; add_fixed_eq(v1, v2); - ctx.propagate(var2enode(v1), var2enode(v2), mk_bit2eq_justification(v1, v2)); + ctx.propagate(n1, var2enode(v2), mk_bit2eq_justification(v1, v2)); } } diff --git a/src/sat/smt/bv_solver.h b/src/sat/smt/bv_solver.h index e29ecffdf..82c72ca1d 100644 --- a/src/sat/smt/bv_solver.h +++ b/src/sat/smt/bv_solver.h @@ -233,11 +233,6 @@ namespace bv { unsigned_vector m_prop_queue_lim; unsigned m_prop_queue_head { 0 }; - - sat::solver* m_solver; - sat::solver& s() { return *m_solver; } - sat::solver const& s() const { return *m_solver; } - // internalize void insert_bv2a(bool_var bv, atom * a) { m_bool_var2atom.setx(bv, a, 0); } void erase_bv2a(bool_var bv) { m_bool_var2atom[bv] = 0; } @@ -327,7 +322,6 @@ namespace bv { public: solver(euf::solver& ctx, theory_id id); ~solver() override {} - void set_solver(sat::solver* s) override { m_solver = s; } void set_lookahead(sat::lookahead* s) override { } void init_search() override {} double get_reward(literal l, sat::ext_constraint_idx idx, sat::literal_occs_fun& occs) const override; diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index f80a1d25a..d7c08a14d 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -33,7 +33,6 @@ namespace euf { m_trail(*this), m_rewriter(m), m_unhandled_functions(m), - m_solver(nullptr), m_lookahead(nullptr), m_to_m(&m), m_to_si(&si), @@ -670,4 +669,27 @@ namespace euf { return true; } + void solver::user_propagate_init( + void* ctx, + ::solver::push_eh_t& push_eh, + ::solver::pop_eh_t& pop_eh, + ::solver::fresh_eh_t& fresh_eh) { + m_user_propagator = alloc(user::solver, *this); + m_user_propagator->add(ctx, push_eh, pop_eh, fresh_eh); + for (unsigned i = m_scopes.size(); i-- > 0; ) + m_user_propagator->push(); + m_solvers.push_back(m_user_propagator); + m_id2solver.setx(m_user_propagator->get_id(), m_user_propagator, nullptr); + } + + bool solver::watches_fixed(enode* n) const { + return m_user_propagator && m_user_propagator->has_fixed() && n->get_th_var(m_user_propagator->get_id()) != null_theory_var; + } + + void solver::assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain) { + theory_var v = n->get_th_var(m_user_propagator->get_id()); + m_user_propagator->new_fixed_eh(v, val, sz, explain); + } + + } diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 3fc3b6d56..ace5b383e 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -27,6 +27,7 @@ Author: #include "sat/smt/sat_th.h" #include "sat/smt/sat_dual_solver.h" #include "sat/smt/euf_ackerman.h" +#include "sat/smt/user_solver.h" #include "smt/params/smt_params.h" namespace euf { @@ -85,13 +86,12 @@ namespace euf { stats m_stats; th_rewriter m_rewriter; func_decl_ref_vector m_unhandled_functions; - - sat::solver* m_solver{ nullptr }; - sat::lookahead* m_lookahead{ nullptr }; - ast_manager* m_to_m; + sat::lookahead* m_lookahead{ nullptr }; + ast_manager* m_to_m; sat::sat_internalizer* m_to_si; scoped_ptr m_ackerman; scoped_ptr m_dual_solver; + user::solver* m_user_propagator{ nullptr }; ptr_vector m_var2expr; ptr_vector m_explain; @@ -174,6 +174,12 @@ namespace euf { constraint& eq_constraint() { return mk_constraint(m_eq, constraint::kind_t::eq); } constraint& lit_constraint() { return mk_constraint(m_lit, constraint::kind_t::lit); } + // user propagator + void check_for_user_propagator() { + if (!m_user_propagator) + throw default_exception("user propagator must be initialized"); + } + public: solver(ast_manager& m, sat::sat_internalizer& si, params_ref const& p = params_ref()); @@ -197,8 +203,7 @@ namespace euf { }; // accessors - sat::solver& s() { return *m_solver; } - sat::solver const& s() const { return *m_solver; } + sat::sat_internalizer& get_si() { return si; } ast_manager& get_manager() { return m; } enode* get_enode(expr* e) { return m_egraph.find(e); } @@ -212,7 +217,6 @@ namespace euf { euf_trail_stack& get_trail_stack() { return m_trail; } void updt_params(params_ref const& p); - void set_solver(sat::solver* s) override { m_solver = s; } void set_lookahead(sat::lookahead* s) override { m_lookahead = s; } void init_search() override; double get_reward(literal l, ext_constraint_idx idx, sat::literal_occs_fun& occs) const override; @@ -285,6 +289,40 @@ namespace euf { // diagnostics func_decl_ref_vector const& unhandled_functions() { return m_unhandled_functions; } + + // user propagator + void user_propagate_init( + void* ctx, + ::solver::push_eh_t& push_eh, + ::solver::pop_eh_t& pop_eh, + ::solver::fresh_eh_t& fresh_eh); + bool watches_fixed(enode* n) const; + void assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain); + void assign_fixed(enode* n, expr* val, literal_vector const& explain) { assign_fixed(n, val, explain.size(), explain.c_ptr()); } + void assign_fixed(enode* n, expr* val, literal explain) { assign_fixed(n, val, 1, &explain); } + + void user_propagate_register_final(::solver::final_eh_t& final_eh) { + check_for_user_propagator(); + m_user_propagator->register_final(final_eh); + } + void user_propagate_register_fixed(::solver::fixed_eh_t& fixed_eh) { + check_for_user_propagator(); + m_user_propagator->register_fixed(fixed_eh); + } + void user_propagate_register_eq(::solver::eq_eh_t& eq_eh) { + check_for_user_propagator(); + m_user_propagator->register_eq(eq_eh); + } + void user_propagate_register_diseq(::solver::eq_eh_t& diseq_eh) { + check_for_user_propagator(); + m_user_propagator->register_diseq(diseq_eh); + } + unsigned user_propagate_register(expr* e) { + check_for_user_propagator(); + return m_user_propagator->add_expr(e); + } + + }; }; diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp new file mode 100644 index 000000000..e12e57a6d --- /dev/null +++ b/src/sat/smt/user_solver.cpp @@ -0,0 +1,159 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + user_solver.cpp + +Abstract: + + User propagator plugin. + +Author: + + Nikolaj Bjorner (nbjorner) 2020-09-23 + +--*/ + +#include "sat/smt/user_solver.h" +#include "sat/smt/euf_solver.h" + +namespace user { + + solver::solver(euf::solver& ctx) : + th_euf_solver(ctx, ctx.get_manager().mk_family_id("user")) + {} + + solver::~solver() { + dealloc(m_api_context); + } + + unsigned solver::add_expr(expr* e) { + force_push(); + ctx.internalize(e, false); + euf::enode* n = expr2enode(e); + if (is_attached_to_var(n)) + return n->get_th_var(get_id()); + euf::theory_var v = mk_var(n); + ctx.attach_th_var(n, this, v); + return v; + } + + void solver::propagate_cb( + unsigned num_fixed, unsigned const* fixed_ids, + unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, + expr* conseq) { + m_prop.push_back(prop_info(num_fixed, fixed_ids, num_eqs, eq_lhs, eq_rhs, expr_ref(conseq, m))); + } + + sat::check_result solver::check() { + if (!(bool)m_final_eh) + return sat::check_result::CR_DONE; + unsigned sz = m_prop.size(); + m_final_eh(m_user_context, this); + return sz == m_prop.size() ? sat::check_result::CR_DONE : sat::check_result::CR_CONTINUE; + } + + void solver::new_fixed_eh(euf::theory_var v, expr* value, unsigned num_lits, sat::literal const* jlits) { + if (!m_fixed_eh) + return; + force_push(); + m_id2justification.setx(v, sat::literal_vector(num_lits, jlits), sat::literal_vector()); + m_fixed_eh(m_user_context, this, v, value); + } + + void solver::asserted(sat::literal lit) { + if (!m_fixed_eh) + return; + force_push(); + auto* n = bool_var2enode(lit.var()); + euf::theory_var v = n->get_th_var(get_id()); + sat::literal_vector lits; + lits.push_back(lit); + m_id2justification.setx(v, lits, sat::literal_vector()); + m_fixed_eh(m_user_context, this, v, lit.sign() ? m.mk_false() : m.mk_true()); + } + + void solver::push_core() { + th_euf_solver::push_core(); + m_prop_lim.push_back(m_prop.size()); + m_push_eh(m_user_context); + } + + void solver::pop_core(unsigned num_scopes) { + th_euf_solver::pop_core(num_scopes); + unsigned old_sz = m_prop_lim.size() - num_scopes; + m_prop.shrink(m_prop_lim[old_sz]); + m_prop_lim.shrink(old_sz); + m_pop_eh(m_user_context, num_scopes); + } + + bool solver::unit_propagate() { + if (m_qhead == m_prop.size()) + return false; + force_push(); + ctx.push(value_trail(m_qhead)); + unsigned np = m_stats.m_num_propagations; + for (; m_qhead < m_prop.size() && !s().inconsistent(); ++m_qhead) { + auto const& prop = m_prop[m_qhead]; + sat::literal lit = ctx.internalize(prop.m_conseq, false, false, true); + if (s().value(lit) != l_true) { + s().assign(lit, mk_justification(m_qhead)); + ++m_stats.m_num_propagations; + } + } + return np < m_stats.m_num_propagations; + } + + void solver::collect_statistics(::statistics& st) const { + st.update("user-propagations", m_stats.m_num_propagations); + st.update("user-watched", get_num_vars()); + } + + sat::justification solver::mk_justification(unsigned prop_idx) { + void* mem = get_region().allocate(justification::get_obj_size()); + sat::constraint_base::initialize(mem, this); + auto* constraint = new (sat::constraint_base::ptr2mem(mem)) justification(prop_idx); + return sat::justification::mk_ext_justification(s().scope_lvl(), constraint->to_index()); + } + + void solver::get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) { + auto& j = justification::from_index(idx); + auto const& prop = m_prop[j.m_propagation_index]; + for (unsigned id : prop.m_ids) + r.append(m_id2justification[id]); + for (auto const& p : prop.m_eqs) + ctx.add_antecedent(var2enode(p.first), var2enode(p.second)); + } + + std::ostream& solver::display(std::ostream& out) const { + for (unsigned i = 0; i < get_num_vars(); ++i) + out << i << ": " << mk_pp(var2expr(i), m) << "\n"; + return out; + } + + std::ostream& solver::display_justification(std::ostream& out, sat::ext_justification_idx idx) const { + auto& j = justification::from_index(idx); + auto const& prop = m_prop[j.m_propagation_index]; + for (unsigned id : prop.m_ids) + out << id << ": " << m_id2justification[id]; + for (auto const& p : prop.m_eqs) + out << "v" << p.first << " == v" << p.second << " "; + return out; + } + + std::ostream& solver::display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const { + return display_justification(out, idx); + } + + euf::th_solver* solver::fresh(sat::solver* dst_s, euf::solver& dst_ctx) { + auto* result = alloc(solver, dst_ctx); + result->set_solver(dst_s); + ast_translation tr(m, dst_ctx.get_manager(), false); + for (unsigned i = 0; i < get_num_vars(); ++i) + result->add_expr(tr(var2expr(i))); + return result; + } + +} + diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h new file mode 100644 index 000000000..c7df05ad0 --- /dev/null +++ b/src/sat/smt/user_solver.h @@ -0,0 +1,130 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + user_solver.h + +Abstract: + + User-propagator plugin. + Adds user plugins to propagate based on + terms receiving fixed values or equalities. + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-17 + +--*/ + +#pragma once + +#include "sat/smt/sat_th.h" +#include "solver/solver.h" + + +namespace user { + + class solver : public euf::th_euf_solver, public ::solver::propagate_callback { + + struct prop_info { + unsigned_vector m_ids; + expr_ref m_conseq; + svector> m_eqs; + prop_info(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, expr_ref const& c): + m_ids(num_fixed, fixed_ids), + m_conseq(c) + { + for (unsigned i = 0; i < num_eqs; ++i) + m_eqs.push_back(std::make_pair(eq_lhs[i], eq_rhs[i])); + } + }; + + struct stats { + unsigned m_num_propagations; + stats() { reset(); } + void reset() { memset(this, 0, sizeof(*this)); } + }; + + void* m_user_context; + ::solver::push_eh_t m_push_eh; + ::solver::pop_eh_t m_pop_eh; + ::solver::fresh_eh_t m_fresh_eh; + ::solver::final_eh_t m_final_eh; + ::solver::fixed_eh_t m_fixed_eh; + ::solver::eq_eh_t m_eq_eh; + ::solver::eq_eh_t m_diseq_eh; + ::solver::context_obj* m_api_context { nullptr }; + unsigned m_qhead { 0 }; + vector m_prop; + unsigned_vector m_prop_lim; + vector m_id2justification; + unsigned m_num_scopes { 0 }; + sat::literal_vector m_lits; + euf::enode_pair_vector m_eqs; + stats m_stats; + + struct justification { + unsigned m_propagation_index { 0 }; + + justification(unsigned prop_index): m_propagation_index(prop_index) {} + + sat::ext_constraint_idx to_index() const { + return sat::constraint_base::mem2base(this); + } + static justification& from_index(size_t idx) { + return *reinterpret_cast(sat::constraint_base::from_index(idx)->mem()); + } + static size_t get_obj_size() { return sat::constraint_base::obj_size(sizeof(justification)); } + }; + + sat::justification mk_justification(unsigned propagation_index); + + public: + solver(euf::solver& ctx); + + ~solver() override; + + /* + * \brief initial setup for user propagator. + */ + void add( + void* ctx, + ::solver::push_eh_t& push_eh, + ::solver::pop_eh_t& pop_eh, + ::solver::fresh_eh_t& fresh_eh) { + m_user_context = ctx; + m_push_eh = push_eh; + m_pop_eh = pop_eh; + m_fresh_eh = fresh_eh; + } + + unsigned add_expr(expr* e); + + void register_final(::solver::final_eh_t& final_eh) { m_final_eh = final_eh; } + void register_fixed(::solver::fixed_eh_t& fixed_eh) { m_fixed_eh = fixed_eh; } + void register_eq(::solver::eq_eh_t& eq_eh) { m_eq_eh = eq_eh; } + void register_diseq(::solver::eq_eh_t& diseq_eh) { m_diseq_eh = diseq_eh; } + + bool has_fixed() const { return (bool)m_fixed_eh; } + + void propagate_cb(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* lhs, unsigned const* rhs, expr* conseq) override; + + void new_fixed_eh(euf::theory_var v, expr* value, unsigned num_lits, sat::literal const* jlits); + + void asserted(sat::literal lit) override; + sat::check_result check() override; + void push_core() override; + void pop_core(unsigned n) override; + bool unit_propagate() override; + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override; + void collect_statistics(statistics& st) const override; + sat::literal internalize(expr* e, bool sign, bool root, bool learned) override { UNREACHABLE(); return sat::null_literal; } + void internalize(expr* e, bool redundant) override { UNREACHABLE(); } + std::ostream& display(std::ostream& out) const override; + std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override; + std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override; + euf::th_solver* fresh(sat::solver* s, euf::solver& ctx) override; + + }; +}; diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index 16e4cd4f0..534964344 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -2385,6 +2385,7 @@ public: TRACE("arith", tout << "v" << v << " " << be.kind() << " " << be.m_bound << "\n";); ensure_bounds(v); + if (m_unassigned_bounds[v] == 0 && !should_refine_bounds()) { TRACE("arith", tout << "return\n";); diff --git a/src/smt/user_propagator.cpp b/src/smt/user_propagator.cpp index 42450056e..22fa4e279 100644 --- a/src/smt/user_propagator.cpp +++ b/src/smt/user_propagator.cpp @@ -48,7 +48,7 @@ unsigned user_propagator::add_expr(expr* e) { return v; } -void user_propagator::propagate( +void user_propagator::propagate_cb( unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, expr* conseq) { diff --git a/src/smt/user_propagator.h b/src/smt/user_propagator.h index 9aa6d87d7..544a9ed0e 100644 --- a/src/smt/user_propagator.h +++ b/src/smt/user_propagator.h @@ -95,7 +95,7 @@ namespace smt { bool has_fixed() const { return (bool)m_fixed_eh; } - void propagate(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* lhs, unsigned const* rhs, expr* conseq) override; + void propagate_cb(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* lhs, unsigned const* rhs, expr* conseq) override; void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits); diff --git a/src/solver/solver.h b/src/solver/solver.h index a72620097..513572ab2 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -241,7 +241,7 @@ public: class propagate_callback { public: - virtual void propagate(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, expr* conseq) = 0; + virtual void propagate_cb(unsigned num_fixed, unsigned const* fixed_ids, unsigned num_eqs, unsigned const* eq_lhs, unsigned const* eq_rhs, expr* conseq) = 0; }; class context_obj { public: