3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-12 12:08:18 +00:00

exposing user propagators over .Net

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2022-05-07 11:08:40 -07:00
parent 3ae781039b
commit 14214c5a07
7 changed files with 296 additions and 18 deletions

View file

@ -42,6 +42,7 @@ IN_ARRAY = 3
OUT_ARRAY = 4 OUT_ARRAY = 4
INOUT_ARRAY = 5 INOUT_ARRAY = 5
OUT_MANAGED_ARRAY = 6 OUT_MANAGED_ARRAY = 6
FN_PTR = 7
# Primitive Types # Primitive Types
VOID = 0 VOID = 0
@ -61,11 +62,16 @@ FLOAT = 13
CHAR = 14 CHAR = 14
CHAR_PTR = 15 CHAR_PTR = 15
FIRST_FN_ID = 50
FIRST_OBJ_ID = 100 FIRST_OBJ_ID = 100
def is_obj(ty): def is_obj(ty):
return ty >= FIRST_OBJ_ID return ty >= FIRST_OBJ_ID
def is_fn(ty):
return FIRST_FN_ID <= ty and ty < FIRST_OBJ_ID
Type2Str = { VOID : 'void', VOID_PTR : 'void*', INT : 'int', UINT : 'unsigned', INT64 : 'int64_t', UINT64 : 'uint64_t', DOUBLE : 'double', Type2Str = { VOID : 'void', VOID_PTR : 'void*', INT : 'int', UINT : 'unsigned', INT64 : 'int64_t', UINT64 : 'uint64_t', DOUBLE : 'double',
FLOAT : 'float', STRING : 'Z3_string', STRING_PTR : 'Z3_string_ptr', BOOL : 'bool', SYMBOL : 'Z3_symbol', FLOAT : 'float', STRING : 'Z3_string', STRING_PTR : 'Z3_string_ptr', BOOL : 'bool', SYMBOL : 'Z3_symbol',
PRINT_MODE : 'Z3_ast_print_mode', ERROR_CODE : 'Z3_error_code', CHAR: 'char', CHAR_PTR: 'Z3_char_ptr' PRINT_MODE : 'Z3_ast_print_mode', ERROR_CODE : 'Z3_error_code', CHAR: 'char', CHAR_PTR: 'Z3_char_ptr'
@ -88,9 +94,12 @@ Type2ML = { VOID : 'unit', VOID_PTR : 'VOIDP', INT : 'int', UINT : 'int', INT64
FLOAT : 'float', STRING : 'string', STRING_PTR : 'char**', FLOAT : 'float', STRING : 'string', STRING_PTR : 'char**',
BOOL : 'bool', SYMBOL : 'z3_symbol', PRINT_MODE : 'int', ERROR_CODE : 'int', CHAR : 'char', CHAR_PTR : 'string' } BOOL : 'bool', SYMBOL : 'z3_symbol', PRINT_MODE : 'int', ERROR_CODE : 'int', CHAR : 'char', CHAR_PTR : 'string' }
Closures = []
class APITypes: class APITypes:
def __init__(self): def __init__(self):
self.next_type_id = FIRST_OBJ_ID self.next_type_id = FIRST_OBJ_ID
self.next_fntype_id = FIRST_FN_ID
def def_Type(self, var, c_type, py_type): def def_Type(self, var, c_type, py_type):
"""Process type definitions of the form def_Type(var, c_type, py_type) """Process type definitions of the form def_Type(var, c_type, py_type)
@ -103,24 +112,42 @@ class APITypes:
Type2Str[id] = c_type Type2Str[id] = c_type
Type2PyStr[id] = py_type Type2PyStr[id] = py_type
self.next_type_id += 1 self.next_type_id += 1
def def_Types(self, api_files): def def_Types(self, api_files):
global Closures
pat1 = re.compile(" *def_Type\(\'(.*)\',[^\']*\'(.*)\',[^\']*\'(.*)\'\)[ \t]*") pat1 = re.compile(" *def_Type\(\'(.*)\',[^\']*\'(.*)\',[^\']*\'(.*)\'\)[ \t]*")
pat2 = re.compile("Z3_DECLARE_CLOSURE\((.*),(.*), \((.*)\)\)")
for api_file in api_files: for api_file in api_files:
with open(api_file, 'r') as api: with open(api_file, 'r') as api:
for line in api: for line in api:
m = pat1.match(line) m = pat1.match(line)
if m: if m:
self.def_Type(m.group(1), m.group(2), m.group(3)) self.def_Type(m.group(1), m.group(2), m.group(3))
continue
m = pat2.match(line)
if m:
self.fun_Type(m.group(1))
Closures += [(m.group(1), m.group(2), m.group(3))]
continue
# #
# Populate object type entries in dotnet and ML bindings. # Populate object type entries in dotnet and ML bindings.
# #
for k in Type2Str: for k in Type2Str:
v = Type2Str[k] v = Type2Str[k]
if is_obj(k): if is_obj(k) or is_fn(k):
Type2Dotnet[k] = v Type2Dotnet[k] = v
Type2ML[k] = v.lower() Type2ML[k] = v.lower()
def fun_Type(self, var):
"""Process function type definitions"""
id = self.next_fntype_id
exec('%s = %s' % (var, id), globals())
Type2Str[id] = var
Type2PyStr[id] = var
self.next_fntype_id += 1
def type2str(ty): def type2str(ty):
global Type2Str global Type2Str
return Type2Str[ty] return Type2Str[ty]
@ -147,6 +174,9 @@ def _in(ty):
def _in_array(sz, ty): def _in_array(sz, ty):
return (IN_ARRAY, ty, sz) return (IN_ARRAY, ty, sz)
def _fnptr(ty):
return (FN_PTR, ty)
def _out(ty): def _out(ty):
return (OUT, ty) return (OUT, ty)
@ -180,7 +210,7 @@ def param_array_size_pos(p):
def param2str(p): def param2str(p):
if param_kind(p) == IN_ARRAY: if param_kind(p) == IN_ARRAY:
return "%s const *" % type2str(param_type(p)) return "%s const *" % type2str(param_type(p))
elif param_kind(p) == OUT_ARRAY or param_kind(p) == IN_ARRAY or param_kind(p) == INOUT_ARRAY: elif param_kind(p) == OUT_ARRAY or param_kind(p) == IN_ARRAY or param_kind(p) == INOUT_ARRAY or param_kind(p) == FN_PTR:
return "%s*" % type2str(param_type(p)) return "%s*" % type2str(param_type(p))
elif param_kind(p) == OUT: elif param_kind(p) == OUT:
return "%s*" % type2str(param_type(p)) return "%s*" % type2str(param_type(p))
@ -374,11 +404,20 @@ def mk_dotnet(dotnet):
v = Type2Str[k] v = Type2Str[k]
if is_obj(k): if is_obj(k):
dotnet.write(' using %s = System.IntPtr;\n' % v) dotnet.write(' using %s = System.IntPtr;\n' % v)
dotnet.write(' using voidp = System.IntPtr;\n')
dotnet.write('\n') dotnet.write('\n')
dotnet.write(' public class Native\n') dotnet.write(' public class Native\n')
dotnet.write(' {\n\n') dotnet.write(' {\n\n')
dotnet.write(' [UnmanagedFunctionPointer(CallingConvention.Cdecl)]\n')
dotnet.write(' public delegate void Z3_error_handler(Z3_context c, Z3_error_code e);\n\n') for name, ret, sig in Closures:
sig = sig.replace("void*","voidp").replace("unsigned","uint")
ret = ret.replace("void*","voidp").replace("unsigned","uint")
if "*" in sig or "*" in ret:
continue
dotnet.write(' [UnmanagedFunctionPointer(CallingConvention.Cdecl)]\n')
dotnet.write(f" public delegate {ret} {name}({sig});\n")
dotnet.write(' public class LIB\n') dotnet.write(' public class LIB\n')
dotnet.write(' {\n') dotnet.write(' {\n')
dotnet.write(' const string Z3_DLL_NAME = \"libz3\";\n' dotnet.write(' const string Z3_DLL_NAME = \"libz3\";\n'
@ -1070,6 +1109,9 @@ def def_API(name, result, params):
log_c.write(" }\n") log_c.write(" }\n")
log_c.write(" Ap(%s);\n" % sz_e) log_c.write(" Ap(%s);\n" % sz_e)
exe_c.write("reinterpret_cast<%s**>(in.get_obj_array(%s))" % (tstr, i)) exe_c.write("reinterpret_cast<%s**>(in.get_obj_array(%s))" % (tstr, i))
elif kind == FN_PTR:
log_c.write(" P(a%s);\n" % i)
exe_c.write("reinterpret_cast<%s>(in.get_obj(%s))" % (param2str(p), i))
else: else:
error ("unsupported parameter for %s, %s" % (name, p)) error ("unsupported parameter for %s, %s" % (name, p))
i = i + 1 i = i + 1

View file

@ -113,6 +113,7 @@ set(Z3_DOTNET_ASSEMBLY_SOURCES_IN_SRC_TREE
Tactic.cs Tactic.cs
TupleSort.cs TupleSort.cs
UninterpretedSort.cs UninterpretedSort.cs
UserPropagator.cs
Version.cs Version.cs
Z3Exception.cs Z3Exception.cs
Z3Object.cs Z3Object.cs

View file

@ -0,0 +1,222 @@
/*++
Copyright (c) 2012 Microsoft Corporation
Module Name:
UserPropagator.cs
Abstract:
User Propagator plugin
Author:
Nikolaj Bjorner (nbjorner) 2022-05-07
Notes:
// Todo: fresh, created, declare user function, register_cb, decide,
--*/
using System;
using System.Diagnostics;
using System.Linq;
using System.Collections.Generic;
using System.Runtime.InteropServices;
namespace Microsoft.Z3
{
using Z3_solver_callback = System.IntPtr;
using Z3_context = System.IntPtr;
using Z3_solver = System.IntPtr;
using voidp = System.IntPtr;
using Z3_ast = System.IntPtr;
/// <summary>
/// Propagator context for .Net
/// </summary>
public class UserPropagator
{
/// <summary>
/// Delegate type for fixed callback
/// </summary>
public delegate void FixedEh(Expr term, Expr value);
/// <summary>
/// Delegate type for equality or disequality callback
/// </summary>
public delegate void EqEh(Expr term, Expr value);
Solver solver;
GCHandle gch;
Z3_solver_callback callback;
FixedEh fixed_eh;
Action final_eh;
EqEh eq_eh;
EqEh diseq_eh;
unsafe static void _push(voidp ctx, Z3_solver_callback cb) {
var gch = GCHandle.FromIntPtr(ctx);
var prop = (UserPropagator)gch.Target;
prop.callback = cb;
prop.Push();
}
unsafe static void _pop(voidp ctx, Z3_solver_callback cb, uint num_scopes) {
var gch = GCHandle.FromIntPtr(ctx);
var prop = (UserPropagator)gch.Target;
prop.callback = cb;
prop.Pop(num_scopes);
}
unsafe static voidp _fresh(voidp ctx, Z3_context new_context) {
var gch = GCHandle.FromIntPtr(ctx);
var prop = (UserPropagator)gch.Target;
throw new Z3Exception("fresh is NYI");
}
unsafe static void _fixed(voidp ctx, Z3_solver_callback cb, Z3_ast _term, Z3_ast _value) {
var gch = GCHandle.FromIntPtr(ctx);
var prop = (UserPropagator)gch.Target;
var term = Expr.Create(prop.solver.Context, _term);
var value = Expr.Create(prop.solver.Context, _value);
prop.callback = cb;
prop.fixed_eh(term, value);
}
unsafe static void _final(voidp ctx, Z3_solver_callback cb) {
var gch = GCHandle.FromIntPtr(ctx);
var prop = (UserPropagator)gch.Target;
prop.callback = cb;
prop.final_eh();
}
unsafe static void _eq(voidp ctx, Z3_solver_callback cb, Z3_ast a, Z3_ast b) {
var gch = GCHandle.FromIntPtr(ctx);
var prop = (UserPropagator)gch.Target;
var s = Expr.Create(prop.solver.Context, a);
var t = Expr.Create(prop.solver.Context, b);
prop.callback = cb;
prop.eq_eh(s, t);
}
unsafe static void _diseq(voidp ctx, Z3_solver_callback cb, Z3_ast a, Z3_ast b) {
var gch = GCHandle.FromIntPtr(ctx);
var prop = (UserPropagator)gch.Target;
var s = Expr.Create(prop.solver.Context, a);
var t = Expr.Create(prop.solver.Context, b);
prop.callback = cb;
prop.diseq_eh(s, t);
}
/// <summary>
/// Propagator constructor from a solver class.
/// </summary>
public UserPropagator(Solver s)
{
gch = GCHandle.Alloc(this);
solver = s;
var cb = GCHandle.ToIntPtr(gch);
Native.Z3_solver_propagate_init(solver.Context.nCtx, solver.NativeObject, cb, _push, _pop, _fresh);
}
/// <summary>
/// Release provate memory.
/// </summary>
~UserPropagator()
{
gch.Free();
}
/// <summary>
/// Virtual method for push. It must be overwritten by inherited class.
/// </summary>
public virtual void Push() { throw new Z3Exception("Push method should be overwritten"); }
/// <summary>
/// Virtual method for pop. It must be overwritten by inherited class.
/// </summary>
public virtual void Pop(uint n) { throw new Z3Exception("Pop method should be overwritten"); }
/// <summary>
/// Virtual method for fresh. It must be overwritten by inherited class.
/// </summary>
public virtual UserPropagator Fresh(Context ctx) { throw new Z3Exception("Fresh method should be overwritten"); }
/// <summary>
/// Declare combination of assigned expressions a conflict
/// </summary>
void Conflict(params Expr[] terms) {
Propagate(terms, solver.Context.MkFalse());
}
/// <summary>
/// Propagate consequence
/// </summary>
void Propagate(Expr[] terms, Expr conseq) {
var nTerms = Z3Object.ArrayToNative(terms);
Native.Z3_solver_propagate_consequence(solver.Context.nCtx, this.callback, (uint)nTerms.Length, nTerms, 0u, null, null, conseq.NativeObject);
}
/// <summary>
/// Set fixed callback
/// </summary>
public FixedEh Fixed
{
set
{
this.fixed_eh = value;
Native.Z3_solver_propagate_fixed(solver.Context.nCtx, solver.NativeObject, _fixed);
}
}
/// <summary>
/// Set final callback
/// </summary>
public Action Final
{
set
{
this.final_eh = value;
Native.Z3_solver_propagate_final(solver.Context.nCtx, solver.NativeObject, _final);
}
}
/// <summary>
/// Set equality event callback
/// </summary>
public EqEh Eq
{
set
{
this.eq_eh = value;
Native.Z3_solver_propagate_eq(solver.Context.nCtx, solver.NativeObject, _eq);
}
}
/// <summary>
/// Set disequality event callback
/// </summary>
public EqEh Diseq
{
set
{
this.diseq_eh = value;
Native.Z3_solver_propagate_diseq(solver.Context.nCtx, solver.NativeObject, _diseq);
}
}
/// <summary>
/// Track assignments to a term
/// </summary>
public void Register(Expr term) {
Native.Z3_solver_propagate_register(solver.Context.nCtx, solver.NativeObject, term.NativeObject);
}
}
}

View file

@ -1444,7 +1444,7 @@ Z3_DECLARE_CLOSURE(Z3_fixed_eh, void, (void* ctx, Z3_solver_callback cb, Z3_as
Z3_DECLARE_CLOSURE(Z3_eq_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast s, Z3_ast t)); Z3_DECLARE_CLOSURE(Z3_eq_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast s, Z3_ast t));
Z3_DECLARE_CLOSURE(Z3_final_eh, void, (void* ctx, Z3_solver_callback cb)); Z3_DECLARE_CLOSURE(Z3_final_eh, void, (void* ctx, Z3_solver_callback cb));
Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t)); Z3_DECLARE_CLOSURE(Z3_created_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast t));
Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast*, unsigned*, Z3_lbool*)); Z3_DECLARE_CLOSURE(Z3_decide_eh, void, (void* ctx, Z3_solver_callback cb, Z3_ast* t, unsigned* idx, Z3_lbool* phase));
/** /**
@ -6733,6 +6733,8 @@ extern "C" {
\param push_eh - a callback invoked when scopes are pushed \param push_eh - a callback invoked when scopes are pushed
\param pop_eh - a callback invoked when scopes are poped \param pop_eh - a callback invoked when scopes are poped
\param fresh_eh - a solver may spawn new solvers internally. This callback is used to produce a fresh user_context to be associated with fresh solvers. \param fresh_eh - a solver may spawn new solvers internally. This callback is used to produce a fresh user_context to be associated with fresh solvers.
def_API('Z3_solver_propagate_init', VOID, (_in(CONTEXT), _in(SOLVER), _in(VOID_PTR), _fnptr(Z3_push_eh), _fnptr(Z3_pop_eh), _fnptr(Z3_fresh_eh)))
*/ */
void Z3_API Z3_solver_propagate_init( void Z3_API Z3_solver_propagate_init(
@ -6748,6 +6750,8 @@ extern "C" {
The supported expression types are The supported expression types are
- Booleans - Booleans
- Bit-vectors - Bit-vectors
def_API('Z3_solver_propagate_fixed', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_fixed_eh)))
*/ */
void Z3_API Z3_solver_propagate_fixed(Z3_context c, Z3_solver s, Z3_fixed_eh fixed_eh); void Z3_API Z3_solver_propagate_fixed(Z3_context c, Z3_solver s, Z3_fixed_eh fixed_eh);
@ -6764,22 +6768,30 @@ extern "C" {
The callback context can only be accessed (for propagation and for dynamically registering expressions) within a callback. The callback context can only be accessed (for propagation and for dynamically registering expressions) within a callback.
If the callback context gets used for propagation or conflicts, those propagations take effect and If the callback context gets used for propagation or conflicts, those propagations take effect and
may trigger new decision variables to be set. may trigger new decision variables to be set.
def_API('Z3_solver_propagate_final', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_final_eh)))
*/ */
void Z3_API Z3_solver_propagate_final(Z3_context c, Z3_solver s, Z3_final_eh final_eh); void Z3_API Z3_solver_propagate_final(Z3_context c, Z3_solver s, Z3_final_eh final_eh);
/** /**
\brief register a callback on expression equalities. \brief register a callback on expression equalities.
def_API('Z3_solver_propagate_eq', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_eq_eh)))
*/ */
void Z3_API Z3_solver_propagate_eq(Z3_context c, Z3_solver s, Z3_eq_eh eq_eh); void Z3_API Z3_solver_propagate_eq(Z3_context c, Z3_solver s, Z3_eq_eh eq_eh);
/** /**
\brief register a callback on expression dis-equalities. \brief register a callback on expression dis-equalities.
def_API('Z3_solver_propagate_diseq', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_eq_eh)))
*/ */
void Z3_API Z3_solver_propagate_diseq(Z3_context c, Z3_solver s, Z3_eq_eh eq_eh); void Z3_API Z3_solver_propagate_diseq(Z3_context c, Z3_solver s, Z3_eq_eh eq_eh);
/** /**
\brief register a callback when a new expression with a registered function is used by the solver \brief register a callback when a new expression with a registered function is used by the solver
The registered function appears at the top level and is created using \ref Z3_propagate_solver_declare. The registered function appears at the top level and is created using \ref Z3_propagate_solver_declare.
def_API('Z3_solver_propagate_created', VOID, (_in(CONTEXT), _in(SOLVER), _fnptr(Z3_created_eh)))
*/ */
void Z3_API Z3_solver_propagate_created(Z3_context c, Z3_solver s, Z3_created_eh created_eh); void Z3_API Z3_solver_propagate_created(Z3_context c, Z3_solver s, Z3_created_eh created_eh);
@ -6788,6 +6800,7 @@ extern "C" {
The callback may set the passed expression to another registered expression which will be selected instead. The callback may set the passed expression to another registered expression which will be selected instead.
In case the expression is a bitvector the bit to split on is determined by the bit argument and the In case the expression is a bitvector the bit to split on is determined by the bit argument and the
truth-value to try first is given by is_pos. In case the truth value is undefined the solver will decide. truth-value to try first is given by is_pos. In case the truth value is undefined the solver will decide.
*/ */
void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh); void Z3_API Z3_solver_propagate_decide(Z3_context c, Z3_solver s, Z3_decide_eh decide_eh);

