From d465bdbb87e5b71b43d13db4da64cd356ad9d725 Mon Sep 17 00:00:00 2001
From: Nikolaj Bjorner <nbjorner@microsoft.com>
Date: Fri, 31 Jan 2025 11:06:40 -0800
Subject: [PATCH] include extensionality constraints for arrays

---
 src/ast/sls/sls_array_plugin.cpp | 94 +++++++++++++++++++++++++++++++-
 src/ast/sls/sls_array_plugin.h   |  4 ++
 src/ast/sls/sls_context.cpp      |  3 +-
 3 files changed, 99 insertions(+), 2 deletions(-)

diff --git a/src/ast/sls/sls_array_plugin.cpp b/src/ast/sls/sls_array_plugin.cpp
index 70e6335fb..9b2948ed9 100644
--- a/src/ast/sls/sls_array_plugin.cpp
+++ b/src/ast/sls/sls_array_plugin.cpp
@@ -40,7 +40,9 @@ namespace sls {
         if (m_g->inconsistent()) {
             resolve_conflict();
             return false;
-        }         
+        }        
+        if (saturate_extensionality(*m_g))
+            return false;
         return !m_g->inconsistent();
     }
 
@@ -122,6 +124,70 @@ namespace sls {
         IF_VERBOSE(10, display(verbose_stream() << "saturated\n"));
     }
 
+    bool array_plugin::saturate_extensionality(euf::egraph& g) {
+        bool new_axiom = false;
+        for (auto lit : ctx.root_literals()) {
+            if (!lit.sign() || !ctx.is_true(lit))
+                continue;
+            auto e = ctx.atom(lit.var());
+            expr* x, *y;
+            if (m.is_eq(e, x, y) && a.is_array(x) && add_extensionality_axiom(x, y))
+                new_axiom = true;            
+        }
+
+        euf::enode_vector shared;
+        collect_shared(g, shared);
+        for (unsigned i = shared.size(); i-- > 0; ) {
+            auto x = shared[i];
+            auto e1 = x->get_expr();
+            for (unsigned j = i; j-- > 0; ) {
+                auto y = shared[j];
+                auto e2 = y->get_expr();
+                if (e1->get_sort() != e2->get_sort())
+                    continue;
+                if (add_extensionality_axiom(e1, e2))
+                    new_axiom = true;
+            }
+        }
+        return new_axiom;
+    }
+
+    void array_plugin::collect_shared(euf::egraph& g, euf::enode_vector& shared) {
+        ptr_buffer<euf::enode> to_unmark;
+        for (auto n : g.nodes()) {
+            expr* e = n->get_expr();
+            if (!a.is_array(e))
+                continue;
+            if (!ctx.is_relevant(e))
+                continue;
+            euf::enode * r = n->get_root();
+            if (r->is_marked1()) 
+                continue;            
+            if (is_shared_arg(r))
+                shared.push_back(r);
+            r->mark1();            
+        }
+        for (auto* r : to_unmark)
+            r->unmark1();
+    }
+
+    bool array_plugin::is_shared_arg(euf::enode* r) {
+        SASSERT(r->is_root());
+        for (euf::enode* n : euf::enode_parents(r)) {
+            expr* e = n->get_expr();
+            if (a.is_select(e) || a.is_store(e)) {
+                for (unsigned i = 1; i < n->num_args(); ++i)
+                    if (r == n->get_arg(i)->get_root())
+                        return true;
+                continue;
+            }
+            if (m.is_eq(e))
+                continue;
+            return true;
+        }            
+        return false;
+    }
+
     void array_plugin::saturate_store(euf::egraph& g, euf::enode* n) {
         force_store_axiom1(g, n);
         for (auto p : euf::enode_parents(n->get_root()))
@@ -329,6 +395,32 @@ namespace sls {
         ctx.add_theory_axiom(m.mk_or(ors));
     }
 
+    bool array_plugin::add_extensionality_axiom(expr* x, expr* y) {
+        SASSERT(a.is_array(x));
+        SASSERT(x->get_sort() == y->get_sort());
+        auto s = x->get_sort();
+        auto dimension = get_array_arity(s);
+        func_decl_ref_vector funcs(m);
+        for (unsigned i = 0; i < dimension; ++i) 
+            funcs.push_back(a.mk_array_ext(s, i));
+
+        expr_ref_vector args1(m), args2(m);
+        args1.push_back(x);
+        args2.push_back(y);
+        for (func_decl* f : funcs) {
+            expr_ref k(m.mk_app(f, x, y), m);
+            args1.push_back(k);
+            args2.push_back(k);
+        }
+        expr_ref sel1(a.mk_select(args1), m);
+        expr_ref sel2(a.mk_select(args2), m);
+        bool r = ctx.add_constraint(m.mk_implies(m.mk_eq(sel1, sel2), m.mk_eq(x, y)));
+        if (r)
+            ++m_stats.m_num_axioms;        
+        return r;
+    }
+
+
     void array_plugin::init_egraph(euf::egraph& g) {
         ptr_vector<euf::enode> args;
         for (auto t : ctx.subterms()) {
diff --git a/src/ast/sls/sls_array_plugin.h b/src/ast/sls/sls_array_plugin.h
index 02d2b1865..ca72a454e 100644
--- a/src/ast/sls/sls_array_plugin.h
+++ b/src/ast/sls/sls_array_plugin.h
@@ -84,6 +84,9 @@ namespace sls {
         void init_egraph(euf::egraph& g);
         void init_kv(euf::egraph& g, kv& kv);
         void saturate(euf::egraph& g);
+        bool saturate_extensionality(euf::egraph& g);
+        void collect_shared(euf::egraph& g, euf::enode_vector& shared);
+        bool is_shared_arg(euf::enode* r);
         void saturate_store(euf::egraph& g, euf::enode* n);
         void saturate_const(euf::egraph& g, euf::enode* n);
         void saturate_map(euf::egraph& g, euf::enode* n);
@@ -94,6 +97,7 @@ namespace sls {
         void add_map_axiom(euf::egraph& g, euf::enode* n, euf::enode* sel);
         void add_store_axiom1(app* sto);
         void add_store_axiom2(app* sto, app* sel);
+        bool add_extensionality_axiom(expr* a, expr* b);
         bool are_distinct(euf::enode* a, euf::enode* b);
         bool eq_args(euf::enode* sto, euf::enode* sel);
         euf::enode* mk_select(euf::egraph& g, euf::enode* b, euf::enode* sel);
diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp
index ae140545a..91cfd5fae 100644
--- a/src/ast/sls/sls_context.cpp
+++ b/src/ast/sls/sls_context.cpp
@@ -148,7 +148,8 @@ namespace sls {
                 continue;            
 
             if (all_of(m_plugins, [&](auto* p) { return !p || p->is_sat(); })) {
-                VERIFY(unsat().empty() || !m_new_constraint);
+                if (!unsat().empty() || m_new_constraint)
+                    continue;
                 values2model();
                 return l_true;
             }