diff --git a/src/sat/smt/q_solver.cpp b/src/sat/smt/q_solver.cpp index 97adc6ae6..bba21723a 100644 --- a/src/sat/smt/q_solver.cpp +++ b/src/sat/smt/q_solver.cpp @@ -33,14 +33,13 @@ namespace q { expr* e = bool_var2expr(l.var()); SASSERT(is_forall(e) || is_exists(e)); if (l.sign() == is_forall(e)) { - // existential force add_clause(~l, skolemize(to_quantifier(e))); } - else { - // universal force -// add_clause(~l, uskolemize(to_quantifier(e))); + else { + add_clause(~l, specialize(to_quantifier(e))); ctx.push_vec(m_universal, l); } + m_stats.m_num_quantifier_asserts++; } sat::check_result solver::check() { @@ -59,7 +58,7 @@ namespace q { } void solver::collect_statistics(statistics& st) const { - st.update("quantifier inst", m_stats.m_num_inst); + st.update("quantifier asserts", m_stats.m_num_quantifier_asserts); } euf::th_solver* solver::clone(sat::solver* s, euf::solver& ctx) { @@ -77,26 +76,28 @@ namespace q { return v; } - sat::literal solver::skolemize(quantifier* q) { + sat::literal solver::instantiate(quantifier* q, std::function& mk_var) { sat::literal sk; - if (m_skolems.find(q, sk)) - return sk; expr_ref tmp(m); expr_ref_vector vars(m); - unsigned sz = q->get_num_decls(); + quantifier* q_flat = flatten(q); + unsigned sz = q_flat->get_num_decls(); vars.resize(sz, nullptr); - for (unsigned i = 0; i < sz; ++i) { - vars[i] = m.mk_fresh_const(q->get_decl_name(i), q->get_decl_sort(i)); - } + for (unsigned i = 0; i < sz; ++i) + vars[i] = mk_var(q_flat, i); var_subst subst(m); - expr_ref body = subst(q->get_expr(), vars.size(), vars.c_ptr()); + expr_ref body = subst(q_flat->get_expr(), vars); ctx.get_rewriter()(body); - sk = b_internalize(body); + return b_internalize(body); + } + + sat::literal solver::skolemize(quantifier* q) { + std::function mk_var = [&](quantifier* q, unsigned i) { + return m.mk_fresh_const(q->get_decl_name(i), q->get_decl_sort(i)); + }; + sat::literal sk = instantiate(q, mk_var); if (is_forall(q)) sk.neg(); - m_skolems.insert(q, sk); - // TODO find a different way than rely on backtrack stack, e,g., save body/q in ref-counted stack - ctx.push(insert_map(m_skolems, q)); return sk; } @@ -104,9 +105,14 @@ namespace q { * Find initial values to instantiate quantifier with so to make it as hard as possible for solver * to find values to free variables. */ - sat::literal solver::uskolemize(quantifier* q) { - NOT_IMPLEMENTED_YET(); - return sat::null_literal; + sat::literal solver::specialize(quantifier* q) { + std::function mk_var = [&](quantifier* q, unsigned i) { + return get_unit(q->get_decl_sort(i)); + }; + sat::literal sk = instantiate(q, mk_var); + if (is_exists(q)) + sk.neg(); + return sk; } void solver::init_search() { @@ -143,4 +149,34 @@ namespace q { ctx.push(insert_ref2_map(m, m_flat, q, q_flat)); return q_flat; } + + void solver::init_units() { + if (!m_unit_table.empty()) + return; + for (euf::enode* n : ctx.get_egraph().nodes()) { + if (!n->interpreted()) + continue; + expr* e = n->get_expr(); + sort* s = m.get_sort(e); + if (m_unit_table.contains(s)) + continue; + m_unit_table.insert(s, e); + ctx.push(insert_map, sort*>(m_unit_table, s)); + } + } + + expr* solver::get_unit(sort* s) { + expr* u = nullptr; + if (m_unit_table.find(s, u)) + return u; + init_units(); + if (m_unit_table.find(s, u)) + return u; + model mdl(m); + expr* val = mdl.get_some_value(s); + m.inc_ref(val); + m.inc_ref(s); + ctx.push(insert_ref2_map(m, m_unit_table, s, val)); + return val; + } } diff --git a/src/sat/smt/q_solver.h b/src/sat/smt/q_solver.h index 58574544e..19778b099 100644 --- a/src/sat/smt/q_solver.h +++ b/src/sat/smt/q_solver.h @@ -29,12 +29,11 @@ namespace q { class solver : public euf::th_euf_solver { - typedef obj_map skolem_table; typedef obj_map flat_table; friend class mbqi; struct stats { - unsigned m_num_inst; + unsigned m_num_quantifier_asserts; void reset() { memset(this, 0, sizeof(*this)); } stats() { reset(); } }; @@ -42,12 +41,15 @@ namespace q { stats m_stats; mbqi m_mbqi; - skolem_table m_skolems; flat_table m_flat; sat::literal_vector m_universal; + obj_map m_unit_table; + sat::literal instantiate(quantifier* q, std::function& mk_var); sat::literal skolemize(quantifier* q); - sat::literal uskolemize(quantifier* q); + sat::literal specialize(quantifier* q); + void init_units(); + expr* get_unit(sort* s); public: