diff --git a/src/smt/CMakeLists.txt b/src/smt/CMakeLists.txt index f7a17a0ef..ba1c58cdb 100644 --- a/src/smt/CMakeLists.txt +++ b/src/smt/CMakeLists.txt @@ -80,6 +80,7 @@ z3_add_component(smt theory_str_regex.cpp theory_utvpi.cpp theory_wmaxsat.cpp + user_propagator.cpp uses_theory.cpp watch_list.cpp COMPONENT_DEPENDENCIES diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index e4dc193ca..b869be69e 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -853,7 +853,7 @@ namespace smt { assign(l, mk_justification(eq_root_propagation_justification(curr))); curr = curr->m_next; } - while(curr != r1); + while (curr != r1); } else { bool_var v1 = enode2bool_var(n1); @@ -1394,6 +1394,8 @@ namespace smt { bool sign = val == l_false; if (n->merge_tf()) add_eq(n, sign ? m_false_enode : m_true_enode, eq_justification(literal(v, sign))); + if (watches_fixed(n)) + assign_fixed(n, sign ? m.mk_false() : m.mk_true(), literal(v, sign)); enode * r = n->get_root(); if (r == m_true_enode || r == m_false_enode) return; @@ -1924,6 +1926,8 @@ namespace smt { for (theory* t : m_theory_set) t->push_scope_eh(); + if (m_user_propagator) + m_user_propagator->push_scope_eh(); CASSERT("context", check_invariant()); } @@ -2418,9 +2422,11 @@ namespace smt { unassign_vars(s.m_assigned_literals_lim); undo_trail_stack(s.m_trail_stack_lim); - for (theory* th : m_theory_set) { + for (theory* th : m_theory_set) th->pop_scope_eh(num_scopes); - } + + if (m_user_propagator) + m_user_propagator->pop_scope_eh(num_scopes); del_justifications(m_justifications, s.m_justifications_lim); @@ -2872,6 +2878,26 @@ namespace smt { } } + void context::register_user_propagator( + void* ctx, + std::function& fixed_eh, + std::function& push_eh, + std::function& pop_eh) { + m_user_propagator = alloc(user_propagator, *this); + m_user_propagator->add(ctx, fixed_eh, push_eh, pop_eh); + for (unsigned i = m_scopes.size(); i-- > 0; ) + m_user_propagator->push_scope_eh(); + } + + bool context::watches_fixed(enode* n) const { + return m_user_propagator && n->get_th_var(m_user_propagator->get_family_id()) != null_theory_var; + } + + void context::assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain) { + theory_var v = n->get_th_var(m_user_propagator->get_family_id()); + m_user_propagator->new_fixed_eh(v, val, sz, explain); + } + void context::push() { pop_to_base_lvl(); setup_context(false); diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 0fed7e459..e425e025b 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -43,12 +43,13 @@ Revision History: #include "ast/ast_smt_pp.h" #include "smt/watch_list.h" #include "util/trail.h" -#include "smt/fingerprints.h" #include "util/ref.h" -#include "smt/proto_model/proto_model.h" -#include "model/model.h" #include "util/timer.h" #include "util/statistics.h" +#include "smt/fingerprints.h" +#include "smt/proto_model/proto_model.h" +#include "smt/user_propagator.h" +#include "model/model.h" #include "solver/progress_callback.h" #include @@ -92,6 +93,7 @@ namespace smt { scoped_ptr m_qmanager; scoped_ptr m_model_generator; scoped_ptr m_relevancy_propagator; + scoped_ptr m_user_propagator; random_gen m_random; bool m_flushing; // (debug support) true when flushing mutable unsigned m_lemma_id; @@ -1677,6 +1679,27 @@ namespace smt { void get_assertions(ptr_vector & result) { m_asserted_formulas.get_assertions(result); } + /* + * user-propagator + */ + void register_user_propagator( + void* ctx, + std::function& fixed_eh, + std::function& push_eh, + std::function& pop_eh); + + bool watches_fixed(enode* n) const; + + void assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain); + + void assign_fixed(enode* n, expr* val, literal_vector const& explain) { + assign_fixed(n, val, explain.size(), explain.c_ptr()); + } + + void assign_fixed(enode* n, expr* val, literal explain) { + assign_fixed(n, val, 1, &explain); + } + void display(std::ostream & out) const; void display_unsat_core(std::ostream & out) const; diff --git a/src/smt/smt_kernel.cpp b/src/smt/smt_kernel.cpp index 4f48aec33..4f134ecac 100644 --- a/src/smt/smt_kernel.cpp +++ b/src/smt/smt_kernel.cpp @@ -232,6 +232,14 @@ namespace smt { void updt_params(params_ref const & p) { m_kernel.updt_params(p); } + + void register_user_propagator( + void* ctx, + std::function& fixed_eh, + std::function& push_eh, + std::function& pop_eh) { + m_kernel.register_user_propagator(ctx, fixed_eh, push_eh, pop_eh); + } }; kernel::kernel(ast_manager & m, smt_params & fp, params_ref const & p) { @@ -445,6 +453,14 @@ namespace smt { return m_imp->get_implied_upper_bound(e); } + void kernel::register_user_propagator( + void* ctx, + std::function& fixed_eh, + std::function& push_eh, + std::function& pop_eh) { + m_imp->register_user_propagator(ctx, fixed_eh, push_eh, pop_eh); + } + }; diff --git a/src/smt/smt_kernel.h b/src/smt/smt_kernel.h index b9f80ac01..c14ba66bf 100644 --- a/src/smt/smt_kernel.h +++ b/src/smt/smt_kernel.h @@ -284,6 +284,15 @@ namespace smt { */ static void collect_param_descrs(param_descrs & d); + /** + \brief register a user-propagator "theory" + */ + void register_user_propagator( + void* ctx, + std::function& fixed_eh, + std::function& push_eh, + std::function& pop_eh); + /** \brief Return a reference to smt::context. This is a temporary hack to support user theories. diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 495448eeb..62e7947d7 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -208,6 +208,15 @@ namespace { return m_context.get_trail(); } + void register_user_propagator( + void* ctx, + std::function& fixed_eh, + std::function& push_eh, + std::function& pop_eh) override { + m_context.register_user_propagator(ctx, fixed_eh, push_eh, pop_eh); + } + + struct scoped_minimize_core { smt_solver& s; expr_ref_vector m_assumptions; diff --git a/src/smt/smt_theory.cpp b/src/smt/smt_theory.cpp index 525285466..04a29e904 100644 --- a/src/smt/smt_theory.cpp +++ b/src/smt/smt_theory.cpp @@ -150,6 +150,18 @@ namespace smt { return lit; } + literal theory::mk_literal(expr* _e) { + expr_ref e(_e, m); + bool is_not = m.is_not(_e, _e); + if (!ctx.e_internalized(_e)) { + ctx.internalize(_e, is_quantifier(_e)); + } + literal lit = ctx.get_literal(_e); + ctx.mark_as_relevant(lit); + if (is_not) lit.neg(); + return lit; + } + enode* theory::ensure_enode(expr* e) { if (!ctx.e_internalized(e)) { ctx.internalize(e, is_quantifier(e)); diff --git a/src/smt/smt_theory.h b/src/smt/smt_theory.h index 991dbd5d7..61dc97c82 100644 --- a/src/smt/smt_theory.h +++ b/src/smt/smt_theory.h @@ -533,6 +533,8 @@ namespace smt { literal mk_preferred_eq(expr* a, expr* b); + literal mk_literal(expr* e); + enode* ensure_enode(expr* e); enode* get_root(expr* e) { return ensure_enode(e)->get_root(); } diff --git a/src/smt/theory_bv.cpp b/src/smt/theory_bv.cpp index c0218fdf8..033df7aa0 100644 --- a/src/smt/theory_bv.cpp +++ b/src/smt/theory_bv.cpp @@ -465,6 +465,11 @@ namespace smt { void theory_bv::fixed_var_eh(theory_var v) { numeral val; VERIFY(get_fixed_value(v, val)); + enode* n = get_enode(v); + if (ctx.watches_fixed(n)) { + expr_ref num(m_util.mk_numeral(val, m.get_sort(n->get_owner())), m); + ctx.assign_fixed(n, num, m_bits[v]); + } unsigned sz = get_bv_size(v); value_sort_pair key(val, sz); theory_var v2; @@ -528,8 +533,8 @@ namespace smt { void theory_bv::internalize_num(app * n) { SASSERT(!ctx.e_internalized(n)); numeral val; - unsigned sz; - m_util.is_numeral(n, val, sz); + unsigned sz = 0; + VERIFY(m_util.is_numeral(n, val, sz)); enode * e = mk_enode(n); // internalizer is marking enodes as interpreted whenever the associated ast is a value and a constant. // e->mark_as_interpreted(); diff --git a/src/smt/theory_seq.cpp b/src/smt/theory_seq.cpp index f1ef7832d..c1e182aa2 100644 --- a/src/smt/theory_seq.cpp +++ b/src/smt/theory_seq.cpp @@ -2785,18 +2785,6 @@ literal theory_seq::mk_simplified_literal(expr * _e) { return mk_literal(e); } -literal theory_seq::mk_literal(expr* _e) { - expr_ref e(_e, m); - bool is_not = m.is_not(_e, _e); - if (!ctx.e_internalized(_e)) { - ctx.internalize(_e, is_quantifier(_e)); - } - literal lit = ctx.get_literal(_e); - ctx.mark_as_relevant(lit); - if (is_not) lit.neg(); - return lit; -} - literal theory_seq::mk_seq_eq(expr* a, expr* b) { SASSERT(m_util.is_seq(a)); return mk_literal(m_sk.mk_eq(a, b)); diff --git a/src/smt/theory_seq.h b/src/smt/theory_seq.h index f0e5e06f2..f98eaf9cf 100644 --- a/src/smt/theory_seq.h +++ b/src/smt/theory_seq.h @@ -575,7 +575,6 @@ namespace smt { expr_ref add_elim_string_axiom(expr* n); void add_in_re_axiom(expr* n); - literal mk_literal(expr* n); literal mk_simplified_literal(expr* n); literal mk_eq_empty(expr* n, bool phase = true); literal mk_seq_eq(expr* a, expr* b); diff --git a/src/smt/theory_str_mc.cpp b/src/smt/theory_str_mc.cpp index 9b7df6c64..c8a6f3dcb 100644 --- a/src/smt/theory_str_mc.cpp +++ b/src/smt/theory_str_mc.cpp @@ -446,6 +446,10 @@ namespace smt { ast_manager & m = get_manager(); ast_manager & sub_m = subsolver.m(); + + // NSB code review: to remove dependencies on subsolver.get_context(). + // It uses a method that should be removed from smt_kernel. + // currently sub_ctx is used to retrieve a rewriter. Theory_str already has a rewriter attahed. context & sub_ctx = subsolver.get_context(); expr * str = nullptr, *re = nullptr; diff --git a/src/smt/user_propagator.cpp b/src/smt/user_propagator.cpp new file mode 100644 index 000000000..c4c15481f --- /dev/null +++ b/src/smt/user_propagator.cpp @@ -0,0 +1,85 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + user_propagator.cpp + +Abstract: + + User theory propagator plugin. + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-17 + +--*/ + + +#include "smt/user_propagator.h" +#include "smt/smt_context.h" + +using namespace smt; + +user_propagator::user_propagator(context& ctx): + theory(ctx, ctx.get_manager().mk_family_id("user_propagator")), + m_qhead(0) +{} + +unsigned user_propagator::add_expr(expr* e) { + return mk_var(ensure_enode(e)); +} + +void user_propagator::new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits) { + m_id2justification.setx(v, literal_vector(num_lits, jlits), literal_vector()); + m_fixed_eh(m_user_context, v, value); +} + +void user_propagator::push_scope_eh() { + theory::push_scope_eh(); + m_push_eh(m_user_context); + m_prop_lim.push_back(m_prop.size()); +} + +void user_propagator::pop_scope_eh(unsigned num_scopes) { + m_pop_eh(m_user_context, num_scopes); + theory::pop_scope_eh(num_scopes); + unsigned old_sz = m_prop_lim.size() - num_scopes; + m_prop.shrink(m_prop_lim[old_sz]); + m_prop_lim.shrink(old_sz); +} + +bool user_propagator::can_propagate() { + return m_qhead < m_prop.size(); +} + +void user_propagator::propagate() { + unsigned qhead = m_qhead; + literal_vector lits; + enode_pair_vector eqs; + justification* js; + while (qhead < m_prop.size() && !ctx.inconsistent()) { + auto const& prop = m_prop[qhead]; + lits.reset(); + for (unsigned id : prop.m_ids) + lits.append(m_id2justification[id]); + if (m.is_false(prop.m_conseq)) { + js = ctx.mk_justification( + ext_theory_conflict_justification( + get_id(), ctx.get_region(), lits.size(), lits.c_ptr(), eqs.size(), eqs.c_ptr(), 0, nullptr)); + ctx.set_conflict(js); + } + else { + literal lit = mk_literal(prop.m_conseq); + js = ctx.mk_justification( + ext_theory_propagation_justification( + get_id(), ctx.get_region(), lits.size(), lits.c_ptr(), eqs.size(), eqs.c_ptr(), lit)); + ctx.assign(lit, js); + } + ++qhead; + } + ctx.push_trail(value_trail(m_qhead)); + m_qhead = qhead; +} + + diff --git a/src/smt/user_propagator.h b/src/smt/user_propagator.h new file mode 100644 index 000000000..4eaebc0d6 --- /dev/null +++ b/src/smt/user_propagator.h @@ -0,0 +1,97 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + user_propagator.h + +Abstract: + + User-propagator plugin. + Or, user-propagator in response to registered + terms receiveing fixed values. + +Author: + + Nikolaj Bjorner (nbjorner) 2020-08-17 + +Notes: + +- could also be complemented with disequalities to fixed values to narrow range of variables. + + +--*/ + +#pragma once + +#include "smt/smt_theory.h" + +namespace smt { + class user_propagator : public theory { + void* m_user_context; + std::function m_fixed_eh; + std::function m_push_eh; + std::function m_pop_eh; + struct prop_info { + unsigned_vector m_ids; + expr_ref m_conseq; + prop_info(unsigned sz, unsigned const* ids, expr_ref const& c): + m_ids(sz, ids), + m_conseq(c) + {} + }; + unsigned m_qhead; + vector m_prop; + unsigned_vector m_prop_lim; + vector m_id2justification; + + public: + user_propagator(context& ctx); + + ~user_propagator() override {} + + /* + * \brief initial setup for user propagator. + */ + void add( + void* ctx, + std::function& fixed_eh, + std::function& push_eh, + std::function& pop_eh) { + m_user_context = ctx; + m_fixed_eh = fixed_eh; + m_push_eh = push_eh; + m_pop_eh = pop_eh; + } + + unsigned add_expr(expr* e); + + void add_propagation(unsigned sz, unsigned const* ids, expr* conseq) { + m_prop.push_back(prop_info(sz, ids, expr_ref(conseq, m))); + } + + void new_fixed_eh(theory_var v, expr* value, unsigned num_lits, literal const* jlits); + + theory * mk_fresh(context * new_ctx) override { UNREACHABLE(); return alloc(user_propagator, *new_ctx); } + bool internalize_atom(app * atom, bool gate_ctx) override { UNREACHABLE(); return false; } + bool internalize_term(app * term) override { UNREACHABLE(); return false; } + void new_eq_eh(theory_var v1, theory_var v2) override { UNREACHABLE(); } + void new_diseq_eh(theory_var v1, theory_var v2) override { UNREACHABLE(); } + bool use_diseqs() const override { return false; } + bool build_models() const override { return false; } + final_check_status final_check_eh() override { UNREACHABLE(); return FC_DONE; } + void reset_eh() override {} + void assign_eh(bool_var v, bool is_true) override { UNREACHABLE(); } + void init_search_eh() override {} + void push_scope_eh() override; + void pop_scope_eh(unsigned num_scopes) override; + void restart_eh() override {} + void collect_statistics(::statistics & st) const override {} + model_value_proc * mk_value(enode * n, model_generator & mg) override { return nullptr; } + void init_model(model_generator & m) override {} + bool include_func_interp(func_decl* f) override { return false; } + bool can_propagate() override; + void propagate() override; + void display(std::ostream& out) const {} + }; +}; diff --git a/src/solver/solver.h b/src/solver/solver.h index e3d980021..6adefc6e9 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -238,6 +238,14 @@ public: virtual expr_ref get_implied_upper_bound(expr* e) = 0; + virtual void register_user_propagator( + void* ctx, + std::function& fixed_eh, + std::function& push_eh, + std::function& pop_eh) { + throw default_exception("user-propagators are only supported on the SMT solver"); + } + /** \brief Display the content of this solver.