diff --git a/src/ast/rewriter/pb2bv_rewriter.cpp b/src/ast/rewriter/pb2bv_rewriter.cpp index 37c87cd5b..6ba55d12a 100644 --- a/src/ast/rewriter/pb2bv_rewriter.cpp +++ b/src/ast/rewriter/pb2bv_rewriter.cpp @@ -49,6 +49,7 @@ struct pb2bv_rewriter::imp { expr_ref_vector m_args; rational m_k; vector m_coeffs; + bool m_enable_card; template expr_ref mk_le_ge(expr_ref_vector& fmls, expr* a, expr* b, expr* bound) { @@ -238,12 +239,13 @@ struct pb2bv_rewriter::imp { pb(m), bv(m), m_trail(m), - m_args(m) + m_args(m), + m_enable_card(false) {} bool mk_app(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) { - if (f->get_family_id() == pb.get_family_id()) { - mk_pb(full, f, sz, args, result); + if (f->get_family_id() == pb.get_family_id() && mk_pb(full, f, sz, args, result)) { + // skip } else if (au.is_le(f) && is_pb(args[0], args[1])) { result = mk_le_ge(m_args.size(), m_args.c_ptr(), m_k); @@ -349,29 +351,36 @@ struct pb2bv_rewriter::imp { return false; } - void mk_pb(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) { + bool mk_pb(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) { SASSERT(f->get_family_id() == pb.get_family_id()); + std::cout << "card: " << m_enable_card << "\n"; if (is_or(f)) { result = m.mk_or(sz, args); } else if (pb.is_at_most_k(f) && pb.get_k(f).is_unsigned()) { + if (m_enable_card) return false; result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args); } else if (pb.is_at_least_k(f) && pb.get_k(f).is_unsigned()) { + if (m_enable_card) return false; result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args); } else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) { + if (m_enable_card) return false; result = m_sort.eq(full, pb.get_k(f).get_unsigned(), sz, args); } else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) { + if (m_enable_card) return false; result = m_sort.le(full, pb.get_k(f).get_unsigned(), sz, args); } else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && pb.has_unit_coefficients(f)) { + if (m_enable_card) return false; result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args); } else { result = mk_bv(f, sz, args); } + return true; } // definitions used for sorting network @@ -396,6 +405,12 @@ struct pb2bv_rewriter::imp { void mk_clause(unsigned n, literal const* lits) { m_imp.m_lemmas.push_back(mk_or(m, n, lits)); } + + void enable_card(bool f) { + std::cout << "set " << f << "\n"; + m_enable_card = f; + m_enable_card = true; + } }; struct card2bv_rewriter_cfg : public default_rewriter_cfg { @@ -407,6 +422,7 @@ struct pb2bv_rewriter::imp { return m_r.mk_app_core(f, num, args, result); } card2bv_rewriter_cfg(imp& i, ast_manager & m):m_r(i, m) {} + void enable_card(bool f) { m_r.enable_card(f); } }; class card_pb_rewriter : public rewriter_tpl { @@ -415,6 +431,7 @@ struct pb2bv_rewriter::imp { card_pb_rewriter(imp& i, ast_manager & m): rewriter_tpl(m, false, m_cfg), m_cfg(i, m) {} + void enable_card(bool f) { m_cfg.enable_card(f); } }; card_pb_rewriter m_rw; @@ -424,9 +441,13 @@ struct pb2bv_rewriter::imp { m_fresh(m), m_num_translated(0), m_rw(*this, m) { + m_rw.enable_card(p.get_bool("cardinality_solver", false)); } - void updt_params(params_ref const & p) {} + void updt_params(params_ref const & p) { + m_params.append(p); + m_rw.enable_card(m_params.get_bool("cardinality_solver", false)); + } unsigned get_num_steps() const { return m_rw.get_num_steps(); } void cleanup() { m_rw.cleanup(); } void operator()(expr * e, expr_ref & result, proof_ref & result_proof) { diff --git a/src/sat/card_extension.cpp b/src/sat/card_extension.cpp index 2af1a5b5e..756e9642c 100644 --- a/src/sat/card_extension.cpp +++ b/src/sat/card_extension.cpp @@ -144,12 +144,14 @@ namespace sat { } void card_extension::watch_literal(card& c, literal lit) { + TRACE("sat", tout << "watch: " << lit << "\n";); init_watch(lit.var()); ptr_vector* cards = m_var_infos[lit.var()].m_lit_watch[lit.sign()]; if (cards == 0) { cards = alloc(ptr_vector); m_var_infos[lit.var()].m_lit_watch[lit.sign()] = cards; } + TRACE("sat", tout << "insert: " << lit.var() << " " << lit.sign() << "\n";); cards->push_back(&c); } @@ -436,7 +438,9 @@ namespace sat { return p; } - card_extension::card_extension(): m_solver(0) {} + card_extension::card_extension(): m_solver(0) { + TRACE("sat", tout << this << "\n";); + } card_extension::~card_extension() { for (unsigned i = 0; i < m_var_infos.size(); ++i) { @@ -537,7 +541,10 @@ namespace sat { void card_extension::asserted(literal l) { bool_var v = l.var(); + if (v >= m_var_infos.size()) return; ptr_vector* cards = m_var_infos[v].m_lit_watch[!l.sign()]; + TRACE("sat", tout << "retrieve: " << v << " " << !l.sign() << "\n";); + TRACE("sat", tout << "asserted: " << l << " " << (cards ? "non-empty" : "empty") << "\n";); if (cards != 0 && !cards->empty() && !s().inconsistent()) { ptr_vector::iterator it = cards->begin(), it2 = it, end = cards->end(); for (; it != end; ++it) { @@ -545,7 +552,7 @@ namespace sat { if (value(c.lit()) != l_true) { continue; } - switch (add_assign(c, l)) { + switch (add_assign(c, ~l)) { case l_false: // conflict for (; it != end; ++it, ++it2) { *it2 = *it; @@ -579,6 +586,7 @@ namespace sat { } void card_extension::pop(unsigned n) { + TRACE("sat", tout << "pop:" << n << "\n";); unsigned new_lim = m_var_lim.size() - n; unsigned sz = m_var_lim[new_lim]; while (m_var_trail.size() > sz) { @@ -598,6 +606,21 @@ namespace sat { void card_extension::clauses_modifed() {} lbool card_extension::get_phase(bool_var v) { return l_undef; } + extension* card_extension::copy(solver* s) { + card_extension* result = alloc(card_extension); + result->set_solver(s); + for (unsigned i = 0; i < m_constraints.size(); ++i) { + literal_vector lits; + card& c = *m_constraints[i]; + for (unsigned i = 0; i < c.size(); ++i) { + lits.push_back(c[i]); + } + result->add_at_least(c.lit().var(), lits, c.k()); + } + return result; + } + + void card_extension::display_watch(std::ostream& out, bool_var v, bool sign) const { watch const* w = m_var_infos[v].m_lit_watch[sign]; if (w) { @@ -647,7 +670,7 @@ namespace sat { for (unsigned vi = 0; vi < m_var_infos.size(); ++vi) { card* c = m_var_infos[vi].m_card; if (c) { - display(out, *c, true); + display(out, *c, false); } } return out; diff --git a/src/sat/card_extension.h b/src/sat/card_extension.h index 1593ef26f..5c3f1e293 100644 --- a/src/sat/card_extension.h +++ b/src/sat/card_extension.h @@ -139,7 +139,7 @@ namespace sat { public: card_extension(); virtual ~card_extension(); - void set_solver(solver* s) { m_solver = s; } + virtual void set_solver(solver* s) { m_solver = s; } void add_at_least(bool_var v, literal_vector const& lits, unsigned k); virtual void propagate(literal l, ext_constraint_idx idx, bool & keep); virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r); @@ -152,6 +152,7 @@ namespace sat { virtual lbool get_phase(bool_var v); virtual std::ostream& display(std::ostream& out) const; virtual void collect_statistics(statistics& st) const; + virtual extension* copy(solver* s); }; }; diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index f1a48c0a8..c065e132e 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -32,6 +32,7 @@ namespace sat { class extension { public: virtual ~extension() {} + virtual void set_solver(solver* s) = 0; virtual void propagate(literal l, ext_constraint_idx idx, bool & keep) = 0; virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) = 0; virtual void asserted(literal l) = 0; @@ -43,6 +44,7 @@ namespace sat { virtual lbool get_phase(bool_var v) = 0; virtual std::ostream& display(std::ostream& out) const = 0; virtual void collect_statistics(statistics& st) const = 0; + virtual extension* copy(solver* s) = 0; }; }; diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index 60708fd5c..940fa8c45 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -23,4 +23,5 @@ def_module_params('sat', ('core.minimize', BOOL, False, 'minimize computed core'), ('core.minimize_partial', BOOL, False, 'apply partial (cheap) core minimization'), ('parallel_threads', UINT, 1, 'number of parallel threads to use'), + ('cardinality_solver', BOOL, False, 'enable cardinality based solver'), ('dimacs.core', BOOL, False, 'extract core from DIMACS benchmarks'))) diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 647c5f9ed..d4349d2c2 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -54,6 +54,7 @@ namespace sat { m_conflicts = 0; m_next_simplify = 0; m_num_checkpoints = 0; + if (m_ext) m_ext->set_solver(this); } solver::~solver() { @@ -84,13 +85,15 @@ namespace sat { VERIFY(v == mk_var(ext, dvar)); } } - unsigned sz = src.scope_lvl() == 0 ? src.m_trail.size() : src.m_scopes[0].m_trail_lim; - for (unsigned i = 0; i < sz; ++i) { - assign(src.m_trail[i], justification()); + { + unsigned sz = src.scope_lvl() == 0 ? src.m_trail.size() : src.m_scopes[0].m_trail_lim; + for (unsigned i = 0; i < sz; ++i) { + assign(src.m_trail[i], justification()); + } } + // copy binary clauses { - // copy binary clauses unsigned sz = src.m_watches.size(); for (unsigned l_idx = 0; l_idx < sz; ++l_idx) { literal l = ~to_literal(l_idx); @@ -107,6 +110,7 @@ namespace sat { } } } + { literal_vector buffer; // copy clause @@ -120,6 +124,10 @@ namespace sat { mk_clause_core(buffer); } } + + if (src.get_extension()) { + m_ext = src.get_extension()->copy(this); + } } // ----------------------- diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index b42acc680..25c823446 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -74,7 +74,7 @@ namespace sat { reslimit& m_rlimit; config m_config; stats m_stats; - extension * m_ext; + scoped_ptr m_ext; par* m_par; random_gen m_rand; clause_allocator m_cls_allocator; @@ -251,6 +251,7 @@ namespace sat { void set_par(par* p); bool canceled() { return !m_rlimit.inc(); } config const& get_config() { return m_config; } + extension* get_extension() const { return m_ext.get(); } typedef std::pair bin_clause; protected: watch_list & get_wlist(literal l) { return m_watches[l.index()]; } diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index 65a7b021c..e20ca9583 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -20,6 +20,7 @@ Notes: #include "solver.h" #include "tactical.h" #include "sat_solver.h" +#include "card_extension.h" #include "tactic2solver.h" #include "aig_tactic.h" #include "propagate_values_tactic.h" @@ -35,6 +36,8 @@ Notes: #include "ast_translation.h" #include "ast_util.h" #include "propagate_values_tactic.h" +#include "sat_params.hpp" + // incremental SAT solver. class inc_sat_solver : public solver { @@ -68,7 +71,8 @@ class inc_sat_solver : public solver { typedef obj_map dep2asm_t; public: inc_sat_solver(ast_manager& m, params_ref const& p): - m(m), m_solver(p, m.limit(), 0), + m(m), + m_solver(p, m.limit(), alloc(sat::card_extension)), m_params(p), m_optimize_model(false), m_fmls(m), m_asmsf(m), @@ -79,6 +83,8 @@ public: m_dep_core(m), m_unknown("no reason given") { m_params.set_bool("elim_vars", false); + sat_params p1(m_params); + m_params.set_bool("cardinality_solver", p1.cardinality_solver()); m_solver.updt_params(m_params); init_preprocess(); } @@ -86,6 +92,7 @@ public: virtual ~inc_sat_solver() {} virtual solver* translate(ast_manager& dst_m, params_ref const& p) { + std::cout << "translate\n"; ast_translation tr(m, dst_m); if (m_num_scopes > 0) { throw default_exception("Cannot translate sat solver at non-base level"); @@ -210,8 +217,11 @@ public: sat::solver::collect_param_descrs(r); } virtual void updt_params(params_ref const & p) { - m_params = p; + m_params.append(p); + sat_params p1(p); + m_params.set_bool("cardinality_solver", p1.cardinality_solver()); m_params.set_bool("elim_vars", false); + std::cout << m_params << "\n"; m_solver.updt_params(m_params); m_optimize_model = m_params.get_bool("optimize_model", false); } diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 72340bc1c..41fb4dfce 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -37,6 +37,7 @@ Notes: #include"tactic.h" #include"ast_pp.h" #include"pb_decl_plugin.h" +#include"card_extension.h" #include struct goal2sat::imp { @@ -50,6 +51,7 @@ struct goal2sat::imp { }; ast_manager & m; pb_util pb; + sat::card_extension* m_ext; svector m_frame_stack; svector m_result_stack; obj_map m_cache; @@ -67,6 +69,7 @@ struct goal2sat::imp { imp(ast_manager & _m, params_ref const & p, sat::solver & s, atom2bool_var & map, dep2asm_map& dep2asm, bool default_external): m(_m), pb(m), + m_ext(0), m_solver(s), m_map(map), m_dep2asm(dep2asm), @@ -75,6 +78,11 @@ struct goal2sat::imp { m_default_external(default_external) { updt_params(p); m_true = sat::null_bool_var; + sat::extension* e = m_solver.get_extension(); + if (e) { + sat::card_extension* ce = dynamic_cast(e); + m_ext = ce; + } } void updt_params(params_ref const & p) { @@ -116,7 +124,7 @@ struct goal2sat::imp { return m_true; } - bool convert_atom(expr * t, bool root, bool sign) { + void convert_atom(expr * t, bool root, bool sign) { SASSERT(m.is_bool(t)); sat::literal l; sat::bool_var v = m_map.to_bool_var(t); @@ -147,15 +155,17 @@ struct goal2sat::imp { mk_clause(l); else m_result_stack.push_back(l); - return true; } bool convert_app(app* t, bool root, bool sign) { - return convert_atom(t, root, sign); - } - - bool convert_pb(app* t, bool root, bool sign) { - + if (m_ext && t->get_family_id() == pb.get_family_id()) { + m_frame_stack.push_back(frame(to_app(t), root, sign, 0)); + return false; + } + else { + convert_atom(t, root, sign); + return true; + } } bool process_cached(app * t, bool root, bool sign) { @@ -175,7 +185,8 @@ struct goal2sat::imp { bool visit(expr * t, bool root, bool sign) { if (!is_app(t)) { - return convert_atom(t, root, sign); + convert_atom(t, root, sign); + return true; } if (process_cached(to_app(t), root, sign)) return true; @@ -195,7 +206,10 @@ struct goal2sat::imp { m_frame_stack.push_back(frame(to_app(t), root, sign, 0)); return false; } - return convert_atom(t, root, sign); + else { + convert_atom(t, root, sign); + return true; + } case OP_XOR: case OP_IMPLIES: case OP_DISTINCT: { @@ -205,7 +219,8 @@ struct goal2sat::imp { throw_op_not_handled(strm.str()); } default: - return convert_atom(t, root, sign); + convert_atom(t, root, sign); + return true; } } @@ -361,6 +376,95 @@ struct goal2sat::imp { } } + void convert_pb_args(app* t, sat::literal_vector& lits) { + unsigned num_args = t->get_num_args(); + unsigned sz = m_result_stack.size(); + for (unsigned i = 0; i < num_args; ++i) { + sat::literal lit(m_result_stack[sz - num_args + i]); + if (!m_solver.is_external(lit.var())) { + sat::bool_var w = m_solver.mk_var(true); + sat::literal lit2(w, false); + mk_clause(lit, ~lit2); + mk_clause(~lit, lit2); + lit = lit2; + } + lits.push_back(lit); + } + } + + void convert_at_least_k(app* t, rational k, bool root, bool sign) { + SASSERT(k.is_unsigned()); + sat::literal_vector lits; + unsigned sz = m_result_stack.size(); + convert_pb_args(t, lits); + sat::bool_var v = m_solver.mk_var(true); + sat::literal lit(v, sign); + m_ext->add_at_least(v, lits, k.get_unsigned()); + TRACE("sat", tout << "root: " << root << " lit: " << lit << "\n";); + if (root) { + m_result_stack.reset(); + mk_clause(lit); + } + else { + m_result_stack.shrink(sz - t->get_num_args()); + m_result_stack.push_back(lit); + } + } + + void convert_at_most_k(app* t, rational k, bool root, bool sign) { + SASSERT(k.is_unsigned()); + sat::literal_vector lits; + unsigned sz = m_result_stack.size(); + convert_pb_args(t, lits); + for (unsigned i = 0; i < lits.size(); ++i) { + lits[i].neg(); + } + sat::bool_var v = m_solver.mk_var(true); + m_ext->add_at_least(v, lits, lits.size() - k.get_unsigned() + 1); + if (root) { + m_result_stack.reset(); + mk_clause(sat::literal(v, sign)); + } + else { + m_result_stack.shrink(sz - t->get_num_args()); + m_result_stack.push_back(sat::literal(v, sign)); + } + } + + void convert_eq_k(app* t, rational k, bool root, bool sign) { + SASSERT(k.is_unsigned()); + sat::literal_vector lits; + convert_pb_args(t, lits); + sat::bool_var v1 = m_solver.mk_var(true); + sat::bool_var v2 = m_solver.mk_var(true); + sat::literal l1(v1, false), l2(v2, false); + m_ext->add_at_least(v1, lits, k.get_unsigned()); + for (unsigned i = 0; i < lits.size(); ++i) { + lits[i].neg(); + } + m_ext->add_at_least(v2, lits, lits.size() - k.get_unsigned() + 1); + if (root) { + m_result_stack.reset(); + if (sign) { + mk_clause(~l1, ~l2); + } + else { + mk_clause(l1); + mk_clause(l2); + } + m_result_stack.reset(); + } + else { + sat::bool_var v = m_solver.mk_var(); + sat::literal l(v, false); + mk_clause(~l, l1); + mk_clause(~l, l2); + mk_clause(~l1, ~l2, l); + m_result_stack.shrink(m_result_stack.size() - t->get_num_args()); + m_result_stack.push_back(l); + } + } + void convert(app * t, bool root, bool sign) { if (t->get_family_id() == m.get_basic_family_id()) { switch (to_app(t)->get_decl_kind()) { @@ -376,13 +480,34 @@ struct goal2sat::imp { case OP_IFF: case OP_EQ: convert_iff(t, root, sign); - break; + break; default: UNREACHABLE(); } } - else if (t->get_family_id() == pb.get_fid()) { - NOT_IMPLEMENTED_YET(); + else if (m_ext && t->get_family_id() == pb.get_family_id()) { + switch (t->get_decl_kind()) { + case OP_AT_MOST_K: + convert_at_most_k(t, pb.get_k(t), root, sign); + break; + case OP_AT_LEAST_K: + convert_at_least_k(t, pb.get_k(t), root, sign); + break; + case OP_PB_LE: + SASSERT(pb.has_unit_coefficients(t)); + convert_at_most_k(t, pb.get_k(t), root, sign); + break; + case OP_PB_GE: + SASSERT(pb.has_unit_coefficients(t)); + convert_at_least_k(t, pb.get_k(t), root, sign); + break; + case OP_PB_EQ: + SASSERT(pb.has_unit_coefficients(t)); + convert_eq_k(t, pb.get_k(t), root, sign); + break; + default: + UNREACHABLE(); + } } else { UNREACHABLE();