diff --git a/src/cmd_context/extra_cmds/dbg_cmds.cpp b/src/cmd_context/extra_cmds/dbg_cmds.cpp index 7c1383d62..d9e575664 100644 --- a/src/cmd_context/extra_cmds/dbg_cmds.cpp +++ b/src/cmd_context/extra_cmds/dbg_cmds.cpp @@ -519,6 +519,7 @@ public: model_ref mdl; s->get_model(mdl); qe::euf_arith_mbi_plugin plugin(s.get(), se.get()); + plugin.set_shared(vars); plugin.project(mdl, lits); ctx.regular_stream() << lits << "\n"; } diff --git a/src/model/model_evaluator.cpp b/src/model/model_evaluator.cpp index 1991fc4d2..1bd274316 100644 --- a/src/model/model_evaluator.cpp +++ b/src/model/model_evaluator.cpp @@ -673,6 +673,15 @@ bool model_evaluator::is_true(expr_ref_vector const& ts) { return true; } +bool model_evaluator::are_equal(expr* s, expr* t) { + if (m().are_equal(s, t)) return true; + if (m().are_distinct(s, t)) return false; + expr_ref t1(m()), t2(m()); + eval(t, t1, true); + eval(s, t2, true); + return m().are_equal(t1, t2); +} + bool model_evaluator::eval(expr* t, expr_ref& r, bool model_completion) { set_model_completion(model_completion); try { diff --git a/src/model/model_evaluator.h b/src/model/model_evaluator.h index 2ffcf2e0b..9f2320a66 100644 --- a/src/model/model_evaluator.h +++ b/src/model/model_evaluator.h @@ -55,6 +55,8 @@ public: bool is_true(expr * t); bool is_false(expr * t); bool is_true(expr_ref_vector const& ts); + bool are_equal(expr* s, expr* t); + void set_solver(expr_solver* solver); bool has_solver(); diff --git a/src/qe/qe_arith.cpp b/src/qe/qe_arith.cpp index a0aea08da..38027efbc 100644 --- a/src/qe/qe_arith.cpp +++ b/src/qe/qe_arith.cpp @@ -604,6 +604,11 @@ namespace qe { return m_imp->maximize(fmls, mdl, t, ge, gt); } + void arith_project_plugin::saturate(model& model, func_decl_ref_vector const& shared, expr_ref_vector& lits) { + UNREACHABLE(); + } + + bool arith_project(model& model, app* var, expr_ref_vector& lits) { ast_manager& m = lits.get_manager(); arith_project_plugin ap(m); diff --git a/src/qe/qe_arith.h b/src/qe/qe_arith.h index b55e63fcf..acff8e13b 100644 --- a/src/qe/qe_arith.h +++ b/src/qe/qe_arith.h @@ -31,6 +31,7 @@ namespace qe { family_id get_family_id() override; void operator()(model& model, app_ref_vector& vars, expr_ref_vector& lits) override; vector project(model& model, app_ref_vector& vars, expr_ref_vector& lits) override; + void saturate(model& model, func_decl_ref_vector const& shared, expr_ref_vector& lits) override; opt::inf_eps maximize(expr_ref_vector const& fmls, model& mdl, app* t, expr_ref& ge, expr_ref& gt); diff --git a/src/qe/qe_arrays.cpp b/src/qe/qe_arrays.cpp index 4bead6684..460f9ea37 100644 --- a/src/qe/qe_arrays.cpp +++ b/src/qe/qe_arrays.cpp @@ -21,12 +21,14 @@ Revision History: #include "util/lbool.h" #include "ast/rewriter/rewriter_def.h" #include "ast/expr_functors.h" +#include "ast/for_each_expr.h" #include "ast/rewriter/expr_safe_replace.h" #include "ast/rewriter/th_rewriter.h" #include "ast/ast_util.h" #include "ast/ast_pp.h" #include "model/model_evaluator.h" #include "qe/qe_arrays.h" +#include "qe/qe_term_graph.h" namespace { @@ -357,7 +359,7 @@ namespace qe { ptr_vector sel_args; sel_args.push_back (arr); sel_args.append(I[i].size(), I[i].c_ptr()); - expr_ref val_term (m_arr_u.mk_select (sel_args.size (), sel_args.c_ptr ()), m); + expr_ref val_term (m_arr_u.mk_select (sel_args), m); // evaluate and assign to ith diff_val_const val = (*m_mev)(val_term); M->register_decl (diff_val_consts.get (i)->get_decl (), val); @@ -451,7 +453,7 @@ namespace qe { ptr_vector sel_args; sel_args.push_back (arr1); sel_args.append(idxs.size(), idxs.c_ptr()); - expr_ref arr1_idx (m_arr_u.mk_select (sel_args.size (), sel_args.c_ptr ()), m); + expr_ref arr1_idx (m_arr_u.mk_select (sel_args), m); expr_ref eq (m.mk_eq (arr1_idx, x), m); m_aux_lits_v.push_back (eq); @@ -821,7 +823,7 @@ namespace qe { ptr_vector args; args.push_back(array); args.append(arity, js); - expr* r = m_arr_u.mk_select (args.size(), args.c_ptr()); + expr* r = m_arr_u.mk_select (args); m_pinned.push_back (r); return r; } @@ -1191,7 +1193,7 @@ namespace qe { array_util a; scoped_ptr m_var; - imp(ast_manager& m): m(m), a(m) {} + imp(ast_manager& m): m(m), a(m), m_stores(m) {} ~imp() {} bool solve(model& model, app_ref_vector& vars, expr_ref_vector& lits) { @@ -1267,7 +1269,7 @@ namespace qe { args.push_back (s); args.append(idxs[i].m_values.size(), idxs[i].m_vars); - sel = a.mk_select (args.size (), args.c_ptr ()); + sel = a.mk_select (args); val = model(sel); model.register_decl (var->get_decl (), val); @@ -1306,7 +1308,7 @@ namespace qe { } args.push_back(t); args.append(n, s->get_args()+1); - lits.push_back(m.mk_eq(a.mk_select(args.size(), args.c_ptr()), s->get_arg(n+1))); + lits.push_back(m.mk_eq(a.mk_select(args), s->get_arg(n+1))); idxs.push_back(idx); return solve(model, to_app(s->get_arg(0)), t, idxs, vars, lits); case l_undef: @@ -1357,6 +1359,217 @@ namespace qe { } return l_undef; } + + void saturate(model& model, func_decl_ref_vector const& shared, expr_ref_vector& lits) { + term_graph tg(m); + tg.set_vars(shared, false); + tg.add_model_based_terms(model, lits); + + // need tg to take term and map it to optional rep over the + // shared vocabulary if it exists. + + // . collect shared store expressions, index sorts + // . collect shared index expressions + // . assert extensionality (add shared index expressions) + // . assert store axioms for collected expressions + + collect_store_expressions(tg, lits); + collect_index_expressions(tg, lits); + + TRACE("qe", + tout << "indices\n"; + for (auto& kv : m_indices) { + tout << sort_ref(kv.m_key, m) << " |-> " << *kv.m_value << "\n"; + } + tout << "stores " << m_stores << "\n"; + tout << "arrays\n"; + for (auto& kv : m_arrays) { + tout << sort_ref(kv.m_key, m) << " |-> " << *kv.m_value << "\n"; + }); + + assert_extensionality(model, tg, lits); + assert_store_select(model, tg, lits); + + TRACE("qe", tout << lits << "\n";); + + for (auto& kv : m_indices) { + dealloc(kv.m_value); + } + for (auto& kv : m_arrays) { + dealloc(kv.m_value); + } + m_stores.reset(); + m_indices.reset(); + m_arrays.reset(); + + TRACE("qe", tout << "done: " << lits << "\n";); + + } + + app_ref_vector m_stores; + obj_map m_indices; + obj_map m_arrays; + + void add_index_sort(expr* n) { + sort* s = m.get_sort(n); + if (!m_indices.contains(s)) { + m_indices.insert(s, alloc(app_ref_vector, m)); + } + } + + void add_array(app* n) { + sort* s = m.get_sort(n); + app_ref_vector* vs = nullptr; + if (!m_arrays.find(s, vs)) { + vs = alloc(app_ref_vector, m); + m_arrays.insert(s, vs); + } + vs->push_back(n); + } + + app_ref_vector* is_index(expr* n) { + app_ref_vector* result = nullptr; + m_indices.find(m.get_sort(n), result); + return result; + } + + struct for_each_store_proc { + imp& m_imp; + term_graph& tg; + for_each_store_proc(imp& i, term_graph& tg) : m_imp(i), tg(tg) {} + + void operator()(app* n) { + if (m_imp.a.is_array(n) && tg.get_model_based_rep(n)) { + m_imp.add_array(n); + } + + if (m_imp.a.is_store(n) && + (tg.get_model_based_rep(n->get_arg(0)) || + tg.get_model_based_rep(n->get_arg(n->get_num_args() - 1)))) { + m_imp.m_stores.push_back(n); + for (unsigned i = 1; i + 1 < n->get_num_args(); ++i) { + m_imp.add_index_sort(n->get_arg(i)); + } + } + } + + void operator()(expr* e) {} + }; + + struct for_each_index_proc { + imp& m_imp; + term_graph& tg; + for_each_index_proc(imp& i, term_graph& tg) : m_imp(i), tg(tg) {} + + void operator()(app* n) { + auto* v = m_imp.is_index(n); + if (v && tg.get_model_based_rep(n)) { + v->push_back(n); + } + } + + void operator()(expr* e) {} + + }; + + void collect_store_expressions(term_graph& tg, expr_ref_vector const& terms) { + for_each_store_proc proc(*this, tg); + for_each_expr(proc, terms); + } + + void collect_index_expressions(term_graph& tg, expr_ref_vector const& terms) { + for_each_index_proc proc(*this, tg); + for_each_expr(proc, terms); + } + + bool are_equal(model& mdl, expr* s, expr* t) { + return mdl.are_equal(s, t); + } + + void assert_extensionality(model & mdl, term_graph& tg, expr_ref_vector& lits) { + for (auto& kv : m_arrays) { + app_ref_vector const& vs = *kv.m_value; + if (vs.size() <= 1) continue; + func_decl_ref_vector ext(m); + sort* s = kv.m_key; + unsigned arity = get_array_arity(s); + for (unsigned i = 0; i < arity; ++i) { + ext.push_back(a.mk_array_ext(s, i)); + } + expr_ref_vector args(m); + args.resize(arity + 1); + for (unsigned i = 0; i < vs.size(); ++i) { + expr* s = vs[i]; + for (unsigned j = i + 1; j < vs.size(); ++j) { + expr* t = vs[j]; + if (are_equal(mdl, s, t)) { + lits.push_back(m.mk_eq(s, t)); + } + else { + for (unsigned k = 0; k < arity; ++k) { + args[k+1] = m.mk_app(ext.get(k), s, t); + } + args[0] = t; + expr* t1 = a.mk_select(args); + args[0] = s; + expr* s1 = a.mk_select(args); + lits.push_back(m.mk_not(m.mk_eq(t1, s1))); + } + } + } + } + } + + void assert_store_select(model & mdl, term_graph& tg, expr_ref_vector& lits) { + for (auto& store : m_stores) { + assert_store_select(store, mdl, tg, lits); + } + } + + void assert_store_select(app* store, model & mdl, term_graph& tg, expr_ref_vector& lits) { + SASSERT(a.is_store(store)); + ptr_vector indices; + for (unsigned i = 1; i + 1 < store->get_num_args(); ++i) { + SASSERT(indices.empty()); + assert_store_select(indices, store, mdl, tg, lits); + } + } + + void assert_store_select(ptr_vector& indices, app* store, model & mdl, term_graph& tg, expr_ref_vector& lits) { + unsigned sz = store->get_num_args(); + if (indices.size() + 2 == sz) { + ptr_vector args; + args.push_back(store); + for (expr* idx : indices) args.push_back(idx); + for (unsigned i = 1; i + 1 < sz; ++i) { + expr* idx1 = store->get_arg(i); + expr* idx2 = indices[i - 1]; + if (!are_equal(mdl, idx1, idx2)) { + lits.push_back(m.mk_not(m.mk_eq(idx1, idx2))); + lits.push_back(m.mk_eq(store->get_arg(sz-1), a.mk_select(args))); + return; + } + } + for (unsigned i = 1; i + 1 < sz; ++i) { + expr* idx1 = store->get_arg(i); + expr* idx2 = indices[i - 1]; + lits.push_back(m.mk_eq(idx1, idx2)); + } + expr* a1 = a.mk_select(args); + args[0] = store->get_arg(0); + expr* a2 = a.mk_select(args); + lits.push_back(m.mk_eq(a1, a2)); + } + else { + sort* s = m.get_sort(store->get_arg(indices.size() + 1)); + for (app* idx : *m_indices[s]) { + indices.push_back(idx); + assert_store_select(indices, store, mdl, tg, lits); + indices.pop_back(); + } + } + } + }; @@ -1417,4 +1630,9 @@ namespace qe { return vector(); } + void array_project_plugin::saturate(model& model, func_decl_ref_vector const& shared, expr_ref_vector& lits) { + m_imp->saturate(model, shared, lits); + } + + }; diff --git a/src/qe/qe_arrays.h b/src/qe/qe_arrays.h index 3bb90335d..84e318426 100644 --- a/src/qe/qe_arrays.h +++ b/src/qe/qe_arrays.h @@ -37,6 +37,8 @@ namespace qe { void operator()(model& model, app_ref_vector& vars, expr_ref& fml, app_ref_vector& aux_vars, bool reduce_all_selects); family_id get_family_id() override; vector project(model& model, app_ref_vector& vars, expr_ref_vector& lits) override; + void saturate(model& model, func_decl_ref_vector const& shared, expr_ref_vector& lits) override; + }; }; diff --git a/src/qe/qe_datatypes.cpp b/src/qe/qe_datatypes.cpp index 4109d7fd9..85bb14640 100644 --- a/src/qe/qe_datatypes.cpp +++ b/src/qe/qe_datatypes.cpp @@ -303,6 +303,11 @@ namespace qe { vector datatype_project_plugin::project(model& model, app_ref_vector& vars, expr_ref_vector& lits) { return vector(); } + + void datatype_project_plugin::saturate(model& model, func_decl_ref_vector const& shared, expr_ref_vector& lits) { + NOT_IMPLEMENTED_YET(); + } + family_id datatype_project_plugin::get_family_id() { return m_imp->dt.get_family_id(); diff --git a/src/qe/qe_datatypes.h b/src/qe/qe_datatypes.h index 0483f4cce..50a3930e9 100644 --- a/src/qe/qe_datatypes.h +++ b/src/qe/qe_datatypes.h @@ -36,6 +36,8 @@ namespace qe { bool solve(model& model, app_ref_vector& vars, expr_ref_vector& lits) override; family_id get_family_id() override; vector project(model& model, app_ref_vector& vars, expr_ref_vector& lits) override; + void saturate(model& model, func_decl_ref_vector const& shared, expr_ref_vector& lits) override; + }; }; diff --git a/src/qe/qe_mbi.cpp b/src/qe/qe_mbi.cpp index 750dedbce..71a783c4f 100644 --- a/src/qe/qe_mbi.cpp +++ b/src/qe/qe_mbi.cpp @@ -38,7 +38,7 @@ Notes: #include "qe/qe_mbi.h" #include "qe/qe_term_graph.h" #include "qe/qe_arith.h" -// include "opt/opt_context.h" +#include "qe/qe_arrays.h" namespace qe { @@ -263,6 +263,9 @@ namespace qe { TRACE("qe", tout << lits << "\n" << *mdl << "\n";); TRACE("qe", tout << m_solver->get_assertions() << "\n";); + // 0. saturation + array_project_plugin arp(m); + arp.saturate(*mdl, m_shared, lits); // . arithmetical variables - atomic and in purified positions app_ref_vector proxies(m); diff --git a/src/qe/qe_mbp.h b/src/qe/qe_mbp.h index 0bb8ba00f..1b7dedbd4 100644 --- a/src/qe/qe_mbp.h +++ b/src/qe/qe_mbp.h @@ -57,6 +57,13 @@ namespace qe { */ virtual vector project(model& model, app_ref_vector& vars, expr_ref_vector& lits) = 0; + /** + \brief model based saturation. Saturates theory axioms to equi-satisfiable literals over EUF, + such that 'shared' are not retained for EUF. + */ + virtual void saturate(model& model, func_decl_ref_vector const& shared, expr_ref_vector& lits) = 0; + + static expr_ref pick_equality(ast_manager& m, model& model, expr* t); static void erase(expr_ref_vector& lits, unsigned& i); static void push_back(expr_ref_vector& lits, expr* lit); diff --git a/src/qe/qe_term_graph.cpp b/src/qe/qe_term_graph.cpp index 274c25293..166b18bc9 100644 --- a/src/qe/qe_term_graph.cpp +++ b/src/qe/qe_term_graph.cpp @@ -217,6 +217,7 @@ namespace qe { bool term_graph::is_variable_proc::operator()(const expr * e) const { if (!is_app(e)) return false; const app *a = ::to_app(e); + TRACE("qe", tout << a->get_family_id() << " " << m_solved.contains(a->get_decl()) << " " << m_decls.contains(a->get_decl()) << "\n";); return a->get_family_id() == null_family_id && !m_solved.contains(a->get_decl()) && @@ -242,12 +243,13 @@ namespace qe { bool term_graph::term_eq::operator()(term const* a, term const* b) const { return term::cg_eq(a, b); } - term_graph::term_graph(ast_manager &man) : m(man), m_lits(m), m_pinned(m) { + term_graph::term_graph(ast_manager &man) : m(man), m_lits(m), m_pinned(m), m_projector(nullptr) { m_plugins.register_plugin(mk_basic_solve_plugin(m, m_is_var)); m_plugins.register_plugin(mk_arith_solve_plugin(m, m_is_var)); } term_graph::~term_graph() { + dealloc(m_projector); reset(); } @@ -582,12 +584,14 @@ namespace qe { u_map m_term2app; u_map m_root2rep; + model_ref m_model; expr_ref_vector m_pinned; // tracks expr in the maps expr* mk_pure(term const& t) { + TRACE("qe", t.display(tout);); expr* e = nullptr; - if (m_term2app.find(t.get_id(), e)) return e; + if (find_term2app(t, e)) return e; e = t.get_expr(); if (!is_app(e)) return nullptr; app* a = ::to_app(e); @@ -595,17 +599,20 @@ namespace qe { for (term* ch : term::children(t)) { // prefer a node that resembles current child, // otherwise, pick a root representative, if present. - if (m_term2app.find(ch->get_id(), e)) - kids.push_back(e); - else if (m_root2rep.find(ch->get_root().get_id(), e)) + if (find_term2app(*ch, e)) { kids.push_back(e); - else + } + else if (m_root2rep.find(ch->get_root().get_id(), e)) { + kids.push_back(e); + } + else { return nullptr; + } TRACE("qe_verbose", tout << *ch << " -> " << mk_pp(e, m) << "\n";); } expr* pure = m.mk_app(a->get_decl(), kids.size(), kids.c_ptr()); m_pinned.push_back(pure); - m_term2app.insert(t.get_id(), pure); + add_term2app(t, pure); return pure; } @@ -621,69 +628,15 @@ namespace qe { } }; - void purify() { - // - propagate representatives up over parents. - // use work-list + marking to propagate. - // - produce equalities over represented classes. - // - produce other literals over represented classes - // (walk disequalities in m_lits and represent - // lhs/rhs over decls or excluding decls) - - ptr_vector worklist; - for (term * t : m_tg.m_terms) { - worklist.push_back(t); - t->set_mark(true); - } - // traverse worklist in order of depth. - term_depth td; - std::sort(worklist.begin(), worklist.end(), td); - - for (unsigned i = 0; i < worklist.size(); ++i) { - term* t = worklist[i]; - t->set_mark(false); - if (m_term2app.contains(t->get_id())) - continue; - if (!t->is_theory() && is_projected(*t)) - continue; - - expr* pure = mk_pure(*t); - if (!pure) continue; - - m_term2app.insert(t->get_id(), pure); - TRACE("qe_verbose", tout << "purified " << *t << " " << mk_pp(pure, m) << "\n";); - expr* rep = nullptr; // ensure that the root has a representative - m_root2rep.find(t->get_root().get_id(), rep); - - // update rep with pure if it is better - if (pure != rep && is_better_rep(pure, rep)) { - m_root2rep.insert(t->get_root().get_id(), pure); - for (term * p : term::parents(t->get_root())) { - m_term2app.remove(p->get_id()); - if (!p->is_marked()) { - p->set_mark(true); - worklist.push_back(p); - } - } - } - } - - // Here we could also walk equivalence classes that - // contain interpreted values by sort and extract - // disequalities between non-unique value - // representatives. these disequalities are implied - // and can be mined using other means, such as theory - // aware core minimization - m_tg.reset_marks(); - TRACE("qe", display(tout << "after purify\n");); - } void solve_core() { ptr_vector worklist; for (term * t : m_tg.m_terms) { // skip pure terms - if (m_term2app.contains(t->get_id())) continue; - worklist.push_back(t); - t->set_mark(true); + if (!in_term2app(*t)) { + worklist.push_back(t); + t->set_mark(true); + } } term_depth td; std::sort(worklist.begin(), worklist.end(), td); @@ -691,13 +644,14 @@ namespace qe { for (unsigned i = 0; i < worklist.size(); ++i) { term* t = worklist[i]; t->set_mark(false); - if (m_term2app.contains(t->get_id())) + if (in_term2app(*t)) continue; expr* pure = mk_pure(*t); - if (!pure) continue; + if (!pure) + continue; - m_term2app.insert(t->get_id(), pure); + add_term2app(*t, pure); expr* rep = nullptr; // ensure that the root has a representative m_root2rep.find(t->get_root().get_id(), rep); @@ -705,7 +659,7 @@ namespace qe { if (!rep) { m_root2rep.insert(t->get_root().get_id(), pure); for (term * p : term::parents(t->get_root())) { - SASSERT(!m_term2app.contains(p->get_id())); + SASSERT(!in_term2app(*p)); if (!p->is_marked()) { p->set_mark(true); worklist.push_back(p); @@ -718,14 +672,14 @@ namespace qe { bool find_app(term &t, expr *&res) { return - m_term2app.find(t.get_id(), res) || + find_term2app(t, res) || m_root2rep.find(t.get_root().get_id(), res); } bool find_app(expr *lit, expr *&res) { term const* t = m_tg.get_term(lit); return - m_term2app.find(t->get_id(), res) || + find_term2app(*t, res) || m_root2rep.find(t->get_root().get_id(), res); } @@ -856,7 +810,7 @@ namespace qe { term const * r = &t; do { expr* member = nullptr; - if (m_term2app.find(r->get_id(), member) && !members.contains(member)) { + if (find_term2app(*r, member) && !members.contains(member)) { res.push_back (m.mk_eq (rep, member)); members.insert(member); } @@ -865,7 +819,9 @@ namespace qe { while (r != &t); } - bool is_projected(const term &t) {return m_tg.m_is_var(t);} + bool is_projected(const term &t) { + return m_tg.m_is_var(t); + } void mk_unpure_equalities(const term &t, expr_ref_vector &res) { expr *rep = nullptr; @@ -981,6 +937,28 @@ namespace qe { public: projector(term_graph &tg) : m_tg(tg), m(m_tg.m), m_pinned(m) {} + void add_term2app(term const& t, expr* a) { + m_term2app.insert(t.get_id(), a); + } + + void del_term2app(term const& t) { + m_term2app.remove(t.get_id()); + } + + bool find_term2app(term const& t, expr*& r) { + return m_term2app.find(t.get_id(), r); + } + + expr* find_term2app(term const& t) { + expr* r = nullptr; + find_term2app(t, r); + return r; + } + + bool in_term2app(term const& t) { + return m_term2app.contains(t.get_id()); + } + void set_model(model &mdl) { m_model = &mdl; } void reset() { @@ -1025,7 +1003,7 @@ namespace qe { return res; } - vector get_partition(model& mdl) { + vector get_partition(model& mdl, bool include_bool) { vector result; expr_ref_vector pinned(m); obj_map pid; @@ -1033,7 +1011,7 @@ namespace qe { for (term *t : m_tg.m_terms) { expr* a = t->get_expr(); if (!is_app(a)) continue; - if (m.is_bool(a)) continue; + if (m.is_bool(a) && !include_bool) continue; expr_ref val = mdl(a); unsigned p = 0; // NB. works for simple domains Integers, Rationals, @@ -1065,6 +1043,63 @@ namespace qe { } return result; } + + void purify() { + // - propagate representatives up over parents. + // use work-list + marking to propagate. + // - produce equalities over represented classes. + // - produce other literals over represented classes + // (walk disequalities in m_lits and represent + // lhs/rhs over decls or excluding decls) + + ptr_vector worklist; + for (term * t : m_tg.m_terms) { + worklist.push_back(t); + t->set_mark(true); + } + // traverse worklist in order of depth. + term_depth td; + std::sort(worklist.begin(), worklist.end(), td); + + for (unsigned i = 0; i < worklist.size(); ++i) { + term* t = worklist[i]; + t->set_mark(false); + if (in_term2app(*t)) + continue; + if (!t->is_theory() && is_projected(*t)) + continue; + + expr* pure = mk_pure(*t); + if (!pure) continue; + + add_term2app(*t, pure); + TRACE("qe_verbose", tout << "purified " << *t << " " << mk_pp(pure, m) << "\n";); + expr* rep = nullptr; // ensure that the root has a representative + m_root2rep.find(t->get_root().get_id(), rep); + + // update rep with pure if it is better + if (pure != rep && is_better_rep(pure, rep)) { + m_root2rep.insert(t->get_root().get_id(), pure); + for (term * p : term::parents(t->get_root())) { + del_term2app(*p); + if (!p->is_marked()) { + p->set_mark(true); + worklist.push_back(p); + } + } + } + } + + // Here we could also walk equivalence classes that + // contain interpreted values by sort and extract + // disequalities between non-unique value + // representatives. these disequalities are implied + // and can be mined using other means, such as theory + // aware core minimization + m_tg.reset_marks(); + TRACE("qe", display(tout << "after purify\n");); + } + }; void term_graph::set_vars(func_decl_ref_vector const& decls, bool exclude) { @@ -1094,13 +1129,15 @@ namespace qe { expr_ref_vector term_graph::get_ackerman_disequalities() { m_is_var.reset_solved(); - term_graph::projector p(*this); - return p.get_ackerman_disequalities(); + dealloc(m_projector); + m_projector = alloc(term_graph::projector, *this); + return m_projector->get_ackerman_disequalities(); } vector term_graph::get_partition(model& mdl) { - term_graph::projector p(*this); - return p.get_partition(mdl); + dealloc(m_projector); + m_projector = alloc(term_graph::projector, *this); + return m_projector->get_partition(mdl, false); } expr_ref_vector term_graph::shared_occurrences(family_id fid) { @@ -1108,4 +1145,42 @@ namespace qe { return p.shared_occurrences(fid); } + void term_graph::add_model_based_terms(model& mdl, expr_ref_vector const& terms) { + for (expr* t : terms) { + internalize_term(t); + } + m_is_var.reset_solved(); + + SASSERT(!m_projector); + m_projector = alloc(term_graph::projector, *this); + + // retrieve partition of terms + vector equivs = m_projector->get_partition(mdl, true); + + // merge term graph on equal terms. + for (auto const& cs : equivs) { + term* t0 = get_term(cs[0]); + for (unsigned i = 1; i < cs.size(); ++i) { + merge(*t0, *get_term(cs[i])); + } + } + TRACE("qe", + for (auto & es : equivs) { + tout << "equiv: "; + for (expr* t : es) tout << expr_ref(t, m) << " "; + tout << "\n"; + } + display(tout);); + // create representatives for shared/projected variables. + m_projector->set_model(mdl); + m_projector->purify(); + + } + + expr* term_graph::get_model_based_rep(expr* e) { + SASSERT(m_projector); + term* t = get_term(e); + SASSERT(t && "only get representatives"); + return m_projector->find_term2app(*t); + } } diff --git a/src/qe/qe_term_graph.h b/src/qe/qe_term_graph.h index 855a0f2bc..ef12ab683 100644 --- a/src/qe/qe_term_graph.h +++ b/src/qe/qe_term_graph.h @@ -52,6 +52,7 @@ namespace qe { expr_ref_vector m_lits; // NSB: expr_ref_vector? u_map m_app2term; ast_ref_vector m_pinned; + projector* m_projector; u_map m_term2app; plugin_manager m_plugins; ptr_hashtable m_cg_table; @@ -135,6 +136,12 @@ namespace qe { */ expr_ref_vector shared_occurrences(family_id fid); + /** + * Map expression that occurs in added literals into representative if it exists. + */ + void add_model_based_terms(model& mdl, expr_ref_vector const& terms); + expr* get_model_based_rep(expr* e); + }; }