mirror of
https://github.com/Z3Prover/z3
synced 2025-04-08 10:25:18 +00:00
force-push on new_eq, new_diseq in user propagator, other fixes to Python bindings for user propagator
This update allows the python bindings for user-propagator to handle functions that are declared to be registered with the user propagator plugin. It fixes a bug in UserPropagateBase.add to allow registering terms dynamically during search. It also fixes a bug in theory_user_propagate as scopes were not fully pushed when the solver gets the callbacks for new equalities and new disequalities. It also adds equality and disequality interfaces to the sat/smt solver version (which isn't being exercised in earnest yet)
This commit is contained in:
parent
3e38bbb009
commit
5c2c0ae900
|
@ -11364,12 +11364,12 @@ def to_ContextObj(ptr,):
|
|||
return ctx
|
||||
|
||||
|
||||
def user_prop_fresh(ctx, new_ctx):
|
||||
def user_prop_fresh(ctx, _new_ctx):
|
||||
_prop_closures.set_threaded()
|
||||
prop = _prop_closures.get(ctx)
|
||||
nctx = Context()
|
||||
Z3_del_context(nctx.ctx)
|
||||
new_ctx = to_ContextObj(new_ctx)
|
||||
new_ctx = to_ContextObj(_new_ctx)
|
||||
nctx.ctx = new_ctx
|
||||
nctx.eh = Z3_set_error_handler(new_ctx, z3_error_handler)
|
||||
nctx.owner = False
|
||||
|
@ -11390,6 +11390,13 @@ def user_prop_fixed(ctx, cb, id, value):
|
|||
prop.fixed(id, value)
|
||||
prop.cb = None
|
||||
|
||||
def user_prop_created(ctx, cb, id):
|
||||
prop = _prop_closures.get(ctx)
|
||||
prop.cb = cb
|
||||
id = _to_expr_ref(to_Ast(id), prop.ctx())
|
||||
prop.created(id)
|
||||
prop.cb = None
|
||||
|
||||
def user_prop_final(ctx, cb):
|
||||
prop = _prop_closures.get(ctx)
|
||||
prop.cb = cb
|
||||
|
@ -11417,10 +11424,32 @@ _user_prop_push = Z3_push_eh(user_prop_push)
|
|||
_user_prop_pop = Z3_pop_eh(user_prop_pop)
|
||||
_user_prop_fresh = Z3_fresh_eh(user_prop_fresh)
|
||||
_user_prop_fixed = Z3_fixed_eh(user_prop_fixed)
|
||||
_user_prop_created = Z3_created_eh(user_prop_created)
|
||||
_user_prop_final = Z3_final_eh(user_prop_final)
|
||||
_user_prop_eq = Z3_eq_eh(user_prop_eq)
|
||||
_user_prop_diseq = Z3_eq_eh(user_prop_diseq)
|
||||
|
||||
def PropagateFunction(name, *sig):
|
||||
"""Create a function that gets tracked by user propagator.
|
||||
Every term headed by this function symbol is tracked.
|
||||
If a term is fixed and the fixed callback is registered a
|
||||
callback is invoked that the term headed by this function is fixed.
|
||||
"""
|
||||
sig = _get_args(sig)
|
||||
if z3_debug():
|
||||
_z3_assert(len(sig) > 0, "At least two arguments expected")
|
||||
arity = len(sig) - 1
|
||||
rng = sig[arity]
|
||||
if z3_debug():
|
||||
_z3_assert(is_sort(rng), "Z3 sort expected")
|
||||
dom = (Sort * arity)()
|
||||
for i in range(arity):
|
||||
if z3_debug():
|
||||
_z3_assert(is_sort(sig[i]), "Z3 sort expected")
|
||||
dom[i] = sig[i].ast
|
||||
ctx = rng.ctx
|
||||
return FuncDeclRef(Z3_solver_propagate_declare(ctx.ref(), to_symbol(name, ctx), arity, dom, rng.ast), ctx)
|
||||
|
||||
|
||||
class UserPropagateBase:
|
||||
|
||||
|
@ -11443,6 +11472,7 @@ class UserPropagateBase:
|
|||
self.final = None
|
||||
self.eq = None
|
||||
self.diseq = None
|
||||
self.created = None
|
||||
if ctx:
|
||||
self.fresh_ctx = ctx
|
||||
if s:
|
||||
|
@ -11473,6 +11503,13 @@ class UserPropagateBase:
|
|||
Z3_solver_propagate_fixed(self.ctx_ref(), self.solver.solver, _user_prop_fixed)
|
||||
self.fixed = fixed
|
||||
|
||||
def add_created(self, created):
|
||||
assert not self.created
|
||||
assert not self._ctx
|
||||
if self.solver:
|
||||
Z3_solver_propagate_created(self.ctx_ref(), self.solver.solver, _user_prop_created)
|
||||
self.created = created
|
||||
|
||||
def add_final(self, final):
|
||||
assert not self.final
|
||||
assert not self._ctx
|
||||
|
@ -11504,9 +11541,12 @@ class UserPropagateBase:
|
|||
raise Z3Exception("fresh needs to be overwritten")
|
||||
|
||||
def add(self, e):
|
||||
assert self.solver
|
||||
assert not self._ctx
|
||||
Z3_solver_propagate_register(self.ctx_ref(), self.solver.solver, e.ast)
|
||||
if self.solver:
|
||||
Z3_solver_propagate_register(self.ctx_ref(), self.solver.solver, e.ast)
|
||||
else:
|
||||
Z3_solver_propagate_register_cb(self.ctx_ref(), ctypes.c_void_p(self.cb), e.ast)
|
||||
|
||||
|
||||
#
|
||||
# Propagation can only be invoked as during a fixed or final callback.
|
||||
|
@ -11519,5 +11559,5 @@ class UserPropagateBase:
|
|||
Z3_solver_propagate_consequence(e.ctx.ref(), ctypes.c_void_p(
|
||||
self.cb), num_fixed, _ids, num_eqs, _lhs, _rhs, e.ast)
|
||||
|
||||
def conflict(self, deps):
|
||||
self.propagate(BoolVal(False, self.ctx()), deps, eqs=[])
|
||||
def conflict(self, deps = [], eqs = []):
|
||||
self.propagate(BoolVal(False, self.ctx()), deps, eqs)
|
||||
|
|
|
@ -1429,7 +1429,7 @@ ast_manager::~ast_manager() {
|
|||
}
|
||||
m_plugins.reset();
|
||||
while (!m_ast_table.empty()) {
|
||||
DEBUG_CODE(IF_VERBOSE(0, verbose_stream() << "ast_manager LEAKED: " << m_ast_table.size() << std::endl););
|
||||
DEBUG_CODE(IF_VERBOSE(1, verbose_stream() << "ast_manager LEAKED: " << m_ast_table.size() << std::endl););
|
||||
ptr_vector<ast> roots;
|
||||
ast_mark mark;
|
||||
for (ast * n : m_ast_table) {
|
||||
|
@ -1465,22 +1465,21 @@ ast_manager::~ast_manager() {
|
|||
break;
|
||||
}
|
||||
}
|
||||
for (ast * n : m_ast_table) {
|
||||
if (!mark.is_marked(n)) {
|
||||
for (ast * n : m_ast_table)
|
||||
if (!mark.is_marked(n))
|
||||
roots.push_back(n);
|
||||
}
|
||||
}
|
||||
|
||||
SASSERT(!roots.empty());
|
||||
for (unsigned i = 0; i < roots.size(); ++i) {
|
||||
ast* a = roots[i];
|
||||
DEBUG_CODE(
|
||||
std::cout << "Leaked: ";
|
||||
if (is_sort(a)) {
|
||||
std::cout << to_sort(a)->get_name() << "\n";
|
||||
}
|
||||
else {
|
||||
std::cout << mk_ll_pp(a, *this, false) << "id: " << a->get_id() << "\n";
|
||||
});
|
||||
IF_VERBOSE(1,
|
||||
verbose_stream() << "Leaked: ";
|
||||
if (is_sort(a))
|
||||
verbose_stream() << to_sort(a)->get_name() << "\n";
|
||||
else
|
||||
verbose_stream() << mk_ll_pp(a, *this, false) << "id: " << a->get_id() << "\n";
|
||||
););
|
||||
a->m_ref_count = 0;
|
||||
delete_node(a);
|
||||
}
|
||||
|
|
|
@ -131,6 +131,21 @@ namespace user_solver {
|
|||
m_id2justification.setx(v, lits, sat::literal_vector());
|
||||
m_fixed_eh(m_user_context, this, var2expr(v), lit.sign() ? m.mk_false() : m.mk_true());
|
||||
}
|
||||
|
||||
void solver::new_eq_eh(euf::th_eq const& eq) {
|
||||
if (!m_eq_eh)
|
||||
return;
|
||||
force_push();
|
||||
m_eq_eh(m_user_context, this, var2expr(eq.v1()), var2expr(eq.v2()));
|
||||
}
|
||||
|
||||
void solver::new_diseq_eh(euf::th_eq const& de) {
|
||||
if (!m_diseq_eh)
|
||||
return;
|
||||
force_push();
|
||||
m_diseq_eh(m_user_context, this, var2expr(eq.v1()), var2expr(eq.v2()));
|
||||
}
|
||||
|
||||
|
||||
void solver::push_core() {
|
||||
th_euf_solver::push_core();
|
||||
|
|
|
@ -144,6 +144,10 @@ namespace user_solver {
|
|||
bool get_case_split(sat::bool_var& var, lbool &phase) override;
|
||||
|
||||
void asserted(sat::literal lit) override;
|
||||
bool use_diseqs() const override { return (bool)m_diseq_eh; }
|
||||
void new_eq_eh(euf::th_eq const& eq) override;
|
||||
void new_diseq_eh(euf::th_eq const& de) override;
|
||||
|
||||
sat::check_result check() override;
|
||||
void push_core() override;
|
||||
void pop_core(unsigned n) override;
|
||||
|
|
|
@ -821,6 +821,8 @@ namespace smt {
|
|||
SASSERT(t2 != null_theory_id);
|
||||
theory_var v1 = m_fparams.m_new_core2th_eq ? get_closest_var(n1, t2) : r1->get_th_var(t2);
|
||||
|
||||
TRACE("merge_theory_vars", tout << get_theory(t2)->get_name() << ": " << v2 << " == " << v1 << "\n");
|
||||
|
||||
if (v1 != null_theory_var) {
|
||||
// only send the equality to the theory, if the equality was not propagated by it.
|
||||
if (t2 != from_th)
|
||||
|
@ -839,6 +841,7 @@ namespace smt {
|
|||
SASSERT(v1 != null_theory_var);
|
||||
SASSERT(t1 != null_theory_id);
|
||||
theory_var v2 = r2->get_th_var(t1);
|
||||
TRACE("merge_theory_vars", tout << get_theory(t1)->get_name() << ": " << v2 << " == " << v1 << "\n");
|
||||
if (v2 == null_theory_var) {
|
||||
r2->add_th_var(v1, t1, m_region);
|
||||
push_new_th_diseqs(r2, v1, get_theory(t1));
|
||||
|
|
|
@ -146,7 +146,9 @@ final_check_status theory_user_propagator::final_check_eh() {
|
|||
catch (...) {
|
||||
throw default_exception("Exception thrown in \"final\"-callback");
|
||||
}
|
||||
CTRACE("user_propagate", can_propagate(), tout << "can propagate\n");
|
||||
propagate();
|
||||
CTRACE("user_propagate", ctx.inconsistent(), tout << "inconsistent\n");
|
||||
// check if it became inconsistent or something new was propagated/registered
|
||||
bool done = (sz1 == m_prop.size()) && (sz2 == m_expr2var.size()) && !ctx.inconsistent();
|
||||
return done ? FC_DONE : FC_CONTINUE;
|
||||
|
@ -298,13 +300,17 @@ void theory_user_propagator::propagate_consequence(prop_info const& prop) {
|
|||
m_eqs.reset();
|
||||
for (expr* id : prop.m_ids)
|
||||
m_lits.append(m_id2justification[expr2var(id)]);
|
||||
for (auto const& p : prop.m_eqs)
|
||||
m_eqs.push_back(enode_pair(get_enode(expr2var(p.first)), get_enode(expr2var(p.second))));
|
||||
DEBUG_CODE(for (auto const& p : m_eqs) VERIFY(p.first->get_root() == p.second->get_root()););
|
||||
for (auto const& [a,b] : prop.m_eqs)
|
||||
if (a != b)
|
||||
m_eqs.push_back(enode_pair(get_enode(expr2var(a)), get_enode(expr2var(b))));
|
||||
DEBUG_CODE(for (auto const& [a, b] : m_eqs) VERIFY(a->get_root() == b->get_root()););
|
||||
DEBUG_CODE(for (expr* e : prop.m_ids) VERIFY(m_fixed.contains(expr2var(e))););
|
||||
DEBUG_CODE(for (literal lit : m_lits) VERIFY(ctx.get_assignment(lit) == l_true););
|
||||
|
||||
TRACE("user_propagate", tout << "propagating #" << prop.m_conseq->get_id() << ": " << prop.m_conseq << "\n");
|
||||
TRACE("user_propagate", tout << "propagating #" << prop.m_conseq->get_id() << ": " << prop.m_conseq << "\n";
|
||||
for (auto const& [a,b] : m_eqs) tout << enode_pp(a, ctx) << " == " << enode_pp(b, ctx) << "\n";
|
||||
for (expr* e : prop.m_ids) tout << mk_pp(e, m) << "\n";
|
||||
for (literal lit : m_lits) tout << lit << "\n");
|
||||
|
||||
if (m.is_false(prop.m_conseq)) {
|
||||
js = ctx.mk_justification(
|
||||
|
@ -341,9 +347,9 @@ void theory_user_propagator::propagate_new_fixed(prop_info const& prop) {
|
|||
|
||||
|
||||
void theory_user_propagator::propagate() {
|
||||
TRACE("user_propagate", tout << "propagating queue head: " << m_qhead << " prop queue: " << m_prop.size() << "\n");
|
||||
if (m_qhead == m_prop.size() && m_to_add_qhead == m_to_add.size())
|
||||
return;
|
||||
TRACE("user_propagate", tout << "propagating queue head: " << m_qhead << " prop queue: " << m_prop.size() << "\n");
|
||||
force_push();
|
||||
|
||||
unsigned qhead = m_to_add_qhead;
|
||||
|
|
|
@ -140,10 +140,11 @@ namespace smt {
|
|||
bool get_case_split(bool_var& var, bool& is_pos);
|
||||
|
||||
theory * mk_fresh(context * new_ctx) override;
|
||||
char const* get_name() const override { return "user_propagate"; }
|
||||
bool internalize_atom(app* atom, bool gate_ctx) override;
|
||||
bool internalize_term(app* term) override;
|
||||
void new_eq_eh(theory_var v1, theory_var v2) override { if (m_eq_eh) m_eq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); }
|
||||
void new_diseq_eh(theory_var v1, theory_var v2) override { if (m_diseq_eh) m_diseq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); }
|
||||
void new_eq_eh(theory_var v1, theory_var v2) override { force_push(); if (m_eq_eh) m_eq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); }
|
||||
void new_diseq_eh(theory_var v1, theory_var v2) override { force_push(); if (m_diseq_eh) m_diseq_eh(m_user_context, this, var2expr(v1), var2expr(v2)); }
|
||||
bool use_diseqs() const override { return ((bool)m_diseq_eh); }
|
||||
bool build_models() const override { return false; }
|
||||
final_check_status final_check_eh() override;
|
||||
|
|
Loading…
Reference in a new issue