From 14483dcd6ed7b7528a97917684618c35563725ee Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Mon, 20 Nov 2023 16:15:30 -0800 Subject: [PATCH] n/a Signed-off-by: Nikolaj Bjorner --- src/ast/euf/euf_ac_plugin.cpp | 25 +++++++++++------------ src/ast/euf/euf_ac_plugin.h | 1 + src/test/euf_arith_plugin.cpp | 38 +++++++++++++++++++++++++++++------ 3 files changed, 45 insertions(+), 19 deletions(-) diff --git a/src/ast/euf/euf_ac_plugin.cpp b/src/ast/euf/euf_ac_plugin.cpp index 74a98ea59..272a14320 100644 --- a/src/ast/euf/euf_ac_plugin.cpp +++ b/src/ast/euf/euf_ac_plugin.cpp @@ -598,8 +598,7 @@ namespace euf { return false; if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) return false; - if (backward_subsumes(src_eq, dst_eq)) { - + if (backward_subsumes(src_eq, dst_eq)) { set_status(dst_eq, eq_status::is_dead); return true; } @@ -613,8 +612,8 @@ namespace euf { return true; } - // dst_eq is fixed, dst_count is pre-computed for monomial(dst.l) - // dst2_counts is pre-computed for monomial(dst.r). + // dst_eq is fixed, dst_l_count is pre-computed for monomial(dst.l) + // dst_r_counts is pre-computed for monomial(dst.r). // is dst_eq subsumed by src_eq? bool ac_plugin::backward_subsumes(unsigned src_eq, unsigned dst_eq) { auto& src = m_eqs[src_eq]; @@ -626,22 +625,22 @@ namespace euf { unsigned size_diff = monomial(dst.l).size() - monomial(src.l).size(); if (size_diff != monomial(dst.r).size() - monomial(src.r).size()) return false; - if (!is_subset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) + if (!is_superset(m_dst_l_counts, m_src_l_counts, monomial(src.l))) return false; - if (!is_subset(m_dst_r_counts, m_src_r_counts, monomial(src.r))) + if (!is_superset(m_dst_r_counts, m_src_r_counts, monomial(src.r))) return false; // add difference betwen src and dst1 to dst2 // (also add it to dst1 to make sure same difference isn't counted twice). for (auto n : monomial(src.l)) { unsigned id = n->root_id(); - SASSERT(m_src_l_counts[id] >= m_dst_l_counts[id]); - unsigned diff = m_src_l_counts[id] - m_dst_l_counts[id]; + SASSERT(m_dst_l_counts[id] >= m_src_l_counts[id]); + unsigned diff = m_dst_l_counts[id] - m_src_l_counts[id]; if (diff > 0) { - m_dst_l_counts.inc(id, diff); - m_dst_r_counts.inc(id, diff); + m_src_l_counts.inc(id, diff); + m_src_r_counts.inc(id, diff); } } - // now dst2 and src2 should align and have the same elements. + // now dst.r and src.r should align and have the same elements. // since src.r is a subset of dst.r we iterate over dst.r return all_of(monomial(dst.r), [&](node* n) { unsigned id = n->root_id(); return m_src_r_counts[id] == m_dst_r_counts[id]; }); } @@ -744,7 +743,7 @@ namespace euf { for (auto n : monomial(src.l)) { unsigned id = n->root_id(); if (m_dst_l_counts[id] > 0) - m_dst_l_counts.inc(id, -1); + m_dst_l_counts.dec(id, 1); else m_dst_r.push_back(n); } @@ -781,7 +780,7 @@ namespace euf { unsigned id = n->root_id(); if (m_eq_counts[id] == 0) return false; - m_eq_counts.inc(id, -1); + m_eq_counts.dec(id, 1); } return true; } diff --git a/src/ast/euf/euf_ac_plugin.h b/src/ast/euf/euf_ac_plugin.h index 88b4f6bb1..5fd1b272b 100644 --- a/src/ast/euf/euf_ac_plugin.h +++ b/src/ast/euf/euf_ac_plugin.h @@ -212,6 +212,7 @@ namespace euf { void reset() { for (auto idx : ids) counts[idx] = 0; ids.reset(); } unsigned operator[](unsigned idx) const { return counts.get(idx, 0); } void inc(unsigned idx, unsigned amount) { counts.reserve(idx + 1, 0); ids.push_back(idx); counts[idx] += amount; } + void dec(unsigned idx, unsigned amount) { counts.reserve(idx + 1, 0); ids.push_back(idx); counts[idx] -= amount; } unsigned const* begin() const { return ids.begin(); } unsigned const* end() const { return ids.end(); } }; diff --git a/src/test/euf_arith_plugin.cpp b/src/test/euf_arith_plugin.cpp index 218570854..217928977 100644 --- a/src/test/euf_arith_plugin.cpp +++ b/src/test/euf_arith_plugin.cpp @@ -55,23 +55,49 @@ static void test2() { expr_ref x(m.mk_const("x", I), m); expr_ref y(m.mk_const("y", I), m); - auto* nx = get_node(g, a.mk_add(x, y)); - auto* ny = get_node(g, a.mk_add(y, x)); + auto* nxy = get_node(g, a.mk_add(x, y)); + auto* nyx = get_node(g, a.mk_add(y, x)); + auto* nx = get_node(g, x); + auto* ny = get_node(g, y); + TRACE("plugin", tout << "before merge\n" << g << "\n"); - g.merge(nx, get_node(g, x), nullptr); - g.merge(ny, get_node(g, y), nullptr); - + g.merge(nxy, nx, nullptr); + g.merge(nyx, ny, nullptr); TRACE("plugin", tout << "before propagate\n" << g << "\n"); g.propagate(); TRACE("plugin", tout << "after propagate\n" << g << "\n"); - SASSERT(get_node(g, x)->get_root() == get_node(g, y)->get_root()); + SASSERT(nx->get_root() == ny->get_root()); g.merge(get_node(g, a.mk_add(x, a.mk_add(y, y))), get_node(g, a.mk_add(y, x)), nullptr); g.propagate(); std::cout << g << "\n"; } +static void test3() { + ast_manager m; + reg_decl_plugins(m); + euf::egraph g(m); + g.add_plugins(); + arith_util a(m); + sort_ref I(a.mk_int(), m); + + expr_ref x(m.mk_const("x", I), m); + expr_ref y(m.mk_const("y", I), m); + auto* nxyy = get_node(g, a.mk_add(a.mk_add(x, y), y)); + auto* nyxx = get_node(g, a.mk_add(a.mk_add(y, x), x)); + auto* nx = get_node(g, x); + auto* ny = get_node(g, y); + g.merge(nxyy, nx, nullptr); + g.merge(nyxx, ny, nullptr); + TRACE("plugin", tout << "before propagate\n" << g << "\n"); + g.propagate(); + TRACE("plugin", tout << "after propagate\n" << g << "\n"); + SASSERT(nx->get_root() == ny->get_root()); + std::cout << g << "\n"; +} + void tst_euf_arith_plugin() { enable_trace("plugin"); + test3(); test1(); test2(); }