3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 17:15:31 +00:00

add EUF plugin framework.

plugin setting allows adding equality saturation within the E-graph propagation without involving externalizing theory solver dispatch. It makes equality saturation independent of SAT integration.
Add a special relation operator to support ad-hoc AC symbols.
This commit is contained in:
Nikolaj Bjorner 2023-11-30 13:58:24 -08:00
parent 5784c2da79
commit b52fd8d954
28 changed files with 3063 additions and 68 deletions

View file

@ -39,6 +39,8 @@ add_executable(test-z3
doc.cpp
egraph.cpp
escaped.cpp
euf_bv_plugin.cpp
euf_arith_plugin.cpp
ex.cpp
expr_rand.cpp
expr_substitution.cpp

View file

@ -0,0 +1,106 @@
/*++
Copyright (c) 2023 Microsoft Corporation
--*/
#include "util/util.h"
#include "util/timer.h"
#include "ast/euf/euf_egraph.h"
#include "ast/euf/euf_arith_plugin.h"
#include "ast/reg_decl_plugins.h"
#include "ast/ast_pp.h"
#include <iostream>
unsigned s_var = 0;
static euf::enode* get_node(euf::egraph& g, arith_util& a, expr* e) {
auto* n = g.find(e);
if (n)
return n;
euf::enode_vector args;
for (expr* arg : *to_app(e))
args.push_back(get_node(g, a, arg));
n = g.mk(e, 0, args.size(), args.data());
g.add_th_var(n, s_var++, a.get_family_id());
return n;
}
//
static void test1() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::arith_plugin, g));
arith_util a(m);
sort_ref I(a.mk_int(), m);
expr_ref x(m.mk_const("x", I), m);
expr_ref y(m.mk_const("y", I), m);
auto* nx = get_node(g, a, a.mk_add(a.mk_add(y, y), a.mk_add(x, x)));
auto* ny = get_node(g, a, a.mk_add(a.mk_add(y, x), x));
TRACE("plugin", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("plugin", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("plugin", tout << "after propagate\n" << g << "\n");
g.merge(get_node(g, a, a.mk_add(x, a.mk_add(y, y))), get_node(g, a, a.mk_add(y, x)), nullptr);
g.propagate();
std::cout << g << "\n";
}
static void test2() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::arith_plugin, g));
arith_util a(m);
sort_ref I(a.mk_int(), m);
expr_ref x(m.mk_const("x", I), m);
expr_ref y(m.mk_const("y", I), m);
auto* nxy = get_node(g, a, a.mk_add(x, y));
auto* nyx = get_node(g, a, a.mk_add(y, x));
auto* nx = get_node(g, a, x);
auto* ny = get_node(g, a, y);
TRACE("plugin", tout << "before merge\n" << g << "\n");
g.merge(nxy, nx, nullptr);
g.merge(nyx, ny, nullptr);
TRACE("plugin", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("plugin", tout << "after propagate\n" << g << "\n");
SASSERT(nx->get_root() == ny->get_root());
g.merge(get_node(g, a, a.mk_add(x, a.mk_add(y, y))), get_node(g, a, a.mk_add(y, x)), nullptr);
g.propagate();
std::cout << g << "\n";
}
static void test3() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::arith_plugin, g));
arith_util a(m);
sort_ref I(a.mk_int(), m);
expr_ref x(m.mk_const("x", I), m);
expr_ref y(m.mk_const("y", I), m);
auto* nxyy = get_node(g, a, a.mk_add(a.mk_add(x, y), y));
auto* nyxx = get_node(g, a, a.mk_add(a.mk_add(y, x), x));
auto* nx = get_node(g, a, x);
auto* ny = get_node(g, a, y);
g.merge(nxyy, nx, nullptr);
g.merge(nyxx, ny, nullptr);
TRACE("plugin", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("plugin", tout << "after propagate\n" << g << "\n");
std::cout << g << "\n";
}
void tst_euf_arith_plugin() {
enable_trace("plugin");
test1();
test2();
test3();
}

183
src/test/euf_bv_plugin.cpp Normal file
View file