View file

@ -195,8 +195,10 @@ public:
void operator()(app* a) void operator()(app* a)
{ {
if (a->get_family_id() == null_family_id && m_au.is_array(a)) { if (a->get_family_id() == null_family_id && m_au.is_array(a)) {
if (m_sort && m_sort != a->get_sort()) { return; } if (m_sort && m_sort != a->get_sort())
if (!m_sort) { m_sort = a->get_sort(); } return;
if (!m_sort)
m_sort = a->get_sort();
m_symbs.insert(a->get_decl()); m_symbs.insert(a->get_decl());
} }
} }
@ -208,16 +210,10 @@ public:
bool lemma_array_eq_generalizer::is_array_eq (ast_manager &m, expr* e) { bool lemma_array_eq_generalizer::is_array_eq (ast_manager &m, expr* e) {
expr *e1 = nullptr, *e2 = nullptr; expr *e1 = nullptr, *e2 = nullptr;
if (m.is_eq(e, e1, e2) && is_app(e1) && is_app(e2)) { array_util au(m);
app *a1 = to_app(e1); return m.is_eq(e, e1, e2) &&
app *a2 = to_app(e2); is_uninterp(e1) && is_uninterp(e2) &&
array_util au(m); au.is_array(e1) && au.is_array(e2);
if (a1->get_family_id() == null_family_id &&
a2->get_family_id() == null_family_id &&
au.is_array(a1) && au.is_array(a2))
return true;
}
return false;
} }
void lemma_array_eq_generalizer::operator() (lemma_ref &lemma) void lemma_array_eq_generalizer::operator() (lemma_ref &lemma)

