From 5fce4a1d1af81cd912a12d0f7dbb57acb7783a9d Mon Sep 17 00:00:00 2001 From: Arie Gurfinkel Date: Tue, 12 Jun 2018 11:59:18 -0700 Subject: [PATCH] Wire qe_solve_plugin into qe_term_graph Compiles. Not tested. --- src/qe/qe_solve_plugin.cpp | 45 +++---- src/qe/qe_term_graph.cpp | 241 +++++++++++++------------------------ src/qe/qe_term_graph.h | 41 ++++--- src/qe/qe_vartest.h | 7 +- 4 files changed, 136 insertions(+), 198 deletions(-) diff --git a/src/qe/qe_solve_plugin.cpp b/src/qe/qe_solve_plugin.cpp index ad4a03b5d..0a499d3b8 100644 --- a/src/qe/qe_solve_plugin.cpp +++ b/src/qe/qe_solve_plugin.cpp @@ -27,7 +27,7 @@ Revision History: namespace qe { expr_ref solve_plugin::operator()(expr* lit) { - if (m.is_not(lit, lit)) + if (m.is_not(lit, lit)) return solve(lit, false); else return solve(lit, true); @@ -39,9 +39,9 @@ namespace qe { arith_solve_plugin(ast_manager& m, is_variable_proc& is_var): solve_plugin(m, m.get_family_id("arith"), is_var), a(m) {} typedef std::pair signed_expr; - + /** - *\brief + *\brief * return r * (sum_{(sign,e) \in exprs} sign * e) */ expr_ref mk_term(bool is_int, rational const& r, bool sign, svector const& exprs) { @@ -124,7 +124,7 @@ namespace qe { return false; } - // is arg of the form a_val * v, where a_val + // is arg of the form a_val * v, where a_val // is a constant that we can safely divide by. bool is_invertible_mul(bool is_int, expr*& arg, rational& a_val) { if (is_variable(arg)) { @@ -144,9 +144,9 @@ namespace qe { } return false; } - - expr_ref mk_eq_core (expr *e1, expr *e2) { + + expr_ref mk_eq_core (expr *e1, expr *e2) { expr_ref v(m), t(m); if (solve(e1, e2, v, t)) { return expr_ref(m.mk_eq(v, t), m); @@ -172,7 +172,6 @@ namespace qe { app* mk_le_zero(expr *arg) { expr *e1, *e2, *e3; - // XXX currently disabled if (a.is_add(arg, e1, e2)) { // e1-e2<=0 --> e1<=e2 if (a.is_times_minus_one(e2, e3)) { @@ -188,7 +187,6 @@ namespace qe { app* mk_ge_zero(expr *arg) { expr *e1, *e2, *e3; - // XXX currently disabled if (a.is_add(arg, e1, e2)) { // e1-e2>=0 --> e1>=e2 if (a.is_times_minus_one(e2, e3)) { @@ -249,17 +247,22 @@ namespace qe { return false; } - expr_ref solve(expr* lit, bool is_pos) override { + expr_ref solve(expr* atom, bool is_pos) override { expr *e1, *e2; - expr_ref res(lit, m); - if (m.is_eq (lit, e1, e2)) { - res = mk_eq_core(e1, e2); + expr_ref res(atom, m); + if (m.is_eq (atom, e1, e2)) { + expr_ref v(m), t(m); + v = e1; t = e2; + // -- attempt to solve using arithmetic + solve(e1, e2, v, t); + // -- normalize equality + res = mk_eq_core(v, t); } - else if (a.is_le(lit, e1, e2)) { + else if (a.is_le(atom, e1, e2)) { mk_le_core(e1, e2, res); } - else if (a.is_ge(lit, e1, e2)) { + else if (a.is_ge(atom, e1, e2)) { mk_ge_core(e1, e2, res); } @@ -273,7 +276,7 @@ namespace qe { class basic_solve_plugin : public solve_plugin { public: - basic_solve_plugin(ast_manager& m, is_variable_proc& is_var): + basic_solve_plugin(ast_manager& m, is_variable_proc& is_var): solve_plugin(m, m.get_basic_family_id(), is_var) {} expr_ref solve(expr *atom, bool is_pos) override { @@ -288,7 +291,7 @@ namespace qe { } else if (is_variable(rhs) && !is_variable(lhs)) { res = m.mk_eq(rhs, lhs); - } + } } // (ite cond (= VAR t) (= VAR t2)) case expr* cond = nullptr, *th = nullptr, *el = nullptr; @@ -296,7 +299,7 @@ namespace qe { expr_ref r1 = solve(th, true); expr_ref r2 = solve(el, true); expr* v1 = nullptr, *t1 = nullptr, *v2 = nullptr, *t2 = nullptr; - if (m.is_eq(r1, v1, t1) && m.is_eq(r2, v2, t2) && v1 == v2) { + if (m.is_eq(r1, v1, t1) && m.is_eq(r2, v2, t2) && v1 == v2) { res = m.mk_eq(v1, m.mk_ite(cond, t1, t2)); } } @@ -313,8 +316,8 @@ namespace qe { class dt_solve_plugin : public solve_plugin { datatype_util dt; public: - dt_solve_plugin(ast_manager& m, is_variable_proc& is_var): - solve_plugin(m, m.get_family_id("datatype"), is_var), + dt_solve_plugin(ast_manager& m, is_variable_proc& is_var): + solve_plugin(m, m.get_family_id("datatype"), is_var), dt(m) {} expr_ref solve(expr *atom, bool is_pos) override { @@ -350,11 +353,11 @@ namespace qe { } } // TBD: can also solve for is_nil(x) by x = nil - // + // return is_pos ? res : mk_not(res); } }; - + class bv_solve_plugin : public solve_plugin { public: bv_solve_plugin(ast_manager& m, is_variable_proc& is_var): solve_plugin(m, m.get_family_id("bv"), is_var) {} diff --git a/src/qe/qe_term_graph.cpp b/src/qe/qe_term_graph.cpp index 396b5f092..9bf007428 100644 --- a/src/qe/qe_term_graph.cpp +++ b/src/qe/qe_term_graph.cpp @@ -27,6 +27,28 @@ Notes: namespace qe { + namespace is_pure_ns { + struct found{}; + struct proc { + is_variable_proc &m_is_var; + proc(is_variable_proc &is_var) : m_is_var(is_var) {} + void operator()(var *n) const {if (m_is_var(n)) throw found();} + void operator()(app const *n) const {if (m_is_var(n)) throw found();} + void operator()(quantifier *n) const {} + }; + } + + bool is_pure(is_variable_proc &is_var, expr *e) { + try { + is_pure_ns::proc v(is_var); + quick_for_each_expr(v, e); + } + catch (is_pure_ns::found) { + return false; + } + return true; + } + class term { // -- an app represented by this term expr* m_expr; // NSB: to make usable with exprs @@ -160,154 +182,46 @@ namespace qe { }; - class arith_term_graph_plugin : public term_graph_plugin { - term_graph &m_g; - ast_manager &m; - arith_util m_arith; - public: - arith_term_graph_plugin(term_graph &g) : - term_graph_plugin (g.get_ast_manager().mk_family_id("arith")), - m_g(g), m(g.get_ast_manager()), m_arith(m) {(void)m_g;} + bool term_graph::is_variable_proc::operator()(const expr * e) const { + if (!is_app(e)) return false; + const app *a = ::to_app(e); + if (a->get_family_id() != null_family_id) return false; + if (m_solved.contains(a->get_decl()->get_id())) return false; + return m_exclude == m_decls.contains(a->get_decl()->get_id()); + } + bool term_graph::is_variable_proc::operator()(const term &t) const { + return !t.is_theory() && m_exclude == m_decls.contains(t.get_decl_id()); + } - virtual ~arith_term_graph_plugin() {} + void term_graph::is_variable_proc::set_decls(const func_decl_ref_vector &decls, bool exclude) { + reset(); + m_exclude = exclude; + for (auto *d : decls) m_decls.insert(d->get_id(), true); + } + void term_graph::is_variable_proc::mark_solved(const expr *e) { + if ((*this)(e)) + m_solved.insert(::to_app(e)->get_decl()->get_id(), true); + } - bool mk_eq_core (expr *_e1, expr *_e2, expr_ref &res) { - expr *e1, *e2; - e1 = _e1; - e2 = _e2; - if (m_arith.is_zero(e1)) { - std::swap(e1, e2); - } - // y + -1*x == 0 --> y = x - expr *a0 = 0, *a1 = 0, *x = 0; - if (m_arith.is_zero(e2) && m_arith.is_add(e1, a0, a1)) { - if (m_arith.is_times_minus_one(a1, x)) { - e1 = a0; - e2 = x; - } - else if (m_arith.is_times_minus_one(a0, x)) { - e1 = a1; - e2 = x; - } - } - res = m.mk_eq(e1, e2); - return true; - } - - app* mk_le_zero(expr *arg) { - expr *e1, *e2, *e3; - if (m_arith.is_add(arg, e1, e2)) { - // e1-e2<=0 --> e1<=e2 - if (m_arith.is_times_minus_one(e2, e3)) { - return m_arith.mk_le(e1, e3); - } - // -e1+e2<=0 --> e2<=e1 - else if (m_arith.is_times_minus_one(e1, e3)) { - return m_arith.mk_le(e2, e3); - } - } - return m_arith.mk_le(arg, mk_zero()); - } - - app* mk_ge_zero(expr *arg) { - expr *e1, *e2, *e3; - if (m_arith.is_add(arg, e1, e2)) { - // e1-e2>=0 --> e1>=e2 - if (m_arith.is_times_minus_one(e2, e3)) { - return m_arith.mk_ge(e1, e3); - } - // -e1+e2>=0 --> e2>=e1 - else if (m_arith.is_times_minus_one(e1, e3)) { - return m_arith.mk_ge(e2, e3); - } - } - return m_arith.mk_ge(arg, mk_zero()); - } - - bool mk_le_core (expr *arg1, expr * arg2, expr_ref &result) { - // t <= -1 ==> t < 0 ==> ! (t >= 0) - rational n; - if (m_arith.is_int (arg1) && m_arith.is_minus_one (arg2)) { - result = m.mk_not (mk_ge_zero (arg1)); - return true; - } - else if (m_arith.is_zero(arg2)) { - result = mk_le_zero(arg1); - return true; - } - else if (m_arith.is_int(arg1) && m_arith.is_numeral(arg2, n) && n < 0) { - // t <= n ==> t < n + 1 ==> ! (t >= n + 1) - result = m.mk_not(m_arith.mk_ge(arg1, m_arith.mk_numeral(n+1, true))); - return true; - } - return false; - } - expr * mk_zero () {return m_arith.mk_numeral (rational (0), true);} - bool is_one (expr const * n) const { - rational val; - return m_arith.is_numeral (n, val) && val.is_one (); - } - - bool mk_ge_core (expr * arg1, expr * arg2, expr_ref &result) { - // t >= 1 ==> t > 0 ==> ! (t <= 0) - rational n; - if (m_arith.is_int (arg1) && is_one (arg2)) { - result = m.mk_not (mk_le_zero (arg1)); - return true; - } - else if (m_arith.is_zero(arg2)) { - result = mk_ge_zero(arg1); - return true; - } - else if (m_arith.is_int(arg1) && m_arith.is_numeral(arg2, n) && n > 0) { - // t >= n ==> t > n - 1 ==> ! (t <= n - 1) - result = m.mk_not(m_arith.mk_le(arg1, m_arith.mk_numeral(n-1, true))); - return true; - } - return false; - } - - expr_ref process_lit (expr *_lit) override { - expr *lit = _lit; - expr *e1, *e2; - - // strip negation - bool is_neg = m.is_not(lit); - if (is_neg) { - lit = to_app(to_app(lit)->get_arg(0)); - } - - expr_ref res(m); - res = lit; - if (m.is_eq (lit, e1, e2)) { - mk_eq_core(e1, e2, res); - } - else if (m_arith.is_le(lit, e1, e2)) { - mk_le_core(e1, e2, res); - } - else if (m_arith.is_ge(lit, e1, e2)) { - mk_ge_core(e1, e2, res); - } - // restore negation - if (is_neg) { - res = mk_not(m, res); - } - return res; - } - }; unsigned term_graph::term_hash::operator()(term const* t) const { return t->get_hash(); } bool term_graph::term_eq::operator()(term const* a, term const* b) const { return term::cg_eq(a, b); } term_graph::term_graph(ast_manager &man) : m(man), m_lits(m), m_pinned(m) { - m_plugins.register_plugin (alloc(arith_term_graph_plugin, *this)); + m_plugins.register_plugin(mk_basic_solve_plugin(m, m_is_var)); + m_plugins.register_plugin(mk_arith_solve_plugin(m, m_is_var)); } term_graph::~term_graph() { reset(); } + bool term_graph::is_pure_def(expr *atom, expr *v) { + expr *e = nullptr; + return m.is_eq(atom, v, e) && m_is_var(v) && is_pure(m_is_var, e); + } + static family_id get_family_id(ast_manager &m, expr *lit) { if (m.is_not(lit, lit)) return get_family_id(m, lit); @@ -328,13 +242,9 @@ namespace qe { void term_graph::add_lit(expr *l) { expr_ref lit(m); - family_id fid = get_family_id (m, l); - term_graph_plugin *pin = m_plugins.get_plugin(fid); - if (pin) { - lit = pin->process_lit(l); - } else { - lit = l; - } + family_id fid = get_family_id(m, l); + qe::solve_plugin *pin = m_plugins.get_plugin(fid); + lit = pin ? (*pin)(l) : l; m_lits.push_back(lit); internalize_lit(lit); } @@ -620,8 +530,6 @@ namespace qe { ast_manager &m; u_map m_term2app; u_map m_root2rep; - u_map m_decls; - bool m_exclude; expr_ref_vector m_pinned; // tracks expr in the maps @@ -700,7 +608,7 @@ namespace qe { m_tg.reset_marks(); } - void solve() { + void solve_core() { ptr_vector worklist; for (term * t : m_tg.m_terms) { // skip pure terms @@ -772,9 +680,7 @@ namespace qe { while (r != &t); } - bool is_projected(const term &t) { - return m_exclude == m_decls.contains(t.get_decl_id()); - } + bool is_projected(const term &t) {return m_tg.m_is_var(t);} void mk_unpure_equalities(const term &t, expr_ref_vector &res) { expr *rep = nullptr; @@ -834,24 +740,20 @@ namespace qe { m_tg.reset_marks(); m_term2app.reset(); m_root2rep.reset(); - m_decls.reset(); m_pinned.reset(); } - expr_ref_vector project(func_decl_ref_vector const &decls, bool exclude) { + expr_ref_vector project() { expr_ref_vector res(m); - m_exclude = exclude; - for (auto *d : decls) {m_decls.insert(d->get_id(), true);} purify(); mk_lits(res); mk_pure_equalities(res); reset(); return res; } - expr_ref_vector solve(func_decl_ref_vector const &decls, bool exclude) { + expr_ref_vector solve() { expr_ref_vector res(m); - m_exclude = exclude; purify(); - solve(); + solve_core(); mk_lits(res); mk_unpure_equalities(res); reset(); @@ -860,14 +762,41 @@ namespace qe { }; } + void term_graph::solve_for_vars() { + expr_ref new_lit(m); + expr *old_lit = nullptr, *v = nullptr; + for (unsigned i = 0, sz = m_lits.size(); i < sz; ++i) { + old_lit = m_lits.get(i); + qe::solve_plugin *pin = m_plugins.get_plugin(get_family_id(m, old_lit)); + if (pin) { + new_lit = (*pin)(old_lit); + if (new_lit.get() != old_lit) { + m_lits.set(i, new_lit); + internalize_lit(new_lit); + } + if (is_pure_def(new_lit, v)) { + m_is_var.mark_solved(v); + } + } + } + m_is_var.reset_solved(); + } expr_ref_vector term_graph::project(func_decl_ref_vector const& decls, bool exclude) { + m_is_var.set_decls(decls, exclude); + solve_for_vars(); projector p(*this); - return p.project(decls, exclude); + m_is_var.reset(); + expr_ref_vector v = p.project(); + return v; } expr_ref_vector term_graph::solve(func_decl_ref_vector const &decls, bool exclude) { + m_is_var.set_decls(decls, exclude); + solve_for_vars(); projector p(*this); - return p.solve(decls, exclude); + expr_ref_vector v = p.solve(); + m_is_var.reset(); + return v; } } diff --git a/src/qe/qe_term_graph.h b/src/qe/qe_term_graph.h index 210941dac..e60b535c0 100644 --- a/src/qe/qe_term_graph.h +++ b/src/qe/qe_term_graph.h @@ -21,28 +21,31 @@ Notes: #include "ast/ast.h" #include "util/plugin_manager.h" +#include "qe/qe_solve_plugin.h" +#include "qe/qe_vartest.h" namespace qe { class term; - namespace {class projector;} - class term_graph_plugin { - family_id m_id; - public: - term_graph_plugin(family_id fid) : m_id(fid) {} - virtual ~term_graph_plugin() {} - - family_id get_family_id() const {return m_id;} - - /// Process (and potentially augment) a literal - virtual expr_ref process_lit (expr *lit) = 0; - }; - - class term_graph { friend class projector; + + class is_variable_proc : public ::is_variable_proc { + bool m_exclude; + u_map m_decls; + u_map m_solved; + public: + bool operator()(const expr *e) const override; + bool operator()(const term &t) const; + + void set_decls(const func_decl_ref_vector &decls, bool exclude); + void mark_solved(const expr *e); + void reset_solved() {m_solved.reset();} + void reset() {m_decls.reset(); m_solved.reset(); m_exclude = true;} + }; + struct term_hash { unsigned operator()(term const* t) const; }; struct term_eq { bool operator()(term const* a, term const* b) const; }; ast_manager & m; @@ -51,10 +54,11 @@ namespace qe { u_map m_app2term; ast_ref_vector m_pinned; u_map m_term2app; - plugin_manager m_plugins; + plugin_manager m_plugins; ptr_hashtable m_cg_table; vector> m_merge; + term_graph::is_variable_proc m_is_var; void merge(term &t1, term &t2); void merge_flush(); @@ -80,9 +84,10 @@ namespace qe { void mk_equalities(term const &t, expr_ref_vector &out); void mk_all_equalities(term const &t, expr_ref_vector &out); void display(std::ostream &out); - void project_core(func_decl_ref_vector const &decls, bool exclude, expr_ref_vector &result); - void solve_core(func_decl_ref_vector const &decls, bool exclude, expr_ref_vector &result); - bool is_solved_eq(expr *lhs, expr *rhs); + + bool is_pure_def(expr* atom, expr *v); + void solve_for_vars(); + public: term_graph(ast_manager &m); diff --git a/src/qe/qe_vartest.h b/src/qe/qe_vartest.h index 56d9229b8..52609893f 100644 --- a/src/qe/qe_vartest.h +++ b/src/qe/qe_vartest.h @@ -22,9 +22,10 @@ Revision History: #include "ast/ast.h" #include "util/uint_set.h" -class is_variable_proc { +// TBD: move under qe namespace +class is_variable_proc : public std::unary_function { public: - virtual bool operator()(expr* e) const = 0; + virtual bool operator()(const expr* e) const = 0; }; class is_variable_test : public is_variable_proc { @@ -42,7 +43,7 @@ public: m_num_decls(num_decls), m_var_kind(BY_NUM_DECLS) {} - bool operator()(expr* e) const override { + bool operator()(const expr* e) const override { if (!is_var(e)) { return false; }