From 05a39cb2cfea64d834da636f61d6a8181dabf1ab Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 9 May 2014 08:51:07 -0700 Subject: [PATCH] fix wrong simplex backtracking Signed-off-by: Nikolaj Bjorner --- src/ast/pb_decl_plugin.cpp | 2 +- src/math/simplex/simplex.h | 3 +- src/math/simplex/simplex_def.h | 64 ++++++++++++++++++++-- src/opt/opt_sls_solver.h | 2 + src/opt/pb_sls.cpp | 21 +++++--- src/opt/weighted_maxsat.cpp | 67 +++++++++++++---------- src/smt/theory_pb.cpp | 12 ++--- src/smt/theory_pb.h | 5 +- src/test/simplex.cpp | 2 - src/test/theory_pb.cpp | 99 ++++++++++++++++++++++++++++++++++ src/util/sorting_network.h | 10 ++-- 11 files changed, 231 insertions(+), 56 deletions(-) diff --git a/src/ast/pb_decl_plugin.cpp b/src/ast/pb_decl_plugin.cpp index c04ed5e8c..0cfe1c096 100644 --- a/src/ast/pb_decl_plugin.cpp +++ b/src/ast/pb_decl_plugin.cpp @@ -144,7 +144,7 @@ app * pb_util::mk_lt(unsigned num_args, rational const * _coeffs, expr * const * args.push_back(f); } else { - args.push_back(m.mk_not(f)); + args.push_back(m.mk_not(_args[i])); } } if (!d.is_one()) { diff --git a/src/math/simplex/simplex.h b/src/math/simplex/simplex.h index 7a5ba2d70..3f164f275 100644 --- a/src/math/simplex/simplex.h +++ b/src/math/simplex/simplex.h @@ -126,7 +126,7 @@ namespace simplex { row get_infeasible_row(); var_t get_base_var(row const& r) const { return m_row2base[r.id()]; } numeral const& get_base_coeff(row const& r) const { return m_vars[m_row2base[r.id()]].m_base_coeff; } - void del_row(row const& r); + void del_row(var_t base_var); void set_lower(var_t var, eps_numeral const& b); void set_upper(var_t var, eps_numeral const& b); void get_lower(var_t var, scoped_eps_numeral& b) const { b = m_vars[var].m_lower; } @@ -157,6 +157,7 @@ namespace simplex { private: + void del_row(row const& r); var_t select_var_to_fix(); pivot_strategy_t pivot_strategy(); var_t select_smallest_var() { return m_to_patch.empty()?null_var:m_to_patch.erase_min(); } diff --git a/src/math/simplex/simplex_def.h b/src/math/simplex/simplex_def.h index c3a2bef95..c1a4ec36e 100644 --- a/src/math/simplex/simplex_def.h +++ b/src/math/simplex/simplex_def.h @@ -114,6 +114,7 @@ namespace simplex { em.set(m_vars[base_var].m_value, value); add_patch(base_var); SASSERT(well_formed_row(r)); + SASSERT(well_formed()); return r; } @@ -135,10 +136,51 @@ namespace simplex { template void simplex::del_row(row const& r) { - TRACE("simplex", tout << r.id() << "\n";); - m_vars[m_row2base[r.id()]].m_is_base = false; + var_t var = m_row2base[r.id()]; + m_vars[var].m_is_base = false; + m_vars[var].m_lower_valid = false; + m_vars[var].m_upper_valid = false; m_row2base[r.id()] = null_var; M.del(r); + SASSERT(M.col_begin(var) == M.col_end(var)); + SASSERT(well_formed()); + } + + template + void simplex::del_row(var_t var) { + TRACE("simplex", tout << var << "\n";); + row r; + if (is_base(var)) { + r = row(m_vars[var].m_base2row); + } + else { + col_iterator it = M.col_begin(var), end = M.col_end(var); + if (it == end) { + return; + } + typename matrix::row_entry const& re = it.get_row_entry(); + r = it.get_row(); + var_t old_base = m_row2base[r.id()]; + scoped_eps_numeral new_value(em); + var_info& vi = m_vars[old_base]; + if (below_lower(old_base)) { + new_value = vi.m_lower; + } + else if (above_upper(old_base)) { + new_value = vi.m_upper; + } + else { + new_value = vi.m_value; + } + // need to move var such that old_base comes in bound. + update_and_pivot(old_base, var, re.m_coeff, new_value); + SASSERT(is_base(var)); + SASSERT(m_vars[var].m_base2row == r.id()); + SASSERT(!below_lower(old_base) && !above_upper(old_base)); + } + del_row(r); + TRACE("simplex", display(tout);); + SASSERT(well_formed()); } template @@ -164,6 +206,7 @@ namespace simplex { em.sub(b, vi.m_value, delta); update_value(var, delta); } + SASSERT(well_formed()); } template @@ -177,6 +220,7 @@ namespace simplex { em.sub(b, vi.m_value, delta); update_value(var, delta); } + SASSERT(well_formed()); } template @@ -194,6 +238,7 @@ namespace simplex { scoped_eps_numeral delta(em); em.sub(b, m_vars[var].m_value, delta); update_value(var, delta); + SASSERT(well_formed()); } template @@ -345,6 +390,7 @@ namespace simplex { SASSERT(well_formed_row(row(r_k))); } } + SASSERT(well_formed()); } template @@ -883,7 +929,13 @@ namespace simplex { var_t s = m_row2base[i]; if (s == null_var) continue; SASSERT(i == m_vars[s].m_base2row); - SASSERT(well_formed_row(row(i))); + VERIFY(well_formed_row(row(i))); + } + for (unsigned i = 0; i < m_vars.size(); ++i) { + if (!is_base(i)) { + SASSERT(!above_upper(i)); + SASSERT(!below_lower(i)); + } } return true; } @@ -909,7 +961,11 @@ namespace simplex { sum += tmp; SASSERT(s != it->m_var || m.eq(m_vars[s].m_base_coeff, it->m_coeff)); } - SASSERT(em.is_zero(sum)); + if (!em.is_zero(sum)) { + IF_VERBOSE(0, M.display_row(verbose_stream(), r);); + TRACE("pb", display(tout << "non-well formed row\n"); M.display_row(tout << "row: ", r);); + throw default_exception("non-well formed row"); + } return true; } diff --git a/src/opt/opt_sls_solver.h b/src/opt/opt_sls_solver.h index e2cac1d89..817537a75 100644 --- a/src/opt/opt_sls_solver.h +++ b/src/opt/opt_sls_solver.h @@ -215,6 +215,7 @@ namespace opt { } (*m_pbsls.get())(); m_pbsls->get_model(m_model); + mdl = m_model.get(); } void bvsls_opt(model_ref& mdl) { @@ -235,6 +236,7 @@ namespace opt { SASSERT(res.is_sat == l_true || res.is_sat == l_undef); if (res.is_sat == l_true) { m_bvsls->get_model(m_model); + mdl = m_model.get(); } } diff --git a/src/opt/pb_sls.cpp b/src/opt/pb_sls.cpp index 1052e44b6..a702b6b7b 100644 --- a/src/opt/pb_sls.cpp +++ b/src/opt/pb_sls.cpp @@ -129,7 +129,8 @@ namespace smt { m_trail(m), one(mgr) { - init_max_flips(); + reset(); + one = mpz(1); } ~imp() { @@ -153,7 +154,6 @@ namespace smt { m_assignment.push_back(true); m_hard_occ.push_back(unsigned_vector()); m_soft_occ.push_back(unsigned_vector()); - one = mpz(1); } void init_max_flips() { @@ -406,7 +406,14 @@ namespace smt { else if (break_count == min_bc && m_rng(5) == 1) { min_bc_index = i; } - VERIFY(-break_count == flip(~lit)); + int new_break_count = flip(~lit); + if (-break_count != new_break_count) { + verbose_stream() << lit << "\n"; + IF_VERBOSE(0, display(verbose_stream(), cls);); + display(verbose_stream()); + exit(0); + } + // VERIFY(-break_count == flip(~lit)); } if (m_rng(100) <= m_non_greedy_percent) { lit = cls.m_lits[m_rng(cls.m_lits.size())]; @@ -472,11 +479,12 @@ namespace smt { m_assignment[l.var()] = !m_assignment[l.var()]; int break_count = 0; unsigned_vector const& occh = m_hard_occ[l.var()]; - scoped_mpz value(mgr); + scoped_mpz value(mgr); for (unsigned i = 0; i < occh.size(); ++i) { unsigned j = occh[i]; - value = m_clauses[j].m_value; - if (eval(m_clauses[j])) { + clause& cls = m_clauses[j]; + value = cls.m_value; + if (eval(cls)) { if (m_hard_false.contains(j)) { break_count--; m_hard_false.remove(j); @@ -488,7 +496,6 @@ namespace smt { m_hard_false.insert(j); } else if (value < m_clauses[j].m_value) { - break_count++; } } } diff --git a/src/opt/weighted_maxsat.cpp b/src/opt/weighted_maxsat.cpp index fbd028338..112b0522a 100644 --- a/src/opt/weighted_maxsat.cpp +++ b/src/opt/weighted_maxsat.cpp @@ -82,6 +82,7 @@ namespace opt { } } virtual void get_model(model_ref& mdl) { mdl = m_model.get(); } + void set_model() { s().get_model(m_model); } virtual void updt_params(params_ref& p) { m_params.copy(p); s().updt_params(p); @@ -104,8 +105,12 @@ namespace opt { m_upper.reset(); m_assignment.reset(); for (unsigned i = 0; i < m_weights.size(); ++i) { - m_upper += m_weights[i]; - m_assignment.push_back(false); + expr_ref val(m); + VERIFY(m_model->eval(m_soft[i].get(), val)); + m_assignment.push_back(m.is_true(val)); + if (!m_assignment.back()) { + m_upper += m_weights[i]; + } } } expr* mk_not(expr* e) { @@ -255,6 +260,14 @@ namespace opt { m_upper += rational(1); } + void process_sat() { + svector assignment; + update_assignment(assignment); + if (check_lazy_soft(assignment)) { + update_sigmas(); + } + } + public: bcd2(solver* s, ast_manager& m): maxsmt_solver_base(s, m), @@ -263,7 +276,6 @@ namespace opt { m_trail(m), m_soft_constraints(m), m_enable_lazy(true) { - m_enable_lazy = true; } virtual ~bcd2() {} @@ -272,11 +284,15 @@ namespace opt { expr_ref fml(m), r(m); lbool is_sat = l_undef; expr_ref_vector asms(m); - bool first = true; enable_sls(); solver::scoped_push _scope1(s()); init(); init_bcd(); + if (m_cancel) { + normalize_bounds(); + return l_undef; + } + process_sat(); while (m_lower < m_upper) { IF_VERBOSE(1, verbose_stream() << "(wmaxsat.bcd2 [" << m_lower << ":" << m_upper << "])\n";); assert_soft(); @@ -293,15 +309,9 @@ namespace opt { case l_undef: normalize_bounds(); return l_undef; - case l_true: { - svector assignment; - update_assignment(assignment); - first = false; - if (check_lazy_soft(assignment)) { - update_sigmas(); - } - break; - } + case l_true: + process_sat(); + break; case l_false: { ptr_vector unsat_core; uint_set subC, soft; @@ -322,7 +332,7 @@ namespace opt { r = mk_fresh(); relax(subC, soft, c_s.m_R, delta); c_s.m_lower = refine(c_s.m_R, lower + delta - rational(1)); - c_s.m_upper = rational(first?1:0); + c_s.m_upper = rational::one(); c_s.m_upper += sum_of_sigmas(c_s.m_R); c_s.m_mid = div(c_s.m_lower + c_s.m_upper, rational(2)); c_s.m_r = r; @@ -337,12 +347,7 @@ namespace opt { m_lower = compute_lower(); } normalize_bounds(); - if (first) { - return is_sat; - } - else { - return l_true; - } + return l_true; } @@ -533,13 +538,18 @@ namespace opt { expr_ref fml(m); vector ws; ptr_vector rs; + rational w(0); for (unsigned j = 0; j < core.m_R.size(); ++j) { unsigned idx = core.m_R[j]; ws.push_back(m_weights[idx]); - rs.push_back(m_soft_aux[idx].get()); // TBD: check + w += ws.back(); + rs.push_back(m_soft_aux[idx].get()); } + w.neg(); + w += core.m_mid; + ws.push_back(w); + rs.push_back(core.m_r); fml = pb.mk_le(ws.size(), ws.c_ptr(), rs.c_ptr(), core.m_mid); - fml = m.mk_or(core.m_r, fml); s().assert_expr(fml); } void display(std::ostream& out) { @@ -604,7 +614,7 @@ namespace opt { } lbool is_sat = l_true; while (l_true == is_sat) { - IF_VERBOSE(1, verbose_stream() << "(wmaxsat.pb solve with upper bound: " << m_upper << ")\n";); + TRACE("opt", s().display(tout<<"looping\n");); m_upper.reset(); for (unsigned i = 0; i < m_soft.size(); ++i) { VERIFY(m_model->eval(nsoft[i].get(), val)); @@ -614,11 +624,10 @@ namespace opt { m_upper += m_weights[i]; } } + IF_VERBOSE(1, verbose_stream() << "(wmaxsat.pb solve with upper bound: " << m_upper << ")\n";); TRACE("opt", tout << "new upper: " << m_upper << "\n";); - fml = u.mk_lt(nsoft.size(), m_weights.c_ptr(), nsoft.c_ptr(), m_upper); - - TRACE("opt", s().display(tout<<"looping\n");); + fml = m.mk_not(u.mk_ge(nsoft.size(), m_weights.c_ptr(), nsoft.c_ptr(), m_upper)); solver::scoped_push _scope2(s()); s().assert_expr(fml); is_sat = s().check_sat(0,0); @@ -840,7 +849,11 @@ namespace opt { tout << mk_pp(bs[i], m) << " " << ws[i] << "\n"; }); maxs->init_soft(ws, nbs); - lbool is_sat = (*maxs)(); + lbool is_sat = maxs->s().check_sat(0,0); + if (is_sat == l_true) { + maxs->set_model(); + is_sat = (*maxs)(); + } SASSERT(maxs->get_lower() > k); k = maxs->get_lower(); return is_sat; diff --git a/src/smt/theory_pb.cpp b/src/smt/theory_pb.cpp index 495c757fa..f19eb706c 100644 --- a/src/smt/theory_pb.cpp +++ b/src/smt/theory_pb.cpp @@ -529,11 +529,9 @@ namespace smt { numeral k = rep.k(); theory_var slack; bool_var abv2; - row r; TRACE("pb", display(tout << abv <<"\n", rep);); if (m_ineq_rep.find(rep, abv2)) { slack = abv2; - r = m_ineq_row_info.find(abv2).m_row; TRACE("pb", tout << "Old row: " << abv << " |-> " << slack << " "; tout << m_ineq_row_info.find(abv2).m_bound << " vs. " << k << "\n"; @@ -572,10 +570,10 @@ namespace smt { m_simplex.ensure_var(slack); vars.push_back(slack); coeffs.push_back(mpz(-1)); - r = m_simplex.add_row(slack, vars.size(), vars.c_ptr(), coeffs.c_ptr()); + m_simplex.add_row(slack, vars.size(), vars.c_ptr(), coeffs.c_ptr()); TRACE("pb", tout << "New row: " << abv << " " << k << "\n"; display(tout, rep);); } - m_ineq_row_info.insert(abv, row_info(slack, k, rep, r)); + m_ineq_row_info.insert(abv, row_info(slack, k, rep)); } TRACE("pb", display(tout, *c);); @@ -1310,7 +1308,7 @@ namespace smt { m_ineq_row_info.erase(v); bool_var v2 = m_ineq_rep.find(r_info.m_rep); if (v == v2) { - m_simplex.del_row(r_info.m_row); + m_simplex.del_row(r_info.m_slack); m_ineq_rep.erase(r_info.m_rep); } } @@ -1433,7 +1431,7 @@ namespace smt { justification* js = 0; - if (m_conflict_frequency == 0 || (0 == (c.m_num_propagations % m_conflict_frequency))) { + if (m_conflict_frequency == 0 || (m_conflict_frequency -1 == (c.m_num_propagations % m_conflict_frequency))) { resolve_conflict(c); } @@ -1854,6 +1852,8 @@ namespace smt { case l_undef: maxsum += c.coeff(i); break; + case l_false: + break; } } TRACE("pb", display(tout << "validate: ", c, true); diff --git a/src/smt/theory_pb.h b/src/smt/theory_pb.h index 595cf5e6e..b7922fbd0 100644 --- a/src/smt/theory_pb.h +++ b/src/smt/theory_pb.h @@ -185,9 +185,8 @@ namespace smt { unsigned m_slack; // slack variable in simplex tableau numeral m_bound; // bound arg_t m_rep; // representative - row m_row; - row_info(theory_var slack, numeral const& b, arg_t const& r, row const& ro): - m_slack(slack), m_bound(b), m_rep(r), m_row(ro) {} + row_info(theory_var slack, numeral const& b, arg_t const& r): + m_slack(slack), m_bound(b), m_rep(r) {} row_info(): m_slack(0) {} }; diff --git a/src/test/simplex.cpp b/src/test/simplex.cpp index 37d4501f7..a70d0d8cf 100644 --- a/src/test/simplex.cpp +++ b/src/test/simplex.cpp @@ -124,8 +124,6 @@ static void test4() { feas(S); } - - void tst_simplex() { Simplex S; diff --git a/src/test/theory_pb.cpp b/src/test/theory_pb.cpp index 8c9ef405b..ee1ec126a 100644 --- a/src/test/theory_pb.cpp +++ b/src/test/theory_pb.cpp @@ -3,6 +3,7 @@ #include "model_v2_pp.h" #include "reg_decl_plugins.h" #include "theory_pb.h" +#include "th_rewriter.h" unsigned populate_literals(unsigned k, smt::literal_vector& lits) { SASSERT(k < (1u << lits.size())); @@ -19,7 +20,105 @@ unsigned populate_literals(unsigned k, smt::literal_vector& lits) { return t; } +class pb_fuzzer { + ast_manager& m; + random_gen rand; + smt_params params; + smt::context ctx; + expr_ref_vector vars; + +public: + pb_fuzzer(ast_manager& m): m(m), rand(0), ctx(m, params), vars(m) { + params.m_model = true; + params.m_pb_enable_simplex = true; + unsigned N = 3; + for (unsigned i = 0; i < N; ++i) { + std::stringstream strm; + strm << "b" << i; + vars.push_back(m.mk_const(symbol(strm.str().c_str()), m.mk_bool_sort())); + std::cout << "(declare-const " << strm.str() << " Bool)\n"; + } + } + + void fuzz() { + enable_trace("pb"); + enable_trace("simplex"); + unsigned nr = 0; + for (unsigned i = 0; i < 100000; ++i) { + fuzz_round(nr, 2); + } + } + +private: + + void add_ineq() { + pb_util pb(m); + expr_ref fml(m), tmp(m); + th_rewriter rw(m); + vector coeffs(vars.size()); + expr_ref_vector args(vars); + while (true) { + rational k(rand(6)); + for (unsigned i = 0; i < coeffs.size(); ++i) { + int v = 3 - rand(5); + coeffs[i] = rational(v); + if (coeffs[i].is_neg()) { + args[i] = m.mk_not(args[i].get()); + coeffs[i].neg(); + k += coeffs[i]; + } + } + fml = pb.mk_ge(args.size(), coeffs.c_ptr(), args.c_ptr(), k); + rw(fml, tmp); + rw(tmp, tmp); + if (pb.is_ge(tmp)) { + fml = tmp; + break; + } + } + std::cout << "(assert " << fml << ")\n"; + ctx.assert_expr(fml); + } + + + + void fuzz_round(unsigned& num_rounds, unsigned lvl) { + unsigned num_rounds2 = 0; + lbool is_sat = l_true; + std::cout << "(push)\n"; + ctx.push(); + unsigned r = 0; + while (is_sat == l_true && r <= num_rounds + 1) { + add_ineq(); + std::cout << "(check-sat)\n"; + is_sat = ctx.check(); + if (lvl > 0 && is_sat == l_true) { + fuzz_round(num_rounds2, lvl-1); + } + ++r; + } + num_rounds = r; + std::cout << "; number of rounds: " << num_rounds << " level: " << lvl << "\n"; + ctx.pop(1); + std::cout << "(pop)\n"; + } + +}; + + + +static void fuzz_pb() +{ + ast_manager m; + reg_decl_plugins(m); + pb_fuzzer fuzzer(m); + fuzzer.fuzz(); +} + void tst_theory_pb() { + + fuzz_pb(); + ast_manager m; smt_params params; params.m_model = true; diff --git a/src/util/sorting_network.h b/src/util/sorting_network.h index b9cf86433..f403f9e16 100644 --- a/src/util/sorting_network.h +++ b/src/util/sorting_network.h @@ -187,7 +187,7 @@ Notes: else { SASSERT(2*k <= n); m_t = full?GE_FULL:GE; - card(k, n, xs, out); + psort_nw::card(k, n, xs, out); return out[k-1]; } } @@ -322,8 +322,8 @@ Notes: void cmp(literal x1, literal x2, literal y1, literal y2) { switch(m_t) { - case LE: cmp_le(x1, x2, y1, y2); break; - case GE: cmp_ge(x1, x2, y1, y2); break; + case LE: case LE_FULL: cmp_le(x1, x2, y1, y2); break; + case GE: case GE_FULL: cmp_ge(x1, x2, y1, y2); break; case EQ: cmp_eq(x1, x2, y1, y2); break; } } @@ -334,7 +334,7 @@ Notes: void card(unsigned k, unsigned n, literal const* xs, literal_vector& out) { TRACE("pb", tout << "card k:" << k << " n: " << n << "\n";); if (n <= k) { - sorting(n, xs, out); + psort_nw::sorting(n, xs, out); } else if (use_dcard(k, n)) { dsorting(k, n, xs, out); @@ -485,7 +485,7 @@ Notes: out.push_back(xs[0]); break; case 2: - merge(1, xs, 1, xs+1, out); + psort_nw::merge(1, xs, 1, xs+1, out); break; default: if (use_dsorting(n)) {