3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-13 12:28:44 +00:00

add get-interpolant command

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2020-01-19 12:12:04 -06:00
parent d3b105f9f8
commit 9179deb746
5 changed files with 127 additions and 32 deletions

View file

@ -31,6 +31,8 @@ Notes:
#include "cmd_context/cmd_util.h"
#include "cmd_context/simplify_cmd.h"
#include "cmd_context/eval_cmd.h"
#include "qe/qe_mbp.h"
#include "qe/qe_mbi.h"
class help_cmd : public cmd {
svector<symbol> m_cmds;
@ -849,6 +851,47 @@ public:
void finalize(cmd_context & ctx) override {}
};
class get_interpolant_cmd : public cmd {
expr* m_a;
expr* m_b;
public:
get_interpolant_cmd():cmd("get-interpolant") {}
char const * get_usage() const override { return "<expr> <expr>"; }
char const * get_descr(cmd_context & ctx) const override { return "perform model based interpolation"; }
unsigned get_arity() const override { return 2; }
cmd_arg_kind next_arg_kind(cmd_context& ctx) const override {
return CPK_EXPR;
}
void set_next_arg(cmd_context& ctx, expr * arg) override {
if (m_a == nullptr)
m_a = arg;
else
m_b = arg;
}
void prepare(cmd_context & ctx) override { m_a = nullptr; m_b = nullptr; }
void execute(cmd_context & ctx) override {
ast_manager& m = ctx.m();
qe::interpolator mbi(m);
expr_ref a(m_a, m);
expr_ref b(m_b, m);
expr_ref itp(m);
solver_factory& sf = ctx.get_solver_factory();
params_ref p;
solver_ref sA = sf(m, p, false /* no proofs */, true, true, symbol::null);
solver_ref sB = sf(m, p, false /* no proofs */, true, true, symbol::null);
solver_ref sNotA = sf(m, p, false /* no proofs */, true, true, symbol::null);
sA->assert_expr(a);
sB->assert_expr(b);
qe::uflia_mbi pA(sA.get(), sNotA.get());
qe::prop_mbi_plugin pB(sB.get());
pA.set_shared(a, b);
pB.set_shared(a, b);
lbool res = mbi.pogo(pA, pB, itp);
ctx.regular_stream() << res << " " << itp << "\n";
}
};
// provides "help" for builtin cmds
class builtin_cmd : public cmd {
char const * m_usage;
@ -898,6 +941,7 @@ void install_ext_basic_cmds(cmd_context & ctx) {
ctx.insert(alloc(echo_cmd));
ctx.insert(alloc(labels_cmd));
ctx.insert(alloc(declare_map_cmd));
ctx.insert(alloc(get_interpolant_cmd));
ctx.insert(alloc(builtin_cmd, "reset", nullptr, "reset the shell (all declarations and assertions will be erased)"));
install_simplify_cmd(ctx);
install_eval_cmd(ctx);

View file

@ -1839,9 +1839,11 @@ namespace algebraic_numbers {
m_compare_sturm++;
upolynomial::scoped_upolynomial_sequence seq(upm());
upm().sturm_tarski_seq(cell_a->m_p_sz, cell_a->m_p, cell_b->m_p_sz, cell_b->m_p, seq);
int V = upm().sign_variations_at(seq, a_lower) - upm().sign_variations_at(seq, a_upper);
unsigned V1 = upm().sign_variations_at(seq, a_lower);
unsigned V2 = upm().sign_variations_at(seq, a_upper);
int V = V1 - V2;
TRACE("algebraic", tout << "comparing using sturm\n"; display_interval(tout, a); tout << "\n"; display_interval(tout, b); tout << "\n";
tout << "V: " << V << ", sign_lower(a): " << sign_lower(cell_a) << ", sign_lower(b): " << sign_lower(cell_b) << "\n";);
tout << "V: " << V << " V1 " << V1 << " V2 " << V2 << " sign_lower(a): " << sign_lower(cell_a) << ", sign_lower(b): " << sign_lower(cell_b) << "\n";);
if (V == 0)
return sign_zero;
if ((V < 0) == (sign_lower(cell_b) < 0))

View file

@ -169,54 +169,58 @@ namespace nlsat {
return new_set;
}
inline int compare_lower_lower(anum_manager & am, interval const & i1, interval const & i2) {
inline ::sign compare_lower_lower(anum_manager & am, interval const & i1, interval const & i2) {
if (i1.m_lower_inf && i2.m_lower_inf)
return 0;
return sign_zero;
if (i1.m_lower_inf)
return -1;
return sign_neg;
if (i2.m_lower_inf)
return 1;
return sign_pos;
SASSERT(!i1.m_lower_inf && !i2.m_lower_inf);
int s = am.compare(i1.m_lower, i2.m_lower);
if (s != 0)
::sign s = am.compare(i1.m_lower, i2.m_lower);
if (!is_zero(s))
return s;
if (i1.m_lower_open == i2.m_lower_open)
return 0;
return sign_zero;
if (i1.m_lower_open)
return 1;
return sign_pos;
else
return -1;
return sign_neg;
}
inline int compare_upper_upper(anum_manager & am, interval const & i1, interval const & i2) {
inline ::sign compare_upper_upper(anum_manager & am, interval const & i1, interval const & i2) {
if (i1.m_upper_inf && i2.m_upper_inf)
return 0;
return sign_zero;
if (i1.m_upper_inf)
return 1;
return sign_pos;
if (i2.m_upper_inf)
return -1;
return sign_neg;
SASSERT(!i1.m_upper_inf && !i2.m_upper_inf);
int s = am.compare(i1.m_upper, i2.m_upper);
if (s != 0)
auto s = am.compare(i1.m_upper, i2.m_upper);
if (!::is_zero(s))
return s;
if (i1.m_upper_open == i2.m_upper_open)
return 0;
return sign_zero;
if (i1.m_upper_open)
return -1;
return sign_neg;
else
return 1;
return sign_pos;
}
inline int compare_upper_lower(anum_manager & am, interval const & i1, interval const & i2) {
if (i1.m_upper_inf || i2.m_lower_inf)
return 1;
inline ::sign compare_upper_lower(anum_manager & am, interval const & i1, interval const & i2) {
if (i1.m_upper_inf || i2.m_lower_inf) {
TRACE("nlsat_interval", nlsat::display(tout << "i1: ", am, i1); nlsat::display(tout << "i2: ", am, i2););
return sign_pos;
}
SASSERT(!i1.m_upper_inf && !i2.m_lower_inf);
int s = am.compare(i1.m_upper, i2.m_lower);
if (s != 0)
auto s = am.compare(i1.m_upper, i2.m_lower);
TRACE("nlsat_interval", nlsat::display(tout << "i1: ", am, i1); nlsat::display(tout << " i2: ", am, i2);
tout << " compare: " << s << "\n";);
if (!::is_zero(s))
return s;
if (!i1.m_upper_open && !i2.m_lower_open)
return 0;
return -1;
return sign_zero;
return sign_neg;
}
typedef sbuffer<interval, 128> interval_buffer;
@ -227,9 +231,9 @@ namespace nlsat {
bool adjacent(anum_manager & am, interval const & curr, interval const & next) {
SASSERT(!curr.m_upper_inf);
SASSERT(!next.m_lower_inf);
int sign = am.compare(curr.m_upper, next.m_lower);
SASSERT(sign <= 0);
if (sign == 0) {
auto sign = am.compare(curr.m_upper, next.m_lower);
SASSERT(sign != sign_pos);
if (is_zero(sign)) {
SASSERT(curr.m_upper_open || next.m_lower_open);
return !curr.m_upper_open || !next.m_lower_open;
}
@ -271,6 +275,15 @@ namespace nlsat {
}
interval_set * interval_set_manager::mk_union(interval_set const * s1, interval_set const * s2) {
#if 0
// issue #2867:
static unsigned s_count = 0;
s_count++;
if (s_count == 8442) {
enable_trace("nlsat_interval");
enable_trace("algebraic");
}
#endif
TRACE("nlsat_interval", tout << "mk_union\ns1: "; display(tout, s1); tout << "\ns2: "; display(tout, s2); tout << "\n";);
if (s1 == nullptr || s1 == s2)
return const_cast<interval_set*>(s2);
@ -421,7 +434,7 @@ namespace nlsat {
// i2 may consume other intervals of s1
}
else {
int u2_l1_sign = compare_upper_lower(m_am, int2, int1);
auto u2_l1_sign = compare_upper_lower(m_am, int2, int1);
if (u2_l1_sign < 0) {
TRACE("nlsat_interval", tout << "l1_l2_sign > 0, u1_u2_sign > 0, u2_l1_sign < 0\n";);
// Case:
@ -430,7 +443,7 @@ namespace nlsat {
push_back(m_am, result, int2);
i2++;
}
else if (u2_l1_sign == 0) {
else if (is_zero(u2_l1_sign)) {
TRACE("nlsat_interval", tout << "l1_l2_sign > 0, u1_u2_sign > 0, u2_l1_sign == 0\n";);
SASSERT(!int1.m_lower_open && !int2.m_upper_open);
SASSERT(!int1.m_lower_inf);

View file

@ -29,6 +29,7 @@ Notes:
--*/
#include "ast/ast_util.h"
#include "ast/ast_pp.h"
#include "ast/for_each_expr.h"
#include "ast/rewriter/expr_safe_replace.h"
#include "ast/rewriter/bool_rewriter.h"
@ -43,6 +44,38 @@ Notes:
namespace qe {
void mbi_plugin::set_shared(expr* a, expr* b) {
TRACE("qe", tout << mk_pp(a, m) << " " << mk_pp(b, m) << "\n";);
struct fun_proc {
obj_hashtable<func_decl> s;
void operator()(app* a) { if (is_uninterp(a)) s.insert(a->get_decl()); }
void operator()(expr*) {}
};
fun_proc symbols_in_a;
expr_fast_mark1 marks;
quick_for_each_expr(symbols_in_a, marks, a);
marks.reset();
m_shared_trail.reset();
m_shared.reset();
m_is_shared.reset();
struct intersect_proc {
mbi_plugin& p;
obj_hashtable<func_decl>& sA;
intersect_proc(mbi_plugin& p, obj_hashtable<func_decl>& sA):p(p), sA(sA) {}
void operator()(app* a) {
func_decl* f = a->get_decl();
if (sA.contains(f) && !p.m_shared.contains(f)) {
p.m_shared_trail.push_back(f);
p.m_shared.insert(f);
}
}
void operator()(expr*) {}
};
intersect_proc symbols_in_b(*this, symbols_in_a.s);
quick_for_each_expr(symbols_in_b, marks, b);
}
lbool mbi_plugin::check(expr_ref_vector& lits, model_ref& mdl) {
while (true) {
switch ((*this)(lits, mdl)) {

View file

@ -21,6 +21,7 @@ Revision History:
#pragma once
#include "qe/qe_arith.h"
#include "util/lbool.h"
namespace qe {
enum mbi_result {
@ -54,6 +55,8 @@ namespace qe {
for (auto* f : vars) m_shared.insert(f);
}
void set_shared(expr* a, expr* b);
/**
* Set representative (shared) expression finder.
*/