diff --git a/src/sat/smt/euf_model.cpp b/src/sat/smt/euf_model.cpp index 84aa98e3d..83d02fada 100644 --- a/src/sat/smt/euf_model.cpp +++ b/src/sat/smt/euf_model.cpp @@ -24,11 +24,11 @@ namespace euf { void solver::update_model(model_ref& mdl) { deps_t deps; - expr_ref_vector values(m); + m_values.reset(); collect_dependencies(deps); deps.topological_sort(); - dependencies2values(deps, values, mdl); - values2model(deps, values, mdl); + dependencies2values(deps, mdl); + values2model(deps, mdl); } bool solver::include_func_interp(func_decl* f) { @@ -91,26 +91,26 @@ namespace euf { } }; - void solver::dependencies2values(deps_t& deps, expr_ref_vector& values, model_ref& mdl) { - user_sort user_sort(*this, values, mdl); + void solver::dependencies2values(deps_t& deps, model_ref& mdl) { + user_sort user_sort(*this, m_values, mdl); for (enode* n : deps.top_sorted()) { unsigned id = n->get_root_id(); - if (values.get(id, nullptr)) + if (m_values.get(id, nullptr)) continue; expr* e = n->get_expr(); - values.reserve(id + 1); + m_values.reserve(id + 1); if (m.is_bool(e) && is_uninterp_const(e) && mdl->get_const_interp(to_app(e)->get_decl())) { - values.set(id, mdl->get_const_interp(to_app(e)->get_decl())); + m_values.set(id, mdl->get_const_interp(to_app(e)->get_decl())); continue; } // model of s() must have been fixed. if (m.is_bool(e)) { if (m.is_true(e)) { - values.set(id, m.mk_true()); + m_values.set(id, m.mk_true()); continue; } if (m.is_false(e)) { - values.set(id, m.mk_false()); + m_values.set(id, m.mk_false()); continue; } if (is_app(e) && to_app(e)->get_family_id() == m.get_basic_family_id()) @@ -119,10 +119,10 @@ namespace euf { SASSERT(v != sat::null_bool_var); switch (s().value(v)) { case l_true: - values.set(id, m.mk_true()); + m_values.set(id, m.mk_true()); break; case l_false: - values.set(id, m.mk_false()); + m_values.set(id, m.mk_false()); break; default: break; @@ -134,16 +134,16 @@ namespace euf { if (m.is_uninterp(srt)) user_sort.add(id, srt); else if (auto* mbS = sort2solver(srt)) - mbS->add_value(n, *mdl, values); + mbS->add_value(n, *mdl, m_values); else if (auto* mbE = expr2solver(e)) - mbE->add_value(n, *mdl, values); + mbE->add_value(n, *mdl, m_values); else { IF_VERBOSE(1, verbose_stream() << "no model values created for " << mk_pp(e, m) << "\n"); } } } - void solver::values2model(deps_t const& deps, expr_ref_vector const& values, model_ref& mdl) { + void solver::values2model(deps_t const& deps, model_ref& mdl) { ptr_vector args; for (enode* n : deps.top_sorted()) { expr* e = n->get_expr(); @@ -155,7 +155,7 @@ namespace euf { continue; if (m.is_bool(e) && is_uninterp_const(e) && mdl->get_const_interp(f)) continue; - expr* v = values.get(n->get_root_id()); + expr* v = m_values.get(n->get_root_id()); CTRACE("euf", !v, tout << "no value for " << mk_pp(e, m) << "\n";); if (!v) continue; @@ -170,7 +170,7 @@ namespace euf { } args.reset(); for (enode* arg : enode_args(n)) { - args.push_back(values.get(arg->get_root_id())); + args.push_back(m_values.get(arg->get_root_id())); SASSERT(args.back()); } SASSERT(args.size() == arity); @@ -184,4 +184,13 @@ namespace euf { // TODO } + obj_map const& solver::values2root() { + m_values2root.reset(); + for (enode* n : m_egraph.nodes()) + if (n->is_root()) + m_values2root.insert(m_values.get(n->get_root_id()), n); + return m_values2root; + } + + } diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 0f6d69a8c..de60726db 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -37,7 +37,8 @@ namespace euf { m_lookahead(nullptr), m_to_m(&m), m_to_si(&si), - m_reinit_exprs(m) + m_reinit_exprs(m), + m_values(m) { updt_params(p); diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 6b8f33b07..844f0acb6 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -101,6 +101,7 @@ namespace euf { svector m_scopes; scoped_ptr_vector m_solvers; ptr_vector m_id2solver; + std::function<::solver*(void)> m_mk_solver; constraint* m_conflict{ nullptr }; constraint* m_eq{ nullptr }; @@ -135,11 +136,13 @@ namespace euf { void init_ackerman(); // model building + expr_ref_vector m_values; + obj_map m_values2root; bool include_func_interp(func_decl* f); void register_macros(model& mdl); - void dependencies2values(deps_t& deps, expr_ref_vector& values, model_ref& mdl); + void dependencies2values(deps_t& deps, model_ref& mdl); void collect_dependencies(deps_t& deps); - void values2model(deps_t const& deps, expr_ref_vector const& values, model_ref& mdl); + void values2model(deps_t const& deps, model_ref& mdl); // solving void propagate_literals(); @@ -299,6 +302,7 @@ namespace euf { // model construction void update_model(model_ref& mdl); + obj_map const& values2root(); // diagnostics func_decl_ref_vector const& unhandled_functions() { return m_unhandled_functions; } @@ -335,6 +339,10 @@ namespace euf { return m_user_propagator->add_expr(e); } + // solver factory + ::solver* mk_solver() { return m_mk_solver(); } + void set_mk_solver(std::function<::solver*(void)>& mk) { m_mk_solver = mk; } + }; }; diff --git a/src/sat/smt/q_mbi.cpp b/src/sat/smt/q_mbi.cpp index 50662b49d..dd9cb53c0 100644 --- a/src/sat/smt/q_mbi.cpp +++ b/src/sat/smt/q_mbi.cpp @@ -27,7 +27,7 @@ Author: namespace q { mbqi::mbqi(euf::solver& ctx, solver& s): - ctx(ctx), qs(s), m(s.get_manager()) {} + ctx(ctx), qs(s), m(s.get_manager()), m_fresh_trail(m) {} void mbqi::restrict_to_universe(expr * sk, ptr_vector const & universe) { @@ -49,16 +49,12 @@ namespace q { m_values.push_back(values); } if (!values->contains(e)) { - NOT_IMPLEMENTED_YET(); -#if 0 for (expr* b : *values) { - m_context.add(m.mk_not(m.mk_eq(e, b)), __FUNCTION__); + expr_ref eq = ctx.mk_eq(e, b); + qs.add_unit(~qs.b_internalize(eq)); } -#endif values->insert(e); -#if 0 m_fresh_trail.push_back(e); -#endif } } @@ -76,10 +72,20 @@ namespace q { for (expr* arg : *to_app(e)) { args.push_back(replace_model_value(arg)); } - return expr_ref(m.mk_app(to_app(e)->get_decl(), args.size(), args.c_ptr()), m); + return expr_ref(m.mk_app(to_app(e)->get_decl(), args), m); } return expr_ref(e, m); } + + expr_ref mbqi::choose_term(euf::enode* r) { + unsigned sz = r->class_size(); + unsigned start = ctx.s().rand()() % sz; + unsigned i = 0; + for (euf::enode* n : euf::enode_class(r)) + if (i++ >= start) + return expr_ref(n->get_expr(), m); + return expr_ref(nullptr, m); + } lbool mbqi::check_forall(quantifier* q) { expr_ref_vector vars(m); @@ -137,6 +143,7 @@ namespace q { unsigned sz = q->get_num_decls(); expr_ref_vector vals(m); vals.resize(sz, nullptr); + auto const& v2r = ctx.values2root(); for (unsigned i = 0; i < sz; ++i) { app* v = to_app(vars.get(i)); func_decl* f = v->get_decl(); @@ -144,22 +151,16 @@ namespace q { if (!val) return expr_ref(m); expr* t = nullptr; - NOT_IMPLEMENTED_YET(); -#if 0 - if (m_val2term.find(val, m.get_sort(v), t)) { - val = t; - } - else { - val = replace_model_value(val); - } - vals[i] = val; -#endif + euf::enode* r = nullptr; + if (v2r.find(val, r)) + vals[i] = choose_term(r); + if (!vals.get(i)) + vals[i] = replace_model_value(val); } var_subst subst(m); return subst(q->get_expr(), vals); } - lbool mbqi::operator()() { lbool result = l_true; m_model = nullptr; @@ -191,9 +192,8 @@ namespace q { } void mbqi::init_solver() { - if (m_solver) - return; - NOT_IMPLEMENTED_YET(); + if (!m_solver) + m_solver = ctx.mk_solver(); } } diff --git a/src/sat/smt/q_mbi.h b/src/sat/smt/q_mbi.h index 3d34ff313..17dad0bfc 100644 --- a/src/sat/smt/q_mbi.h +++ b/src/sat/smt/q_mbi.h @@ -35,10 +35,12 @@ namespace q { ref<::solver> m_solver; obj_map*> m_fresh; scoped_ptr_vector> m_values; + expr_ref_vector m_fresh_trail; void restrict_to_universe(expr * sk, ptr_vector const & universe); void register_value(expr* e); expr_ref replace_model_value(expr* e); + expr_ref choose_term(euf::enode* r); lbool check_forall(quantifier* q); expr_ref specialize(quantifier* q, expr_ref_vector& vars); expr_ref project(model& mdl, quantifier* q, expr_ref_vector& vars); diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index ad0f3adad..1cad3c0ae 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -608,6 +608,12 @@ struct goal2sat::imp : public sat::sat_internalizer { m_solver.set_extension(euf); for (unsigned i = m_solver.num_scopes(); i-- > 0; ) euf->push(); +#if 0 + std::function mk_solver = [&]() { + return mk_inc_sat_solver(m, m_params, true); + }; + euf->set_mk_solver(mk_solver); +#endif } else { euf = dynamic_cast(ext);