diff --git a/src/sat/smt/q_mam.cpp b/src/ast/euf/euf_mam.cpp similarity index 97% rename from src/sat/smt/q_mam.cpp rename to src/ast/euf/euf_mam.cpp index 0fa35cfd4..7fb90cdc5 100644 --- a/src/sat/smt/q_mam.cpp +++ b/src/ast/euf/euf_mam.cpp @@ -30,9 +30,7 @@ Revision History: #include "ast/ast_smt2_pp.h" #include "ast/euf/euf_enode.h" #include "ast/euf/euf_egraph.h" -#include "sat/smt/q_mam.h" -#include "sat/smt/q_ematch.h" -#include "sat/smt/euf_solver.h" +#include "ast/euf/euf_mam.h" @@ -545,9 +543,9 @@ namespace q { return m_root; } - void add_candidate(euf::solver& ctx, enode * n) { + void add_candidate(euf::mam_solver& ctx, enode * n) { m_candidates.push_back(n); - ctx.push(push_back_trail(m_candidates)); + ctx.get_trail().push(push_back_trail(m_candidates)); } void unmark(unsigned head) { @@ -570,8 +568,8 @@ namespace q { return m_qhead < m_candidates.size(); } - void save_qhead(euf::solver& ctx) { - ctx.push(value_trail(m_qhead)); + void save_qhead(euf::mam_solver& ctx) { + ctx.get_trail().push(value_trail(m_qhead)); } enode* next_candidate() { @@ -628,7 +626,7 @@ namespace q { // ------------------------------------ class code_tree_manager { - euf::solver & ctx; + euf::mam_solver & ctx; label_hasher & m_lbl_hasher; region & m_region; @@ -662,7 +660,7 @@ namespace q { } public: - code_tree_manager(label_hasher & h, euf::solver& ctx): + code_tree_manager(label_hasher & h, euf::mam_solver& ctx): ctx(ctx), m_lbl_hasher(h), m_region(ctx.get_region()) { @@ -787,20 +785,20 @@ namespace q { } void set_next(instruction * instr, instruction * new_next) { - ctx.push(mam_value_trail(instr->m_next)); + ctx.get_trail().push(mam_value_trail(instr->m_next)); instr->m_next = new_next; } void save_num_regs(code_tree * tree) { - ctx.push(mam_value_trail(tree->m_num_regs)); + ctx.get_trail().push(mam_value_trail(tree->m_num_regs)); } void save_num_choices(code_tree * tree) { - ctx.push(mam_value_trail(tree->m_num_choices)); + ctx.get_trail().push(mam_value_trail(tree->m_num_choices)); } void insert_new_lbl_hash(filter * instr, unsigned h) { - ctx.push(mam_value_trail(instr->m_lbl_set)); + ctx.get_trail().push(mam_value_trail(instr->m_lbl_set)); instr->m_lbl_set.insert(h); } }; @@ -1858,7 +1856,7 @@ namespace q { typedef svector backtrack_stack; class interpreter { - euf::solver& ctx; + euf::mam_solver& ctx; ast_manager & m; mam & m_mam; bool m_use_filters; @@ -1992,7 +1990,7 @@ namespace q { #define INIT_ARGS_SIZE 16 public: - interpreter(euf::solver& ctx, mam & ma, bool use_filters): + interpreter(euf::mam_solver& ctx, mam & ma, bool use_filters): ctx(ctx), m(ctx.get_manager()), m_mam(ma), @@ -2019,7 +2017,7 @@ namespace q { if (t->filter_candidates()) { code_tree::scoped_unmark _unmark(t); while ((app = t->next_candidate()) && !ctx.resource_limits_exceeded()) { - TRACE("trigger_bug", tout << "candidate\n" << ctx.bpp(app) << "\n";); + TRACE("trigger_bug", tout << "candidate\n" << ctx.get_egraph().bpp(app) << "\n";); if (!app->is_marked3() && app->is_cgr()) { execute_core(t, app); app->mark3(); @@ -2028,7 +2026,7 @@ namespace q { } else { while ((app = t->next_candidate()) && !ctx.resource_limits_exceeded()) { - TRACE("trigger_bug", tout << "candidate\n" << ctx.bpp(app) << "\n";); + TRACE("trigger_bug", tout << "candidate\n" << ctx.get_egraph().bpp(app) << "\n";); if (app->is_cgr()) execute_core(t, app); } @@ -2820,7 +2818,7 @@ namespace q { ast_manager & m; compiler & m_compiler; ptr_vector m_trees; // mapping: func_label -> tree - euf::solver& ctx; + euf::mam_solver& ctx; #ifdef Z3DEBUG egraph * m_egraph; #endif @@ -2837,7 +2835,7 @@ namespace q { }; public: - code_tree_map(ast_manager & m, compiler & c, euf::solver& ctx): + code_tree_map(ast_manager & m, compiler & c, euf::mam_solver& ctx): m(m), m_compiler(c), ctx(ctx) { @@ -2871,7 +2869,7 @@ namespace q { m_trees[lbl_id] = m_compiler.mk_tree(qa, mp, first_idx, false); SASSERT(m_trees[lbl_id]->expected_num_args() == p->get_num_args()); DEBUG_CODE(m_trees[lbl_id]->set_egraph(m_egraph);); - ctx.push(mk_tree_trail(m_trees, lbl_id)); + ctx.get_trail().push(mk_tree_trail(m_trees, lbl_id)); } else { code_tree * tree = m_trees[lbl_id]; @@ -2884,7 +2882,7 @@ namespace q { } DEBUG_CODE(if (first_idx == 0) { m_trees[lbl_id]->get_patterns().push_back(std::make_pair(qa, mp)); - ctx.push(push_back_trail, false>(m_trees[lbl_id]->get_patterns())); + ctx.get_trail().push(push_back_trail, false>(m_trees[lbl_id]->get_patterns())); }); TRACE("trigger_bug", tout << "after add_pattern, first_idx: " << first_idx << "\n"; m_trees[lbl_id]->display(tout);); } @@ -3038,9 +3036,9 @@ namespace q { // // ------------------------------------ class mam_impl : public mam { - euf::solver& ctx; + euf::mam_solver& ctx; egraph & m_egraph; - ematch & m_ematch; + euf::on_binding_callback & m_ematch; ast_manager & m; bool m_use_filters; label_hasher m_lbl_hasher; @@ -3095,9 +3093,9 @@ namespace q { void add_candidate(code_tree * t, enode * app) { if (!t) return; - TRACE("q", tout << "candidate " << ctx.bpp(app) << "\n";); + TRACE("q", tout << "candidate " << ctx.get_egraph().bpp(app) << "\n";); if (!t->has_candidates()) { - ctx.push(push_back_trail(m_to_match)); + ctx.get_trail().push(push_back_trail(m_to_match)); m_to_match.push_back(t); } t->add_candidate(ctx, app); @@ -3120,7 +3118,7 @@ namespace q { void update_lbls(enode * n, unsigned elem) { approx_set & r_lbls = n->get_root()->get_lbls(); if (!r_lbls.may_contain(elem)) { - ctx.push(mam_value_trail(r_lbls)); + ctx.get_trail().push(mam_value_trail(r_lbls)); r_lbls.insert(elem); } } @@ -3132,7 +3130,7 @@ namespace q { TRACE("mam_bug", tout << "update_clbls: " << lbl->get_name() << " is already clbl: " << m_is_clbl[lbl_id] << "\n";); if (m_is_clbl[lbl_id]) return; - ctx.push(set_bitvector_trail(m_is_clbl, lbl_id)); + ctx.get_trail().push(set_bitvector_trail(m_is_clbl, lbl_id)); SASSERT(m_is_clbl[lbl_id]); unsigned h = m_lbl_hasher(lbl); for (enode* app : m_egraph.enodes_of(lbl)) { @@ -3152,7 +3150,7 @@ namespace q { enode * c = app->get_arg(i); approx_set & r_plbls = c->get_root()->get_plbls(); if (!r_plbls.may_contain(elem)) { - ctx.push(mam_value_trail(r_plbls)); + ctx.get_trail().push(mam_value_trail(r_plbls)); r_plbls.insert(elem); TRACE("trigger_bug", tout << "updating plabels of:\n" << mk_ismt2_pp(c->get_root()->get_expr(), m) << "\n"; tout << "new_elem: " << static_cast(elem) << "\n"; @@ -3173,7 +3171,7 @@ namespace q { TRACE("mam_bug", tout << "update_plbls: " << lbl->get_name() << " is already plbl: " << m_is_plbl[lbl_id] << "\n";); if (m_is_plbl[lbl_id]) return; - ctx.push(set_bitvector_trail(m_is_plbl, lbl_id)); + ctx.get_trail().push(set_bitvector_trail(m_is_plbl, lbl_id)); SASSERT(m_is_plbl[lbl_id]); SASSERT(is_plbl(lbl)); unsigned h = m_lbl_hasher(lbl); @@ -3220,7 +3218,7 @@ namespace q { p = p->m_child; } curr->m_code = mk_code(qa, mp, pat_idx); - ctx.push(new_obj_trail(curr->m_code)); + ctx.get_trail().push(new_obj_trail(curr->m_code)); return head; } @@ -3243,7 +3241,7 @@ namespace q { insert_code(t, qa, mp, p->m_pattern_idx); } else { - ctx.push(set_ptr_trail(t->m_first_child)); + ctx.get_trail().push(set_ptr_trail(t->m_first_child)); t->m_first_child = mk_path_tree(p->m_child, qa, mp); } } @@ -3253,9 +3251,9 @@ namespace q { insert_code(t, qa, mp, p->m_pattern_idx); } else { - ctx.push(set_ptr_trail(t->m_code)); + ctx.get_trail().push(set_ptr_trail(t->m_code)); t->m_code = mk_code(qa, mp, p->m_pattern_idx); - ctx.push(new_obj_trail(t->m_code)); + ctx.get_trail().push(new_obj_trail(t->m_code)); } } else { @@ -3268,10 +3266,10 @@ namespace q { prev_sibling = t; t = t->m_sibling; } - ctx.push(set_ptr_trail(prev_sibling->m_sibling)); + ctx.get_trail().push(set_ptr_trail(prev_sibling->m_sibling)); prev_sibling->m_sibling = mk_path_tree(p, qa, mp); if (!found_label) { - ctx.push(value_trail(head->m_filter)); + ctx.get_trail().push(value_trail(head->m_filter)); head->m_filter.insert(m_lbl_hasher(p->m_label)); } } @@ -3281,7 +3279,7 @@ namespace q { insert(m_pc[h1][h2], p, qa, mp); } else { - ctx.push(set_ptr_trail(m_pc[h1][h2])); + ctx.get_trail().push(set_ptr_trail(m_pc[h1][h2])); m_pc[h1][h2] = mk_path_tree(p, qa, mp); } TRACE("mam_path_tree_updt", @@ -3298,7 +3296,7 @@ namespace q { insert(m_pp[h1][h2].first, p2, qa, mp); } else { - ctx.push(set_ptr_trail(m_pp[h1][h2].first)); + ctx.get_trail().push(set_ptr_trail(m_pp[h1][h2].first)); m_pp[h1][h2].first = mk_path_tree(p1, qa, mp); insert(m_pp[h1][h2].first, p2, qa, mp); } @@ -3316,8 +3314,8 @@ namespace q { } else { SASSERT(m_pp[h1][h2].second == nullptr); - ctx.push(set_ptr_trail(m_pp[h1][h2].first)); - ctx.push(set_ptr_trail(m_pp[h1][h2].second)); + ctx.get_trail().push(set_ptr_trail(m_pp[h1][h2].first)); + ctx.get_trail().push(set_ptr_trail(m_pp[h1][h2].second)); m_pp[h1][h2].first = mk_path_tree(p1, qa, mp); m_pp[h1][h2].second = mk_path_tree(p2, qa, mp); } @@ -3462,7 +3460,7 @@ namespace q { \brief Collect new E-matching candidates using the inverted path index t. */ void collect_parents(enode * r, path_tree * t) { - TRACE("mam", tout << ctx.bpp(r) << " " << t << "\n";); + TRACE("mam", tout << ctx.get_egraph().bpp(r) << " " << t << "\n";); if (t == nullptr) return; #ifdef _PROFILE_PATH_TREE @@ -3682,7 +3680,7 @@ namespace q { void propagate_new_patterns() { if (m_new_patterns_qhead >= m_new_patterns.size()) return; - ctx.push(value_trail(m_new_patterns_qhead)); + ctx.get_trail().push(value_trail(m_new_patterns_qhead)); TRACE("mam_new_pat", tout << "matching new patterns:\n";); m_tmp_trees_to_delete.reset(); @@ -3725,7 +3723,7 @@ namespace q { } public: - mam_impl(euf::solver & ctx, ematch& ematch, bool use_filters): + mam_impl(euf::mam_solver & ctx, euf::on_binding_callback& ematch, bool use_filters): ctx(ctx), m_egraph(ctx.get_egraph()), m_ematch(ematch), @@ -3754,7 +3752,7 @@ namespace q { return; // ignore multi-pattern containing ground pattern. update_filters(qa, mp); m_new_patterns.push_back(qp_pair(qa, mp)); - ctx.push(push_back_trail(m_new_patterns)); + ctx.get_trail().push(push_back_trail(m_new_patterns)); // The matching abstract machine implements incremental // e-matching. So, for a multi-pattern [ p_1, ..., p_n ], // we have to make n insertions. In the i-th insertion, @@ -3782,7 +3780,7 @@ namespace q { void propagate_to_match() { if (m_to_match_head >= m_to_match.size()) return; - ctx.push(value_trail(m_to_match_head)); + ctx.get_trail().push(value_trail(m_to_match_head)); for (; m_to_match_head < m_to_match.size(); ++m_to_match_head) m_interpreter.execute(m_to_match[m_to_match_head]); } @@ -3872,8 +3870,8 @@ namespace q { approx_set other_lbls = other->get_lbls(); approx_set & root_lbls = root->get_lbls(); - ctx.push(mam_value_trail(root_lbls)); - ctx.push(mam_value_trail(root_plbls)); + ctx.get_trail().push(mam_value_trail(root_lbls)); + ctx.get_trail().push(mam_value_trail(root_plbls)); root_lbls |= other_lbls; root_plbls |= other_plbls; TRACE("mam_inc_bug", @@ -3909,7 +3907,7 @@ namespace q { } } - mam* mam::mk(euf::solver& ctx, ematch& em) { + mam* mam::mk(euf::mam_solver& ctx, euf::on_binding_callback& em) { return alloc(mam_impl, ctx, em, true); } diff --git a/src/sat/smt/q_mam.h b/src/ast/euf/euf_mam.h similarity index 66% rename from src/sat/smt/q_mam.h rename to src/ast/euf/euf_mam.h index e642f9e44..72d2f8f0e 100644 --- a/src/sat/smt/q_mam.h +++ b/src/ast/euf/euf_mam.h @@ -22,7 +22,23 @@ Author: #include namespace euf { - class solver; + + class mam_solver { + public: + virtual ~mam_solver() = default; + virtual trail_stack& get_trail() = 0; + virtual region& get_region() = 0; + virtual ast_manager& get_manager() = 0; + virtual egraph& get_egraph() = 0; + virtual bool is_relevant(euf::enode* n) const = 0; + virtual bool resource_limits_exceeded() const = 0; + }; + + class on_binding_callback { + public: + virtual ~on_binding_callback() = default; + virtual void on_binding(quantifier* q, app* pat, euf::enode* const* binding, unsigned max_generation, unsigned min_gen, unsigned max_gen) = 0; + }; }; namespace q { @@ -43,7 +59,7 @@ namespace q { public: - static mam * mk(euf::solver& ctx, ematch& em); + static mam * mk(euf::mam_solver& ctx, euf::on_binding_callback& em); virtual ~mam() = default; diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 1e2193501..02aa22806 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -36,7 +36,6 @@ z3_add_component(sat_smt q_clause.cpp q_ematch.cpp q_eval.cpp - q_mam.cpp q_mbi.cpp q_model_fixer.cpp q_theory_checker.cpp diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 06ee45df0..73e8878e3 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -21,6 +21,7 @@ Author: #include "ast/ast_translation.h" #include "ast/ast_util.h" #include "ast/euf/euf_egraph.h" +#include "ast/euf/euf_mam.h" #include "ast/rewriter/th_rewriter.h" #include "ast/converters/model_converter.h" #include "sat/sat_extension.h" @@ -83,7 +84,7 @@ namespace euf { expr* get_hint(euf::solver& s) const override; }; - class solver : public sat::extension, public th_internalizer, public th_decompile, public sat::clause_eh { + class solver : public sat::extension, public th_internalizer, public th_decompile, public sat::clause_eh, public mam_solver { typedef top_sort deps_t; friend class ackerman; friend class eq_proof_hint; @@ -331,6 +332,7 @@ namespace euf { push(push_back_trail< V, false>(vec)); } trail_stack& get_trail_stack() { return m_trail; } + trail_stack& get_trail() override { return m_trail; } void updt_params(params_ref const& p); void set_solver(sat::solver* s) override { m_solver = s; use_drat(); } @@ -398,7 +400,7 @@ namespace euf { bool is_blocked(literal l, ext_constraint_idx) override; bool check_model(sat::model const& m) const override; void gc_vars(unsigned num_vars) override; - bool resource_limits_exceeded() const { return false; } // TODO + bool resource_limits_exceeded() const override { return false; } // TODO // proof diff --git a/src/sat/smt/q_ematch.cpp b/src/sat/smt/q_ematch.cpp index 76d234d8d..e1a5ed175 100644 --- a/src/sat/smt/q_ematch.cpp +++ b/src/sat/smt/q_ematch.cpp @@ -29,6 +29,7 @@ Done: --*/ #include "ast/ast_util.h" +#include "ast/euf/euf_mam.h" #include "ast/rewriter/var_subst.h" #include "ast/rewriter/rewriter_def.h" #include "ast/normal_forms/pull_quant.h" @@ -36,7 +37,7 @@ Done: #include "sat/smt/sat_th.h" #include "sat/smt/euf_solver.h" #include "sat/smt/q_solver.h" -#include "sat/smt/q_mam.h" + #include "sat/smt/q_ematch.h" diff --git a/src/sat/smt/q_ematch.h b/src/sat/smt/q_ematch.h index f7de55fb8..5edefff05 100644 --- a/src/sat/smt/q_ematch.h +++ b/src/sat/smt/q_ematch.h @@ -22,7 +22,7 @@ Author: #include "ast/normal_forms/nnf.h" #include "solver/solver.h" #include "sat/smt/sat_th.h" -#include "sat/smt/q_mam.h" +#include "ast/euf/euf_mam.h" #include "sat/smt/q_clause.h" #include "sat/smt/q_queue.h" #include "sat/smt/q_eval.h" @@ -35,7 +35,7 @@ namespace q { class solver; - class ematch { + class ematch : public euf::on_binding_callback { struct stats { unsigned m_num_instantiations; unsigned m_num_propagations; @@ -149,7 +149,7 @@ namespace q { void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing); // callback from mam - void on_binding(quantifier* q, app* pat, euf::enode* const* binding, unsigned max_generation, unsigned min_gen, unsigned max_gen); + void on_binding(quantifier* q, app* pat, euf::enode* const* binding, unsigned max_generation, unsigned min_gen, unsigned max_gen) override; // callbacks from queue lbool evaluate(euf::enode* const* binding, clause& c) { m_evidence.reset(); return m_eval(binding, c, m_evidence); }