3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

adding unit test for arith_plugin

This commit is contained in:
Nikolaj Bjorner 2023-11-12 16:48:10 -08:00
parent 65a8c162f5
commit 5c71824f2b
11 changed files with 89 additions and 17 deletions

View file

@ -41,8 +41,6 @@ More notes:
--*/
#pragma once
#include "ast/euf/euf_ac_plugin.h"
#include "ast/euf/euf_egraph.h"
@ -275,7 +273,7 @@ namespace euf {
return m_nodes[id];
auto* r = node::mk(get_region(), n);
push_undo(is_add_node);
m_nodes.set(id, r);
m_nodes.setx(id, r, nullptr);
m_node_trail.push_back(r);
return r;
}

View file

@ -48,7 +48,7 @@ namespace euf {
unsigned root_id() const { return root->n->get_id(); }
~node() {}
static node* mk(region& r, enode* n);
static node* mk(region& r, enode* n);
};
class equiv {
@ -71,17 +71,17 @@ namespace euf {
iterator end() const { return iterator(&n, &n); }
};
struct eq {
struct eq {
unsigned l, r; // refer to monomials
bool is_processed = false;
justification j;
justification j;
};
unsigned m_fid;
unsigned m_op;
unsigned m_op;
vector<eq> m_eqs;
ptr_vector<node> m_nodes;
vector<ptr_vector<node>> m_monomials;
vector<ptr_vector<node>> m_monomials;
enode_vector m_monomial_enodes;
justification::dependency_manager m_dep_manager;
@ -103,8 +103,8 @@ namespace euf {
svector<std::pair<unsigned, eq>> m_update_eq_trail;
node* mk_node(enode* n);
void merge(node* r1, node* r2, justification j);
void merge(node* r1, node* r2, justification j);
bool is_op(enode* n) const { auto d = n->get_decl(); return d && m_fid == d->get_family_id() && m_op == d->get_kind(); }
std::function<void(void)> m_undo_notify;
@ -114,9 +114,9 @@ namespace euf {
unsigned to_monomial(enode* n, ptr_vector<node> const& ms);
ptr_vector<node> const& monomial(unsigned i) const { return m_monomials[i]; }
ptr_vector<node>& monomial(unsigned i) { return m_monomials[i]; }
void init_equation(eq const& e);
bool orient_equation(eq & e);
bool orient_equation(eq& e);
void set_processed(unsigned eq_id, bool f);
unsigned pick_next_eq();
bool is_trivial(unsigned eq_id) const { throw default_exception("NYI"); }
@ -143,7 +143,6 @@ namespace euf {
bool is_processed(unsigned eq) const { return m_eqs[eq].is_processed; }
justification justify_rewrite(unsigned eq1, unsigned eq2);
justification justify_superpose(justification j1, justification j2);
justification::dependency* justify_monomial(justification::dependency* d, ptr_vector<node> const& m);
void propagate_shared();
@ -152,6 +151,8 @@ namespace euf {
public:
ac_plugin(egraph& g, unsigned fid, unsigned op);
~ac_plugin() override {}
unsigned get_id() const override { return m_fid; }

View file

@ -52,6 +52,11 @@ namespace euf {
// no-op
}
void arith_plugin::propagate() {
m_add.propagate();
m_mul.propagate();
}
void arith_plugin::undo() {
auto k = m_undo.back();
m_undo.pop_back();

View file

@ -33,6 +33,8 @@ namespace euf {
public:
arith_plugin(egraph& g);
~arith_plugin() override {}
unsigned get_id() const override { return a.get_family_id(); }
void register_node(enode* n) override;
@ -44,6 +46,8 @@ namespace euf {
void diseq_eh(enode* n1, enode* n2) override;
void undo() override;
void propagate() override;
std::ostream& display(std::ostream& out) const override;

View file

@ -80,6 +80,8 @@ namespace euf {
public:
bv_plugin(egraph& g);
~bv_plugin() override {}
unsigned get_id() const override { return bv.get_family_id(); }
void register_node(enode* n) override;

View file

@ -19,6 +19,7 @@ Notes:
#include "ast/euf/euf_egraph.h"
#include "ast/euf/euf_bv_plugin.h"
#include "ast/euf/euf_arith_plugin.h"
#include "ast/ast_pp.h"
#include "ast/ast_translation.h"
@ -159,9 +160,13 @@ namespace euf {
}
void egraph::add_plugins() {
auto* plugin = alloc(bv_plugin, *this);
m_plugins.reserve(plugin->get_id() + 1);
m_plugins.set(plugin->get_id(), plugin);
auto insert = [&](plugin* p) {
m_plugins.reserve(p->get_id() + 1);
m_plugins.set(p->get_id(), p);
};
insert(alloc(bv_plugin, *this));
insert(alloc(arith_plugin, *this));
}
void egraph::propagate_plugins() {

View file

@ -38,6 +38,8 @@ namespace euf {
g(g)
{}
virtual ~plugin() {}
virtual unsigned get_id() const = 0;
virtual void register_node(enode* n) = 0;

View file

@ -38,6 +38,7 @@ add_executable(test-z3
dl_util.cpp
doc.cpp
egraph.cpp
euf_arith_plugin.cpp
euf_bv_plugin.cpp
escaped.cpp
ex.cpp

View file

@ -0,0 +1,53 @@
/*++
Copyright (c) 2023 Microsoft Corporation
--*/
#include "util/util.h"
#include "util/timer.h"
#include "ast/euf/euf_egraph.h"
#include "ast/euf/euf_arith_plugin.h"
#include "ast/reg_decl_plugins.h"
#include "ast/ast_pp.h"
#include <iostream>
static euf::enode* get_node(euf::egraph& g, expr* e) {
auto* n = g.find(e);
if (n)
return n;
euf::enode_vector args;
for (expr* arg : *to_app(e))
args.push_back(get_node(g, arg));
return g.mk(e, 0, args.size(), args.data());
}
//
static void test1() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugins();
arith_util a(m);
sort_ref I(a.mk_int(), m);
expr_ref x(m.mk_const("x", I), m);
expr_ref y(m.mk_const("y", I), m);
auto* nx = get_node(g, a.mk_add(a.mk_add(y, y), a.mk_add(x, x)));
auto* ny = get_node(g, a.mk_add(a.mk_add(y, x), x));
TRACE("plugin", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("plugin", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("plugin", tout << "after propagate\n" << g << "\n");
g.merge(get_node(g, a.mk_add(x, a.mk_add(y, y))), get_node(g, a.mk_add(y, x)), nullptr);
g.propagate();
std::cout << g << "\n";
}
void tst_euf_arith_plugin() {
enable_trace("plugin");
test1();
}

View file

@ -11,7 +11,7 @@ Copyright (c) 2023 Microsoft Corporation
#include "ast/ast_pp.h"
#include <iostream>
euf::enode* get_node(euf::egraph& g, expr* e) {
static euf::enode* get_node(euf::egraph& g, expr* e) {
auto* n = g.find(e);
if (n)
return n;

View file

@ -271,4 +271,5 @@ int main(int argc, char ** argv) {
TST(totalizer);
TST(distribution);
TST(euf_bv_plugin);
TST(euf_arith_plugin);
}