From ec36a9c495746551c6187cfe2662b5475fc9e6dc Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 22 Nov 2018 12:40:23 -0800 Subject: [PATCH] fix user push/pop with ba constraints Signed-off-by: Nikolaj Bjorner --- src/sat/ba_solver.cpp | 17 ++++++++++++ src/sat/ba_solver.h | 2 ++ src/sat/sat_extension.h | 1 + src/sat/sat_solver.cpp | 22 ++++++++------- src/sat/sat_solver.h | 1 + src/sat/tactic/goal2sat.cpp | 54 ++++++++++++++++++------------------- 6 files changed, 59 insertions(+), 38 deletions(-) diff --git a/src/sat/ba_solver.cpp b/src/sat/ba_solver.cpp index dd690a6a9..762939b74 100644 --- a/src/sat/ba_solver.cpp +++ b/src/sat/ba_solver.cpp @@ -64,6 +64,13 @@ namespace sat { return static_cast(*this); } + unsigned ba_solver::constraint::fold_max_var(unsigned w) const { + if (lit() != null_literal) w = std::max(w, lit().var()); + for (unsigned i = 0; i < size(); ++i) w = std::max(w, get_lit(i).var()); + return w; + } + + std::ostream& operator<<(std::ostream& out, ba_solver::constraint const& cnstr) { if (cnstr.lit() != null_literal) out << cnstr.lit() << " == "; switch (cnstr.tag()) { @@ -2660,6 +2667,16 @@ namespace sat { } c.set_psm(r); } + + unsigned ba_solver::max_var(unsigned w) const { + for (constraint* cp : m_constraints) { + w = cp->fold_max_var(w); + } + for (constraint* cp : m_learned) { + w = cp->fold_max_var(w); + } + return w; + } void ba_solver::gc() { if (m_learned.size() >= 2 * m_constraints.size() && diff --git a/src/sat/ba_solver.h b/src/sat/ba_solver.h index 558b22bf3..141ca0887 100644 --- a/src/sat/ba_solver.h +++ b/src/sat/ba_solver.h @@ -101,6 +101,7 @@ namespace sat { bool is_clear() const { return m_watch == null_literal && m_lit != null_literal; } bool is_pure() const { return m_pure; } void set_pure() { m_pure = true; } + unsigned fold_max_var(unsigned w) const; size_t obj_size() const { return m_obj_size; } card& to_card(); @@ -552,6 +553,7 @@ namespace sat { void find_mutexes(literal_vector& lits, vector & mutexes) override; void pop_reinit() override; void gc() override; + unsigned max_var(unsigned w) const override; double get_reward(literal l, ext_justification_idx idx, literal_occs_fun& occs) const override; bool is_extended_binary(ext_justification_idx idx, literal_vector & r) override; void init_use_list(ext_use_list& ul) override; diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index e687ab2b0..41aebb97e 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -80,6 +80,7 @@ namespace sat { virtual void init_use_list(ext_use_list& ul) = 0; virtual bool is_blocked(literal l, ext_constraint_idx) = 0; virtual bool check_model(model const& m) const = 0; + virtual unsigned max_var(unsigned w) const = 0; }; }; diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 1881d0375..ee3bc7880 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -3274,17 +3274,19 @@ namespace sat { } void solver::gc_var(bool_var v) { - if (v > 0) { - bool_var w = max_var(m_learned, v-1); - w = max_var(m_clauses, w); - w = max_var(true, w); - w = max_var(false, w); - v = m_mc.max_var(w); - for (literal lit : m_trail) { - if (lit.var() > w) w = lit.var(); - } - v = std::max(v, w + 1); + bool_var w = max_var(m_learned, v); + w = max_var(m_clauses, w); + w = max_var(true, w); + w = max_var(false, w); + v = m_mc.max_var(w); + for (literal lit : m_trail) { + w = std::max(w, lit.var()); } + if (m_ext) { + w = m_ext->max_var(w); + } + v = w + 1; + // v is an index of a variable that does not occur in solver state. if (v < m_level.size()) { for (bool_var i = v; i < m_level.size(); ++i) { diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 27c2e123c..8402fc898 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -543,6 +543,7 @@ namespace sat { void user_push(); void user_pop(unsigned num_scopes); void pop_to_base_level(); + unsigned num_user_scopes() const { return m_user_scope_literals.size(); } reslimit& rlimit() { return m_rlimit; } // ----------------------- // diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 645ef8a3f..d484ecda4 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -444,13 +444,24 @@ struct goal2sat::imp { convert_to_wlits(t, lits, wlits); } + void push_result(bool root, sat::literal lit, unsigned num_args) { + if (root) { + m_result_stack.reset(); + mk_clause(lit); + } + else { + m_result_stack.shrink(m_result_stack.size() - num_args); + m_result_stack.push_back(lit); + } + } + void convert_pb_ge(app* t, bool root, bool sign) { rational k = pb.get_k(t); check_unsigned(k); svector wlits; convert_pb_args(t, wlits); unsigned sz = m_result_stack.size(); - if (root) { + if (root && m_solver.num_user_scopes() == 0) { m_result_stack.reset(); unsigned k1 = k.get_unsigned(); if (sign) { @@ -467,8 +478,7 @@ struct goal2sat::imp { sat::literal lit(v, sign); m_ext->add_pb_ge(v, wlits, k.get_unsigned()); TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); - m_result_stack.shrink(sz - t->get_num_args()); - m_result_stack.push_back(lit); + push_result(root, lit, t->get_num_args()); } } @@ -483,7 +493,7 @@ struct goal2sat::imp { } check_unsigned(k); unsigned sz = m_result_stack.size(); - if (root) { + if (root && m_solver.num_user_scopes() == 0) { m_result_stack.reset(); unsigned k1 = k.get_unsigned(); if (sign) { @@ -500,19 +510,19 @@ struct goal2sat::imp { sat::literal lit(v, sign); m_ext->add_pb_ge(v, wlits, k.get_unsigned()); TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); - m_result_stack.shrink(sz - t->get_num_args()); - m_result_stack.push_back(lit); + push_result(root, lit, t->get_num_args()); } } void convert_pb_eq(app* t, bool root, bool sign) { - IF_VERBOSE(0, verbose_stream() << "pbeq: " << mk_pp(t, m) << "\n";); + //IF_VERBOSE(0, verbose_stream() << "pbeq: " << mk_pp(t, m) << "\n";); rational k = pb.get_k(t); SASSERT(k.is_unsigned()); svector wlits; convert_pb_args(t, wlits); - sat::bool_var v1 = (root && !sign) ? sat::null_bool_var : m_solver.mk_var(true); - sat::bool_var v2 = (root && !sign) ? sat::null_bool_var : m_solver.mk_var(true); + bool base_assert = (root && !sign && m_solver.num_user_scopes() == 0); + sat::bool_var v1 = base_assert ? sat::null_bool_var : m_solver.mk_var(true); + sat::bool_var v2 = base_assert ? sat::null_bool_var : m_solver.mk_var(true); m_ext->add_pb_ge(v1, wlits, k.get_unsigned()); k.neg(); for (wliteral& wl : wlits) { @@ -521,7 +531,7 @@ struct goal2sat::imp { } check_unsigned(k); m_ext->add_pb_ge(v2, wlits, k.get_unsigned()); - if (root && !sign) { + if (base_assert) { m_result_stack.reset(); } else { @@ -532,13 +542,8 @@ struct goal2sat::imp { mk_clause(~l, l2); mk_clause(~l1, ~l2, l); m_cache.insert(t, l); - m_result_stack.shrink(m_result_stack.size() - t->get_num_args()); if (sign) l.neg(); - m_result_stack.push_back(l); - if (root) { - m_result_stack.reset(); - mk_clause(l); - } + push_result(root, l, t->get_num_args()); } } @@ -547,7 +552,7 @@ struct goal2sat::imp { sat::literal_vector lits; unsigned sz = m_result_stack.size(); convert_pb_args(t->get_num_args(), lits); - if (root) { + if (root && m_solver.num_user_scopes() == 0) { m_result_stack.reset(); m_ext->add_at_least(sat::null_bool_var, lits, k.get_unsigned()); } @@ -558,8 +563,7 @@ struct goal2sat::imp { m_cache.insert(t, lit); if (sign) lit.neg(); TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); - m_result_stack.shrink(sz - t->get_num_args()); - m_result_stack.push_back(lit); + push_result(root, lit, t->get_num_args()); } } @@ -571,7 +575,7 @@ struct goal2sat::imp { for (sat::literal& l : lits) { l.neg(); } - if (root) { + if (root && m_solver.num_user_scopes() == 0) { m_result_stack.reset(); m_ext->add_at_least(sat::null_bool_var, lits, lits.size() - k.get_unsigned()); } @@ -580,9 +584,8 @@ struct goal2sat::imp { sat::literal lit(v, false); m_ext->add_at_least(v, lits, lits.size() - k.get_unsigned()); m_cache.insert(t, lit); - m_result_stack.shrink(sz - t->get_num_args()); if (sign) lit.neg(); - m_result_stack.push_back(lit); + push_result(root, lit, t->get_num_args()); } } @@ -610,13 +613,8 @@ struct goal2sat::imp { mk_clause(~l, l2); mk_clause(~l1, ~l2, l); m_cache.insert(t, l); - m_result_stack.shrink(m_result_stack.size() - t->get_num_args()); if (sign) l.neg(); - m_result_stack.push_back(l); - if (root) { - mk_clause(l); - m_result_stack.reset(); - } + push_result(root, l, t->get_num_args()); } }