From a7b3dae262c766331452ea89f6f4e05f859039b8 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 3 Jun 2023 11:52:15 -0700 Subject: [PATCH] save state of arith Signed-off-by: Nikolaj Bjorner --- src/ast/CMakeLists.txt | 1 + src/ast/expr_polarities.cpp | 86 +++++++++++++++++++++++++++++++++ src/ast/expr_polarities.h | 52 ++++++++++++++++++++ src/math/lp/int_solver.h | 2 +- src/sat/smt/arith_solver.cpp | 12 ++++- src/sat/smt/euf_internalize.cpp | 4 ++ src/sat/smt/euf_solver.cpp | 4 ++ src/sat/smt/euf_solver.h | 5 ++ src/sat/smt/q_solver.cpp | 4 ++ src/sat/tactic/goal2sat.cpp | 10 +++- src/smt/arith_eq_adapter.cpp | 7 ++- src/smt/smt_context.cpp | 10 ++-- src/smt/smt_context.h | 14 ++++-- src/smt/smt_internalizer.cpp | 2 + src/smt/theory_arith_core.h | 26 +++++++--- src/smt/theory_lra.cpp | 65 ++++++++++++++++++++----- 16 files changed, 267 insertions(+), 37 deletions(-) create mode 100644 src/ast/expr_polarities.cpp create mode 100644 src/ast/expr_polarities.h diff --git a/src/ast/CMakeLists.txt b/src/ast/CMakeLists.txt index 8dd870964..4beecb917 100644 --- a/src/ast/CMakeLists.txt +++ b/src/ast/CMakeLists.txt @@ -25,6 +25,7 @@ z3_add_component(ast expr_abstract.cpp expr_functors.cpp expr_map.cpp + expr_polarities.cpp expr_stat.cpp expr_substitution.cpp for_each_ast.cpp diff --git a/src/ast/expr_polarities.cpp b/src/ast/expr_polarities.cpp new file mode 100644 index 000000000..1c67d5a70 --- /dev/null +++ b/src/ast/expr_polarities.cpp @@ -0,0 +1,86 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + expr_polarities.cpp + +Abstract: + + Extract polarities of expressions. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-06-01 + +--*/ +#pragma once + +#include "ast/expr_polarities.h" + +void expr_polarities::push() { + m_fresh_lim.push_back(m_fresh.size()); + m_trail_lim.push_back(m_trail.size()); +} + +void expr_polarities::pop(unsigned n) { + unsigned sz = m_fresh_lim[m_fresh_lim.size() - n]; + for (unsigned i = m_fresh_lim.size(); --i > sz; ) { + auto const & [e, p] = m_fresh[i]; + if (p) + m_pos.mark(e, false); + else + m_neg.mark(e, false); + } + m_fresh.shrink(sz); + m_fresh_lim.shrink(m_fresh_lim.size() - n); + sz = m_trail_lim[m_trail_lim.size() - n]; + m_trail.shrink(sz); + m_trail_lim.shrink(m_trail_lim.size() - n); +} + + +void expr_polarities::add(expr* e) { + if (m_pos.is_marked(e)) + return; + m_trail.push_back(e); + buffer> frames; + frames.push_back({true, e}); + while (!frames.empty()) { + auto [p, e] = frames.back(); + frames.pop_back(); + if (p) { + if (m_pos.is_marked(e)) + continue; + m_pos.mark(e, true); + } + else { + if (m_neg.is_marked(e)) + continue; + m_neg.mark(e, true); + } + m_fresh.push_back({e, p}); + if (m.is_and(e) || m.is_or(e)) { + for (expr* arg : *to_app(e)) + frames.push_back({p, arg}); + } + else if (m.is_not(e, e)) { + frames.push_back({!p, e}); + } + else if (m.is_implies(e)) { + frames.push_back({!p, to_app(e)->get_arg(0)}); + frames.push_back({p, to_app(e)->get_arg(1)}); + } + else if (is_app(e)) { + for (expr* arg : *to_app(e)) { + frames.push_back({true, arg}); + frames.push_back({false, arg}); + } + } + else if (is_quantifier(e)) + frames.push_back({p, to_quantifier(e)->get_expr()}); + } +} + + + diff --git a/src/ast/expr_polarities.h b/src/ast/expr_polarities.h new file mode 100644 index 000000000..0231b9337 --- /dev/null +++ b/src/ast/expr_polarities.h @@ -0,0 +1,52 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + expr_polarities.h + +Abstract: + + Extract polarities of expressions. + +Author: + + Nikolaj Bjorner (nbjorner) 2023-06-01 + +--*/ +#pragma once + +#include "ast/ast.h" + +class expr_polarities { + ast_manager & m; + expr_ref_vector m_trail; + expr_mark m_pos, m_neg; + vector> m_fresh; + unsigned_vector m_fresh_lim, m_trail_lim; +public: + expr_polarities(ast_manager & m) : m(m), m_trail(m) {} + + void push(); + + void pop(unsigned n); + + // add expressions to annotate with polarities + void add(expr* e); + + void add(expr_ref_vector const& es) { + for (expr* e : es) + add(e); + } + + bool has_negative(expr* e) const { + return m_neg.is_marked(e); + } + + bool has_positive(expr* e) const { + return m_pos.is_marked(e); + } + +}; + + diff --git a/src/math/lp/int_solver.h b/src/math/lp/int_solver.h index 94802c648..e128871f5 100644 --- a/src/math/lp/int_solver.h +++ b/src/math/lp/int_solver.h @@ -51,7 +51,7 @@ class int_solver { public: patcher(int_solver& lia); bool should_apply() const { return true; } - lia_move operator()() { return patch_basic_columns(); } + lia_move operator()() { return patch_nbasic_columns(); } void patch_nbasic_column(unsigned j); bool patch_basic_column(unsigned v, row_cell const& c); void patch_basic_column(unsigned j); diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index bd5dd315f..f4693e1e8 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -46,9 +46,17 @@ namespace arith { del_bounds(0); } - void solver::asserted(literal l) { + void solver::asserted(literal lit) { force_push(); - m_asserted.push_back(l); + expr* e = ctx.bool_var2expr(lit.var()); + if (lit.sign() && !ctx.is_neg(e)) { + //verbose_stream() << "not negative " << mk_pp(e, m) << "\n"; + } + else if (!lit.sign() && !ctx.is_pos(e)) { + //verbose_stream() << "not positive " << mk_pp(e, m) << "\n"; + } + else + m_asserted.push_back(lit); } euf::th_solver* solver::clone(euf::solver& dst_ctx) { diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index a1d383e45..c62eccb87 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -34,6 +34,10 @@ Notes: namespace euf { + void solver::add_polarities(expr* f) { + m_polarities.add(f); + } + void solver::internalize(expr* e) { if (get_enode(e)) return; diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 0180fcc60..3f4286d61 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -48,6 +48,7 @@ namespace euf { m_unhandled_functions(m), m_to_m(&m), m_to_si(&si), + m_polarities(m), m_clause_visitor(m), m_smt_proof_checker(m, p), m_clause(m), @@ -642,10 +643,13 @@ namespace euf { e->push(); m_egraph.push(); m_relevancy.push(); + m_polarities.push(); + } void solver::pop(unsigned n) { start_reinit(n); + m_polarities.pop(n); m_trail.pop_scope(n); for (auto* e : m_solvers) e->pop(n); diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 72776b7ff..206d40d36 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -20,6 +20,7 @@ Author: #include "util/trail.h" #include "ast/ast_translation.h" #include "ast/ast_util.h" +#include "ast/expr_polarities.h" #include "ast/euf/euf_egraph.h" #include "ast/rewriter/th_rewriter.h" #include "ast/converters/model_converter.h" @@ -152,6 +153,7 @@ namespace euf { svector m_scopes; scoped_ptr_vector m_solvers; ptr_vector m_id2solver; + expr_polarities m_polarities; constraint* m_conflict = nullptr; constraint* m_eq = nullptr; @@ -451,6 +453,9 @@ namespace euf { bool to_formulas(std::function& l2e, expr_ref_vector& fmls) override; // internalize + void add_polarities(expr* f); + bool is_neg(expr* e) const { return m_polarities.has_negative(e); } + bool is_pos(expr* e) const { return m_polarities.has_positive(e); } sat::literal internalize(expr* e, bool sign, bool root) override; void internalize(expr* e) override; sat::literal mk_literal(expr* e); diff --git a/src/sat/smt/q_solver.cpp b/src/sat/smt/q_solver.cpp index f49611660..54045cc68 100644 --- a/src/sat/smt/q_solver.cpp +++ b/src/sat/smt/q_solver.cpp @@ -47,6 +47,10 @@ namespace q { if (l.sign() == is_forall(e)) { if (m_quantifiers_are_positive && is_forall(e)) return; + if (is_forall(e) && ctx.is_pos(e)) { + verbose_stream() << "is positive\n"; + return; + } sat::literal lit = skolemize(q); add_clause(~l, lit); return; diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 267888804..7bc1a85d7 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -224,7 +224,7 @@ struct goal2sat::imp : public sat::sat_internalizer { return v; } - unsigned m_num_scopes{ 0 }; + unsigned m_num_scopes = 0; void force_push() { for (; m_num_scopes > 0; --m_num_scopes) { @@ -785,6 +785,7 @@ struct goal2sat::imp : public sat::sat_internalizer { }; void process(expr* n, bool is_root) { + TRACE("goal2sat", tout << "process-begin " << mk_bounded_pp(n, m, 2) << " root: " << is_root << " result-stack: " << m_result_stack.size() @@ -840,6 +841,9 @@ struct goal2sat::imp : public sat::sat_internalizer { } sat::literal internalize(expr* n) override { + auto* ext = dynamic_cast(m_solver.get_extension()); + if (ext) + ext->add_polarities(n); bool is_not = m.is_not(n, n); flet _top(m_top_level, false); unsigned sz = m_result_stack.size(); @@ -889,6 +893,10 @@ struct goal2sat::imp : public sat::sat_internalizer { } void process(expr * n) { + auto* ext = dynamic_cast(m_solver.get_extension()); + if (ext) + ext->add_polarities(n); + flet _top(m_top_level, true); VERIFY(m_result_stack.empty()); TRACE("goal2sat", tout << "assert: " << mk_bounded_pp(n, m, 3) << "\n";); diff --git a/src/smt/arith_eq_adapter.cpp b/src/smt/arith_eq_adapter.cpp index b77a38927..cfad30eb1 100644 --- a/src/smt/arith_eq_adapter.cpp +++ b/src/smt/arith_eq_adapter.cpp @@ -258,12 +258,11 @@ namespace smt { TRACE("arith_eq_adapter", tout << "restart\n";); enode_pair_vector tmp(m_restart_pairs); m_restart_pairs.reset(); - for (auto const& p : tmp) { + for (auto const& [a, b] : tmp) { if (ctx.inconsistent()) break; - TRACE("arith_eq_adapter", tout << "creating arith_eq_adapter axioms at the base level #" << p.first->get_owner_id() << " #" << - p.second->get_owner_id() << "\n";); - mk_axioms(p.first, p.second); + TRACE("arith_eq_adapter", tout << "creating arith_eq_adapter axioms at the base level #" << a->get_owner_id() << " #" << b->get_owner_id() << "\n";); + mk_axioms(a, b); } } diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index d84535a9c..0a71f5b6d 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -59,12 +59,9 @@ namespace smt { m_relevancy_propagator(mk_relevancy_propagator(*this)), m_user_propagator(nullptr), m_random(p.m_random_seed), - m_flushing(false), - m_lemma_id(0), - m_progress_callback(nullptr), - m_next_progress_sample(0), m_clause_proof(*this), m_fingerprints(m, m_region), + m_polarities(m), m_b_internalized_stack(m), m_e_internalized_stack(m), m_l_internalized_stack(m), @@ -2975,6 +2972,7 @@ namespace smt { // logical context became inconsistent during user PUSH VERIFY(!resolve_conflict()); // build the proof } + m_polarities.push(); push_scope(); m_base_scopes.push_back(base_scope()); base_scope & bs = m_base_scopes.back(); @@ -2988,9 +2986,11 @@ namespace smt { void context::pop(unsigned num_scopes) { SASSERT (num_scopes > 0); - if (num_scopes > m_scope_lvl) return; + if (num_scopes > m_scope_lvl) + return; pop_to_base_lvl(); pop_scope(num_scopes); + m_polarities.pop(num_scopes); } /** diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 7a267fdec..85730f20e 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -19,6 +19,7 @@ Revision History: #pragma once #include "ast/quantifier_stat.h" +#include "ast/expr_polarities.h" #include "smt/smt_clause.h" #include "smt/smt_setup.h" #include "smt/smt_enode.h" @@ -92,13 +93,14 @@ namespace smt { scoped_ptr m_relevancy_propagator; theory_user_propagator* m_user_propagator; random_gen m_random; - bool m_flushing; // (debug support) true when flushing - mutable unsigned m_lemma_id; - progress_callback * m_progress_callback; - unsigned m_next_progress_sample; + bool m_flushing = false; // (debug support) true when flushing + mutable unsigned m_lemma_id = 0; + progress_callback * m_progress_callback = nullptr; + unsigned m_next_progress_sample = 0; clause_proof m_clause_proof; region m_region; fingerprint_set m_fingerprints; + expr_polarities m_polarities; expr_ref_vector m_b_internalized_stack; // stack of the boolean expressions already internalized. // Remark: boolean expressions can also be internalized as @@ -1550,6 +1552,7 @@ namespace smt { // ----------------------------------- void assert_expr_core(expr * e, proof * pr); + // copy plugins into a fresh context. void copy_plugins(context& src, context& dst); @@ -1626,6 +1629,9 @@ namespace smt { void assert_expr(expr * e, proof * pr); + bool is_positive(expr* e) const { return m_polarities.has_positive(e); } + bool is_negative(expr* e) const { return m_polarities.has_negative(e); } + void internalize_assertions(); void push(); diff --git a/src/smt/smt_internalizer.cpp b/src/smt/smt_internalizer.cpp index 68879b8ac..2349a9b4d 100644 --- a/src/smt/smt_internalizer.cpp +++ b/src/smt/smt_internalizer.cpp @@ -392,6 +392,8 @@ namespace smt { if (m.is_true(n) || m.is_false(n)) return; + m_polarities.add(n); + if (m.is_not(n) && gate_ctx) { // a boolean variable does not need to be created if n a NOT gate is in // the context of a gate. diff --git a/src/smt/theory_arith_core.h b/src/smt/theory_arith_core.h index 2e111fa99..300650a53 100644 --- a/src/smt/theory_arith_core.h +++ b/src/smt/theory_arith_core.h @@ -1379,6 +1379,18 @@ namespace smt { TRACE("arith_verbose", tout << "p" << v << " := " << (is_true?"true":"false") << "\n";); atom * a = get_bv2a(v); if (!a) return; + +#if 0 + expr* f = ctx.bool_var2expr(v); + if (ctx.is_positive(f) || ctx.is_negative(f)) { + if (is_true && !ctx.is_positive(f)) { + return; + } + if (!is_true && !ctx.is_negative(f)) { + return; + } + } +#endif SASSERT(ctx.get_assignment(a->get_bool_var()) != l_undef); SASSERT((ctx.get_assignment(a->get_bool_var()) == l_true) == is_true); a->assign_eh(is_true, get_epsilon(a->get_var())); @@ -1491,10 +1503,6 @@ namespace smt { final_check_status result = FC_DONE; final_check_status ok; - //display(verbose_stream()); - //exit(0); - - if (false) { verbose_stream() << "final\n"; @@ -1514,6 +1522,7 @@ namespace smt { switch (m_final_check_idx) { case 0: ok = check_int_feasibility(); + if (ok != FC_DONE) verbose_stream() << "int-feas\n"; TRACE("arith", tout << "check_int_feasibility(), ok: " << ok << "\n";); break; case 1: @@ -1521,6 +1530,7 @@ namespace smt { ok = FC_CONTINUE; else ok = FC_DONE; + if (ok != FC_DONE) verbose_stream() << "assume-eqs\n"; TRACE("arith", tout << "assume_eqs(), ok: " << ok << "\n";); break; default: @@ -1557,7 +1567,7 @@ namespace smt { template final_check_status theory_arith::final_check_eh() { - // verbose_stream() << "final " << ctx.get_scope_level() << " " << ctx.assigned_literals().size() << "\n"; + verbose_stream() << "final " << ctx.get_scope_level() << " " << ctx.assigned_literals().size() << "\n"; // ctx.display(verbose_stream()); // exit(0); @@ -1566,14 +1576,14 @@ namespace smt { if (!propagate_core()) return FC_CONTINUE; - if (delayed_assume_eqs()) + if (delayed_assume_eqs()) { + verbose_stream() << "delayed-eqs\n"; return FC_CONTINUE; + } ctx.push_trail(value_trail(m_final_check_idx)); m_liberal_final_check = true; m_changed_assignment = false; final_check_status result = final_check_core(); - //display(verbose_stream()); - //exit(0); if (result != FC_DONE) return result; if (!m_changed_assignment) diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index 3b7494cba..10a2204c2 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -1255,7 +1255,6 @@ public: } void internalize_eq_eh(app * atom, bool_var) { - return; if (!ctx().get_fparams().m_arith_eager_eq_axioms) return; expr* lhs = nullptr, *rhs = nullptr; @@ -1265,15 +1264,26 @@ public: if (is_arith(n1) && is_arith(n2) && n1->get_th_var(get_id()) != null_theory_var && n2->get_th_var(get_id()) != null_theory_var && n1 != n2) { - verbose_stream() << "ineq\n"; + // verbose_stream() << "ineq\n"; m_arith_eq_adapter.mk_axioms(n1, n2); } - else - verbose_stream() << "skip\n"; + // else + // verbose_stream() << "skip\n"; } void assign_eh(bool_var v, bool is_true) { TRACE("arith", tout << mk_bounded_pp(ctx().bool_var2expr(v), m) << " " << (literal(v, !is_true)) << "\n";); +#if 0 + expr* f = ctx().bool_var2expr(v); + if (ctx().is_positive(f) || ctx().is_negative(f)) { + if (is_true && !ctx().is_positive(f)) { + return; + } + if (!is_true && !ctx().is_negative(f)) { + return; + } + } +#endif m_asserted_atoms.push_back(delayed_atom(v, is_true)); } @@ -1305,7 +1315,11 @@ public: if (!is_int(v1) && !is_real(v1)) return; - if (true) { + if (false) { + m_deqs.push_back({v1, v2}); + ctx().push_trail(push_back_vector(m_deqs)); + } + else if (false) { enode* n1 = get_enode(v1); enode* n2 = get_enode(v2); lpvar w1 = register_theory_var_in_lar_solver(v1); @@ -1342,6 +1356,7 @@ public: m_arith_eq_adapter.new_diseq_eh(v1, v2); } +// actually have to go over entire stack to ensure diseqs are respected after mutations bool delayed_diseqs() { if (m_diseqs_qhead == m_diseqs.size()) return false; @@ -1360,6 +1375,28 @@ public: return has_eq; } + svector> m_deqs; + unsigned m_eqs_qhead = 0; +// actually have to go over entire stack to ensure eqs are respected after mutations + bool delayed_eqs() { + if (m_eqs_qhead == m_deqs.size()) + return false; + ctx().push_trail(value_trail(m_eqs_qhead)); + bool has_eq = false; + while (m_eqs_qhead < m_deqs.size()) { + auto [v1,v2] = m_deqs[m_eqs_qhead]; + if (!is_eq(v1, v2)) { + //verbose_stream() << "bad diseq " << m_diseqs_qhead << "\n"; + m_arith_eq_adapter.new_eq_eh(v1, v2); + has_eq = true; + } + ++m_eqs_qhead; + } + return has_eq; + } + + + void apply_sort_cnstr(enode* n, sort*) { TRACE("arith", tout << "sort constraint: " << enode_pp(n, ctx()) << "\n";); #if 0 @@ -1845,7 +1882,10 @@ public: if (delayed_diseqs()) return true; - + + if (delayed_eqs()) + return true; + theory_var sz = static_cast(th.get_num_vars()); unsigned old_sz = m_assume_eq_candidates.size(); @@ -1989,7 +2029,7 @@ public: final_check_status final_check_eh() { - // verbose_stream() << "final " << ctx().get_scope_level() << " " << ctx().assigned_literals().size() << "\n"; + verbose_stream() << "final " << ctx().get_scope_level() << " " << ctx().assigned_literals().size() << "\n"; //ctx().display(verbose_stream()); //exit(0); @@ -1998,14 +2038,14 @@ public: if (propagate_core()) return FC_CONTINUE; - if (delayed_assume_eqs()) + if (delayed_assume_eqs()) { + verbose_stream() << "delayed-eqs\n"; return FC_CONTINUE; + } m_liberal_final_check = true; m_changed_assignment = false; ctx().push_trail(value_trail(m_final_check_idx)); final_check_status result = final_check_core(); - //display(verbose_stream()); - //exit(0); if (result != FC_DONE) return result; if (!m_changed_assignment) @@ -2019,7 +2059,6 @@ public: final_check_status final_check_core() { - if (false) { verbose_stream() << "final\n"; @@ -2094,10 +2133,12 @@ public: switch (m_final_check_idx) { case 0: st = check_lia(); + if (st != FC_DONE) verbose_stream() << "check-lia\n"; break; case 1: if (assume_eqs()) - st = FC_CONTINUE; + st = FC_CONTINUE; + if (st != FC_DONE) verbose_stream() << "assume-eqs\n"; break; case 2: st = check_nla();