diff --git a/src/sat/sat_clause.h b/src/sat/sat_clause.h index 27a0ed739..11824b247 100644 --- a/src/sat/sat_clause.h +++ b/src/sat/sat_clause.h @@ -73,6 +73,8 @@ namespace sat { bool check_approx() const; // for debugging literal * begin() { return m_lits; } literal * end() { return m_lits + m_size; } + literal const * begin() const { return m_lits; } + literal const * end() const { return m_lits + m_size; } bool contains(literal l) const; bool contains(bool_var v) const; bool satisfied_by(model const & m) const; diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 6f117de72..e45373b5c 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -3004,7 +3004,8 @@ namespace sat { // Iterators // // ----------------------- - void solver::collect_bin_clauses(svector & r, bool learned) const { + void solver::collect_bin_clauses(svector & r, bool learned, bool learned_only) const { + SASSERT(learned || !learned_only); unsigned sz = m_watches.size(); for (unsigned l_idx = 0; l_idx < sz; l_idx++) { literal l = to_literal(l_idx); @@ -3017,6 +3018,8 @@ namespace sat { continue; if (!learned && it->is_learned()) continue; + else if (learned && learned_only && !it->is_learned()) + continue; literal l2 = it->get_literal(); if (l.index() > l2.index()) continue; @@ -3327,7 +3330,6 @@ namespace sat { m_user_bin_clauses.reset(); m_binary_clause_graph.reset(); collect_bin_clauses(m_user_bin_clauses, true); - collect_bin_clauses(m_user_bin_clauses, false); hashtable, default_eq > seen_bc; for (unsigned i = 0; i < m_user_bin_clauses.size(); ++i) { literal l1 = m_user_bin_clauses[i].first; diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index ea8246466..348e1cd22 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -586,7 +586,8 @@ namespace sat { clause * const * end_clauses() const { return m_clauses.end(); } clause * const * begin_learned() const { return m_learned.begin(); } clause * const * end_learned() const { return m_learned.end(); } - void collect_bin_clauses(svector & r, bool learned) const; + clause_vector const& learned() const { return m_learned; } + void collect_bin_clauses(svector & r, bool learned, bool learned_only = false) const; // ----------------------- // diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index c636c3f2a..eb931d428 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -302,6 +302,15 @@ public: } return expr_ref(lit2expr[l.index()].get(), m); } + virtual void get_lemmas(expr_ref_vector & lemmas) { + IF_VERBOSE(1, verbose_stream() << "(sat-get-lemmas " << lemmas.size() << ")\n";); + if (!m_internalized) return; + sat2goal s2g; + goal g(m, false, false, false); + s2g.get_learned(m_solver, m_map, m_params, lemmas); + // TBD: handle externals properly. + } + virtual lbool get_consequences_core(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector& conseq) { init_preprocess(); @@ -426,9 +435,9 @@ public: g.get_formulas(m_internalized_fmls); // g.display(std::cout); m_internalized_converted = true; - // if (mc) mc->display(std::cout << "mc"); - // if (m_mc) m_mc->display(std::cout << "m_mc\n"); - // if (m_mc0) m_mc0->display(std::cout << "m_mc0\n"); + // if (mc) mc->display(std::cout << "mc"); + // if (m_mc) m_mc->display(std::cout << "m_mc\n"); + // if (m_mc0) m_mc0->display(std::cout << "m_mc0\n"); } void init_preprocess() { diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 6604bab89..cb4a19e89 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -36,6 +36,7 @@ Notes: #include"model_v2_pp.h" #include"tactic.h" #include"ast_pp.h" +#include"ast_util.h" #include"pb_decl_plugin.h" #include"card_extension.h" #include @@ -1143,7 +1144,6 @@ struct sat2goal::imp { assert_clauses(s, s.begin_clauses(), s.end_clauses(), r, true); assert_clauses(s, s.begin_learned(), s.end_learned(), r, false); - // TBD: collect assertions from plugin sat::card_extension* ext = get_card_extension(s); if (ext) { for (unsigned i = 0; i < ext->num_pb(); ++i) { @@ -1158,6 +1158,73 @@ struct sat2goal::imp { } } + void add_clause(sat::literal_vector const& lits, expr_ref_vector& lemmas) { + expr_ref_vector lemma(m); + for (sat::literal l : lits) { + expr* e = m_lit2expr.get(l.index(), 0); + if (!e) return; + lemma.push_back(e); + } + lemmas.push_back(mk_or(lemma)); + } + + void add_clause(sat::clause const& c, expr_ref_vector& lemmas) { + expr_ref_vector lemma(m); + for (sat::literal l : c) { + expr* e = m_lit2expr.get(l.index(), 0); + if (!e) return; + lemma.push_back(e); + } + lemmas.push_back(mk_or(lemma)); + } + + void get_learned(sat::solver const& s, atom2bool_var const& map, expr_ref_vector& lemmas) { + if (s.inconsistent()) { + lemmas.push_back(m.mk_false()); + return; + } + + unsigned num_vars = s.num_vars(); + m_lit2expr.resize(num_vars * 2); + map.mk_inv(m_lit2expr); + + sat::literal_vector lits; + // collect units + for (sat::bool_var v = 0; v < num_vars; v++) { + checkpoint(); + lits.reset(); + switch (s.value(v)) { + case l_true: + lits.push_back(sat::literal(v, false)); + add_clause(lits, lemmas); + break; + case l_false: + lits.push_back(sat::literal(v, false)); + add_clause(lits, lemmas); + break; + case l_undef: + break; + } + } + // collect learned binary clauses + svector bin_clauses; + s.collect_bin_clauses(bin_clauses, true, true); + svector::iterator it = bin_clauses.begin(); + svector::iterator end = bin_clauses.end(); + for (; it != end; ++it) { + checkpoint(); + lits.reset(); + lits.push_back(it->first); + lits.push_back(it->second); + add_clause(lits, lemmas); + } + // collect clauses + for (sat::clause const* c : s.learned()) { + add_clause(*c, lemmas); + } + } + + }; sat2goal::sat2goal():m_imp(0) { @@ -1186,3 +1253,9 @@ void sat2goal::operator()(sat::solver const & t, atom2bool_var const & m, params proc(t, m, g, mc); } +void sat2goal::get_learned(sat::solver const & t, atom2bool_var const & m, params_ref const& p, expr_ref_vector& lemmas) { + imp proc(lemmas.get_manager(), p); + scoped_set_imp set(this, &proc); + proc.get_learned(t, m, lemmas); +} + diff --git a/src/sat/tactic/goal2sat.h b/src/sat/tactic/goal2sat.h index cd63cd497..5bfb28f60 100644 --- a/src/sat/tactic/goal2sat.h +++ b/src/sat/tactic/goal2sat.h @@ -85,6 +85,13 @@ public: or memory consumption limit is reached (set with param :max-memory). */ void operator()(sat::solver const & t, atom2bool_var const & m, params_ref const & p, goal & s, model_converter_ref & mc); + + + /** + \brief extract learned clauses only that are in the domain of m. + + */ + void get_learned(sat::solver const& s, atom2bool_var const& m, params_ref const& p, expr_ref_vector& learned); }; diff --git a/src/tactic/portfolio/bounded_int2bv_solver.cpp b/src/tactic/portfolio/bounded_int2bv_solver.cpp index cee8688e5..3645ba97a 100644 --- a/src/tactic/portfolio/bounded_int2bv_solver.cpp +++ b/src/tactic/portfolio/bounded_int2bv_solver.cpp @@ -150,6 +150,7 @@ public: virtual void get_labels(svector & r) { m_solver->get_labels(r); } virtual ast_manager& get_manager() const { return m; } virtual expr_ref lookahead(expr_ref_vector& candidates) { flush_assertions(); return m_solver->lookahead(candidates); } + virtual void get_lemmas(expr_ref_vector & lemmas) { flush_assertions(); m_solver->get_lemmas(lemmas); } virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) { return m_solver->find_mutexes(vars, mutexes); } virtual lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) { flush_assertions(); diff --git a/src/tactic/portfolio/enum2bv_solver.cpp b/src/tactic/portfolio/enum2bv_solver.cpp index ef7ee6cd8..a40f2988a 100644 --- a/src/tactic/portfolio/enum2bv_solver.cpp +++ b/src/tactic/portfolio/enum2bv_solver.cpp @@ -98,6 +98,7 @@ public: virtual ast_manager& get_manager() const { return m; } virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) { return m_solver->find_mutexes(vars, mutexes); } virtual expr_ref lookahead(expr_ref_vector& candidates) { return m_solver->lookahead(candidates); } + virtual void get_lemmas(expr_ref_vector & lemmas) { m_solver->get_lemmas(lemmas); } virtual lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) { datatype_util dt(m); diff --git a/src/tactic/portfolio/pb2bv_solver.cpp b/src/tactic/portfolio/pb2bv_solver.cpp index 673db03c0..8bf6b7e39 100644 --- a/src/tactic/portfolio/pb2bv_solver.cpp +++ b/src/tactic/portfolio/pb2bv_solver.cpp @@ -94,6 +94,7 @@ public: virtual void get_labels(svector & r) { m_solver->get_labels(r); } virtual ast_manager& get_manager() const { return m; } virtual expr_ref lookahead(expr_ref_vector& candidates) { flush_assertions(); return m_solver->lookahead(candidates); } + virtual void get_lemmas(expr_ref_vector & lemmas) { flush_assertions(); m_solver->get_lemmas(lemmas); } virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) { return m_solver->find_mutexes(vars, mutexes); } virtual lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) { flush_assertions();