diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index a121b19e6..86de919dc 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -687,6 +687,11 @@ public: return ensure_euf()->user_propagate_register_expr(e); } + void user_propagate_register_created(user_propagator::created_eh_t& r) { + ensure_euf()->user_propagate_register_created(r); + } + + private: lbool internalize_goal(goal_ref& g) { diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index f1d5733d7..d299c92d0 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -430,6 +430,10 @@ namespace euf { check_for_user_propagator(); m_user_propagator->register_diseq(diseq_eh); } + void user_propagate_register_created(user_propagator::created_eh_t& ceh) { + check_for_user_propagator(); + m_user_propagator->register_created(ceh); + } unsigned user_propagate_register_expr(expr* e) { check_for_user_propagator(); return m_user_propagator->add_expr(e); diff --git a/src/sat/smt/user_solver.cpp b/src/sat/smt/user_solver.cpp index 58de15855..febbe9383 100644 --- a/src/sat/smt/user_solver.cpp +++ b/src/sat/smt/user_solver.cpp @@ -186,5 +186,51 @@ namespace user_solver { return result; } + sat::literal solver::internalize(expr* e, bool sign, bool root, bool redundant) { + if (!visit_rec(m, e, sign, root, redundant)) { + TRACE("array", tout << mk_pp(e, m) << "\n";); + return sat::null_literal; + } + sat::literal lit = ctx.expr2literal(e); + if (sign) + lit.neg(); + if (root) + add_unit(lit); + return lit; + } + + void solver::internalize(expr* e, bool redundant) { + visit_rec(m, e, false, false, redundant); + } + + bool solver::visit(expr* e) { + if (visited(e)) + return true; + if (!is_app(e) || to_app(e)->get_family_id() != get_id()) { + ctx.internalize(e, m_is_redundant); + return true; + } + m_stack.push_back(sat::eframe(e)); + return false; + } + + bool solver::visited(expr* e) { + euf::enode* n = expr2enode(e); + return n && n->is_attached_to(get_id()); + } + + bool solver::post_visit(expr* e, bool sign, bool root) { + euf::enode* n = expr2enode(e); + SASSERT(!n || !n->is_attached_to(get_id())); + if (!n) + n = mk_enode(e, false); + auto v = add_expr(e); + if (m_created_eh) + m_created_eh(m_user_context, this, e, v); + return true; + } + + + } diff --git a/src/sat/smt/user_solver.h b/src/sat/smt/user_solver.h index 087633ece..a30bc6a6d 100644 --- a/src/sat/smt/user_solver.h +++ b/src/sat/smt/user_solver.h @@ -64,6 +64,7 @@ namespace user_solver { user_propagator::fixed_eh_t m_fixed_eh; user_propagator::eq_eh_t m_eq_eh; user_propagator::eq_eh_t m_diseq_eh; + user_propagator::created_eh_t m_created_eh; user_propagator::context_obj* m_api_context = nullptr; unsigned m_qhead = 0; vector m_prop; @@ -94,6 +95,10 @@ namespace user_solver { void validate_propagation(); + bool visit(expr* e) override; + bool visited(expr* e) override; + bool post_visit(expr* e, bool sign, bool root) override; + public: solver(euf::solver& ctx); @@ -119,6 +124,7 @@ namespace user_solver { void register_fixed(user_propagator::fixed_eh_t& fixed_eh) { m_fixed_eh = fixed_eh; } void register_eq(user_propagator::eq_eh_t& eq_eh) { m_eq_eh = eq_eh; } void register_diseq(user_propagator::eq_eh_t& diseq_eh) { m_diseq_eh = diseq_eh; } + void register_created(user_propagator::created_eh_t& created_eh) { m_created_eh = created_eh; } bool has_fixed() const { return (bool)m_fixed_eh; } @@ -134,8 +140,8 @@ namespace user_solver { bool unit_propagate() override; void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override; void collect_statistics(statistics& st) const override; - sat::literal internalize(expr* e, bool sign, bool root, bool learned) override { UNREACHABLE(); return sat::null_literal; } - void internalize(expr* e, bool redundant) override { UNREACHABLE(); } + sat::literal internalize(expr* e, bool sign, bool root, bool learned) override; + void internalize(expr* e, bool redundant) override; std::ostream& display(std::ostream& out) const override; std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override; std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override;