diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 36ebdbacd..25ec308d8 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -78,6 +78,7 @@ namespace intblast { literals.push_back(a); } + m_core.reset(); m_solver = mk_smt2_solver(m, s.params(), symbol::null); expr_ref_vector es(m); @@ -86,11 +87,23 @@ namespace intblast { translate(es); - for (auto e : es) - m_solver->assert_expr(e); - + for (auto const& [src, vi] : m_vars) { + auto const& [v, b] = vi; + m_solver->assert_expr(a.mk_le(a.mk_int(0), v)); + m_solver->assert_expr(a.mk_lt(v, a.mk_int(b))); + } - lbool r = m_solver->check_sat(0, nullptr); + lbool r = m_solver->check_sat(es); + + if (r == l_false) { + expr_ref_vector core(m); + m_solver->get_unsat_core(core); + obj_map e2index; + for (unsigned i = 0; i < es.size(); ++i) + e2index.insert(es.get(i), i); + for (auto e : core) + m_core.push_back(literals[e2index[e]]); + } return r; }; @@ -290,11 +303,6 @@ namespace intblast { } for (unsigned i = 0; i < es.size(); ++i) es[i] = translated[es.get(i)]; - for (auto const& [src, vi] : m_vars) { - auto const& [v, b] = vi; - es.push_back(a.mk_le(a.mk_int(0), v)); - es.push_back(a.mk_lt(v, a.mk_int(b))); - } } rational solver::get_value(expr* e) const { @@ -313,4 +321,8 @@ namespace intblast { return val; } + sat::literal_vector const& solver::unsat_core() { + return m_core; + } + } diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index f2ec486d5..1df46c300 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -46,6 +46,7 @@ namespace intblast { scoped_ptr<::solver> m_solver; obj_map m_vars; expr_ref_vector m_trail; + sat::literal_vector m_core; @@ -58,6 +59,8 @@ namespace intblast { lbool check(); + sat::literal_vector const& unsat_core(); + rational get_value(expr* e) const; }; diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 460501cd0..8baf05de5 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -31,6 +31,8 @@ The result of polysat::core::check is one of: #include "sat/smt/polysat/polysat_umul_ovfl.h" + + namespace polysat { solver::solver(euf::solver& ctx, theory_id id): @@ -38,6 +40,7 @@ namespace polysat { bv(ctx.get_manager()), m_autil(ctx.get_manager()), m_core(*this), + m_intblast(ctx), m_lemma(ctx.get_manager()) { ctx.get_egraph().add_plugin(alloc(euf::bv_plugin, ctx.get_egraph())); @@ -56,7 +59,31 @@ namespace polysat { } sat::check_result solver::check() { - return m_core.check(); + switch (m_core.check()) { + case sat::check_result::CR_DONE: + return sat::check_result::CR_DONE; + case sat::check_result::CR_CONTINUE: + return sat::check_result::CR_CONTINUE; + case sat::check_result::CR_GIVEUP: { + if (!m.inc()) + return sat::check_result::CR_GIVEUP; + switch (m_intblast.check()) { + case l_true: + trail().push(value_trail(m_use_intblast_model)); + m_use_intblast_model = true; + return sat::check_result::CR_DONE; + case l_false: { + auto core = m_intblast.unsat_core(); + for (auto& lit : core) + lit.neg(); + s().add_clause(core.size(), core.data(), sat::status::th(true, get_id(), nullptr)); + return sat::check_result::CR_CONTINUE; + } + case l_undef: + return sat::check_result::CR_GIVEUP; + } + } + } } void solver::asserted(literal l) { @@ -136,6 +163,7 @@ namespace polysat { unsigned num_scopes = s().scope_lvl() - m_lemma_level; + NOT_IMPLEMENTED_YET(); // s().pop_reinit(num_scopes); sat::literal_vector lits; diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index b5e69c36a..e1a9221e9 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -19,6 +19,7 @@ Author: #include "sat/smt/sat_th.h" #include "math/dd/dd_pdd.h" #include "sat/smt/polysat/polysat_core.h" +#include "sat/smt/intblast_solver.h" namespace euf { class solver; @@ -57,7 +58,8 @@ namespace polysat { arith_util m_autil; stats m_stats; core m_core; - polysat_proof m_proof; + intblast::solver m_intblast; + bool m_use_intblast_model = false; vector m_var2pdd; // theory_var 2 pdd bool_vector m_var2pdd_valid; // valid flag