From 9179deb7463c707719e2461850a78be41cad69dd Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 19 Jan 2020 12:12:04 -0600 Subject: [PATCH] add get-interpolant command Signed-off-by: Nikolaj Bjorner --- src/cmd_context/basic_cmds.cpp | 44 ++++++++++++++ src/math/polynomial/algebraic_numbers.cpp | 6 +- src/nlsat/nlsat_interval_set.cpp | 73 +++++++++++++---------- src/qe/qe_mbi.cpp | 33 ++++++++++ src/qe/qe_mbi.h | 3 + 5 files changed, 127 insertions(+), 32 deletions(-) diff --git a/src/cmd_context/basic_cmds.cpp b/src/cmd_context/basic_cmds.cpp index 1c4f04ddd..83e3e1a44 100644 --- a/src/cmd_context/basic_cmds.cpp +++ b/src/cmd_context/basic_cmds.cpp @@ -31,6 +31,8 @@ Notes: #include "cmd_context/cmd_util.h" #include "cmd_context/simplify_cmd.h" #include "cmd_context/eval_cmd.h" +#include "qe/qe_mbp.h" +#include "qe/qe_mbi.h" class help_cmd : public cmd { svector m_cmds; @@ -849,6 +851,47 @@ public: void finalize(cmd_context & ctx) override {} }; +class get_interpolant_cmd : public cmd { + expr* m_a; + expr* m_b; +public: + get_interpolant_cmd():cmd("get-interpolant") {} + char const * get_usage() const override { return " "; } + char const * get_descr(cmd_context & ctx) const override { return "perform model based interpolation"; } + unsigned get_arity() const override { return 2; } + cmd_arg_kind next_arg_kind(cmd_context& ctx) const override { + return CPK_EXPR; + } + void set_next_arg(cmd_context& ctx, expr * arg) override { + if (m_a == nullptr) + m_a = arg; + else + m_b = arg; + } + void prepare(cmd_context & ctx) override { m_a = nullptr; m_b = nullptr; } + void execute(cmd_context & ctx) override { + ast_manager& m = ctx.m(); + qe::interpolator mbi(m); + expr_ref a(m_a, m); + expr_ref b(m_b, m); + expr_ref itp(m); + solver_factory& sf = ctx.get_solver_factory(); + params_ref p; + solver_ref sA = sf(m, p, false /* no proofs */, true, true, symbol::null); + solver_ref sB = sf(m, p, false /* no proofs */, true, true, symbol::null); + solver_ref sNotA = sf(m, p, false /* no proofs */, true, true, symbol::null); + sA->assert_expr(a); + sB->assert_expr(b); + qe::uflia_mbi pA(sA.get(), sNotA.get()); + qe::prop_mbi_plugin pB(sB.get()); + pA.set_shared(a, b); + pB.set_shared(a, b); + lbool res = mbi.pogo(pA, pB, itp); + ctx.regular_stream() << res << " " << itp << "\n"; + } +}; + + // provides "help" for builtin cmds class builtin_cmd : public cmd { char const * m_usage; @@ -898,6 +941,7 @@ void install_ext_basic_cmds(cmd_context & ctx) { ctx.insert(alloc(echo_cmd)); ctx.insert(alloc(labels_cmd)); ctx.insert(alloc(declare_map_cmd)); + ctx.insert(alloc(get_interpolant_cmd)); ctx.insert(alloc(builtin_cmd, "reset", nullptr, "reset the shell (all declarations and assertions will be erased)")); install_simplify_cmd(ctx); install_eval_cmd(ctx); diff --git a/src/math/polynomial/algebraic_numbers.cpp b/src/math/polynomial/algebraic_numbers.cpp index c8403e2b1..8e61d7ebb 100644 --- a/src/math/polynomial/algebraic_numbers.cpp +++ b/src/math/polynomial/algebraic_numbers.cpp @@ -1839,9 +1839,11 @@ namespace algebraic_numbers { m_compare_sturm++; upolynomial::scoped_upolynomial_sequence seq(upm()); upm().sturm_tarski_seq(cell_a->m_p_sz, cell_a->m_p, cell_b->m_p_sz, cell_b->m_p, seq); - int V = upm().sign_variations_at(seq, a_lower) - upm().sign_variations_at(seq, a_upper); + unsigned V1 = upm().sign_variations_at(seq, a_lower); + unsigned V2 = upm().sign_variations_at(seq, a_upper); + int V = V1 - V2; TRACE("algebraic", tout << "comparing using sturm\n"; display_interval(tout, a); tout << "\n"; display_interval(tout, b); tout << "\n"; - tout << "V: " << V << ", sign_lower(a): " << sign_lower(cell_a) << ", sign_lower(b): " << sign_lower(cell_b) << "\n";); + tout << "V: " << V << " V1 " << V1 << " V2 " << V2 << " sign_lower(a): " << sign_lower(cell_a) << ", sign_lower(b): " << sign_lower(cell_b) << "\n";); if (V == 0) return sign_zero; if ((V < 0) == (sign_lower(cell_b) < 0)) diff --git a/src/nlsat/nlsat_interval_set.cpp b/src/nlsat/nlsat_interval_set.cpp index 0957ac2cc..69b09e600 100644 --- a/src/nlsat/nlsat_interval_set.cpp +++ b/src/nlsat/nlsat_interval_set.cpp @@ -169,54 +169,58 @@ namespace nlsat { return new_set; } - inline int compare_lower_lower(anum_manager & am, interval const & i1, interval const & i2) { + inline ::sign compare_lower_lower(anum_manager & am, interval const & i1, interval const & i2) { if (i1.m_lower_inf && i2.m_lower_inf) - return 0; + return sign_zero; if (i1.m_lower_inf) - return -1; + return sign_neg; if (i2.m_lower_inf) - return 1; + return sign_pos; SASSERT(!i1.m_lower_inf && !i2.m_lower_inf); - int s = am.compare(i1.m_lower, i2.m_lower); - if (s != 0) + ::sign s = am.compare(i1.m_lower, i2.m_lower); + if (!is_zero(s)) return s; if (i1.m_lower_open == i2.m_lower_open) - return 0; + return sign_zero; if (i1.m_lower_open) - return 1; + return sign_pos; else - return -1; + return sign_neg; } - inline int compare_upper_upper(anum_manager & am, interval const & i1, interval const & i2) { + inline ::sign compare_upper_upper(anum_manager & am, interval const & i1, interval const & i2) { if (i1.m_upper_inf && i2.m_upper_inf) - return 0; + return sign_zero; if (i1.m_upper_inf) - return 1; + return sign_pos; if (i2.m_upper_inf) - return -1; + return sign_neg; SASSERT(!i1.m_upper_inf && !i2.m_upper_inf); - int s = am.compare(i1.m_upper, i2.m_upper); - if (s != 0) + auto s = am.compare(i1.m_upper, i2.m_upper); + if (!::is_zero(s)) return s; if (i1.m_upper_open == i2.m_upper_open) - return 0; + return sign_zero; if (i1.m_upper_open) - return -1; + return sign_neg; else - return 1; + return sign_pos; } - inline int compare_upper_lower(anum_manager & am, interval const & i1, interval const & i2) { - if (i1.m_upper_inf || i2.m_lower_inf) - return 1; + inline ::sign compare_upper_lower(anum_manager & am, interval const & i1, interval const & i2) { + if (i1.m_upper_inf || i2.m_lower_inf) { + TRACE("nlsat_interval", nlsat::display(tout << "i1: ", am, i1); nlsat::display(tout << "i2: ", am, i2);); + return sign_pos; + } SASSERT(!i1.m_upper_inf && !i2.m_lower_inf); - int s = am.compare(i1.m_upper, i2.m_lower); - if (s != 0) + auto s = am.compare(i1.m_upper, i2.m_lower); + TRACE("nlsat_interval", nlsat::display(tout << "i1: ", am, i1); nlsat::display(tout << " i2: ", am, i2); + tout << " compare: " << s << "\n";); + if (!::is_zero(s)) return s; if (!i1.m_upper_open && !i2.m_lower_open) - return 0; - return -1; + return sign_zero; + return sign_neg; } typedef sbuffer interval_buffer; @@ -227,9 +231,9 @@ namespace nlsat { bool adjacent(anum_manager & am, interval const & curr, interval const & next) { SASSERT(!curr.m_upper_inf); SASSERT(!next.m_lower_inf); - int sign = am.compare(curr.m_upper, next.m_lower); - SASSERT(sign <= 0); - if (sign == 0) { + auto sign = am.compare(curr.m_upper, next.m_lower); + SASSERT(sign != sign_pos); + if (is_zero(sign)) { SASSERT(curr.m_upper_open || next.m_lower_open); return !curr.m_upper_open || !next.m_lower_open; } @@ -271,6 +275,15 @@ namespace nlsat { } interval_set * interval_set_manager::mk_union(interval_set const * s1, interval_set const * s2) { +#if 0 + // issue #2867: + static unsigned s_count = 0; + s_count++; + if (s_count == 8442) { + enable_trace("nlsat_interval"); + enable_trace("algebraic"); + } +#endif TRACE("nlsat_interval", tout << "mk_union\ns1: "; display(tout, s1); tout << "\ns2: "; display(tout, s2); tout << "\n";); if (s1 == nullptr || s1 == s2) return const_cast(s2); @@ -421,7 +434,7 @@ namespace nlsat { // i2 may consume other intervals of s1 } else { - int u2_l1_sign = compare_upper_lower(m_am, int2, int1); + auto u2_l1_sign = compare_upper_lower(m_am, int2, int1); if (u2_l1_sign < 0) { TRACE("nlsat_interval", tout << "l1_l2_sign > 0, u1_u2_sign > 0, u2_l1_sign < 0\n";); // Case: @@ -430,7 +443,7 @@ namespace nlsat { push_back(m_am, result, int2); i2++; } - else if (u2_l1_sign == 0) { + else if (is_zero(u2_l1_sign)) { TRACE("nlsat_interval", tout << "l1_l2_sign > 0, u1_u2_sign > 0, u2_l1_sign == 0\n";); SASSERT(!int1.m_lower_open && !int2.m_upper_open); SASSERT(!int1.m_lower_inf); diff --git a/src/qe/qe_mbi.cpp b/src/qe/qe_mbi.cpp index fa6292c13..23ef7229b 100644 --- a/src/qe/qe_mbi.cpp +++ b/src/qe/qe_mbi.cpp @@ -29,6 +29,7 @@ Notes: --*/ #include "ast/ast_util.h" +#include "ast/ast_pp.h" #include "ast/for_each_expr.h" #include "ast/rewriter/expr_safe_replace.h" #include "ast/rewriter/bool_rewriter.h" @@ -43,6 +44,38 @@ Notes: namespace qe { + void mbi_plugin::set_shared(expr* a, expr* b) { + TRACE("qe", tout << mk_pp(a, m) << " " << mk_pp(b, m) << "\n";); + struct fun_proc { + obj_hashtable s; + void operator()(app* a) { if (is_uninterp(a)) s.insert(a->get_decl()); } + void operator()(expr*) {} + }; + fun_proc symbols_in_a; + expr_fast_mark1 marks; + quick_for_each_expr(symbols_in_a, marks, a); + marks.reset(); + m_shared_trail.reset(); + m_shared.reset(); + m_is_shared.reset(); + + struct intersect_proc { + mbi_plugin& p; + obj_hashtable& sA; + intersect_proc(mbi_plugin& p, obj_hashtable& sA):p(p), sA(sA) {} + void operator()(app* a) { + func_decl* f = a->get_decl(); + if (sA.contains(f) && !p.m_shared.contains(f)) { + p.m_shared_trail.push_back(f); + p.m_shared.insert(f); + } + } + void operator()(expr*) {} + }; + intersect_proc symbols_in_b(*this, symbols_in_a.s); + quick_for_each_expr(symbols_in_b, marks, b); + } + lbool mbi_plugin::check(expr_ref_vector& lits, model_ref& mdl) { while (true) { switch ((*this)(lits, mdl)) { diff --git a/src/qe/qe_mbi.h b/src/qe/qe_mbi.h index 0c33c933e..e602f407b 100644 --- a/src/qe/qe_mbi.h +++ b/src/qe/qe_mbi.h @@ -21,6 +21,7 @@ Revision History: #pragma once #include "qe/qe_arith.h" +#include "util/lbool.h" namespace qe { enum mbi_result { @@ -54,6 +55,8 @@ namespace qe { for (auto* f : vars) m_shared.insert(f); } + void set_shared(expr* a, expr* b); + /** * Set representative (shared) expression finder. */