3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 02:15:19 +00:00
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-01-30 21:23:53 -08:00
parent 1a95c33775
commit 8b7bafbd9f
9 changed files with 221 additions and 29 deletions

View file

@ -49,6 +49,7 @@ struct pb2bv_rewriter::imp {
expr_ref_vector m_args;
rational m_k;
vector<rational> m_coeffs;
bool m_enable_card;
template<lbool is_le>
expr_ref mk_le_ge(expr_ref_vector& fmls, expr* a, expr* b, expr* bound) {
@ -238,12 +239,13 @@ struct pb2bv_rewriter::imp {
pb(m),
bv(m),
m_trail(m),
m_args(m)
m_args(m),
m_enable_card(false)
{}
bool mk_app(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
if (f->get_family_id() == pb.get_family_id()) {
mk_pb(full, f, sz, args, result);
if (f->get_family_id() == pb.get_family_id() && mk_pb(full, f, sz, args, result)) {
// skip
}
else if (au.is_le(f) && is_pb(args[0], args[1])) {
result = mk_le_ge<l_true>(m_args.size(), m_args.c_ptr(), m_k);
@ -349,29 +351,36 @@ struct pb2bv_rewriter::imp {
return false;
}
void mk_pb(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
bool mk_pb(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) {
SASSERT(f->get_family_id() == pb.get_family_id());
std::cout << "card: " << m_enable_card << "\n";
if (is_or(f)) {
result = m.mk_or(sz, args);
}
else if (pb.is_at_most_k(f) && pb.get_k(f).is_unsigned()) {
if (m_enable_card) return false;
result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args);
}
else if (pb.is_at_least_k(f) && pb.get_k(f).is_unsigned()) {
if (m_enable_card) return false;
result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args);
}
else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) {
if (m_enable_card) return false;
result = m_sort.eq(full, pb.get_k(f).get_unsigned(), sz, args);
}
else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) {
if (m_enable_card) return false;
result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args);
}
else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) {
if (m_enable_card) return false;
result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args);
}
else {
result = mk_bv(f, sz, args);
}
return true;
}
// definitions used for sorting network
@ -396,6 +405,12 @@ struct pb2bv_rewriter::imp {
void mk_clause(unsigned n, literal const* lits) {
m_imp.m_lemmas.push_back(mk_or(m, n, lits));
}
void enable_card(bool f) {
std::cout << "set " << f << "\n";
m_enable_card = f;
m_enable_card = true;
}
};
struct card2bv_rewriter_cfg : public default_rewriter_cfg {
@ -407,6 +422,7 @@ struct pb2bv_rewriter::imp {
return m_r.mk_app_core(f, num, args, result);
}
card2bv_rewriter_cfg(imp& i, ast_manager & m):m_r(i, m) {}
void enable_card(bool f) { m_r.enable_card(f); }
};
class card_pb_rewriter : public rewriter_tpl<card2bv_rewriter_cfg> {
@ -415,6 +431,7 @@ struct pb2bv_rewriter::imp {
card_pb_rewriter(imp& i, ast_manager & m):
rewriter_tpl<card2bv_rewriter_cfg>(m, false, m_cfg),
m_cfg(i, m) {}
void enable_card(bool f) { m_cfg.enable_card(f); }
};
card_pb_rewriter m_rw;
@ -424,9 +441,13 @@ struct pb2bv_rewriter::imp {
m_fresh(m),
m_num_translated(0),
m_rw(*this, m) {
m_rw.enable_card(p.get_bool("cardinality_solver", false));
}
void updt_params(params_ref const & p) {}
void updt_params(params_ref const & p) {
m_params.append(p);
m_rw.enable_card(m_params.get_bool("cardinality_solver", false));
}
unsigned get_num_steps() const { return m_rw.get_num_steps(); }
void cleanup() { m_rw.cleanup(); }
void operator()(expr * e, expr_ref & result, proof_ref & result_proof) {

View file

@ -144,12 +144,14 @@ namespace sat {
}
void card_extension::watch_literal(card& c, literal lit) {
TRACE("sat", tout << "watch: " << lit << "\n";);
init_watch(lit.var());
ptr_vector<card>* cards = m_var_infos[lit.var()].m_lit_watch[lit.sign()];
if (cards == 0) {
cards = alloc(ptr_vector<card>);
m_var_infos[lit.var()].m_lit_watch[lit.sign()] = cards;
}
TRACE("sat", tout << "insert: " << lit.var() << " " << lit.sign() << "\n";);
cards->push_back(&c);
}
@ -436,7 +438,9 @@ namespace sat {
return p;
}
card_extension::card_extension(): m_solver(0) {}
card_extension::card_extension(): m_solver(0) {
TRACE("sat", tout << this << "\n";);
}
card_extension::~card_extension() {
for (unsigned i = 0; i < m_var_infos.size(); ++i) {
@ -537,7 +541,10 @@ namespace sat {
void card_extension::asserted(literal l) {
bool_var v = l.var();
if (v >= m_var_infos.size()) return;
ptr_vector<card>* cards = m_var_infos[v].m_lit_watch[!l.sign()];
TRACE("sat", tout << "retrieve: " << v << " " << !l.sign() << "\n";);
TRACE("sat", tout << "asserted: " << l << " " << (cards ? "non-empty" : "empty") << "\n";);
if (cards != 0 && !cards->empty() && !s().inconsistent()) {
ptr_vector<card>::iterator it = cards->begin(), it2 = it, end = cards->end();
for (; it != end; ++it) {
@ -545,7 +552,7 @@ namespace sat {
if (value(c.lit()) != l_true) {
continue;
}
switch (add_assign(c, l)) {
switch (add_assign(c, ~l)) {
case l_false: // conflict
for (; it != end; ++it, ++it2) {
*it2 = *it;
@ -579,6 +586,7 @@ namespace sat {
}
void card_extension::pop(unsigned n) {
TRACE("sat", tout << "pop:" << n << "\n";);
unsigned new_lim = m_var_lim.size() - n;
unsigned sz = m_var_lim[new_lim];
while (m_var_trail.size() > sz) {
@ -598,6 +606,21 @@ namespace sat {
void card_extension::clauses_modifed() {}
lbool card_extension::get_phase(bool_var v) { return l_undef; }
extension* card_extension::copy(solver* s) {
card_extension* result = alloc(card_extension);
result->set_solver(s);
for (unsigned i = 0; i < m_constraints.size(); ++i) {
literal_vector lits;
card& c = *m_constraints[i];
for (unsigned i = 0; i < c.size(); ++i) {
lits.push_back(c[i]);
}
result->add_at_least(c.lit().var(), lits, c.k());
}
return result;
}
void card_extension::display_watch(std::ostream& out, bool_var v, bool sign) const {
watch const* w = m_var_infos[v].m_lit_watch[sign];
if (w) {
@ -647,7 +670,7 @@ namespace sat {
for (unsigned vi = 0; vi < m_var_infos.size(); ++vi) {
card* c = m_var_infos[vi].m_card;
if (c) {
display(out, *c, true);
display(out, *c, false);
}
}
return out;

View file

@ -139,7 +139,7 @@ namespace sat {
public:
card_extension();
virtual ~card_extension();
void set_solver(solver* s) { m_solver = s; }
virtual void set_solver(solver* s) { m_solver = s; }
void add_at_least(bool_var v, literal_vector const& lits, unsigned k);
virtual void propagate(literal l, ext_constraint_idx idx, bool & keep);
virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r);
@ -152,6 +152,7 @@ namespace sat {
virtual lbool get_phase(bool_var v);
virtual std::ostream& display(std::ostream& out) const;
virtual void collect_statistics(statistics& st) const;
virtual extension* copy(solver* s);
};
};

View file

@ -32,6 +32,7 @@ namespace sat {
class extension {
public:
virtual ~extension() {}
virtual void set_solver(solver* s) = 0;
virtual void propagate(literal l, ext_constraint_idx idx, bool & keep) = 0;
virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) = 0;
virtual void asserted(literal l) = 0;
@ -43,6 +44,7 @@ namespace sat {
virtual lbool get_phase(bool_var v) = 0;
virtual std::ostream& display(std::ostream& out) const = 0;
virtual void collect_statistics(statistics& st) const = 0;
virtual extension* copy(solver* s) = 0;
};
};

View file

@ -23,4 +23,5 @@ def_module_params('sat',
('core.minimize', BOOL, False, 'minimize computed core'),
('core.minimize_partial', BOOL, False, 'apply partial (cheap) core minimization'),
('parallel_threads', UINT, 1, 'number of parallel threads to use'),
('cardinality_solver', BOOL, False, 'enable cardinality based solver'),
('dimacs.core', BOOL, False, 'extract core from DIMACS benchmarks')))

View file

@ -54,6 +54,7 @@ namespace sat {
m_conflicts = 0;
m_next_simplify = 0;
m_num_checkpoints = 0;
if (m_ext) m_ext->set_solver(this);
}
solver::~solver() {
@ -84,13 +85,15 @@ namespace sat {
VERIFY(v == mk_var(ext, dvar));
}
}
unsigned sz = src.scope_lvl() == 0 ? src.m_trail.size() : src.m_scopes[0].m_trail_lim;
for (unsigned i = 0; i < sz; ++i) {
assign(src.m_trail[i], justification());
{
unsigned sz = src.scope_lvl() == 0 ? src.m_trail.size() : src.m_scopes[0].m_trail_lim;
for (unsigned i = 0; i < sz; ++i) {
assign(src.m_trail[i], justification());
}
}
// copy binary clauses
{
// copy binary clauses
unsigned sz = src.m_watches.size();
for (unsigned l_idx = 0; l_idx < sz; ++l_idx) {
literal l = ~to_literal(l_idx);
@ -107,6 +110,7 @@ namespace sat {
}
}
}
{
literal_vector buffer;
// copy clause
@ -120,6 +124,10 @@ namespace sat {
mk_clause_core(buffer);
}
}
if (src.get_extension()) {
m_ext = src.get_extension()->copy(this);
}
}
// -----------------------

View file

@ -74,7 +74,7 @@ namespace sat {
reslimit& m_rlimit;
config m_config;
stats m_stats;
extension * m_ext;
scoped_ptr<extension> m_ext;
par* m_par;
random_gen m_rand;
clause_allocator m_cls_allocator;
@ -251,6 +251,7 @@ namespace sat {
void set_par(par* p);
bool canceled() { return !m_rlimit.inc(); }
config const& get_config() { return m_config; }
extension* get_extension() const { return m_ext.get(); }
typedef std::pair<literal, literal> bin_clause;
protected:
watch_list & get_wlist(literal l) { return m_watches[l.index()]; }

View file

@ -20,6 +20,7 @@ Notes:
#include "solver.h"
#include "tactical.h"
#include "sat_solver.h"
#include "card_extension.h"
#include "tactic2solver.h"
#include "aig_tactic.h"
#include "propagate_values_tactic.h"
@ -35,6 +36,8 @@ Notes:
#include "ast_translation.h"
#include "ast_util.h"
#include "propagate_values_tactic.h"
#include "sat_params.hpp"
// incremental SAT solver.
class inc_sat_solver : public solver {
@ -68,7 +71,8 @@ class inc_sat_solver : public solver {
typedef obj_map<expr, sat::literal> dep2asm_t;
public:
inc_sat_solver(ast_manager& m, params_ref const& p):
m(m), m_solver(p, m.limit(), 0),
m(m),
m_solver(p, m.limit(), alloc(sat::card_extension)),
m_params(p), m_optimize_model(false),
m_fmls(m),
m_asmsf(m),
@ -79,6 +83,8 @@ public:
m_dep_core(m),
m_unknown("no reason given") {
m_params.set_bool("elim_vars", false);
sat_params p1(m_params);
m_params.set_bool("cardinality_solver", p1.cardinality_solver());
m_solver.updt_params(m_params);
init_preprocess();
}
@ -86,6 +92,7 @@ public:
virtual ~inc_sat_solver() {}
virtual solver* translate(ast_manager& dst_m, params_ref const& p) {
std::cout << "translate\n";
ast_translation tr(m, dst_m);
if (m_num_scopes > 0) {
throw default_exception("Cannot translate sat solver at non-base level");
@ -210,8 +217,11 @@ public:
sat::solver::collect_param_descrs(r);
}
virtual void updt_params(params_ref const & p) {
m_params = p;
m_params.append(p);
sat_params p1(p);
m_params.set_bool("cardinality_solver", p1.cardinality_solver());
m_params.set_bool("elim_vars", false);
std::cout << m_params << "\n";
m_solver.updt_params(m_params);
m_optimize_model = m_params.get_bool("optimize_model", false);
}

View file

@ -37,6 +37,7 @@ Notes:
#include"tactic.h"
#include"ast_pp.h"
#include"pb_decl_plugin.h"
#include"card_extension.h"
#include<sstream>
struct goal2sat::imp {
@ -50,6 +51,7 @@ struct goal2sat::imp {
};
ast_manager & m;
pb_util pb;
sat::card_extension* m_ext;
svector<frame> m_frame_stack;
svector<sat::literal> m_result_stack;
obj_map<app, sat::literal> m_cache;
@ -67,6 +69,7 @@ struct goal2sat::imp {
imp(ast_manager & _m, params_ref const & p, sat::solver & s, atom2bool_var & map, dep2asm_map& dep2asm, bool default_external):
m(_m),
pb(m),
m_ext(0),
m_solver(s),
m_map(map),
m_dep2asm(dep2asm),
@ -75,6 +78,11 @@ struct goal2sat::imp {
m_default_external(default_external) {
updt_params(p);
m_true = sat::null_bool_var;
sat::extension* e = m_solver.get_extension();
if (e) {
sat::card_extension* ce = dynamic_cast<sat::card_extension*>(e);
m_ext = ce;
}
}
void updt_params(params_ref const & p) {
@ -116,7 +124,7 @@ struct goal2sat::imp {
return m_true;
}
bool convert_atom(expr * t, bool root, bool sign) {
void convert_atom(expr * t, bool root, bool sign) {
SASSERT(m.is_bool(t));
sat::literal l;
sat::bool_var v = m_map.to_bool_var(t);
@ -147,15 +155,17 @@ struct goal2sat::imp {
mk_clause(l);
else
m_result_stack.push_back(l);
return true;
}
bool convert_app(app* t, bool root, bool sign) {
return convert_atom(t, root, sign);
}
bool convert_pb(app* t, bool root, bool sign) {
if (m_ext && t->get_family_id() == pb.get_family_id()) {
m_frame_stack.push_back(frame(to_app(t), root, sign, 0));
return false;
}
else {
convert_atom(t, root, sign);
return true;
}
}
bool process_cached(app * t, bool root, bool sign) {
@ -175,7 +185,8 @@ struct goal2sat::imp {
bool visit(expr * t, bool root, bool sign) {
if (!is_app(t)) {
return convert_atom(t, root, sign);
convert_atom(t, root, sign);
return true;
}
if (process_cached(to_app(t), root, sign))
return true;
@ -195,7 +206,10 @@ struct goal2sat::imp {
m_frame_stack.push_back(frame(to_app(t), root, sign, 0));
return false;
}
return convert_atom(t, root, sign);
else {
convert_atom(t, root, sign);
return true;
}
case OP_XOR:
case OP_IMPLIES:
case OP_DISTINCT: {
@ -205,7 +219,8 @@ struct goal2sat::imp {
throw_op_not_handled(strm.str());
}
default:
return convert_atom(t, root, sign);
convert_atom(t, root, sign);
return true;
}
}
@ -361,6 +376,95 @@ struct goal2sat::imp {
}
}
void convert_pb_args(app* t, sat::literal_vector& lits) {
unsigned num_args = t->get_num_args();
unsigned sz = m_result_stack.size();
for (unsigned i = 0; i < num_args; ++i) {
sat::literal lit(m_result_stack[sz - num_args + i]);
if (!m_solver.is_external(lit.var())) {
sat::bool_var w = m_solver.mk_var(true);
sat::literal lit2(w, false);
mk_clause(lit, ~lit2);
mk_clause(~lit, lit2);
lit = lit2;
}
lits.push_back(lit);
}
}
void convert_at_least_k(app* t, rational k, bool root, bool sign) {
SASSERT(k.is_unsigned());
sat::literal_vector lits;
unsigned sz = m_result_stack.size();
convert_pb_args(t, lits);
sat::bool_var v = m_solver.mk_var(true);
sat::literal lit(v, sign);
m_ext->add_at_least(v, lits, k.get_unsigned());
TRACE("sat", tout << "root: " << root << " lit: " << lit << "\n";);
if (root) {
m_result_stack.reset();
mk_clause(lit);
}
else {
m_result_stack.shrink(sz - t->get_num_args());
m_result_stack.push_back(lit);
}
}
void convert_at_most_k(app* t, rational k, bool root, bool sign) {
SASSERT(k.is_unsigned());
sat::literal_vector lits;
unsigned sz = m_result_stack.size();
convert_pb_args(t, lits);
for (unsigned i = 0; i < lits.size(); ++i) {
lits[i].neg();
}
sat::bool_var v = m_solver.mk_var(true);
m_ext->add_at_least(v, lits, lits.size() - k.get_unsigned() + 1);
if (root) {
m_result_stack.reset();
mk_clause(sat::literal(v, sign));
}
else {
m_result_stack.shrink(sz - t->get_num_args());
m_result_stack.push_back(sat::literal(v, sign));
}
}
void convert_eq_k(app* t, rational k, bool root, bool sign) {
SASSERT(k.is_unsigned());
sat::literal_vector lits;
convert_pb_args(t, lits);
sat::bool_var v1 = m_solver.mk_var(true);
sat::bool_var v2 = m_solver.mk_var(true);
sat::literal l1(v1, false), l2(v2, false);
m_ext->add_at_least(v1, lits, k.get_unsigned());
for (unsigned i = 0; i < lits.size(); ++i) {
lits[i].neg();
}
m_ext->add_at_least(v2, lits, lits.size() - k.get_unsigned() + 1);
if (root) {
m_result_stack.reset();
if (sign) {
mk_clause(~l1, ~l2);
}
else {
mk_clause(l1);
mk_clause(l2);
}
m_result_stack.reset();
}
else {
sat::bool_var v = m_solver.mk_var();
sat::literal l(v, false);
mk_clause(~l, l1);
mk_clause(~l, l2);
mk_clause(~l1, ~l2, l);
m_result_stack.shrink(m_result_stack.size() - t->get_num_args());
m_result_stack.push_back(l);
}
}
void convert(app * t, bool root, bool sign) {
if (t->get_family_id() == m.get_basic_family_id()) {
switch (to_app(t)->get_decl_kind()) {
@ -376,13 +480,34 @@ struct goal2sat::imp {
case OP_IFF:
case OP_EQ:
convert_iff(t, root, sign);
break;
break;
default:
UNREACHABLE();
}
}
else if (t->get_family_id() == pb.get_fid()) {
NOT_IMPLEMENTED_YET();
else if (m_ext && t->get_family_id() == pb.get_family_id()) {
switch (t->get_decl_kind()) {
case OP_AT_MOST_K:
convert_at_most_k(t, pb.get_k(t), root, sign);
break;
case OP_AT_LEAST_K:
convert_at_least_k(t, pb.get_k(t), root, sign);
break;
case OP_PB_LE:
SASSERT(pb.has_unit_coefficients(t));
convert_at_most_k(t, pb.get_k(t), root, sign);
break;
case OP_PB_GE:
SASSERT(pb.has_unit_coefficients(t));
convert_at_least_k(t, pb.get_k(t), root, sign);
break;
case OP_PB_EQ:
SASSERT(pb.has_unit_coefficients(t));
convert_eq_k(t, pb.get_k(t), root, sign);
break;
default:
UNREACHABLE();
}
}
else {
UNREACHABLE();