diff --git a/src/api/js/src/high-level/high-level.test.ts b/src/api/js/src/high-level/high-level.test.ts index fb8805d99..e9a2b46e7 100644 --- a/src/api/js/src/high-level/high-level.test.ts +++ b/src/api/js/src/high-level/high-level.test.ts @@ -890,4 +890,74 @@ describe('high-level', () => { expect(model.eval(z).eqIdentity(Int.val(5))).toBeTruthy(); }); }); + + describe('datatypes', () => { + it('should create simple enum datatype', async () => { + const { Datatype, Int, Bool, Solver } = api.Context('main'); + + // Create a simple Color enum datatype + const Color = Datatype('Color'); + Color.declare('red'); + Color.declare('green'); + Color.declare('blue'); + + const ColorSort = Color.create(); + + // Test that we can access the constructors + expect(typeof (ColorSort as any).red).not.toBe('undefined'); + expect(typeof (ColorSort as any).green).not.toBe('undefined'); + expect(typeof (ColorSort as any).blue).not.toBe('undefined'); + + // Test that we can access the recognizers + expect(typeof (ColorSort as any).is_red).not.toBe('undefined'); + expect(typeof (ColorSort as any).is_green).not.toBe('undefined'); + expect(typeof (ColorSort as any).is_blue).not.toBe('undefined'); + }); + + it('should create recursive list datatype', async () => { + const { Datatype, Int, Solver } = api.Context('main'); + + // Create a recursive List datatype like in the Python example + const List = Datatype('List'); + List.declare('cons', ['car', Int.sort()], ['cdr', List]); + List.declare('nil'); + + const ListSort = List.create(); + + // Test that constructors and accessors exist + expect(typeof (ListSort as any).cons).not.toBe('undefined'); + expect(typeof (ListSort as any).nil).not.toBe('undefined'); + expect(typeof (ListSort as any).is_cons).not.toBe('undefined'); + expect(typeof (ListSort as any).is_nil).not.toBe('undefined'); + expect(typeof (ListSort as any).car).not.toBe('undefined'); + expect(typeof (ListSort as any).cdr).not.toBe('undefined'); + }); + + it('should create mutually recursive tree datatypes', async () => { + const { Datatype, Int } = api.Context('main'); + + // Create mutually recursive Tree and TreeList datatypes + const Tree = Datatype('Tree'); + const TreeList = Datatype('TreeList'); + + Tree.declare('leaf', ['value', Int.sort()]); + Tree.declare('node', ['children', TreeList]); + TreeList.declare('nil'); + TreeList.declare('cons', ['car', Tree], ['cdr', TreeList]); + + const [TreeSort, TreeListSort] = Datatype.createDatatypes(Tree, TreeList); + + // Test that both datatypes have their constructors + expect(typeof (TreeSort as any).leaf).not.toBe('undefined'); + expect(typeof (TreeSort as any).node).not.toBe('undefined'); + expect(typeof (TreeListSort as any).nil).not.toBe('undefined'); + expect(typeof (TreeListSort as any).cons).not.toBe('undefined'); + + // Test accessors exist + expect(typeof (TreeSort as any).value).not.toBe('undefined'); + expect(typeof (TreeSort as any).children).not.toBe('undefined'); + expect(typeof (TreeListSort as any).car).not.toBe('undefined'); + expect(typeof (TreeListSort as any).cdr).not.toBe('undefined'); + }); + }); }); diff --git a/src/api/js/src/high-level/high-level.ts b/src/api/js/src/high-level/high-level.ts index 7d19df982..f53f2d8ca 100644 --- a/src/api/js/src/high-level/high-level.ts +++ b/src/api/js/src/high-level/high-level.ts @@ -17,6 +17,8 @@ import { Z3_ast_print_mode, Z3_ast_vector, Z3_context, + Z3_constructor, + Z3_constructor_list, Z3_decl_kind, Z3_error_code, Z3_func_decl, @@ -88,6 +90,10 @@ import { FuncEntry, SMTSetSort, SMTSet, + Datatype, + DatatypeSort, + DatatypeExpr, + DatatypeCreation, } from './types'; import { allSatisfy, assert, assertExhaustive } from './utils'; @@ -825,6 +831,17 @@ export function createApi(Z3: Z3Core): Z3HighLevel { } } + const Datatype = Object.assign( + (name: string): DatatypeImpl => { + return new DatatypeImpl(ctx, name); + }, + { + createDatatypes(...datatypes: DatatypeImpl[]): DatatypeSortImpl[] { + return createDatatypes(...datatypes); + } + } + ); + //////////////// // Operations // //////////////// @@ -2647,6 +2664,185 @@ export function createApi(Z3: Z3Core): Z3HighLevel { } } + //////////////////////////// + // Datatypes + //////////////////////////// + + class DatatypeImpl implements Datatype { + readonly ctx: Context; + readonly name: string; + public constructors: Array<[string, Array<[string, Sort | Datatype]>]> = []; + + constructor(ctx: Context, name: string) { + this.ctx = ctx; + this.name = name; + } + + declare(name: string, ...fields: Array<[string, Sort | Datatype]>): this { + this.constructors.push([name, fields]); + return this; + } + + create(): DatatypeSort { + const datatypes = createDatatypes(this); + return datatypes[0]; + } + } + + class DatatypeSortImpl extends SortImpl implements DatatypeSort { + declare readonly __typename: DatatypeSort['__typename']; + + numConstructors(): number { + return Z3.get_datatype_sort_num_constructors(contextPtr, this.ptr); + } + + constructorDecl(idx: number): FuncDecl { + const ptr = Z3.get_datatype_sort_constructor(contextPtr, this.ptr, idx); + return new FuncDeclImpl(ptr); + } + + recognizer(idx: number): FuncDecl { + const ptr = Z3.get_datatype_sort_recognizer(contextPtr, this.ptr, idx); + return new FuncDeclImpl(ptr); + } + + accessor(constructorIdx: number, accessorIdx: number): FuncDecl { + const ptr = Z3.get_datatype_sort_constructor_accessor(contextPtr, this.ptr, constructorIdx, accessorIdx); + return new FuncDeclImpl(ptr); + } + + cast(other: CoercibleToExpr): DatatypeExpr; + cast(other: DatatypeExpr): DatatypeExpr; + cast(other: CoercibleToExpr | DatatypeExpr): DatatypeExpr { + if (isExpr(other)) { + assert(this.eqIdentity(other.sort), 'Value cannot be converted to this datatype'); + return other as DatatypeExpr; + } + throw new Error('Cannot coerce value to datatype expression'); + } + + subsort(other: Sort) { + _assertContext(other.ctx); + return this.eqIdentity(other); + } + } + + class DatatypeExprImpl extends ExprImpl implements DatatypeExpr { + declare readonly __typename: DatatypeExpr['__typename']; + } + + function createDatatypes(...datatypes: DatatypeImpl[]): DatatypeSortImpl[] { + if (datatypes.length === 0) { + throw new Error('At least one datatype must be provided'); + } + + // All datatypes must be from the same context + const dtCtx = datatypes[0].ctx; + for (const dt of datatypes) { + if (dt.ctx !== dtCtx) { + throw new Error('All datatypes must be from the same context'); + } + } + + const sortNames = datatypes.map(dt => dt.name); + const constructorLists: Z3_constructor_list[] = []; + const scopedConstructors: Z3_constructor[] = []; + + try { + // Create constructor lists for each datatype + for (const dt of datatypes) { + const constructors: Z3_constructor[] = []; + + for (const [constructorName, fields] of dt.constructors) { + const fieldNames: string[] = []; + const fieldSorts: Z3_sort[] = []; + const fieldRefs: number[] = []; + + for (const [fieldName, fieldSort] of fields) { + fieldNames.push(fieldName); + + if (fieldSort instanceof DatatypeImpl) { + // Reference to another datatype being defined + const refIndex = datatypes.indexOf(fieldSort); + if (refIndex === -1) { + throw new Error(`Referenced datatype "${fieldSort.name}" not found in datatypes being created`); + } + // For recursive references, we pass null and the ref index + fieldSorts.push(null as any); // null will be handled by the Z3 API + fieldRefs.push(refIndex); + } else { + // Regular sort + fieldSorts.push((fieldSort as Sort).ptr); + fieldRefs.push(0); + } + } + + const constructor = Z3.mk_constructor( + contextPtr, + Z3.mk_string_symbol(contextPtr, constructorName), + Z3.mk_string_symbol(contextPtr, `is_${constructorName}`), + fieldNames.map(name => Z3.mk_string_symbol(contextPtr, name)), + fieldSorts, + fieldRefs + ); + constructors.push(constructor); + scopedConstructors.push(constructor); + } + + const constructorList = Z3.mk_constructor_list(contextPtr, constructors); + constructorLists.push(constructorList); + } + + // Create the datatypes + const sortSymbols = sortNames.map(name => Z3.mk_string_symbol(contextPtr, name)); + const resultSorts = Z3.mk_datatypes(contextPtr, sortSymbols, constructorLists); + + // Create DatatypeSortImpl instances + const results: DatatypeSortImpl[] = []; + for (let i = 0; i < resultSorts.length; i++) { + const sortImpl = new DatatypeSortImpl(resultSorts[i]); + + // Attach constructor, recognizer, and accessor functions dynamically + const numConstructors = sortImpl.numConstructors(); + for (let j = 0; j < numConstructors; j++) { + const constructor = sortImpl.constructorDecl(j); + const recognizer = sortImpl.recognizer(j); + const constructorName = constructor.name().toString(); + + // Attach constructor function + if (constructor.arity() === 0) { + // Nullary constructor (constant) + (sortImpl as any)[constructorName] = constructor.call(); + } else { + (sortImpl as any)[constructorName] = constructor; + } + + // Attach recognizer function + (sortImpl as any)[`is_${constructorName}`] = recognizer; + + // Attach accessor functions + for (let k = 0; k < constructor.arity(); k++) { + const accessor = sortImpl.accessor(j, k); + const accessorName = accessor.name().toString(); + (sortImpl as any)[accessorName] = accessor; + } + } + + results.push(sortImpl); + } + + return results; + } finally { + // Clean up resources + for (const constructor of scopedConstructors) { + Z3.del_constructor(contextPtr, constructor); + } + for (const constructorList of constructorLists) { + Z3.del_constructor_list(contextPtr, constructorList); + } + } + } + class QuantifierImpl< QVarSorts extends NonEmptySortArray, QSort extends BoolSort | SMTArraySort, @@ -3029,6 +3225,7 @@ export function createApi(Z3: Z3Core): Z3HighLevel { BitVec, Array, Set, + Datatype, //////////////// // Operations // diff --git a/src/api/js/src/high-level/types.ts b/src/api/js/src/high-level/types.ts index 37d9c8f21..3c1ebaa10 100644 --- a/src/api/js/src/high-level/types.ts +++ b/src/api/js/src/high-level/types.ts @@ -3,6 +3,8 @@ import { Z3_ast_map, Z3_ast_vector, Z3_context, + Z3_constructor, + Z3_constructor_list, Z3_decl_kind, Z3_func_decl, Z3_func_entry, @@ -362,6 +364,8 @@ export interface Context { readonly Array: SMTArrayCreation; /** @category Expressions */ readonly Set: SMTSetCreation; + /** @category Expressions */ + readonly Datatype: DatatypeCreation; //////////////// // Operations // @@ -842,7 +846,8 @@ export interface Sort extends Ast { | BoolSort['__typename'] | ArithSort['__typename'] | BitVecSort['__typename'] - | SMTArraySort['__typename']; + | SMTArraySort['__typename'] + | DatatypeSort['__typename']; kind(): Z3_sort_kind; @@ -966,7 +971,8 @@ export interface Expr = AnySo | Bool['__typename'] | Arith['__typename'] | BitVec['__typename'] - | SMTArray['__typename']; + | SMTArray['__typename'] + | DatatypeExpr['__typename']; get sort(): S; @@ -1653,6 +1659,111 @@ export interface SMTSet): Bool; } +////////////////////////////////////////// +// +// Datatypes +// +////////////////////////////////////////// + +/** + * Helper class for declaring Z3 datatypes. + * + * Follows the same pattern as Python Z3 API for declaring constructors + * before creating the actual datatype sort. + * + * @example + * ```typescript + * const List = new ctx.Datatype('List'); + * List.declare('cons', ['car', ctx.Int.sort()], ['cdr', List]); + * List.declare('nil'); + * const ListSort = List.create(); + * ``` + * + * @category Datatypes + */ +export interface Datatype { + readonly ctx: Context; + readonly name: string; + + /** + * Declare a constructor for this datatype. + * + * @param name Constructor name + * @param fields Array of [field_name, field_sort] pairs + */ + declare(name: string, ...fields: Array<[string, AnySort | Datatype]>): this; + + /** + * Create the actual datatype sort from the declared constructors. + * For mutually recursive datatypes, use Context.createDatatypes instead. + */ + create(): DatatypeSort; +} + +/** + * @category Datatypes + */ +export interface DatatypeCreation { + /** + * Create a new datatype declaration helper. + */ + (name: string): Datatype; + + /** + * Create mutually recursive datatypes. + * + * @param datatypes Array of Datatype declarations + * @returns Array of created DatatypeSort instances + */ + createDatatypes(...datatypes: Datatype[]): DatatypeSort[]; +} + +/** + * A Sort representing an algebraic datatype. + * + * After creation, this sort will have constructor, recognizer, and accessor + * functions dynamically attached based on the declared constructors. + * + * @category Datatypes + */ +export interface DatatypeSort extends Sort { + /** @hidden */ + readonly __typename: 'DatatypeSort'; + + /** + * Number of constructors in this datatype + */ + numConstructors(): number; + + /** + * Get the idx'th constructor function declaration + */ + constructorDecl(idx: number): FuncDecl; + + /** + * Get the idx'th recognizer function declaration + */ + recognizer(idx: number): FuncDecl; + + /** + * Get the accessor function declaration for the idx_a'th field of the idx_c'th constructor + */ + accessor(constructorIdx: number, accessorIdx: number): FuncDecl; + + cast(other: CoercibleToExpr): DatatypeExpr; + + cast(other: DatatypeExpr): DatatypeExpr; +} + +/** + * Represents expressions of datatype sorts. + * + * @category Datatypes + */ +export interface DatatypeExpr extends Expr, Z3_ast> { + /** @hidden */ + readonly __typename: 'DatatypeExpr'; +} /** * Defines the expression type of the body of a quantifier expression diff --git a/src/ast/euf/euf_ac_plugin.cpp b/src/ast/euf/euf_ac_plugin.cpp index 278a19f0c..ce0134439 100644 --- a/src/ast/euf/euf_ac_plugin.cpp +++ b/src/ast/euf/euf_ac_plugin.cpp @@ -278,17 +278,9 @@ namespace euf { if (!m_shared.empty()) out << "shared monomials:\n"; for (auto const& s : m_shared) { - out << g.bpp(s.n) << ": " << s.m << " r: " << g.bpp(s.n->get_root()) << "\n"; + out << g.bpp(s.n) << " r " << g.bpp(s.n->get_root()) << " - " << s.m << ": " << m_pp_ll(*this, monomial(s.m)) << "\n"; } -#if 0 - i = 0; - for (auto m : m_monomials) { - out << i << ": "; - display_monomial_ll(out, m); - out << "\n"; - ++i; - } -#endif + for (auto n : m_nodes) { if (!n) continue; @@ -361,19 +353,16 @@ namespace euf { if (!orient_equation(eq)) return false; -#if 1 if (is_reducing(eq)) is_active = true; -#else - - is_active = true; // set to active by default -#endif if (!is_active) { m_passive.push_back(eq); return true; } + eq.status = eq_status::is_to_simplify_eq; + m_active.push_back(eq); auto& ml = monomial(eq.l); auto& mr = monomial(eq.r); @@ -621,9 +610,9 @@ namespace euf { // simplify eq using processed TRACE(plugin, for (auto other_eq : forward_iterator(eq_id)) - tout << "forward iterator " << eq_id << " vs " << other_eq << " " << is_processed(other_eq) << "\n"); + tout << "forward iterator " << eq_pp_ll(*this, m_active[eq_id]) << " vs " << eq_pp_ll(*this, m_active[other_eq]) << "\n"); for (auto other_eq : forward_iterator(eq_id)) - if (is_processed(other_eq) && forward_simplify(eq_id, other_eq)) + if ((is_processed(other_eq) || is_reducing(other_eq)) && forward_simplify(eq_id, other_eq)) goto loop_start; auto& eq = m_active[eq_id]; @@ -914,6 +903,8 @@ namespace euf { set_status(dst_eq, eq_status::is_dead_eq); return true; } + SASSERT(!are_equal(m_active[src_eq], m_active[dst_eq])); + if (!is_equation_oriented(src)) return false; // check that src.l is a subset of dst.r @@ -1088,23 +1079,18 @@ namespace euf { // rewrite monomial to normal form. bool ac_plugin::reduce(ptr_vector& m, justification& j) { bool change = false; - unsigned sz = m.size(); do { init_loop: - if (m.size() == 1) - return change; bloom b; init_ref_counts(m, m_m_counts); for (auto n : m) { if (n->is_zero) { m[0] = n; m.shrink(1); + change = true; break; } for (auto eq : n->eqs) { - continue; - if (!is_reducing(eq)) // also can use processed? - continue; auto& src = m_active[eq]; if (!is_equation_oriented(src)) @@ -1116,17 +1102,16 @@ namespace euf { TRACE(plugin, display_equation_ll(tout << "reduce ", src) << "\n"); SASSERT(is_correct_ref_count(monomial(src.l), m_eq_counts)); - //display_equation_ll(std::cout << "reduce ", src) << ": "; - //display_monomial_ll(std::cout, m); + for (auto n : m) + for (auto s : n->shared) + m_shared_todo.insert(s); rewrite1(m_eq_counts, monomial(src.r), m_m_counts, m); - //display_monomial_ll(std::cout << " -> ", m) << "\n"; j = join(j, eq); change = true; goto init_loop; } } } while (false); - VERIFY(sz >= m.size()); return change; } @@ -1287,6 +1272,8 @@ namespace euf { continue; } change = true; + for (auto s : n->shared) + m_shared_todo.insert(s); if (r.size() == 0) // if r is empty, we can remove n from l continue; @@ -1407,9 +1394,11 @@ namespace euf { TRACE(plugin_verbose, tout << "num shared todo " << m_shared_todo.size() << "\n"); if (m_shared_todo.empty()) return; + while (!m_shared_todo.empty()) { auto idx = *m_shared_todo.begin(); - m_shared_todo.remove(idx); + m_shared_todo.remove(idx); + TRACE(plugin, tout << "index " << idx << " shared size " << m_shared.size() << "\n"); if (idx < m_shared.size()) simplify_shared(idx, m_shared[idx]); } @@ -1431,7 +1420,7 @@ namespace euf { auto old_m = s.m; auto old_n = monomial(old_m).m_src; ptr_vector m1(monomial(old_m).m_nodes); - TRACE(plugin_verbose, tout << "simplify shared: " << g.bpp(old_n) << ": " << m_pp_ll(*this, monomial(old_m)) << "\n"); + TRACE(plugin, tout << "simplify shared: " << g.bpp(old_n) << ": " << m_pp_ll(*this, monomial(old_m)) << "\n"); if (!reduce(m1, j)) return; diff --git a/src/ast/simplifiers/euf_completion.cpp b/src/ast/simplifiers/euf_completion.cpp index ce4dd578d..f3c40aa6d 100644 --- a/src/ast/simplifiers/euf_completion.cpp +++ b/src/ast/simplifiers/euf_completion.cpp @@ -268,7 +268,6 @@ namespace euf { expr_ref r(f, m); m_rewriter(r); f = r.get(); - // verbose_stream() << r << "\n"; auto cons = m.mk_app(symbol("consequence"), 1, &f, m.mk_bool_sort()); m_fmls.add(dependent_expr(m, cons, nullptr, nullptr)); } @@ -317,35 +316,43 @@ namespace euf { expr_ref y1(y, m); m_rewriter(x1); m_rewriter(y1); - + add_quantifiers(x1); add_quantifiers(y1); enode* a = mk_enode(x1); enode* b = mk_enode(y1); + if (a->get_root() == b->get_root()) - return; - m_egraph.merge(a, b, to_ptr(push_pr_dep(pr, d))); - m_egraph.propagate(); + return; + + TRACE(euf, tout << "merge and propagate\n"); add_children(a); add_children(b); + m_egraph.merge(a, b, to_ptr(push_pr_dep(pr, d))); + m_egraph.propagate(); + m_should_propagate = true; + +#if 0 auto a1 = mk_enode(x); - if (a1->get_root() != a->get_root()) { + auto b1 = mk_enode(y); + + if (a->get_root() != a1->get_root()) { + add_children(a1);; m_egraph.merge(a, a1, nullptr); m_egraph.propagate(); - add_children(a1); - } - auto b1 = mk_enode(y); - if (b1->get_root() != b->get_root()) { - TRACE(euf, tout << "merge and propagate\n"); - m_egraph.merge(b, b1, nullptr); - m_egraph.propagate(); - add_children(b1); } - m_should_propagate = true; - if (m_side_condition_solver) + if (b->get_root() != b1->get_root()) { + add_children(b1); + m_egraph.merge(b, b1, nullptr); + m_egraph.propagate(); + } +#endif + + if (m_side_condition_solver && a->get_root() != b->get_root()) m_side_condition_solver->add_constraint(f, pr, d); - IF_VERBOSE(1, verbose_stream() << "eq: " << mk_pp(x1, m) << " == " << mk_pp(y1, m) << "\n"); + IF_VERBOSE(1, verbose_stream() << "eq: " << a->get_root_id() << " " << b->get_root_id() << " " + << x1 << " == " << y1 << "\n"); } else if (m.is_not(f, f)) { enode* n = mk_enode(f); @@ -689,7 +696,7 @@ namespace euf { b = new (mem) binding(q, pat, max_generation, min_top, max_top); b->init(b); for (unsigned i = 0; i < n; ++i) - b->m_nodes[i] = _binding[i]; + b->m_nodes[i] = _binding[i]->get_root(); m_bindings.insert(b); get_trail().push(insert_map(m_bindings, b)); @@ -748,12 +755,13 @@ namespace euf { void completion::apply_binding(binding& b, quantifier* q, expr_ref_vector const& s) { var_subst subst(m); - expr_ref r = subst(q->get_expr(), s); + expr_ref r = subst(q->get_expr(), s); scoped_generation sg(*this, b.m_max_top_generation + 1); auto [pr, d] = get_dependency(q); if (pr) pr = m.mk_quant_inst(m.mk_or(m.mk_not(q), r), s.size(), s.data()); m_consequences.push_back(r); + TRACE(euf_completion, tout << "new instantiation: " << r << " q: " << mk_pp(q, m) << "\n"); add_constraint(r, pr, d); propagate_rules(); m_egraph.propagate(); @@ -1022,7 +1030,7 @@ namespace euf { } enode* n = m_egraph.find(f); - + if (!n) n = mk_enode(f); enode* r = n->get_root(); d = m.mk_join(d, explain_eq(n, r)); d = m.mk_join(d, m_deps.get(r->get_id(), nullptr)); diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 6eb070b7c..a6477761c 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -199,7 +199,7 @@ namespace smt { }; lit_node* m_dll_lits; - + // svector> m_lit_scores; svector m_lit_scores[2]; clause_vector m_aux_clauses; diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index 967b95773..44d32c6e8 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -40,7 +40,6 @@ namespace smt { namespace smt { - void parallel::worker::run() { ast_translation tr(ctx->m, m); while (m.inc()) { @@ -56,10 +55,13 @@ namespace smt { // return unprocessed cubes to the batch manager // add a split literal to the batch manager. // optionally process other cubes and delay sending back unprocessed cubes to batch manager. + b.m_cubes.push_back(cube); // TODO: add access funcs for m_cubes break; case l_true: { model_ref mdl; ctx->get_model(mdl); + if (mdl) + ctx->set_model(mdl->translate(tr)); //b.set_sat(tr, *mdl); return; } @@ -68,6 +70,9 @@ namespace smt { // otherwise, extract lemmas that can be shared (units (and unsat core?)). // share with batch manager. // process next cube. + ctx->m_unsat_core.reset(); + for (expr* e : pctx.unsat_core()) // TODO: move this logic to the batch manager since this is per-thread + ctx->m_unsat_core.push_back(tr(e)); break; } } @@ -75,7 +80,6 @@ namespace smt { } parallel::worker::worker(parallel& p, context& _ctx, expr_ref_vector const& _asms): p(p), b(p.m_batch_manager), m_smt_params(_ctx.get_fparams()), asms(m) { - ast_translation g2l(_ctx.m, m); for (auto e : _asms) asms.push_back(g2l(e)); @@ -85,8 +89,12 @@ namespace smt { lbool parallel::worker::check_cube(expr_ref_vector const& cube) { - - return l_undef; + for (auto& atom : cube) { + asms.push_back(atom); + } + lbool r = ctx->check(asms.size(), asms.data()); + asms.shrink(asms.size() - cube.size()); + return r; } void parallel::batch_manager::get_cubes(ast_translation& g2l, vector& cubes) { @@ -96,9 +104,8 @@ namespace smt { cubes.push_back(expr_ref_vector(g2l.to())); return; } - // TODO adjust to number of worker threads runnin. - // if the size of m_cubes is less than m_max_batch_size/ num_threads, then return fewer cubes. - for (unsigned i = 0; i < m_max_batch_size && !m_cubes.empty(); ++i) { + + for (unsigned i = 0; i < std::min(m_max_batch_size / p.num_threads, (unsigned)m_cubes.size()) && !m_cubes.empty(); ++i) { auto& cube = m_cubes.back(); expr_ref_vector l_cube(g2l.to()); for (auto& e : cube) { @@ -109,6 +116,21 @@ namespace smt { } } + void parallel::batch_manager::set_sat(ast_translation& l2g, model& m) { + std::scoped_lock lock(mux); + if (m_result == l_true || m_result == l_undef) { + m_result = l_true; + return; + } + m_result = l_true; + for (auto& c : m_cubes) { + expr_ref_vector g_cube(l2g.to()); + for (auto& e : c) { + g_cube.push_back(l2g(e)); + } + share_lemma(l2g, mk_and(g_cube)); + } + } void parallel::batch_manager::return_cubes(ast_translation& l2g, vectorconst& cubes, expr_ref_vector const& split_atoms) { std::scoped_lock lock(mux); @@ -120,6 +142,7 @@ namespace smt { // TODO: split this g_cube on m_split_atoms that are not already in g_cube as literals. m_cubes.push_back(g_cube); } + // TODO: avoid making m_cubes too large. for (auto& atom : split_atoms) { expr_ref g_atom(l2g.from()); @@ -136,9 +159,27 @@ namespace smt { } } + expr_ref_vector parallel::worker::get_split_atoms() { + unsigned k = 1; + + auto candidates = ctx->m_pq_scores.get_heap(); + std::sort(candidates.begin(), candidates.end(), + [](const auto& a, const auto& b) { return a.priority > b.priority; }); + + expr_ref_vector top_lits(m); + for (const auto& node : candidates) { + if (ctx->get_assignment(node.key) != l_undef) continue; + + expr* e = ctx->bool_var2expr(node.key); + if (!e) continue; + + top_lits.push_back(expr_ref(e, m)); + if (top_lits.size() >= k) break; + } + return top_lits; + } lbool parallel::new_check(expr_ref_vector const& asms) { - ast_manager& m = ctx.m; { scoped_limits sl(m.limit()); @@ -146,6 +187,11 @@ namespace smt { SASSERT(num_threads > 1); for (unsigned i = 0; i < num_threads; ++i) m_workers.push_back(alloc(worker, *this, ctx, asms)); + + // THIS WILL ALLOW YOU TO CANCEL ALL THE CHILD THREADS + // within the lexical scope of the code block, creates a data structure that allows you to push children + // objects to the limit object, so if someone cancels the parent object, the cancellation propagates to the children + // and that cancellation has the lifetime of the scope for (auto w : m_workers) sl.push_child(&(w->limit())); @@ -154,8 +200,7 @@ namespace smt { for (unsigned i = 0; i < num_threads; ++i) { threads[i] = std::thread([&, i]() { m_workers[i]->run(); - } - ); + }); } // Wait for all threads to finish @@ -175,18 +220,16 @@ namespace smt { unsigned max_conflicts = ctx.get_fparams().m_max_conflicts; // try first sequential with a low conflict budget to make super easy problems cheap - unsigned max_c = std::min(thread_max_conflicts, 40u); - flet _mc(ctx.get_fparams().m_max_conflicts, max_c); - result = ctx.check(asms.size(), asms.data()); - if (result != l_undef || ctx.m_num_conflicts < max_c) { - return result; - } + // GET RID OF THIS, AND IMMEDIATELY SEND TO THE MULTITHREADED CHECKER + // THE FIRST BATCH OF CUBES IS EMPTY, AND WE WILL SET ALL THREADS TO WORK ON THE ORIGINAL FORMULA enum par_exception_kind { DEFAULT_EX, ERROR_EX }; + // MOVE ALL OF THIS INSIDE THE WORKER THREAD AND CREATE/MANAGE LOCALLY + // SO THEN WE REMOVE THE ENCAPSULATING scoped_ptr_vector ETC, SMT_PARAMS BECOMES SMT_ vector smt_params; scoped_ptr_vector pms; scoped_ptr_vector pctxs; @@ -222,77 +265,6 @@ namespace smt { sl.push_child(&(new_m->limit())); } - - auto cube_pq = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) { - unsigned k = 3; // Number of top literals you want - - ast_manager& m = ctx.get_manager(); - - // Get the entire fixed-size priority queue (it's not that big) - auto candidates = ctx.m_pq_scores.get_heap(); // returns vector> - - // Sort descending by priority (higher priority first) - std::sort(candidates.begin(), candidates.end(), - [](const auto& a, const auto& b) { return a.priority > b.priority; }); - - expr_ref_vector conjuncts(m); - unsigned count = 0; - - for (const auto& node : candidates) { - if (ctx.get_assignment(node.key) != l_undef) continue; - - expr* e = ctx.bool_var2expr(node.key); - if (!e) continue; - - - expr_ref lit(e, m); - conjuncts.push_back(lit); - - if (++count >= k) break; - } - - c = mk_and(conjuncts); - lasms.push_back(c); - }; - - auto cube_score = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) { - vector> candidates; - unsigned k = 4; // Get top-k scoring literals - ast_manager& m = ctx.get_manager(); - - // Loop over first 100 Boolean vars - for (bool_var v = 0; v < 100; ++v) { - if (ctx.get_assignment(v) != l_undef) continue; - - expr* e = ctx.bool_var2expr(v); - if (!e) continue; - - literal lit(v, false); - double score = ctx.get_score(lit); - if (score == 0.0) continue; - - candidates.push_back(std::make_pair(expr_ref(e, m), score)); - } - - // Sort all candidate literals descending by score - std::sort(candidates.begin(), candidates.end(), - [](auto& a, auto& b) { return a.second > b.second; }); - - // Clear c and build it as conjunction of top-k - expr_ref_vector conjuncts(m); - - for (unsigned i = 0; i < std::min(k, (unsigned)candidates.size()); ++i) { - expr_ref lit = candidates[i].first; - conjuncts.push_back(lit); - } - - // Build conjunction and store in c - c = mk_and(conjuncts); - - // Add the single cube formula to lasms (not each literal separately) - lasms.push_back(c); - }; - obj_hashtable unit_set; expr_ref_vector unit_trail(ctx.m); unsigned_vector unit_lim; @@ -307,6 +279,9 @@ namespace smt { unsigned sz = pctx.assigned_literals().size(); for (unsigned j = unit_lim[i]; j < sz; ++j) { literal lit = pctx.assigned_literals()[j]; + //IF_VERBOSE(0, verbose_stream() << "(smt.thread " << i << " :unit " << lit << " " << pctx.is_relevant(lit.var()) << ")\n";); + if (!pctx.is_relevant(lit.var())) + continue; expr_ref e(pctx.bool_var2expr(lit.var()), pctx.m); if (lit.sign()) e = pctx.m.mk_not(e); expr_ref ce(tr(e.get()), ctx.m); @@ -331,275 +306,6 @@ namespace smt { IF_VERBOSE(1, verbose_stream() << "(smt.thread :units " << sz << ")\n"); }; - std::mutex mux; - - // Lambda defining the work each SMT thread performs - auto worker_thread = [&](int i, vector& cube_batch) { - try { - // Get thread-specific context and AST manager - context& pctx = *pctxs[i]; - ast_manager& pm = *pms[i]; - - // Initialize local assumptions and cube - expr_ref_vector lasms(pasms[i]); - - vector results; - for (expr_ref_vector& cube : cube_batch) { - expr_ref_vector lasms_copy(lasms); - - if (&cube.get_manager() != &pm) { - std::cerr << "Manager mismatch on cube: " << mk_bounded_pp(mk_and(cube), pm, 3) << "\n"; - UNREACHABLE(); // or throw - } - - for (expr* cube_lit : cube) { - lasms_copy.push_back(expr_ref(cube_lit, pm)); - } - - // Set the max conflict limit for this thread - pctx.get_fparams().m_max_conflicts = std::min(thread_max_conflicts, max_conflicts); - - // Optional verbose logging - IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i; - if (num_rounds > 0) verbose_stream() << " :round " << num_rounds; - verbose_stream() << " :cube " << mk_bounded_pp(mk_and(cube), pm, 3); - verbose_stream() << ")\n";); - - lbool r = pctx.check(lasms_copy.size(), lasms_copy.data()); - std::cout << "Thread " << i << " finished cube " << mk_bounded_pp(mk_and(cube), pm, 3) << " with result: " << r << "\n"; - results.push_back(r); - } - - lbool r = l_false; - for (lbool res : results) { - if (res == l_true) { - r = l_true; - } else if (res == l_undef) { - if (r == l_false) - r = l_undef; - } - } - - auto cube_intersects_core = [&](expr* cube, const expr_ref_vector &core) { - expr_ref_vector cube_lits(pctx.m); - flatten_and(cube, cube_lits); - for (expr* lit : cube_lits) - if (core.contains(lit)) - return true; - return false; - }; - - // Handle results based on outcome and conflict count - if (r == l_undef && pctx.m_num_conflicts >= max_conflicts) - ; // no-op, allow loop to continue - else if (r == l_undef && pctx.m_num_conflicts >= thread_max_conflicts) - return; // quit thread early - // If cube was unsat and it's in the core, learn from it. i.e. a thread can be UNSAT because the cube c contradicted F. In this case learn the negation of the cube ¬c - // else if (r == l_false) { - // // IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i << " :learn cube batch " << mk_bounded_pp(cube, pm, 3) << ")" << " unsat_core: " << pctx.unsat_core() << ")"); - // for (expr* cube : cube_batch) { // iterate over each cube in the batch - // if (cube_intersects_core(cube, pctx.unsat_core())) { - // // IF_VERBOSE(1, verbose_stream() << "(pruning cube: " << mk_bounded_pp(cube, pm, 3) << " given unsat core: " << pctx.unsat_core() << ")"); - // pctx.assert_expr(mk_not(mk_and(pctx.unsat_core()))); - // } - // } - // } - - // Begin thread-safe update of shared result state - bool first = false; - { - std::lock_guard lock(mux); - if (finished_id == UINT_MAX) { - finished_id = i; - first = true; - result = r; - done = true; - } - if (!first && r != l_undef && result == l_undef) { - finished_id = i; - result = r; - } - else if (!first) return; // nothing new to contribute - } - - // Cancel limits on other threads now that a result is known - for (ast_manager* m : pms) { - if (m != &pm) m->limit().cancel(); - } - } catch (z3_error & err) { - if (finished_id == UINT_MAX) { - error_code = err.error_code(); - ex_kind = ERROR_EX; - done = true; - } - } catch (z3_exception & ex) { - if (finished_id == UINT_MAX) { - ex_msg = ex.what(); - ex_kind = DEFAULT_EX; - done = true; - } - } catch (...) { - if (finished_id == UINT_MAX) { - ex_msg = "unknown exception"; - ex_kind = ERROR_EX; - done = true; - } - } - }; - - struct BatchManager { - std::mutex mtx; - vector> batches; - unsigned batch_idx = 0; - unsigned batch_size = 1; - - BatchManager(unsigned batch_size) : batch_size(batch_size) {} - - // translate the next SINGLE batch of batch_size cubes to the thread - vector get_next_batch( - ast_manager &main_ctx_m, - ast_manager &thread_m - ) { - std::lock_guard lock(mtx); - vector cube_batch; // ensure bound to thread manager - if (batch_idx >= batches.size()) return cube_batch; - - vector next_batch = batches[batch_idx]; - - for (const expr_ref_vector& cube : next_batch) { - expr_ref_vector translated_cube_lits(thread_m); - for (expr* lit : cube) { - // Translate each literal to the thread's manager - translated_cube_lits.push_back(translate(lit, main_ctx_m, thread_m)); - } - cube_batch.push_back(translated_cube_lits); - } - - ++batch_idx; - - return cube_batch; - } - - // returns a list (vector) of cubes, where each cube is an expr_ref_vector of literals - vector cube_batch_pq(context& ctx) { - unsigned k = 1; // generates 2^k cubes in the batch - ast_manager& m = ctx.get_manager(); - - auto candidates = ctx.m_pq_scores.get_heap(); - std::sort(candidates.begin(), candidates.end(), - [](const auto& a, const auto& b) { return a.priority > b.priority; }); - - expr_ref_vector top_lits(m); - for (const auto& node : candidates) { - if (ctx.get_assignment(node.key) != l_undef) continue; - - expr* e = ctx.bool_var2expr(node.key); - if (!e) continue; - - top_lits.push_back(expr_ref(e, m)); - if (top_lits.size() >= k) break; - } - - // std::cout << "Top lits:\n"; - // for (unsigned j = 0; j < top_lits.size(); ++j) { - // std::cout << " [" << j << "] " << mk_pp(top_lits[j].get(), m) << "\n"; - // } - - unsigned num_lits = top_lits.size(); - unsigned num_cubes = 1 << num_lits; // 2^num_lits combinations - - vector cube_batch; - - for (unsigned mask = 0; mask < num_cubes; ++mask) { - expr_ref_vector cube_lits(m); - for (unsigned i = 0; i < num_lits; ++i) { - expr_ref lit(top_lits[i].get(), m); - if ((mask >> i) & 1) - cube_lits.push_back(mk_not(lit)); - else - cube_lits.push_back(lit); - } - cube_batch.push_back(cube_lits); - } - - std::cout << "Cubes out:\n"; - for (unsigned j = 0; j < cube_batch.size(); ++j) { - std::cout << " [" << j << "]\n"; - for (unsigned k = 0; k < cube_batch[j].size(); ++k) { - std::cout << " [" << k << "] " << mk_pp(cube_batch[j][k].get(), m) << "\n"; - } - } - - return cube_batch; - }; - - // returns a vector of new cubes batches. each cube batch is a vector of expr_ref_vector cubes - vector> gen_new_batches(context& main_ctx) { - vector> cube_batches; - - // Get all cubes in the main context's manager - vector all_cubes = cube_batch_pq(main_ctx); - - ast_manager &m = main_ctx.get_manager(); - - // Partition into batches - for (unsigned start = 0; start < all_cubes.size(); start += batch_size) { - vector batch; - - unsigned end = std::min(start + batch_size, all_cubes.size()); - for (unsigned j = start; j < end; ++j) { - batch.push_back(all_cubes[j]); - } - - cube_batches.push_back(batch); - } - batch_idx = 0; // Reset index for next round - return cube_batches; - } - - void check_for_new_batches(context& main_ctx) { - std::lock_guard lock(mtx); - if (batch_idx >= batches.size()) { - batches = gen_new_batches(main_ctx); - } - } - }; - - BatchManager batch_manager(1); - - // Thread scheduling loop - while (true) { - vector threads(num_threads); - batch_manager.check_for_new_batches(ctx); - - // Launch threads - for (unsigned i = 0; i < num_threads; ++i) { - // [&, i] is the lambda's capture clause: capture all variables by reference (&) except i, which is captured by value. - threads[i] = std::thread([&, i]() { - while (!done) { - auto next_batch = batch_manager.get_next_batch(ctx.m, *pms[i]); - if (next_batch.empty()) break; // No more work - - worker_thread(i, next_batch); - } - }); - } - - // Wait for all threads to finish - for (auto & th : threads) { - th.join(); - } - - // Stop if one finished with a result - if (done) break; - - // Otherwise update shared state and retry - collect_units(); - ++num_rounds; - max_conflicts = (max_conflicts < thread_max_conflicts) ? 0 : (max_conflicts - thread_max_conflicts); - thread_max_conflicts *= 2; - } - // Gather statistics from all solver contexts for (context* c : pctxs) { c->collect_statistics(ctx.m_aux_stats); @@ -612,27 +318,7 @@ namespace smt { default: throw default_exception(std::move(ex_msg)); } } - - // Handle result: translate model/unsat core back to main context - model_ref mdl; - context& pctx = *pctxs[finished_id]; - ast_translation tr(*pms[finished_id], m); - switch (result) { - case l_true: - pctx.get_model(mdl); - if (mdl) - ctx.set_model(mdl->translate(tr)); - break; - case l_false: - ctx.m_unsat_core.reset(); - for (expr* e : pctx.unsat_core()) - ctx.m_unsat_core.push_back(tr(e)); - break; - default: - break; - } - - return result; + } } diff --git a/src/smt/smt_parallel.h b/src/smt/smt_parallel.h index 7bdea79e4..316213ad4 100644 --- a/src/smt/smt_parallel.h +++ b/src/smt/smt_parallel.h @@ -24,6 +24,7 @@ namespace smt { class parallel { context& ctx; + unsigned num_threads; class batch_manager { ast_manager& m; @@ -71,6 +72,7 @@ namespace smt { public: worker(parallel& p, context& _ctx, expr_ref_vector const& _asms); void run(); + expr_ref_vector get_split_atoms(); void cancel() { m.limit().cancel(); } @@ -88,7 +90,12 @@ namespace smt { lbool new_check(expr_ref_vector const& asms); public: - parallel(context& ctx): ctx(ctx), m_batch_manager(ctx.m, *this) {} + parallel(context& ctx) : + ctx(ctx), + num_threads(std::min( + (unsigned)std::thread::hardware_concurrency(), + ctx.get_fparams().m_threads)), + m_batch_manager(ctx.m, *this) {} lbool operator()(expr_ref_vector const& asms); diff --git a/src/solver/smt_logics.cpp b/src/solver/smt_logics.cpp index 1afea69dc..0942ed3fe 100644 --- a/src/solver/smt_logics.cpp +++ b/src/solver/smt_logics.cpp @@ -29,52 +29,56 @@ bool smt_logics::supported_logic(symbol const & s) { } bool smt_logics::logic_has_reals_only(symbol const& s) { + auto str = s.str(); return - s.str().find("LRA") != std::string::npos || - s.str().find("LRA") != std::string::npos || - s.str().find("NRA") != std::string::npos || - s.str().find("RDL") != std::string::npos; + str.find("LRA") != std::string::npos || + str.find("LRA") != std::string::npos || + str.find("NRA") != std::string::npos || + str.find("RDL") != std::string::npos; } bool smt_logics::logic_has_arith(symbol const & s) { + auto str = s.str(); return - s.str().find("LRA") != std::string::npos || - s.str().find("LIRA") != std::string::npos || - s.str().find("LIA") != std::string::npos || - s.str().find("LRA") != std::string::npos || - s.str().find("NRA") != std::string::npos || - s.str().find("NIRA") != std::string::npos || - s.str().find("NIA") != std::string::npos || - s.str().find("IDL") != std::string::npos || - s.str().find("RDL") != std::string::npos || - s == "QF_BVRE" || - s == "QF_FP" || - s == "FP" || - s == "QF_FPBV" || - s == "QF_BVFP" || - s == "QF_S" || + str.find("LRA") != std::string::npos || + str.find("LIRA") != std::string::npos || + str.find("LIA") != std::string::npos || + str.find("LRA") != std::string::npos || + str.find("NRA") != std::string::npos || + str.find("NIRA") != std::string::npos || + str.find("NIA") != std::string::npos || + str.find("IDL") != std::string::npos || + str.find("RDL") != std::string::npos || + str == "QF_BVRE" || + str == "QF_FP" || + str == "FP" || + str == "QF_FPBV" || + str == "QF_BVFP" || + str == "QF_S" || logic_is_all(s) || - s == "QF_FD" || - s == "HORN"; + str == "QF_FD" || + str == "HORN"; } bool smt_logics::logic_has_bv(symbol const & s) { + auto str = s.str(); return - s.str().find("BV") != std::string::npos || - s == "FP" || + str.find("BV") != std::string::npos || + str == "FP" || logic_is_all(s) || - s == "QF_FD" || - s == "SMTFD" || - s == "HORN"; + str == "QF_FD" || + str == "SMTFD" || + str == "HORN"; } bool smt_logics::logic_has_array(symbol const & s) { + auto str = s.str(); return - s.str().starts_with("QF_A") || - s.str().starts_with("A") || + str.starts_with("QF_A") || + str.starts_with("A") || logic_is_all(s) || - s == "SMTFD" || - s == "HORN"; + str == "SMTFD" || + str == "HORN"; } bool smt_logics::logic_has_seq(symbol const & s) { @@ -82,17 +86,28 @@ bool smt_logics::logic_has_seq(symbol const & s) { } bool smt_logics::logic_has_str(symbol const & s) { - return s == "QF_S" || s == "QF_SLIA" || s == "QF_SNIA" || logic_is_all(s); + auto str = s.str(); + return str == "QF_S" || + str == "QF_SLIA" || + str == "QF_SNIA" || + logic_is_all(s); } bool smt_logics::logic_has_fpa(symbol const & s) { - return s == "FP" || s == "QF_FP" || s == "QF_FPBV" || s == "QF_BVFP" || s == "QF_FPLRA" || logic_is_all(s); + auto str = s.str(); + return str == "FP" || + str == "QF_FP" || + str == "QF_FPBV" || + str == "QF_BVFP" || + str == "QF_FPLRA" || + logic_is_all(s); } bool smt_logics::logic_has_uf(symbol const & s) { + auto str = s.str(); return - s.str().find("UF") != std::string::npos || - s == "SMTFD"; + str.find("UF") != std::string::npos || + str == "SMTFD"; } bool smt_logics::logic_has_horn(symbol const& s) { @@ -104,9 +119,10 @@ bool smt_logics::logic_has_pb(symbol const& s) { } bool smt_logics::logic_has_datatype(symbol const& s) { + auto str = s.str(); return - s.str().find("DT") != std::string::npos || - s == "QF_FD" || + str.find("DT") != std::string::npos || + str == "QF_FD" || logic_is_all(s) || logic_has_horn(s); } diff --git a/src/util/event_handler.h b/src/util/event_handler.h index cabbca4c9..b3bbc8438 100644 --- a/src/util/event_handler.h +++ b/src/util/event_handler.h @@ -28,9 +28,8 @@ enum event_handler_caller_t { class event_handler { protected: - event_handler_caller_t m_caller_id; + event_handler_caller_t m_caller_id = UNSET_EH_CALLER; public: - event_handler(): m_caller_id(UNSET_EH_CALLER) {} virtual ~event_handler() = default; virtual void operator()(event_handler_caller_t caller_id) = 0; event_handler_caller_t caller_id() const { return m_caller_id; } diff --git a/src/util/gparams.cpp b/src/util/gparams.cpp index d2adc9f9f..7a81e000c 100644 --- a/src/util/gparams.cpp +++ b/src/util/gparams.cpp @@ -416,7 +416,7 @@ public: symbol sp(p.c_str()); std::ostringstream buffer; ps.display(buffer, sp); - return buffer.str(); + return std::move(buffer).str(); } std::string get_default(param_descrs const & d, std::string const & p, std::string const & m) {