3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

split into base and plugin

This commit is contained in:
Nikolaj Bjorner 2024-07-05 20:09:46 -07:00
parent 52533130f9
commit 3ff60a4af0
6 changed files with 214 additions and 59 deletions

View file

@ -5,7 +5,8 @@ z3_add_component(ast_sls
bv_sls_eval.cpp
bv_sls_fixed.cpp
bv_sls_terms.cpp
sls_arith_int.cpp
sls_arith_base.cpp
sls_arith_plugin.cpp
sls_cc.cpp
sls_engine.cpp
sls_smt.cpp

View file

@ -3,7 +3,7 @@ Copyright (c) 2023 Microsoft Corporation
Module Name:
arith_sls_int.cpp
sls_arith_base.cpp
Abstract:
@ -15,39 +15,39 @@ Author:
--*/
#include "ast/sls/sls_arith_int.h"
#include "ast/sls/sls_arith_base.h"
#include "ast/ast_ll_pp.h"
namespace sls {
template<typename num_t>
arith_plugin<num_t>::arith_plugin(context& ctx) :
arith_base<num_t>::arith_base(context& ctx) :
plugin(ctx),
a(m) {
m_fid = a.get_family_id();
}
template<typename num_t>
void arith_plugin<num_t>::reset() {
void arith_base<num_t>::reset() {
m_bool_vars.reset();
m_vars.reset();
m_expr2var.reset();
}
template<typename num_t>
void arith_plugin<num_t>::save_best_values() {
void arith_base<num_t>::save_best_values() {
for (auto& v : m_vars)
v.m_best_value = v.m_value;
check_ineqs();
}
template<typename num_t>
void arith_plugin<num_t>::store_best_values() {
void arith_base<num_t>::store_best_values() {
}
// distance to true
template<typename num_t>
num_t arith_plugin<num_t>::dtt(bool sign, num_t const& args, ineq const& ineq) const {
num_t arith_base<num_t>::dtt(bool sign, num_t const& args, ineq const& ineq) const {
num_t zero{ 0 };
switch (ineq.m_op) {
case ineq_kind::LE:
@ -89,7 +89,7 @@ namespace sls {
// different data-structures for storing coefficients
//
template<typename num_t>
num_t arith_plugin<num_t>::dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const {
num_t arith_base<num_t>::dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const {
for (auto const& [coeff, w] : ineq.m_args)
if (w == v)
return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].m_value), ineq);
@ -97,12 +97,12 @@ namespace sls {
}
template<typename num_t>
num_t arith_plugin<num_t>::dtt(bool sign, ineq const& ineq, num_t const& coeff, num_t const& old_value, num_t const& new_value) const {
num_t arith_base<num_t>::dtt(bool sign, ineq const& ineq, num_t const& coeff, num_t const& old_value, num_t const& new_value) const {
return dtt(sign, ineq.m_args_value + coeff * (new_value - old_value), ineq);
}
template<typename num_t>
bool arith_plugin<num_t>::cm(ineq const& ineq, var_t v, num_t& new_value) {
bool arith_base<num_t>::cm(ineq const& ineq, var_t v, num_t& new_value) {
for (auto const& [coeff, w] : ineq.m_args)
if (w == v)
return cm(ineq, v, coeff, new_value);
@ -110,14 +110,14 @@ namespace sls {
}
template<typename num_t>
num_t arith_plugin<num_t>::divide(var_t v, num_t const& delta, num_t const& coeff) {
num_t arith_base<num_t>::divide(var_t v, num_t const& delta, num_t const& coeff) {
if (m_vars[v].m_kind == var_kind::REAL)
return delta / coeff;
return div(delta + abs(coeff) - 1, coeff);
}
template<typename num_t>
bool arith_plugin<num_t>::cm(ineq const& ineq, var_t v, num_t const& coeff, num_t& new_value) {
bool arith_base<num_t>::cm(ineq const& ineq, var_t v, num_t const& coeff, num_t& new_value) {
auto bound = -ineq.m_coeff;
auto argsv = ineq.m_args_value;
bool solved = false;
@ -195,7 +195,7 @@ namespace sls {
// or flip on maximal non-negative score
// or flip on first non-negative score
template<typename num_t>
void arith_plugin<num_t>::repair(sat::literal lit, ineq const& ineq) {
void arith_base<num_t>::repair(sat::literal lit, ineq const& ineq) {
num_t new_value;
if (UINT_MAX == ineq.m_var_to_flip)
dtt_reward(lit);
@ -218,7 +218,7 @@ namespace sls {
// cached dts has to be updated when the score of literals are updated.
//
template<typename num_t>
double arith_plugin<num_t>::dscore(var_t v, num_t const& new_value) const {
double arith_base<num_t>::dscore(var_t v, num_t const& new_value) const {
double score = 0;
auto const& vi = m_vars[v];
for (auto const& [coeff, bv] : vi.m_bool_vars) {
@ -238,7 +238,7 @@ namespace sls {
// - dtt_old can be saved
//
template<typename num_t>
int arith_plugin<num_t>::cm_score(var_t v, num_t const& new_value) {
int arith_base<num_t>::cm_score(var_t v, num_t const& new_value) {
int score = 0;
auto& vi = m_vars[v];
num_t old_value = vi.m_value;
@ -273,7 +273,7 @@ namespace sls {
}
template<typename num_t>
num_t arith_plugin<num_t>::compute_dts(unsigned cl) const {
num_t arith_base<num_t>::compute_dts(unsigned cl) const {
num_t d(1), d2;
bool first = true;
for (auto a : ctx.get_clause(cl)) {
@ -292,7 +292,7 @@ namespace sls {
}
template<typename num_t>
num_t arith_plugin<num_t>::dts(unsigned cl, var_t v, num_t const& new_value) const {
num_t arith_base<num_t>::dts(unsigned cl, var_t v, num_t const& new_value) const {
num_t d(1), d2;
bool first = true;
for (auto lit : ctx.get_clause(cl)) {
@ -311,7 +311,7 @@ namespace sls {
}
template<typename num_t>
void arith_plugin<num_t>::update(var_t v, num_t const& new_value) {
void arith_base<num_t>::update(var_t v, num_t const& new_value) {
auto& vi = m_vars[v];
auto old_value = vi.m_value;
if (old_value == new_value)
@ -352,7 +352,7 @@ namespace sls {
}
template<typename num_t>
typename arith_plugin<num_t>::ineq& arith_plugin<num_t>::new_ineq(ineq_kind op, num_t const& coeff) {
typename arith_base<num_t>::ineq& arith_base<num_t>::new_ineq(ineq_kind op, num_t const& coeff) {
auto* i = alloc(ineq);
i->m_coeff = coeff;
i->m_op = op;
@ -360,12 +360,12 @@ namespace sls {
}
template<typename num_t>
void arith_plugin<num_t>::add_arg(linear_term& ineq, num_t const& c, var_t v) {
void arith_base<num_t>::add_arg(linear_term& ineq, num_t const& c, var_t v) {
ineq.m_args.push_back({ c, v });
}
bool arith_plugin<checked_int64<true>>::is_num(expr* e, checked_int64<true>& i) {
bool arith_base<checked_int64<true>>::is_num(expr* e, checked_int64<true>& i) {
rational r;
if (a.is_numeral(e, r)) {
if (!r.is_int64())
@ -376,17 +376,17 @@ namespace sls {
return false;
}
bool arith_plugin<rational>::is_num(expr* e, rational& i) {
bool arith_base<rational>::is_num(expr* e, rational& i) {
return a.is_numeral(e, i);
}
template<typename num_t>
bool arith_plugin<num_t>::is_num(expr* e, num_t& i) {
bool arith_base<num_t>::is_num(expr* e, num_t& i) {
return false;
}
template<typename num_t>
void arith_plugin<num_t>::add_args(linear_term& term, expr* e, num_t const& coeff) {
void arith_base<num_t>::add_args(linear_term& term, expr* e, num_t const& coeff) {
auto v = m_expr2var.get(e->get_id(), UINT_MAX);
expr* x, * y;
num_t i;
@ -440,7 +440,7 @@ namespace sls {
}
template<typename num_t>
typename arith_plugin<num_t>::var_t arith_plugin<num_t>::mk_term(expr* e) {
typename arith_base<num_t>::var_t arith_base<num_t>::mk_term(expr* e) {
auto v = m_expr2var.get(e->get_id(), UINT_MAX);
if (v != UINT_MAX)
return v;
@ -460,7 +460,7 @@ namespace sls {
}
template<typename num_t>
unsigned arith_plugin<num_t>::mk_var(expr* e) {
unsigned arith_base<num_t>::mk_var(expr* e) {
unsigned v = m_expr2var.get(e->get_id(), UINT_MAX);
if (v == UINT_MAX) {
v = m_vars.size();
@ -471,7 +471,7 @@ namespace sls {
}
template<typename num_t>
void arith_plugin<num_t>::init_bool_var(sat::bool_var bv) {
void arith_base<num_t>::init_bool_var(sat::bool_var bv) {
if (m_bool_vars.get(bv, nullptr))
return;
expr* e = ctx.atom(bv);
@ -510,7 +510,7 @@ namespace sls {
}
template<typename num_t>
void arith_plugin<num_t>::init_ineq(sat::bool_var bv, ineq& i) {
void arith_base<num_t>::init_ineq(sat::bool_var bv, ineq& i) {
i.m_args_value = 0;
for (auto const& [coeff, v] : i.m_args) {
m_vars[v].m_bool_vars.push_back({ coeff, bv });
@ -520,14 +520,14 @@ namespace sls {
}
template<typename num_t>
void arith_plugin<num_t>::init_bool_var_assignment(sat::bool_var v) {
void arith_base<num_t>::init_bool_var_assignment(sat::bool_var v) {
auto* ineq = m_bool_vars.get(v, nullptr);
if (ineq && ctx.is_true(sat::literal(v, false)) != (dtt(false, *ineq) == 0))
ctx.flip(v);
}
template<typename num_t>
void arith_plugin<num_t>::repair(sat::literal lit) {
void arith_base<num_t>::repair(sat::literal lit) {
if (!ctx.is_true(lit))
return;
auto const* ineq = atom(lit.var());
@ -540,7 +540,7 @@ namespace sls {
}
template<typename num_t>
void arith_plugin<num_t>::repair_defs_and_updates() {
void arith_base<num_t>::repair_defs_and_updates() {
while (!m_defs_to_update.empty() || !m_vars_to_update.empty()) {
repair_updates();
repair_defs();
@ -548,7 +548,7 @@ namespace sls {
}
template<typename num_t>
void arith_plugin<num_t>::repair_updates() {
void arith_base<num_t>::repair_updates() {
while (!m_vars_to_update.empty()) {
auto [w, new_value1] = m_vars_to_update.back();
m_vars_to_update.pop_back();
@ -557,7 +557,7 @@ namespace sls {
}
template<typename num_t>
void arith_plugin<num_t>::repair_defs() {
void arith_base<num_t>::repair_defs() {
while (!m_defs_to_update.empty()) {
auto v = m_defs_to_update.back();
m_defs_to_update.pop_back();
@ -570,7 +570,7 @@ namespace sls {
}
template<typename num_t>
void arith_plugin<num_t>::repair_add(add_def const& ad) {
void arith_base<num_t>::repair_add(add_def const& ad) {
auto v = ad.m_var;
auto const& coeffs = ad.m_args;
num_t sum(ad.m_coeff);
@ -592,7 +592,7 @@ namespace sls {
}
template<typename num_t>
void arith_plugin<num_t>::repair_mul(mul_def const& md) {
void arith_base<num_t>::repair_mul(mul_def const& md) {
num_t product(1);
num_t val = value(md.m_var);
for (auto v : md.m_monomial)
@ -651,7 +651,7 @@ namespace sls {
}
template<typename num_t>
double arith_plugin<num_t>::reward(sat::literal lit) {
double arith_base<num_t>::reward(sat::literal lit) {
if (m_dscore_mode)
return dscore_reward(lit.var());
else
@ -659,7 +659,7 @@ namespace sls {
}
template<typename num_t>
double arith_plugin<num_t>::dtt_reward(sat::literal lit) {
double arith_base<num_t>::dtt_reward(sat::literal lit) {
auto* ineq = atom(lit.var());
if (!ineq)
return -1;
@ -690,7 +690,7 @@ namespace sls {
}
template<typename num_t>
double arith_plugin<num_t>::dscore_reward(sat::bool_var bv) {
double arith_base<num_t>::dscore_reward(sat::bool_var bv) {
m_dscore_mode = false;
bool old_sign = sign(bv);
sat::literal litv(bv, old_sign);
@ -715,19 +715,19 @@ namespace sls {
// switch to dscore mode
template<typename num_t>
void arith_plugin<num_t>::on_rescale() {
void arith_base<num_t>::on_rescale() {
m_dscore_mode = true;
}
template<typename num_t>
void arith_plugin<num_t>::on_restart() {
void arith_base<num_t>::on_restart() {
for (unsigned v = 0; v < ctx.num_bool_vars(); ++v)
init_bool_var_assignment(v);
check_ineqs();
}
template<typename num_t>
void arith_plugin<num_t>::check_ineqs() {
void arith_base<num_t>::check_ineqs() {
auto check_bool_var = [&](sat::bool_var bv) {
auto const* ineq = atom(bv);
if (!ineq)
@ -744,17 +744,17 @@ namespace sls {
}
template<typename num_t>
void arith_plugin<num_t>::register_term(expr* e) {
void arith_base<num_t>::register_term(expr* e) {
}
template<typename num_t>
expr_ref arith_plugin<num_t>::get_value(expr* e) {
expr_ref arith_base<num_t>::get_value(expr* e) {
auto v = mk_var(e);
return expr_ref(a.mk_numeral(rational(m_vars[v].m_value.get_int64(), rational::i64()), a.is_int(e)), m);
}
template<typename num_t>
lbool arith_plugin<num_t>::check() {
lbool arith_base<num_t>::check() {
// repair each root literal
for (sat::literal lit : ctx.root_literals())
repair(lit);
@ -769,7 +769,7 @@ namespace sls {
}
template<typename num_t>
bool arith_plugin<num_t>::is_sat() {
bool arith_base<num_t>::is_sat() {
for (auto const& clause : ctx.clauses()) {
bool sat = false;
for (auto lit : clause.m_clause) {
@ -792,7 +792,7 @@ namespace sls {
}
template<typename num_t>
std::ostream& arith_plugin<num_t>::display(std::ostream& out) const {
std::ostream& arith_base<num_t>::display(std::ostream& out) const {
for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) {
auto ineq = atom(v);
if (ineq)
@ -823,7 +823,7 @@ namespace sls {
}
template<typename num_t>
void arith_plugin<num_t>::mk_model(model& mdl) {
void arith_base<num_t>::mk_model(model& mdl) {
for (auto const& v : m_vars) {
expr* e = v.m_expr;
if (is_uninterp_const(e))
@ -832,5 +832,5 @@ namespace sls {
}
}
template class sls::arith_plugin<checked_int64<true>>;
template class sls::arith_plugin<rational>;
template class sls::arith_base<checked_int64<true>>;
template class sls::arith_base<rational>;

View file

@ -28,7 +28,7 @@ namespace sls {
// local search portion for arithmetic
template<typename num_t>
class arith_plugin : public plugin {
class arith_base : public plugin {
enum class ineq_kind { EQ, LE, LT};
enum class var_kind { INT, REAL };
typedef unsigned var_t;
@ -166,8 +166,8 @@ namespace sls {
void check_ineqs();
public:
arith_plugin(context& ctx);
~arith_plugin() override {}
arith_base(context& ctx);
~arith_base() override {}
void init_bool_var(sat::bool_var v) override;
void register_term(expr* e) override;
expr_ref get_value(expr* e) override;
@ -182,11 +182,11 @@ namespace sls {
};
inline std::ostream& operator<<(std::ostream& out, typename arith_plugin<checked_int64<true>>::ineq const& ineq) {
inline std::ostream& operator<<(std::ostream& out, typename arith_base<checked_int64<true>>::ineq const& ineq) {
return ineq.display(out);
}
inline std::ostream& operator<<(std::ostream& out, typename arith_plugin<rational>::ineq const& ineq) {
inline std::ostream& operator<<(std::ostream& out, typename arith_base<rational>::ineq const& ineq) {
return ineq.display(out);
}
}

View file

@ -0,0 +1,113 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
sls_arith_plugin.cpp
Abstract:
Local search dispatch for NIA
Author:
Nikolaj Bjorner (nbjorner) 2023-02-07
--*/
#include "ast/sls/sls_arith_plugin.h"
#include "ast/ast_ll_pp.h"
namespace sls {
void arith_plugin::init_bool_var(sat::bool_var v) {
if (!m_arith) {
try {
m_arith64->init_bool_var(v);
return;
}
catch (overflow_exception&) {
m_arith = alloc(arith_base<rational>, ctx);
return; // initialization happens on check-sat calls
}
}
m_arith->init_bool_var(v);
}
void arith_plugin::register_term(expr* e) {
if (!m_arith) {
try {
m_arith64->register_term(e);
return;
}
catch (overflow_exception&) {
m_arith = alloc(arith_base<rational>, ctx);
}
}
m_arith->register_term(e);
}
expr_ref arith_plugin::get_value(expr* e) {
if (!m_arith) {
try {
return m_arith64->get_value(e);
}
catch (overflow_exception&) {
m_arith = alloc(arith_base<rational>, ctx);
}
}
return m_arith->get_value(e);
}
lbool arith_plugin::check() {
if (!m_arith) {
try {
return m_arith64->check();
}
catch (overflow_exception&) {
m_arith = alloc(arith_base<rational>, ctx);
}
}
return m_arith->check();
}
bool arith_plugin::is_sat() {
if (!m_arith)
return m_arith64->is_sat();
return m_arith->is_sat();
}
void arith_plugin::reset() {
if (!m_arith)
m_arith64->reset();
else
m_arith->reset();
}
void arith_plugin::on_rescale() {
if (!m_arith)
m_arith64->on_rescale();
else
m_arith->on_rescale();
}
void arith_plugin::on_restart() {
if (!m_arith)
m_arith64->on_restart();
else
m_arith->on_restart();
}
std::ostream& arith_plugin::display(std::ostream& out) const {
if (!m_arith)
return m_arith64->display(out);
return m_arith->display(out);
}
void arith_plugin::mk_model(model& mdl) {
if (!m_arith)
m_arith64->mk_model(mdl);
else
m_arith->mk_model(mdl);
}
}

View file

@ -0,0 +1,43 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
sls_arith_plugin.h
Abstract:
Theory plugin for arithmetic local search
Author:
Nikolaj Bjorner (nbjorner) 2024-07-05
--*/
#pragma once
#include "ast/sls/sls_smt.h"
#include "ast/sls/sls_arith_base.h"
namespace sls {
class arith_plugin : public plugin {
scoped_ptr<arith_base<checked_int64<true>>> m_arith64;
scoped_ptr<arith_base<rational>> m_arith;
public:
arith_plugin(context& ctx) : plugin(ctx) { m_arith64 = alloc(arith_base<checked_int64<true>>,ctx); }
~arith_plugin() override {}
void init_bool_var(sat::bool_var v) override;
void register_term(expr* e) override;
expr_ref get_value(expr* e) override;
lbool check() override;
bool is_sat() override;
void reset() override;
void on_rescale() override;
void on_restart() override;
std::ostream& display(std::ostream& out) const override;
void mk_model(model& mdl) override;
};
}

View file

@ -18,7 +18,7 @@ Author:
#include "ast/sls/sls_smt.h"
#include "ast/sls/sls_cc.h"
#include "ast/sls/sls_arith_int.h"
#include "ast/sls/sls_arith_plugin.h"
namespace sls {
@ -42,8 +42,6 @@ namespace sls {
m_atoms.setx(v, e);
m_atom2bool_var.setx(e->get_id(), v, UINT_MAX);
}
typedef arith_plugin<checked_int64<true>> arith64;
void context::reset() {
m_plugins.reset();
@ -55,7 +53,7 @@ namespace sls {
m_visited.reset();
m_subterms.reset();
register_plugin(alloc(cc_plugin, *this));
register_plugin(alloc(arith64, *this));
register_plugin(alloc(arith_plugin, *this));
}
lbool context::check() {