diff --git a/src/api/js/src/high-level/high-level.test.ts b/src/api/js/src/high-level/high-level.test.ts index fb8805d99..e9a2b46e7 100644 --- a/src/api/js/src/high-level/high-level.test.ts +++ b/src/api/js/src/high-level/high-level.test.ts @@ -890,4 +890,74 @@ describe('high-level', () => { expect(model.eval(z).eqIdentity(Int.val(5))).toBeTruthy(); }); }); + + describe('datatypes', () => { + it('should create simple enum datatype', async () => { + const { Datatype, Int, Bool, Solver } = api.Context('main'); + + // Create a simple Color enum datatype + const Color = Datatype('Color'); + Color.declare('red'); + Color.declare('green'); + Color.declare('blue'); + + const ColorSort = Color.create(); + + // Test that we can access the constructors + expect(typeof (ColorSort as any).red).not.toBe('undefined'); + expect(typeof (ColorSort as any).green).not.toBe('undefined'); + expect(typeof (ColorSort as any).blue).not.toBe('undefined'); + + // Test that we can access the recognizers + expect(typeof (ColorSort as any).is_red).not.toBe('undefined'); + expect(typeof (ColorSort as any).is_green).not.toBe('undefined'); + expect(typeof (ColorSort as any).is_blue).not.toBe('undefined'); + }); + + it('should create recursive list datatype', async () => { + const { Datatype, Int, Solver } = api.Context('main'); + + // Create a recursive List datatype like in the Python example + const List = Datatype('List'); + List.declare('cons', ['car', Int.sort()], ['cdr', List]); + List.declare('nil'); + + const ListSort = List.create(); + + // Test that constructors and accessors exist + expect(typeof (ListSort as any).cons).not.toBe('undefined'); + expect(typeof (ListSort as any).nil).not.toBe('undefined'); + expect(typeof (ListSort as any).is_cons).not.toBe('undefined'); + expect(typeof (ListSort as any).is_nil).not.toBe('undefined'); + expect(typeof (ListSort as any).car).not.toBe('undefined'); + expect(typeof (ListSort as any).cdr).not.toBe('undefined'); + }); + + it('should create mutually recursive tree datatypes', async () => { + const { Datatype, Int } = api.Context('main'); + + // Create mutually recursive Tree and TreeList datatypes + const Tree = Datatype('Tree'); + const TreeList = Datatype('TreeList'); + + Tree.declare('leaf', ['value', Int.sort()]); + Tree.declare('node', ['children', TreeList]); + TreeList.declare('nil'); + TreeList.declare('cons', ['car', Tree], ['cdr', TreeList]); + + const [TreeSort, TreeListSort] = Datatype.createDatatypes(Tree, TreeList); + + // Test that both datatypes have their constructors + expect(typeof (TreeSort as any).leaf).not.toBe('undefined'); + expect(typeof (TreeSort as any).node).not.toBe('undefined'); + expect(typeof (TreeListSort as any).nil).not.toBe('undefined'); + expect(typeof (TreeListSort as any).cons).not.toBe('undefined'); + + // Test accessors exist + expect(typeof (TreeSort as any).value).not.toBe('undefined'); + expect(typeof (TreeSort as any).children).not.toBe('undefined'); + expect(typeof (TreeListSort as any).car).not.toBe('undefined'); + expect(typeof (TreeListSort as any).cdr).not.toBe('undefined'); + }); + }); }); diff --git a/src/api/js/src/high-level/high-level.ts b/src/api/js/src/high-level/high-level.ts index 7d19df982..f53f2d8ca 100644 --- a/src/api/js/src/high-level/high-level.ts +++ b/src/api/js/src/high-level/high-level.ts @@ -17,6 +17,8 @@ import { Z3_ast_print_mode, Z3_ast_vector, Z3_context, + Z3_constructor, + Z3_constructor_list, Z3_decl_kind, Z3_error_code, Z3_func_decl, @@ -88,6 +90,10 @@ import { FuncEntry, SMTSetSort, SMTSet, + Datatype, + DatatypeSort, + DatatypeExpr, + DatatypeCreation, } from './types'; import { allSatisfy, assert, assertExhaustive } from './utils'; @@ -825,6 +831,17 @@ export function createApi(Z3: Z3Core): Z3HighLevel { } } + const Datatype = Object.assign( + (name: string): DatatypeImpl => { + return new DatatypeImpl(ctx, name); + }, + { + createDatatypes(...datatypes: DatatypeImpl[]): DatatypeSortImpl[] { + return createDatatypes(...datatypes); + } + } + ); + //////////////// // Operations // //////////////// @@ -2647,6 +2664,185 @@ export function createApi(Z3: Z3Core): Z3HighLevel { } } + //////////////////////////// + // Datatypes + //////////////////////////// + + class DatatypeImpl implements Datatype { + readonly ctx: Context; + readonly name: string; + public constructors: Array<[string, Array<[string, Sort | Datatype]>]> = []; + + constructor(ctx: Context, name: string) { + this.ctx = ctx; + this.name = name; + } + + declare(name: string, ...fields: Array<[string, Sort | Datatype]>): this { + this.constructors.push([name, fields]); + return this; + } + + create(): DatatypeSort { + const datatypes = createDatatypes(this); + return datatypes[0]; + } + } + + class DatatypeSortImpl extends SortImpl implements DatatypeSort { + declare readonly __typename: DatatypeSort['__typename']; + + numConstructors(): number { + return Z3.get_datatype_sort_num_constructors(contextPtr, this.ptr); + } + + constructorDecl(idx: number): FuncDecl { + const ptr = Z3.get_datatype_sort_constructor(contextPtr, this.ptr, idx); + return new FuncDeclImpl(ptr); + } + + recognizer(idx: number): FuncDecl { + const ptr = Z3.get_datatype_sort_recognizer(contextPtr, this.ptr, idx); + return new FuncDeclImpl(ptr); + } + + accessor(constructorIdx: number, accessorIdx: number): FuncDecl { + const ptr = Z3.get_datatype_sort_constructor_accessor(contextPtr, this.ptr, constructorIdx, accessorIdx); + return new FuncDeclImpl(ptr); + } + + cast(other: CoercibleToExpr): DatatypeExpr; + cast(other: DatatypeExpr): DatatypeExpr; + cast(other: CoercibleToExpr | DatatypeExpr): DatatypeExpr { + if (isExpr(other)) { + assert(this.eqIdentity(other.sort), 'Value cannot be converted to this datatype'); + return other as DatatypeExpr; + } + throw new Error('Cannot coerce value to datatype expression'); + } + + subsort(other: Sort) { + _assertContext(other.ctx); + return this.eqIdentity(other); + } + } + + class DatatypeExprImpl extends ExprImpl implements DatatypeExpr { + declare readonly __typename: DatatypeExpr['__typename']; + } + + function createDatatypes(...datatypes: DatatypeImpl[]): DatatypeSortImpl[] { + if (datatypes.length === 0) { + throw new Error('At least one datatype must be provided'); + } + + // All datatypes must be from the same context + const dtCtx = datatypes[0].ctx; + for (const dt of datatypes) { + if (dt.ctx !== dtCtx) { + throw new Error('All datatypes must be from the same context'); + } + } + + const sortNames = datatypes.map(dt => dt.name); + const constructorLists: Z3_constructor_list[] = []; + const scopedConstructors: Z3_constructor[] = []; + + try { + // Create constructor lists for each datatype + for (const dt of datatypes) { + const constructors: Z3_constructor[] = []; + + for (const [constructorName, fields] of dt.constructors) { + const fieldNames: string[] = []; + const fieldSorts: Z3_sort[] = []; + const fieldRefs: number[] = []; + + for (const [fieldName, fieldSort] of fields) { + fieldNames.push(fieldName); + + if (fieldSort instanceof DatatypeImpl) { + // Reference to another datatype being defined + const refIndex = datatypes.indexOf(fieldSort); + if (refIndex === -1) { + throw new Error(`Referenced datatype "${fieldSort.name}" not found in datatypes being created`); + } + // For recursive references, we pass null and the ref index + fieldSorts.push(null as any); // null will be handled by the Z3 API + fieldRefs.push(refIndex); + } else { + // Regular sort + fieldSorts.push((fieldSort as Sort).ptr); + fieldRefs.push(0); + } + } + + const constructor = Z3.mk_constructor( + contextPtr, + Z3.mk_string_symbol(contextPtr, constructorName), + Z3.mk_string_symbol(contextPtr, `is_${constructorName}`), + fieldNames.map(name => Z3.mk_string_symbol(contextPtr, name)), + fieldSorts, + fieldRefs + ); + constructors.push(constructor); + scopedConstructors.push(constructor); + } + + const constructorList = Z3.mk_constructor_list(contextPtr, constructors); + constructorLists.push(constructorList); + } + + // Create the datatypes + const sortSymbols = sortNames.map(name => Z3.mk_string_symbol(contextPtr, name)); + const resultSorts = Z3.mk_datatypes(contextPtr, sortSymbols, constructorLists); + + // Create DatatypeSortImpl instances + const results: DatatypeSortImpl[] = []; + for (let i = 0; i < resultSorts.length; i++) { + const sortImpl = new DatatypeSortImpl(resultSorts[i]); + + // Attach constructor, recognizer, and accessor functions dynamically + const numConstructors = sortImpl.numConstructors(); + for (let j = 0; j < numConstructors; j++) { + const constructor = sortImpl.constructorDecl(j); + const recognizer = sortImpl.recognizer(j); + const constructorName = constructor.name().toString(); + + // Attach constructor function + if (constructor.arity() === 0) { + // Nullary constructor (constant) + (sortImpl as any)[constructorName] = constructor.call(); + } else { + (sortImpl as any)[constructorName] = constructor; + } + + // Attach recognizer function + (sortImpl as any)[`is_${constructorName}`] = recognizer; + + // Attach accessor functions + for (let k = 0; k < constructor.arity(); k++) { + const accessor = sortImpl.accessor(j, k); + const accessorName = accessor.name().toString(); + (sortImpl as any)[accessorName] = accessor; + } + } + + results.push(sortImpl); + } + + return results; + } finally { + // Clean up resources + for (const constructor of scopedConstructors) { + Z3.del_constructor(contextPtr, constructor); + } + for (const constructorList of constructorLists) { + Z3.del_constructor_list(contextPtr, constructorList); + } + } + } + class QuantifierImpl< QVarSorts extends NonEmptySortArray, QSort extends BoolSort | SMTArraySort, @@ -3029,6 +3225,7 @@ export function createApi(Z3: Z3Core): Z3HighLevel { BitVec, Array, Set, + Datatype, //////////////// // Operations // diff --git a/src/api/js/src/high-level/types.ts b/src/api/js/src/high-level/types.ts index 37d9c8f21..3c1ebaa10 100644 --- a/src/api/js/src/high-level/types.ts +++ b/src/api/js/src/high-level/types.ts @@ -3,6 +3,8 @@ import { Z3_ast_map, Z3_ast_vector, Z3_context, + Z3_constructor, + Z3_constructor_list, Z3_decl_kind, Z3_func_decl, Z3_func_entry, @@ -362,6 +364,8 @@ export interface Context { readonly Array: SMTArrayCreation; /** @category Expressions */ readonly Set: SMTSetCreation; + /** @category Expressions */ + readonly Datatype: DatatypeCreation; //////////////// // Operations // @@ -842,7 +846,8 @@ export interface Sort extends Ast { | BoolSort['__typename'] | ArithSort['__typename'] | BitVecSort['__typename'] - | SMTArraySort['__typename']; + | SMTArraySort['__typename'] + | DatatypeSort['__typename']; kind(): Z3_sort_kind; @@ -966,7 +971,8 @@ export interface Expr = AnySo | Bool['__typename'] | Arith['__typename'] | BitVec['__typename'] - | SMTArray['__typename']; + | SMTArray['__typename'] + | DatatypeExpr['__typename']; get sort(): S; @@ -1653,6 +1659,111 @@ export interface SMTSet): Bool; } +////////////////////////////////////////// +// +// Datatypes +// +////////////////////////////////////////// + +/** + * Helper class for declaring Z3 datatypes. + * + * Follows the same pattern as Python Z3 API for declaring constructors + * before creating the actual datatype sort. + * + * @example + * ```typescript + * const List = new ctx.Datatype('List'); + * List.declare('cons', ['car', ctx.Int.sort()], ['cdr', List]); + * List.declare('nil'); + * const ListSort = List.create(); + * ``` + * + * @category Datatypes + */ +export interface Datatype { + readonly ctx: Context; + readonly name: string; + + /** + * Declare a constructor for this datatype. + * + * @param name Constructor name + * @param fields Array of [field_name, field_sort] pairs + */ + declare(name: string, ...fields: Array<[string, AnySort | Datatype]>): this; + + /** + * Create the actual datatype sort from the declared constructors. + * For mutually recursive datatypes, use Context.createDatatypes instead. + */ + create(): DatatypeSort; +} + +/** + * @category Datatypes + */ +export interface DatatypeCreation { + /** + * Create a new datatype declaration helper. + */ + (name: string): Datatype; + + /** + * Create mutually recursive datatypes. + * + * @param datatypes Array of Datatype declarations + * @returns Array of created DatatypeSort instances + */ + createDatatypes(...datatypes: Datatype[]): DatatypeSort[]; +} + +/** + * A Sort representing an algebraic datatype. + * + * After creation, this sort will have constructor, recognizer, and accessor + * functions dynamically attached based on the declared constructors. + * + * @category Datatypes + */ +export interface DatatypeSort extends Sort { + /** @hidden */ + readonly __typename: 'DatatypeSort'; + + /** + * Number of constructors in this datatype + */ + numConstructors(): number; + + /** + * Get the idx'th constructor function declaration + */ + constructorDecl(idx: number): FuncDecl; + + /** + * Get the idx'th recognizer function declaration + */ + recognizer(idx: number): FuncDecl; + + /** + * Get the accessor function declaration for the idx_a'th field of the idx_c'th constructor + */ + accessor(constructorIdx: number, accessorIdx: number): FuncDecl; + + cast(other: CoercibleToExpr): DatatypeExpr; + + cast(other: DatatypeExpr): DatatypeExpr; +} + +/** + * Represents expressions of datatype sorts. + * + * @category Datatypes + */ +export interface DatatypeExpr extends Expr, Z3_ast> { + /** @hidden */ + readonly __typename: 'DatatypeExpr'; +} /** * Defines the expression type of the body of a quantifier expression diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index 5861d7511..2c67532dc 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -41,50 +41,62 @@ namespace smt { namespace smt { void parallel::worker::run() { - ast_translation tr(ctx->m, m); + ast_translation g2l(ctx->m, m); + ast_translation l2g(m, ctx->m); while (m.inc()) { vector cubes; - b.get_cubes(tr, cubes); + b.get_cubes(g2l, cubes); if (cubes.empty()) return; for (auto& cube : cubes) { if (!m.inc()) return; // stop if the main context is cancelled switch (check_cube(cube)) { - case l_undef: + case l_undef: { // return unprocessed cubes to the batch manager // add a split literal to the batch manager. // optionally process other cubes and delay sending back unprocessed cubes to batch manager. - b.m_cubes.push_back(cube); // TODO: add access funcs for m_cubes + vector returned_cubes; + returned_cubes.push_back(cube); + auto split_atoms = get_split_atoms(); + b.return_cubes(l2g, returned_cubes, split_atoms); break; + } case l_true: { model_ref mdl; ctx->get_model(mdl); - if (mdl) - ctx->set_model(mdl->translate(tr)); - //b.set_sat(tr, *mdl); + b.set_sat(l2g, *mdl); return; } - case l_false: + case l_false: { // if unsat core only contains (external) assumptions (i.e. all the unsat core are asms), then unsat and return as this does NOT depend on cubes // otherwise, extract lemmas that can be shared (units (and unsat core?)). // share with batch manager. // process next cube. - ctx->m_unsat_core.reset(); - for (expr* e : pctx.unsat_core()) // TODO: move this logic to the batch manager since this is per-thread - ctx->m_unsat_core.push_back(tr(e)); + auto const& unsat_core = ctx->unsat_core(); + // If the unsat core only contains assumptions, + // unsatisfiability does not depend on the current cube and the entire problem is unsat. + if (any_of(unsat_core, [&](expr* e) { return asms.contains(e); })) { + b.set_unsat(l2g, ctx->unsat_core()); + return; + } + // TODO: can share lemmas here, such as new units and not(and(unsat_core)), binary clauses, etc. + // TODO: remember assumptions used in core so that they get used for the final core. break; } + } } } } - parallel::worker::worker(parallel& p, context& _ctx, expr_ref_vector const& _asms): p(p), b(p.m_batch_manager), m_smt_params(_ctx.get_fparams()), asms(m) { - ast_translation g2l(_ctx.m, m); + parallel::worker::worker(unsigned id, parallel& p, expr_ref_vector const& _asms): id(id), p(p), b(p.m_batch_manager), m_smt_params(p.ctx.get_fparams()), asms(m) { + ast_translation g2l(p.ctx.m, m); for (auto e : _asms) asms.push_back(g2l(e)); m_smt_params.m_preprocess = false; - ctx = alloc(context, m, m_smt_params, _ctx.get_params()); + ctx = alloc(context, m, m_smt_params, p.ctx.get_params()); + context::copy(p.ctx, *ctx, true); + ctx->set_random_seed(id + m_smt_params.m_random_seed); } @@ -92,10 +104,21 @@ namespace smt { // THERE IS AN EDGE CASE: IF ALL THE CUBES ARE UNSAT, BUT DEPEND ON NONEMPTY ASSUMPTIONS, NEED TO TAKE THE UNION OF THESE ASMS WHEN LEARNING FROM UNSAT CORE // DON'T CODE THIS CASE YET: WE ARE JUST TESTING WITH EMPTY ASMS FOR NOW (I.E. WE ARE NOT PASSING IN ASMS). THIS DOES NOT APPLY TO THE INTERNAL "LEARNED" UNSAT CORE lbool parallel::worker::check_cube(expr_ref_vector const& cube) { - for (auto& atom : cube) { - asms.push_back(atom); + for (auto& atom : cube) + asms.push_back(atom); + lbool r = l_undef; + try { + r = ctx->check(asms.size(), asms.data()); + } + catch (z3_error& err) { + b.set_exception(err.error_code()); + } + catch (z3_exception& ex) { + b.set_exception(ex.what()); + } + catch (...) { + b.set_exception("unknown exception"); } - lbool r = ctx->check(asms.size(), asms.data()); asms.shrink(asms.size() - cube.size()); return r; } @@ -121,19 +144,64 @@ namespace smt { void parallel::batch_manager::set_sat(ast_translation& l2g, model& m) { std::scoped_lock lock(mux); - if (m_result == l_true || m_result == l_undef) { - m_result = l_true; + if (l_true == m_result) return; - } m_result = l_true; - for (auto& c : m_cubes) { - expr_ref_vector g_cube(l2g.to()); - for (auto& e : c) { - g_cube.push_back(l2g(e)); - } - share_lemma(l2g, mk_and(g_cube)); - } + p.ctx.set_model(m.translate(l2g)); + cancel_workers(); } + + void parallel::batch_manager::set_unsat(ast_translation& l2g, expr_ref_vector const& unsat_core) { + std::scoped_lock lock(mux); + if (l_false == m_result) + return; + m_result = l_false; + expr_ref_vector g_core(l2g.to()); + for (auto& e : unsat_core) + g_core.push_back(l2g(e)); + p.ctx.m_unsat_core.reset(); + for (expr* e : unsat_core) + p.ctx.m_unsat_core.push_back(l2g(e)); + cancel_workers(); + } + + void parallel::batch_manager::set_exception(unsigned error_code) { + std::scoped_lock lock(mux); + if (m_exception_kind != NO_EX) + return; // already set + m_exception_kind = ERROR_CODE_EX; + m_exception_code = error_code; + cancel_workers(); + } + + void parallel::batch_manager::set_exception(std::string const& msg) { + std::scoped_lock lock(mux); + if (m_exception_kind != NO_EX) + return; // already set + m_exception_kind = ERROR_MSG_EX; + m_exception_msg = msg; + cancel_workers(); + } + + lbool parallel::batch_manager::get_result() const { + if (m_exception_kind == ERROR_MSG_EX) + throw default_exception(m_exception_msg.c_str()); + if (m_exception_kind == ERROR_CODE_EX) + throw z3_error(m_exception_code); + if (m.limit().is_canceled()) + return l_undef; // the main context was cancelled, so we return undef. + return m_result; + } + +#if 0 + for (auto& c : m_cubes) { + expr_ref_vector g_cube(l2g.to()); + for (auto& e : c) { + g_cube.push_back(l2g(e)); + } + share_lemma(l2g, mk_and(g_cube)); + } +#endif // CALL GET_SPLIT_ATOMS AS ARGUMENT TO RETURN_CUBES void parallel::batch_manager::return_cubes(ast_translation& l2g, vectorconst& cubes, expr_ref_vector const& split_atoms) { @@ -172,25 +240,32 @@ namespace smt { expr_ref_vector top_lits(m); for (const auto& node : candidates) { - if (ctx->get_assignment(node.key) != l_undef) continue; + if (ctx->get_assignment(node.key) != l_undef) + continue; expr* e = ctx->bool_var2expr(node.key); - if (!e) continue; + if (!e) + continue; top_lits.push_back(expr_ref(e, m)); - if (top_lits.size() >= k) break; + if (top_lits.size() >= k) + break; } return top_lits; } lbool parallel::new_check(expr_ref_vector const& asms) { ast_manager& m = ctx.m; + + if (m.has_trace_stream()) + throw default_exception("trace streams have to be off in parallel mode"); + { scoped_limits sl(m.limit()); unsigned num_threads = std::min((unsigned)std::thread::hardware_concurrency(), ctx.get_fparams().m_threads); SASSERT(num_threads > 1); for (unsigned i = 0; i < num_threads; ++i) - m_workers.push_back(alloc(worker, *this, ctx, asms)); + m_workers.push_back(alloc(worker, i, *this, asms)); // THIS WILL ALLOW YOU TO CANCEL ALL THE CHILD THREADS // within the lexical scope of the code block, creates a data structure that allows you to push children @@ -210,6 +285,9 @@ namespace smt { // Wait for all threads to finish for (auto& th : threads) th.join(); + + for (auto w : m_workers) + w->collect_statistics(ctx.m_aux_stats); } m_workers.clear(); return m_batch_manager.get_result(); diff --git a/src/smt/smt_parallel.h b/src/smt/smt_parallel.h index 9d0a3de3f..e99c95367 100644 --- a/src/smt/smt_parallel.h +++ b/src/smt/smt_parallel.h @@ -27,6 +27,12 @@ namespace smt { unsigned num_threads; class batch_manager { + + enum exception_kind { + NO_EX, + ERROR_CODE_EX, + ERROR_MSG_EX + }; ast_manager& m; parallel& p; std::mutex mux; @@ -34,10 +40,19 @@ namespace smt { vector m_cubes; lbool m_result = l_false; // want states: init/undef, canceled/exception, sat, unsat unsigned m_max_batch_size = 10; + exception_kind m_exception_kind = NO_EX; + unsigned m_exception_code = 0; + std::string m_exception_msg; + + // called from batch manager to cancel other workers if we've reached a verdict + void cancel_workers() { + for (auto& w : p.m_workers) + w->cancel(); + } public: batch_manager(ast_manager& m, parallel& p) : m(m), p(p), m_split_atoms(m) { m_cubes.push_back(expr_ref_vector(m)); } - void set_unsat(); + void set_unsat(ast_translation& l2g, expr_ref_vector const& unsat_core); void set_sat(ast_translation& l2g, model& m); void set_exception(std::string const& msg); void set_exception(unsigned error_code); @@ -55,11 +70,11 @@ namespace smt { // void return_cubes(ast_translation& l2g, vectorconst& cubes, expr_ref_vector const& split_atoms); void share_lemma(ast_translation& l2g, expr* lemma); - void cancel_workers(); // called from batch manager to cancel other workers if we've reached a verdict - lbool get_result() const { return m.limit().is_canceled() ? l_undef : m_result; } + lbool get_result() const; }; class worker { + unsigned id; // unique identifier for the worker parallel& p; batch_manager& b; ast_manager m; @@ -71,7 +86,7 @@ namespace smt { void share_units(); lbool check_cube(expr_ref_vector const& cube); public: - worker(parallel& p, context& _ctx, expr_ref_vector const& _asms); + worker(unsigned id, parallel& p, expr_ref_vector const& _asms); void run(); expr_ref_vector get_split_atoms(); void cancel() {