mirror of
				https://github.com/Z3Prover/z3
				synced 2025-11-03 21:09:11 +00:00 
			
		
		
		
	add get-interpolant command
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
		
							parent
							
								
									d3b105f9f8
								
							
						
					
					
						commit
						9179deb746
					
				
					 5 changed files with 127 additions and 32 deletions
				
			
		| 
						 | 
				
			
			@ -31,6 +31,8 @@ Notes:
 | 
			
		|||
#include "cmd_context/cmd_util.h"
 | 
			
		||||
#include "cmd_context/simplify_cmd.h"
 | 
			
		||||
#include "cmd_context/eval_cmd.h"
 | 
			
		||||
#include "qe/qe_mbp.h"
 | 
			
		||||
#include "qe/qe_mbi.h"
 | 
			
		||||
 | 
			
		||||
class help_cmd : public cmd {
 | 
			
		||||
    svector<symbol> m_cmds;
 | 
			
		||||
| 
						 | 
				
			
			@ -849,6 +851,47 @@ public:
 | 
			
		|||
    void finalize(cmd_context & ctx) override {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class get_interpolant_cmd : public cmd {
 | 
			
		||||
    expr* m_a;
 | 
			
		||||
    expr* m_b;
 | 
			
		||||
public:
 | 
			
		||||
    get_interpolant_cmd():cmd("get-interpolant") {}
 | 
			
		||||
    char const * get_usage() const override { return "<expr> <expr>"; }
 | 
			
		||||
    char const * get_descr(cmd_context & ctx) const override { return "perform model based interpolation"; }
 | 
			
		||||
    unsigned get_arity() const override { return 2; }
 | 
			
		||||
    cmd_arg_kind next_arg_kind(cmd_context& ctx) const override {
 | 
			
		||||
        return CPK_EXPR; 
 | 
			
		||||
    }
 | 
			
		||||
    void set_next_arg(cmd_context& ctx, expr * arg) override { 
 | 
			
		||||
        if (m_a == nullptr) 
 | 
			
		||||
            m_a = arg; 
 | 
			
		||||
        else 
 | 
			
		||||
            m_b = arg; 
 | 
			
		||||
    }
 | 
			
		||||
    void prepare(cmd_context & ctx) override { m_a = nullptr; m_b = nullptr;  }
 | 
			
		||||
    void execute(cmd_context & ctx) override { 
 | 
			
		||||
        ast_manager& m = ctx.m();
 | 
			
		||||
        qe::interpolator mbi(m);
 | 
			
		||||
        expr_ref a(m_a, m);
 | 
			
		||||
        expr_ref b(m_b, m);
 | 
			
		||||
        expr_ref itp(m);
 | 
			
		||||
        solver_factory& sf = ctx.get_solver_factory();
 | 
			
		||||
        params_ref p;
 | 
			
		||||
        solver_ref sA = sf(m, p, false /* no proofs */, true, true, symbol::null);
 | 
			
		||||
        solver_ref sB = sf(m, p, false /* no proofs */, true, true, symbol::null);
 | 
			
		||||
        solver_ref sNotA = sf(m, p, false /* no proofs */, true, true, symbol::null);
 | 
			
		||||
        sA->assert_expr(a);
 | 
			
		||||
        sB->assert_expr(b);
 | 
			
		||||
        qe::uflia_mbi pA(sA.get(), sNotA.get());
 | 
			
		||||
        qe::prop_mbi_plugin pB(sB.get());
 | 
			
		||||
        pA.set_shared(a, b);
 | 
			
		||||
        pB.set_shared(a, b);
 | 
			
		||||
        lbool res = mbi.pogo(pA, pB, itp);
 | 
			
		||||
        ctx.regular_stream() << res << " " << itp << "\n";
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// provides "help" for builtin cmds
 | 
			
		||||
class builtin_cmd : public cmd {
 | 
			
		||||
    char const * m_usage;
 | 
			
		||||
| 
						 | 
				
			
			@ -898,6 +941,7 @@ void install_ext_basic_cmds(cmd_context & ctx) {
 | 
			
		|||
    ctx.insert(alloc(echo_cmd));
 | 
			
		||||
    ctx.insert(alloc(labels_cmd));
 | 
			
		||||
    ctx.insert(alloc(declare_map_cmd));
 | 
			
		||||
    ctx.insert(alloc(get_interpolant_cmd));
 | 
			
		||||
    ctx.insert(alloc(builtin_cmd, "reset", nullptr, "reset the shell (all declarations and assertions will be erased)"));
 | 
			
		||||
    install_simplify_cmd(ctx);
 | 
			
		||||
    install_eval_cmd(ctx);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1839,9 +1839,11 @@ namespace algebraic_numbers {
 | 
			
		|||
           m_compare_sturm++;
 | 
			
		||||
           upolynomial::scoped_upolynomial_sequence seq(upm());
 | 
			
		||||
           upm().sturm_tarski_seq(cell_a->m_p_sz, cell_a->m_p, cell_b->m_p_sz, cell_b->m_p, seq);
 | 
			
		||||
           int V = upm().sign_variations_at(seq, a_lower) - upm().sign_variations_at(seq, a_upper);
 | 
			
		||||
           unsigned V1 = upm().sign_variations_at(seq, a_lower);
 | 
			
		||||
           unsigned V2 = upm().sign_variations_at(seq, a_upper); 
 | 
			
		||||
           int V =  V1 - V2;
 | 
			
		||||
           TRACE("algebraic", tout << "comparing using sturm\n"; display_interval(tout, a); tout << "\n"; display_interval(tout, b); tout << "\n";
 | 
			
		||||
                 tout << "V: " << V << ", sign_lower(a): " << sign_lower(cell_a) << ", sign_lower(b): " << sign_lower(cell_b) << "\n";);
 | 
			
		||||
                 tout << "V: " << V << " V1 " << V1 << " V2 " << V2 << " sign_lower(a): " << sign_lower(cell_a) << ", sign_lower(b): " << sign_lower(cell_b) << "\n";);
 | 
			
		||||
           if (V == 0)
 | 
			
		||||
               return sign_zero;
 | 
			
		||||
           if ((V < 0) == (sign_lower(cell_b) < 0))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -169,54 +169,58 @@ namespace nlsat {
 | 
			
		|||
        return new_set;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    inline int compare_lower_lower(anum_manager & am, interval const & i1, interval const & i2) {
 | 
			
		||||
    inline ::sign compare_lower_lower(anum_manager & am, interval const & i1, interval const & i2) {
 | 
			
		||||
        if (i1.m_lower_inf && i2.m_lower_inf)
 | 
			
		||||
            return 0;
 | 
			
		||||
            return sign_zero;
 | 
			
		||||
        if (i1.m_lower_inf)
 | 
			
		||||
            return -1;
 | 
			
		||||
            return sign_neg;
 | 
			
		||||
        if (i2.m_lower_inf)
 | 
			
		||||
            return 1;
 | 
			
		||||
            return sign_pos;
 | 
			
		||||
        SASSERT(!i1.m_lower_inf && !i2.m_lower_inf);
 | 
			
		||||
        int s = am.compare(i1.m_lower, i2.m_lower);
 | 
			
		||||
        if (s != 0)
 | 
			
		||||
        ::sign s = am.compare(i1.m_lower, i2.m_lower);
 | 
			
		||||
        if (!is_zero(s)) 
 | 
			
		||||
            return s;
 | 
			
		||||
        if (i1.m_lower_open == i2.m_lower_open)
 | 
			
		||||
            return 0;
 | 
			
		||||
            return sign_zero;
 | 
			
		||||
        if (i1.m_lower_open)
 | 
			
		||||
            return 1;
 | 
			
		||||
            return sign_pos;
 | 
			
		||||
        else
 | 
			
		||||
            return -1;
 | 
			
		||||
            return sign_neg;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    inline int compare_upper_upper(anum_manager & am, interval const & i1, interval const & i2) {
 | 
			
		||||
    inline ::sign compare_upper_upper(anum_manager & am, interval const & i1, interval const & i2) {
 | 
			
		||||
        if (i1.m_upper_inf && i2.m_upper_inf)
 | 
			
		||||
            return 0;
 | 
			
		||||
            return sign_zero;
 | 
			
		||||
        if (i1.m_upper_inf)
 | 
			
		||||
            return 1;
 | 
			
		||||
            return sign_pos;
 | 
			
		||||
        if (i2.m_upper_inf)
 | 
			
		||||
            return -1;
 | 
			
		||||
            return sign_neg;
 | 
			
		||||
        SASSERT(!i1.m_upper_inf && !i2.m_upper_inf);
 | 
			
		||||
        int s = am.compare(i1.m_upper, i2.m_upper);
 | 
			
		||||
        if (s != 0)
 | 
			
		||||
        auto s = am.compare(i1.m_upper, i2.m_upper);
 | 
			
		||||
        if (!::is_zero(s)) 
 | 
			
		||||
            return s;
 | 
			
		||||
        if (i1.m_upper_open == i2.m_upper_open)
 | 
			
		||||
            return 0;
 | 
			
		||||
            return sign_zero;
 | 
			
		||||
        if (i1.m_upper_open)
 | 
			
		||||
            return -1;
 | 
			
		||||
            return sign_neg;
 | 
			
		||||
        else
 | 
			
		||||
            return 1;
 | 
			
		||||
            return sign_pos;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    inline int compare_upper_lower(anum_manager & am, interval const & i1, interval const & i2) {
 | 
			
		||||
        if (i1.m_upper_inf || i2.m_lower_inf)
 | 
			
		||||
            return 1;
 | 
			
		||||
    inline ::sign compare_upper_lower(anum_manager & am, interval const & i1, interval const & i2) {
 | 
			
		||||
        if (i1.m_upper_inf || i2.m_lower_inf) {
 | 
			
		||||
            TRACE("nlsat_interval", nlsat::display(tout << "i1: ", am, i1); nlsat::display(tout << "i2: ", am, i2););
 | 
			
		||||
            return sign_pos;
 | 
			
		||||
        }
 | 
			
		||||
        SASSERT(!i1.m_upper_inf && !i2.m_lower_inf);
 | 
			
		||||
        int s = am.compare(i1.m_upper, i2.m_lower);
 | 
			
		||||
        if (s != 0)
 | 
			
		||||
        auto s = am.compare(i1.m_upper, i2.m_lower);
 | 
			
		||||
        TRACE("nlsat_interval", nlsat::display(tout << "i1: ", am, i1); nlsat::display(tout << " i2: ", am, i2); 
 | 
			
		||||
              tout << " compare: " << s << "\n";);
 | 
			
		||||
        if (!::is_zero(s))
 | 
			
		||||
            return s;
 | 
			
		||||
        if (!i1.m_upper_open && !i2.m_lower_open)
 | 
			
		||||
            return 0;
 | 
			
		||||
        return -1;
 | 
			
		||||
            return sign_zero;
 | 
			
		||||
        return sign_neg;
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    typedef sbuffer<interval, 128> interval_buffer;
 | 
			
		||||
| 
						 | 
				
			
			@ -227,9 +231,9 @@ namespace nlsat {
 | 
			
		|||
    bool adjacent(anum_manager & am, interval const & curr, interval const & next) {
 | 
			
		||||
        SASSERT(!curr.m_upper_inf);
 | 
			
		||||
        SASSERT(!next.m_lower_inf);
 | 
			
		||||
        int sign = am.compare(curr.m_upper, next.m_lower);
 | 
			
		||||
        SASSERT(sign <= 0);
 | 
			
		||||
        if (sign == 0) {
 | 
			
		||||
        auto sign = am.compare(curr.m_upper, next.m_lower);
 | 
			
		||||
        SASSERT(sign != sign_pos);
 | 
			
		||||
        if (is_zero(sign)) {
 | 
			
		||||
            SASSERT(curr.m_upper_open || next.m_lower_open);
 | 
			
		||||
            return !curr.m_upper_open || !next.m_lower_open;
 | 
			
		||||
        }
 | 
			
		||||
| 
						 | 
				
			
			@ -271,6 +275,15 @@ namespace nlsat {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    interval_set * interval_set_manager::mk_union(interval_set const * s1, interval_set const * s2) {
 | 
			
		||||
#if 0
 | 
			
		||||
        // issue #2867:
 | 
			
		||||
        static unsigned s_count = 0;
 | 
			
		||||
        s_count++;
 | 
			
		||||
        if (s_count == 8442) {
 | 
			
		||||
            enable_trace("nlsat_interval");
 | 
			
		||||
            enable_trace("algebraic");
 | 
			
		||||
        }
 | 
			
		||||
#endif
 | 
			
		||||
        TRACE("nlsat_interval", tout << "mk_union\ns1: "; display(tout, s1); tout << "\ns2: "; display(tout, s2); tout << "\n";);
 | 
			
		||||
        if (s1 == nullptr || s1 == s2)
 | 
			
		||||
            return const_cast<interval_set*>(s2);
 | 
			
		||||
| 
						 | 
				
			
			@ -421,7 +434,7 @@ namespace nlsat {
 | 
			
		|||
                    // i2 may consume other intervals of s1
 | 
			
		||||
                }
 | 
			
		||||
                else {
 | 
			
		||||
                    int u2_l1_sign = compare_upper_lower(m_am, int2, int1);
 | 
			
		||||
                    auto u2_l1_sign = compare_upper_lower(m_am, int2, int1);
 | 
			
		||||
                    if (u2_l1_sign < 0) {
 | 
			
		||||
                        TRACE("nlsat_interval", tout << "l1_l2_sign > 0, u1_u2_sign > 0, u2_l1_sign < 0\n";);
 | 
			
		||||
                        // Case:
 | 
			
		||||
| 
						 | 
				
			
			@ -430,7 +443,7 @@ namespace nlsat {
 | 
			
		|||
                        push_back(m_am, result, int2);
 | 
			
		||||
                        i2++;
 | 
			
		||||
                    }
 | 
			
		||||
                    else if (u2_l1_sign == 0) {
 | 
			
		||||
                    else if (is_zero(u2_l1_sign)) {
 | 
			
		||||
                        TRACE("nlsat_interval", tout << "l1_l2_sign > 0, u1_u2_sign > 0, u2_l1_sign == 0\n";);
 | 
			
		||||
                        SASSERT(!int1.m_lower_open && !int2.m_upper_open);
 | 
			
		||||
                        SASSERT(!int1.m_lower_inf);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,6 +29,7 @@ Notes:
 | 
			
		|||
--*/
 | 
			
		||||
 | 
			
		||||
#include "ast/ast_util.h"
 | 
			
		||||
#include "ast/ast_pp.h"
 | 
			
		||||
#include "ast/for_each_expr.h"
 | 
			
		||||
#include "ast/rewriter/expr_safe_replace.h"
 | 
			
		||||
#include "ast/rewriter/bool_rewriter.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -43,6 +44,38 @@ Notes:
 | 
			
		|||
 | 
			
		||||
namespace qe {
 | 
			
		||||
 | 
			
		||||
    void mbi_plugin::set_shared(expr* a, expr* b) {
 | 
			
		||||
        TRACE("qe", tout << mk_pp(a, m) << " " << mk_pp(b, m) << "\n";);
 | 
			
		||||
        struct fun_proc {
 | 
			
		||||
            obj_hashtable<func_decl> s;
 | 
			
		||||
            void operator()(app* a) { if (is_uninterp(a)) s.insert(a->get_decl()); }
 | 
			
		||||
            void operator()(expr*) {}
 | 
			
		||||
        };
 | 
			
		||||
        fun_proc symbols_in_a;
 | 
			
		||||
        expr_fast_mark1 marks;
 | 
			
		||||
        quick_for_each_expr(symbols_in_a, marks, a);
 | 
			
		||||
        marks.reset();
 | 
			
		||||
        m_shared_trail.reset();
 | 
			
		||||
        m_shared.reset();
 | 
			
		||||
        m_is_shared.reset();
 | 
			
		||||
        
 | 
			
		||||
        struct intersect_proc {
 | 
			
		||||
            mbi_plugin& p;
 | 
			
		||||
            obj_hashtable<func_decl>& sA;
 | 
			
		||||
            intersect_proc(mbi_plugin& p, obj_hashtable<func_decl>& sA):p(p), sA(sA) {}
 | 
			
		||||
            void operator()(app* a) { 
 | 
			
		||||
                func_decl* f = a->get_decl();
 | 
			
		||||
                if (sA.contains(f) && !p.m_shared.contains(f)) {
 | 
			
		||||
                    p.m_shared_trail.push_back(f);
 | 
			
		||||
                    p.m_shared.insert(f);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            void operator()(expr*) {}
 | 
			
		||||
        };
 | 
			
		||||
        intersect_proc symbols_in_b(*this, symbols_in_a.s);
 | 
			
		||||
        quick_for_each_expr(symbols_in_b, marks, b);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    lbool mbi_plugin::check(expr_ref_vector& lits, model_ref& mdl) {
 | 
			
		||||
        while (true) {
 | 
			
		||||
            switch ((*this)(lits, mdl)) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,6 +21,7 @@ Revision History:
 | 
			
		|||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "qe/qe_arith.h"
 | 
			
		||||
#include "util/lbool.h"
 | 
			
		||||
 | 
			
		||||
namespace qe {
 | 
			
		||||
    enum mbi_result {
 | 
			
		||||
| 
						 | 
				
			
			@ -54,6 +55,8 @@ namespace qe {
 | 
			
		|||
            for (auto* f : vars) m_shared.insert(f);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        void set_shared(expr* a, expr* b);
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Set representative (shared) expression finder.
 | 
			
		||||
         */
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue