From c36d9f7b3eb3f5e9805d2acd333f56ace4b6fc0b Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Tue, 26 Nov 2019 19:45:19 -0800
Subject: [PATCH] fix #2741

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
---
 src/ast/arith_decl_plugin.cpp            |   5 +
 src/ast/arith_decl_plugin.h              |   1 +
 src/tactic/arith/purify_arith_tactic.cpp |  22 +++-
 src/tactic/fd_solver/smtfd_solver.cpp    | 143 ++++++++++++++++-------
 src/util/parray.h                        |  52 ++++++---
 5 files changed, 164 insertions(+), 59 deletions(-)

diff --git a/src/ast/arith_decl_plugin.cpp b/src/ast/arith_decl_plugin.cpp
index 82ab483aa..a93b6258d 100644
--- a/src/ast/arith_decl_plugin.cpp
+++ b/src/ast/arith_decl_plugin.cpp
@@ -822,3 +822,8 @@ func_decl* arith_util::mk_div0() {
     sort* rs[2] = { mk_real(), mk_real() };
     return m_manager.mk_func_decl(m_afid, OP_DIV0, 0, nullptr, 2, rs, mk_real());
 }
+
+func_decl* arith_util::mk_idiv0() {
+    sort* rs[2] = { mk_int(), mk_int() };
+    return m_manager.mk_func_decl(m_afid, OP_IDIV0, 0, nullptr, 2, rs, mk_int());
+}
diff --git a/src/ast/arith_decl_plugin.h b/src/ast/arith_decl_plugin.h
index 0bc572feb..86241a1fa 100644
--- a/src/ast/arith_decl_plugin.h
+++ b/src/ast/arith_decl_plugin.h
@@ -368,6 +368,7 @@ public:
     sort * mk_real() { return m_manager.mk_sort(m_afid, REAL_SORT); }
 
     func_decl* mk_div0();
