diff --git a/param-tuning-experiment.py b/param-tuning-experiment.py index 72068329a..bdaf10ae6 100644 --- a/param-tuning-experiment.py +++ b/param-tuning-experiment.py @@ -1,9 +1,15 @@ -import os -from more_itertools import iterate -from z3 import * from multiprocessing import Process import math, random +import sys, os +sys.path.insert(0, os.path.abspath("build/python")) +os.environ["Z3_LIBRARY_PATH"] = os.path.abspath("build") + +# import z3 +# print("Using z3 from:", z3.__file__) + +from z3 import * + MAX_CONFLICTS = 100 MAX_EXAMPLES = 5 bench_dir = "../z3-poly-testing/inputs/QF_NIA_small" @@ -68,7 +74,8 @@ def stats_tuple(st): def run_prefix_step(S, K, clause_limit): clauses = [] - def on_clause(premises, deps, clause): + def on_clause(premises, deps, clause, status): + print(f" [OnClause] collected clause status: {status}, clause: {clause}") if len(clauses) < clause_limit: clauses.append(clause) @@ -87,10 +94,7 @@ def replay_prefix_on_pps(PPS_solver, clauses, param_state, budget): # For each learned clause Cj = [l1, l2, ...], check ¬(l1 ∨ l2 ∨ ...) for idx, Cj in enumerate(clauses): - if isinstance(Cj, AstVector): - lits = [Cj[i].translate(PPS_solver.ctx) for i in range(len(Cj))] - else: - lits = [l.translate(PPS_solver.ctx) for l in Cj] + lits = [l.translate(PPS_solver.ctx) for l in Cj] negated_lits = [] for l in lits: diff --git a/scripts/update_api.py b/scripts/update_api.py index 5c28bcd3e..08eeaaf68 100755 --- a/scripts/update_api.py +++ b/scripts/update_api.py @@ -1941,7 +1941,7 @@ _error_handler_type = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint) _lib.Z3_set_error_handler.restype = None _lib.Z3_set_error_handler.argtypes = [ContextObj, _error_handler_type] -Z3_on_clause_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.POINTER(ctypes.c_uint), ctypes.c_void_p) +Z3_on_clause_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint, ctypes.POINTER(ctypes.c_uint), ctypes.c_void_p, ctypes.c_uint) Z3_push_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p) Z3_pop_eh = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint) Z3_fresh_eh = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p) diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 05b93d38b..75b6b0588 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -1031,14 +1031,14 @@ extern "C" { Z3_TRY; RESET_ERROR_CODE(); init_solver(c, s); - user_propagator::on_clause_eh_t _on_clause = [=](void* user_ctx, expr* proof, unsigned nd, unsigned const* deps, unsigned n, expr* const* _literals) { + user_propagator::on_clause_eh_t _on_clause = [=](void* user_ctx, expr* proof, unsigned nd, unsigned const* deps, unsigned n, expr* const* _literals, unsigned const status) { Z3_ast_vector_ref * literals = alloc(Z3_ast_vector_ref, *mk_c(c), mk_c(c)->m()); mk_c(c)->save_object(literals); expr_ref pr(proof, mk_c(c)->m()); scoped_ast_vector _sc(literals); for (unsigned i = 0; i < n; ++i) literals->m_ast_vector.push_back(_literals[i]); - on_clause_eh(user_ctx, of_expr(pr.get()), nd, deps, of_ast_vector(literals)); + on_clause_eh(user_ctx, of_expr(pr.get()), nd, deps, of_ast_vector(literals), status); }; to_solver_ref(s)->register_on_clause(user_context, _on_clause); auto& solver = *to_solver(s); diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index 71f3ff79b..9332d9074 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -4288,20 +4288,20 @@ namespace z3 { return expr(ctx(), r); } - typedef std::function const& deps, expr_vector const& clause)> on_clause_eh_t; + typedef std::function const& deps, expr_vector const& clause, unsigned const status)> on_clause_eh_t; class on_clause { context& c; on_clause_eh_t m_on_clause; - static void _on_clause_eh(void* _ctx, Z3_ast _proof, unsigned n, unsigned const* dep, Z3_ast_vector _literals) { + static void _on_clause_eh(void* _ctx, Z3_ast _proof, unsigned n, unsigned const* dep, Z3_ast_vector _literals, unsigned const status) { on_clause* ctx = static_cast(_ctx); expr_vector lits(ctx->c, _literals); expr proof(ctx->c, _proof); std::vector deps; for (unsigned i = 0; i < n; ++i) deps.push_back(dep[i]); - ctx->m_on_clause(proof, deps, lits); + ctx->m_on_clause(proof, deps, lits, status); } public: on_clause(solver& s, on_clause_eh_t& on_clause_eh): c(s.ctx()) { diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 128726dae..b9aa9856e 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -11697,12 +11697,12 @@ def to_AstVectorObj(ptr,): # for UserPropagator we use a global dictionary, which isn't great code. _my_hacky_class = None -def on_clause_eh(ctx, p, n, dep, clause): +def on_clause_eh(ctx, p, n, dep, clause, status): onc = _my_hacky_class p = _to_expr_ref(to_Ast(p), onc.ctx) clause = AstVector(to_AstVectorObj(clause), onc.ctx) deps = [dep[i] for i in range(n)] - onc.on_clause(p, deps, clause) + onc.on_clause(p, deps, clause, status) _on_clause_eh = Z3_on_clause_eh(on_clause_eh) diff --git a/src/api/z3_api.h b/src/api/z3_api.h index baa2fa34c..3264226f3 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -1441,7 +1441,7 @@ Z3_DECLARE_CLOSURE(Z3_final_eh, void, (void* ctx, Z3_solver_callback cb)); Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t)); Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t, unsigned idx, bool phase)); Z3_DECLARE_CLOSURE(Z3_on_binding_eh, bool, (void* ctx, Z3_solver_callback cb, Z3_ast q, Z3_ast inst)); -Z3_DECLARE_CLOSURE(Z3_on_clause_eh, void, (void* ctx, Z3_ast proof_hint, unsigned n, unsigned const* deps, Z3_ast_vector literals)); +Z3_DECLARE_CLOSURE(Z3_on_clause_eh, void, (void* ctx, Z3_ast proof_hint, unsigned n, unsigned const* deps, Z3_ast_vector literals, unsigned const status)); /** diff --git a/src/cmd_context/extra_cmds/proof_cmds.cpp b/src/cmd_context/extra_cmds/proof_cmds.cpp index ea585bfae..22348fd2b 100644 --- a/src/cmd_context/extra_cmds/proof_cmds.cpp +++ b/src/cmd_context/extra_cmds/proof_cmds.cpp @@ -314,7 +314,7 @@ public: if (m_trim) trim().assume(m_lits); if (m_on_clause_eh) - m_on_clause_eh(m_on_clause_ctx, assumption(), m_deps.size(), m_deps.data(), m_lits.size(), m_lits.data()); + m_on_clause_eh(m_on_clause_ctx, assumption(), m_deps.size(), m_deps.data(), m_lits.size(), m_lits.data(), 0u); m_lits.reset(); m_proof_hint.reset(); m_deps.reset(); @@ -328,7 +328,7 @@ public: if (m_trim) trim().infer(m_lits, m_proof_hint); if (m_on_clause_eh) - m_on_clause_eh(m_on_clause_ctx, m_proof_hint, m_deps.size(), m_deps.data(), m_lits.size(), m_lits.data()); + m_on_clause_eh(m_on_clause_ctx, m_proof_hint, m_deps.size(), m_deps.data(), m_lits.size(), m_lits.data(), 0u); m_lits.reset(); m_proof_hint.reset(); m_deps.reset(); @@ -342,7 +342,7 @@ public: if (m_trim) trim().del(m_lits); if (m_on_clause_eh) - m_on_clause_eh(m_on_clause_ctx, del(), m_deps.size(), m_deps.data(), m_lits.size(), m_lits.data()); + m_on_clause_eh(m_on_clause_ctx, del(), m_deps.size(), m_deps.data(), m_lits.size(), m_lits.data(), 0u); m_lits.reset(); m_proof_hint.reset(); m_deps.reset(); diff --git a/src/sat/smt/euf_proof.cpp b/src/sat/smt/euf_proof.cpp index 6f240a88d..b28f5e9a2 100644 --- a/src/sat/smt/euf_proof.cpp +++ b/src/sat/smt/euf_proof.cpp @@ -382,7 +382,7 @@ namespace euf { for (unsigned i = 0; i < n; ++i) m_clause.push_back(literal2expr(lits[i])); auto hint = status2proof_hint(st); - m_on_clause(m_on_clause_ctx, hint, 0, nullptr, m_clause.size(), m_clause.data()); + m_on_clause(m_on_clause_ctx, hint, 0, nullptr, m_clause.size(), m_clause.data(), 0u); } void solver::on_proof(unsigned n, literal const* lits, sat::status st) { diff --git a/src/smt/smt_clause_proof.cpp b/src/smt/smt_clause_proof.cpp index bc4105e13..b537a44dc 100644 --- a/src/smt/smt_clause_proof.cpp +++ b/src/smt/smt_clause_proof.cpp @@ -192,8 +192,19 @@ namespace smt { TRACE(clause_proof, tout << m_trail.size() << " " << st << " " << v << "\n";); if (ctx.get_fparams().m_clause_proof) m_trail.push_back(info(st, v, p)); - if (m_on_clause_eh) - m_on_clause_eh(m_on_clause_ctx, p, 0, nullptr, v.size(), v.data()); + if (m_on_clause_eh) { + // Encode status as an integer flag for simplicity. + unsigned st_code = 0; + switch (st) { + case status::assumption: st_code = 1; break; + case status::lemma: st_code = 2; break; + case status::th_lemma: st_code = 3; break; + case status::th_assumption: st_code = 4; break; + case status::deleted: st_code = 5; break; + default: st_code = 0; break; + } + m_on_clause_eh(m_on_clause_ctx, p, 0, nullptr, v.size(), v.data(), st_code); + } if (m_has_log) { init_pp_out(); diff --git a/src/tactic/user_propagator_base.h b/src/tactic/user_propagator_base.h index 1b480fb04..ffe87d27c 100644 --- a/src/tactic/user_propagator_base.h +++ b/src/tactic/user_propagator_base.h @@ -27,7 +27,7 @@ namespace user_propagator { typedef std::function pop_eh_t; typedef std::function created_eh_t; typedef std::function decide_eh_t; - typedef std::function on_clause_eh_t; + typedef std::function on_clause_eh_t; typedef std::function binding_eh_t; class plugin : public decl_plugin {