From 5864fcba6b11c2cbe38e7adf5ed457a150d2f584 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 18 Oct 2024 09:34:49 -0700 Subject: [PATCH] fixing model construction for underspecified operators Signed-off-by: Nikolaj Bjorner --- src/ast/sls/sls_arith_base.cpp | 3 -- src/ast/sls/sls_arith_base.h | 1 - src/ast/sls/sls_arith_plugin.cpp | 4 -- src/ast/sls/sls_arith_plugin.h | 1 - src/ast/sls/sls_array_plugin.h | 1 - src/ast/sls/sls_basic_plugin.h | 1 - src/ast/sls/sls_bv_plugin.h | 1 - src/ast/sls/sls_context.cpp | 53 ++++++++++++++------ src/ast/sls/sls_context.h | 5 +- src/ast/sls/sls_datatype_plugin.cpp | 77 +++++++++++++++-------------- src/ast/sls/sls_datatype_plugin.h | 8 +-- src/ast/sls/sls_euf_plugin.cpp | 22 ++------- src/ast/sls/sls_euf_plugin.h | 4 +- 13 files changed, 93 insertions(+), 88 deletions(-) diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index e5868bf18..f5dc20260 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -2310,9 +2310,6 @@ namespace sls { SASSERT(val == i.m_args_value); } - template - void arith_base::mk_model(model& mdl) { - } template void arith_base::collect_statistics(statistics& st) const { diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index 05ed99b00..fe9876660 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -277,7 +277,6 @@ namespace sls { void on_rescale() override; void on_restart() override; std::ostream& display(std::ostream& out) const override; - void mk_model(model& mdl) override; void collect_statistics(statistics& st) const override; void reset_statistics() override; }; diff --git a/src/ast/sls/sls_arith_plugin.cpp b/src/ast/sls/sls_arith_plugin.cpp index f48c85642..310d4009f 100644 --- a/src/ast/sls/sls_arith_plugin.cpp +++ b/src/ast/sls/sls_arith_plugin.cpp @@ -99,10 +99,6 @@ namespace sls { return m_arith->display(out); } - void arith_plugin::mk_model(model& mdl) { - WITH_FALLBACK(mk_model(mdl)); - } - bool arith_plugin::repair_down(app* e) { WITH_FALLBACK(repair_down(e)); } diff --git a/src/ast/sls/sls_arith_plugin.h b/src/ast/sls/sls_arith_plugin.h index 5c8d4b245..7d8491579 100644 --- a/src/ast/sls/sls_arith_plugin.h +++ b/src/ast/sls/sls_arith_plugin.h @@ -43,7 +43,6 @@ namespace sls { void on_rescale() override; void on_restart() override; std::ostream& display(std::ostream& out) const override; - void mk_model(model& mdl) override; bool set_value(expr* e, expr* v) override; void collect_statistics(statistics& st) const override; diff --git a/src/ast/sls/sls_array_plugin.h b/src/ast/sls/sls_array_plugin.h index d6557f245..4f8f051f4 100644 --- a/src/ast/sls/sls_array_plugin.h +++ b/src/ast/sls/sls_array_plugin.h @@ -82,7 +82,6 @@ namespace sls { void on_rescale() override {} void on_restart() override {} std::ostream& display(std::ostream& out) const override; - void mk_model(model& mdl) override {} bool set_value(expr* e, expr* v) override { return false; } void collect_statistics(statistics& st) const override {} void reset_statistics() override {} diff --git a/src/ast/sls/sls_basic_plugin.h b/src/ast/sls/sls_basic_plugin.h index d640415f4..63ac3ed5d 100644 --- a/src/ast/sls/sls_basic_plugin.h +++ b/src/ast/sls/sls_basic_plugin.h @@ -51,7 +51,6 @@ namespace sls { void on_rescale() override {} void on_restart() override {} std::ostream& display(std::ostream& out) const override; - void mk_model(model& mdl) override {} bool set_value(expr* e, expr* v) override; void collect_statistics(statistics& st) const override {} void reset_statistics() override {} diff --git a/src/ast/sls/sls_bv_plugin.h b/src/ast/sls/sls_bv_plugin.h index 013f4ad3c..7d9e338e7 100644 --- a/src/ast/sls/sls_bv_plugin.h +++ b/src/ast/sls/sls_bv_plugin.h @@ -54,7 +54,6 @@ namespace sls { void on_rescale() override {} void on_restart() override {} std::ostream& display(std::ostream& out) const override; - void mk_model(model& mdl) override {} bool set_value(expr* e, expr* v) override; void collect_statistics(statistics& st) const override {} void reset_statistics() override {} diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index aabd24980..07bfc3e8d 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -129,21 +129,46 @@ namespace sls { return l_undef; if (all_of(m_plugins, [&](auto* p) { return !p || p->is_sat(); })) { - model_ref mdl = alloc(model, m); - for (expr* e : subterms()) - if (is_uninterp_const(e)) - mdl->register_decl(to_app(e)->get_decl(), get_value(e)); - for (auto p : m_plugins) - if (p) - p->mk_model(*mdl); - s.on_model(mdl); - // verbose_stream() << *mdl << "\n"; - TRACE("sls", display(tout)); + values2model(); return l_true; } } return l_undef; } + + void context::values2model() { + model_ref mdl = alloc(model, m); + expr_ref_vector args(m); + for (expr* e : subterms()) + if (is_uninterp_const(e)) + mdl->register_decl(to_app(e)->get_decl(), get_value(e)); + + for (expr* e : subterms()) { + if (!is_app(e)) + continue; + auto f = to_app(e)->get_decl(); + if (!include_func_interp(f)) + continue; + auto v = get_value(e); + auto fi = mdl->get_func_interp(f); + if (!fi) { + fi = alloc(func_interp, m, f->get_arity()); + mdl->register_decl(f, fi); + } + args.reset(); + for (expr* arg : *to_app(e)) { + args.push_back(get_value(arg)); + SASSERT(args.back()); + } + SASSERT(f->get_arity() == args.size()); + if (!fi->get_entry(args.data())) + fi->insert_new_entry(args.data(), v); + } + + s.on_model(mdl); + // verbose_stream() << *mdl << "\n"; + TRACE("sls", display(tout)); + } void context::propagate_boolean_assignment() { reinit_relevant(); @@ -156,8 +181,7 @@ namespace sls { propagate_literal(lit); if (m_new_constraint) - return; - + return; while (!m_new_constraint && m.inc() && (!m_repair_up.empty() || !m_repair_down.empty())) { while (!m_repair_down.empty() && !m_new_constraint && m.inc()) { @@ -264,10 +288,7 @@ namespace sls { } bool context::set_value(expr * e, expr * v) { - for (auto p : m_plugins) - if (p && p->set_value(e, v)) - return true; - return false; + return any_of(m_plugins, [&](auto p) { return p && p->set_value(e, v); }); } bool context::is_relevant(expr* e) { diff --git a/src/ast/sls/sls_context.h b/src/ast/sls/sls_context.h index 6e53186bd..4d334e8c0 100644 --- a/src/ast/sls/sls_context.h +++ b/src/ast/sls/sls_context.h @@ -53,10 +53,10 @@ namespace sls { virtual void on_rescale() {}; virtual void on_restart() {}; virtual std::ostream& display(std::ostream& out) const = 0; - virtual void mk_model(model& mdl) = 0; virtual bool set_value(expr* e, expr* v) = 0; virtual void collect_statistics(statistics& st) const = 0; virtual void reset_statistics() = 0; + virtual bool include_func_interp(func_decl* f) const { return false; } }; using clause = ptr_iterator; @@ -139,6 +139,8 @@ namespace sls { void propagate_literal(sat::literal lit); void repair_literals(); + void values2model(); + void ensure_plugin(expr* e); void ensure_plugin(family_id fid); family_id get_fid(expr* e) const; @@ -181,6 +183,7 @@ namespace sls { bool is_unit(sat::literal lit) const { return m_unit_indices.contains(lit.index()); } void reinit_relevant(); void force_restart() { s.force_restart(); } + bool include_func_interp(func_decl* f) const { return any_of(m_plugins, [&](plugin* p) { return p && p->include_func_interp(f); }); } ptr_vector const& parents(expr* e) { m_parents.reserve(e->get_id() + 1); diff --git a/src/ast/sls/sls_datatype_plugin.cpp b/src/ast/sls/sls_datatype_plugin.cpp index 181b2893b..7b17fe5a8 100644 --- a/src/ast/sls/sls_datatype_plugin.cpp +++ b/src/ast/sls/sls_datatype_plugin.cpp @@ -384,7 +384,7 @@ namespace sls { m_model = nullptr; } - euf::enode* datatype_plugin::get_constructor(euf::enode* n) { + euf::enode* datatype_plugin::get_constructor(euf::enode* n) const { euf::enode* con = nullptr; for (auto sib : euf::enode_class(n)) if (dt.is_constructor(sib->get_expr())) @@ -395,22 +395,22 @@ namespace sls { bool datatype_plugin::propagate() { enum color_t { white, grey, black }; svector color; - svector> todo; + ptr_vector stack; obj_map> sorts; - auto set_conflict = [&](euf::enode* n, unsigned parent_idx) { + auto set_conflict = [&](euf::enode* n) { expr_ref_vector diseqs(m); while (true) { - auto [n2, parent_idx2] = todo[parent_idx]; + auto n2 = stack.back(); auto con2 = get_constructor(n2); if (n2 != con2) diseqs.push_back(m.mk_not(m.mk_eq(n2->get_expr(), con2->get_expr()))); - parent_idx = parent_idx2; if (n2->get_root() == n->get_root()) { if (n != n2) diseqs.push_back(m.mk_not(m.mk_eq(n->get_expr(), n2->get_expr()))); break; } + stack.pop_back(); } IF_VERBOSE(1, verbose_stream() << "cycle\n"; for (auto e : diseqs) verbose_stream() << mk_pp(e, m) << "\n";); ctx.add_constraint(m.mk_or(diseqs)); @@ -437,46 +437,41 @@ namespace sls { // is a node in the same congruence class as n that is a constructor. // For every cycle accumulate a conflict. - todo.push_back({ n, 0}); - while (!todo.empty()) { - auto [n, parent_idx] = todo.back(); + stack.push_back(n); + while (!stack.empty()) { + n = stack.back(); unsigned id = n->get_root_id(); c = color.get(id, white); euf::enode* con; - unsigned idx; switch (c) { case black: - todo.pop_back(); + stack.pop_back(); break; case grey: - case white: { - bool new_child = false; + case white: color.setx(id, grey, white); con = get_constructor(n); - idx = todo.size() - 1; - if (con) { - for (auto child : euf::enode_args(con)) { - auto c2 = color.get(child->get_root_id(), white); - switch (c2) { - case black: - break; - case grey: - set_conflict(child, idx); - return true; - case white: - todo.push_back({ child, idx }); - new_child = true; - break; - } + if (!con) + goto done_with_node; + for (auto child : euf::enode_args(con)) { + auto c2 = color.get(child->get_root_id(), white); + switch (c2) { + case black: + break; + case grey: + set_conflict(child); + return true; + case white: + stack.push_back(child); + goto node_pushed; } - } - if (!new_child) { - color[id] = black; - todo.pop_back(); } - break; - } + done_with_node: + color[id] = black; + stack.pop_back(); + node_pushed: + break; } } } @@ -493,6 +488,18 @@ namespace sls { return false; } + bool datatype_plugin::include_func_interp(func_decl* f) const { + if (!dt.is_accessor(f)) + return false; + func_decl* con_decl = dt.get_accessor_constructor(f); + for (euf::enode* app : g->enodes_of(f)) { + euf::enode* con = get_constructor(app->get_arg(0)); + if (con && con->get_decl() != con_decl) + return true; + } + return false; + } + std::ostream& datatype_plugin::display(std::ostream& out) const { for (auto a : m_axioms) out << mk_bounded_pp(a, m, 3) << "\n"; @@ -504,10 +511,8 @@ namespace sls { } bool datatype_plugin::is_sat() { return true; } + void datatype_plugin::register_term(expr* e) {} - - void datatype_plugin::mk_model(model& mdl) { - } void datatype_plugin::collect_statistics(statistics& st) const { st.update("sls-dt-axioms", m_axioms.size()); diff --git a/src/ast/sls/sls_datatype_plugin.h b/src/ast/sls/sls_datatype_plugin.h index 55b04e612..c29e42d9d 100644 --- a/src/ast/sls/sls_datatype_plugin.h +++ b/src/ast/sls/sls_datatype_plugin.h @@ -38,7 +38,7 @@ namespace sls { obj_map> m_dts; obj_map> m_parents; - datatype_util dt; + mutable datatype_util dt; expr_ref_vector m_axioms, m_values; model_ref m_model; stats m_stats; @@ -52,7 +52,7 @@ namespace sls { void init_values(); void add_dep(euf::enode* n, top_sort& dep); - euf::enode* get_constructor(euf::enode* n); + euf::enode* get_constructor(euf::enode* n) const; public: datatype_plugin(context& c); @@ -66,7 +66,6 @@ namespace sls { bool is_sat() override; void register_term(expr* e) override; std::ostream& display(std::ostream& out) const override; - void mk_model(model& mdl) override; bool set_value(expr* e, expr* v) override { return false; } void repair_up(app* e) override {} @@ -75,6 +74,9 @@ namespace sls { void collect_statistics(statistics& st) const override; void reset_statistics() override; + + bool include_func_interp(func_decl* f) const override; + }; } diff --git a/src/ast/sls/sls_euf_plugin.cpp b/src/ast/sls/sls_euf_plugin.cpp index 9903fcc6f..e7263fafc 100644 --- a/src/ast/sls/sls_euf_plugin.cpp +++ b/src/ast/sls/sls_euf_plugin.cpp @@ -339,6 +339,9 @@ namespace sls { return expr_ref(e, m); } + bool euf_plugin::include_func_interp(func_decl* f) const { + return is_uninterp(f) && f->get_arity() > 0; + } bool euf_plugin::is_sat() { for (auto& [f, ts] : m_app) { @@ -477,25 +480,6 @@ namespace sls { return out; } - void euf_plugin::mk_model(model& mdl) { - expr_ref_vector args(m); - for (auto& [f, ts] : m_app) { - func_interp* fi = alloc(func_interp, m, f->get_arity()); - mdl.register_decl(f, fi); - m_values.reset(); - for (auto* t : ts) { - if (m_values.contains(t)) - continue; - args.reset(); - expr_ref val = ctx.get_value(t); - for (auto arg : *t) - args.push_back(ctx.get_value(arg)); - fi->insert_new_entry(args.data(), val); - m_values.insert(t); - } - } - } - void euf_plugin::collect_statistics(statistics& st) const { st.update("sls-euf-conflict", m_stats.m_num_conflicts); } diff --git a/src/ast/sls/sls_euf_plugin.h b/src/ast/sls/sls_euf_plugin.h index ab6977823..6bd0dfcb2 100644 --- a/src/ast/sls/sls_euf_plugin.h +++ b/src/ast/sls/sls_euf_plugin.h @@ -79,8 +79,8 @@ namespace sls { bool is_sat() override; void register_term(expr* e) override; std::ostream& display(std::ostream& out) const override; - void mk_model(model& mdl) override; bool set_value(expr* e, expr* v) override { return false; } + bool include_func_interp(func_decl* f) const override; void repair_up(app* e) override {} bool repair_down(app* e) override { return false; } @@ -89,6 +89,8 @@ namespace sls { void collect_statistics(statistics& st) const override; void reset_statistics() override; + + scoped_ptr& egraph() { return m_g; } };