@ -0,0 +1,183 @@
/*++
Copyright (c) 2023 Microsoft Corporation
--*/
#include "util/util.h"
#include "util/timer.h"
#include "ast/euf/euf_egraph.h"
#include "ast/euf/euf_bv_plugin.h"
#include "ast/reg_decl_plugins.h"
#include "ast/ast_pp.h"
#include <iostream>
static unsigned s_var = 0;
static euf::enode* get_node(euf::egraph& g, bv_util& b, expr* e) {
auto* n = g.find(e);
if (n)
return n;
euf::enode_vector args;
for (expr* arg : *to_app(e))
args.push_back(get_node(g, b, arg));
n = g.mk(e, 0, args.size(), args.data());
g.add_th_var(n, s_var++, b.get_family_id());
return n;
}
// align slices, and propagate extensionality
static void test1() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::bv_plugin, g));
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref y(m.mk_const("y", u32), m);
expr_ref x3(bv.mk_extract(31, 16, x), m);
expr_ref x2(bv.mk_extract(15, 8, x), m);
expr_ref x1(bv.mk_extract(7, 0, x), m);
expr_ref y3(bv.mk_extract(31, 24, y), m);
expr_ref y2(bv.mk_extract(23, 8, y), m);
expr_ref y1(bv.mk_extract(7, 0, y), m);
expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m);
expr_ref yy(bv.mk_concat(y1, bv.mk_concat(y2, y3)), m);
auto* nx = get_node(g, bv, xx);
auto* ny = get_node(g, bv, yy);
TRACE("bv", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("bv", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("bv", tout << "after propagate\n" << g << "\n");
std::cout << g << "\n";
SASSERT(nx->get_root() == ny->get_root());
}
// propagate values down
static void test2() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::bv_plugin, g));
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref x3(bv.mk_extract(31, 16, x), m);
expr_ref x2(bv.mk_extract(15, 8, x), m);
expr_ref x1(bv.mk_extract(7, 0, x), m);
expr_ref xx(bv.mk_concat(x1, bv.mk_concat(x2, x3)), m);
g.merge(get_node(g, bv, xx), get_node(g, bv, bv.mk_numeral((1 << 27) + (1 << 17) + (1 << 3), 32)), nullptr);
g.propagate();
SASSERT(get_node(g, bv, x1)->get_root()->interpreted());
SASSERT(get_node(g, bv, x2)->get_root()->interpreted());
SASSERT(get_node(g, bv, x3)->get_root()->interpreted());
SASSERT(get_node(g, bv, x)->get_root()->interpreted());
}
// propagate values up
static void test3() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::bv_plugin, g));
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref x3(bv.mk_extract(31, 16, x), m);
expr_ref x2(bv.mk_extract(15, 8, x), m);
expr_ref x1(bv.mk_extract(7, 0, x), m);
expr_ref xx(bv.mk_concat(bv.mk_concat(x1, x2), x3), m);
expr_ref y(m.mk_const("y", u32), m);
g.merge(get_node(g, bv, xx), get_node(g, bv, y), nullptr);
g.merge(get_node(g, bv, x1), get_node(g, bv, bv.mk_numeral(2, 8)), nullptr);
g.merge(get_node(g, bv, x2), get_node(g, bv, bv.mk_numeral(8, 8)), nullptr);
g.propagate();
SASSERT(get_node(g, bv, bv.mk_concat(x1, x2))->get_root()->interpreted());
SASSERT(get_node(g, bv, x1)->get_root()->interpreted());
SASSERT(get_node(g, bv, x2)->get_root()->interpreted());
}
// propagate extract up
static void test4() {
// concat(a, x[J]), a = x[I] => x[IJ] = concat(x[I],x[J])
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::bv_plugin, g));
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
sort_ref u8(bv.mk_sort(8), m);
sort_ref u16(bv.mk_sort(16), m);
expr_ref a(m.mk_const("a", u8), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref y(m.mk_const("y", u16), m);
expr_ref x1(bv.mk_extract(15, 8, x), m);
expr_ref x2(bv.mk_extract(23, 16, x), m);
g.merge(get_node(g, bv, bv.mk_concat(a, x2)), get_node(g, bv, y), nullptr);
g.merge(get_node(g, bv, x1), get_node(g, bv, a), nullptr);
g.propagate();
TRACE("bv", tout << g << "\n");
SASSERT(get_node(g, bv, bv.mk_extract(23, 8, x))->get_root() == get_node(g, bv, y)->get_root());
}
// iterative slicing
static void test5() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::bv_plugin, g));
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref x1(bv.mk_extract(31, 4, x), m);
expr_ref x2(bv.mk_extract(27, 0, x), m);
auto* nx = get_node(g, bv, x1);
auto* ny = get_node(g, bv, x2);
TRACE("bv", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("bv", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("bv", tout << "after propagate\n" << g << "\n");
std::cout << g << "\n";
}
// iterative slicing
static void test6() {
ast_manager m;
reg_decl_plugins(m);
euf::egraph g(m);
g.add_plugin(alloc(euf::bv_plugin, g));
bv_util bv(m);
sort_ref u32(bv.mk_sort(32), m);
expr_ref x(m.mk_const("x", u32), m);
expr_ref x1(bv.mk_extract(31, 3, x), m);
expr_ref x2(bv.mk_extract(28, 0, x), m);
auto* nx = get_node(g, bv, x1);
auto* ny = get_node(g, bv, x2);
TRACE("bv", tout << "before merge\n" << g << "\n");
g.merge(nx, ny, nullptr);
TRACE("bv", tout << "before propagate\n" << g << "\n");
g.propagate();
TRACE("bv", tout << "after propagate\n" << g << "\n");
std::cout << g << "\n";
}
void tst_euf_bv_plugin() {
enable_trace("bv");
enable_trace("plugin");
test6();
return;
test1();
test2();
test3();
test4();
test5();
test6();
}

View file

@ -265,4 +265,6 @@ int main(int argc, char ** argv) {
TST(finder);
TST(totalizer);
TST(distribution);
TST(euf_bv_plugin);
TST(euf_arith_plugin);
}