From 070c5c624a7148aba6e5b56a1283984796dc5fab Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 3 Nov 2022 03:33:31 -0700 Subject: [PATCH] wip - converting the equation solver as a simplifier --- src/ast/arith_decl_plugin.h | 3 + src/ast/simplifiers/CMakeLists.txt | 3 +- src/ast/simplifiers/extract_eqs.cpp | 239 ++++++++++++++++++++++++++++ src/ast/simplifiers/extract_eqs.h | 47 ++++++ src/ast/simplifiers/solve_eqs.cpp | 74 +++------ src/ast/simplifiers/solve_eqs.h | 22 +-- 6 files changed, 316 insertions(+), 72 deletions(-) create mode 100644 src/ast/simplifiers/extract_eqs.cpp create mode 100644 src/ast/simplifiers/extract_eqs.h diff --git a/src/ast/arith_decl_plugin.h b/src/ast/arith_decl_plugin.h index 781996662..0c77867d4 100644 --- a/src/ast/arith_decl_plugin.h +++ b/src/ast/arith_decl_plugin.h @@ -453,6 +453,9 @@ public: app * mk_mul(expr * arg1, expr * arg2) const { return m_manager.mk_app(arith_family_id, OP_MUL, arg1, arg2); } app * mk_mul(expr * arg1, expr * arg2, expr* arg3) const { return m_manager.mk_app(arith_family_id, OP_MUL, arg1, arg2, arg3); } app * mk_mul(unsigned num_args, expr * const * args) const { return num_args == 1 && is_app(args[0]) ? to_app(args[0]) : m_manager.mk_app(arith_family_id, OP_MUL, num_args, args); } + app * mk_mul(ptr_buffer const& args) const { return mk_mul(args.size(), args.data()); } + app * mk_mul(ptr_vector const& args) const { return mk_mul(args.size(), args.data()); } + app * mk_mul(expr_ref_vector const& args) const { return mk_mul(args.size(), args.data()); } app * mk_uminus(expr * arg) const { return m_manager.mk_app(arith_family_id, OP_UMINUS, arg); } app * mk_div(expr * arg1, expr * arg2) { return m_manager.mk_app(arith_family_id, OP_DIV, arg1, arg2); } app * mk_idiv(expr * arg1, expr * arg2) { return m_manager.mk_app(arith_family_id, OP_IDIV, arg1, arg2); } diff --git a/src/ast/simplifiers/CMakeLists.txt b/src/ast/simplifiers/CMakeLists.txt index d07220fb5..a260dd3b7 100644 --- a/src/ast/simplifiers/CMakeLists.txt +++ b/src/ast/simplifiers/CMakeLists.txt @@ -1,7 +1,8 @@ z3_add_component(simplifiers SOURCES - euf_completion.cpp bv_slice.cpp + euf_completion.cpp + extract_eqs.cpp solve_eqs.cpp COMPONENT_DEPENDENCIES euf diff --git a/src/ast/simplifiers/extract_eqs.cpp b/src/ast/simplifiers/extract_eqs.cpp new file mode 100644 index 000000000..99be4268c --- /dev/null +++ b/src/ast/simplifiers/extract_eqs.cpp @@ -0,0 +1,239 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + extract_eqs.cpp + +Abstract: + + simplifier for solving equations + +Author: + + Nikolaj Bjorner (nbjorner) 2022-11-2. + +--*/ + + +#include "ast/ast_util.h" +#include "ast/for_each_expr.h" +#include "ast/ast_pp.h" +#include "ast/arith_decl_plugin.h" +#include "ast/simplifiers/extract_eqs.h" + + +namespace euf { + + class basic_extract_eq : public extract_eq { + ast_manager& m; + + public: + basic_extract_eq(ast_manager& m) : m(m) {} + + void get_eqs(dependent_expr const& e, dep_eq_vector& eqs) override { + auto [f, d] = e(); + expr* x, * y; + if (m.is_eq(f, x, y)) { + if (is_uninterp_const(x)) + eqs.push_back(dependent_eq(to_app(x), expr_ref(y, m), d)); + if (is_uninterp_const(y)) + eqs.push_back(dependent_eq(to_app(y), expr_ref(x, m), d)); + } + expr* c, * th, * el, * x1, * y1, * x2, * y2; + if (m.is_ite(f, c, th, el)) { + if (m.is_eq(th, x1, y1) && m.is_eq(el, x2, y2)) { + if (x1 == y2 && is_uninterp_const(x1)) + std::swap(x2, y2); + if (x2 == y2 && is_uninterp_const(x2)) + std::swap(x2, y2), std::swap(x1, y1); + if (x2 == y1 && is_uninterp_const(x2)) + std::swap(x1, y1); + if (x1 == x2 && is_uninterp_const(x1)) + eqs.push_back(dependent_eq(to_app(x1), expr_ref(m.mk_ite(c, y1, y2), m), d)); + } + } + if (is_uninterp_const(f)) + eqs.push_back(dependent_eq(to_app(f), expr_ref(m.mk_true(), m), d)); + if (m.is_not(f, x) && is_uninterp_const(x)) + eqs.push_back(dependent_eq(to_app(x), expr_ref(m.mk_false(), m), d)); + } + }; + + class arith_extract_eq : public extract_eq { + ast_manager& m; + arith_util a; + expr_ref_vector m_args; + expr_sparse_mark m_nonzero; + + + // solve u mod r1 = y -> u = r1*mod!1 + y + void solve_mod(expr* x, expr* y, expr_dependency* d, dep_eq_vector& eqs) { + expr* u, * z; + rational r1, r2; + if (!a.is_mod(x, u, z)) + return; + if (!a.is_numeral(z, r1)) + return; + if (r1 <= 0) + return; + expr_ref term(m); + term = a.mk_add(a.mk_mul(z, m.mk_fresh_const("mod", a.mk_int())), y); + solve_eq(u, term, d, eqs); + } + + /*** + * Solve + * x + Y = Z -> x = Z - Y + * -1*x + Y = Z -> x = Y - Z + * a*x + Y = Z -> x = (Z - Y)/a for is-real(x) + */ + void solve_add(expr* x, expr* y, expr_dependency* d, dep_eq_vector& eqs) { + if (!a.is_add(x)) + return; + expr* u, * z; + rational r; + expr_ref term(m); + unsigned i = 0; + auto mk_term = [&](unsigned i) { + term = y; + unsigned j = 0; + for (expr* arg2 : *to_app(x)) { + if (i != j) + term = a.mk_sub(term, arg2); + ++j; + } + }; + for (expr* arg : *to_app(x)) { + if (is_uninterp_const(arg)) { + mk_term(i); + eqs.push_back(dependent_eq(to_app(arg), term, d)); + } + else if (a.is_mul(arg, u, z) && a.is_numeral(u, r) && is_uninterp_const(z)) { + if (r == -1) { + mk_term(i); + term = a.mk_uminus(term); + eqs.push_back(dependent_eq(to_app(z), term, d)); + } + else if (a.is_real(arg) && r != 0) { + mk_term(i); + term = a.mk_div(term, u); + eqs.push_back(dependent_eq(to_app(z), term, d)); + } + } + else if (a.is_real(arg) && a.is_mul(arg)) { + unsigned j = 0; + for (expr* xarg : *to_app(arg)) { + ++j; + if (!is_uninterp_const(xarg)) + continue; + unsigned k = 0; + bool nonzero = true; + for (expr* yarg : *to_app(arg)) { + ++k; + nonzero = k == j || m_nonzero.is_marked(yarg) || (a.is_numeral(yarg, r) && r != 0); + if (!nonzero) + break; + } + if (!nonzero) + continue; + + k = 0; + ptr_buffer args; + for (expr* yarg : *to_app(arg)) { + ++k; + if (k != j) + args.push_back(yarg); + } + mk_term(i); + term = a.mk_div(term, a.mk_mul(args.size(), args.data())); + eqs.push_back(dependent_eq(to_app(xarg), term, d)); + } + } + ++i; + } + } + + /*** + * Solve for x * Y = Z, where Y != 0 -> x = Z / Y + */ + void solve_mul(expr* x, expr* y, expr_dependency* d, dep_eq_vector& eqs) { + if (!a.is_mul(x)) + return; + rational r; + expr_ref term(m); + unsigned i = 0; + for (expr* arg : *to_app(x)) { + ++i; + if (!is_uninterp_const(arg)) + continue; + unsigned j = 0; + bool nonzero = true; + for (expr* arg2 : *to_app(x)) { + ++j; + nonzero = j == i || m_nonzero.is_marked(arg2) || (a.is_numeral(arg2, r) && r != 0); + if (!nonzero) + break; + } + if (!nonzero) + continue; + ptr_buffer args; + j = 0; + for (expr* arg2 : *to_app(x)) { + ++j; + if (j != i) + args.push_back(arg2); + } + term = a.mk_div(y, a.mk_mul(args)); + eqs.push_back(dependent_eq(to_app(arg), term, d)); + } + } + + void add_pos(expr* f) { + expr* lhs = nullptr, * rhs = nullptr; + rational val; + if (a.is_le(f, lhs, rhs) && a.is_numeral(rhs, val) && val.is_neg()) + m_nonzero.mark(lhs); + else if (a.is_ge(f, lhs, rhs) && a.is_numeral(rhs, val) && val.is_pos()) + m_nonzero.mark(lhs); + else if (m.is_not(f, f)) { + if (a.is_le(f, lhs, rhs) && a.is_numeral(rhs, val) && !val.is_neg()) + m_nonzero.mark(lhs); + else if (a.is_ge(f, lhs, rhs) && a.is_numeral(rhs, val) && !val.is_pos()) + m_nonzero.mark(lhs); + else if (m.is_eq(f, lhs, rhs) && a.is_numeral(rhs, val) && val.is_zero()) + m_nonzero.mark(lhs); + } + } + + void solve_eq(expr* x, expr* y, expr_dependency* d, dep_eq_vector& eqs) { + solve_add(x, y, d, eqs); + solve_mod(x, y, d, eqs); + solve_mul(x, y, d, eqs); + } + + public: + arith_extract_eq(ast_manager& m) : m(m), a(m), m_args(m) {} + void get_eqs(dependent_expr const& e, dep_eq_vector& eqs) override { + auto [f, d] = e(); + expr* x, * y; + if (m.is_eq(f, x, y) && a.is_int_real(x)) { + solve_eq(x, y, d, eqs); + solve_eq(y, x, d, eqs); + } + } + + void pre_process(dependent_expr_state& fmls) override { + m_nonzero.reset(); + for (unsigned i = 0; i < fmls.size(); ++i) { + auto [f, d] = fmls[i](); + add_pos(f); + } + } + }; + + void register_extract_eqs(ast_manager& m, scoped_ptr_vector& ex) { + ex.push_back(alloc(arith_extract_eq, m)); + ex.push_back(alloc(basic_extract_eq, m)); + } +} diff --git a/src/ast/simplifiers/extract_eqs.h b/src/ast/simplifiers/extract_eqs.h new file mode 100644 index 000000000..e6c81bb20 --- /dev/null +++ b/src/ast/simplifiers/extract_eqs.h @@ -0,0 +1,47 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + extract_eqs.h + +Abstract: + + simplifier for solving equations + +Author: + + Nikolaj Bjorner (nbjorner) 2022-11-2. + +--*/ + + +#pragma once + +#include "ast/simplifiers/dependent_expr_state.h" +#include "ast/rewriter/th_rewriter.h" +#include "ast/expr_substitution.h" +#include "util/scoped_ptr_vector.h" + + +namespace euf { + + struct dependent_eq { + app* var; + expr_ref term; + expr_dependency* dep; + dependent_eq(app* var, expr_ref& term, expr_dependency* d) : var(var), term(term), dep(d) {} + }; + + typedef vector dep_eq_vector; + + class extract_eq { + public: + virtual ~extract_eq() {} + virtual void get_eqs(dependent_expr const& e, dep_eq_vector& eqs) = 0; + virtual void pre_process(dependent_expr_state& fmls) {} + }; + + void register_extract_eqs(ast_manager& m, scoped_ptr_vector& ex); + +} diff --git a/src/ast/simplifiers/solve_eqs.cpp b/src/ast/simplifiers/solve_eqs.cpp index a47a05cee..d9fbf9664 100644 --- a/src/ast/simplifiers/solve_eqs.cpp +++ b/src/ast/simplifiers/solve_eqs.cpp @@ -27,59 +27,6 @@ Author: namespace euf { - class basic_extract_eq : public extract_eq { - ast_manager& m; - public: - basic_extract_eq(ast_manager& m) : m(m) {} - void get_eqs(dependent_expr const& e, dep_eq_vector& eqs) { - auto [f, d] = e(); - expr* x, * y; - if (m.is_eq(f, x, y)) { - if (is_uninterp_const(x)) - eqs.push_back(dependent_eq(to_app(x), expr_ref(y, m), d)); - if (is_uninterp_const(y)) - eqs.push_back(dependent_eq(to_app(y), expr_ref(x, m), d)); - } - expr* c, * th, * el, * x1, * y1, * x2, * y2; - if (m.is_ite(f, c, th, el)) { - if (m.is_eq(th, x1, y1) && m.is_eq(el, x2, y2)) { - if (x1 == y2 && is_uninterp_const(x1)) - std::swap(x2, y2); - if (x2 == y2 && is_uninterp_const(x2)) - std::swap(x2, y2), std::swap(x1, y1); - if (x2 == y1 && is_uninterp_const(x2)) - std::swap(x1, y1); - if (x1 == x2 && is_uninterp_const(x1)) - eqs.push_back(dependent_eq(to_app(x1), expr_ref(m.mk_ite(c, y1, y2), m), d)); - } - } - if (is_uninterp_const(f)) - eqs.push_back(dependent_eq(to_app(f), expr_ref(m.mk_true(), m), d)); - if (m.is_not(f, x) && is_uninterp_const(x)) - eqs.push_back(dependent_eq(to_app(x), expr_ref(m.mk_false(), m), d)); - } - }; - - class arith_extract_eq : public extract_eq { - ast_manager& m; - arith_util a; -#if 0 - void solve_eq(expr* f, expr_depedency* d) { - - } -#endif - public: - arith_extract_eq(ast_manager& m) : m(m), a(m) {} - void get_eqs(dependent_expr const& e, dep_eq_vector& eqs) { -#if 0 - auto [f, d] = e(); - expr* x, * y; - if (m.is_eq(f, x, y) && a.is_int_real(x)) - ; -#endif - } - }; - // initialize graph that maps variable ids to next ids void solve_eqs::extract_dep_graph(dep_eq_vector& eqs) { m_var2id.reset(); @@ -210,6 +157,11 @@ namespace euf { } void solve_eqs::reduce() { + + for (extract_eq* ex : m_extract_plugins) + ex->pre_process(m_fmls); + + // TODO add a loop. dep_eq_vector eqs; get_eqs(eqs); extract_dep_graph(eqs); @@ -218,4 +170,20 @@ namespace euf { advance_qhead(m_fmls.size()); } + solve_eqs::solve_eqs(ast_manager& m, dependent_expr_state& fmls) : + dependent_expr_simplifier(m, fmls), m_rewriter(m) { + register_extract_eqs(m, m_extract_plugins); + } + + void solve_eqs::updt_params(params_ref const& p) { + // TODO +#if 0 + tactic_params tp(m_params); + m_ite_solver = p.get_bool("ite_solver", tp.solve_eqs_ite_solver()); + m_theory_solver = p.get_bool("theory_solver", tp.solve_eqs_theory_solver()); + m_max_occs = p.get_uint("solve_eqs_max_occs", tp.solve_eqs_max_occs()); + m_context_solve = p.get_bool("context_solve", tp.solve_eqs_context_solve()); +#endif + } + } diff --git a/src/ast/simplifiers/solve_eqs.h b/src/ast/simplifiers/solve_eqs.h index 55cad7e67..942498a52 100644 --- a/src/ast/simplifiers/solve_eqs.h +++ b/src/ast/simplifiers/solve_eqs.h @@ -18,29 +18,13 @@ Author: #pragma once -#include "ast/simplifiers/dependent_expr_state.h" #include "ast/rewriter/th_rewriter.h" #include "ast/expr_substitution.h" #include "util/scoped_ptr_vector.h" - +#include "ast/simplifiers/extract_eqs.h" namespace euf { - struct dependent_eq { - app* var; - expr_ref term; - expr_dependency* dep; - dependent_eq(app* var, expr_ref& term, expr_dependency* d) : var(var), term(term), dep(d) {} - }; - - typedef vector dep_eq_vector; - - class extract_eq { - public: - virtual ~extract_eq() {} - virtual void get_eqs(dependent_expr const& e, dep_eq_vector& eqs) = 0; - }; - class solve_eqs : public dependent_expr_simplifier { th_rewriter m_rewriter; scoped_ptr_vector m_extract_plugins; @@ -71,10 +55,12 @@ namespace euf { public: - solve_eqs(ast_manager& m, dependent_expr_state& fmls) : dependent_expr_simplifier(m, fmls), m_rewriter(m) {} + solve_eqs(ast_manager& m, dependent_expr_state& fmls); void push() override { dependent_expr_simplifier::push(); } void pop(unsigned n) override { dependent_expr_simplifier::pop(n); } void reduce() override; + + void updt_params(params_ref const& p) override; }; }