3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 18:31:49 +00:00

remove xor solver, tune dt_solver for enumeration case

This commit is contained in:
Nikolaj Bjorner 2021-02-27 17:17:39 -08:00
parent 830f314a3f
commit fb8e2e444e
10 changed files with 67 additions and 332 deletions

View file

@ -14,7 +14,6 @@ z3_add_component(sat_smt
ba_internalize.cpp
ba_pb.cpp
ba_solver.cpp
ba_xor.cpp
bv_ackerman.cpp
bv_delay_internalize.cpp
bv_internalize.cpp
@ -41,7 +40,6 @@ z3_add_component(sat_smt
sat_dual_solver.cpp
sat_th.cpp
user_solver.cpp
xor_solver.cpp
COMPONENT_DEPENDENCIES
sat
ast

View file

@ -24,13 +24,11 @@ namespace ba {
enum class tag_t {
card_t,
pb_t,
xr_t
pb_t
};
class card;
class pb;
class xr;
class pb_base;
inline lbool value(sat::model const& m, literal l) { return l.sign() ? ~m[l.var()] : m[l.var()]; }
@ -82,14 +80,11 @@ namespace ba {
size_t obj_size() const { return m_obj_size; }
card& to_card();
pb& to_pb();
xr& to_xr();
card const& to_card() const;
pb const& to_pb() const;
xr const& to_xr() const;
pb_base const& to_pb_base() const;
bool is_card() const { return m_tag == tag_t::card_t; }
bool is_pb() const { return m_tag == tag_t::pb_t; }
bool is_xr() const { return m_tag == tag_t::xr_t; }
bool is_watched(solver_interface const& s, literal lit) const;
void unwatch_literal(solver_interface& s, literal lit);

View file

@ -29,36 +29,10 @@ namespace sat {
flet<bool> _redundant(m_is_redundant, redundant);
if (m_pb.is_pb(e))
return internalize_pb(e, sign, root);
if (m.is_xor(e))
return internalize_xor(e, sign, root);
UNREACHABLE();
return null_literal;
}
literal ba_solver::internalize_xor(expr* e, bool sign, bool root) {
sat::literal_vector lits;
sat::bool_var v = s().add_var(true);
lits.push_back(literal(v, true));
auto add_expr = [&](expr* a) {
literal lit = si.internalize(a, m_is_redundant);
s().set_external(lit.var());
lits.push_back(lit);
};
expr* e1 = nullptr;
while (m.is_iff(e, e1, e))
add_expr(e1);
add_expr(e);
// ensure that = is converted to xor
for (unsigned i = 1; i + 1 < lits.size(); ++i) {
lits[i].neg();
}
add_xr(lits, m_is_redundant);
auto* aig = s().get_cut_simplifier();
if (aig) aig->add_xor(~lits.back(), lits.size() - 1, lits.c_ptr() + 1);
sat::literal lit(v, sign);
return literal(v, sign);
}
literal ba_solver::internalize_pb(expr* e, bool sign, bool root) {
SASSERT(m_pb.is_pb(e));
app* t = to_app(e);
@ -313,19 +287,6 @@ namespace sat {
return fml;
}
expr_ref ba_solver::get_xor(std::function<expr_ref(sat::literal)>& lit2expr, xr const& x) {
ptr_buffer<expr> lits;
for (sat::literal l : x) {
lits.push_back(lit2expr(l));
}
expr_ref fml(m.mk_xor(x.size(), lits.c_ptr()), m);
if (x.lit() != sat::null_literal) {
fml = m.mk_eq(lit2expr(x.lit()), fml);
}
return fml;
}
bool ba_solver::to_formulas(std::function<expr_ref(sat::literal)>& l2e, expr_ref_vector& fmls) {
for (auto* c : constraints()) {
switch (c->tag()) {
@ -335,9 +296,6 @@ namespace sat {
case ba::tag_t::pb_t:
fmls.push_back(get_pb(l2e, c->to_pb()));
break;
case ba::tag_t::xr_t:
fmls.push_back(get_xor(l2e, c->to_xr()));
break;
}
}
return true;

View file

@ -7,7 +7,7 @@ Module Name:
Abstract:
Extension for cardinality and xor reasoning.
Extension for cardinality reasoning.
Author:
@ -21,7 +21,6 @@ Author:
#include "sat/smt/ba_solver.h"
#include "sat/smt/euf_solver.h"
#include "sat/sat_simplifier_params.hpp"
#include "sat/sat_xor_finder.h"
namespace sat {
@ -40,7 +39,6 @@ namespace sat {
UNREACHABLE();
}
SASSERT(validate_conflict(c));
if (c.is_xr() && value(lit) == l_true) lit.neg();
SASSERT(value(lit) == l_false);
set_conflict(justification::mk_ext_justification(s().scope_lvl(), c.cindex()), ~lit);
SASSERT(inconsistent());
@ -749,16 +747,6 @@ namespace sat {
for (literal l : m_lemma) process_antecedent(~l, offset);
break;
}
case ba::tag_t::xr_t: {
// jus.push_back(js);
m_lemma.reset();
inc_bound(offset);
inc_coeff(consequent, offset);
get_xr_antecedents(consequent, idx, js, m_lemma);
for (literal l : m_lemma)
process_antecedent(~l, offset);
break;
}
default:
UNREACHABLE();
break;
@ -1444,7 +1432,6 @@ namespace sat {
switch (c.tag()) {
case ba::tag_t::card_t: return add_assign(c.to_card(), l);
case ba::tag_t::pb_t: return add_assign(c.to_pb(), l);
case ba::tag_t::xr_t: return add_assign(c.to_xr(), l);
}
UNREACHABLE();
return l_undef;
@ -1683,7 +1670,6 @@ namespace sat {
switch (c.tag()) {
case ba::tag_t::card_t: get_antecedents(l, c.to_card(), r); break;
case ba::tag_t::pb_t: get_antecedents(l, c.to_pb(), r); break;
case ba::tag_t::xr_t: get_antecedents(l, c.to_xr(), r); break;
default: UNREACHABLE(); break;
}
if (get_config().m_drat && m_solver && !probing) {
@ -2036,9 +2022,6 @@ namespace sat {
case ba::tag_t::pb_t:
simplify(c.to_pb());
break;
case ba::tag_t::xr_t:
simplify(c.to_xr());
break;
default:
UNREACHABLE();
}
@ -2064,10 +2047,6 @@ namespace sat {
for (unsigned sz = m_constraints.size(), i = 0; i < sz; ++i) subsumption(*m_constraints[i]);
for (unsigned sz = m_learned.size(), i = 0; i < sz; ++i) subsumption(*m_learned[i]);
unit_strengthen();
if (s().get_config().m_xor_solver) {
extract_xor();
merge_xor();
}
cleanup_clauses();
cleanup_constraints();
update_pure();
@ -2290,10 +2269,6 @@ namespace sat {
case ba::tag_t::pb_t:
recompile(c.to_pb());
break;
case ba::tag_t::xr_t:
add_xr(c.to_xr().literals(), c.learned());
remove_constraint(c, "recompile xor");
break;
default:
UNREACHABLE();
}
@ -2495,7 +2470,6 @@ namespace sat {
switch (c.tag()) {
case ba::tag_t::card_t: split_root(c.to_card()); break;
case ba::tag_t::pb_t: split_root(c.to_pb()); break;
case ba::tag_t::xr_t: NOT_IMPLEMENTED_YET(); break;
}
}
@ -2595,12 +2569,6 @@ namespace sat {
if (lit != null_literal) m_cnstr_use_list[(~l).index()].push_back(cp);
}
break;
case ba::tag_t::xr_t:
for (literal l : cp->to_xr()) {
m_cnstr_use_list[l.index()].push_back(cp);
m_cnstr_use_list[(~l).index()].push_back(cp);
}
break;
}
}
}
@ -3188,13 +3156,6 @@ namespace sat {
result->add_pb_ge(p.lit(), wlits, p.k(), p.learned());
break;
}
case ba::tag_t::xr_t: {
xr const& x = cp->to_xr();
lits.reset();
for (literal l : x) lits.push_back(l);
result->add_xr(lits, x.learned());
break;
}
default:
UNREACHABLE();
}
@ -3607,17 +3568,6 @@ namespace sat {
if (p.lit() != null_literal) ineq.push(~p.lit(), offset * p.k());
break;
}
case ba::tag_t::xr_t: {
xr& x = cnstr.to_xr();
literal_vector ls;
SASSERT(lit != null_literal);
get_antecedents(lit, x, ls);
ineq.reset(offset);
for (literal l : ls) ineq.push(~l, offset);
literal lxr = x.lit();
if (lxr != null_literal) ineq.push(~lxr, offset);
break;
}
default:
UNREACHABLE();
break;
@ -3860,7 +3810,7 @@ namespace sat {
card const& c = cp->to_card();
unsigned n = c.size();
unsigned k = c.k();
if (c.lit() == null_literal) {
// c.lits() >= k
// <=>
@ -3885,7 +3835,7 @@ namespace sat {
for (literal l : c) lits.push_back(l), coeffs.push_back(1);
lits.push_back(~c.lit()); coeffs.push_back(n - k + 1);
add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), n);
lits.reset();
coeffs.reset();
for (literal l : c) lits.push_back(~l), coeffs.push_back(1);
@ -3900,7 +3850,7 @@ namespace sat {
coeffs.reset();
unsigned sum = 0;
for (wliteral wl : p) sum += wl.first;
if (p.lit() == null_literal) {
// w1 + .. + w_n >= k
// <=>
@ -3919,7 +3869,7 @@ namespace sat {
lits.push_back(p.lit()), coeffs.push_back(p.k());
for (wliteral wl : p) lits.push_back(~(wl.second)), coeffs.push_back(wl.first);
add_pb(lits.size(), lits.c_ptr(), coeffs.c_ptr(), sum);
lits.reset();
coeffs.reset();
lits.push_back(~p.lit()), coeffs.push_back(sum + 1 - p.k());
@ -3928,8 +3878,6 @@ namespace sat {
}
break;
}
case ba::tag_t::xr_t:
return false;
}
}
return true;

View file

@ -9,7 +9,6 @@ Abstract:
Cardinality extensions,
Pseudo Booleans,
Xors
Author:
@ -29,7 +28,6 @@ Revision History:
#include "sat/smt/ba_constraint.h"
#include "sat/smt/ba_card.h"
#include "sat/smt/ba_pb.h"
#include "sat/smt/ba_xor.h"
#include "util/small_object_allocator.h"
#include "util/scoped_ptr_vector.h"
#include "util/sorting_network.h"
@ -40,11 +38,8 @@ namespace sat {
typedef ba::constraint constraint;
typedef ba::wliteral wliteral;
typedef ba::card card;
typedef ba::xr xr;
typedef ba::pb_base pb_base;
typedef ba::pb pb;
class xor_finder;
class ba_solver : public euf::th_solver, public ba::solver_interface {
@ -200,7 +195,6 @@ namespace sat {
lbool add_assign(constraint& c, literal l);
bool incremental_mode() const;
void simplify(constraint& c);
void pre_simplify(xor_finder& xu, constraint& c);
void set_conflict(constraint& c, literal lit) override;
void assign(constraint& c, literal lit) override;
bool assigned_above(literal above, literal below);
@ -234,18 +228,6 @@ namespace sat {
lbool eval(card const& c) const;
lbool eval(model const& m, card const& c) const;
// xr specific functionality
lbool add_assign(xr& x, literal alit);
void get_xr_antecedents(literal l, unsigned index, justification js, literal_vector& r);
void get_antecedents(literal l, xr const& x, literal_vector & r);
void simplify(xr& x);
void extract_xor();
void merge_xor();
bool clausify(xr& x);
void flush_roots(xr& x);
lbool eval(xr const& x) const;
lbool eval(model const& m, xr const& x) const;
// pb functionality
unsigned m_a_max{ 0 };
@ -329,7 +311,6 @@ namespace sat {
// validation utilities
bool validate_conflict(card const& c) const;
bool validate_conflict(xr const& x) const;
bool validate_conflict(pb const& p) const;
bool validate_assign(literal_vector const& lits, literal lit);
bool validate_lemma();
@ -364,11 +345,8 @@ namespace sat {
constraint* add_at_least(literal l, literal_vector const& lits, unsigned k, bool learned);
constraint* add_pb_ge(literal l, svector<wliteral> const& wlits, unsigned k, bool learned);
constraint* add_xr(literal_vector const& lits, bool learned);
literal add_xor_def(literal_vector& lits, bool learned = false);
bool all_distinct(literal_vector const& lits);
bool all_distinct(clause const& c);
bool all_distinct(xr const& x);
void copy_core(ba_solver* result, bool learned);
void copy_constraints(ba_solver* result, ptr_vector<constraint> const& constraints);
@ -386,12 +364,9 @@ namespace sat {
void convert_pb_args(app* t, literal_vector& lits);
bool m_is_redundant{ false };
literal internalize_pb(expr* e, bool sign, bool root);
literal internalize_xor(expr* e, bool sign, bool root);
// Decompile
expr_ref get_card(std::function<expr_ref(sat::literal)>& l2e, card const& c);
expr_ref get_pb(std::function<expr_ref(sat::literal)>& l2e, pb const& p);
expr_ref get_xor(std::function<expr_ref(sat::literal)>& l2e, xr const& x);
public:
ba_solver(euf::solver& ctx, euf::theory_id id);
@ -400,7 +375,6 @@ namespace sat {
void set_lookahead(lookahead* l) override { m_lookahead = l; }
void add_at_least(bool_var v, literal_vector const& lits, unsigned k);
void add_pb_ge(bool_var v, svector<wliteral> const& wlits, unsigned k);
void add_xr(literal_vector const& lits);
bool is_external(bool_var v) override;
bool propagated(literal l, ext_constraint_idx idx) override;
@ -411,7 +385,7 @@ namespace sat {
check_result check() override;
void push() override;
void pop(unsigned n) override;
void pre_simplify() override;
void pre_simplify() override {}
void simplify() override;
void clauses_modifed() override;
lbool get_phase(bool_var v) override;

View file

@ -1,192 +0,0 @@
/*++
Copyright (c) 2017 Microsoft Corporation
Module Name:
ba_xor.cpp
Abstract:
Interface for Xor constraints.
Author:
Nikolaj Bjorner (nbjorner) 2017-01-30
--*/
#include "sat/smt/ba_xor.h"
#include "sat/smt/ba_solver.h"
namespace ba {
xr& constraint::to_xr() {
SASSERT(is_xr());
return static_cast<xr&>(*this);
}
xr const& constraint::to_xr() const {
SASSERT(is_xr());
return static_cast<xr const&>(*this);
}
xr::xr(unsigned id, literal_vector const& lits) :
constraint(ba::tag_t::xr_t, id, sat::null_literal, lits.size(), get_obj_size(lits.size())) {
for (unsigned i = 0; i < size(); ++i) {
m_lits[i] = lits[i];
}
}
bool xr::is_watching(literal l) const {
return
l == (*this)[0] || l == (*this)[1] ||
~l == (*this)[0] || ~l == (*this)[1];
}
bool xr::well_formed() const {
uint_set vars;
if (lit() != sat::null_literal) vars.insert(lit().var());
for (literal l : *this) {
bool_var v = l.var();
if (vars.contains(v)) return false;
vars.insert(v);
}
return true;
}
std::ostream& xr::display(std::ostream& out) const {
for (unsigned i = 0; i < size(); ++i) {
out << (*this)[i] << " ";
if (i + 1 < size()) out << "x ";
}
return out;
}
void xr::clear_watch(solver_interface& s) {
auto& x = *this;
x.reset_watch();
x.unwatch_literal(s, x[0]);
x.unwatch_literal(s, x[1]);
x.unwatch_literal(s, ~x[0]);
x.unwatch_literal(s, ~x[1]);
}
bool xr::init_watch(solver_interface& s) {
auto& x = *this;
x.clear_watch(s);
VERIFY(x.lit() == sat::null_literal);
TRACE("ba", x.display(tout););
unsigned sz = x.size();
unsigned j = 0;
for (unsigned i = 0; i < sz && j < 2; ++i) {
if (s.value(x[i]) == l_undef) {
x.swap(i, j);
++j;
}
}
switch (j) {
case 0:
if (!parity(s, 0)) {
unsigned l = s.lvl(x[0]);
j = 1;
for (unsigned i = 1; i < sz; ++i) {
if (s.lvl(x[i]) > l) {
j = i;
l = s.lvl(x[i]);
}
}
s.set_conflict(x, x[j]);
}
return false;
case 1:
SASSERT(x.lit() == sat::null_literal || s.value(x.lit()) == l_true);
s.assign(x, parity(s, 1) ? ~x[0] : x[0]);
return false;
default:
SASSERT(j == 2);
x.watch_literal(s, x[0]);
x.watch_literal(s, x[1]);
x.watch_literal(s, ~x[0]);
x.watch_literal(s, ~x[1]);
return true;
}
}
bool xr::parity(solver_interface const& s, unsigned offset) const {
auto const& x = *this;
bool odd = false;
unsigned sz = x.size();
for (unsigned i = offset; i < sz; ++i) {
SASSERT(s.value(x[i]) != l_undef);
if (s.value(x[i]) == l_true) {
odd = !odd;
}
}
return odd;
}
std::ostream& xr::display(std::ostream& out, solver_interface const& s, bool values) const {
auto const& x = *this;
out << "xr: ";
for (literal l : x) {
out << l;
if (values) {
out << "@(" << s.value(l);
if (s.value(l) != l_undef) {
out << ":" << s.lvl(l);
}
out << ") ";
}
else {
out << " ";
}
}
return out << "\n";
}
bool xr::validate_unit_propagation(solver_interface const& s, literal alit) const {
if (s.value(lit()) != l_true) return false;
for (unsigned i = 1; i < size(); ++i) {
if (s.value((*this)[i]) == l_undef) return false;
}
return true;
}
lbool xr::eval(solver_interface const& s) const {
auto const& x = *this;
bool odd = false;
for (auto l : x) {
switch (s.value(l)) {
case l_true: odd = !odd; break;
case l_false: break;
default: return l_undef;
}
}
return odd ? l_true : l_false;
}
lbool xr::eval(sat::model const& m) const {
auto const& x = *this;
bool odd = false;
for (auto l : x) {
switch (ba::value(m, l)) {
case l_true: odd = !odd; break;
case l_false: break;
default: return l_undef;
}
}
return odd ? l_true : l_false;
}
void xr::init_use_list(sat::ext_use_list& ul) const {
auto idx = cindex();
for (auto l : *this) {
ul.insert(l, idx);
ul.insert(~l, idx);
}
}
}

View file

@ -236,10 +236,14 @@ namespace dt {
*/
void solver::mk_split(theory_var v, bool is_final) {
m_stats.m_splits++;
v = m_find.find(v);
enode* n = var2enode(v);
sort* srt = n->get_sort();
if (dt.is_enum_sort(srt)) {
mk_enum_split(v);
return;
}
func_decl* non_rec_c = dt.get_non_rec_constructor(srt);
unsigned non_rec_idx = dt.get_constructor_idx(non_rec_c);
var_data* d = m_var_data[v];
@ -289,6 +293,32 @@ namespace dt {
s().set_phase(lit);
}
void solver::mk_enum_split(theory_var v) {
enode* n = var2enode(v);
var_data* d = m_var_data[v];
sort* srt = n->get_sort();
auto const& constructors = *dt.get_datatype_constructors(srt);
unsigned sz = constructors.size();
int start = s().rand()();
m_lits.reset();
for (unsigned i = 0; i < sz; ++i) {
unsigned j = (i + start) % sz;
sat::literal lit = eq_internalize(n->get_expr(), m.mk_const(constructors[j]));
switch (s().value(lit)) {
case l_undef:
s().set_phase(lit);
return;
case l_true:
return;
case l_false:
m_lits.push_back(~lit);
break;
}
}
ctx.set_conflict(euf::th_propagation::conflict(*this, m_lits));
}
void solver::apply_sort_cnstr(enode* n, sort* s) {
force_push();
// Remark: If s is an infinite sort, then it is not necessary to create
@ -406,7 +436,7 @@ namespace dt {
CTRACE("dt", d->m_recognizers.empty(), ctx.display(tout););
SASSERT(!d->m_recognizers.empty());
literal_vector lits;
m_lits.reset();
enode_pair_vector eqs;
unsigned idx = 0;
for (enode* r : d->m_recognizers) {
@ -414,7 +444,7 @@ namespace dt {
return; // nothing to be propagated
if (r && ctx.value(r) == l_false) {
SASSERT(r->num_args() == 1);
lits.push_back(~ctx.enode2literal(r));
m_lits.push_back(~ctx.enode2literal(r));
if (n != r->get_arg(0)) {
// Argument of the current recognizer is not necessarily equal to n.
// This can happen when n and r->get_arg(0) are in the same equivalence class.
@ -432,10 +462,10 @@ namespace dt {
}
TRACE("dt", tout << "propagate " << num_unassigned << " eqs: " << eqs.size() << "\n";);
if (num_unassigned == 0)
ctx.set_conflict(euf::th_propagation::conflict(*this, lits, eqs));
ctx.set_conflict(euf::th_propagation::conflict(*this, m_lits, eqs));
else if (num_unassigned == 1) {
// propagate remaining recognizer
SASSERT(!lits.empty());
SASSERT(!m_lits.empty());
enode* r = d->m_recognizers[unassigned_idx];
literal consequent;
if (r)
@ -446,7 +476,7 @@ namespace dt {
app_ref rec_app(m.mk_app(rec, n->get_expr()), m);
consequent = mk_literal(rec_app);
}
ctx.propagate(consequent, euf::th_propagation::propagate(*this, lits, eqs, consequent));
ctx.propagate(consequent, euf::th_propagation::propagate(*this, m_lits, eqs, consequent));
}
else if (get_config().m_dt_lazy_splits == 0 || (!srt->is_infinite() && get_config().m_dt_lazy_splits == 1))
// there are more than 2 unassigned recognizers...

View file

@ -97,6 +97,7 @@ namespace dt {
enode_pair_vector m_used_eqs; // conflict, if any
parent_tbl m_parent; // parent explanation for occurs_check
svector<stack_entry> m_dfs; // stack for DFS for occurs_check
sat::literal_vector m_lits;
void clear_mark();
@ -119,6 +120,7 @@ namespace dt {
void explain_is_child(enode* parent, enode* child);
void mk_split(theory_var v, bool is_final);
void mk_enum_split(theory_var v);
void display_var(std::ostream & out, theory_var v) const;

View file

@ -205,6 +205,7 @@ namespace euf {
public:
static th_propagation* conflict(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs);
static th_propagation* conflict(th_euf_solver& th, sat::literal_vector const& lits) { return conflict(th, lits.size(), lits.c_ptr(), 0, nullptr); }
static th_propagation* conflict(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs);
static th_propagation* conflict(th_euf_solver& th, enode_pair_vector const& eqs);
static th_propagation* conflict(th_euf_solver& th, sat::literal lit);

View file

@ -416,3 +416,24 @@ namespace sat {
}
// xr specific functionality
lbool add_assign(xr& x, literal alit);
void get_xr_antecedents(literal l, unsigned index, justification js, literal_vector& r);
void get_antecedents(literal l, xr const& x, literal_vector & r);
void simplify(xr& x);
void extract_xor();
void merge_xor();
bool clausify(xr& x);
void flush_roots(xr& x);
lbool eval(xr const& x) const;
lbool eval(model const& m, xr const& x) const;
bool validate_conflict(xr const& x) const;
constraint* add_xr(literal_vector const& lits, bool learned);
literal add_xor_def(literal_vector& lits, bool learned = false);
bool all_distinct(xr const& x);
expr_ref get_xor(std::function<expr_ref(sat::literal)>& l2e, xr const& x);
void add_xr(literal_vector const& lits);
#include "sat/sat_xor_finder.h"