diff --git a/scripts/mk_project.py b/scripts/mk_project.py index 4c99b1a7a..adb90f22e 100644 --- a/scripts/mk_project.py +++ b/scripts/mk_project.py @@ -33,7 +33,7 @@ def init_project_def(): add_lib('macros', ['rewriter'], 'ast/macros') add_lib('model', ['macros']) add_lib('converters', ['model'], 'ast/converters') - add_lib('ast_sls', ['ast','normal_forms','converters','smt_params'], 'ast/sls') + add_lib('ast_sls', ['ast','normal_forms','converters','smt_params','euf'], 'ast/sls') add_lib('sat', ['params', 'util', 'dd', 'ast_sls', 'grobner']) add_lib('nlsat', ['polynomial', 'sat']) add_lib('lp', ['util', 'nlsat', 'grobner', 'interval', 'smt_params'], 'math/lp') diff --git a/src/ast/sls/CMakeLists.txt b/src/ast/sls/CMakeLists.txt index 519497e9f..c2a1376ee 100644 --- a/src/ast/sls/CMakeLists.txt +++ b/src/ast/sls/CMakeLists.txt @@ -15,6 +15,7 @@ z3_add_component(ast_sls sls_context.cpp sls_euf_plugin.cpp sls_smt_solver.cpp + sls_user_sort_plugin.cpp COMPONENT_DEPENDENCIES ast euf diff --git a/src/ast/sls/sls_array_plugin.cpp b/src/ast/sls/sls_array_plugin.cpp index 749970be5..deec71b51 100644 --- a/src/ast/sls/sls_array_plugin.cpp +++ b/src/ast/sls/sls_array_plugin.cpp @@ -16,10 +16,11 @@ Author: --*/ #include "ast/sls/sls_array_plugin.h" +#include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" namespace sls { - array_plugin::array_plugin(context& ctx): plugin(ctx), @@ -28,12 +29,180 @@ namespace sls { m_fid = a.get_family_id(); } - bool array_plugin::is_sat() { - euf::egraph g(m); - init_egraph(g); + m_g = alloc(euf::egraph, m); + m_kv = nullptr; + init_egraph(*m_g); + saturate_store(*m_g); + return true; + } - return false; + // b ~ a[i -> v] + // ensure b[i] ~ v + // ensure b[j] ~ a[j] for j != i + + void array_plugin::saturate_store(euf::egraph& g) { + unsigned sz = 0; + while (sz < g.nodes().size()) { + sz = g.nodes().size(); + for (unsigned i = 0; i < sz; ++i) { + auto n = g.nodes()[i]; + if (!a.is_store(n->get_expr())) + continue; + + force_store_axiom1(g, n); + + for (auto p : euf::enode_parents(n->get_root())) + if (a.is_select(p->get_expr())) + force_store_axiom2_down(g, n, p); + + auto arr = n->get_arg(0); + for (auto p : euf::enode_parents(arr->get_root())) + if (a.is_select(p->get_expr())) + force_store_axiom2_up(g, n, p); + } + } + display(verbose_stream() << "saturated\n"); + } + + euf::enode* array_plugin::mk_select(euf::egraph& g, euf::enode* b, euf::enode* sel) { + auto arity = get_array_arity(b->get_sort()); + ptr_buffer args; + ptr_buffer eargs; + args.push_back(b->get_expr()); + eargs.push_back(b); + for (unsigned i = 1; i <= arity; ++i) { + auto idx = sel->get_arg(i); + eargs.push_back(idx); + args.push_back(idx->get_expr()); + } + expr_ref esel(a.mk_select(args), m); + auto n = g.find(esel); + return n ? n : g.mk(esel, 0, eargs.size(), eargs.data()); + } + + // ensure a[i->v][i] = v exists in the e-graph + void array_plugin::force_store_axiom1(euf::egraph& g, euf::enode* n) { + SASSERT(a.is_store(n->get_expr())); + auto val = n->get_arg(n->num_args() - 1); + auto nsel = mk_select(g, n, n); + if (are_distinct(nsel, val)) + add_store_axiom1(n->get_app()); + else + g.merge(nsel, val, nullptr); + } + + // i /~ j, b ~ a[i->v], b[j] occurs -> a[j] = b[j] + void array_plugin::force_store_axiom2_down(euf::egraph& g, euf::enode* sto, euf::enode* sel) { + SASSERT(a.is_store(sto->get_expr())); + SASSERT(a.is_select(sel->get_expr())); + if (sel->get_arg(0)->get_root() != sto->get_root()) + return; + if (eq_args(sto, sel)) + return; + auto nsel = mk_select(g, sto->get_arg(0), sel); + if (are_distinct(nsel, sel)) + add_store_axiom2(sto->get_app(), sel->get_app()); + else + g.merge(nsel, sel, nullptr); + } + + // a ~ b, i /~ j, b[j] occurs -> a[i -> v][j] = b[j] + void array_plugin::force_store_axiom2_up(euf::egraph& g, euf::enode* sto, euf::enode* sel) { + SASSERT(a.is_store(sto->get_expr())); + SASSERT(a.is_select(sel->get_expr())); + if (sel->get_arg(0)->get_root() != sto->get_arg(0)->get_root()) + return; + if (eq_args(sto, sel)) + return; + auto nsel = mk_select(g, sto, sel); + if (are_distinct(nsel, sel)) + add_store_axiom2(sto->get_app(), sel->get_app()); + else + g.merge(nsel, sel, nullptr); + } + + bool array_plugin::are_distinct(euf::enode* a, euf::enode* b) { + a = a->get_root(); + b = b->get_root(); + return a->interpreted() && b->interpreted() && a != b; // TODO work with nested arrays? + } + + bool array_plugin::eq_args(euf::enode* sto, euf::enode* sel) { + SASSERT(a.is_store(sto->get_expr())); + SASSERT(a.is_select(sel->get_expr())); + unsigned arity = get_array_arity(sto->get_sort()); + for (unsigned i = 1; i < arity; ++i) { + if (sto->get_arg(i)->get_root() != sel->get_arg(i)->get_root()) + return false; + } + return true; + } + + void array_plugin::add_store_axiom1(app* sto) { + if (!m_add_conflicts) + return; + ptr_vector args; + args.push_back(sto); + for (unsigned i = 1; i < sto->get_num_args() - 1; ++i) + args.push_back(sto->get_arg(i)); + expr_ref sel(a.mk_select(args), m); + expr_ref eq(m.mk_eq(sel, to_app(sto)->get_arg(sto->get_num_args() - 1)), m); + verbose_stream() << "add store axiom 1 " << mk_bounded_pp(sto, m) << "\n"; + ctx.add_clause(eq); + } + + void array_plugin::add_store_axiom2(app* sto, app* sel) { + if (!m_add_conflicts) + return; + ptr_vector args1, args2; + args1.push_back(sto); + args2.push_back(sto->get_arg(0)); + for (unsigned i = 1; i < sel->get_num_args() - 1; ++i) { + args1.push_back(sel->get_arg(i)); + args2.push_back(sel->get_arg(i)); + } + expr_ref sel1(a.mk_select(args1), m); + expr_ref sel2(a.mk_select(args2), m); + expr_ref eq(m.mk_eq(sel1, sel2), m); + expr_ref_vector ors(m); + ors.push_back(eq); + for (unsigned i = 1; i < sel->get_num_args() - 1; ++i) + ors.push_back(m.mk_eq(sel->get_arg(i), sto->get_arg(i))); + verbose_stream() << "add store axiom 2 " << mk_bounded_pp(sto, m) << " " << mk_bounded_pp(sel, m) << "\n"; + ctx.add_clause(m.mk_or(ors)); + } + + void array_plugin::init_egraph(euf::egraph& g) { + ptr_vector args; + for (auto t : ctx.subterms()) { + args.reset(); + if (is_app(t)) + for (auto* arg : *to_app(t)) + args.push_back(g.find(arg)); + + euf::enode* n1, * n2; + n1 = g.find(t); + n1 = n1 ? n1 : g.mk(t, 0, args.size(), args.data()); + if (a.is_array(t)) + continue; + auto v = ctx.get_value(t); + verbose_stream() << "init " << mk_bounded_pp(t, m) << " := " << mk_bounded_pp(v, m) << "\n"; + n2 = g.find(v); + n2 = n2 ? n2: g.mk(v, 0, 0, nullptr); + g.merge(n1, n2, nullptr); + } + for (auto lit : ctx.root_literals()) { + if (!ctx.is_true(lit)) + continue; + auto e = ctx.atom(lit.var()); + expr* x, * y; + if (e && m.is_eq(e, x, y)) + g.merge(g.find(x), g.find(y), nullptr); + } + + display(verbose_stream()); + } void array_plugin::init_kv(euf::egraph& g, kv& kv) { @@ -44,61 +213,56 @@ namespace sls { for (auto p : euf::enode_parents(n)) { if (!a.is_select(p->get_expr())) continue; - SASSERT(n->num_args() == 2); + SASSERT(p->num_args() == 2); if (p->get_arg(0)->get_root() != n->get_root()) continue; - auto idx = n->get_arg(1)->get_root(); + auto idx = p->get_arg(1)->get_root(); auto val = p->get_root(); kv[n].insert(idx, val); } } + display(verbose_stream()); } - void array_plugin::saturate_store(euf::egraph& g, kv& kv) { - for (auto n : g.nodes()) { - if (!a.is_store(n->get_expr())) - continue; - SASSERT(n->num_args() == 3); - auto idx = n->get_arg(1)->get_root(); - auto val = n->get_arg(2)->get_root(); - auto arr = n->get_arg(0)->get_root(); -#if 0 - auto it = kv.find(arr); - if (it == kv.end()) - continue; - auto it2 = it->get_value().find(idx); - if (it2 == nullptr) - continue; - g.merge(val, it2->get_value(), nullptr); -#endif + expr_ref array_plugin::get_value(expr* e) { + SASSERT(a.is_array(e)); + if (!m_g) { + m_g = alloc(euf::egraph, m); + init_egraph(*m_g); + flet _strong(m_add_conflicts, false); + saturate_store(*m_g); } + if (!m_kv) { + m_kv = alloc(kv); + init_kv(*m_g, *m_kv); + } + auto& kv = *m_kv; + auto n = m_g->find(e)->get_root(); + expr_ref r(n->get_expr(), m); + for (auto [k, v] : kv[n]) { + ptr_vector args; + args.push_back(r); + args.push_back(k->get_expr()); + args.push_back(v->get_expr()); + r = a.mk_store(args); + } + return r; } - void array_plugin::init_egraph(euf::egraph& g) { - ptr_vector args; - for (auto t : ctx.subterms()) { - args.reset(); - if (is_app(t)) { - for (auto* arg : *to_app(t)) { - args.push_back(g.find(arg)); + std::ostream& array_plugin::display(std::ostream& out) const { + if (m_g) + m_g->display(out); + if (m_kv) { + for (auto& [n, kvs] : *m_kv) { + out << m_g->pp(n) << " -> {"; + char const* sp = ""; + for (auto& [k, v] : kvs) { + out << sp << m_g->pp(k) << " -> " << m_g->pp(v); + sp = " "; } + out << "}\n"; } - auto n = g.mk(t, 0, args.size(), args.data()); - if (a.is_array(t)) - continue; - auto v = ctx.get_value(t); - auto n2 = g.mk(v, 0, 0, nullptr); - g.merge(n, n2, nullptr); - } - for (auto lit : ctx.root_literals()) { - if (!ctx.is_true(lit)) - continue; - auto e = ctx.atom(lit.var()); - expr* x, * y; - if (e && m.is_eq(e, x, y)) - g.merge(g.find(x), g.find(y), nullptr); } + return out; } - - } diff --git a/src/ast/sls/sls_array_plugin.h b/src/ast/sls/sls_array_plugin.h index f19d112fd..ebeec6fbd 100644 --- a/src/ast/sls/sls_array_plugin.h +++ b/src/ast/sls/sls_array_plugin.h @@ -22,31 +22,42 @@ Author: namespace sls { - class array_plugin : public plugin { - array_util a; - typedef obj_map> kv; + + array_util a; + scoped_ptr m_g; + scoped_ptr m_kv; + bool m_add_conflicts = true; + void init_egraph(euf::egraph& g); void init_kv(euf::egraph& g, kv& kv); - void saturate_store(euf::egraph& g, kv& kv); + void saturate_store(euf::egraph& g); + void force_store_axiom1(euf::egraph& g, euf::enode* n); + void force_store_axiom2_down(euf::egraph& g, euf::enode* sto, euf::enode* sel); + void force_store_axiom2_up(euf::egraph& g, euf::enode* sto, euf::enode* sel); + void add_store_axiom1(app* sto); + void add_store_axiom2(app* sto, app* sel); + bool are_distinct(euf::enode* a, euf::enode* b); + bool eq_args(euf::enode* sto, euf::enode* sel); + euf::enode* mk_select(euf::egraph& g, euf::enode* b, euf::enode* sel); public: array_plugin(context& ctx); ~array_plugin() override {} void register_term(expr* e) override { } - expr_ref get_value(expr* e) override { return expr_ref(m); } - void initialize() override {} - void propagate_literal(sat::literal lit) override {} + expr_ref get_value(expr* e) override; + void initialize() override { m_g = nullptr; } + void propagate_literal(sat::literal lit) override { m_g = nullptr; } bool propagate() override { return false; } - bool repair_down(app* e) override { return false; } + bool repair_down(app* e) override { return true; } void repair_up(app* e) override {} - void repair_literal(sat::literal lit) override {} + void repair_literal(sat::literal lit) override { m_g = nullptr; } bool is_sat() override; void on_rescale() override {} void on_restart() override {} - std::ostream& display(std::ostream& out) const override { return out; } + std::ostream& display(std::ostream& out) const override; void mk_model(model& mdl) override {} bool set_value(expr* e, expr* v) override { return false; } void collect_statistics(statistics& st) const override {} diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index 87f4d5890..3c2c0ce6e 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -18,8 +18,11 @@ Author: #include "ast/sls/sls_context.h" #include "ast/sls/sls_euf_plugin.h" #include "ast/sls/sls_arith_plugin.h" +#include "ast/sls/sls_array_plugin.h" #include "ast/sls/sls_bv_plugin.h" #include "ast/sls/sls_basic_plugin.h" +#include "ast/sls/sls_model_value_plugin.h" +#include "ast/sls/sls_user_sort_plugin.h" #include "ast/ast_ll_pp.h" #include "ast/ast_pp.h" #include "smt/params/smt_params_helper.hpp" @@ -42,6 +45,9 @@ namespace sls { register_plugin(alloc(arith_plugin, *this)); register_plugin(alloc(bv_plugin, *this)); register_plugin(alloc(basic_plugin, *this)); + register_plugin(alloc(array_plugin, *this)); + register_plugin(alloc(user_sort_plugin, *this)); + register_plugin(alloc(model_value_plugin, *this)); } void context::updt_params(params_ref const& p) { @@ -205,6 +211,7 @@ namespace sls { auto p = m_plugins.get(fid, nullptr); if (p) return p->get_value(e); + verbose_stream() << fid << " " << m.get_family_name(fid) << " " << mk_pp(e, m) << "\n"; UNREACHABLE(); return expr_ref(e, m); } diff --git a/src/ast/sls/sls_model_value_plugin.h b/src/ast/sls/sls_model_value_plugin.h new file mode 100644 index 000000000..d5664cb27 --- /dev/null +++ b/src/ast/sls/sls_model_value_plugin.h @@ -0,0 +1,47 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_model_value_plugin.h + +Abstract: + + Theory plugin for model values + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-06 + +--*/ +#pragma once + +#include "ast/sls/sls_context.h" + +namespace sls { + + class model_value_plugin : public plugin { + + public: + model_value_plugin(context& ctx) : plugin(ctx) { m_fid = m.get_family_id("model-value"); } + ~model_value_plugin() override {} + void register_term(expr* e) override { } + expr_ref get_value(expr* e) override { return expr_ref(e, m); } + void initialize() override { } + void propagate_literal(sat::literal lit) override { } + bool propagate() override { return false; } + bool repair_down(app* e) override { return true; } + void repair_up(app* e) override {} + void repair_literal(sat::literal lit) override { } + bool is_sat() override { return true; } + + void on_rescale() override {} + void on_restart() override {} + std::ostream& display(std::ostream& out) const override { return out;} + void mk_model(model& mdl) override {} + bool set_value(expr* e, expr* v) override { return false; } + void collect_statistics(statistics& st) const override {} + void reset_statistics() override {} + }; + +} diff --git a/src/ast/sls/sls_user_sort_plugin.cpp b/src/ast/sls/sls_user_sort_plugin.cpp new file mode 100644 index 000000000..f731eb007 --- /dev/null +++ b/src/ast/sls/sls_user_sort_plugin.cpp @@ -0,0 +1,92 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_user_sort_plugin.cpp + +Abstract: + + Theory plugin for user sort local search + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-06 + +--*/ + +#include "ast/sls/sls_user_sort_plugin.h" +#include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" + + +namespace sls { + + user_sort_plugin::user_sort_plugin(context& ctx): + plugin(ctx) + { + m_fid = user_sort_family_id; + } + + + void user_sort_plugin::init_egraph(euf::egraph& g) { + ptr_vector args; + for (auto t : ctx.subterms()) { + args.reset(); + if (is_app(t)) { + for (auto* arg : *to_app(t)) { + args.push_back(g.find(arg)); + } + } + g.mk(t, 0, args.size(), args.data()); + } + for (auto lit : ctx.root_literals()) { + if (!ctx.is_true(lit) || lit.sign()) + continue; + auto e = ctx.atom(lit.var()); + expr* x, * y; + if (e && m.is_eq(e, x, y)) + g.merge(g.find(x), g.find(y), nullptr); + } + display(verbose_stream()); + + + typedef obj_map map1; + typedef obj_map map2; + + m_num_elems = alloc(map1); + m_root2value = alloc(map2); + m_pinned = alloc(expr_ref_vector, m); + + for (auto n : g.nodes()) { + if (n->is_root() && is_user_sort(n->get_sort())) { + unsigned num = 0; + m_num_elems->find(n->get_sort(), num); + expr* v = m.mk_model_value(num, n->get_sort()); + m_pinned->push_back(v); + m_root2value->insert(n, v); + m_num_elems->insert(n->get_sort(), num + 1); + } + } + } + + + expr_ref user_sort_plugin::get_value(expr* e) { + if (m.is_model_value(e)) + return expr_ref(e, m); + + if (!m_g) { + m_g = alloc(euf::egraph, m); + init_egraph(*m_g); + } + auto n = m_g->find(e)->get_root(); + VERIFY(m_root2value->find(n, e)); + return expr_ref(e, m); + } + + std::ostream& user_sort_plugin::display(std::ostream& out) const { + if (m_g) + m_g->display(out); + return out; + } +} diff --git a/src/ast/sls/sls_user_sort_plugin.h b/src/ast/sls/sls_user_sort_plugin.h new file mode 100644 index 000000000..121798bd2 --- /dev/null +++ b/src/ast/sls/sls_user_sort_plugin.h @@ -0,0 +1,55 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_user_sort_plugin.h + +Abstract: + + Theory plugin for arrays local search + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-06 + +--*/ +#pragma once + +#include "ast/sls/sls_context.h" +#include "ast/euf/euf_egraph.h" + +namespace sls { + + class user_sort_plugin : public plugin { + scoped_ptr m_g; + scoped_ptr> m_num_elems; + scoped_ptr> m_root2value; + scoped_ptr m_pinned; + + void init_egraph(euf::egraph& g); + bool is_user_sort(sort* s) { return s->get_family_id() == user_sort_family_id; } + + public: + user_sort_plugin(context& ctx); + ~user_sort_plugin() override {} + void register_term(expr* e) override { } + expr_ref get_value(expr* e) override; + void initialize() override { m_g = nullptr; } + void propagate_literal(sat::literal lit) override { m_g = nullptr; } + bool propagate() override { return false; } + bool repair_down(app* e) override { return true; } + void repair_up(app* e) override {} + void repair_literal(sat::literal lit) override { m_g = nullptr; } + bool is_sat() override { return true; } + + void on_rescale() override {} + void on_restart() override {} + std::ostream& display(std::ostream& out) const override; + void mk_model(model& mdl) override {} + bool set_value(expr* e, expr* v) override { return false; } + void collect_statistics(statistics& st) const override {} + void reset_statistics() override {} + }; + +}