From 23b9d3ef5572e2cb0dede71974690876cc5d67ef Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 22 Oct 2016 18:50:16 -0700 Subject: [PATCH] fix at-most-1 constraint compiler bug Signed-off-by: Nikolaj Bjorner --- src/api/api_opt.cpp | 31 ++++++++++ src/api/python/z3/z3.py | 9 +++ src/api/z3_api.h | 2 +- src/api/z3_optimization.h | 27 ++++++++ src/ast/rewriter/arith_rewriter.cpp | 43 ++++++++++++- src/opt/maxres.cpp | 11 ++-- src/opt/opt_context.cpp | 79 ++++++++++++++++++------ src/opt/opt_context.h | 7 ++- src/smt/smt_internalizer.cpp | 2 +- src/smt/theory_arith_aux.h | 2 +- src/smt/theory_arith_int.h | 2 +- src/smt/theory_pb.cpp | 25 +++++--- src/test/sorting_network.cpp | 95 ++++++++++++++++++++++++++--- src/util/mpq.cpp | 4 +- src/util/mpq.h | 2 +- src/util/rational.h | 2 +- src/util/sorting_network.h | 88 ++++++++++++++++++++++---- 17 files changed, 369 insertions(+), 62 deletions(-) diff --git a/src/api/api_opt.cpp b/src/api/api_opt.cpp index 58d8902c3..20eb6d1d4 100644 --- a/src/api/api_opt.cpp +++ b/src/api/api_opt.cpp @@ -27,6 +27,7 @@ Revision History: #include"cancel_eh.h" #include"scoped_timer.h" #include"smt2parser.h" +#include"api_ast_vector.h" extern "C" { @@ -296,6 +297,36 @@ extern "C" { } + Z3_ast_vector Z3_API Z3_optimize_get_assertions(Z3_context c, Z3_optimize o) { + Z3_TRY; + LOG_Z3_optimize_get_assertions(c, o); + RESET_ERROR_CODE(); + Z3_ast_vector_ref * v = alloc(Z3_ast_vector_ref, *mk_c(c), mk_c(c)->m()); + mk_c(c)->save_object(v); + expr_ref_vector hard(mk_c(c)->m()); + to_optimize_ptr(o)->get_hard_constraints(hard); + for (unsigned i = 0; i < hard.size(); i++) { + v->m_ast_vector.push_back(hard[i].get()); + } + RETURN_Z3(of_ast_vector(v)); + Z3_CATCH_RETURN(0); + } + + unsigned Z3_API Z3_optimize_get_num_objectives(Z3_context c, Z3_optimize o) { + RESET_ERROR_CODE(); + return to_optimize_ptr(o)->num_objectives(); + } + + Z3_ast Z3_API Z3_optimize_get_objective(Z3_context c, Z3_optimize o, unsigned index) { + Z3_TRY; + LOG_Z3_optimize_get_objective(c, o, index); + RESET_ERROR_CODE(); + expr_ref result = to_optimize_ptr(o)->get_objective(index); + mk_c(c)->save_ast_trail(result); + RETURN_Z3(of_expr(result)); + Z3_CATCH_RETURN(0); + } + }; diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index c360a9bdc..bc5bd8153 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -6796,6 +6796,15 @@ class Optimize(Z3PPObject): """Parse assertions and objectives from a string""" Z3_optimize_from_string(self.ctx.ref(), self.optimize, s) + def assertions(self): + """Return an AST vector containing all added constraints.""" + return AstVector(Z3_optimize_get_assertions(self.ctx.ref(), self.optimize), self.ctx) + + def objectives(self): + """returns set of objective functions""" + num = Z3_optimize_get_num_objectives(self.ctx.ref(), self.optimize) + return [_to_expr_ref(Z3_optimize_get_objective(self.ctx.ref(), self.optimize, i), self.ctx) for i in range(num)] + def __repr__(self): """Return a formatted string with all added rules and constraints.""" return self.sexpr() diff --git a/src/api/z3_api.h b/src/api/z3_api.h index cd889b3be..9e9771884 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -5832,7 +5832,7 @@ extern "C" { void Z3_API Z3_solver_assert_and_track(Z3_context c, Z3_solver s, Z3_ast a, Z3_ast p); /** - \brief Return the set of asserted formulas as a goal object. + \brief Return the set of asserted formulas on the solver. def_API('Z3_solver_get_assertions', AST_VECTOR, (_in(CONTEXT), _in(SOLVER))) */ diff --git a/src/api/z3_optimization.h b/src/api/z3_optimization.h index 15a6dff16..e78be7881 100644 --- a/src/api/z3_optimization.h +++ b/src/api/z3_optimization.h @@ -239,6 +239,33 @@ extern "C" { def_API('Z3_optimize_get_statistics', STATS, (_in(CONTEXT), _in(OPTIMIZE))) */ Z3_stats Z3_API Z3_optimize_get_statistics(Z3_context c, Z3_optimize d); + + + /** + \brief Return the set of asserted formulas on the optimization context. + + def_API('Z3_optimize_get_assertions', AST_VECTOR, (_in(CONTEXT), _in(OPTIMIZE))) + */ + Z3_ast_vector Z3_API Z3_optimize_get_assertions(Z3_context c, Z3_optimize o); + + /** + \brief Return number of objectives on the optimization context. + + def_API('Z3_optimize_get_num_objectives', UINT, (_in(CONTEXT), _in(OPTIMIZE))) + */ + unsigned Z3_API Z3_optimize_get_num_objectives(Z3_context c, Z3_optimize o); + + /** + \brief Return i'th objective function. If the objective function is a max-sat objective it is returned + as a Pseudo-Boolean (minimization) sum of the form (+ (if f1 w1 0) (if f2 w2 0) ...) + If the objective function is entered as a maximization objective, then return the corresponding minimizaiton + objective. In this way the resulting objective function is always returned as a minimization objective. + + def_API('Z3_optimize_get_objective', AST, (_in(CONTEXT), _in(OPTIMIZE), _in(UINT))) + */ + Z3_ast Z3_API Z3_optimize_get_objective(Z3_context c, Z3_optimize o, unsigned index); + + /*@}*/ /*@}*/ diff --git a/src/ast/rewriter/arith_rewriter.cpp b/src/ast/rewriter/arith_rewriter.cpp index 368476b8e..81385c2af 100644 --- a/src/ast/rewriter/arith_rewriter.cpp +++ b/src/ast/rewriter/arith_rewriter.cpp @@ -162,8 +162,8 @@ bool arith_rewriter::div_polynomial(expr * t, numeral const & g, const_treatment } bool arith_rewriter::is_bound(expr * arg1, expr * arg2, op_kind kind, expr_ref & result) { - numeral c; - if (!is_add(arg1) && is_numeral(arg2, c)) { + numeral b, c; + if (!is_add(arg1) && !m_util.is_mod(arg1) && is_numeral(arg2, c)) { numeral a; bool r = false; expr * pp = get_power_product(arg1, a); @@ -193,6 +193,45 @@ bool arith_rewriter::is_bound(expr * arg1, expr * arg2, op_kind kind, expr_ref & case EQ: result = m_util.mk_eq(pp, k); return true; } } + expr* t1, *t2; + bool is_int; + if (m_util.is_mod(arg2)) { + std::swap(arg1, arg2); + switch (kind) { + case LE: kind = GE; break; + case GE: kind = LE; break; + case EQ: break; + } + } + + if (m_util.is_numeral(arg2, c, is_int) && is_int && + m_util.is_mod(arg1, t1, t2) && m_util.is_numeral(t2, b, is_int) && !b.is_zero()) { + // mod x b <= c = false if c < 0, b != 0, true if c >= b, b != 0 + if (c.is_neg()) { + switch (kind) { + case EQ: + case LE: result = m().mk_false(); return true; + case GE: result = m().mk_true(); return true; + } + } + if (c.is_zero() && kind == GE) { + result = m().mk_true(); + return true; + } + if (c.is_pos() && c >= abs(b)) { + switch (kind) { + case LE: result = m().mk_true(); return true; + case EQ: + case GE: result = m().mk_false(); return true; + } + } + // mod x b <= b - 1 + if (c + rational::one() == abs(b) && kind == LE) { + result = m().mk_true(); + return true; + } + } + return false; } diff --git a/src/opt/maxres.cpp b/src/opt/maxres.cpp index d9f060784..d8b166924 100644 --- a/src/opt/maxres.cpp +++ b/src/opt/maxres.cpp @@ -297,12 +297,16 @@ public: sort_assumptions(mutex); ptr_vector core(mutex.size(), mutex.c_ptr()); remove_soft(core, m_asms); - rational weight(0); + rational weight(0), sum1(0), sum2(0); + for (unsigned i = 0; i < mutex.size(); ++i) { + sum1 += get_weight(mutex[i].get()); + } while (!mutex.empty()) { expr_ref soft = mk_or(mutex); rational w = get_weight(mutex.back()); weight = w - weight; m_lower += weight*rational(mutex.size()-1); + sum2 += weight*rational(mutex.size()); add_soft(soft, weight); mutex.pop_back(); while (!mutex.empty() && get_weight(mutex.back()) == w) { @@ -310,6 +314,7 @@ public: } weight = w; } + SASSERT(sum1 == sum2); } lbool check_sat_hill_climb(expr_ref_vector& asms1) { @@ -398,7 +403,7 @@ public: while (is_sat == l_false) { core.reset(); s().get_unsat_core(core); - //verify_core(core); + // verify_core(core); model_ref mdl; get_mus_model(mdl); is_sat = minimize_core(core); @@ -772,8 +777,6 @@ public: for (unsigned i = 0; i < m_soft.size(); ++i) { m_assignment[i] = is_true(m_soft[i]); } - - DEBUG_CODE(verify_assignment();); diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index 07bb8385b..1335685ff 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -181,6 +181,43 @@ namespace opt { clear_state(); } + void context::get_hard_constraints(expr_ref_vector& hard) { + hard.append(m_scoped_state.m_hard); + } + + expr_ref context::get_objective(unsigned i) { + SASSERT(i < num_objectives()); + objective const& o = m_scoped_state.m_objectives[i]; + expr_ref result(m), zero(m); + expr_ref_vector args(m); + switch (o.m_type) { + case O_MAXSMT: + zero = m_arith.mk_numeral(rational(0), false); + for (unsigned i = 0; i < o.m_terms.size(); ++i) { + args.push_back(m.mk_ite(o.m_terms[i], zero, m_arith.mk_numeral(o.m_weights[i], false))); + } + result = m_arith.mk_add(args.size(), args.c_ptr()); + break; + case O_MAXIMIZE: + result = o.m_term; + if (m_arith.is_arith_expr(result)) { + result = m_arith.mk_uminus(result); + } + else if (m_bv.is_bv(result)) { + result = m_bv.mk_bv_neg(result); + } + else { + UNREACHABLE(); + } + break; + case O_MINIMIZE: + result = o.m_term; + break; + } + return result; + } + + unsigned context::add_soft_constraint(expr* f, rational const& w, symbol const& id) { clear_state(); return m_scoped_state.add(f, w, id); @@ -1328,14 +1365,21 @@ namespace opt { } std::string context::to_string() const { + return to_string(m_scoped_state.m_hard, m_scoped_state.m_objectives); + } + + std::string context::to_string_internal() const { + return to_string(m_hard_constraints, m_objectives); + } + + std::string context::to_string(expr_ref_vector const& hard, vector const& objectives) const { smt2_pp_environment_dbg env(m); ast_pp_util visitor(m); std::ostringstream out; -#define PP(_e_) ast_smt2_pp(out, _e_, env); - visitor.collect(m_scoped_state.m_hard); + visitor.collect(hard); - for (unsigned i = 0; i < m_scoped_state.m_objectives.size(); ++i) { - objective const& obj = m_scoped_state.m_objectives[i]; + for (unsigned i = 0; i < objectives.size(); ++i) { + objective const& obj = objectives[i]; switch(obj.m_type) { case O_MAXIMIZE: case O_MINIMIZE: @@ -1351,33 +1395,34 @@ namespace opt { } visitor.display_decls(out); - visitor.display_asserts(out, m_scoped_state.m_hard, m_pp_neat); - for (unsigned i = 0; i < m_scoped_state.m_objectives.size(); ++i) { - objective const& obj = m_scoped_state.m_objectives[i]; + visitor.display_asserts(out, hard, m_pp_neat); + for (unsigned i = 0; i < objectives.size(); ++i) { + objective const& obj = objectives[i]; switch(obj.m_type) { case O_MAXIMIZE: out << "(maximize "; - PP(obj.m_term); + ast_smt2_pp(out, obj.m_term, env); out << ")\n"; break; case O_MINIMIZE: out << "(minimize "; - PP(obj.m_term); + ast_smt2_pp(out, obj.m_term, env); out << ")\n"; break; case O_MAXSMT: for (unsigned j = 0; j < obj.m_terms.size(); ++j) { out << "(assert-soft "; - PP(obj.m_terms[j]); + ast_smt2_pp(out, obj.m_terms[j], env); rational w = obj.m_weights[j]; - if (w.is_int()) { - out << " :weight " << w; - } - else { - out << " :dweight " << w; - } + + w.display_decimal(out << " :weight ", 3, true); if (obj.m_id != symbol::null) { - out << " :id " << obj.m_id; + if (is_smt2_quoted_symbol(obj.m_id)) { + out << " :id " << mk_smt2_quoted_symbol(obj.m_id); + } + else { + out << " :id " << obj.m_id; + } } out << ")\n"; } diff --git a/src/opt/opt_context.h b/src/opt/opt_context.h index ac1fe8e7a..18af756bf 100644 --- a/src/opt/opt_context.h +++ b/src/opt/opt_context.h @@ -175,6 +175,8 @@ namespace opt { unsigned add_objective(app* t, bool is_max); void add_hard_constraint(expr* f); + void get_hard_constraints(expr_ref_vector& hard); + expr_ref get_objective(unsigned i); virtual void push(); virtual void pop(unsigned n); @@ -208,7 +210,7 @@ namespace opt { std::string to_string() const; - virtual unsigned num_objectives() { return m_objectives.size(); } + virtual unsigned num_objectives() { return m_scoped_state.m_objectives.size(); } virtual expr_ref mk_gt(unsigned i, model_ref& model); virtual expr_ref mk_ge(unsigned i, model_ref& model); virtual expr_ref mk_le(unsigned i, model_ref& model); @@ -284,6 +286,9 @@ namespace opt { void display_objective(std::ostream& out, objective const& obj) const; void display_bounds(std::ostream& out, bounds_t const& b) const; + std::string to_string(expr_ref_vector const& hard, vector const& objectives) const; + std::string to_string_internal() const; + void validate_lex(); diff --git a/src/smt/smt_internalizer.cpp b/src/smt/smt_internalizer.cpp index d4880f7d7..94ce453b4 100644 --- a/src/smt/smt_internalizer.cpp +++ b/src/smt/smt_internalizer.cpp @@ -1281,7 +1281,7 @@ namespace smt { The deletion event handler is ignored if binary clause optimization is applicable. */ clause * context::mk_clause(unsigned num_lits, literal * lits, justification * j, clause_kind k, clause_del_eh * del_eh) { - TRACE("mk_clause", tout << "creating clause:\n"; display_literals(tout, num_lits, lits); tout << "\n";); + TRACE("mk_clause", tout << "creating clause:\n"; display_literals_verbose(tout, num_lits, lits); tout << "\n";); switch (k) { case CLS_AUX: { literal_buffer simp_lits; diff --git a/src/smt/theory_arith_aux.h b/src/smt/theory_arith_aux.h index 163452d47..d2db3a603 100644 --- a/src/smt/theory_arith_aux.h +++ b/src/smt/theory_arith_aux.h @@ -1709,7 +1709,7 @@ namespace smt { SASSERT(!maintain_integrality || valid_assignment()); SASSERT(satisfy_bounds()); } - TRACE("opt", display(tout);); + TRACE("opt_verbose", display(tout);); return (best_efforts>0 || ctx.get_cancel_flag())?BEST_EFFORT:result; } diff --git a/src/smt/theory_arith_int.h b/src/smt/theory_arith_int.h index d3b1f0f10..c06f82c8b 100644 --- a/src/smt/theory_arith_int.h +++ b/src/smt/theory_arith_int.h @@ -1385,7 +1385,7 @@ namespace smt { m_branch_cut_counter++; // TODO: add giveup code if (m_branch_cut_counter % m_params.m_arith_branch_cut_ratio == 0) { - TRACE("opt", display(tout);); + TRACE("opt_verbose", display(tout);); move_non_base_vars_to_bounds(); if (!make_feasible()) { TRACE("arith_int", tout << "failed to move variables to bounds.\n";); diff --git a/src/smt/theory_pb.cpp b/src/smt/theory_pb.cpp index 90cd020c3..384c58d73 100644 --- a/src/smt/theory_pb.cpp +++ b/src/smt/theory_pb.cpp @@ -321,7 +321,8 @@ namespace smt { if (m_simplex.upper_valid(v)) { m_simplex.get_upper(v, last_bound); if (m_mpq_inf_mgr.gt(bound, last_bound)) { - literal lit = m_explain_upper.get(v, null_literal); + literal lit = m_explain_upper.get(v, null_literal); + TRACE("pb", tout << ~lit << " " << ~explain << "\n";); get_context().mk_clause(~lit, ~explain, justify(~lit, ~explain)); return false; } @@ -342,6 +343,7 @@ namespace smt { m_simplex.get_lower(v, last_bound); if (m_mpq_inf_mgr.gt(last_bound, bound)) { literal lit = m_explain_lower.get(v, null_literal); + TRACE("pb", tout << ~lit << " " << ~explain << "\n";); get_context().mk_clause(~lit, ~explain, justify(~lit, ~explain)); return false; } @@ -405,6 +407,7 @@ namespace smt { if (proofs_enabled()) { js = alloc(theory_lemma_justification, get_id(), ctx, lits.size(), lits.c_ptr()); } + TRACE("pb", tout << lits << "\n";); ctx.mk_clause(lits.size(), lits.c_ptr(), js, CLS_AUX_LEMMA, 0); return false; @@ -515,11 +518,10 @@ namespace smt { ++log; n *= 2; } - unsigned th = args.size()*log; // 10* + unsigned th = args.size()*log; c->m_compilation_threshold = th; - IF_VERBOSE(2, verbose_stream() << "(smt.pb setting compilation threhshold to " << th << ")\n";); + IF_VERBOSE(2, verbose_stream() << "(smt.pb setting compilation threshold to " << th << " " << c->k() << ")\n";); TRACE("pb", tout << "compilation threshold: " << th << "\n";); - //compile_ineq(*c); } else { c->m_compilation_threshold = UINT_MAX; @@ -1247,9 +1249,9 @@ namespace smt { literal_vector in; for (unsigned i = 0; i < num_args; ++i) { rational n = c.coeff(i); - lbool val = ctx.get_assignment(c.lit()); - if (val != l_undef && - ctx.get_assign_level(thl) == ctx.get_base_level()) { + literal lit = c.lit(i); + lbool val = ctx.get_assignment(lit); + if (val != l_undef && ctx.get_assign_level(lit) == ctx.get_base_level()) { if (val == l_true) { unsigned m = n.get_unsigned(); if (k < m) { @@ -1264,6 +1266,8 @@ namespace smt { n -= rational::one(); } } + + TRACE("pb", tout << in << " >= " << k << "\n";); if (ctx.get_assignment(thl) == l_true && @@ -1272,6 +1276,7 @@ namespace smt { psort_nw sortnw(ps); sortnw.m_stats.reset(); at_least_k = sortnw.ge(false, k, in.size(), in.c_ptr()); + TRACE("pb", tout << ~thl << " " << at_least_k << "\n";); ctx.mk_clause(~thl, at_least_k, justify(~thl, at_least_k)); m_stats.m_num_compiled_vars += sortnw.m_stats.m_num_compiled_vars; m_stats.m_num_compiled_clauses += sortnw.m_stats.m_num_compiled_clauses; @@ -1281,6 +1286,7 @@ namespace smt { psort_nw sortnw(ps); sortnw.m_stats.reset(); literal at_least_k = sortnw.ge(true, k, in.size(), in.c_ptr()); + TRACE("pb", tout << ~thl << " " << at_least_k << "\n";); ctx.mk_clause(~thl, at_least_k, justify(~thl, at_least_k)); ctx.mk_clause(~at_least_k, thl, justify(thl, ~at_least_k)); m_stats.m_num_compiled_vars += sortnw.m_stats.m_num_compiled_vars; @@ -1290,7 +1296,6 @@ namespace smt { << "(smt.pb compile sorting network bound: " << k << " literals: " << in.size() << ")\n";); - TRACE("pb", tout << thl << "\n";); // auxiliary clauses get removed when popping scopes. // we have to recompile the circuit after back-tracking. c.m_compiled = l_false; @@ -1300,7 +1305,6 @@ namespace smt { void theory_pb::init_search_eh() { - m_to_compile.reset(); } void theory_pb::push_scope_eh() { @@ -1329,6 +1333,7 @@ namespace smt { m_ineq_rep.erase(r_info.m_rep); } } + m_to_compile.erase(c); dealloc(c); } m_ineqs_lim.resize(new_lim); @@ -1454,6 +1459,7 @@ namespace smt { if (proofs_enabled()) { js = alloc(theory_lemma_justification, get_id(), ctx, lits.size(), lits.c_ptr()); } + TRACE("pb", tout << lits << "\n";); ctx.mk_clause(lits.size(), lits.c_ptr(), js, CLS_AUX_LEMMA, 0); } @@ -1760,6 +1766,7 @@ namespace smt { for (unsigned i = 0; i < m_ineq_literals.size(); ++i) { m_ineq_literals[i].neg(); } + TRACE("pb", tout << m_ineq_literals << "\n";); ctx.mk_clause(m_ineq_literals.size(), m_ineq_literals.c_ptr(), justify(m_ineq_literals), CLS_AUX_LEMMA, 0); break; default: { diff --git a/src/test/sorting_network.cpp b/src/test/sorting_network.cpp index 57e818542..8b2aadee3 100644 --- a/src/test/sorting_network.cpp +++ b/src/test/sorting_network.cpp @@ -332,31 +332,106 @@ void test_sorting5(unsigned n, unsigned k) { test_sorting_ge(n, k); } -void test_at_most_1(unsigned n) { +expr_ref naive_at_most1(expr_ref_vector const& xs) { + ast_manager& m = xs.get_manager(); + expr_ref_vector clauses(m); + for (unsigned i = 0; i < xs.size(); ++i) { + for (unsigned j = i + 1; j < xs.size(); ++j) { + clauses.push_back(m.mk_not(m.mk_and(xs[i], xs[j]))); + } + } + return mk_and(clauses); +} + +void test_at_most_1(unsigned n, bool full) { ast_manager m; reg_decl_plugins(m); expr_ref_vector in(m), out(m); for (unsigned i = 0; i < n; ++i) { in.push_back(m.mk_fresh_const("a",m.mk_bool_sort())); } - + ast_ext2 ext(m); + psort_nw sn(ext); + expr_ref result1(m), result2(m); + result1 = sn.le(full, 1, in.size(), in.c_ptr()); + result2 = naive_at_most1(in); + + std::cout << "clauses: " << ext.m_clauses << "\n-----\n"; + + smt_params fp; + smt::kernel solver(m, fp); + for (unsigned i = 0; i < ext.m_clauses.size(); ++i) { + solver.assert_expr(ext.m_clauses[i].get()); + } + lbool res; + if (full) { + solver.push(); + solver.assert_expr(m.mk_not(m.mk_eq(result1, result2))); + + std::cout << result1 << "\n"; + + res = solver.check(); + SASSERT(res == l_false); + + solver.pop(1); + } + + if (n >= 9) return; + for (unsigned i = 0; i < static_cast(1 << n); ++i) { + std::cout << "checking: " << n << ": " << i << "\n"; + solver.push(); + unsigned k = 0; + for (unsigned j = 0; j < n; ++j) { + bool is_true = (i & (1 << j)) != 0; + expr_ref atom(m); + atom = is_true ? in[j].get() : m.mk_not(in[j].get()); + solver.assert_expr(atom); + std::cout << atom << "\n"; + if (is_true) ++k; + } + res = solver.check(); + SASSERT(res == l_true); + if (k > 1) { + solver.assert_expr(result1); + } + else if (!full) { + solver.pop(1); + continue; + } + else { + solver.assert_expr(m.mk_not(result1)); + } + res = solver.check(); + SASSERT(res == l_false); + solver.pop(1); + } +} + + +static void test_at_most1() { + ast_manager m; + reg_decl_plugins(m); + expr_ref_vector in(m), out(m); + for (unsigned i = 0; i < 5; ++i) { + in.push_back(m.mk_fresh_const("a",m.mk_bool_sort())); + } + in[4] = in[3]; + ast_ext2 ext(m); psort_nw sn(ext); expr_ref result(m); - result = sn.le(false, 1, in.size(), in.c_ptr()); + result = sn.le(true, 1, in.size(), in.c_ptr()); std::cout << result << "\n"; std::cout << ext.m_clauses << "\n"; } void tst_sorting_network() { - test_at_most_1(1); - test_at_most_1(2); - test_at_most_1(3); - test_at_most_1(4); - test_at_most_1(5); - test_at_most_1(10); - return; + for (unsigned i = 1; i < 17; ++i) { + test_at_most_1(i, true); + test_at_most_1(i, false); + } + test_at_most1(); test_sorting_eq(11,7); for (unsigned n = 3; n < 20; n += 2) { diff --git a/src/util/mpq.cpp b/src/util/mpq.cpp index df4d207a6..feb051033 100644 --- a/src/util/mpq.cpp +++ b/src/util/mpq.cpp @@ -153,7 +153,7 @@ void mpq_manager::display_smt2(std::ostream & out, mpq const & a, bool de } template -void mpq_manager::display_decimal(std::ostream & out, mpq const & a, unsigned prec) { +void mpq_manager::display_decimal(std::ostream & out, mpq const & a, unsigned prec, bool truncate) { mpz n1, d1, v1; get_numerator(a, n1); get_denominator(a, d1); @@ -177,7 +177,7 @@ void mpq_manager::display_decimal(std::ostream & out, mpq const & a, unsi if (is_zero(n1)) goto end; // number is precise } - out << "?"; + if (!truncate) out << "?"; end: del(ten); del(n1); del(d1); del(v1); } diff --git a/src/util/mpq.h b/src/util/mpq.h index 0a643c650..093cc0a44 100644 --- a/src/util/mpq.h +++ b/src/util/mpq.h @@ -265,7 +265,7 @@ public: void display_smt2(std::ostream & out, mpq const & a, bool decimal) const; - void display_decimal(std::ostream & out, mpq const & a, unsigned prec); + void display_decimal(std::ostream & out, mpq const & a, unsigned prec, bool truncate = false); void add(mpz const & a, mpz const & b, mpz & c) { mpz_manager::add(a, b, c); } diff --git a/src/util/rational.h b/src/util/rational.h index fc25837c6..ba447fca6 100644 --- a/src/util/rational.h +++ b/src/util/rational.h @@ -86,7 +86,7 @@ public: void display(std::ostream & out) const { return m().display(out, m_val); } - void display_decimal(std::ostream & out, unsigned prec) const { return m().display_decimal(out, m_val, prec); } + void display_decimal(std::ostream & out, unsigned prec, bool truncate = false) const { return m().display_decimal(out, m_val, prec, truncate); } bool is_uint64() const { return m().is_uint64(m_val); } diff --git a/src/util/sorting_network.h b/src/util/sorting_network.h index 242d4f43e..31ad8a452 100644 --- a/src/util/sorting_network.h +++ b/src/util/sorting_network.h @@ -234,19 +234,28 @@ Notes: private: literal mk_at_most_1(bool full, unsigned n, literal const* xs) { + TRACE("pb", tout << (full?"full":"partial") << " "; + for (unsigned i = 0; i < n; ++i) tout << xs[i] << " "; + tout << "\n";); + + if (!full && n >= 4) { + return mk_at_most_1_bimander(n, xs); + } literal_vector in(n, xs); - literal result = ctx.fresh(); + literal result = fresh(); unsigned inc_size = 4; + literal_vector ands; + ands.push_back(result); while (!in.empty()) { literal_vector ors; unsigned i = 0; unsigned n = in.size(); bool last = n <= inc_size; for (; i + inc_size < n; i += inc_size) { - mk_at_most_1_small(full, last, inc_size, in.c_ptr() + i, result, ors); + mk_at_most_1_small(full, last, inc_size, in.c_ptr() + i, result, ands, ors); } if (i < n) { - mk_at_most_1_small(full, last, n - i, in.c_ptr() + i, result, ors); + mk_at_most_1_small(full, last, n - i, in.c_ptr() + i, result, ands, ors); } if (last) { break; @@ -255,12 +264,22 @@ Notes: in.append(ors); ors.reset(); } + if (full) { + add_clause(ands); + } return result; } - void mk_at_most_1_small(bool full, bool last, unsigned n, literal const* xs, literal result, literal_vector& ors) { + void mk_at_most_1_small(bool full, bool last, unsigned n, literal const* xs, literal result, literal_vector& ands, literal_vector& ors) { + SASSERT(n > 0); + if (n == 1) { + if (!last) { + ors.push_back(xs[0]); + } + return; + } if (!last) { - literal ex = ctx.fresh(); + literal ex = fresh(); for (unsigned j = 0; j < n; ++j) { add_clause(ctx.mk_not(xs[j]), ex); } @@ -271,16 +290,59 @@ Notes: } ors.push_back(ex); } + // result => xs[0] + ... + xs[n-1] <= 1 for (unsigned i = 0; i < n; ++i) { for (unsigned j = i + 1; j < n; ++j) { add_clause(ctx.mk_not(result), ctx.mk_not(xs[i]), ctx.mk_not(xs[j])); } - if (full) { - add_clause(result, xs[i]); + } + // xs[0] + ... + xs[n-1] <= 1 => and_x + if (full) { + literal and_i = fresh(); + for (unsigned i = 0; i < n; ++i) { + literal_vector lits; + lits.push_back(and_i); + for (unsigned j = 0; j < n; ++j) { + if (j != i) lits.push_back(xs[j]); + } + add_clause(lits); } + ands.push_back(ctx.mk_not(and_i)); } } + literal mk_at_most_1_bimander(unsigned n, literal const* xs) { + literal_vector in(n, xs); + literal result = fresh(); + unsigned inc_size = 2; + bool last = false; + bool full = false; + literal_vector ors, ands; + unsigned i = 0; + for (; i + inc_size < n; i += inc_size) { + mk_at_most_1_small(full, last, inc_size, in.c_ptr() + i, result, ands, ors); + } + if (i < n) { + mk_at_most_1_small(full, last, n - i, in.c_ptr() + i, result, ands, ors); + } + + unsigned nbits = 0; + while (static_cast(1 << nbits) < ors.size()) { + ++nbits; + } + literal_vector bits; + for (unsigned k = 0; k < nbits; ++k) { + bits.push_back(fresh()); + } + for (i = 0; i < ors.size(); ++i) { + for (unsigned k = 0; k < nbits; ++k) { + bool bit_set = (i & (static_cast(1 << k))) != 0; + add_clause(ctx.mk_not(result), ctx.mk_not(ors[i]), bit_set ? bits[k] : ctx.mk_not(bits[k])); + } + } + return result; + } + std::ostream& pp(std::ostream& out, unsigned n, literal const* lits) { for (unsigned i = 0; i < n; ++i) ctx.pp(out, lits[i]) << " "; return out; @@ -344,9 +406,13 @@ Notes: literal lits[2] = { l1, l2 }; add_clause(2, lits); } + void add_clause(literal_vector const& lits) { + add_clause(lits.size(), lits.c_ptr()); + } void add_clause(unsigned n, literal const* ls) { m_stats.m_num_compiled_clauses++; literal_vector tmp(n, ls); + TRACE("pb", for (unsigned i = 0; i < n; ++i) tout << ls[i] << " "; tout << "\n";); ctx.mk_clause(n, tmp.c_ptr()); } @@ -383,7 +449,7 @@ Notes: } void card(unsigned k, unsigned n, literal const* xs, literal_vector& out) { - TRACE("pb", tout << "card k:" << k << " n: " << n << "\n";); + TRACE("pb", tout << "card k: " << k << " n: " << n << "\n";); if (n <= k) { psort_nw::sorting(n, xs, out); } @@ -397,7 +463,7 @@ Notes: card(k, n-l, xs + l, out2); smerge(k, out1.size(), out1.c_ptr(), out2.size(), out2.c_ptr(), out); } - TRACE("pb", tout << "card k:" << k << " n: " << n << "\n"; + TRACE("pb", tout << "card k: " << k << " n: " << n << "\n"; pp(tout << "in:", n, xs) << "\n"; pp(tout << "out:", out) << "\n";); @@ -743,7 +809,7 @@ Notes: if (j < b) { ls.push_back(as[i]); ls.push_back(bs[j]); - add_clause(ls.size(), ls.c_ptr()); + add_clause(ls); ls.pop_back(); ls.pop_back(); } @@ -804,7 +870,7 @@ Notes: pp(tout, lits) << "\n";); SASSERT(k + offset <= n); if (k == 0) { - add_clause(lits.size(), lits.c_ptr()); + add_clause(lits); return; } for (unsigned i = offset; i < n - k + 1; ++i) {