diff --git a/src/ast/sls/CMakeLists.txt b/src/ast/sls/CMakeLists.txt index be26d70f0..92094ba3e 100644 --- a/src/ast/sls/CMakeLists.txt +++ b/src/ast/sls/CMakeLists.txt @@ -5,7 +5,8 @@ z3_add_component(ast_sls bv_sls_eval.cpp bv_sls_fixed.cpp bv_sls_terms.cpp - sls_arith_int.cpp + sls_arith_base.cpp + sls_arith_plugin.cpp sls_cc.cpp sls_engine.cpp sls_smt.cpp diff --git a/src/ast/sls/sls_arith_int.cpp b/src/ast/sls/sls_arith_base.cpp similarity index 88% rename from src/ast/sls/sls_arith_int.cpp rename to src/ast/sls/sls_arith_base.cpp index 6dd0d946a..8e65bb8d7 100644 --- a/src/ast/sls/sls_arith_int.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -3,7 +3,7 @@ Copyright (c) 2023 Microsoft Corporation Module Name: - arith_sls_int.cpp + sls_arith_base.cpp Abstract: @@ -15,39 +15,39 @@ Author: --*/ -#include "ast/sls/sls_arith_int.h" +#include "ast/sls/sls_arith_base.h" #include "ast/ast_ll_pp.h" namespace sls { template - arith_plugin::arith_plugin(context& ctx) : + arith_base::arith_base(context& ctx) : plugin(ctx), a(m) { m_fid = a.get_family_id(); } template - void arith_plugin::reset() { + void arith_base::reset() { m_bool_vars.reset(); m_vars.reset(); m_expr2var.reset(); } template - void arith_plugin::save_best_values() { + void arith_base::save_best_values() { for (auto& v : m_vars) v.m_best_value = v.m_value; check_ineqs(); } template - void arith_plugin::store_best_values() { + void arith_base::store_best_values() { } // distance to true template - num_t arith_plugin::dtt(bool sign, num_t const& args, ineq const& ineq) const { + num_t arith_base::dtt(bool sign, num_t const& args, ineq const& ineq) const { num_t zero{ 0 }; switch (ineq.m_op) { case ineq_kind::LE: @@ -89,7 +89,7 @@ namespace sls { // different data-structures for storing coefficients // template - num_t arith_plugin::dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const { + num_t arith_base::dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const { for (auto const& [coeff, w] : ineq.m_args) if (w == v) return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].m_value), ineq); @@ -97,12 +97,12 @@ namespace sls { } template - num_t arith_plugin::dtt(bool sign, ineq const& ineq, num_t const& coeff, num_t const& old_value, num_t const& new_value) const { + num_t arith_base::dtt(bool sign, ineq const& ineq, num_t const& coeff, num_t const& old_value, num_t const& new_value) const { return dtt(sign, ineq.m_args_value + coeff * (new_value - old_value), ineq); } template - bool arith_plugin::cm(ineq const& ineq, var_t v, num_t& new_value) { + bool arith_base::cm(ineq const& ineq, var_t v, num_t& new_value) { for (auto const& [coeff, w] : ineq.m_args) if (w == v) return cm(ineq, v, coeff, new_value); @@ -110,14 +110,14 @@ namespace sls { } template - num_t arith_plugin::divide(var_t v, num_t const& delta, num_t const& coeff) { + num_t arith_base::divide(var_t v, num_t const& delta, num_t const& coeff) { if (m_vars[v].m_kind == var_kind::REAL) return delta / coeff; return div(delta + abs(coeff) - 1, coeff); } template - bool arith_plugin::cm(ineq const& ineq, var_t v, num_t const& coeff, num_t& new_value) { + bool arith_base::cm(ineq const& ineq, var_t v, num_t const& coeff, num_t& new_value) { auto bound = -ineq.m_coeff; auto argsv = ineq.m_args_value; bool solved = false; @@ -195,7 +195,7 @@ namespace sls { // or flip on maximal non-negative score // or flip on first non-negative score template - void arith_plugin::repair(sat::literal lit, ineq const& ineq) { + void arith_base::repair(sat::literal lit, ineq const& ineq) { num_t new_value; if (UINT_MAX == ineq.m_var_to_flip) dtt_reward(lit); @@ -218,7 +218,7 @@ namespace sls { // cached dts has to be updated when the score of literals are updated. // template - double arith_plugin::dscore(var_t v, num_t const& new_value) const { + double arith_base::dscore(var_t v, num_t const& new_value) const { double score = 0; auto const& vi = m_vars[v]; for (auto const& [coeff, bv] : vi.m_bool_vars) { @@ -238,7 +238,7 @@ namespace sls { // - dtt_old can be saved // template - int arith_plugin::cm_score(var_t v, num_t const& new_value) { + int arith_base::cm_score(var_t v, num_t const& new_value) { int score = 0; auto& vi = m_vars[v]; num_t old_value = vi.m_value; @@ -273,7 +273,7 @@ namespace sls { } template - num_t arith_plugin::compute_dts(unsigned cl) const { + num_t arith_base::compute_dts(unsigned cl) const { num_t d(1), d2; bool first = true; for (auto a : ctx.get_clause(cl)) { @@ -292,7 +292,7 @@ namespace sls { } template - num_t arith_plugin::dts(unsigned cl, var_t v, num_t const& new_value) const { + num_t arith_base::dts(unsigned cl, var_t v, num_t const& new_value) const { num_t d(1), d2; bool first = true; for (auto lit : ctx.get_clause(cl)) { @@ -311,7 +311,7 @@ namespace sls { } template - void arith_plugin::update(var_t v, num_t const& new_value) { + void arith_base::update(var_t v, num_t const& new_value) { auto& vi = m_vars[v]; auto old_value = vi.m_value; if (old_value == new_value) @@ -352,7 +352,7 @@ namespace sls { } template - typename arith_plugin::ineq& arith_plugin::new_ineq(ineq_kind op, num_t const& coeff) { + typename arith_base::ineq& arith_base::new_ineq(ineq_kind op, num_t const& coeff) { auto* i = alloc(ineq); i->m_coeff = coeff; i->m_op = op; @@ -360,12 +360,12 @@ namespace sls { } template - void arith_plugin::add_arg(linear_term& ineq, num_t const& c, var_t v) { + void arith_base::add_arg(linear_term& ineq, num_t const& c, var_t v) { ineq.m_args.push_back({ c, v }); } - bool arith_plugin>::is_num(expr* e, checked_int64& i) { + bool arith_base>::is_num(expr* e, checked_int64& i) { rational r; if (a.is_numeral(e, r)) { if (!r.is_int64()) @@ -376,17 +376,17 @@ namespace sls { return false; } - bool arith_plugin::is_num(expr* e, rational& i) { + bool arith_base::is_num(expr* e, rational& i) { return a.is_numeral(e, i); } template - bool arith_plugin::is_num(expr* e, num_t& i) { + bool arith_base::is_num(expr* e, num_t& i) { return false; } template - void arith_plugin::add_args(linear_term& term, expr* e, num_t const& coeff) { + void arith_base::add_args(linear_term& term, expr* e, num_t const& coeff) { auto v = m_expr2var.get(e->get_id(), UINT_MAX); expr* x, * y; num_t i; @@ -440,7 +440,7 @@ namespace sls { } template - typename arith_plugin::var_t arith_plugin::mk_term(expr* e) { + typename arith_base::var_t arith_base::mk_term(expr* e) { auto v = m_expr2var.get(e->get_id(), UINT_MAX); if (v != UINT_MAX) return v; @@ -460,7 +460,7 @@ namespace sls { } template - unsigned arith_plugin::mk_var(expr* e) { + unsigned arith_base::mk_var(expr* e) { unsigned v = m_expr2var.get(e->get_id(), UINT_MAX); if (v == UINT_MAX) { v = m_vars.size(); @@ -471,7 +471,7 @@ namespace sls { } template - void arith_plugin::init_bool_var(sat::bool_var bv) { + void arith_base::init_bool_var(sat::bool_var bv) { if (m_bool_vars.get(bv, nullptr)) return; expr* e = ctx.atom(bv); @@ -510,7 +510,7 @@ namespace sls { } template - void arith_plugin::init_ineq(sat::bool_var bv, ineq& i) { + void arith_base::init_ineq(sat::bool_var bv, ineq& i) { i.m_args_value = 0; for (auto const& [coeff, v] : i.m_args) { m_vars[v].m_bool_vars.push_back({ coeff, bv }); @@ -520,14 +520,14 @@ namespace sls { } template - void arith_plugin::init_bool_var_assignment(sat::bool_var v) { + void arith_base::init_bool_var_assignment(sat::bool_var v) { auto* ineq = m_bool_vars.get(v, nullptr); if (ineq && ctx.is_true(sat::literal(v, false)) != (dtt(false, *ineq) == 0)) ctx.flip(v); } template - void arith_plugin::repair(sat::literal lit) { + void arith_base::repair(sat::literal lit) { if (!ctx.is_true(lit)) return; auto const* ineq = atom(lit.var()); @@ -540,7 +540,7 @@ namespace sls { } template - void arith_plugin::repair_defs_and_updates() { + void arith_base::repair_defs_and_updates() { while (!m_defs_to_update.empty() || !m_vars_to_update.empty()) { repair_updates(); repair_defs(); @@ -548,7 +548,7 @@ namespace sls { } template - void arith_plugin::repair_updates() { + void arith_base::repair_updates() { while (!m_vars_to_update.empty()) { auto [w, new_value1] = m_vars_to_update.back(); m_vars_to_update.pop_back(); @@ -557,7 +557,7 @@ namespace sls { } template - void arith_plugin::repair_defs() { + void arith_base::repair_defs() { while (!m_defs_to_update.empty()) { auto v = m_defs_to_update.back(); m_defs_to_update.pop_back(); @@ -570,7 +570,7 @@ namespace sls { } template - void arith_plugin::repair_add(add_def const& ad) { + void arith_base::repair_add(add_def const& ad) { auto v = ad.m_var; auto const& coeffs = ad.m_args; num_t sum(ad.m_coeff); @@ -592,7 +592,7 @@ namespace sls { } template - void arith_plugin::repair_mul(mul_def const& md) { + void arith_base::repair_mul(mul_def const& md) { num_t product(1); num_t val = value(md.m_var); for (auto v : md.m_monomial) @@ -651,7 +651,7 @@ namespace sls { } template - double arith_plugin::reward(sat::literal lit) { + double arith_base::reward(sat::literal lit) { if (m_dscore_mode) return dscore_reward(lit.var()); else @@ -659,7 +659,7 @@ namespace sls { } template - double arith_plugin::dtt_reward(sat::literal lit) { + double arith_base::dtt_reward(sat::literal lit) { auto* ineq = atom(lit.var()); if (!ineq) return -1; @@ -690,7 +690,7 @@ namespace sls { } template - double arith_plugin::dscore_reward(sat::bool_var bv) { + double arith_base::dscore_reward(sat::bool_var bv) { m_dscore_mode = false; bool old_sign = sign(bv); sat::literal litv(bv, old_sign); @@ -715,19 +715,19 @@ namespace sls { // switch to dscore mode template - void arith_plugin::on_rescale() { + void arith_base::on_rescale() { m_dscore_mode = true; } template - void arith_plugin::on_restart() { + void arith_base::on_restart() { for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) init_bool_var_assignment(v); check_ineqs(); } template - void arith_plugin::check_ineqs() { + void arith_base::check_ineqs() { auto check_bool_var = [&](sat::bool_var bv) { auto const* ineq = atom(bv); if (!ineq) @@ -744,17 +744,17 @@ namespace sls { } template - void arith_plugin::register_term(expr* e) { + void arith_base::register_term(expr* e) { } template - expr_ref arith_plugin::get_value(expr* e) { + expr_ref arith_base::get_value(expr* e) { auto v = mk_var(e); return expr_ref(a.mk_numeral(rational(m_vars[v].m_value.get_int64(), rational::i64()), a.is_int(e)), m); } template - lbool arith_plugin::check() { + lbool arith_base::check() { // repair each root literal for (sat::literal lit : ctx.root_literals()) repair(lit); @@ -769,7 +769,7 @@ namespace sls { } template - bool arith_plugin::is_sat() { + bool arith_base::is_sat() { for (auto const& clause : ctx.clauses()) { bool sat = false; for (auto lit : clause.m_clause) { @@ -792,7 +792,7 @@ namespace sls { } template - std::ostream& arith_plugin::display(std::ostream& out) const { + std::ostream& arith_base::display(std::ostream& out) const { for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) { auto ineq = atom(v); if (ineq) @@ -823,7 +823,7 @@ namespace sls { } template - void arith_plugin::mk_model(model& mdl) { + void arith_base::mk_model(model& mdl) { for (auto const& v : m_vars) { expr* e = v.m_expr; if (is_uninterp_const(e)) @@ -832,5 +832,5 @@ namespace sls { } } -template class sls::arith_plugin>; -template class sls::arith_plugin; \ No newline at end of file +template class sls::arith_base>; +template class sls::arith_base; diff --git a/src/ast/sls/sls_arith_int.h b/src/ast/sls/sls_arith_base.h similarity index 96% rename from src/ast/sls/sls_arith_int.h rename to src/ast/sls/sls_arith_base.h index df8b5bfc4..2996354db 100644 --- a/src/ast/sls/sls_arith_int.h +++ b/src/ast/sls/sls_arith_base.h @@ -28,7 +28,7 @@ namespace sls { // local search portion for arithmetic template - class arith_plugin : public plugin { + class arith_base : public plugin { enum class ineq_kind { EQ, LE, LT}; enum class var_kind { INT, REAL }; typedef unsigned var_t; @@ -166,8 +166,8 @@ namespace sls { void check_ineqs(); public: - arith_plugin(context& ctx); - ~arith_plugin() override {} + arith_base(context& ctx); + ~arith_base() override {} void init_bool_var(sat::bool_var v) override; void register_term(expr* e) override; expr_ref get_value(expr* e) override; @@ -182,11 +182,11 @@ namespace sls { }; - inline std::ostream& operator<<(std::ostream& out, typename arith_plugin>::ineq const& ineq) { + inline std::ostream& operator<<(std::ostream& out, typename arith_base>::ineq const& ineq) { return ineq.display(out); } - inline std::ostream& operator<<(std::ostream& out, typename arith_plugin::ineq const& ineq) { + inline std::ostream& operator<<(std::ostream& out, typename arith_base::ineq const& ineq) { return ineq.display(out); } } diff --git a/src/ast/sls/sls_arith_plugin.cpp b/src/ast/sls/sls_arith_plugin.cpp new file mode 100644 index 000000000..23b657192 --- /dev/null +++ b/src/ast/sls/sls_arith_plugin.cpp @@ -0,0 +1,113 @@ + +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + sls_arith_plugin.cpp + +Abstract: + + Local search dispatch for NIA + +Author: + + Nikolaj Bjorner (nbjorner) 2023-02-07 + +--*/ + +#include "ast/sls/sls_arith_plugin.h" +#include "ast/ast_ll_pp.h" + +namespace sls { + + void arith_plugin::init_bool_var(sat::bool_var v) { + if (!m_arith) { + try { + m_arith64->init_bool_var(v); + return; + } + catch (overflow_exception&) { + m_arith = alloc(arith_base, ctx); + return; // initialization happens on check-sat calls + } + } + m_arith->init_bool_var(v); + + } + + void arith_plugin::register_term(expr* e) { + if (!m_arith) { + try { + m_arith64->register_term(e); + return; + } + catch (overflow_exception&) { + m_arith = alloc(arith_base, ctx); + } + } + m_arith->register_term(e); + } + + expr_ref arith_plugin::get_value(expr* e) { + if (!m_arith) { + try { + return m_arith64->get_value(e); + } + catch (overflow_exception&) { + m_arith = alloc(arith_base, ctx); + } + } + return m_arith->get_value(e); + } + + lbool arith_plugin::check() { + if (!m_arith) { + try { + return m_arith64->check(); + } + catch (overflow_exception&) { + m_arith = alloc(arith_base, ctx); + } + } + return m_arith->check(); + } + + bool arith_plugin::is_sat() { + if (!m_arith) + return m_arith64->is_sat(); + return m_arith->is_sat(); + } + void arith_plugin::reset() { + if (!m_arith) + m_arith64->reset(); + else + m_arith->reset(); + } + + void arith_plugin::on_rescale() { + if (!m_arith) + m_arith64->on_rescale(); + else + m_arith->on_rescale(); + } + void arith_plugin::on_restart() { + if (!m_arith) + m_arith64->on_restart(); + else + m_arith->on_restart(); + } + + std::ostream& arith_plugin::display(std::ostream& out) const { + if (!m_arith) + return m_arith64->display(out); + return m_arith->display(out); + } + + void arith_plugin::mk_model(model& mdl) { + if (!m_arith) + m_arith64->mk_model(mdl); + else + m_arith->mk_model(mdl); + } +} diff --git a/src/ast/sls/sls_arith_plugin.h b/src/ast/sls/sls_arith_plugin.h new file mode 100644 index 000000000..494a20b9b --- /dev/null +++ b/src/ast/sls/sls_arith_plugin.h @@ -0,0 +1,43 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_arith_plugin.h + +Abstract: + + Theory plugin for arithmetic local search + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-05 + +--*/ +#pragma once + +#include "ast/sls/sls_smt.h" +#include "ast/sls/sls_arith_base.h" + +namespace sls { + + class arith_plugin : public plugin { + scoped_ptr>> m_arith64; + scoped_ptr> m_arith; + public: + arith_plugin(context& ctx) : plugin(ctx) { m_arith64 = alloc(arith_base>,ctx); } + ~arith_plugin() override {} + void init_bool_var(sat::bool_var v) override; + void register_term(expr* e) override; + expr_ref get_value(expr* e) override; + lbool check() override; + bool is_sat() override; + void reset() override; + + void on_rescale() override; + void on_restart() override; + std::ostream& display(std::ostream& out) const override; + void mk_model(model& mdl) override; + }; + +} diff --git a/src/ast/sls/sls_smt.cpp b/src/ast/sls/sls_smt.cpp index aed245807..fa696b3f7 100644 --- a/src/ast/sls/sls_smt.cpp +++ b/src/ast/sls/sls_smt.cpp @@ -18,7 +18,7 @@ Author: #include "ast/sls/sls_smt.h" #include "ast/sls/sls_cc.h" -#include "ast/sls/sls_arith_int.h" +#include "ast/sls/sls_arith_plugin.h" namespace sls { @@ -42,8 +42,6 @@ namespace sls { m_atoms.setx(v, e); m_atom2bool_var.setx(e->get_id(), v, UINT_MAX); } - - typedef arith_plugin> arith64; void context::reset() { m_plugins.reset(); @@ -55,7 +53,7 @@ namespace sls { m_visited.reset(); m_subterms.reset(); register_plugin(alloc(cc_plugin, *this)); - register_plugin(alloc(arith64, *this)); + register_plugin(alloc(arith_plugin, *this)); } lbool context::check() {