3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

adding basic plugin

This commit is contained in:
Nikolaj Bjorner 2024-07-09 19:47:58 -07:00
parent ef54feec3d
commit 8357ac1cfc
19 changed files with 819 additions and 230 deletions

View file

@ -7,7 +7,8 @@ z3_add_component(ast_sls
sat_ddfw.cpp
sls_arith_base.cpp
sls_arith_plugin.cpp
sls_bv.cpp
sls_basic_plugin.cpp
sls_bv_plugin.cpp
sls_cc.cpp
sls_engine.cpp
sls_smt.cpp

View file

@ -26,7 +26,7 @@ namespace bv {
{}
void sls_eval::init_eval(std::function<bool(expr*, unsigned)> const& eval) {
for (expr* e : terms.subterms()) {
for (expr* e : ctx.subterms()) {
if (!is_app(e))
continue;
app* a = to_app(e);
@ -68,6 +68,7 @@ namespace bv {
m_tmp2.push_back(0);
m_tmp3.push_back(0);
m_tmp4.push_back(0);
m_mul_tmp.push_back(0);
m_zero.push_back(0);
m_one.push_back(0);
m_a.push_back(0);
@ -272,31 +273,46 @@ namespace bv {
break;
}
case OP_BAND: {
SASSERT(e->get_num_args() == 2);
SASSERT(e->get_num_args() >= 2);
auto const& a = wval(e->get_arg(0));
auto const& b = wval(e->get_arg(1));
for (unsigned i = 0; i < a.nw; ++i)
val.eval[i] = a.bits()[i] & b.bits()[i];
for (unsigned j = 2; j < e->get_num_args(); ++j) {
auto const& c = wval(e->get_arg(j));
for (unsigned i = 0; i < a.nw; ++i)
val.eval[i] &= c.bits()[i];
}
break;
}
case OP_BOR: {
SASSERT(e->get_num_args() == 2);
SASSERT(e->get_num_args() >= 2);
auto const& a = wval(e->get_arg(0));
auto const& b = wval(e->get_arg(1));
for (unsigned i = 0; i < a.nw; ++i)
val.eval[i] = a.bits()[i] | b.bits()[i];
for (unsigned j = 2; j < e->get_num_args(); ++j) {
auto const& c = wval(e->get_arg(j));
for (unsigned i = 0; i < a.nw; ++i)
val.eval[i] |= c.bits()[i];
}
break;
}
case OP_BXOR: {
SASSERT(e->get_num_args() == 2);
SASSERT(e->get_num_args() >= 2);
auto const& a = wval(e->get_arg(0));
auto const& b = wval(e->get_arg(1));
for (unsigned i = 0; i < a.nw; ++i)
val.eval[i] = a.bits()[i] ^ b.bits()[i];
for (unsigned j = 2; j < e->get_num_args(); ++j) {
auto const& c = wval(e->get_arg(j));
for (unsigned i = 0; i < a.nw; ++i)
val.eval[i] ^= c.bits()[i];
}
break;
}
case OP_BNAND: {
SASSERT(e->get_num_args() == 2);
VERIFY(e->get_num_args() == 2);
auto const& a = wval(e->get_arg(0));
auto const& b = wval(e->get_arg(1));
for (unsigned i = 0; i < a.nw; ++i)
@ -304,10 +320,15 @@ namespace bv {
break;
}
case OP_BADD: {
SASSERT(e->get_num_args() == 2);
SASSERT(e->get_num_args() >= 2);
auto const& a = wval(e->get_arg(0));
auto const& b = wval(e->get_arg(1));
val.set_add(val.eval, a.bits(), b.bits());
for (unsigned i = 0; i < a.nw; ++i)
val.set_add(val.eval, a.bits(), b.bits());
for (unsigned j = 2; j < e->get_num_args(); ++j) {
auto const& c = wval(e->get_arg(j));
val.set_add(val.eval, val.eval, c.bits());
}
break;
}
case OP_BSUB: {
@ -318,11 +339,14 @@ namespace bv {
break;
}
case OP_BMUL: {
SASSERT(e->get_num_args() == 2);
auto const& a = wval(e->get_arg(0));
auto const& b = wval(e->get_arg(1));
val.set_mul(m_tmp2, a.bits(), b.bits());
val.set(m_tmp2);
for (unsigned i = 0; i < a.nw; ++i)
val.set_mul(val.eval, a.bits(), b.bits());
for (unsigned j = 2; j < e->get_num_args(); ++j) {
auto const& c = wval(e->get_arg(j));
val.set_mul(val.eval, val.eval, c.bits());
}
break;
}
case OP_CONCAT: {
@ -600,17 +624,43 @@ namespace bv {
bool sls_eval::try_repair_bv(app* e, unsigned i) {
switch (e->get_decl_kind()) {
case OP_BAND:
return try_repair_band(eval_value(e), wval(e, i), wval(e, 1 - i));
SASSERT(e->get_num_args() >= 2);
if (e->get_num_args() == 2)
return try_repair_band(eval_value(e), wval(e, i), wval(e, 1 - i));
else
return try_repair_band(e, i);
case OP_BOR:
return try_repair_bor(eval_value(e), wval(e, i), wval(e, 1 - i));
SASSERT(e->get_num_args() >= 2);
if (e->get_num_args() == 2)
return try_repair_bor(eval_value(e), wval(e, i), wval(e, 1 - i));
else
return try_repair_bor(e, i);
case OP_BXOR:
return try_repair_bxor(eval_value(e), wval(e, i), wval(e, 1 - i));
SASSERT(e->get_num_args() >= 2);
if (e->get_num_args() == 2)
return try_repair_bxor(eval_value(e), wval(e, i), wval(e, 1 - i));
else
return try_repair_bxor(e, i);
case OP_BADD:
return try_repair_add(eval_value(e), wval(e, i), wval(e, 1 - i));
SASSERT(e->get_num_args() >= 2);
if (e->get_num_args() == 2)
return try_repair_add(eval_value(e), wval(e, i), wval(e, 1 - i));
else
return try_repair_add(e, i);
case OP_BSUB:
return try_repair_sub(eval_value(e), wval(e, 0), wval(e, 1), i);
case OP_BMUL:
return try_repair_mul(eval_value(e), wval(e, i), wval(e, 1 - i));
SASSERT(e->get_num_args() >= 2);
if (e->get_num_args() == 2)
return try_repair_mul(eval_value(e), wval(e, i), eval_value(to_app(e->get_arg(1 - i))));
else {
auto const& a = wval(e, 0);
auto f = [&](bvect& out, bvval const& c) {
a.set_mul(out, out, c.bits());
};
fold_oper(m_mul_tmp, e, i, f);
return try_repair_mul(eval_value(e), wval(e, i), m_mul_tmp);
}
case OP_BNOT:
return try_repair_bnot(eval_value(e), wval(e, i));
case OP_BNEG:
@ -734,8 +784,9 @@ namespace bv {
case OP_BSDIV_I:
case OP_BSDIV0:
// these are currently compiled to udiv and urem.
UNREACHABLE();
return false;
// there is an equation that enforces equality between the semantics
// of these operators.
return true;
default:
return false;
}
@ -787,6 +838,19 @@ namespace bv {
}
}
void sls_eval::fold_oper(bvect& out, app* t, unsigned i, std::function<void(bvect&, bvval const&)> const& f) {
auto i2 = i == 0 ? 1 : 0;
auto const& c = wval(t->get_arg(i2));
for (unsigned j = 0; j < c.nw; ++j)
out[j] = c.bits()[j];
for (unsigned k = 1; k < t->get_num_args(); ++k) {
if (k == i || k == i2)
continue;
bvval const& c = wval(t->get_arg(k));
f(out, c);
}
}
//
// e = a & b
// e[i] = 1 -> a[i] = 1
@ -800,6 +864,21 @@ namespace bv {
return a.set_repair(random_bool(), m_tmp);
}
bool sls_eval::try_repair_band(app* t, unsigned i) {
bvect const& e = eval_value(t);
auto f = [&](bvect& out, bvval const& c) {
for (unsigned j = 0; j < c.nw; ++j)
out[j] &= c.bits()[j];
};
fold_oper(m_tmp2, t, i, f);
bvval& a = wval(t, i);
for (unsigned j = 0; j < a.nw; ++j)
m_tmp[j] = ~a.fixed[j] & (e[j] | (~m_tmp2[j] & random_bits()));
return a.set_repair(random_bool(), m_tmp);
}
//
// e = a | b
// set a[i] to 1 where b[i] = 0, e[i] = 1
@ -811,6 +890,20 @@ namespace bv {
return a.set_repair(random_bool(), m_tmp);
}
bool sls_eval::try_repair_bor(app* t, unsigned i) {
bvect const& e = eval_value(t);
auto f = [&](bvect& out, bvval const& c) {
for (unsigned j = 0; j < c.nw; ++j)
out[j] |= c.bits()[j];
};
fold_oper(m_tmp2, t, i, f);
bvval& a = wval(t, i);
for (unsigned j = 0; j < a.nw; ++j)
m_tmp[j] = e[i] & (~m_tmp2[i] | random_bits());
return a.set_repair(random_bool(), m_tmp);
}
bool sls_eval::try_repair_bxor(bvect const& e, bvval& a, bvval const& b) {
for (unsigned i = 0; i < a.nw; ++i)
m_tmp[i] = e[i] ^ b.bits()[i];
@ -818,6 +911,23 @@ namespace bv {
}
bool sls_eval::try_repair_bxor(app* t, unsigned i) {
bvect const& e = eval_value(t);
auto f = [&](bvect& out, bvval const& c) {
for (unsigned j = 0; j < c.nw; ++j)
out[j] ^= c.bits()[j];
};
fold_oper(m_tmp2, t, i, f);
bvval& a = wval(t, i);
for (unsigned j = 0; j < a.nw; ++j)
m_tmp[j] = e[i] ^ m_tmp2[i];
return a.set_repair(random_bool(), m_tmp);
}
//
// first try to set a := e - b
// If this fails, set a to a random value
@ -831,6 +941,22 @@ namespace bv {
return a.set_random(m_rand);
}
bool sls_eval::try_repair_add(app* t, unsigned i) {
bvval& a = wval(t, i);
bvect const& e = eval_value(t);
if (m_rand(20) != 0) {
auto f = [&](bvect& out, bvval const& c) {
a.set_add(m_tmp2, m_tmp2, c.bits());
};
fold_oper(m_tmp2, t, i, f);
a.set_sub(m_tmp, e, m_tmp2);
if (a.try_set(m_tmp))
return true;
}
return a.set_random(m_rand);
}
bool sls_eval::try_repair_sub(bvect const& e, bvval& a, bvval & b, unsigned i) {
if (m_rand(20) != 0) {
if (i == 0)
@ -850,11 +976,11 @@ namespace bv {
* e = a*b, then a = e * b^-1
* 8*e = a*(2b), then a = 4e*b^-1
*/
bool sls_eval::try_repair_mul(bvect const& e, bvval& a, bvval const& b) {
unsigned parity_e = b.parity(e);
unsigned parity_b = b.parity(b.bits());
bool sls_eval::try_repair_mul(bvect const& e, bvval& a, bvect const& b) {
unsigned parity_e = a.parity(e);
unsigned parity_b = a.parity(b);
if (b.is_zero(e)) {
if (a.is_zero(e)) {
a.get_variant(m_tmp, m_rand);
if (m_rand(10) != 0)
for (unsigned i = 0; i < b.bw - parity_b; ++i)
@ -862,7 +988,7 @@ namespace bv {
return a.set_repair(random_bool(), m_tmp);
}
if (b.is_zero() || m_rand(20) == 0) {
if (m_rand(20) == 0) {
a.get_variant(m_tmp, m_rand);
return a.set_repair(random_bool(), m_tmp);
}
@ -890,9 +1016,9 @@ namespace bv {
// x*ta + y*tb = x
b.get(y);
b.copy_to(a.nw, y);
if (parity_b > 0) {
b.shift_right(y, parity_b);
a.shift_right(y, parity_b);
#if 0
for (unsigned i = parity_b; i < b.bw; ++i)
y.set(i, m_rand(2) == 0);
@ -937,15 +1063,15 @@ namespace bv {
tb.set_bw(0);
#if Z3DEBUG
b.get(y);
b.copy_to(a.nw, y);
if (parity_b > 0)
b.shift_right(y, parity_b);
a.shift_right(y, parity_b);
a.set_mul(m_tmp, tb, y);
SASSERT(a.is_one(m_tmp));
#endif
e.copy_to(b.nw, m_tmp2);
if (parity_e > 0 && parity_b > 0)
b.shift_right(m_tmp2, std::min(parity_b, parity_e));
a.shift_right(m_tmp2, std::min(parity_b, parity_e));
a.set_mul(m_tmp, tb, m_tmp2);
if (a.set_repair(random_bool(), m_tmp))
return true;
@ -1773,17 +1899,16 @@ namespace bv {
return expr_ref(m);
}
std::ostream& sls_eval::display(std::ostream& out, expr_ref_vector const& es) {
#if 0
auto& terms = sort_assertions(es);
std::ostream& sls_eval::display(std::ostream& out) {
auto& terms = ctx.subterms();
for (expr* e : terms) {
if (!bv.is_bv(e))
continue;
out << e->get_id() << ": " << mk_bounded_pp(e, m, 1) << " ";
if (is_fixed0(e))
out << "f ";
display_value(out, e) << "\n";
}
terms.reset();
#endif
return out;
}

View file

@ -47,7 +47,7 @@ namespace bv {
scoped_ptr_vector<sls_valuation> m_values; // expr-id -> bv valuation
mutable bvect m_tmp, m_tmp2, m_tmp3, m_tmp4, m_zero, m_one, m_minus_one;
mutable bvect m_tmp, m_tmp2, m_tmp3, m_tmp4, m_mul_tmp, m_zero, m_one, m_minus_one;
bvect m_a, m_b, m_nextb, m_nexta, m_aux;
using bvval = sls_valuation;
@ -64,16 +64,21 @@ namespace bv {
//bool bval1_basic(app* e) const;
bool bval1_bv(app* e) const;
void fold_oper(bvect& out, app* e, unsigned i, std::function<void(bvect&, bvval const&)> const& f);
/**
* Repair operations
*/
bool try_repair_bv(app * e, unsigned i);
bool try_repair_band(bvect const& e, bvval& a, bvval const& b);
bool try_repair_band(app* t, unsigned i);
bool try_repair_bor(bvect const& e, bvval& a, bvval const& b);
bool try_repair_bor(app* t, unsigned i);
bool try_repair_add(bvect const& e, bvval& a, bvval const& b);
bool try_repair_add(app* t, unsigned i);
bool try_repair_sub(bvect const& e, bvval& a, bvval& b, unsigned i);
bool try_repair_mul(bvect const& e, bvval& a, bvval const& b);
bool try_repair_mul(bvect const& e, bvval& a, bvect const& b);
bool try_repair_bxor(bvect const& e, bvval& a, bvval const& b);
bool try_repair_bxor(app* t, unsigned i);
bool try_repair_bnot(bvect const& e, bvval& a);
bool try_repair_bneg(bvect const& e, bvval& a);
bool try_repair_ule(bool e, bvval& a, bvval const& b);
@ -125,7 +130,7 @@ namespace bv {
/**
* Retrieve evaluation based on immediate children.
*/
bool bval1(app* e) const;
bool can_eval1(app* e) const;
public:
@ -158,6 +163,8 @@ namespace bv {
bool re_eval_is_correct(app* e);
expr_ref get_value(app* e);
bool bval1(app* e) const;
/*
* Try to invert value of child to repair value assignment of parent.
@ -171,7 +178,7 @@ namespace bv {
bool repair_up(expr* e);
std::ostream& display(std::ostream& out, expr_ref_vector const& es);
std::ostream& display(std::ostream& out);
std::ostream& display_value(std::ostream& out, expr* e);
};

View file

@ -28,7 +28,7 @@ namespace bv {
{}
void sls_fixed::init() {
for (auto e : terms.subterms())
for (auto e : ctx.subterms())
set_fixed(e);
for (auto const& c : ctx.clauses()) {
@ -37,13 +37,12 @@ namespace bv {
auto a = ctx.atom(lit.var());
if (!a)
continue;
a = terms.translated(a);
if (is_app(a))
init_range(to_app(a), lit.sign());
ev.m_fixed.setx(a->get_id(), true, false);
}
}
for (auto e : terms.subterms())
for (auto e : ctx.subterms())
propagate_range_up(e);
}

View file

@ -3,14 +3,11 @@ Copyright (c) 2024 Microsoft Corporation
Module Name:
bv_sls.cpp
bv_sls_terms.cpp
Abstract:
A Stochastic Local Search (SLS) engine
Uses invertibility conditions,
interval annotations
don't care annotations
normalize bit-vector expressions to use only binary operators.
Author:
@ -19,7 +16,7 @@ Author:
--*/
#include "ast/ast_ll_pp.h"
#include "ast/sls/bv_sls.h"
#include "ast/sls/bv_sls_terms.h"
#include "ast/rewriter/bool_rewriter.h"
#include "ast/rewriter/bv_rewriter.h"
@ -29,38 +26,16 @@ namespace bv {
ctx(ctx),
m(ctx.get_manager()),
bv(m),
m_translated(m) {}
m_axioms(m) {}
void sls_terms::init() {
for (auto t : ctx.subterms())
ensure_binary(t);
m_subterms.reset();
expr_fast_mark1 visited;
for (auto t : ctx.subterms())
m_subterms.push_back(translated(t));
for (auto t : m_subterms)
visited.mark(t, true);
for (unsigned i = 0; i < m_subterms.size(); ++i) {
auto t = m_subterms[i];
if (!is_app(t))
continue;
app* a = to_app(t);
for (expr* arg : *a) {
if (visited.is_marked(arg))
continue;
visited.mark(arg, true);
m_subterms.push_back(arg);
}
}
std::stable_sort(m_subterms.begin(), m_subterms.end(),
[](expr* a, expr* b) { return a->get_id() < b->get_id(); });
void sls_terms::register_term(expr* e) {
auto r = ensure_binary(e);
if (r != e)
m_axioms.push_back(m.mk_eq(e, r));
}
void sls_terms::ensure_binary(expr* e) {
if (m_translated.get(e->get_id(), nullptr))
return;
expr_ref sls_terms::ensure_binary(expr* e) {
app* a = to_app(e);
auto arg = [&](unsigned i) {
return a->get_arg(i);
@ -72,22 +47,7 @@ namespace bv {
for (unsigned i = 1; i < num_args; ++i)\
r = oper(r, arg(i)); \
if (bv.is_bv_and(e)) {
FOLD_OP(bv.mk_bv_and);
}
else if (bv.is_bv_or(e)) {
FOLD_OP(bv.mk_bv_or);
}
else if (bv.is_bv_xor(e)) {
FOLD_OP(bv.mk_bv_xor);
}
else if (bv.is_bv_add(e)) {
FOLD_OP(bv.mk_bv_add);
}
else if (bv.is_bv_mul(e)) {
FOLD_OP(bv.mk_bv_mul);
}
else if (bv.is_concat(e)) {
if (bv.is_concat(e)) {
FOLD_OP(bv.mk_concat);
}
else if (bv.is_bv_sdiv(e) || bv.is_bv_sdiv0(e) || bv.is_bv_sdivi(e)) {
@ -101,7 +61,7 @@ namespace bv {
}
else
r = e;
m_translated.setx(e->get_id(), r);
return r;
}
expr_ref sls_terms::mk_sdiv(expr* x, expr* y) {
@ -118,14 +78,16 @@ namespace bv {
unsigned sz = bv.get_bv_size(x);
rational N = rational::power_of_two(sz);
expr_ref z(bv.mk_zero(sz), m);
expr* signx = bvr.mk_ule(bv.mk_numeral(N / 2, sz), x);
expr* signy = bvr.mk_ule(bv.mk_numeral(N / 2, sz), y);
expr* absx = br.mk_ite(signx, bvr.mk_bv_neg(x), x);
expr* absy = br.mk_ite(signy, bvr.mk_bv_neg(y), y);
expr* d = bv.mk_bv_udiv(absx, absy);
expr_ref r(br.mk_ite(br.mk_eq(signx, signy), d, bvr.mk_bv_neg(d)), m);
expr_ref o(bv.mk_one(sz), m);
expr_ref n1(bv.mk_numeral(N - 1, sz), m);
expr_ref signx = bvr.mk_ule(bv.mk_numeral(N / 2, sz), x);
expr_ref signy = bvr.mk_ule(bv.mk_numeral(N / 2, sz), y);
expr_ref absx = br.mk_ite(signx, bvr.mk_bv_neg(x), x);
expr_ref absy = br.mk_ite(signy, bvr.mk_bv_neg(y), y);
expr_ref d = expr_ref(bv.mk_bv_udiv(absx, absy), m);
expr_ref r = br.mk_ite(br.mk_eq(signx, signy), d, bvr.mk_bv_neg(d));
r = br.mk_ite(br.mk_eq(z, y),
br.mk_ite(signx, bv.mk_one(sz), bv.mk_numeral(N - 1, sz)),
br.mk_ite(signx, o, n1),
br.mk_ite(br.mk_eq(x, z), z, r));
return r;
}
@ -142,9 +104,9 @@ namespace bv {
bv_rewriter bvr(m);
unsigned sz = bv.get_bv_size(x);
expr_ref z(bv.mk_zero(sz), m);
expr_ref abs_x(br.mk_ite(bvr.mk_sle(z, x), x, bvr.mk_bv_neg(x)), m);
expr_ref abs_y(br.mk_ite(bvr.mk_sle(z, y), y, bvr.mk_bv_neg(y)), m);
expr_ref u(bvr.mk_bv_urem(abs_x, abs_y), m);
expr_ref abs_x = br.mk_ite(bvr.mk_sle(z, x), x, bvr.mk_bv_neg(x));
expr_ref abs_y = br.mk_ite(bvr.mk_sle(z, y), y, bvr.mk_bv_neg(y));
expr_ref u = bvr.mk_bv_urem(abs_x, abs_y);
expr_ref r(m);
r = br.mk_ite(br.mk_eq(u, z), z,
br.mk_ite(br.mk_eq(y, z), x,

View file

@ -32,10 +32,9 @@ namespace bv {
sls::context& ctx;
ast_manager& m;
bv_util bv;
expr_ref_vector m_translated;
ptr_vector<expr> m_subterms;
expr_ref_vector m_axioms;
void ensure_binary(expr* e);
expr_ref ensure_binary(expr* e);
expr_ref mk_sdiv(expr* x, expr* y);
expr_ref mk_smod(expr* x, expr* y);
@ -44,14 +43,8 @@ namespace bv {
public:
sls_terms(sls::context& ctx);
/**
* Initialize structures: assertions, parents, terms
*/
void init();
expr* translated(expr* e) const { return m_translated.get(e->get_id(), nullptr); }
ptr_vector<expr> const& subterms() const { return m_subterms; }
void register_term(expr* e);
expr_ref_vector& axioms() { return m_axioms; }
};
}

View file

@ -27,13 +27,6 @@ namespace sls {
m_fid = a.get_family_id();
}
template<typename num_t>
void arith_base<num_t>::reset() {
m_bool_vars.reset();
m_vars.reset();
m_expr2var.reset();
}
template<typename num_t>
void arith_base<num_t>::save_best_values() {
for (auto& v : m_vars)
@ -1070,11 +1063,6 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::mk_model(model& mdl) {
for (auto const& v : m_vars) {
expr* e = v.m_expr;
if (is_uninterp_const(e))
mdl.register_decl(to_app(e)->get_decl(), get_value(e));
}
}
}

View file

@ -189,7 +189,6 @@ namespace sls {
expr_ref get_value(expr* e) override;
lbool check() override;
bool is_sat() override;
void reset() override;
void on_rescale() override;
void on_restart() override;
std::ostream& display(std::ostream& out) const override;

View file

@ -86,13 +86,6 @@ namespace sls {
return m_arith64->is_sat();
return m_arith->is_sat();
}
void arith_plugin::reset() {
if (m_arith)
m_arith->reset();
else
m_arith64->reset();
m_shared.reset();
}
void arith_plugin::on_rescale() {
if (m_arith)

View file

@ -36,7 +36,6 @@ namespace sls {
expr_ref get_value(expr* e) override;
lbool check() override;
bool is_sat() override;
void reset() override;
void on_rescale() override;
void on_restart() override;

View file

@ -0,0 +1,313 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_basic_plugin.cpp
Abstract:
Local search dispatch for Boolean connectives
Author:
Nikolaj Bjorner (nbjorner) 2024-07-07
--*/
#include "ast/sls/sls_basic_plugin.h"
#include "ast/ast_ll_pp.h"
namespace sls {
expr_ref basic_plugin::get_value(expr* e) {
return expr_ref(m.mk_bool_val(bval0(e)), m);
}
lbool basic_plugin::check() {
init();
for (sat::literal lit : ctx.root_literals())
repair_literal(lit);
repair_defs_and_updates();
return ctx.unsat().empty() ? l_true : l_undef;
}
void basic_plugin::init() {
m_repair_down = UINT_MAX;
m_repair_roots.reset();
m_repair_up.reset();
if (m_initialized)
return;
m_initialized = true;
for (auto t : ctx.subterms())
if (is_app(t) && m.is_bool(t) && to_app(t)->get_family_id() == basic_family_id)
m_values.setx(t->get_id(), bval1(to_app(t)), false);
}
bool basic_plugin::is_sat() {
for (auto t : ctx.subterms())
if (is_app(t) &&
m.is_bool(t) &&
to_app(t)->get_family_id() == basic_family_id &&
bval0(t) != bval1(to_app(t)))
return false;
return true;
}
std::ostream& basic_plugin::display(std::ostream& out) const {
for (auto t : ctx.subterms())
if (is_app(t) && m.is_bool(t) && to_app(t)->get_family_id() == basic_family_id)
out << mk_bounded_pp(t, m) << " " << bval0(t) << " ~ " << bval1(to_app(t)) << "\n";
return out;
}
void basic_plugin::set_value(expr* e, expr* v) {
if (!m.is_bool(e))
return;
SASSERT(m.is_bool(v));
SASSERT(m.is_true(v) || m.is_false(v));
if (bval0(e) != m.is_true(v))
return;
set_value(e, m.is_true(v));
m_repair_roots.insert(e->get_id());
}
bool basic_plugin::bval1(app* e) const {
SASSERT(m.is_bool(e));
SASSERT(e->get_family_id() == basic_family_id);
auto id = e->get_id();
switch (e->get_decl_kind()) {
case OP_TRUE:
return true;
case OP_FALSE:
return false;
case OP_AND:
return all_of(*to_app(e), [&](expr* arg) { return bval0(arg); });
case OP_OR:
return any_of(*to_app(e), [&](expr* arg) { return bval0(arg); });
case OP_NOT:
return !bval0(e->get_arg(0));
case OP_XOR: {
bool r = false;
for (auto* arg : *to_app(e))
r ^= bval0(arg);
return r;
}
case OP_IMPLIES: {
auto a = e->get_arg(0);
auto b = e->get_arg(1);
return !bval0(a) || bval0(b);
}
case OP_ITE: {
auto c = bval0(e->get_arg(0));
return bval0(c ? e->get_arg(1) : e->get_arg(2));
}
case OP_EQ: {
auto a = e->get_arg(0);
auto b = e->get_arg(1);
if (m.is_bool(a))
return bval0(a) == bval0(b);
return ctx.get_value(a) == ctx.get_value(b);
}
case OP_DISTINCT: {
for (unsigned i = 0; i < e->get_num_args(); ++i)
for (unsigned j = i + 1; j < e->get_num_args(); ++j)
if (ctx.get_value(e->get_arg(i)) == ctx.get_value(e->get_arg(j)))
return false;
return true;
}
default:
verbose_stream() << mk_bounded_pp(e, m) << "\n";
UNREACHABLE();
break;
}
UNREACHABLE();
return false;
}
bool basic_plugin::bval0(expr* e) const {
SASSERT(m.is_bool(e));
sat::bool_var v = ctx.atom2bool_var(e);
if (v == sat::null_bool_var)
return m_values.get(e->get_id(), false);
else
return ctx.is_true(sat::literal(v, false));
}
bool basic_plugin::try_repair(app* e, unsigned i) {
switch (e->get_decl_kind()) {
case OP_AND:
return try_repair_and_or(e, i);
case OP_OR:
return try_repair_and_or(e, i);
case OP_NOT:
return try_repair_not(e);
case OP_FALSE:
return false;
case OP_TRUE:
return false;
case OP_EQ:
return try_repair_eq(e, i);
case OP_IMPLIES:
return try_repair_implies(e, i);
case OP_XOR:
return try_repair_xor(e, i);
case OP_ITE:
return try_repair_ite(e, i);
case OP_DISTINCT:
NOT_IMPLEMENTED_YET();
return false;
default:
UNREACHABLE();
return false;
}
}
bool basic_plugin::try_repair_and_or(app* e, unsigned i) {
auto b = bval0(e);
auto child = e->get_arg(i);
if (b == bval0(child))
return false;
set_value(child, b);
return true;
}
bool basic_plugin::try_repair_not(app* e) {
auto child = e->get_arg(0);
set_value(child, !bval0(e));
return true;
}
bool basic_plugin::try_repair_eq(app* e, unsigned i) {
auto child = e->get_arg(i);
auto sibling = e->get_arg(1 - i);
if (!m.is_bool(child))
return false;
set_value(child, bval0(e) == bval0(sibling));
return true;
}
bool basic_plugin::try_repair_xor(app* e, unsigned i) {
bool ev = bval0(e);
bool bv = bval0(e->get_arg(1 - i));
auto child = e->get_arg(i);
set_value(child, ev != bv);
return true;
}
bool basic_plugin::try_repair_ite(app* e, unsigned i) {
auto child = e->get_arg(i);
bool c = bval0(e->get_arg(0));
if (i == 0) {
set_value(child, !c);
return true;
}
if (c != (i == 1))
return false;
if (m.is_bool(e)) {
set_value(child, bval0(e));
return true;
}
return false;
}
bool basic_plugin::try_repair_implies(app* e, unsigned i) {
auto child = e->get_arg(i);
bool ev = bval0(e);
bool av = bval0(child);
bool bv = bval0(e->get_arg(1 - i));
if (i == 0) {
if (ev == (!av || bv))
return false;
}
else if (ev != (!bv || av))
return false;
set_value(child, ev);
return true;
}
bool basic_plugin::repair_up(expr* e) {
if (!m.is_bool(e))
return false;
auto b = bval1(to_app(e));
set_value(e, b);
return true;
}
void basic_plugin::repair_down(app* e) {
SASSERT(m.is_bool(e));
unsigned n = e->get_num_args();
if (n == 0 || e->get_family_id() != m.get_basic_family_id()) {
for (auto p : ctx.parents(e))
m_repair_up.insert(p->get_id());
ctx.set_value(e, m.mk_bool_val(bval0(e)));
return;
}
if (bval0(e) == bval1(e))
return;
unsigned s = ctx.rand(n);
for (unsigned i = 0; i < n; ++i) {
auto j = (i + s) % n;
if (try_repair(e, j)) {
m_repair_down = e->get_arg(j)->get_id();
return;
}
}
m_repair_up.insert(e->get_id());
}
void basic_plugin::repair_defs_and_updates() {
if (!m_repair_roots.empty() ||
!m_repair_up.empty() ||
m_repair_down != UINT_MAX) {
while (m_repair_down != UINT_MAX) {
auto e = ctx.term(m_repair_down);
repair_down(to_app(e));
}
while (!m_repair_up.empty()) {
auto id = m_repair_up.elem_at(rand() % m_repair_up.size());
auto e = ctx.term(id);
m_repair_up.remove(id);
repair_up(to_app(e));
}
if (!m_repair_roots.empty()) {
auto id = m_repair_roots.elem_at(rand() % m_repair_roots.size());
m_repair_roots.remove(id);
m_repair_down = id;
}
}
}
void basic_plugin::set_value(expr* e, bool b) {
sat::bool_var v = ctx.atom2bool_var(e);
if (v == sat::null_bool_var) {
if (m_values.get(e->get_id(), b) != b) {
m_values.set(e->get_id(), b);
ctx.set_value(e, m.mk_bool_val(b));
}
}
else if (ctx.is_true(sat::literal(v, false)) != b) {
ctx.flip(v);
ctx.set_value(e, m.mk_bool_val(b));
}
}
void basic_plugin::repair_literal(sat::literal lit) {
if (!ctx.is_true(lit))
return;
auto a = ctx.atom(lit.var());
if (!a || !is_app(a))
return;
if (to_app(a)->get_family_id() != basic_family_id)
return;
if (bval1(to_app(a)) != bval0(to_app(a)))
m_repair_roots.insert(a->get_id());
}
}

View file

@ -0,0 +1,61 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
sls_basic_plugin.h
Author:
Nikolaj Bjorner (nbjorner) 2024-07-05
--*/
#pragma once
#include "ast/sls/sls_smt.h"
namespace sls {
class basic_plugin : public plugin {
bool_vector m_values;
indexed_uint_set m_repair_up, m_repair_roots;
unsigned m_repair_down = UINT_MAX;
bool m_initialized = false;
void init();
bool bval1(app* e) const;
bool bval0(expr* e) const;
bool repair_up(expr* e);
bool try_repair(app* e, unsigned i);
bool try_repair_and_or(app* e, unsigned i);
bool try_repair_not(app* e);
bool try_repair_eq(app* e, unsigned i);
bool try_repair_xor(app* e, unsigned i);
bool try_repair_ite(app* e, unsigned i);
bool try_repair_implies(app* e, unsigned i);
void set_value(expr* e, bool b);
void repair_down(app* e);
void repair_defs_and_updates();
void repair_literal(sat::literal lit);
public:
basic_plugin(context& ctx) :
plugin(ctx) {
}
~basic_plugin() override {}
void init_bool_var(sat::bool_var v) override {}
void register_term(expr* e) override {}
expr_ref get_value(expr* e) override;
lbool check() override;
bool is_sat() override;
void on_rescale() override {}
void on_restart() override {}
std::ostream& display(std::ostream& out) const override;
void mk_model(model& mdl) override {}
void set_shared(expr* e) override {}
void set_value(expr* e, expr* v) override;
};
}

View file

@ -1,63 +0,0 @@
#include "ast/sls/sls_bv.h"
namespace sls {
bv_plugin::bv_plugin(context& ctx):
plugin(ctx),
bv(m),
m_terms(ctx),
m_eval(m_terms, ctx)
{}
void bv_plugin::init_bool_var(sat::bool_var v) {
}
void bv_plugin::register_term(expr* e) {
}
expr_ref bv_plugin::get_value(expr* e) {
return expr_ref(m);
}
lbool bv_plugin::check() {
return l_undef;
}
bool bv_plugin::is_sat() {
return false;
}
void bv_plugin::reset() {
}
void bv_plugin::on_rescale() {
}
void bv_plugin::on_restart() {
}
std::ostream& bv_plugin::display(std::ostream& out) const {
return out;
}
void bv_plugin::mk_model(model& mdl) {
}
void bv_plugin::set_shared(expr* e) {
}
void bv_plugin::set_value(expr* e, expr* v) {
}
std::pair<bool, app*> bv_plugin::next_to_repair() {
return { false, nullptr };
}
}

View file

@ -0,0 +1,210 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_bv_plugin.cpp
Abstract:
Theory plugin for bit-vector local search
Author:
Nikolaj Bjorner (nbjorner) 2024-07-06
--*/
#include "ast/sls/sls_bv_plugin.h"
#include "ast/ast_ll_pp.h"
namespace sls {
bv_plugin::bv_plugin(context& ctx):
plugin(ctx),
bv(m),
m_terms(ctx),
m_eval(m_terms, ctx)
{}
void bv_plugin::register_term(expr* e) {
m_terms.register_term(e);
}
expr_ref bv_plugin::get_value(expr* e) {
return expr_ref(m);
}
lbool bv_plugin::check() {
if (!m_initialized) {
auto eval = [&](expr* e, unsigned idx) { return false; };
m_eval.init_eval(eval);
m_initialized = true;
}
auto& axioms = m_terms.axioms();
if (!axioms.empty()) {
for (auto* e : axioms)
ctx.add_constraint(e);
axioms.reset();
return l_undef;
}
// repair each root literal
for (sat::literal lit : ctx.root_literals())
repair_literal(lit);
repair_defs_and_updates();
// update literal assignment based on current model
for (unsigned v = 0; v < ctx.num_bool_vars(); ++v)
init_bool_var_assignment(v);
return ctx.unsat().empty() ? l_true : l_undef;
}
void bv_plugin::repair_literal(sat::literal lit) {
if (!ctx.is_true(lit))
return;
auto a = ctx.atom(lit.var());
if (!a || !is_app(a))
return;
if (to_app(a)->get_family_id() != bv.get_family_id())
return;
if (!m_eval.eval_is_correct(to_app(a)))
m_repair_roots.insert(a->get_id());
}
void bv_plugin::repair_defs_and_updates() {
if (!m_repair_roots.empty() ||
!m_repair_up.empty() ||
m_repair_down != UINT_MAX) {
while (m_repair_down != UINT_MAX) {
auto e = ctx.term(m_repair_down);
try_repair_down(to_app(e));
}
while (!m_repair_up.empty()) {
auto id = m_repair_up.elem_at(rand() % m_repair_up.size());
auto e = ctx.term(id);
m_repair_up.remove(id);
try_repair_up(to_app(e));
}
if (!m_repair_roots.empty()) {
auto id = m_repair_roots.elem_at(rand() % m_repair_roots.size());
m_repair_roots.remove(id);
m_repair_down = id;
}
}
}
void bv_plugin::init_bool_var_assignment(sat::bool_var v) {
auto a = ctx.atom(v);
if (!a || !is_app(a))
return;
if (to_app(a)->get_family_id() != bv.get_family_id())
return;
bool is_true = m_eval.bval1(to_app(a));
if (is_true != ctx.is_true(sat::literal(v, false)))
ctx.flip(v);
}
bool bv_plugin::is_sat() {
return false;
}
std::ostream& bv_plugin::display(std::ostream& out) const {
// m_eval.display(out);
return out;
}
void bv_plugin::set_shared(expr* e) {
}
void bv_plugin::set_value(expr* e, expr* v) {
}
void bv_plugin::try_repair_down(app* e) {
unsigned n = e->get_num_args();
if (n == 0 || m_eval.eval_is_correct(e)) {
m_eval.commit_eval(e);
if (!m.is_bool(e))
for (auto p : ctx.parents(e))
m_repair_up.insert(p->get_id());
return;
}
if (m.is_bool(e)) {
NOT_IMPLEMENTED_YET();
return;
}
if (n == 2) {
auto d1 = get_depth(e->get_arg(0));
auto d2 = get_depth(e->get_arg(1));
unsigned s = ctx.rand(d1 + d2 + 2);
if (s <= d1 && m_eval.try_repair(e, 0)) {
set_repair_down(e->get_arg(0));
return;
}
if (m_eval.try_repair(e, 1)) {
set_repair_down(e->get_arg(1));
return;
}
if (m_eval.try_repair(e, 0)) {
set_repair_down(e->get_arg(0));
return;
}
}
else {
unsigned s = ctx.rand(n);
for (unsigned i = 0; i < n; ++i) {
auto j = (i + s) % n;
if (m_eval.try_repair(e, j)) {
set_repair_down(e->get_arg(j));
return;
}
}
}
IF_VERBOSE(3, verbose_stream() << "init-repair " << mk_bounded_pp(e, m) << "\n");
// repair was not successful, so reset the state to find a different way to repair
m_repair_down = UINT_MAX;
}
void bv_plugin::try_repair_up(app* e) {
if (m.is_bool(e))
;
else if (m_eval.repair_up(e)) {
if (!m_eval.eval_is_correct(e)) {
verbose_stream() << "incorrect eval #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n";
}
SASSERT(m_eval.eval_is_correct(e));
for (auto p : ctx.parents(e))
m_repair_up.insert(p->get_id());
}
else if (ctx.rand(10) != 0) {
IF_VERBOSE(2, verbose_stream() << "repair-up "; trace_repair(true, e));
m_eval.set_random(e);
m_repair_roots.insert(e->get_id());
}
}
std::ostream& bv_plugin::trace_repair(bool down, expr* e) {
verbose_stream() << (down ? "d #" : "u #")
<< e->get_id() << ": "
<< mk_bounded_pp(e, m, 1) << " ";
return m_eval.display_value(verbose_stream(), e) << "\n";
}
void bv_plugin::trace() {
IF_VERBOSE(2, verbose_stream()
<< "(bvsls :restarts " << m_stats.m_restarts
<< " :repair-up " << m_repair_up.size()
<< " :repair-roots " << m_repair_roots.size() << ")\n");
}
}

View file

@ -3,7 +3,7 @@ Copyright (c) 2020 Microsoft Corporation
Module Name:
sls_bv.h
sls_bv_plugin.h
Abstract:
@ -31,23 +31,34 @@ namespace sls {
indexed_uint_set m_repair_up, m_repair_roots;
unsigned m_repair_down = UINT_MAX;
bool m_initialized = false;
std::pair<bool, app*> next_to_repair();
void repair_literal(sat::literal lit);
void repair_defs_and_updates();
void init_bool_var_assignment(sat::bool_var v);
void try_repair_down(app* e);
void set_repair_down(expr* e) { m_repair_down = e->get_id(); }
void try_repair_up(app* e);
std::ostream& bv_plugin::trace_repair(bool down, expr* e);
void trace();
public:
bv_plugin(context& ctx);
~bv_plugin() override {}
void init_bool_var(sat::bool_var v) override;
void init_bool_var(sat::bool_var v) override {}
void register_term(expr* e) override;
expr_ref get_value(expr* e) override;
lbool check() override;
bool is_sat() override;
void reset() override;
void on_rescale() override;
void on_restart() override;
void on_rescale() override {}
void on_restart() override {}
std::ostream& display(std::ostream& out) const override;
void mk_model(model& mdl) override;
void mk_model(model& mdl) override {}
void set_shared(expr* e) override;
void set_value(expr* e, expr* v) override;
};

View file

@ -34,10 +34,6 @@ namespace sls {
UNREACHABLE();
return expr_ref(m);
}
void cc_plugin::reset() {
m_app.reset();
}
void cc_plugin::register_term(expr* e) {
if (!is_app(e))

View file

@ -41,7 +41,6 @@ namespace sls {
expr_ref get_value(expr* e) override;
lbool check() override;
bool is_sat() override;
void reset() override;
void register_term(expr* e) override;
void init_bool_var(sat::bool_var v) override {}
std::ostream& display(std::ostream& out) const override;

View file

@ -19,18 +19,22 @@ Author:
#include "ast/sls/sls_smt.h"
#include "ast/sls/sls_cc.h"
#include "ast/sls/sls_arith_plugin.h"
#include "ast/sls/sls_bv_plugin.h"
#include "ast/sls/sls_basic_plugin.h"
namespace sls {
plugin::plugin(context& c):
ctx(c),
m(c.get_manager()) {
reset();
}
context::context(ast_manager& m, sat_solver_context& s) :
m(m), s(s), m_atoms(m), m_allterms(m) {
reset();
register_plugin(alloc(cc_plugin, *this));
register_plugin(alloc(arith_plugin, *this));
register_plugin(alloc(bv_plugin, *this));
register_plugin(alloc(basic_plugin, *this));
}
void context::register_plugin(plugin* p) {
@ -43,19 +47,6 @@ namespace sls {
m_atom2bool_var.setx(e->get_id(), v, sat::null_bool_var);
}
void context::reset() {
m_plugins.reset();
m_atoms.reset();
m_atom2bool_var.reset();
m_initialized = false;
m_parents.reset();
m_relevant.reset();
m_visited.reset();
m_allterms.reset();
register_plugin(alloc(cc_plugin, *this));
register_plugin(alloc(arith_plugin, *this));
}
lbool context::check() {
//
// initialize data-structures if not done before.
@ -75,6 +66,9 @@ namespace sls {
return l_undef;
if (all_of(m_plugins, [&](auto* p) { return !p || p->is_sat(); })) {
model_ref mdl = alloc(model, m);
for (expr* e : subterms())
if (is_uninterp_const(e))
mdl->register_decl(to_app(e)->get_decl(), get_value(e));
for (auto p : m_plugins)
if (p)
p->mk_model(*mdl);
@ -99,10 +93,6 @@ namespace sls {
}
expr_ref context::get_value(expr* e) {
if (m.is_bool(e)) {
auto v = m_atom2bool_var[e->get_id()];
return expr_ref(is_true(sat::literal(v, false)) ? m.mk_true() : m.mk_false(), m);
}
sort* s = e->get_sort();
auto fid = s->get_family_id();
auto p = m_plugins.get(fid, nullptr);

View file

@ -41,7 +41,6 @@ namespace sls {
virtual void init_bool_var(sat::bool_var v) = 0;
virtual lbool check() = 0;
virtual bool is_sat() = 0;
virtual void reset() {};
virtual void on_rescale() {};
virtual void on_restart() {};
virtual std::ostream& display(std::ostream& out) const = 0;
@ -98,7 +97,7 @@ namespace sls {
// Between SAT/SMT solver and context.
void register_atom(sat::bool_var v, expr* e);
void reset();
// void reset();
lbool check();
// expose sat_solver to plugins
@ -109,6 +108,8 @@ namespace sls {
unsigned num_bool_vars() const { return s.num_vars(); }
bool is_true(sat::literal lit) { return s.is_true(lit); }
expr* atom(sat::bool_var v) { return m_atoms.get(v, nullptr); }
expr* term(unsigned id) const { return m_allterms.get(id); }
sat::bool_var atom2bool_var(expr* e) const { return m_atom2bool_var.get(e->get_id(), sat::null_bool_var); }
void flip(sat::bool_var v) { s.flip(v); }
double reward(sat::bool_var v) { return s.reward(v); }
indexed_uint_set const& unsat() const { return s.unsat(); }
@ -118,6 +119,11 @@ namespace sls {
void reinit_relevant();
ptr_vector<expr> const& parents(expr* e) {
m_parents.reserve(e->get_id() + 1);
return m_parents[e->get_id()];
}
// Between plugin solvers
expr_ref get_value(expr* e);
bool is_true(expr* e);