diff --git a/src/ast/rewriter/poly_rewriter.h b/src/ast/rewriter/poly_rewriter.h index c4b120ae5..f49980a57 100644 --- a/src/ast/rewriter/poly_rewriter.h +++ b/src/ast/rewriter/poly_rewriter.h @@ -37,9 +37,11 @@ protected: bool m_hoist_mul; bool m_hoist_cmul; bool m_ast_order; + bool m_hoist_ite; bool is_numeral(expr * n) const { return Config::is_numeral(n); } bool is_numeral(expr * n, numeral & r) const { return Config::is_numeral(n, r); } + bool is_int_numeral(expr * n, numeral & r) const { return Config::is_numeral(n, r) && r.is_int(); } bool is_minus_one(expr * n) const { return Config::is_minus_one(n); } void normalize(numeral & c) { Config::normalize(c, m_curr_sort); } app * mk_numeral(numeral const & r) { return Config::mk_numeral(r, m_curr_sort); } @@ -78,6 +80,11 @@ protected: br_status cancel_monomials(expr * lhs, expr * rhs, bool move, expr_ref & lhs_result, expr_ref & rhs_result); + bool is_nontrivial_gcd(numeral const& g) const { return !g.is_zero() && !g.is_one(); } + bool hoist_ite(expr_ref& e); + bool hoist_ite(expr* e, obj_hashtable& shared, numeral& g); + expr* apply_hoist(expr* e, numeral const& g, obj_hashtable const& shared); + bool hoist_multiplication(expr_ref& som); expr* merge_muls(expr* x, expr* y); diff --git a/src/ast/rewriter/poly_rewriter_def.h b/src/ast/rewriter/poly_rewriter_def.h index 25535856e..ffe721fad 100644 --- a/src/ast/rewriter/poly_rewriter_def.h +++ b/src/ast/rewriter/poly_rewriter_def.h @@ -16,6 +16,8 @@ Author: Notes: --*/ + +#include "util/container_util.h" #include "ast/rewriter/poly_rewriter.h" #include "ast/rewriter/poly_rewriter_params.hpp" #include "ast/rewriter/arith_rewriter_params.hpp" @@ -31,6 +33,7 @@ void poly_rewriter::updt_params(params_ref const & _p) { m_som = p.som(); m_hoist_mul = p.hoist_mul(); m_hoist_cmul = p.hoist_cmul(); + m_hoist_ite = p.hoist_ite(); m_som_blowup = p.som_blowup(); if (!m_flat) m_som = false; if (m_som) m_hoist_mul = false; @@ -628,11 +631,14 @@ br_status poly_rewriter::mk_nflat_add_core(unsigned num_args, expr * con if (hoist_multiplication(result)) { return BR_REWRITE_FULL; } + if (hoist_ite(result)) { + return BR_REWRITE_FULL; + } return BR_DONE; } else { SASSERT(!has_multiple); - if (ordered && !m_hoist_mul && !m_hoist_cmul) { + if (ordered && !m_hoist_mul && !m_hoist_cmul && !m_hoist_ite) { if (num_coeffs == 0) return BR_FAILED; if (num_coeffs == 1 && is_numeral(args[0], a) && !a.is_zero()) @@ -655,11 +661,14 @@ br_status poly_rewriter::mk_nflat_add_core(unsigned num_args, expr * con std::sort(new_args.c_ptr(), new_args.c_ptr() + new_args.size(), lt); else std::sort(new_args.c_ptr() + 1, new_args.c_ptr() + new_args.size(), lt); - } + } result = mk_add_app(new_args.size(), new_args.c_ptr()); if (hoist_multiplication(result)) { return BR_REWRITE_FULL; } + if (hoist_ite(result)) { + return BR_REWRITE_FULL; + } return BR_DONE; } } @@ -978,6 +987,101 @@ expr* poly_rewriter::merge_muls(expr* x, expr* y) { return mk_mul_app(k+1, m1.c_ptr()); } +template +bool poly_rewriter::hoist_ite(expr_ref& e) { + if (!m_hoist_ite) { + return false; + } + obj_hashtable shared; + ptr_buffer adds; + expr_ref_vector bs(m()), pinned(m()); + TO_BUFFER(is_add, adds, e); + unsigned i = 0; + for (expr* a : adds) { + if (m().is_ite(a)) { + shared.reset(); + numeral g(0); + if (hoist_ite(a, shared, g) && (is_nontrivial_gcd(g) || !shared.empty())) { + bs.reset(); + if (!shared.empty()) { + g = numeral(1); + } + bs.push_back(apply_hoist(a, g, shared)); + if (is_nontrivial_gcd(g)) { + bs.push_back(mk_numeral(g)); + bs[0] = mk_mul_app(2, bs.c_ptr()); + bs.pop_back(); + } + else { + for (expr* s : shared) { + bs.push_back(s); + } + } + adds[i] = mk_add_app(bs.size(), bs.c_ptr()); + pinned.push_back(adds[i]); + } + } + ++i; + } + if (!pinned.empty()) { + e = mk_add_app(adds.size(), adds.c_ptr()); + return true; + } + return false; +} + +template +bool poly_rewriter::hoist_ite(expr* a, obj_hashtable& shared, numeral& g) { + expr* c = nullptr, *t = nullptr, *e = nullptr; + if (m().is_ite(a, c, t, e)) { + return hoist_ite(t, shared, g) && hoist_ite(e, shared, g); + } + rational k, g1; + if (is_int_numeral(a, k)) { + g = gcd(g, k); + return shared.empty(); + } + ptr_buffer adds; + TO_BUFFER(is_add, adds, a); + if (g.is_zero()) { // first + for (expr* e : adds) { + shared.insert(e); + } + } + else { + obj_hashtable tmp; + for (expr* e : adds) { + tmp.insert(e); + } + set_intersection, obj_hashtable>(shared, tmp); + } + g = numeral(1); + return !shared.empty(); +} + +template +expr* poly_rewriter::apply_hoist(expr* a, numeral const& g, obj_hashtable const& shared) { + expr* c = nullptr, *t = nullptr, *e = nullptr; + if (m().is_ite(a, c, t, e)) { + return m().mk_ite(c, apply_hoist(t, g, shared), apply_hoist(e, g, shared)); + } + rational k; + if (is_nontrivial_gcd(g) && is_int_numeral(a, k)) { + return mk_numeral(k/g); + } + ptr_buffer adds; + TO_BUFFER(is_add, adds, a); + unsigned i = 0; + for (expr* e : adds) { + if (!shared.contains(e)) { + adds[i++] = e; + } + } + adds.shrink(i); + return mk_add_app(adds.size(), adds.c_ptr()); +} + + template bool poly_rewriter::is_times_minus_one(expr * n, expr* & r) const { if (is_mul(n) && to_app(n)->get_num_args() == 2 && is_minus_one(to_app(n)->get_arg(0))) { diff --git a/src/ast/rewriter/poly_rewriter_params.pyg b/src/ast/rewriter/poly_rewriter_params.pyg index 7e909d4aa..776ec890b 100644 --- a/src/ast/rewriter/poly_rewriter_params.pyg +++ b/src/ast/rewriter/poly_rewriter_params.pyg @@ -5,4 +5,5 @@ def_module_params(module_name='rewriter', ("som_blowup", UINT, 10, "maximum increase of monomials generated when putting a polynomial in sum-of-monomials normal form"), ("hoist_mul", BOOL, False, "hoist multiplication over summation to minimize number of multiplications"), ("hoist_cmul", BOOL, False, "hoist constant multiplication over summation to minimize number of multiplications"), + ("hoist_ite", BOOL, False, "hoist shared summands under ite expressions"), ("flat", BOOL, True, "create nary applications for and,or,+,*,bvadd,bvmul,bvand,bvor,bvxor"))) diff --git a/src/parsers/smt2/smt2parser.cpp b/src/parsers/smt2/smt2parser.cpp index 35b9bd9bb..99312c912 100644 --- a/src/parsers/smt2/smt2parser.cpp +++ b/src/parsers/smt2/smt2parser.cpp @@ -2579,36 +2579,9 @@ namespace smt2 { void parse_assumptions() { while (!curr_is_rparen()) { - bool sign; - expr_ref t_ref(m()); - if (curr_is_lparen()) { - next(); - check_id_next(m_not, "invalid check-sat command, 'not' expected, assumptions must be Boolean literals"); - check_identifier("invalid check-sat command, literal expected"); - sign = true; - } - else { - check_identifier("invalid check-sat command, literal or ')' expected"); - sign = false; - } - symbol n = curr_id(); - next(); - m_ctx.mk_const(n, t_ref); - if (!m().is_bool(t_ref)) + parse_expr(); + if (!m().is_bool(expr_stack().back())) throw parser_exception("invalid check-sat command, argument must be a Boolean literal"); - if (sign) { - if (!is_uninterp_const(t_ref)) - throw parser_exception("invalid check-sat command, argument must be a Boolean literal"); - t_ref = m().mk_not(t_ref.get()); - } - else { - expr * arg; - if (!is_uninterp_const(t_ref) && !(m().is_not(t_ref, arg) && is_uninterp_const(arg))) - throw parser_exception("invalid check-sat command, argument must be a Boolean literal"); - } - expr_stack().push_back(t_ref.get()); - if (sign) - check_rparen_next("invalid check-sat command, ')' expected"); } } diff --git a/src/smt/qi_queue.cpp b/src/smt/qi_queue.cpp index 9048a0893..4b3524db9 100644 --- a/src/smt/qi_queue.cpp +++ b/src/smt/qi_queue.cpp @@ -217,7 +217,7 @@ namespace smt { TRACE("checker", tout << "reduced to true, before:\n" << mk_ll_pp(instance, m);); if (m.has_trace_stream()) { - display_instance_profile(f, q, num_bindings, bindings, pr->get_id(), generation); + display_instance_profile(f, q, num_bindings, bindings, pr ? pr->get_id() : 0, generation); m.trace_stream() << "[end-of-instance]\n"; } diff --git a/src/smt/theory_seq.cpp b/src/smt/theory_seq.cpp index 8a29648e9..e51082ab7 100644 --- a/src/smt/theory_seq.cpp +++ b/src/smt/theory_seq.cpp @@ -5961,7 +5961,7 @@ void theory_seq::add_lt_axiom(expr* n) { add_axiom(lt, eq, e1xcy); add_axiom(lt, eq, emp2, ltdc); add_axiom(lt, eq, emp2, e2xdz); - if (e1->get_id() <= e2->get_id() || true) { + if (e1->get_id() <= e2->get_id()) { literal gt = mk_literal(m_util.str.mk_lex_lt(e2, e1)); add_axiom(lt, eq, gt); add_axiom(~eq, ~lt);