From 8b23a1701a2fe0972642bda96e71f84f891ac3bb Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 6 Feb 2020 09:16:23 -0800 Subject: [PATCH] move flatten functionality to asserted_formulas, sort variables in lut_finder Signed-off-by: Nikolaj Bjorner --- src/math/lp/gomory.cpp | 2 ++ src/sat/sat_aig_cuts.cpp | 4 ++-- src/sat/sat_lut_finder.cpp | 7 ++++++- src/smt/asserted_formulas.cpp | 39 ++++++++++++++++++++++++++++++++++- src/smt/asserted_formulas.h | 10 +++++++++ src/smt/smt_context.cpp | 35 ++++++++++++++++++++++--------- src/smt/smt_context.h | 8 +++---- src/smt/smt_context_inv.cpp | 3 ++- src/smt/smt_internalizer.cpp | 37 --------------------------------- 9 files changed, 88 insertions(+), 57 deletions(-) diff --git a/src/math/lp/gomory.cpp b/src/math/lp/gomory.cpp index 46f745d03..c90e499b2 100644 --- a/src/math/lp/gomory.cpp +++ b/src/math/lp/gomory.cpp @@ -78,6 +78,7 @@ class gomory::imp { m_lcm_den = lcm(m_lcm_den, denominator(new_a)); TRACE("gomory_cut_detail", tout << "new_a = " << new_a << ", k = " << m_k << ", lcm_den = " << m_lcm_den << "\n";); #if SMALL_CUTS + // if (numerator(new_a).is_big()) throw found_big(); if (numerator(new_a) > m_big_number) throw found_big(); #endif } @@ -110,6 +111,7 @@ class gomory::imp { TRACE("gomory_cut_detail_real", tout << a << "*v" << j << " k: " << m_k << "\n";); m_t.add_monomial(new_a, j); #if SMALL_CUTS + // if (numerator(new_a).is_big()) throw found_big(); if (numerator(new_a) > m_big_number) throw found_big(); #endif } diff --git a/src/sat/sat_aig_cuts.cpp b/src/sat/sat_aig_cuts.cpp index 069a418a0..bb99ab116 100644 --- a/src/sat/sat_aig_cuts.cpp +++ b/src/sat/sat_aig_cuts.cpp @@ -276,13 +276,13 @@ namespace sat { } void aig_cuts::add_cut(bool_var v, uint64_t lut, bool_var_vector const& args) { + // args can be assumed to be sorted + DEBUG_CODE(for (unsigned i = 0; i + 1 < args.size(); ++i) VERIFY(args[i] < args[i+1]);); reserve(v); for (bool_var w : args) reserve(w); - // optional: reshuffle lut and sort variables. cut c; for (bool_var w : args) VERIFY(c.add(w)); c.set_table(lut); - // add-don't care? insert_cut(v, c, m_cuts[v]); } diff --git a/src/sat/sat_lut_finder.cpp b/src/sat/sat_lut_finder.cpp index f701c669b..0a90dc377 100644 --- a/src/sat/sat_lut_finder.cpp +++ b/src/sat/sat_lut_finder.cpp @@ -59,7 +59,13 @@ namespace sat { s.init_visited(); unsigned mask = 0, i = 0; m_vars.reset(); + m_clause.reset(); for (literal l : c) { + m_clause.push_back(l); + } + // ensure that variables in returned LUT are sorted + std::sort(m_clause.begin(), m_clause.end()); + for (literal l : m_clause) { m_vars.push_back(l.var()); m_var_position[l.var()] = i; s.mark_visited(l.var()); @@ -67,7 +73,6 @@ namespace sat { } m_clauses_to_remove.reset(); m_clauses_to_remove.push_back(&c); - m_clause.resize(c.size()); m_combination = 0; m_num_combinations = 0; set_combination(mask); diff --git a/src/smt/asserted_formulas.cpp b/src/smt/asserted_formulas.cpp index 963356a1b..6120caf68 100644 --- a/src/smt/asserted_formulas.cpp +++ b/src/smt/asserted_formulas.cpp @@ -57,7 +57,8 @@ asserted_formulas::asserted_formulas(ast_manager & m, smt_params & sp, params_re m_find_macros(*this), m_propagate_values(*this), m_nnf_cnf(*this), - m_apply_quasi_macros(*this) { + m_apply_quasi_macros(*this), + m_flatten_clauses(*this) { m_macro_finder = alloc(macro_finder, m, m_macro_manager); @@ -267,6 +268,7 @@ void asserted_formulas::reduce() { if (!invoke(m_max_bv_sharing_fn)) return; if (!invoke(m_elim_bvs_from_quantifiers)) return; if (!invoke(m_reduce_asserted_formulas)) return; + if (!invoke(m_flatten_clauses)) return; // if (!invoke(m_propagate_values)) return; IF_VERBOSE(10, verbose_stream() << "(smt.simplifier-done)\n";); @@ -344,6 +346,41 @@ void asserted_formulas::find_macros_core() { reduce_and_solve(); } +/** + \brief rewrite (a or (b & c)) to (a or b), (a or c) if the reference count of (b & c) is 1. + This avoids the literal for (b & c) +*/ +void asserted_formulas::flatten_clauses() { + if (m.proofs_enabled()) return; + bool change = true; + vector new_fmls; + while (change) { + change = false; + new_fmls.reset(); + unsigned sz = m_formulas.size(); + for (unsigned i = m_qhead; i < sz; ++i) { + auto const& j = m_formulas.get(i); + expr* f = j.get_fml(), *a = nullptr, *b = nullptr; + bool decomposed = false; + if (m.is_or(f, a, b) && m.is_not(b, b) && m.is_or(b) && b->get_ref_count() == 1) { + decomposed = true; + } + else if (m.is_or(f, b, a) && m.is_not(b, b) && m.is_or(b) && b->get_ref_count() == 1) { + decomposed = true; + } + if (decomposed) { + for (expr* arg : *to_app(b)) { + justified_expr j1(m, m.mk_or(a, m.is_not(arg, arg) ? arg : m.mk_not(arg)), nullptr); + new_fmls.push_back(j1); + } + continue; + } + new_fmls.push_back(j); + } + swap_asserted_formulas(new_fmls); + } +} + void asserted_formulas::apply_quasi_macros() { TRACE("before_quasi_macros", display(tout);); diff --git a/src/smt/asserted_formulas.h b/src/smt/asserted_formulas.h index 69b6bb26a..d56057788 100644 --- a/src/smt/asserted_formulas.h +++ b/src/smt/asserted_formulas.h @@ -160,6 +160,15 @@ class asserted_formulas { void pop(unsigned n) { m_elim.pop(n); } }; + class flatten_clauses_fn : public simplify_fmls { + public: + flatten_clauses_fn(asserted_formulas& af): simplify_fmls(af, "flatten-clauses") {} + void operator()() override { af.flatten_clauses(); } + bool should_apply() const override { return true; } + void simplify(justified_expr const& j, expr_ref& n, proof_ref& p) override { UNREACHABLE(); } + }; + void flatten_clauses(); + #define MK_SIMPLIFIERA(NAME, FUNCTOR, MSG, APP, ARG, REDUCE) \ class NAME : public simplify_fmls { \ FUNCTOR m_functor; \ @@ -198,6 +207,7 @@ class asserted_formulas { propagate_values_fn m_propagate_values; nnf_cnf_fn m_nnf_cnf; apply_quasi_macros_fn m_apply_quasi_macros; + flatten_clauses_fn m_flatten_clauses; bool invoke(simplify_fmls& s); void swap_asserted_formulas(vector& new_fmls); diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 7e607153f..5f71bd4e5 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -1989,22 +1989,37 @@ namespace smt { } /** - \brief Update the index used for backward subsumption. + \brief Remove clause */ - void context::remove_lit_occs(clause const& cls) { - int nbv = get_num_bool_vars(); - for (literal l : cls) { - if (l.var() < nbv) - dec_ref(l); - } - } void context::remove_cls_occs(clause * cls) { remove_watch_literal(cls, 0); remove_watch_literal(cls, 1); - remove_lit_occs(*cls); + remove_lit_occs(*cls, get_num_bool_vars()); } + /** + \brief Update occurrence count of literals + */ + + void context::add_lit_occs(clause const& cls) { + for (literal l : cls) { + inc_ref(l); + } + } + + void context::remove_lit_occs(clause const& cls, unsigned nbv) { + for (literal l : cls) { + if (l.var() < static_cast(nbv)) + dec_ref(l); + } + } + + // TBD: enable as assertion when ready to re-check + void context::dec_ref(literal l) { if (m_lit_occs[l.index()] > 0) m_lit_occs[l.index()]--; } + + void context::inc_ref(literal l) { m_lit_occs[l.index()]++; } + /** \brief Delete the given clause. @@ -2257,7 +2272,7 @@ namespace smt { unsigned num = cls->get_num_literals(); - remove_lit_occs(*cls); + remove_lit_occs(*cls, num_bool_vars); unsigned ilvl = 0; (void)ilvl; diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 47d5fbfae..73a091181 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -839,11 +839,11 @@ namespace smt { void mk_ite_cnstr(app * n); - void dec_ref(literal l) { if (m_lit_occs[l.index()] > 0) m_lit_occs[l.index()]--; } + void dec_ref(literal l); - void inc_ref(literal l) { m_lit_occs[l.index()]++; } + void inc_ref(literal l); - void remove_lit_occs(clause const& cls); + void remove_lit_occs(clause const& cls, unsigned num_bool_vars); void add_lit_occs(clause const& cls); public: @@ -1570,8 +1570,6 @@ namespace smt { void get_guessed_literals(expr_ref_vector & result); - bool split_binary(app* o, expr*& a, expr_ref& b, expr_ref& c); - void internalize_assertion(expr * n, proof * pr, unsigned generation); void internalize_proxies(expr_ref_vector const& asms, vector>& asm2proxy); diff --git a/src/smt/smt_context_inv.cpp b/src/smt/smt_context_inv.cpp index de044ffda..8413d06e5 100644 --- a/src/smt/smt_context_inv.cpp +++ b/src/smt/smt_context_inv.cpp @@ -33,7 +33,8 @@ namespace smt { SASSERT(is_watching_clause(~cls->get_literal(0), cls)); SASSERT(is_watching_clause(~cls->get_literal(1), cls)); for (literal l : *cls) { - // currently does not hold: SASSERT(m_lit_occs[l.index()] > 0); + // holds, TBD re-enable when ready to re-check + // SASSERT(m_lit_occs[l.index()] > 0); } return true; } diff --git a/src/smt/smt_internalizer.cpp b/src/smt/smt_internalizer.cpp index ea4787c02..7917f9976 100644 --- a/src/smt/smt_internalizer.cpp +++ b/src/smt/smt_internalizer.cpp @@ -182,27 +182,6 @@ namespace smt { } } - bool context::split_binary(app* o, expr*& a, expr_ref& b, expr_ref& c) { - expr* x = nullptr, *y = nullptr, *ny = nullptr, *z = nullptr, *u = nullptr; - if (m.is_or(o, x, y) && - m.is_not(y, ny) && - m.is_or(ny, z, u)) { - a = x; - b = m.is_not(z, z) ? z : m.mk_not(z); - c = m.is_not(u, u) ? u : m.mk_not(u); - return true; - } - if (m.is_or(o, y, x) && - m.is_not(y, ny) && - m.is_or(ny, z, u)) { - a = x; - b = m.is_not(z, z) ? z : m.mk_not(z); - c = m.is_not(u, u) ? u : m.mk_not(u); - return true; - } - return false; - } - #define DEEP_EXPR_THRESHOLD 1024 @@ -251,16 +230,6 @@ namespace smt { expr* a = nullptr; expr_ref b(m), c(m); // perform light-weight rewriting on clauses. - if (!relevancy() && split_binary(to_app(n), a, b, c)) { - internalize(a, true); - internalize(b, true); - internalize(c, true); - literal lits2[2] = { get_literal(a), get_literal(b) }; - literal lits3[2] = { get_literal(a), get_literal(c) }; - mk_root_clause(2, lits2, pr); - mk_root_clause(2, lits3, pr); - break; - } for (expr * arg : *to_app(n)) { internalize(arg, true); lits.push_back(get_literal(arg)); @@ -1479,12 +1448,6 @@ namespace smt { }} } - void context::add_lit_occs(clause const& cls) { - for (literal l : cls) { - inc_ref(l); - } - } - void context::mk_clause(literal l1, literal l2, justification * j) { literal ls[2] = { l1, l2 }; mk_clause(2, ls, j);