From 5c71824f2b270bc81a13c00a0be6cf8b7216f850 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 12 Nov 2023 16:48:10 -0800 Subject: [PATCH] adding unit test for arith_plugin --- src/ast/euf/euf_ac_plugin.cpp | 4 +-- src/ast/euf/euf_ac_plugin.h | 21 +++++++------ src/ast/euf/euf_arith_plugin.cpp | 5 +++ src/ast/euf/euf_arith_plugin.h | 4 +++ src/ast/euf/euf_bv_plugin.h | 2 ++ src/ast/euf/euf_egraph.cpp | 11 +++++-- src/ast/euf/euf_plugin.h | 2 ++ src/test/CMakeLists.txt | 1 + src/test/euf_arith_plugin.cpp | 53 ++++++++++++++++++++++++++++++++ src/test/euf_bv_plugin.cpp | 2 +- src/test/main.cpp | 1 + 11 files changed, 89 insertions(+), 17 deletions(-) create mode 100644 src/test/euf_arith_plugin.cpp diff --git a/src/ast/euf/euf_ac_plugin.cpp b/src/ast/euf/euf_ac_plugin.cpp index 70283d6b4..329976213 100644 --- a/src/ast/euf/euf_ac_plugin.cpp +++ b/src/ast/euf/euf_ac_plugin.cpp @@ -41,8 +41,6 @@ More notes: --*/ -#pragma once - #include "ast/euf/euf_ac_plugin.h" #include "ast/euf/euf_egraph.h" @@ -275,7 +273,7 @@ namespace euf { return m_nodes[id]; auto* r = node::mk(get_region(), n); push_undo(is_add_node); - m_nodes.set(id, r); + m_nodes.setx(id, r, nullptr); m_node_trail.push_back(r); return r; } diff --git a/src/ast/euf/euf_ac_plugin.h b/src/ast/euf/euf_ac_plugin.h index 5594f5e9d..ee8b602eb 100644 --- a/src/ast/euf/euf_ac_plugin.h +++ b/src/ast/euf/euf_ac_plugin.h @@ -48,7 +48,7 @@ namespace euf { unsigned root_id() const { return root->n->get_id(); } ~node() {} - static node* mk(region& r, enode* n); + static node* mk(region& r, enode* n); }; class equiv { @@ -71,17 +71,17 @@ namespace euf { iterator end() const { return iterator(&n, &n); } }; - struct eq { + struct eq { unsigned l, r; // refer to monomials bool is_processed = false; - justification j; + justification j; }; unsigned m_fid; - unsigned m_op; + unsigned m_op; vector m_eqs; ptr_vector m_nodes; - vector> m_monomials; + vector> m_monomials; enode_vector m_monomial_enodes; justification::dependency_manager m_dep_manager; @@ -103,8 +103,8 @@ namespace euf { svector> m_update_eq_trail; node* mk_node(enode* n); - void merge(node* r1, node* r2, justification j); - + void merge(node* r1, node* r2, justification j); + bool is_op(enode* n) const { auto d = n->get_decl(); return d && m_fid == d->get_family_id() && m_op == d->get_kind(); } std::function m_undo_notify; @@ -114,9 +114,9 @@ namespace euf { unsigned to_monomial(enode* n, ptr_vector const& ms); ptr_vector const& monomial(unsigned i) const { return m_monomials[i]; } ptr_vector& monomial(unsigned i) { return m_monomials[i]; } - + void init_equation(eq const& e); - bool orient_equation(eq & e); + bool orient_equation(eq& e); void set_processed(unsigned eq_id, bool f); unsigned pick_next_eq(); bool is_trivial(unsigned eq_id) const { throw default_exception("NYI"); } @@ -143,7 +143,6 @@ namespace euf { bool is_processed(unsigned eq) const { return m_eqs[eq].is_processed; } justification justify_rewrite(unsigned eq1, unsigned eq2); - justification justify_superpose(justification j1, justification j2); justification::dependency* justify_monomial(justification::dependency* d, ptr_vector const& m); void propagate_shared(); @@ -152,6 +151,8 @@ namespace euf { public: ac_plugin(egraph& g, unsigned fid, unsigned op); + + ~ac_plugin() override {} unsigned get_id() const override { return m_fid; } diff --git a/src/ast/euf/euf_arith_plugin.cpp b/src/ast/euf/euf_arith_plugin.cpp index e43dd1465..3b134c640 100644 --- a/src/ast/euf/euf_arith_plugin.cpp +++ b/src/ast/euf/euf_arith_plugin.cpp @@ -52,6 +52,11 @@ namespace euf { // no-op } + void arith_plugin::propagate() { + m_add.propagate(); + m_mul.propagate(); + } + void arith_plugin::undo() { auto k = m_undo.back(); m_undo.pop_back(); diff --git a/src/ast/euf/euf_arith_plugin.h b/src/ast/euf/euf_arith_plugin.h index beadf7823..893b94a74 100644 --- a/src/ast/euf/euf_arith_plugin.h +++ b/src/ast/euf/euf_arith_plugin.h @@ -33,6 +33,8 @@ namespace euf { public: arith_plugin(egraph& g); + ~arith_plugin() override {} + unsigned get_id() const override { return a.get_family_id(); } void register_node(enode* n) override; @@ -44,6 +46,8 @@ namespace euf { void diseq_eh(enode* n1, enode* n2) override; void undo() override; + + void propagate() override; std::ostream& display(std::ostream& out) const override; diff --git a/src/ast/euf/euf_bv_plugin.h b/src/ast/euf/euf_bv_plugin.h index 4ab9d4618..f7ae53f97 100644 --- a/src/ast/euf/euf_bv_plugin.h +++ b/src/ast/euf/euf_bv_plugin.h @@ -80,6 +80,8 @@ namespace euf { public: bv_plugin(egraph& g); + ~bv_plugin() override {} + unsigned get_id() const override { return bv.get_family_id(); } void register_node(enode* n) override; diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index c9bc8ac1e..dbcfb51d2 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -19,6 +19,7 @@ Notes: #include "ast/euf/euf_egraph.h" #include "ast/euf/euf_bv_plugin.h" +#include "ast/euf/euf_arith_plugin.h" #include "ast/ast_pp.h" #include "ast/ast_translation.h" @@ -159,9 +160,13 @@ namespace euf { } void egraph::add_plugins() { - auto* plugin = alloc(bv_plugin, *this); - m_plugins.reserve(plugin->get_id() + 1); - m_plugins.set(plugin->get_id(), plugin); + auto insert = [&](plugin* p) { + m_plugins.reserve(p->get_id() + 1); + m_plugins.set(p->get_id(), p); + }; + + insert(alloc(bv_plugin, *this)); + insert(alloc(arith_plugin, *this)); } void egraph::propagate_plugins() { diff --git a/src/ast/euf/euf_plugin.h b/src/ast/euf/euf_plugin.h index 3e9d35771..f36ab38f3 100644 --- a/src/ast/euf/euf_plugin.h +++ b/src/ast/euf/euf_plugin.h @@ -38,6 +38,8 @@ namespace euf { g(g) {} + virtual ~plugin() {} + virtual unsigned get_id() const = 0; virtual void register_node(enode* n) = 0; diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index dca916803..59f25e924 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -38,6 +38,7 @@ add_executable(test-z3 dl_util.cpp doc.cpp egraph.cpp + euf_arith_plugin.cpp euf_bv_plugin.cpp escaped.cpp ex.cpp diff --git a/src/test/euf_arith_plugin.cpp b/src/test/euf_arith_plugin.cpp new file mode 100644 index 000000000..821ea1302 --- /dev/null +++ b/src/test/euf_arith_plugin.cpp @@ -0,0 +1,53 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +--*/ + +#include "util/util.h" +#include "util/timer.h" +#include "ast/euf/euf_egraph.h" +#include "ast/euf/euf_arith_plugin.h" +#include "ast/reg_decl_plugins.h" +#include "ast/ast_pp.h" +#include + +static euf::enode* get_node(euf::egraph& g, expr* e) { + auto* n = g.find(e); + if (n) + return n; + euf::enode_vector args; + for (expr* arg : *to_app(e)) + args.push_back(get_node(g, arg)); + return g.mk(e, 0, args.size(), args.data()); +} + +// +static void test1() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugins(); + arith_util a(m); + sort_ref I(a.mk_int(), m); + + expr_ref x(m.mk_const("x", I), m); + expr_ref y(m.mk_const("y", I), m); + auto* nx = get_node(g, a.mk_add(a.mk_add(y, y), a.mk_add(x, x))); + auto* ny = get_node(g, a.mk_add(a.mk_add(y, x), x)); + TRACE("plugin", tout << "before merge\n" << g << "\n"); + g.merge(nx, ny, nullptr); + + TRACE("plugin", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("plugin", tout << "after propagate\n" << g << "\n"); + g.merge(get_node(g, a.mk_add(x, a.mk_add(y, y))), get_node(g, a.mk_add(y, x)), nullptr); + g.propagate(); + std::cout << g << "\n"; +} + + + +void tst_euf_arith_plugin() { + enable_trace("plugin"); + test1(); +} diff --git a/src/test/euf_bv_plugin.cpp b/src/test/euf_bv_plugin.cpp index bea98dfe6..a0946682b 100644 --- a/src/test/euf_bv_plugin.cpp +++ b/src/test/euf_bv_plugin.cpp @@ -11,7 +11,7 @@ Copyright (c) 2023 Microsoft Corporation #include "ast/ast_pp.h" #include -euf::enode* get_node(euf::egraph& g, expr* e) { +static euf::enode* get_node(euf::egraph& g, expr* e) { auto* n = g.find(e); if (n) return n; diff --git a/src/test/main.cpp b/src/test/main.cpp index f7085cfbc..96de5072b 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -271,4 +271,5 @@ int main(int argc, char ** argv) { TST(totalizer); TST(distribution); TST(euf_bv_plugin); + TST(euf_arith_plugin); }