diff --git a/src/sat/smt/polysat_core.cpp b/src/sat/smt/polysat_core.cpp index 931a92992..b97a223f7 100644 --- a/src/sat/smt/polysat_core.cpp +++ b/src/sat/smt/polysat_core.cpp @@ -66,16 +66,16 @@ namespace polysat { class core::mk_add_watch : public trail { core& c; - unsigned m_idx; public: - mk_add_watch(core& c, unsigned idx) : c(c), m_idx(idx) {} + mk_add_watch(core& c) : c(c) {} void undo() override { - auto& sc = c.m_prop_queue[m_idx].first; + auto& [sc, lit] = c.m_constraint_trail.back(); auto& vars = sc.vars(); if (vars.size() > 0) c.m_watch[vars[0]].pop_back(); if (vars.size() > 1) c.m_watch[vars[1]].pop_back(); + c.m_constraint_trail.pop_back(); } }; @@ -123,6 +123,22 @@ namespace polysat { m_var_queue.del_var_eh(v); } + unsigned core::register_constraint(signed_constraint& sc, solver_assertion as) { + unsigned idx = m_constraint_trail.size(); + m_constraint_trail.push_back({ sc, as }); + auto& vars = sc.vars(); + unsigned i = 0, j = 0, sz = vars.size(); + for (; i < sz && j < 2; ++i) + if (!is_assigned(vars[i])) + std::swap(vars[i], vars[j++]); + if (vars.size() > 0) + add_watch(idx, vars[0]); + if (vars.size() > 1) + add_watch(idx, vars[1]); + s.ctx.push(mk_add_watch(*this)); + return idx; + } + // case split on unassigned variables until all are assigned values. // create equality literal for unassigned variable. // return new_eq if there is a new literal. @@ -141,6 +157,7 @@ namespace polysat { s.propagate(m_constraints.eq(var2pdd(m_var), m_value), m_viable.explain()); return sat::check_result::CR_CONTINUE; case find_t::multiple: + s.add_eq_literal(m_var, m_value); return sat::check_result::CR_CONTINUE; case find_t::resource_out: return sat::check_result::CR_GIVEUP; @@ -155,33 +172,17 @@ namespace polysat { return false; s.ctx.push(value_trail(m_qhead)); for (; m_qhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_qhead) - propagate_constraint(m_qhead, m_prop_queue[m_qhead]); + propagate_constraint(m_prop_queue[m_qhead]); s.ctx.push(value_trail(m_vqhead)); for (; m_vqhead < m_prop_queue.size() && !s.ctx.inconsistent(); ++m_vqhead) - propagate_value(m_vqhead, m_prop_queue[m_vqhead]); + propagate_value(m_prop_queue[m_vqhead]); return true; } - void core::propagate_constraint(unsigned idx, dependent_constraint& dc) { - auto [sc, dep] = dc; - if (sc.is_eq(m_var, m_value)) { + void core::propagate_constraint(prop_item& dc) { + auto [idx, sc, dep] = dc; + if (sc.is_eq(m_var, m_value)) propagate_assignment(m_var, m_value, dep); - return; - } - add_watch(idx, sc); - } - - void core::add_watch(unsigned idx, signed_constraint& sc) { - auto& vars = sc.vars(); - unsigned i = 0, j = 0, sz = vars.size(); - for (; i < sz && j < 2; ++i) - if (!is_assigned(vars[i])) - std::swap(vars[i], vars[j++]); - if (vars.size() > 0) - add_watch(idx, vars[0]); - if (vars.size() > 1) - add_watch(idx, vars[1]); - s.ctx.push(mk_add_watch(*this, idx)); } void core::add_watch(unsigned idx, unsigned var) { @@ -205,7 +206,7 @@ namespace polysat { // for entries where there is only one free variable left add to viable set unsigned j = 0; for (auto idx : m_watch[v]) { - auto [sc, dep] = m_prop_queue[idx]; + auto [sc, as] = m_constraint_trail[idx]; auto& vars = sc.vars(); if (vars[0] != v) std::swap(vars[0], vars[1]); @@ -219,39 +220,51 @@ namespace polysat { break; } } - if (!swapped) { - m_watch[v][j++] = idx; - } - // constraint is unitary, add to viable set - if (vars.size() >= 2 && is_assigned(vars[0]) && !is_assigned(vars[1])) { - m_viable.add_unitary(vars[1], idx); - } + SASSERT(!swapped || vars.size() <= 1 || (!is_assigned(vars[0]) && !is_assigned(vars[1]))); + if (swapped) + continue; + m_watch[v][j++] = idx; + if (vars.size() <= 1) + continue; + auto v1 = vars[1]; + if (is_assigned(v1)) + continue; + SASSERT(is_assigned(vars[0]) && vars.size() >= 2); + // detect unitary, add to viable, detect conflict? + m_viable.add_unitary(v1, idx); } m_watch[v].shrink(j); } - void core::propagate_value(unsigned idx, dependent_constraint const& dc) { - auto [sc, dep] = dc; + void core::propagate_value(prop_item const& dc) { + auto [idx, sc, dep] = dc; // check if sc evaluates to false switch (eval(sc)) { case l_true: return; case l_false: - m_unsat_core = explain_eval(dc); + m_unsat_core = explain_eval({ sc, dep }); propagate_unsat_core(); return; default: break; } - // if sc is v == value, then check the watch list for v if they evaluate to false. + // if sc is v == value, then check the watch list for v to propagate truth assignments if (sc.is_eq(m_var, m_value)) { for (auto idx : m_watch[m_var]) { - auto [sc, dep] = m_prop_queue[idx]; - if (eval(sc) == l_false) { - m_unsat_core = explain_eval({ sc, dep }); - propagate_unsat_core(); - return; + auto [sc, as] = m_constraint_trail[idx]; + switch (eval(sc)) { + case l_false: + m_unsat_core = explain_eval({ sc, nullptr }); + s.propagate(as, true, m_unsat_core); + break; + case l_true: + m_unsat_core = explain_eval({ sc, nullptr }); + s.propagate(as, false, m_unsat_core); + break; + default: + break; } } } @@ -259,15 +272,14 @@ namespace polysat { throw default_exception("nyi"); } - bool core::propagate_unsat_core() { + void core::propagate_unsat_core() { // default is to use unsat core: s.set_conflict(m_unsat_core); // if core is based on viable, use s.set_lemma(); - throw default_exception("nyi"); } - void core::assign_eh(signed_constraint const& sc, dependency const& dep) { - m_prop_queue.push_back({ sc, m_dep.mk_leaf(dep) }); + void core::assign_eh(unsigned index, signed_constraint const& sc, dependency const& dep) { + m_prop_queue.push_back({ index, sc, m_dep.mk_leaf(dep) }); s.ctx.push(push_back_vector(m_prop_queue)); } diff --git a/src/sat/smt/polysat_core.h b/src/sat/smt/polysat_core.h index 075840a48..0c173da77 100644 --- a/src/sat/smt/polysat_core.h +++ b/src/sat/smt/polysat_core.h @@ -30,12 +30,25 @@ namespace polysat { class core; class solver; + struct solver_assertion { + unsigned m_var1; + unsigned m_var2 = 0; + public: + solver_assertion(sat::literal lit) : m_var1(2*lit.index()) {} + solver_assertion(unsigned v1, unsigned v2) : m_var1(1 + 2*v1), m_var2(v2) {} + bool is_literal() const { return m_var1 % 2 == 0; } + sat::literal get_literal() const { SASSERT(is_literal()); return sat::to_literal(m_var1 / 2); } + unsigned var1() const { SASSERT(!is_literal()); return (m_var1 - 1) / 2; } + unsigned var2() const { SASSERT(!is_literal()); return m_var2; } + }; + class core { class mk_add_var; class mk_dqueue_var; class mk_assign_var; class mk_add_watch; typedef svector> activity; + typedef std::tuple prop_item; friend class viable; friend class constraints; friend class assignment; @@ -45,7 +58,8 @@ namespace polysat { constraints m_constraints; assignment m_assignment; unsigned m_qhead = 0, m_vqhead = 0; - svector m_prop_queue; + svector m_prop_queue; + svector> m_constraint_trail; // stacked_dependency_manager m_dep; mutable scoped_ptr_vector m_pdd; dependency_vector m_unsat_core; @@ -69,12 +83,11 @@ namespace polysat { void del_var(); bool is_assigned(pvar v) { return nullptr != m_justification[v]; } - void propagate_constraint(unsigned idx, dependent_constraint& dc); - void propagate_value(unsigned idx, dependent_constraint const& dc); + void propagate_constraint(prop_item& dc); + void propagate_value(prop_item const& dc); void propagate_assignment(pvar v, rational const& value, stacked_dependency* dep); - bool propagate_unsat_core(); + void propagate_unsat_core(); - void add_watch(unsigned idx, signed_constraint& sc); void add_watch(unsigned idx, unsigned var); lbool eval(signed_constraint sc) { throw default_exception("nyi"); } @@ -85,8 +98,9 @@ namespace polysat { sat::check_result check(); + unsigned register_constraint(signed_constraint& sc, solver_assertion sa); bool propagate(); - void assign_eh(signed_constraint const& sc, dependency const& dep); + void assign_eh(unsigned idx, signed_constraint const& sc, dependency const& dep); expr_ref constraint2expr(signed_constraint const& sc) const { throw default_exception("nyi"); } diff --git a/src/sat/smt/polysat_internalize.cpp b/src/sat/smt/polysat_internalize.cpp index dd761fdc3..af53750fc 100644 --- a/src/sat/smt/polysat_internalize.cpp +++ b/src/sat/smt/polysat_internalize.cpp @@ -169,12 +169,15 @@ namespace polysat { } }; - solver::atom* solver::mk_atom(sat::bool_var bv) { + solver::atom* solver::mk_atom(sat::literal lit, signed_constraint& sc) { + auto bv = lit.var(); atom* a = get_bv2a(bv); if (a) return a; a = new (get_region()) atom(bv); insert_bv2a(bv, a); + a->m_sc = sc; + a->m_index = m_core.register_constraint(sc, lit); ctx.push(mk_atom_trail(bv, *this)); return a; } @@ -184,7 +187,7 @@ namespace polysat { auto q = expr2pdd(e->get_arg(1)); auto sc = ~fn(p, q); sat::literal lit = expr2literal(e); - mk_atom(lit.var())->m_sc = sc; + auto* a = mk_atom(lit, sc); } void solver::internalize_div_rem_i(app* e, bool is_div) { @@ -277,6 +280,7 @@ namespace polysat { template void solver::internalize_le(app* e) { + SASSERT(e->get_num_args() == 2); auto p = expr2pdd(e->get_arg(0)); auto q = expr2pdd(e->get_arg(1)); if (Rev) @@ -286,8 +290,7 @@ namespace polysat { sc = ~sc; sat::literal lit = expr2literal(e); - atom* a = mk_atom(lit.var()); - a->m_sc = sc; + atom* a = mk_atom(lit, sc); } void solver::internalize_bit2bool(atom* a, expr* e, unsigned idx) { diff --git a/src/sat/smt/polysat_solver.cpp b/src/sat/smt/polysat_solver.cpp index 58b2c97d1..e3aff6e5f 100644 --- a/src/sat/smt/polysat_solver.cpp +++ b/src/sat/smt/polysat_solver.cpp @@ -66,7 +66,7 @@ namespace polysat { auto sc = a->m_sc; if (l.sign()) sc = ~sc; - m_core.assign_eh(sc, dependency(l, s().lvl(l))); + m_core.assign_eh(a->m_index, sc, dependency(l, s().lvl(l))); } void solver::set_conflict(dependency_vector const& core) { @@ -151,8 +151,10 @@ namespace polysat { pdd q = var2pdd(v2); auto sc = m_core.eq(p, q); m_var_eqs.setx(m_var_eqs_head, std::make_pair(v1, v2), std::make_pair(v1, v2)); - ctx.push(value_trail(m_var_eqs_head)); - m_core.assign_eh(sc, dependency(m_var_eqs_head, s().scope_lvl())); + ctx.push(value_trail(m_var_eqs_head)); + unsigned index = 0; +// unsigned index = m_core.register_constraint(sc); + m_core.assign_eh(index, sc, dependency(m_var_eqs_head, s().scope_lvl())); m_var_eqs_head++; } @@ -162,8 +164,9 @@ namespace polysat { pdd q = var2pdd(v2); auto sc = ~m_core.eq(p, q); sat::literal neq = ~expr2literal(ne.eq()); + auto index = m_core.register_constraint(sc, neq); TRACE("bv", tout << neq << " := " << s().value(neq) << " @" << s().scope_lvl() << "\n"); - m_core.assign_eh(sc, dependency(neq, s().lvl(neq))); + m_core.assign_eh(index, sc, dependency(neq, s().lvl(neq))); } // Core uses the propagate callback to add unit propagations to the trail. @@ -176,6 +179,25 @@ namespace polysat { ctx.propagate(lit, ex); } + void solver::propagate(solver_assertion as, bool sign, dependency_vector const& deps) { + auto [core, eqs] = explain_deps(deps); + if (as.is_literal()) { + auto lit = as.get_literal(); + if (sign) + lit.neg(); + auto ex = euf::th_explain::propagate(*this, core, eqs, lit, nullptr); + ctx.propagate(lit, ex); + } + else if (sign) { + // equalities are always asserted so a negative propagation is a conflict. + auto n1 = var2enode(as.var1()); + auto n2 = var2enode(as.var2()); + eqs.push_back({ n1, n2 }); + auto ex = euf::th_explain::conflict(*this, core, eqs, nullptr); + ctx.set_conflict(ex); + } + } + void solver::add_lemma(vector const& lemma) { sat::literal_vector lits; for (auto sc : lemma) diff --git a/src/sat/smt/polysat_solver.h b/src/sat/smt/polysat_solver.h index 5d9cd19a3..7940a7223 100644 --- a/src/sat/smt/polysat_solver.h +++ b/src/sat/smt/polysat_solver.h @@ -42,6 +42,7 @@ namespace polysat { struct atom { bool_var m_bv; + unsigned m_index = 0; signed_constraint m_sc; atom(bool_var b) :m_bv(b) {} ~atom() { } @@ -91,7 +92,7 @@ namespace polysat { void erase_bv2a(bool_var bv) { m_bool_var2atom[bv] = nullptr; } atom* get_bv2a(bool_var bv) const { return m_bool_var2atom.get(bv, nullptr); } class mk_atom_trail; - atom* mk_atom(sat::bool_var bv); + atom* mk_atom(sat::literal lit, signed_constraint& sc); void set_bit_eh(theory_var v, literal l, unsigned idx); void init_bits(expr* e, expr_ref_vector const & bits); void mk_bits(theory_var v); @@ -133,6 +134,8 @@ namespace polysat { void set_conflict(dependency_vector const& core); void set_lemma(vector const& lemma, unsigned level, dependency_vector const& core); void propagate(signed_constraint sc, dependency_vector const& deps); + void propagate(solver_assertion as, bool sign, dependency_vector const& deps); + void add_lemma(vector const& lemma); std::pair explain_deps(dependency_vector const& deps);