View file

@ -288,7 +288,7 @@ namespace sat {
inline clause_allocator& cls_allocator() { return m_cls_allocator[m_cls_allocator_idx]; } inline clause_allocator& cls_allocator() { return m_cls_allocator[m_cls_allocator_idx]; }
inline clause_allocator const& cls_allocator() const { return m_cls_allocator[m_cls_allocator_idx]; } inline clause_allocator const& cls_allocator() const { return m_cls_allocator[m_cls_allocator_idx]; }
inline clause * alloc_clause(unsigned num_lits, literal const * lits, bool learned) { return cls_allocator().mk_clause(num_lits, lits, learned); } inline clause * alloc_clause(unsigned num_lits, literal const * lits, bool learned) { return cls_allocator().mk_clause(num_lits, lits, learned); }
inline void dealloc_clause(clause* c) { cls_allocator().del_clause(c); } inline void dealloc_clause(clause* c) { cls_allocator().del_clause(c); }
struct cmp_activity; struct cmp_activity;
void defrag_clauses(); void defrag_clauses();
bool should_defrag(); bool should_defrag();

View file

@ -1190,6 +1190,10 @@ namespace arith {
} }
void solver::assign(literal lit, literal_vector const& core, svector<enode_pair> const& eqs, vector<parameter> const& params) { void solver::assign(literal lit, literal_vector const& core, svector<enode_pair> const& eqs, vector<parameter> const& params) {
std::cout << "assign: ";
for (auto const& p : params)
std::cout << p << " ";
std::cout << "\n";
if (core.size() < small_lemma_size() && eqs.empty()) { if (core.size() < small_lemma_size() && eqs.empty()) {
m_core2.reset(); m_core2.reset();
for (auto const& c : core) for (auto const& c : core)