diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index 7585e4f36..89bcbb1ae 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -23,8 +23,8 @@ Notes: namespace euf { - enode* egraph::mk_enode(expr* f, unsigned num_args, enode * const* args) { - enode* n = enode::mk(m_region, f, num_args, args); + enode* egraph::mk_enode(expr* f, unsigned generation, unsigned num_args, enode * const* args) { + enode* n = enode::mk(m_region, f, generation, num_args, args); m_nodes.push_back(n); m_exprs.push_back(f); if (is_app(f) && num_args > 0) { @@ -83,10 +83,10 @@ namespace euf { n->set_update_children(); } - enode* egraph::mk(expr* f, unsigned num_args, enode *const* args) { + enode* egraph::mk(expr* f, unsigned generation, unsigned num_args, enode *const* args) { SASSERT(!find(f)); force_push(); - enode *n = mk_enode(f, num_args, args); + enode *n = mk_enode(f, generation, num_args, args); SASSERT(n->class_size() == 1); if (num_args == 0 && m.is_unique_value(f)) n->mark_interpreted(); @@ -552,6 +552,7 @@ namespace euf { void egraph::push_congruence(enode* n1, enode* n2, bool comm) { SASSERT(is_app(n1->get_expr())); SASSERT(n1->get_decl() == n2->get_decl()); + m_uses_congruence = true; if (m_used_cc && !comm) { m_used_cc(to_app(n1->get_expr()), to_app(n2->get_expr())); } @@ -598,6 +599,7 @@ namespace euf { void egraph::begin_explain() { SASSERT(m_todo.empty()); + m_uses_congruence = false; } void egraph::end_explain() { @@ -672,15 +674,16 @@ namespace euf { out << " " << p->get_expr_id(); out << "] "; } - if (n->value() != l_undef) { + if (n->value() != l_undef) out << "[v" << n->bool_var() << " := " << (n->value() == l_true ? "T":"F") << "] "; - } if (n->has_th_vars()) { out << "[t"; for (auto v : enode_th_vars(n)) out << " " << v.get_id() << ":" << v.get_var(); out << "] "; } + if (n->generation() > 0) + out << "[g " << n->generation() << "] "; if (n->m_target && m_display_justification) n->m_justification.display(out << "[j " << n->m_target->get_expr_id() << " ", m_display_justification) << "] "; out << "\n"; @@ -722,7 +725,7 @@ namespace euf { for (unsigned j = 0; j < n1->num_args(); ++j) args.push_back(old_expr2new_enode[n1->get_arg(j)->get_expr_id()]); expr* e2 = tr(e1); - enode* n2 = mk(e2, args.size(), args.c_ptr()); + enode* n2 = mk(e2, n1->generation(), args.size(), args.c_ptr()); old_expr2new_enode.setx(e1->get_id(), n2, nullptr); n2->set_value(n2->value()); n2->m_bool_var = n1->m_bool_var; diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 1f718ba31..10399a8fd 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -164,6 +164,7 @@ namespace euf { bool_vector m_th_propagates_diseqs; enode_vector m_todo; stats m_stats; + bool m_uses_congruence { false }; std::function m_used_eq; std::function m_used_cc; std::function m_display_justification; @@ -180,7 +181,7 @@ namespace euf { void add_literal(enode* n, bool is_eq); void undo_eq(enode* r1, enode* n1, unsigned r2_num_parents); void undo_add_th_var(enode* n, theory_id id); - enode* mk_enode(expr* f, unsigned num_args, enode * const* args); + enode* mk_enode(expr* f, unsigned generation, unsigned num_args, enode * const* args); void force_push(); void set_conflict(enode* n1, enode* n2, justification j); void merge(enode* n1, enode* n2, justification j); @@ -217,7 +218,7 @@ namespace euf { egraph(ast_manager& m); ~egraph(); enode* find(expr* f) const { return m_expr2enode.get(f->get_id(), nullptr); } - enode* mk(expr* f, unsigned n, enode *const* args); + enode* mk(expr* f, unsigned generation, unsigned n, enode *const* args); enode_vector const& enodes_of(func_decl* f); void push() { ++m_num_scopes; } void pop(unsigned num_scopes); @@ -272,6 +273,7 @@ namespace euf { void begin_explain(); void end_explain(); + bool uses_congruence() const { return m_uses_congruence; } template void explain(ptr_vector& justifications); template diff --git a/src/ast/euf/euf_enode.h b/src/ast/euf/euf_enode.h index fe7371ab6..398b38775 100644 --- a/src/ast/euf/euf_enode.h +++ b/src/ast/euf/euf_enode.h @@ -45,17 +45,18 @@ namespace euf { bool m_commutative{ false }; bool m_update_children{ false }; bool m_interpreted{ false }; - bool m_merge_enabled{ true }; - bool m_is_equality{ false }; - lbool m_value; - unsigned m_bool_var { UINT_MAX }; - unsigned m_class_size{ 1 }; - unsigned m_table_id{ UINT_MAX }; + bool m_merge_enabled{ true }; + bool m_is_equality{ false }; // Does the expression represent an equality + lbool m_value; // Assignment by SAT solver for Boolean node + unsigned m_bool_var { UINT_MAX }; // SAT solver variable associated with Boolean node + unsigned m_class_size{ 1 }; // Size of the equivalence class if the enode is the root. + unsigned m_table_id{ UINT_MAX }; + unsigned m_generation { 0 }; // Tracks how many quantifier instantiation rounds were needed to generate this enode. enode_vector m_parents; - enode* m_next{ nullptr }; - enode* m_root{ nullptr }; - enode* m_target{ nullptr }; - enode* m_cg { nullptr }; + enode* m_next { nullptr }; + enode* m_root { nullptr }; + enode* m_target { nullptr }; + enode* m_cg { nullptr }; th_var_list m_th_vars; justification m_justification; unsigned m_num_args{ 0 }; @@ -72,13 +73,14 @@ namespace euf { return sizeof(enode) + num_args * sizeof(enode*); } - static enode* mk(region& r, expr* f, unsigned num_args, enode* const* args) { + static enode* mk(region& r, expr* f, unsigned generation, unsigned num_args, enode* const* args) { SASSERT(num_args <= (is_app(f) ? to_app(f)->get_num_args() : 0)); void* mem = r.allocate(get_enode_size(num_args)); enode* n = new (mem) enode(); n->m_expr = f; n->m_next = n; n->m_root = n; + n->m_generation = generation, n->m_commutative = num_args == 2 && is_app(f) && to_app(f)->get_decl()->is_commutative(); n->m_num_args = num_args; n->m_merge_enabled = true; @@ -142,9 +144,12 @@ namespace euf { enode* get_arg(unsigned i) const { SASSERT(i < num_args()); return m_args[i]; } unsigned hash() const { return m_expr->hash(); } + unsigned get_table_id() const { return m_table_id; } void set_table_id(unsigned t) { m_table_id = t; } + unsigned generation() const { return m_generation; } + void mark1() { m_mark1 = true; } void unmark1() { m_mark1 = false; } bool is_marked1() { return m_mark1; } diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index c16caf5a1..71e992c0e 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -312,6 +312,7 @@ namespace arith { force_push(); expr* e1 = var2expr(v1); expr* e2 = var2expr(v2); + TRACE("arith", tout << "new eq: v" << v1 << " v" << v2 << "\n";); if (e1->get_id() > e2->get_id()) std::swap(e1, e2); diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index cfe7cbe31..a99c00b63 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -590,14 +590,17 @@ namespace arith { theory_var v = n->get_th_var(get_id()); expr* o = n->get_expr(); expr_ref value(m); - if (use_nra_model() && lp().external_to_local(v) != lp::null_lpvar) { + if (m.is_value(n->get_root()->get_expr())) { + value = n->get_root()->get_expr(); + } + else if (use_nra_model() && lp().external_to_local(v) != lp::null_lpvar) { anum const& an = nl_value(v, *m_a1); if (a.is_int(o) && !m_nla->am().is_int(an)) value = a.mk_numeral(rational::zero(), a.is_int(o)); else value = a.mk_numeral(m_nla->am(), nl_value(v, *m_a1), a.is_int(o)); } - else { + else if (v != euf::null_theory_var) { rational r = get_value(v); TRACE("arith", tout << mk_pp(o, m) << " v" << v << " := " << r << "\n";); SASSERT("integer variables should have integer values: " && (!a.is_int(o) || r.is_int() || m.limit().is_canceled())); @@ -605,9 +608,34 @@ namespace arith { r = floor(r); value = a.mk_numeral(r, m.get_sort(o)); } + else if (a.is_arith_expr(o)) { + expr_ref_vector args(m); + for (auto* arg : euf::enode_args(n)) { + if (m.is_value(arg->get_expr())) + args.push_back(arg->get_expr()); + else + args.push_back(values.get(arg->get_root_id())); + } + value = m.mk_app(to_app(o)->get_decl(), args.size(), args.c_ptr()); + ctx.get_rewriter()(value); + } + else { + UNREACHABLE(); + } values.set(n->get_root_id(), value); } + void solver::add_dep(euf::enode* n, top_sort& dep) { + expr* e = n->get_expr(); + if (a.is_arith_expr(e) && to_app(e)->get_num_args() > 0) { + for (auto* arg : euf::enode_args(n)) + dep.add(n, arg); + } + else { + dep.insert(n, nullptr); + } + } + void solver::push_core() { TRACE("arith_verbose", tout << "push\n";); m_scopes.push_back(scope()); diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index f64c4cc93..2191cb029 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -429,6 +429,7 @@ namespace arith { void init_model() override; void finalize_model(model& mdl) override { DEBUG_CODE(dbg_finalize_model(mdl);); } void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override; + void add_dep(euf::enode* n, top_sort& dep) override; sat::literal internalize(expr* e, bool sign, bool root, bool learned) override; void internalize(expr* e, bool redundant) override; void eq_internalized(euf::enode* n) override; diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index ef4450e38..2a9e2a22d 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -81,7 +81,7 @@ namespace euf { if (auto* s = expr2solver(e)) s->internalize(e, m_is_redundant); else - attach_node(m_egraph.mk(e, 0, nullptr)); + attach_node(m_egraph.mk(e, m_generation, 0, nullptr)); return true; } @@ -95,7 +95,7 @@ namespace euf { if (auto* s = expr2solver(e)) s->internalize(e, m_is_redundant); else - attach_node(m_egraph.mk(e, num, m_args.c_ptr())); + attach_node(m_egraph.mk(e, m_generation, num, m_args.c_ptr())); return true; } @@ -149,7 +149,7 @@ namespace euf { m_var_trail.push_back(v); enode* n = m_egraph.find(e); if (!n) - n = m_egraph.mk(e, 0, nullptr); + n = m_egraph.mk(e, m_generation, 0, nullptr); SASSERT(n->bool_var() == UINT_MAX || n->bool_var() == v); m_egraph.set_bool_var(n, v); if (m.is_eq(e) || m.is_or(e) || m.is_and(e) || m.is_not(e)) @@ -249,7 +249,7 @@ namespace euf { for (unsigned i = 0; i < sz; ++i) { expr_ref fapp(m.mk_app(f, e->get_arg(i)), m); expr_ref fresh(m.mk_fresh_const("dist-value", u), m); - enode* n = m_egraph.mk(fresh, 0, nullptr); + enode* n = m_egraph.mk(fresh, m_generation, 0, nullptr); n->mark_interpreted(); expr_ref eq = mk_eq(fapp, fresh); sat::literal lit = mk_literal(eq); diff --git a/src/sat/smt/euf_model.cpp b/src/sat/smt/euf_model.cpp index 95a475788..cc0107e30 100644 --- a/src/sat/smt/euf_model.cpp +++ b/src/sat/smt/euf_model.cpp @@ -112,6 +112,7 @@ namespace euf { void solver::dependencies2values(user_sort& us, deps_t& deps, model_ref& mdl) { for (enode* n : deps.top_sorted()) { + unsigned id = n->get_root_id(); if (m_values.get(id, nullptr)) continue; @@ -181,15 +182,14 @@ namespace euf { mdl->register_decl(f, v); else { auto* fi = mdl->get_func_interp(f); - if (!fi) { + if (!fi) { fi = alloc(func_interp, m, arity); mdl->register_decl(f, fi); } - args.reset(); - for (enode* arg : enode_args(n)) { - args.push_back(m_values.get(arg->get_root_id())); - SASSERT(args.back()); - } + args.reset(); + for (enode* arg : enode_args(n)) + args.push_back(m_values.get(arg->get_root_id())); + DEBUG_CODE(for (expr* arg : args) VERIFY(arg);); SASSERT(args.size() == arity); if (!fi->get_entry(args.c_ptr())) fi->insert_new_entry(args.c_ptr(), v); @@ -207,6 +207,11 @@ namespace euf { for (enode* n : m_egraph.nodes()) if (n->is_root() && m_values.get(n->get_expr_id())) m_values2root.insert(m_values.get(n->get_expr_id()), n); +#if 0 + for (auto kv : m_values2root) { + std::cout << mk_pp(kv.m_key, m) << " -> " << bpp(kv.m_value) << "\n"; + } +#endif return m_values2root; } diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 1285e7150..8e534dbff 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -41,7 +41,6 @@ namespace euf { m_lookahead(nullptr), m_to_m(&m), m_to_si(&si), - m_reinit_exprs(m), m_values(m) { updt_params(p); @@ -96,6 +95,8 @@ namespace euf { return ext; if (fid == m.get_basic_family_id()) return nullptr; + if (fid == m.get_user_sort_family_id()) + return nullptr; pb_util pb(m); bv_util bvu(m); array_util au(m); @@ -369,6 +370,8 @@ namespace euf { m_explain.reset(); m_egraph.explain_eq(m_explain, e.child(), e.root()); m_egraph.end_explain(); + if (m_egraph.uses_congruence()) + return false; for (auto p : m_explain) { if (is_literal(p)) return false; @@ -483,10 +486,11 @@ namespace euf { } void solver::start_reinit(unsigned n) { - m_reinit_exprs.reset(); + m_reinit.reset(); for (sat::bool_var v : s().get_vars_to_reinit()) { expr* e = bool_var2expr(v); - m_reinit_exprs.push_back(e); + if (e) + m_reinit.push_back(reinit_t(expr_ref(e, m), get_enode(e)?get_enode(e)->generation():0, v)); } } @@ -496,8 +500,7 @@ namespace euf { * and replaying internalization. */ void solver::finish_reinit() { - SASSERT(s().get_vars_to_reinit().size() == m_reinit_exprs.size()); - if (s().get_vars_to_reinit().empty()) + if (m_reinit.empty()) return; struct scoped_set_replay { @@ -513,26 +516,23 @@ namespace euf { scoped_set_replay replay(*this); scoped_suspend_rlimit suspend_rlimit(m.limit()); - unsigned i = 0; - for (sat::bool_var v : s().get_vars_to_reinit()) { - expr* e = m_reinit_exprs.get(i++); - if (e) - replay.m.insert(e, v); - } - if (replay.m.empty()) - return; + for (auto const& t : m_reinit) + replay.m.insert(std::get<0>(t), std::get<2>(t)); TRACE("euf", for (auto const& kv : replay.m) tout << kv.m_value << "\n";); - for (auto const& kv : replay.m) { - TRACE("euf", tout << "replay: " << kv.m_value << " " << mk_bounded_pp(kv.m_key, m) << "\n";); + for (auto const& t : m_reinit) { + expr_ref e = std::get<0>(t); + unsigned generation = std::get<1>(t); + sat::bool_var v = std::get<2>(t); + scoped_generation _sg(*this, generation); + TRACE("euf", tout << "replay: " << v << " " << mk_bounded_pp(e, m) << "\n";); sat::literal lit; - expr* e = kv.m_key; if (si.is_bool_op(e)) lit = literal(replay.m[e], false); else - lit = si.internalize(kv.m_key, true); - VERIFY(lit.var() == kv.m_value); - attach_lit(lit, kv.m_key); + lit = si.internalize(e, true); + VERIFY(lit.var() == v); + attach_lit(lit, e); } TRACE("euf", display(tout << "replay done\n");); } diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 29f62afa2..9d62ae4f4 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -95,6 +95,7 @@ namespace euf { scoped_ptr m_dual_solver; user::solver* m_user_propagator{ nullptr }; th_solver* m_qsolver { nullptr }; + unsigned m_generation { 0 }; ptr_vector m_bool_var2expr; ptr_vector m_explain; @@ -121,7 +122,8 @@ namespace euf { euf::enode* mk_false(); // replay - expr_ref_vector m_reinit_exprs; + typedef std::tuple reinit_t; + vector m_reinit; void start_reinit(unsigned num_scopes); void finish_reinit(); @@ -214,6 +216,19 @@ namespace euf { } }; + struct scoped_generation { + solver& s; + unsigned m_g; + scoped_generation(solver& s, unsigned g): + s(s), + m_g(s.m_generation) { + s.m_generation = g; + } + ~scoped_generation() { + s.m_generation = m_g; + } + }; + // accessors sat::sat_internalizer& get_si() { return si; } @@ -310,7 +325,7 @@ namespace euf { void attach_node(euf::enode* n); expr_ref mk_eq(expr* e1, expr* e2); expr_ref mk_eq(euf::enode* n1, euf::enode* n2) { return mk_eq(n1->get_expr(), n2->get_expr()); } - euf::enode* mk_enode(expr* e, unsigned n, enode* const* args) { return m_egraph.mk(e, n, args); } + euf::enode* mk_enode(expr* e, unsigned n, enode* const* args) { return m_egraph.mk(e, m_generation, n, args); } expr* bool_var2expr(sat::bool_var v) const { return m_bool_var2expr.get(v, nullptr); } expr_ref literal2expr(sat::literal lit) const { expr* e = bool_var2expr(lit.var()); return lit.sign() ? expr_ref(m.mk_not(e), m) : expr_ref(e, m); } diff --git a/src/sat/smt/q_mbi.cpp b/src/sat/smt/q_mbi.cpp index 1ba985fbf..080b27bdb 100644 --- a/src/sat/smt/q_mbi.cpp +++ b/src/sat/smt/q_mbi.cpp @@ -50,7 +50,7 @@ namespace q { for (expr* e : universe) eqs.push_back(m.mk_eq(sk, e)); expr_ref fml = mk_or(eqs); - std::cout << "restrict to universe " << fml << "\n"; + // std::cout << "restrict to universe " << fml << "\n"; m_solver->assert_expr(fml); } @@ -82,14 +82,15 @@ namespace q { quantifier* q_flat = m_qs.flatten(q); init_solver(); ::solver::scoped_push _sp(*m_solver); - std::cout << "quantifier\n" << mk_pp(q, m, 4) << "\n"; + std::cout << mk_pp(q, m, 4) << "\n"; // std::cout << *m_model << "\n"; auto* qb = specialize(q_flat); if (!qb) return l_undef; + // return l_undef; if (m.is_false(qb->mbody)) return l_true; - std::cout << "body\n" << qb->mbody << "\n"; + // std::cout << "body\n" << qb->mbody << "\n"; m_solver->assert_expr(qb->mbody); lbool r = m_solver->check_sat(0, nullptr); if (r == l_undef) @@ -103,37 +104,29 @@ namespace q { if (is_exists(q)) qlit.neg(); unsigned i = 0; + expr_ref_vector eqs(m); if (!qb->var_args.empty()) { ::solver::scoped_push _sp(*m_solver); add_domain_eqs(*mdl0, *qb); - std::cout << "check\n"; for (; i < m_max_cex && l_true == m_solver->check_sat(0, nullptr); ++i) { m_solver->get_model(mdl1); - proj = solver_project(*mdl1, *qb); + proj = solver_project(*mdl1, *qb, eqs, true); if (!proj) break; - TRACE("q", tout << "project: " << proj << "\n";); std::cout << "project\n" << proj << "\n"; - std::cout << *m_model << "\n"; + std::cout << "eqs: " << eqs << "\n"; - static unsigned s_count = 0; - ++s_count; - if (s_count == 3) - exit(0); - ++m_stats.m_num_instantiations; - m_qs.add_clause(~qlit, ~ctx.mk_literal(proj)); - m_solver->assert_expr(m.mk_not(proj)); + add_instantiation(qlit, proj); + m_solver->assert_expr(m.mk_not(mk_and(eqs))); } } if (i == 0) { add_domain_bounds(*mdl0, *qb); - proj = solver_project(*mdl0, *qb); + proj = solver_project(*mdl0, *qb, eqs, false); if (!proj) return l_undef; std::cout << "project-base\n" << proj << "\n"; - TRACE("q", tout << "project-base: " << proj << "\n";); - ++m_stats.m_num_instantiations; - m_qs.add_clause(~qlit, ~ctx.mk_literal(proj)); + add_instantiation(qlit, proj); } // TODO: add as top-level clause for relevancy return l_false; @@ -142,30 +135,45 @@ namespace q { mbqi::q_body* mbqi::specialize(quantifier* q) { mbqi::q_body* result = nullptr; var_subst subst(m); + unsigned sz = q->get_num_decls(); if (!m_q2body.find(q, result)) { - unsigned sz = q->get_num_decls(); result = alloc(q_body, m); m_q2body.insert(q, result); ctx.push(new_obj_trail(result)); ctx.push(insert_obj_map(m_q2body, q)); + obj_hashtable _vars; app_ref_vector& vars = result->vars; vars.resize(sz, nullptr); for (unsigned i = 0; i < sz; ++i) { sort* s = q->get_decl_sort(i); vars[i] = m.mk_fresh_const(q->get_decl_name(i), s, false); - if (m_model->has_uninterpreted_sort(s)) - restrict_to_universe(vars.get(i), m_model->get_universe(s)); + _vars.insert(vars.get(i)); } expr_ref fml = subst(q->get_expr(), vars); extract_var_args(q->get_expr(), *result); if (is_forall(q)) fml = m.mk_not(fml); flatten_and(fml, result->vbody); + for (expr* e : result->vbody) { + expr* e1 = nullptr, *e2 = nullptr; + if (m.is_not(e, e) && m.is_eq(e, e1, e2)) { + if (_vars.contains(e1) && !_vars.contains(e2) && is_app(e2)) + result->var_diff.push_back(std::make_pair(to_app(e1), to_app(e2)->get_decl())); + else if (_vars.contains(e2) && !_vars.contains(e1) && is_app(e1)) + result->var_diff.push_back(std::make_pair(to_app(e2), to_app(e1)->get_decl())); + } + } } expr_ref& mbody = result->mbody; if (!m_model->eval_expr(q->get_expr(), mbody, true)) return nullptr; + for (unsigned i = 0; i < sz; ++i) { + sort* s = q->get_decl_sort(i); + if (m_model->has_uninterpreted_sort(s)) + restrict_to_universe(result->vars.get(i), m_model->get_universe(s)); + } + mbody = subst(mbody, result->vars); if (is_forall(q)) mbody = mk_not(m, mbody); @@ -173,7 +181,8 @@ namespace q { return result; } - expr_ref mbqi::solver_project(model& mdl, q_body& qb) { + expr_ref mbqi::solver_project(model& mdl, q_body& qb, expr_ref_vector& eqs, bool use_inst) { + eqs.reset(); model::scoped_model_completion _sc(mdl, true); for (app* v : qb.vars) m_model->register_decl(v->get_decl(), mdl(v)); @@ -186,13 +195,14 @@ namespace q { tout << fmls << "\n"; tout << "model of projection\n" << mdl << "\n"; tout << "var args: " << qb.var_args.size() << "\n"; + tout << "domain eqs: " << qb.domain_eqs << "\n"; for (expr* f : fmls) if (m_model->is_false(f)) tout << mk_pp(f, m) << " := false\n"; tout << "vars: " << vars << "\n";); expr_safe_replace rep(m); - for (unsigned i = 0; i < vars.size(); ++i) { + for (unsigned i = 0; !use_inst && i < vars.size(); ++i) { app* v = vars.get(i); auto* p = get_plugin(v); if (p && !fmls_extracted) { @@ -213,6 +223,7 @@ namespace q { rep.insert(v, term); if (val != term) rep.insert(val, term); + eqs.push_back(m.mk_eq(v, val)); } rep(fmls); return mk_and(fmls); @@ -228,7 +239,59 @@ namespace q { void mbqi::add_domain_eqs(model& mdl, q_body& qb) { qb.domain_eqs.reset(); var_subst subst(m); + expr_mark diff_vars; + for (auto vd : qb.var_diff) { + app* v = vd.first; + func_decl* f = vd.second; + expr_ref_vector diff_set(m), vdiff_set(m); + typedef std::tuple tup; + svector todo; + expr_mark visited; + expr_ref val(m); + for (euf::enode* n : ctx.get_egraph().enodes_of(f)) { + euf::enode* r1 = n->get_root(); + expr* e1 = n->get_expr(); + todo.push_back(tup(r1, 2, 2)); + for (unsigned i = 0; i < todo.size(); ++i) { + auto t = todo[i]; + euf::enode* r2 = std::get<0>(t)->get_root(); + expr* e2 = r2->get_expr(); + if (visited.is_marked(e2)) + continue; + visited.mark(e2); + std::cout << "try: " << mk_bounded_pp(e2, m) << " " << std::get<1>(t) << " " << std::get<2>(t) << "\n"; + if (r1 != r2 && m.get_sort(e1) == m.get_sort(e2) && m_model->eval_expr(e2, val, true) && !visited.is_marked(val)) { + visited.mark(val); + diff_set.push_back(m.mk_eq(v, val)); + vdiff_set.push_back(m.mk_eq(v, e2)); + } + if (std::get<1>(t) > 0) + for (euf::enode* p : euf::enode_parents(r2)) + todo.push_back(tup(p, std::get<1>(t)-1, std::get<2>(t)+1)); + if (std::get<2>(t) > 0) + for (euf::enode* n : euf::enode_class(r2)) + for (euf::enode* arg : euf::enode_args(n)) + todo.push_back(tup(arg, 0, std::get<2>(t)-1)); + + } + todo.reset(); + } + if (!diff_set.empty()) { + diff_vars.mark(v); + expr_ref diff = mk_or(diff_set); + expr_ref vdiff = mk_or(vdiff_set); + std::cout << "diff: " << vdiff_set << "\n"; + m_solver->assert_expr(diff); + qb.domain_eqs.push_back(vdiff); + } + std::cout << "var-diff: " << mk_pp(vd.first, m) << " " << mk_pp(vd.second, m) << "\n"; + } + for (auto p : qb.var_args) { + expr_ref arg(p.first->get_arg(p.second), m); + arg = subst(arg, qb.vars); + if (diff_vars.is_marked(arg)) + continue; expr_ref bounds = m_model_fixer.restrict_arg(p.first, p.second); if (m.is_true(bounds)) continue; @@ -237,13 +300,15 @@ namespace q { if (!m_model->eval_expr(bounds, mbounds, true)) return; mbounds = subst(mbounds, qb.vars); + std::cout << "bounds: " << mk_pp(p.first, m) << " @ " << p.second << " - " << bounds << "\n"; std::cout << "domain eqs " << mbounds << "\n"; + std::cout << "vbounds " << vbounds << "\n"; + std::cout << *m_model << "\n"; m_solver->assert_expr(mbounds); qb.domain_eqs.push_back(vbounds); } } - /* * Add bounds to sub-terms under uninterpreted functions for projection. */ @@ -279,8 +344,13 @@ namespace q { expr_safe_replace rep(m); var_subst subst(m); expr_ref_vector eqs(m); + expr_mark visited; for (auto p : qb.var_args) { - expr_ref _term = subst(p.first, qb.vars); + expr* e = p.first; + if (visited.is_marked(e)) + continue; + visited.mark(e); + expr_ref _term = subst(e, qb.vars); app_ref term(to_app(_term), m); expr_ref value = (*m_model)(term); expr* s = m_model_fixer.invert_app(term, value); @@ -313,6 +383,7 @@ namespace q { lbool mbqi::operator()() { lbool result = l_true; m_model = nullptr; + m_instantiations.reset(); for (sat::literal lit : m_qs.m_universal) { quantifier* q = to_quantifier(ctx.bool_var2expr(lit.var())); if (!ctx.is_relevant(q)) @@ -331,6 +402,9 @@ namespace q { } } m_max_cex += ctx.get_config().m_mbqi_max_cexs; + for (auto p : m_instantiations) + m_qs.add_clause(~p.first, ~ctx.mk_literal(p.second)); + m_instantiations.reset(); return result; } diff --git a/src/sat/smt/q_mbi.h b/src/sat/smt/q_mbi.h index 7e6c80e40..48b2c25df 100644 --- a/src/sat/smt/q_mbi.h +++ b/src/sat/smt/q_mbi.h @@ -44,6 +44,7 @@ namespace q { expr_ref mbody; // body specialized with respect to model expr_ref_vector vbody; // (negation of) body specialized with respect to vars expr_ref_vector domain_eqs; // additional domain restrictions + svector> var_diff; // variable differences svector> var_args; // (uninterpreted) functions in vbody that contain arguments with variables q_body(ast_manager& m) : vars(m), mbody(m), vbody(m), domain_eqs(m) {} }; @@ -59,6 +60,7 @@ namespace q { scoped_ptr_vector m_plugins; obj_map m_q2body; unsigned m_max_cex{ 1 }; + vector> m_instantiations; void restrict_to_universe(expr * sk, ptr_vector const & universe); // void register_value(expr* e); @@ -66,7 +68,7 @@ namespace q { expr_ref choose_term(euf::enode* r); lbool check_forall(quantifier* q); q_body* specialize(quantifier* q); - expr_ref solver_project(model& mdl, q_body& qb); + expr_ref solver_project(model& mdl, q_body& qb, expr_ref_vector& eqs, bool use_inst); void add_domain_eqs(model& mdl, q_body& qb); void add_domain_bounds(model& mdl, q_body& qb); void eliminate_nested_vars(expr_ref_vector& fmls, q_body& qb); @@ -75,6 +77,11 @@ namespace q { void init_solver(); mbp::project_plugin* get_plugin(app* var); void add_plugin(mbp::project_plugin* p); + void add_instantiation(sat::literal qlit, expr_ref& proj) { + TRACE("q", tout << "project: " << proj << "\n";); + ++m_stats.m_num_instantiations; + m_instantiations.push_back(std::make_pair(qlit, proj)); + } public: diff --git a/src/sat/smt/q_model_fixer.cpp b/src/sat/smt/q_model_fixer.cpp index c5329b923..4d5a1a275 100644 --- a/src/sat/smt/q_model_fixer.cpp +++ b/src/sat/smt/q_model_fixer.cpp @@ -220,8 +220,10 @@ namespace q { euf::enode* r = nullptr; TRACE("q", tout << "invert-app " << mk_pp(t, m) << " = " << mk_pp(value, m) << "\n"; - if (ctx.values2root().find(value, r)) - tout << "inverse " << mk_pp(r->get_expr(), m) << "\n";); + if (ctx.values2root().find(value, r)) + tout << "inverse " << mk_pp(r->get_expr(), m) << "\n"; + ctx.display(tout); + ); if (ctx.values2root().find(value, r)) return r->get_expr(); return value;