+    func_decl* mk_idiv0();
 
 
     app * mk_numeral(rational const & val, bool is_int) const {
diff --git a/src/tactic/arith/purify_arith_tactic.cpp b/src/tactic/arith/purify_arith_tactic.cpp
index 1f467999e..f6e038d60 100644
--- a/src/tactic/arith/purify_arith_tactic.cpp
+++ b/src/tactic/arith/purify_arith_tactic.cpp
@@ -193,7 +193,7 @@ struct purify_arith_proc {
         expr_ref_vector      m_pinned;
         expr_ref_vector      m_new_cnstrs;
         proof_ref_vector     m_new_cnstr_prs;
-        svector<div_def>     m_divs;
+        svector<div_def>     m_divs, m_idivs;
         expr_ref             m_subst;
         proof_ref            m_subst_pr;
         expr_ref_vector      m_new_vars;
@@ -361,6 +361,7 @@ struct purify_arith_proc {
                 push_cnstr(OR(NOT(EQ(y, zero)), EQ(k2, u().mk_mod(x, zero))));
                 push_cnstr_pr(mod_pr);
             }
+            m_idivs.push_back(div_def(x, y, k1));
         }
    
         void process_mod(func_decl * f, unsigned num, expr * const * args, expr_ref & result, proof_ref & result_pr) { 
@@ -775,6 +776,7 @@ struct purify_arith_proc {
             m_goal.assert_expr(r.cfg().m_new_cnstrs.get(i), m_produce_proofs ? r.cfg().m_new_cnstr_prs.get(i) : nullptr, nullptr);
         }
         auto const& divs = r.cfg().m_divs;
+        auto const& idivs = r.cfg().m_idivs;
         for (unsigned i = 0; i < divs.size(); ++i) {
             auto const& p1 = divs[i];
             for (unsigned j = i + 1; j < divs.size(); ++j) {
@@ -784,6 +786,15 @@ struct purify_arith_proc {
                                        m().mk_eq(p1.d, p2.d)));
             }
         }
+        for (unsigned i = 0; i < idivs.size(); ++i) {
+            auto const& p1 = idivs[i];
+            for (unsigned j = i + 1; j < idivs.size(); ++j) {
+                auto const& p2 = idivs[j];
+                m_goal.assert_expr(m().mk_implies(
+                                       m().mk_and(m().mk_eq(p1.x, p2.x), m().mk_eq(p1.y, p2.y)), 
+                                       m().mk_eq(p1.d, p2.d)));
+            }
+        }
         
         // add generic_model_converter to eliminate auxiliary variables from model
         if (produce_models) {
@@ -804,6 +815,15 @@ struct purify_arith_proc {
                 }
                 fmc->add(u().mk_div0(), body);
             }
+            if (!idivs.empty()) {
+                expr_ref body(u().mk_int(0), m());
+                expr_ref v0(m().mk_var(0, u().mk_int()), m());
+                expr_ref v1(m().mk_var(1, u().mk_int()), m());
+                for (auto const& p : idivs) {
+                    body = m().mk_ite(m().mk_and(m().mk_eq(v0, p.x), m().mk_eq(v1, p.y)), p.d, body);
+                }
+                fmc->add(u().mk_idiv0(), body);
+            }
         }
         if (produce_models && !m_sin_cos.empty()) {
             generic_model_converter* emc = alloc(generic_model_converter, m(), "purify_sin_cos");
diff --git a/src/tactic/fd_solver/smtfd_solver.cpp b/src/tactic/fd_solver/smtfd_solver.cpp
index 90fcc14cf..f98c848a0 100644
--- a/src/tactic/fd_solver/smtfd_solver.cpp
+++ b/src/tactic/fd_solver/smtfd_solver.cpp
@@ -789,10 +789,8 @@ namespace smtfd {
                 expr_ref v = eval_abs(t);
                 val2elem_t& v2e = get_table(s);
                 expr* e;
-                if (v2e.find(v, e)) {
-                    if (e != t) {
-                        m_context.add(m.mk_not(m.mk_eq(e, t)), __FUNCTION__);
-                    }
+                if (v2e.find(v, e) && e != t && m.is_value(e)) {
+                    m_context.add(m.mk_not(m.mk_eq(e, t)), __FUNCTION__);
                 }
                 else {
                     m_pinned.push_back(v);
@@ -1333,22 +1331,23 @@ namespace smtfd {
                 }
             }
         }
-
     };
 
+
     class mbqi {
-        ast_manager&    m;
-        plugin_context& m_context;
-        obj_hashtable<quantifier>& m_enforced;
-        model_ref       m_model;
-        ref<::solver>   m_solver;
-        expr_ref_vector m_pinned;
+        ast_manager&                    m;
+        plugin_context&                 m_context;
+        obj_hashtable<quantifier>       m_enforced;
+        model_ref                       m_model;
+        ref<::solver>                   m_solver;
         obj_pair_map<expr, sort, expr*> m_val2term;
+        expr_ref_vector                 m_val2term_trail;
+        obj_map<sort, expr*>            m_fresh;
+        expr_ref_vector                 m_fresh_trail;
 
         expr* abs(expr* e) { return m_context.get_abs().abs(e);  }
         expr_ref eval_abs(expr* t) { return (*m_model)(abs(t)); }
 
-
         void restrict_to_universe(expr * sk, ptr_vector<expr> const & universe) {
             SASSERT(!universe.empty());
             expr_ref_vector eqs(m);
@@ -1359,6 +1358,32 @@ namespace smtfd {
             m_solver->assert_expr(fml);
         }
 
+        expr_ref fresh(sort* s) {
+            expr* e = nullptr;
+            if (!m_fresh.find(s, e)) {
+                e = m.mk_fresh_const(s->get_name(), s, false);
+                m_fresh.insert(s, e);
+                m_fresh_trail.push_back(e);
+            }
+            return expr_ref(e, m);
+        }
+
+        expr_ref replace_model_value(expr* e) {
+            if (m.is_model_value(e)) {                
+                expr_ref r = fresh(m.get_sort(e));
+                std::cout << expr_ref(e, m) << " |-> " << r << "\n";
+                return r;
+            }
+            if (is_app(e) && to_app(e)->get_num_args() > 0) {
+                expr_ref_vector args(m);
+                for (expr* arg : *to_app(e)) {
+                    args.push_back(replace_model_value(arg));
+                }
+                return expr_ref(m.mk_app(to_app(e)->get_decl(), args.size(), args.c_ptr()), m);
+            }
+            return expr_ref(e, m);
+        }
+
         // !Ex P(x) => !P(t)
         // Ax P(x) => P(t)
         // l_true: new instance
@@ -1408,11 +1433,7 @@ namespace smtfd {
             if (r == l_true) {
                 expr_ref qq(q->get_expr(), m);
                 for (expr* t : subterms(qq)) {
-                    if (is_ground(t)) {                       
-                        expr_ref v = eval_abs(t);
-                        m_pinned.push_back(v);
-                        m_val2term.insert(v, m.get_sort(t), t);
-                    }
+                    init_term(t);
                 }
                 m_solver->get_model(mdl);
                 TRACE("smtfd", tout << *mdl << "\n";);
@@ -1429,6 +1450,9 @@ namespace smtfd {
                     if (m_val2term.find(val, m.get_sort(v), t)) {
                         val = t;
                     }
+                    else {
+                        val = replace_model_value(val);
+                    }
                     vals[i] = val;
                 }
             }
@@ -1443,12 +1467,14 @@ namespace smtfd {
                 else {
                     body = m.mk_implies(body, q);
                 }
+                IF_VERBOSE(1, verbose_stream() << body << "\n");
                 m_context.add(body, __FUNCTION__);
             }
             m_solver->pop(1);
             return r;
         }
 
+        // 
         lbool check_exists(quantifier* q) {
             if (m_enforced.contains(q)) {
                 return l_true;
@@ -1473,36 +1499,51 @@ namespace smtfd {
             return l_true;
         }
 
-        void init_val2term(expr_ref_vector const& core) {
-            for (expr* t : subterms(core)) {
-                if (!m.is_bool(t) && is_ground(t)) {
-                    expr_ref v = eval_abs(t);
-                    m_pinned.push_back(v);
+        void init_term(expr* t) {
+            if (!m.is_bool(t) && is_ground(t)) {
+                expr_ref v = eval_abs(t);
+                if (!m_val2term.contains(v, m.get_sort(t))) {
                     m_val2term.insert(v, m.get_sort(t), t);
+                    m_val2term_trail.push_back(v);
                 }
             }
         }
 
     public:
 
-        mbqi(::solver* s, plugin_context& c, obj_hashtable<quantifier>& enforced, model_ref& mdl):
-            m(s->get_manager()),
+        mbqi(plugin_context& c):
+            m(c.get_manager()),
             m_context(c),
-            m_enforced(enforced),
-            m_model(mdl),
-            m_solver(s),
-            m_pinned(m)
+            m_model(nullptr),
+            m_solver(nullptr),
+            m_val2term_trail(m),
+            m_fresh_trail(m)
         {}
+
+        void set_model(model* mdl) { m_model = mdl; }
+
+        ref<::solver> & get_solver() { return m_solver; }
             
+        void init_val2term(expr_ref_vector const& fmls, expr_ref_vector const& core) {
+            m_val2term_trail.reset();
+            m_val2term.reset();
+            for (expr* t : subterms(core)) {
+                init_term(t);
+            }
+            for (expr* t : subterms(fmls)) {
+                init_term(t);
+            }
+        }
+
         bool check_quantifiers(expr_ref_vector const& core) {
             bool result = true;
-            init_val2term(core);
             IF_VERBOSE(9, 
                        for (expr* c : core) {
                            verbose_stream() << "core: " << mk_bounded_pp(c, m, 2) << "\n";
                        });
             for (expr* c : core) {
                 lbool r = l_false;
+                IF_VERBOSE(1, verbose_stream() << "core: " << mk_bounded_pp(c, m, 2) << "\n");
                 if (is_forall(c)) {
                     r = check_forall(to_quantifier(c));
                 }
@@ -1536,6 +1577,7 @@ namespace smtfd {
     class solver : public solver_na2as {
         ast_manager&    m;
         mutable smtfd_abs       m_abs;
+        unsigned        m_indent;
         plugin_context  m_context;
         uf_plugin       m_uf;
         ar_plugin       m_ar;
@@ -1545,18 +1587,18 @@ namespace smtfd {
         ref<::solver>   m_fd_sat_solver;
         ref<::solver>   m_fd_core_solver;
         ref<::solver>   m_smt_solver;
-        ref<solver>     m_mbqi_solver;
+        mbqi            m_mbqi;
         expr_ref_vector m_assertions;
         unsigned_vector m_assertions_lim;
         unsigned        m_assertions_qhead;
         expr_ref_vector m_axioms;
+        unsigned_vector m_axioms_lim;
         expr_ref_vector m_toggles;
         unsigned_vector m_toggles_lim;
         model_ref       m_model;
         std::string     m_reason_unknown;
         stats           m_stats;
         unsigned        m_max_conflicts;
-        obj_hashtable<quantifier> m_enforced_quantifier;
 
         void set_delay_simplify() {
             params_ref p;
@@ -1574,6 +1616,10 @@ namespace smtfd {
             return m_toggles.contains(e);
         }
 
+        void indent() {
+            for (unsigned i = 0; i < m_indent; ++i) verbose_stream() << " ";
+        }
+
         void flush_assertions() {
             SASSERT(m_assertions_qhead <= m_assertions.size());
             unsigned sz = m_assertions.size() - m_assertions_qhead;
@@ -1629,7 +1675,7 @@ namespace smtfd {
         }
 
         lbool check_smt(expr_ref_vector& core) {
-            IF_VERBOSE(10, verbose_stream() << "core: " << core.size() << "\n");
+            IF_VERBOSE(10, indent(); verbose_stream() << "core: " << core.size() << "\n");
             params_ref p;
             p.set_uint("max_conflicts", m_max_conflicts);
             m_smt_solver->updt_params(p);
@@ -1709,15 +1755,18 @@ namespace smtfd {
                           tout << m_context.term_covered(a) << " " << m_context.sort_covered(m.get_sort(a)) << "\n";
                       }
                   }
-                  tout << "has quantifier: " << has_q << "\n" << core << "\n";);
+                  tout << "has quantifier: " << has_q << "\n" << core << "\n";
+                  tout << *m_model.get() << "\n";
+                  );
             if (!has_q) {
                 return is_decided;
             }
-            if (!m_mbqi_solver) {
-                m_mbqi_solver = alloc(solver, m, get_params());
+            m_mbqi.set_model(m_model.get());
+            if (!m_mbqi.get_solver()) {
+                m_mbqi.get_solver() = alloc(solver, m_indent + 1, m, get_params());
             }
-            mbqi mb(m_mbqi_solver.get(), m_context, m_enforced_quantifier, m_model);
-            if (!mb.check_quantifiers(core) && m_context.empty()) {
+            m_mbqi.init_val2term(m_assertions, core);
+            if (!m_mbqi.check_quantifiers(core) && m_context.empty()) {
                 return l_false;
             }
             for (expr* f : m_context) {
@@ -1791,9 +1840,10 @@ namespace smtfd {
         }
         
     public:
-        solver(ast_manager& m, params_ref const& p):
+        solver(unsigned indent, ast_manager& m, params_ref const& p):
             solver_na2as(m),
             m(m),
+            m_indent(indent),
             m_abs(m),
             m_context(m_abs, m),
             m_uf(m_context),
@@ -1801,6 +1851,7 @@ namespace smtfd {
             m_bv(m_context),
             m_bs(m_context),
             m_pb(m_context),
+            m_mbqi(m_context),
             m_assertions(m),
             m_assertions_qhead(0),
             m_axioms(m),
@@ -1814,7 +1865,7 @@ namespace smtfd {
         ~solver() override {}
         
         ::solver* translate(ast_manager& dst_m, params_ref const& p) override {
-            solver* result = alloc(solver, dst_m, p);
+            solver* result = alloc(solver, m_indent, dst_m, p);
             if (m_smt_solver) result->m_smt_solver = m_smt_solver->translate(dst_m, p);
             if (m_fd_sat_solver) result->m_fd_sat_solver = m_fd_sat_solver->translate(dst_m, p);
             if (m_fd_core_solver) result->m_fd_core_solver = m_fd_core_solver->translate(dst_m, p);
@@ -1833,6 +1884,7 @@ namespace smtfd {
             m_fd_core_solver->push();
             m_smt_solver->push();
             m_assertions_lim.push_back(m_assertions.size());
+            m_axioms_lim.push_back(m_axioms.size());
             m_toggles_lim.push_back(m_toggles.size());
         }
         
@@ -1846,6 +1898,8 @@ namespace smtfd {
             m_toggles_lim.shrink(m_toggles_lim.size() - n);
             m_assertions.shrink(m_assertions_lim[m_assertions_lim.size() - n]);
             m_assertions_lim.shrink(m_assertions_lim.size() - n);
+            m_axioms.shrink(m_axioms_lim[m_axioms_lim.size() - n]);
+            m_axioms_lim.shrink(m_axioms_lim.size() - n);
             m_assertions_qhead = m_assertions.size();
         }
 
@@ -1861,6 +1915,11 @@ namespace smtfd {
         void assert_fd(expr* fml) {
             expr_ref _fml(fml, m);
             TRACE("smtfd", tout << mk_bounded_pp(fml, m, 3) << "\n";);
+            CTRACE("smtfd", m_axioms.contains(fml), 
+                   tout << "formula:\n" << _fml << "\n";
+                   tout << "axioms:\n" << m_axioms << "\n";
+                   tout << "assertions:\n" << m_assertions << "\n";);
+
             SASSERT(!m_axioms.contains(fml));
             m_axioms.push_back(fml);
             _fml = abs(fml);
@@ -1882,7 +1941,7 @@ namespace smtfd {
             lbool r = l_undef;
             expr_ref_vector core(m), axioms(m);
             while (true) {
-                IF_VERBOSE(1, verbose_stream() << "(smtfd-check-sat :rounds " << m_stats.m_num_rounds 
+                IF_VERBOSE(1, indent(); verbose_stream() << "(smtfd-check-sat :rounds " << m_stats.m_num_rounds 
                            << " :lemmas " << m_stats.m_num_lemmas << " :qi " << m_stats.m_num_mbqi << ")\n");
                 m_stats.m_num_rounds++;
                 checkpoint();
@@ -1949,7 +2008,7 @@ namespace smtfd {
                     ++round;
                     continue;
                 }
-                IF_VERBOSE(1, verbose_stream() << "(smtfd-round :round " << round << " :lemmas " << m_context.size() << ")\n");
+                IF_VERBOSE(1, indent(); verbose_stream() << "(smtfd-round :round " << round << " :lemmas " << m_context.size() << ")\n");
                 round = 0;
                 TRACE("smtfd_verbose", 
                       for (expr* f : m_context) tout << "refine " << mk_bounded_pp(f, m, 3) << "\n";
@@ -2059,7 +2118,7 @@ namespace smtfd {
 }
 
 solver * mk_smtfd_solver(ast_manager & m, params_ref const & p) {
-    return alloc(smtfd::solver, m, p);
+    return alloc(smtfd::solver, 0, m, p);
 }
 
 tactic * mk_smtfd_tactic(ast_manager & m, params_ref const & p) {
diff --git a/src/util/parray.h b/src/util/parray.h
index 38d6987d0..d02d47c83 100644
--- a/src/util/parray.h
+++ b/src/util/parray.h
@@ -216,9 +216,7 @@ private:
         unsigned sz = r->m_size;
         vs = nullptr;
         copy_values(r->m_values, sz, vs);
-        unsigned i = cs.size();
-        while (i > 0) {
-            --i;
+        for (unsigned i = cs.size(); i-- > 0; ) {
             cell * curr = cs[i];
             switch (curr->kind()) {
             case SET:
@@ -312,6 +310,26 @@ public:
         }
     }
 
+    void check_size(cell* c) const {
+        unsigned r;
+        while (c) {
+            switch (c->kind()) {
+            case SET:
+                break;
+            case PUSH_BACK:
+                r = size(c->next());
+                break;
+            case POP_BACK:
+                r = size(c->next());
+                SASSERT(c->idx() == r);
+                break;
+            case ROOT:
+                return;
+            }
+            c = c->next();
+        }
+    }
+
     bool empty(ref const & r) const { return size(r) == 0; }
 
     value const & get(ref const & r, unsigned i) const {
@@ -528,7 +546,7 @@ public:
         unsigned r_sz = size(r);
         unsigned trail_split_idx = r_sz / C::factor;
         unsigned i = 0;
-        cell * c   = r.m_ref;
+        cell * c   = r.m_ref;        
         while (c->kind() != ROOT && i < trail_split_idx) {
             cs.push_back(c);
             c = c->next();
@@ -538,10 +556,9 @@ public:
             // root is too far away.
             unfold(c);
         }
-        SASSERT(c->kind() == ROOT);
-        i = cs.size();
-        while (i > 0) {
-            --i;
+        DEBUG_CODE(check_size(c););
+        SASSERT(c->kind() == ROOT);      
+        for (i = cs.size(); i-- > 0; ) {
             cell * p = cs[i];
             SASSERT(c->m_kind == ROOT);
             unsigned sz = c->m_size;
@@ -558,10 +575,10 @@ public:
             case PUSH_BACK:
                 c->m_kind = POP_BACK;
                 if (sz == capacity(vs))
-                    expand(vs);
-                c->m_idx  = sz;
-                vs[sz]    = p->m_elem;
-                sz++;
+                    expand(vs);                
+                vs[sz] = p->m_elem;
+                ++sz;
+                c->m_idx = sz;
                 break;
             case POP_BACK:
                 c->m_kind = PUSH_BACK;
@@ -575,11 +592,12 @@ public:
             }
             inc_ref(p);
             c->m_next   = p;
-            // p does not point to c anymore
-            dec_ref(c);
+
             p->m_kind   = ROOT;
             p->m_size   = sz;
             p->m_values = vs;
+            // p does not point to c anymore
+            dec_ref(c);
             c = p;
         }
         SASSERT(c == r.m_ref);
@@ -604,9 +622,11 @@ public:
             case ROOT: out << "root, " << c->m_size << ", " << capacity(c->m_values); break;
             }
             out << "]#" << c->m_ref_count;
-            if (c->kind() == ROOT)
+            if (c->kind() == ROOT) {
+                out << "\n";
                 break;
-            out << " -> ";
+            }
+            out << " -> \n";
             c = c->next();
         }
     }