From f7e49501af7c87f25109de4430214349c7ec479a Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 13 Apr 2018 16:22:36 -0700 Subject: [PATCH] updates Signed-off-by: Nikolaj Bjorner --- src/ast/rewriter/pb2bv_rewriter.cpp | 156 +++++++++++++++++--------- src/sat/sat_config.cpp | 4 +- src/sat/sat_config.h | 3 +- src/sat/sat_solver/inc_sat_solver.cpp | 3 + 4 files changed, 109 insertions(+), 57 deletions(-) diff --git a/src/ast/rewriter/pb2bv_rewriter.cpp b/src/ast/rewriter/pb2bv_rewriter.cpp index 179773c34..9fc6bf7b3 100644 --- a/src/ast/rewriter/pb2bv_rewriter.cpp +++ b/src/ast/rewriter/pb2bv_rewriter.cpp @@ -56,9 +56,7 @@ struct pb2bv_rewriter::imp { rational m_k; vector m_coeffs; bool m_keep_cardinality_constraints; - bool m_keep_pb_constraints; - bool m_pb_num_system; - bool m_pb_totalizer; + symbol m_pb_solver; unsigned m_min_arity; template @@ -85,7 +83,7 @@ struct pb2bv_rewriter::imp { struct compare_coeffs { bool operator()(ca const& x, ca const& y) const { - return x.first < y.first; + return x.first > y.first; } }; @@ -126,11 +124,12 @@ struct pb2bv_rewriter::imp { if (i + 1 < sz && !m_coeffs[i+1].is_neg()) tout << "+ "; } switch (is_le) { - case l_true: tout << "<= "; break; + case l_true: tout << "<= "; break; case l_undef: tout << "= "; break; case l_false: tout << ">= "; break; } tout << k << "\n";); + if (k.is_zero()) { if (is_le != l_false) { return expr_ref(m.mk_not(::mk_or(m_args)), m); @@ -143,7 +142,7 @@ struct pb2bv_rewriter::imp { return expr_ref((is_le == l_false)?m.mk_true():m.mk_false(), m); } - if (m_pb_totalizer) { + if (m_pb_solver == "totalizer") { expr_ref result(m); switch (is_le) { case l_true: if (mk_le_tot(sz, args, k, result)) return result; else break; @@ -152,7 +151,7 @@ struct pb2bv_rewriter::imp { } } - if (m_pb_num_system) { + if (m_pb_solver == "sorting") { expr_ref result(m); switch (is_le) { case l_true: if (mk_le(sz, args, k, result)) return result; else break; @@ -161,6 +160,15 @@ struct pb2bv_rewriter::imp { } } + if (m_pb_solver == "segmented") { + expr_ref result(m); + switch (is_le) { + case l_true: return mk_seg_le(k); + case l_false: return mk_seg_ge(k); + case l_undef: break; + } + } + // fall back to divide and conquer encoding. SASSERT(k.is_pos()); expr_ref zero(m), bound(m); @@ -486,14 +494,82 @@ struct pb2bv_rewriter::imp { return true; } - expr_ref mk_and(expr_ref& a, expr_ref& b) { + /** + \brief Segment based encoding. + The PB terms are partitoned into segments, such that each segment contains arguments with the same cofficient. + The segments are sorted, such that the segment with highest coefficient is first. + Then for each segment create circuits based on sorting networks the arguments of the segment. + */ + + expr_ref mk_seg_ge(rational const& k) { + rational bound(-k); + for (unsigned i = 0; i < m_args.size(); ++i) { + m_args[i] = mk_not(m_args[i].get()); + bound += m_coeffs[i]; + } + return mk_seg_le(bound); + } + + expr_ref mk_seg_le(rational const& k) { + sort_args(); + unsigned sz = m_args.size(); + expr* const* args = m_args.c_ptr(); + + // Create sorted entries. + vector> outs; + vector coeffs; + for (unsigned i = 0, seg_size = 0; i < sz; i += seg_size) { + seg_size = segment_size(i); + ptr_vector out; + m_sort.sorting(seg_size, args + i, out); + out.push_back(m.mk_false()); + outs.push_back(out); + coeffs.push_back(m_coeffs[i]); + } + return mk_seg_le_rec(outs, coeffs, 0, k); + } + + expr_ref mk_seg_le_rec(vector> const& outs, vector const& coeffs, unsigned i, rational const& k) { + rational const& c = coeffs[i]; + ptr_vector const& out = outs[i]; + if (k.is_neg()) { + return expr_ref(m.mk_false(), m); + } + if (i == outs.size()) { + return expr_ref(m.mk_true(), m); + } + if (i + 1 == outs.size() && k >= rational(out.size()-1)*c) { + return expr_ref(m.mk_true(), m); + } + expr_ref_vector fmls(m); + fmls.push_back(m.mk_implies(m.mk_not(out[0]), mk_seg_le_rec(outs, coeffs, i + 1, k))); + rational k1; + for (unsigned j = 0; j + 1 < out.size(); ++j) { + k1 = k - rational(j+1)*c; + if (k1.is_neg()) { + fmls.push_back(m.mk_not(out[j])); + break; + } + fmls.push_back(m.mk_implies(m.mk_and(out[j], m.mk_not(out[j+1])), mk_seg_le_rec(outs, coeffs, i + 1, k1))); + } + return ::mk_and(fmls); + } + + // The number of arguments with the same coefficient. + unsigned segment_size(unsigned start) const { + unsigned i = start; + while (i < m_args.size() && m_coeffs[i] == m_coeffs[start]) ++i; + return i - start; + } + + expr_ref mk_and(expr_ref& a, expr_ref& b) { if (m.is_true(a)) return b; if (m.is_true(b)) return a; if (m.is_false(a)) return a; if (m.is_false(b)) return b; return expr_ref(m.mk_and(a, b), m); } - + expr_ref mk_or(expr_ref& a, expr_ref& b) { if (m.is_true(a)) return a; if (m.is_true(b)) return b; @@ -607,12 +683,12 @@ struct pb2bv_rewriter::imp { m_trail(m), m_args(m), m_keep_cardinality_constraints(false), - m_keep_pb_constraints(false), - m_pb_num_system(false), - m_pb_totalizer(false), + m_pb_solver(symbol("solver")), m_min_arity(9) {} + void set_pb_solver(symbol const& s) { m_pb_solver = s; } + bool mk_app(bool full, func_decl * f, unsigned sz, expr * const* args, expr_ref & result) { if (f->get_family_id() == pb.get_family_id() && mk_pb(full, f, sz, args, result)) { // skip @@ -756,13 +832,13 @@ struct pb2bv_rewriter::imp { result = m_sort.ge(full, pb.get_k(f).get_unsigned(), sz, args); ++m_imp.m_compile_card; } - else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_keep_pb_constraints) { + else if (pb.is_eq(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_pb_solver == "solver") { return false; } - else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_keep_pb_constraints) { + else if (pb.is_le(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_pb_solver == "solver") { return false; } - else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_keep_pb_constraints) { + else if (pb.is_ge(f) && pb.get_k(f).is_unsigned() && has_small_coefficients(f) && m_pb_solver == "solver") { return false; } else { @@ -811,17 +887,6 @@ struct pb2bv_rewriter::imp { m_keep_cardinality_constraints = f; } - void keep_pb_constraints(bool f) { - m_keep_pb_constraints = f; - } - - void pb_num_system(bool f) { - m_pb_num_system = f; - } - - void pb_totalizer(bool f) { - m_pb_totalizer = f; - } void set_at_most1(sorting_network_encoding enc) { m_sort.cfg().m_encoding = enc; } }; @@ -836,9 +901,7 @@ struct pb2bv_rewriter::imp { } card2bv_rewriter_cfg(imp& i, ast_manager & m):m_r(i, m) {} void keep_cardinality_constraints(bool f) { m_r.keep_cardinality_constraints(f); } - void keep_pb_constraints(bool f) { m_r.keep_pb_constraints(f); } - void pb_num_system(bool f) { m_r.pb_num_system(f); } - void pb_totalizer(bool f) { m_r.pb_totalizer(f); } + void set_pb_solver(symbol const& s) { m_r.set_pb_solver(s); } void set_at_most1(sorting_network_encoding enc) { m_r.set_at_most1(enc); } }; @@ -850,9 +913,7 @@ struct pb2bv_rewriter::imp { rewriter_tpl(m, false, m_cfg), m_cfg(i, m) {} void keep_cardinality_constraints(bool f) { m_cfg.keep_cardinality_constraints(f); } - void keep_pb_constraints(bool f) { m_cfg.keep_pb_constraints(f); } - void pb_num_system(bool f) { m_cfg.pb_num_system(f); } - void pb_totalizer(bool f) { m_cfg.pb_totalizer(f); } + void set_pb_solver(symbol const& s) { m_cfg.set_pb_solver(s); } void set_at_most1(sorting_network_encoding e) { m_cfg.set_at_most1(e); } void rewrite(expr* e, expr_ref& r, proof_ref& p) { expr_ref ee(e, m()); @@ -875,26 +936,15 @@ struct pb2bv_rewriter::imp { gparams::get_module("sat").get_bool("cardinality.solver", false); } - bool keep_pb() const { + symbol pb_solver() const { params_ref const& p = m_params; - return - p.get_bool("keep_pb_constraints", false) || - p.get_bool("sat.pb.solver", false) || - p.get_bool("pb.solver", false) || - gparams::get_module("sat").get_sym("pb.solver", symbol()) == symbol("solver") ; + symbol s = p.get_sym("sat.pb.solver", symbol()); + if (s != symbol()) return s; + s = p.get_sym("pb.solver", symbol()); + if (s != symbol()) return s; + return gparams::get_module("sat").get_sym("pb.solver", symbol("solver")); } - bool pb_num_system() const { - return m_params.get_bool("pb_num_system", false) || - gparams::get_module("sat").get_sym("pb.solver", symbol()) == symbol("sorting"); - } - - bool pb_totalizer() const { - return m_params.get_bool("pb_totalizer", false) || - gparams::get_module("sat").get_sym("pb.solver", symbol()) == symbol("totalizer"); - } - - sorting_network_encoding atmost1_encoding() const { symbol enc = m_params.get_sym("atmost1_encoding", symbol()); if (enc == symbol()) { @@ -920,16 +970,12 @@ struct pb2bv_rewriter::imp { void updt_params(params_ref const & p) { m_params.append(p); m_rw.keep_cardinality_constraints(keep_cardinality()); - m_rw.keep_pb_constraints(keep_pb()); - m_rw.pb_num_system(pb_num_system()); - m_rw.pb_totalizer(pb_totalizer()); + m_rw.set_pb_solver(pb_solver()); m_rw.set_at_most1(atmost1_encoding()); } void collect_param_descrs(param_descrs& r) const { r.insert("keep_cardinality_constraints", CPK_BOOL, "(default: true) retain cardinality constraints (don't bit-blast them) and use built-in cardinality solver"); - r.insert("keep_pb_constraints", CPK_BOOL, "(default: true) retain pb constraints (don't bit-blast them) and use built-in pb solver"); - r.insert("pb_num_system", CPK_BOOL, "(default: false) use pb number system encoding"); - r.insert("pb_totalizer", CPK_BOOL, "(default: false) use pb totalizer encoding"); + r.insert("pb.solver", CPK_SYMBOL, "(default: solver) retain pb constraints (don't bit-blast them) and use built-in pb solver"); } unsigned get_num_steps() const { return m_rw.get_num_steps(); } diff --git a/src/sat/sat_config.cpp b/src/sat/sat_config.cpp index 2439be36f..812e65abf 100644 --- a/src/sat/sat_config.cpp +++ b/src/sat/sat_config.cpp @@ -178,8 +178,10 @@ namespace sat { m_pb_solver = PB_TOTALIZER; else if (s == symbol("solver")) m_pb_solver = PB_SOLVER; + else if (s == symbol("segmented")) + m_pb_solver = PB_SEGMENTED; else - throw sat_param_exception("invalid PB solver: solver, totalizer, circuit, sorting"); + throw sat_param_exception("invalid PB solver: solver, totalizer, circuit, sorting, segmented"); m_card_solver = p.cardinality_solver(); diff --git a/src/sat/sat_config.h b/src/sat/sat_config.h index c70d52a90..6a704ab44 100644 --- a/src/sat/sat_config.h +++ b/src/sat/sat_config.h @@ -54,7 +54,8 @@ namespace sat { PB_SOLVER, PB_CIRCUIT, PB_SORTING, - PB_TOTALIZER + PB_TOTALIZER, + PB_SEGMENTED }; enum reward_t { diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index 4d7325ecb..a7499590b 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -281,9 +281,12 @@ public: m_params.append(p); sat_params p1(p); m_params.set_bool("keep_cardinality_constraints", p1.cardinality_solver()); + m_params.set_sym("pb.solver", p1.pb_solver()); + m_params.set_bool("keep_pb_constraints", m_solver.get_config().m_pb_solver == sat::PB_SOLVER); m_params.set_bool("pb_num_system", m_solver.get_config().m_pb_solver == sat::PB_SORTING); m_params.set_bool("pb_totalizer", m_solver.get_config().m_pb_solver == sat::PB_TOTALIZER); + m_params.set_bool("xor_solver", p1.xor_solver()); m_solver.updt_params(m_params); m_solver.set_incremental(is_incremental() && !override_incremental());