diff --git a/src/opt/inc_sat_solver.cpp b/src/opt/inc_sat_solver.cpp index 5f121726a..85fef687a 100644 --- a/src/opt/inc_sat_solver.cpp +++ b/src/opt/inc_sat_solver.cpp @@ -57,7 +57,7 @@ public: virtual lbool check_sat(unsigned num_assumptions, expr * const * assumptions) { SASSERT(num_assumptions == 0); - m_solver.pop(m_solver.scope_lvl()); + m_solver.pop_to_base_level(); goal_ref_buffer result; proof_converter_ref pc; model_converter_ref mc; @@ -128,13 +128,13 @@ public: m_preprocess->set_cancel(f); } virtual void push() { - IF_VERBOSE(0, verbose_stream() << "push ignored\n";); + m_solver.user_push(); } virtual void pop(unsigned n) { - IF_VERBOSE(0, verbose_stream() << "pop ignored\n";); + m_solver.user_pop(n); } virtual unsigned get_scope_level() const { - return 0; + return m_solver.scope_lvl(); } virtual void assert_expr(expr * t, expr * a) { if (a) { diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index c357f7a29..bf6d0953f 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -140,7 +140,16 @@ namespace sat { for (unsigned i = 0; i < num_lits; i++) SASSERT(m_eliminated[lits[i].var()] == false); }); - mk_clause_core(num_lits, lits, false); + + if (m_user_scope_literals.empty()) { + mk_clause_core(num_lits, lits, false); + } + else { + m_aux_literals.reset(); + m_aux_literals.append(num_lits, lits); + m_aux_literals.append(m_user_scope_literals); + mk_clause_core(m_aux_literals.size(), m_aux_literals.c_ptr(), false); + } } void solver::mk_clause(literal l1, literal l2) { @@ -686,6 +695,7 @@ namespace sat { // // ----------------------- lbool solver::check(unsigned num_lits, literal const* lits) { + pop_to_base_level(); IF_VERBOSE(2, verbose_stream() << "(sat.sat-solver using the efficient SAT solver)\n";); SASSERT(scope_lvl() == 0); #ifdef CLONE_BEFORE_SOLVING @@ -716,7 +726,7 @@ namespace sat { m_restart_threshold = m_config.m_restart_initial; } - // iff3_finder(*this)(); + // iff3_finder(*this)(); simplify_problem(); if (inconsistent()) return l_false; @@ -858,18 +868,27 @@ namespace sat { } void solver::init_assumptions(unsigned num_lits, literal const* lits) { - if (num_lits == 0) { + if (num_lits == 0 && m_user_scope_literals.empty()) { return; } - push(); m_assumptions.reset(); - m_assumption_set.reset(); + m_assumption_set.reset(); + push(); + + TRACE("sat", display(tout);); +#define _INSERT_LIT(_l_) \ + SASSERT(is_external((_l_).var())); \ + m_assumption_set.insert(_l_); \ + m_assumptions.push_back(_l_); \ + mk_clause_core(1, &(_l_), false); \ + for (unsigned i = 0; i < num_lits; ++i) { - literal l = lits[i]; - SASSERT(is_external(l.var())); - m_assumption_set.insert(l); - m_assumptions.push_back(l); - mk_clause(1, &l); + literal lit = lits[i]; + _INSERT_LIT(lit); + } + for (unsigned i = 0; i < m_user_scope_literals.size(); ++i) { + literal nlit = ~m_user_scope_literals[i]; + _INSERT_LIT(nlit); } TRACE("sat", display(tout);); } @@ -879,7 +898,7 @@ namespace sat { push(); for (unsigned i = 0; i < m_assumptions.size(); ++i) { literal l = m_assumptions[i]; - mk_clause(1, &l); + mk_clause_core(1, &l, false); } } } @@ -911,6 +930,12 @@ namespace sat { \brief Apply all simplifications. */ void solver::simplify_problem() { + + if (tracking_assumptions()) { + // NB. simplification is disabled when tracking assumptions. + return; + } + SASSERT(scope_lvl() == 0); m_cleaner(); @@ -2110,6 +2135,71 @@ namespace sat { m_clauses_to_reinit.shrink(j); } + // + // All new clauses that are added to the solver + // are relative to the user-scope literals. + // + + void solver::user_push() { + literal lit; + if (m_user_scope_literal_pool.empty()) { + bool_var new_v = mk_var(true, false); + lit = literal(new_v, false); + } + else { + lit = m_user_scope_literal_pool.back(); + m_user_scope_literal_pool.pop_back(); + } + m_user_scope_literals.push_back(lit); + } + + void solver::gc_lit(clause_vector &clauses, literal lit) { + unsigned j = 0; + for (unsigned i = 0; i < clauses.size(); ++i) { + clause & c = *(clauses[i]); + if (c.contains(lit)) { + dettach_clause(c); + del_clause(c); + } + else { + clauses[j] = &c; + ++j; + } + } + clauses.shrink(j); + } + + void solver::gc_bin(bool learned, literal nlit) { + m_user_bin_clauses.reset(); + collect_bin_clauses(m_user_bin_clauses, learned); + for (unsigned i = 0; i < m_user_bin_clauses.size(); ++i) { + literal l1 = m_user_bin_clauses[i].first; + literal l2 = m_user_bin_clauses[i].second; + if (nlit == l1 || nlit == l2) { + dettach_bin_clause(l1, l2, learned); + } + } + } + + void solver::user_pop(unsigned num_scopes) { + pop_to_base_level(); + while (num_scopes > 0) { + literal lit = m_user_scope_literals.back(); + m_user_scope_literal_pool.push_back(lit); + m_user_scope_literals.pop_back(); + gc_lit(m_learned, lit); + gc_lit(m_clauses, lit); + gc_bin(true, lit); + gc_bin(false, lit); + TRACE("sat", tout << "gc: " << lit << "\n"; display(tout);); + --num_scopes; + } + } + + void solver::pop_to_base_level() { + pop(scope_lvl()); + } + // ----------------------- // // Misc diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 48067cc48..8c39599bb 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -224,6 +224,7 @@ namespace sat { if (m_cancel) throw solver_exception(Z3_CANCELED_MSG); if (memory::get_allocation_size() > m_config.m_max_memory) throw solver_exception(Z3_MAX_MEMORY_MSG); } + typedef std::pair bin_clause; protected: watch_list & get_wlist(literal l) { return m_watches[l.index()]; } watch_list const & get_wlist(literal l) const { return m_watches[l.index()]; } @@ -350,13 +351,21 @@ namespace sat { // // ----------------------- void push(); - public: void pop(unsigned num_scopes); - protected: void unassign_vars(unsigned old_sz); void reinit_clauses(unsigned old_sz); + literal_vector m_user_scope_literals; + literal_vector m_user_scope_literal_pool; + literal_vector m_aux_literals; + svector m_user_bin_clauses; + void gc_lit(clause_vector& clauses, literal lit); + void gc_bin(bool learned, literal nlit); + public: + void user_push(); + void user_pop(unsigned num_scopes); + void pop_to_base_level(); // ----------------------- // // Simplification @@ -400,7 +409,6 @@ namespace sat { clause * const * end_clauses() const { return m_clauses.end(); } clause * const * begin_learned() const { return m_learned.begin(); } clause * const * end_learned() const { return m_learned.end(); } - typedef std::pair bin_clause; void collect_bin_clauses(svector & r, bool learned) const; // ----------------------- diff --git a/src/sat/tactic/sat_tactic.cpp b/src/sat/tactic/sat_tactic.cpp index 8b575c177..e4d5e6df5 100644 --- a/src/sat/tactic/sat_tactic.cpp +++ b/src/sat/tactic/sat_tactic.cpp @@ -116,7 +116,7 @@ class sat_tactic : public tactic { #if 0 IF_VERBOSE(TACTIC_VERBOSITY_LVL, verbose_stream() << "\"formula constains interpreted atoms, recovering formula from sat solver...\"\n";); #endif - m_solver.pop(m_solver.scope_lvl()); + m_solver.pop_to_base_level(); m_sat2goal(m_solver, map, m_params, *(g.get()), mc); } g->inc_depth(); diff --git a/src/test/main.cpp b/src/test/main.cpp index 3b0179944..d97a94a16 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -219,6 +219,7 @@ int main(int argc, char ** argv) { TST(sorting_network); TST(theory_pb); TST(simplex); + TST(sat_user_scope); //TST_ARGV(hs); } diff --git a/src/test/sat_user_scope.cpp b/src/test/sat_user_scope.cpp new file mode 100644 index 000000000..e39059563 --- /dev/null +++ b/src/test/sat_user_scope.cpp @@ -0,0 +1,102 @@ +#include "sat_solver.h" +#include "util.h" + +typedef sat::literal_vector clause_t; +typedef vector clauses_t; +typedef vector trail_t; + +// [ [c1, c2, ..], [ ...] ] + +static unsigned s_num_vars = 6; +static unsigned s_num_clauses_per_frame = 8; +static unsigned s_num_frames = 7; + +static void add_literal(random_gen& r, clause_t& c) { + c.push_back(sat::literal(r(s_num_vars) + 1, r(2) == 0)); +} + +static clause_t& last_clause(trail_t& t) { + return t.back().back(); +} + +static void add_clause(sat::solver& s, random_gen& r, trail_t& t) { + t.back().push_back(sat::literal_vector()); + clause_t& cls = last_clause(t); + for (unsigned i = 0; i < 3; ++i) { + add_literal(r, cls); + } + s.mk_clause(cls.size(), cls.c_ptr()); +} + +static void display_state(std::ostream& out, sat::solver& s, trail_t& t) { + s.display(out); +} + +static void pop_user_scope(sat::solver& s, trail_t& t) { + std::cout << "pop\n"; + s.user_pop(1); + t.pop_back(); +} + +static void push_user_scope(sat::solver& s, trail_t& t) { + std::cout << "push\n"; + s.user_push(); + t.push_back(clauses_t()); +} + +static void init_vars(sat::solver& s) { + for (unsigned i = 0; i <= s_num_vars; ++i) { + s.mk_var(); + } +} + +static void check_coherence(sat::solver& s1, trail_t& t) { + params_ref p; + sat::solver s2(p, 0); + init_vars(s2); + sat::literal_vector cls; + for (unsigned i = 0; i < t.size(); ++i) { + clauses_t& clss = t[i]; + for (unsigned j = 0; j < clss.size(); ++j) { + cls.reset(); + cls.append(clss[j]); + s2.mk_clause(cls.size(), cls.c_ptr()); + } + } + lbool is_sat1 = s1.check(); + lbool is_sat2 = s2.check(); + if (is_sat1 != is_sat2) { + s1.display(std::cout); + s2.display(std::cout); + } + std::cout << is_sat1 << "\n"; + SASSERT(is_sat1 == is_sat2); +} + +void tst_sat_user_scope() { + random_gen r(0); + trail_t trail; + params_ref p; + sat::solver s(p, 0); // incremental solver + init_vars(s); + while (true) { + for (unsigned i = 0; i < s_num_frames; ++i) { + // push 3 frames, pop 2 + for (unsigned k = 0; k < 3; ++k) { + push_user_scope(s, trail); + for (unsigned j = 0; j < s_num_clauses_per_frame; ++j) { + add_clause(s, r, trail); + } + check_coherence(s, trail); + } + for (unsigned k = 0; k < 2; ++k) { + pop_user_scope(s, trail); + check_coherence(s, trail); + } + } + for (unsigned i = 0; i < s_num_frames; ++i) { + pop_user_scope(s, trail); + check_coherence(s, trail); + } + } +}