From 57c9987d25c346ffa80319ae8158ce5aea323e99 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 03:23:39 +0000 Subject: [PATCH] Add abstract machine for pattern-based term rewriting (rw_rule) Co-authored-by: NikolajBjorner <3085284+NikolajBjorner@users.noreply.github.com> --- src/ast/rewriter/CMakeLists.txt | 1 + src/ast/rewriter/rw_rule.cpp | 194 +++++++++++++++ src/ast/rewriter/rw_rule.h | 111 +++++++++ src/test/CMakeLists.txt | 1 + src/test/main.cpp | 1 + src/test/rw_rule.cpp | 403 ++++++++++++++++++++++++++++++++ 6 files changed, 711 insertions(+) create mode 100644 src/ast/rewriter/rw_rule.cpp create mode 100644 src/ast/rewriter/rw_rule.h create mode 100644 src/test/rw_rule.cpp diff --git a/src/ast/rewriter/CMakeLists.txt b/src/ast/rewriter/CMakeLists.txt index 9d529f9b5..05383e09b 100644 --- a/src/ast/rewriter/CMakeLists.txt +++ b/src/ast/rewriter/CMakeLists.txt @@ -1,6 +1,7 @@ z3_add_component(rewriter SOURCES arith_rewriter.cpp + rw_rule.cpp array_rewriter.cpp ast_counter.cpp bit2int.cpp diff --git a/src/ast/rewriter/rw_rule.cpp b/src/ast/rewriter/rw_rule.cpp new file mode 100644 index 000000000..152b2f2df --- /dev/null +++ b/src/ast/rewriter/rw_rule.cpp @@ -0,0 +1,194 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + rw_rule.cpp + +Abstract: + + Abstract machine for pattern-based term rewriting. See rw_rule.h. + +Author: + + Copilot 2026 + +Notes: + +--*/ + +#include "ast/rewriter/rw_rule.h" +#include "ast/rewriter/rewriter_def.h" +#include "ast/arith_decl_plugin.h" +#include "ast/rewriter/var_subst.h" + +// --------------------------------------------------------------------------- +// rw_table: internals +// --------------------------------------------------------------------------- + +bool rw_table::match(expr * pattern, expr * term, ptr_vector & bindings) { + if (is_var(pattern)) { + unsigned idx = to_var(pattern)->get_idx(); + if (idx >= bindings.size()) + bindings.resize(idx + 1, nullptr); + if (!bindings[idx]) + bindings[idx] = term; + else if (bindings[idx] != term) + return false; + return true; + } + if (!is_app(pattern) || !is_app(term)) + return false; + app * pat = to_app(pattern); + app * trm = to_app(term); + if (pat->get_decl() != trm->get_decl()) + return false; + unsigned n = pat->get_num_args(); + if (n != trm->get_num_args()) + return false; + for (unsigned i = 0; i < n; ++i) + if (!match(pat->get_arg(i), trm->get_arg(i), bindings)) + return false; + return true; +} + +void rw_table::add_rule(unsigned num_vars, expr * lhs, expr * rhs) { + SASSERT(is_app(lhs)); + func_decl * head = to_app(lhs)->get_decl(); + unsigned idx = m_rules.size(); + m_rules.push_back(alloc(rw_rule, m, num_vars, lhs, rhs)); + m_index.insert_if_not_there(head, unsigned_vector()).push_back(idx); +} + +br_status rw_table::apply(func_decl * f, unsigned num, expr * const * args, expr_ref & result) { + auto * entry = m_index.find_core(f); + if (!entry) + return BR_FAILED; + unsigned_vector & rule_ids = entry->get_data().m_value; + for (unsigned idx : rule_ids) { + rw_rule * rule = m_rules[idx]; + app * lhs = to_app(rule->m_lhs.get()); + if (lhs->get_num_args() != num) + continue; + ptr_vector bindings; + bindings.resize(rule->m_num_vars, nullptr); + bool ok = true; + for (unsigned i = 0; i < num && ok; ++i) + ok = match(lhs->get_arg(i), args[i], bindings); + if (ok) { + // verify all declared variables were bound + for (unsigned i = 0; i < rule->m_num_vars && ok; ++i) + if (!bindings[i]) ok = false; + } + if (!ok) + continue; + var_subst subst(m, false); // VAR(i) -> bindings[i] + result = subst(rule->m_rhs.get(), bindings.size(), bindings.data()); + return BR_DONE; + } + return BR_FAILED; +} + +// --------------------------------------------------------------------------- +// populate_rules: representative simplification rules +// --------------------------------------------------------------------------- + +void rw_table::populate_rules() { + arith_util arith(m); + + sort * int_sort = arith.mk_int(); + sort * real_sort = arith.mk_real(); + sort * bool_sort = m.mk_bool_sort(); + + // constant numerals used in patterns + expr_ref zero_i(arith.mk_int(0), m); + expr_ref one_i (arith.mk_int(1), m); + expr_ref zero_r(arith.mk_real(0), m); + expr_ref one_r (arith.mk_real(1), m); + + expr_ref t_true (m.mk_true(), m); + expr_ref t_false(m.mk_false(), m); + + // pattern variables (VAR(i) with explicit sort) + expr_ref v0i(m.mk_var(0, int_sort), m); + expr_ref v1i(m.mk_var(1, int_sort), m); + expr_ref v0r(m.mk_var(0, real_sort), m); + expr_ref v1r(m.mk_var(1, real_sort), m); + expr_ref v0b(m.mk_var(0, bool_sort), m); + expr_ref v1b(m.mk_var(1, bool_sort), m); + + // ------------------------------------------------------------------ + // Arithmetic: addition identity 0 + x -> x and x + 0 -> x + // ------------------------------------------------------------------ + add_rule(1, arith.mk_add(zero_i, v0i), v0i); // 0_i + x -> x (Int) + add_rule(1, arith.mk_add(v0i, zero_i), v0i); // x + 0_i -> x (Int) + add_rule(1, arith.mk_add(zero_r, v0r), v0r); // 0_r + x -> x (Real) + add_rule(1, arith.mk_add(v0r, zero_r), v0r); // x + 0_r -> x (Real) + + // Arithmetic: multiplication identity 1 * x -> x and x * 1 -> x + add_rule(1, arith.mk_mul(one_i, v0i), v0i); // 1_i * x -> x (Int) + add_rule(1, arith.mk_mul(v0i, one_i), v0i); // x * 1_i -> x (Int) + add_rule(1, arith.mk_mul(one_r, v0r), v0r); // 1_r * x -> x (Real) + add_rule(1, arith.mk_mul(v0r, one_r), v0r); // x * 1_r -> x (Real) + + // Arithmetic: multiplication by zero 0 * x -> 0 and x * 0 -> 0 + add_rule(1, arith.mk_mul(zero_i, v0i), zero_i); // 0_i * x -> 0 (Int) + add_rule(1, arith.mk_mul(v0i, zero_i), zero_i); // x * 0_i -> 0 (Int) + add_rule(1, arith.mk_mul(zero_r, v0r), zero_r); // 0_r * x -> 0 (Real) + add_rule(1, arith.mk_mul(v0r, zero_r), zero_r); // x * 0_r -> 0 (Real) + + // Arithmetic: subtraction x - 0 -> x + add_rule(1, arith.mk_sub(v0i, zero_i), v0i); // x - 0_i -> x (Int) + add_rule(1, arith.mk_sub(v0r, zero_r), v0r); // x - 0_r -> x (Real) + + // Arithmetic: unary minus double negation -(-x) -> x + add_rule(1, arith.mk_uminus(arith.mk_uminus(v0i)), v0i); // -(-x) -> x (Int) + add_rule(1, arith.mk_uminus(arith.mk_uminus(v0r)), v0r); // -(-x) -> x (Real) + + // ------------------------------------------------------------------ + // Boolean: and/or identities and annihilators + // ------------------------------------------------------------------ + add_rule(1, m.mk_and(t_true, v0b), v0b); // true /\ x -> x + add_rule(1, m.mk_and(v0b, t_true), v0b); // x /\ true -> x + add_rule(1, m.mk_and(t_false, v0b), t_false); // false /\ x -> false + add_rule(1, m.mk_and(v0b, t_false), t_false); // x /\ false -> false + + add_rule(1, m.mk_or(t_false, v0b), v0b); // false \/ x -> x + add_rule(1, m.mk_or(v0b, t_false), v0b); // x \/ false -> x + add_rule(1, m.mk_or(t_true, v0b), t_true); // true \/ x -> true + add_rule(1, m.mk_or(v0b, t_true), t_true); // x \/ true -> true + + // Boolean: double negation not(not(x)) -> x + add_rule(1, m.mk_not(m.mk_not(v0b)), v0b); + + // Boolean: negation of constants + add_rule(0, m.mk_not(m.mk_true()), m.mk_false()); // not(true) -> false + add_rule(0, m.mk_not(m.mk_false()), m.mk_true()); // not(false) -> true + + // ------------------------------------------------------------------ + // ITE simplifications (Bool, Int, Real branches) + // ------------------------------------------------------------------ + // ite(true, x, y) -> x + add_rule(2, m.mk_ite(t_true, v0b, v1b), v0b); // Bool + add_rule(2, m.mk_ite(t_true, v0i, v1i), v0i); // Int + add_rule(2, m.mk_ite(t_true, v0r, v1r), v0r); // Real + + // ite(false, x, y) -> y + add_rule(2, m.mk_ite(t_false, v0b, v1b), v1b); // Bool + add_rule(2, m.mk_ite(t_false, v0i, v1i), v1i); // Int + add_rule(2, m.mk_ite(t_false, v0r, v1r), v1r); // Real + + // ite(c, x, x) -> x (both branches identical, VAR(1) used twice) + add_rule(2, m.mk_ite(v0b, v1b, v1b), v1b); // Bool + add_rule(2, m.mk_ite(v0b, v1i, v1i), v1i); // Int + add_rule(2, m.mk_ite(v0b, v1r, v1r), v1r); // Real + + // ------------------------------------------------------------------ + // Equality: x = x -> true + // ------------------------------------------------------------------ + add_rule(1, m.mk_eq(v0b, v0b), t_true); // Bool + add_rule(1, m.mk_eq(v0i, v0i), t_true); // Int + add_rule(1, m.mk_eq(v0r, v0r), t_true); // Real +} + +template class rewriter_tpl; diff --git a/src/ast/rewriter/rw_rule.h b/src/ast/rewriter/rw_rule.h new file mode 100644 index 000000000..f6eb14325 --- /dev/null +++ b/src/ast/rewriter/rw_rule.h @@ -0,0 +1,111 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + rw_rule.h + +Abstract: + + Abstract machine for pattern-based term rewriting. + + A rewriting rule lhs -> rhs is represented by storing the lhs as a + pattern in which VAR(i) nodes act as wildcards, and the rhs as a + template where the same VAR(i) nodes are substituted with the + matched subterms. For example, the arithmetic simplification + + 0 + x -> x + + is encoded with lhs = (+ 0_int VAR(0)) and rhs = VAR(0). + + The abstract machine rw_table stores rules indexed by the head + function symbol of each lhs pattern and provides a reduce_app hook + used by the evaluator rw_evaluator. The evaluator derives from + rewriter_tpl so that it traverses terms bottom-up, applying the + first matching rule at every node. + + populate_rules() seeds the machine with a representative set of + arithmetic, Boolean, and ITE simplifications. + +Author: + + Copilot 2026 + +Notes: + +--*/ +#pragma once + +#include "ast/ast.h" +#include "ast/rewriter/rewriter.h" +#include "ast/rewriter/rewriter_types.h" +#include "util/obj_hashtable.h" +#include "util/scoped_ptr_vector.h" + +// --------------------------------------------------------------------------- +// rw_rule: one rewriting rule lhs -> rhs. +// VAR(i) nodes for i < m_num_vars act as pattern wildcards. +// --------------------------------------------------------------------------- +struct rw_rule { + unsigned m_num_vars; + expr_ref m_lhs; + expr_ref m_rhs; + + rw_rule(ast_manager & m, unsigned num_vars, expr * lhs, expr * rhs) + : m_num_vars(num_vars), m_lhs(lhs, m), m_rhs(rhs, m) {} +}; + +// --------------------------------------------------------------------------- +// rw_table: abstract machine. +// Rules are indexed by the head func_decl of the lhs pattern. +// The class also satisfies the Cfg concept expected by rewriter_tpl. +// --------------------------------------------------------------------------- +class rw_table : public default_rewriter_cfg { + ast_manager & m; + scoped_ptr_vector m_rules; + obj_map m_index; + + // Recursive structural matcher. + // VAR(i) in the pattern is unified with the corresponding subterm of term, + // extending bindings[i]. Returns false on conflict or mismatch. + bool match(expr * pattern, expr * term, ptr_vector & bindings); + +public: + explicit rw_table(ast_manager & m) : m(m) {} + + // Add a rewriting rule lhs -> rhs. + // lhs must be an application; VAR(i) for i < num_vars may appear inside + // lhs and rhs as pattern variables / replacement slots. + void add_rule(unsigned num_vars, expr * lhs, expr * rhs); + + // Try to rewrite the application (f args[0] ... args[num-1]). + // Returns BR_DONE with result set on success, BR_FAILED otherwise. + br_status apply(func_decl * f, unsigned num, expr * const * args, expr_ref & result); + + // Cfg hook called by rewriter_tpl for every application node. + br_status reduce_app(func_decl * f, unsigned num, expr * const * args, + expr_ref & result, proof_ref & /*result_pr*/) { + return apply(f, num, args, result); + } + + // Seed the machine with a representative set of simplification rules + // covering arithmetic (Int and Real), Boolean connectives, and ITE. + void populate_rules(); + + ast_manager & get_manager() { return m; } +}; + +// --------------------------------------------------------------------------- +// rw_evaluator: full-term bottom-up evaluator built on top of rw_table. +// --------------------------------------------------------------------------- +class rw_evaluator : public rewriter_tpl { + rw_table m_table; +public: + explicit rw_evaluator(ast_manager & m) + : rewriter_tpl(m, false, m_table) + , m_table(m) { + m_table.populate_rules(); + } + + using rewriter_tpl::operator(); +}; diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 1d5b5ce18..078945a5b 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -119,6 +119,7 @@ add_executable(test-z3 rational.cpp rcf.cpp region.cpp + rw_rule.cpp sat_local_search.cpp sat_lookahead.cpp sat_user_scope.cpp diff --git a/src/test/main.cpp b/src/test/main.cpp index c5d55ebe1..ca79dd80d 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -256,6 +256,7 @@ int main(int argc, char ** argv) { TST(quant_solve); TST(rcf); TST(polynorm); + TST(rw_rule); TST(qe_arith); TST(expr_substitution); TST(sorting_network); diff --git a/src/test/rw_rule.cpp b/src/test/rw_rule.cpp new file mode 100644 index 000000000..3b1de5820 --- /dev/null +++ b/src/test/rw_rule.cpp @@ -0,0 +1,403 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + rw_rule.cpp (test) + +Abstract: + + Tests for the rw_rule abstract machine and rw_evaluator. + +Author: + + Copilot 2026 + +Notes: + +--*/ + +#include "ast/rewriter/rw_rule.h" +#include "ast/arith_decl_plugin.h" +#include "ast/ast_pp.h" +#include "ast/reg_decl_plugins.h" +#include + +// Helper: print a test result and assert the expected condition. +static void check(ast_manager & m, const char * label, expr * result, expr * expected) { + bool ok = (result == expected); + std::cout << label << ": " << mk_pp(result, m) + << (ok ? " [OK]" : " [FAIL]") << "\n"; + ENSURE(ok); +} + +static void check_true(ast_manager & m, const char * label, expr * result) { + check(m, label, result, m.mk_true()); +} + +static void check_false(ast_manager & m, const char * label, expr * result) { + check(m, label, result, m.mk_false()); +} + +// --------------------------------------------------------------------------- +// Arithmetic tests +// --------------------------------------------------------------------------- + +static void test_arith_add_identity() { + ast_manager m; + reg_decl_plugins(m); + arith_util arith(m); + + rw_evaluator ev(m); + expr_ref result(m); + + sort * int_sort = arith.mk_int(); + sort * real_sort = arith.mk_real(); + + expr_ref x(m.mk_const(symbol("x"), int_sort), m); + expr_ref y(m.mk_const(symbol("y"), real_sort), m); + + // 0 + x -> x (Int) + ev(arith.mk_add(arith.mk_int(0), x), result); + check(m, "0_i + x", result, x); + + // x + 0 -> x (Int) + ev(arith.mk_add(x, arith.mk_int(0)), result); + check(m, "x + 0_i", result, x); + + // 0 + y -> y (Real) + ev(arith.mk_add(arith.mk_real(0), y), result); + check(m, "0_r + y", result, y); + + // y + 0 -> y (Real) + ev(arith.mk_add(y, arith.mk_real(0)), result); + check(m, "y + 0_r", result, y); +} + +static void test_arith_mul_identity() { + ast_manager m; + reg_decl_plugins(m); + arith_util arith(m); + + rw_evaluator ev(m); + expr_ref result(m); + + sort * int_sort = arith.mk_int(); + sort * real_sort = arith.mk_real(); + + expr_ref x(m.mk_const(symbol("x"), int_sort), m); + expr_ref y(m.mk_const(symbol("y"), real_sort), m); + + // 1 * x -> x (Int) + ev(arith.mk_mul(arith.mk_int(1), x), result); + check(m, "1_i * x", result, x); + + // x * 1 -> x (Int) + ev(arith.mk_mul(x, arith.mk_int(1)), result); + check(m, "x * 1_i", result, x); + + // 1 * y -> y (Real) + ev(arith.mk_mul(arith.mk_real(1), y), result); + check(m, "1_r * y", result, y); + + // y * 1 -> y (Real) + ev(arith.mk_mul(y, arith.mk_real(1)), result); + check(m, "y * 1_r", result, y); +} + +static void test_arith_mul_zero() { + ast_manager m; + reg_decl_plugins(m); + arith_util arith(m); + + rw_evaluator ev(m); + expr_ref result(m); + + sort * int_sort = arith.mk_int(); + sort * real_sort = arith.mk_real(); + + expr_ref x(m.mk_const(symbol("x"), int_sort), m); + expr_ref y(m.mk_const(symbol("y"), real_sort), m); + + expr_ref zero_i(arith.mk_int(0), m); + expr_ref zero_r(arith.mk_real(0), m); + + // 0 * x -> 0 (Int) + ev(arith.mk_mul(zero_i, x), result); + ENSURE(arith.is_numeral(result) && arith.is_zero(result) && arith.is_int(result)); + + // x * 0 -> 0 (Int) + ev(arith.mk_mul(x, zero_i), result); + ENSURE(arith.is_numeral(result) && arith.is_zero(result) && arith.is_int(result)); + + // 0 * y -> 0 (Real) + ev(arith.mk_mul(zero_r, y), result); + ENSURE(arith.is_numeral(result) && arith.is_zero(result) && !arith.is_int(result)); + + // y * 0 -> 0 (Real) + ev(arith.mk_mul(y, zero_r), result); + ENSURE(arith.is_numeral(result) && arith.is_zero(result) && !arith.is_int(result)); + + std::cout << "mul-zero tests: [OK]\n"; +} + +static void test_arith_sub_zero() { + ast_manager m; + reg_decl_plugins(m); + arith_util arith(m); + + rw_evaluator ev(m); + expr_ref result(m); + + sort * int_sort = arith.mk_int(); + sort * real_sort = arith.mk_real(); + + expr_ref x(m.mk_const(symbol("x"), int_sort), m); + expr_ref y(m.mk_const(symbol("y"), real_sort), m); + + // x - 0 -> x (Int) + ev(arith.mk_sub(x, arith.mk_int(0)), result); + check(m, "x - 0_i", result, x); + + // y - 0 -> y (Real) + ev(arith.mk_sub(y, arith.mk_real(0)), result); + check(m, "y - 0_r", result, y); +} + +static void test_arith_uminus() { + ast_manager m; + reg_decl_plugins(m); + arith_util arith(m); + + rw_evaluator ev(m); + expr_ref result(m); + + sort * int_sort = arith.mk_int(); + sort * real_sort = arith.mk_real(); + + expr_ref x(m.mk_const(symbol("x"), int_sort), m); + expr_ref y(m.mk_const(symbol("y"), real_sort), m); + + // -(-x) -> x (Int) + ev(arith.mk_uminus(arith.mk_uminus(x)), result); + check(m, "-(-x)_i", result, x); + + // -(-y) -> y (Real) + ev(arith.mk_uminus(arith.mk_uminus(y)), result); + check(m, "-(-y)_r", result, y); +} + +// --------------------------------------------------------------------------- +// Boolean tests +// --------------------------------------------------------------------------- + +static void test_bool_and() { + ast_manager m; + reg_decl_plugins(m); + + rw_evaluator ev(m); + expr_ref result(m); + + sort * bool_sort = m.mk_bool_sort(); + expr_ref bx(m.mk_const(symbol("bx"), bool_sort), m); + + // true /\ x -> x + ev(m.mk_and(m.mk_true(), bx), result); + check(m, "true /\\ x", result, bx); + + // x /\ true -> x + ev(m.mk_and(bx, m.mk_true()), result); + check(m, "x /\\ true", result, bx); + + // false /\ x -> false + ev(m.mk_and(m.mk_false(), bx), result); + check_false(m, "false /\\ x", result); + + // x /\ false -> false + ev(m.mk_and(bx, m.mk_false()), result); + check_false(m, "x /\\ false", result); +} + +static void test_bool_or() { + ast_manager m; + reg_decl_plugins(m); + + rw_evaluator ev(m); + expr_ref result(m); + + sort * bool_sort = m.mk_bool_sort(); + expr_ref bx(m.mk_const(symbol("bx"), bool_sort), m); + + // false \/ x -> x + ev(m.mk_or(m.mk_false(), bx), result); + check(m, "false \\/ x", result, bx); + + // x \/ false -> x + ev(m.mk_or(bx, m.mk_false()), result); + check(m, "x \\/ false", result, bx); + + // true \/ x -> true + ev(m.mk_or(m.mk_true(), bx), result); + check_true(m, "true \\/ x", result); + + // x \/ true -> true + ev(m.mk_or(bx, m.mk_true()), result); + check_true(m, "x \\/ true", result); +} + +static void test_bool_not() { + ast_manager m; + reg_decl_plugins(m); + + rw_evaluator ev(m); + expr_ref result(m); + + sort * bool_sort = m.mk_bool_sort(); + expr_ref bx(m.mk_const(symbol("bx"), bool_sort), m); + + // not(not(x)) -> x + ev(m.mk_not(m.mk_not(bx)), result); + check(m, "not(not(x))", result, bx); + + // not(true) -> false + ev(m.mk_not(m.mk_true()), result); + check_false(m, "not(true)", result); + + // not(false) -> true + ev(m.mk_not(m.mk_false()), result); + check_true(m, "not(false)", result); +} + +// --------------------------------------------------------------------------- +// ITE tests +// --------------------------------------------------------------------------- + +static void test_ite() { + ast_manager m; + reg_decl_plugins(m); + arith_util arith(m); + + rw_evaluator ev(m); + expr_ref result(m); + + sort * int_sort = arith.mk_int(); + sort * bool_sort = m.mk_bool_sort(); + + expr_ref x(m.mk_const(symbol("x"), int_sort), m); + expr_ref y(m.mk_const(symbol("y"), int_sort), m); + expr_ref c(m.mk_const(symbol("c"), bool_sort), m); + + // ite(true, x, y) -> x + ev(m.mk_ite(m.mk_true(), x, y), result); + check(m, "ite(true,x,y)", result, x); + + // ite(false, x, y) -> y + ev(m.mk_ite(m.mk_false(), x, y), result); + check(m, "ite(false,x,y)", result, y); + + // ite(c, x, x) -> x + ev(m.mk_ite(c, x, x), result); + check(m, "ite(c,x,x)", result, x); +} + +// --------------------------------------------------------------------------- +// Equality tests +// --------------------------------------------------------------------------- + +static void test_eq_reflexivity() { + ast_manager m; + reg_decl_plugins(m); + arith_util arith(m); + + rw_evaluator ev(m); + expr_ref result(m); + + sort * int_sort = arith.mk_int(); + sort * bool_sort = m.mk_bool_sort(); + + expr_ref x(m.mk_const(symbol("x"), int_sort), m); + expr_ref b(m.mk_const(symbol("b"), bool_sort), m); + + // x = x -> true (Int) + ev(m.mk_eq(x, x), result); + check_true(m, "x = x (Int)", result); + + // b = b -> true (Bool) + ev(m.mk_eq(b, b), result); + check_true(m, "b = b (Bool)", result); +} + +// --------------------------------------------------------------------------- +// Compound rewriting: verify multi-level simplification +// --------------------------------------------------------------------------- + +static void test_compound() { + ast_manager m; + reg_decl_plugins(m); + arith_util arith(m); + + rw_evaluator ev(m); + expr_ref result(m); + + sort * int_sort = arith.mk_int(); + expr_ref x(m.mk_const(symbol("x"), int_sort), m); + + // (0 + x) + (1 * x) should simplify to x + x + expr_ref lhs(arith.mk_add(arith.mk_int(0), x), m); + expr_ref rhs(arith.mk_mul(arith.mk_int(1), x), m); + expr_ref term(arith.mk_add(lhs, rhs), m); + ev(term, result); + // Both sub-terms simplify: result should be x + x + expr_ref expected(arith.mk_add(x, x), m); + check(m, "(0+x)+(1*x)", result, expected); +} + +// --------------------------------------------------------------------------- +// Direct rw_table API test (no evaluator) +// --------------------------------------------------------------------------- + +static void test_table_direct() { + ast_manager m; + reg_decl_plugins(m); + arith_util arith(m); + + rw_table table(m); + table.populate_rules(); + + expr_ref result(m); + proof_ref pr(m); + + sort * int_sort = arith.mk_int(); + expr_ref x(m.mk_const(symbol("x"), int_sort), m); + + // Directly call reduce_app for 0 + x + expr * args[2] = { arith.mk_int(0), x }; + func_decl * add_decl = arith.mk_add(arith.mk_int(0), x)->get_decl(); + br_status st = table.reduce_app(add_decl, 2, args, result, pr); + + std::cout << "table.reduce_app(0+x): status=" << st + << " result=" << mk_pp(result, m) << "\n"; + ENSURE(st == BR_DONE); + ENSURE(result.get() == x.get()); +} + +// --------------------------------------------------------------------------- +// Entry point +// --------------------------------------------------------------------------- + +void tst_rw_rule() { + std::cout << "=== rw_rule tests ===\n"; + test_arith_add_identity(); + test_arith_mul_identity(); + test_arith_mul_zero(); + test_arith_sub_zero(); + test_arith_uminus(); + test_bool_and(); + test_bool_or(); + test_bool_not(); + test_ite(); + test_eq_reflexivity(); + test_compound(); + test_table_direct(); + std::cout << "=== rw_rule: all tests passed ===\n"; +}