diff --git a/src/api/CMakeLists.txt b/src/api/CMakeLists.txt index a537bb3b9..c1e78f473 100644 --- a/src/api/CMakeLists.txt +++ b/src/api/CMakeLists.txt @@ -44,6 +44,7 @@ z3_add_component(api api_context.cpp api_datalog.cpp api_datatype.cpp + api_finite_set.cpp api_fpa.cpp api_goal.cpp api_log.cpp diff --git a/src/api/api_ast.cpp b/src/api/api_ast.cpp index ded8b7089..69ac1303f 100644 --- a/src/api/api_ast.cpp +++ b/src/api/api_ast.cpp @@ -1465,6 +1465,25 @@ extern "C" { } } + if (mk_c(c)->fsutil().get_family_id() == _d->get_family_id()) { + switch(_d->get_decl_kind()) { + case OP_FINITE_SET_EMPTY: return Z3_OP_FINITE_SET_EMPTY; + case OP_FINITE_SET_SINGLETON: return Z3_OP_FINITE_SET_SINGLETON; + case OP_FINITE_SET_UNION: return Z3_OP_FINITE_SET_UNION; + case OP_FINITE_SET_INTERSECT: return Z3_OP_FINITE_SET_INTERSECT; + case OP_FINITE_SET_DIFFERENCE: return Z3_OP_FINITE_SET_DIFFERENCE; + case OP_FINITE_SET_IN: return Z3_OP_FINITE_SET_IN; + case OP_FINITE_SET_SIZE: return Z3_OP_FINITE_SET_SIZE; + case OP_FINITE_SET_SUBSET: return Z3_OP_FINITE_SET_SUBSET; + case OP_FINITE_SET_MAP: return Z3_OP_FINITE_SET_MAP; + case OP_FINITE_SET_FILTER: return Z3_OP_FINITE_SET_FILTER; + case OP_FINITE_SET_RANGE: return Z3_OP_FINITE_SET_RANGE; + case OP_FINITE_SET_EXT: return Z3_OP_FINITE_SET_EXT; + case OP_FINITE_SET_MAP_INVERSE: return Z3_OP_FINITE_SET_MAP_INVERSE; + default: return Z3_OP_INTERNAL; + } + } + if (mk_c(c)->recfun().get_family_id() == _d->get_family_id()) return Z3_OP_RECURSIVE; diff --git a/src/api/api_context.cpp b/src/api/api_context.cpp index ce1eed52c..7fe4e4063 100644 --- a/src/api/api_context.cpp +++ b/src/api/api_context.cpp @@ -133,6 +133,7 @@ namespace api { m_fpa_util(m()), m_sutil(m()), m_recfun(m()), + m_finite_set_util(m()), m_ast_trail(m()), m_pmanager(m_limit) { diff --git a/src/api/api_context.h b/src/api/api_context.h index e570daca3..780372588 100644 --- a/src/api/api_context.h +++ b/src/api/api_context.h @@ -31,6 +31,7 @@ Revision History: #include "ast/fpa_decl_plugin.h" #include "ast/recfun_decl_plugin.h" #include "ast/special_relations_decl_plugin.h" +#include "ast/finite_set_decl_plugin.h" #include "ast/rewriter/seq_rewriter.h" #include "params/smt_params.h" #include "smt/smt_kernel.h" @@ -77,6 +78,7 @@ namespace api { fpa_util m_fpa_util; seq_util m_sutil; recfun::util m_recfun; + finite_set_util m_finite_set_util; // Support for old solver API smt_params m_fparams; @@ -146,6 +148,7 @@ namespace api { datatype_util& dtutil() { return m_dt_plugin->u(); } seq_util& sutil() { return m_sutil; } recfun::util& recfun() { return m_recfun; } + finite_set_util& fsutil() { return m_finite_set_util; } family_id get_basic_fid() const { return basic_family_id; } family_id get_array_fid() const { return m_array_fid; } family_id get_arith_fid() const { return arith_family_id; } diff --git a/src/api/api_finite_set.cpp b/src/api/api_finite_set.cpp new file mode 100644 index 000000000..2a2787e2a --- /dev/null +++ b/src/api/api_finite_set.cpp @@ -0,0 +1,187 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + api_finite_set.cpp + +Abstract: + + API for finite sets. + +Author: + + Copilot 2025-01-21 + +Revision History: + +--*/ +#include "api/z3.h" +#include "api/api_log_macros.h" +#include "api/api_context.h" +#include "api/api_util.h" +#include "ast/ast_pp.h" + +extern "C" { + + Z3_sort Z3_API Z3_mk_finite_set_sort(Z3_context c, Z3_sort elem_sort) { + Z3_TRY; + LOG_Z3_mk_finite_set_sort(c, elem_sort); + RESET_ERROR_CODE(); + parameter param(to_sort(elem_sort)); + sort* ty = mk_c(c)->m().mk_sort(mk_c(c)->fsutil().get_family_id(), FINITE_SET_SORT, 1, ¶m); + mk_c(c)->save_ast_trail(ty); + RETURN_Z3(of_sort(ty)); + Z3_CATCH_RETURN(nullptr); + } + + bool Z3_API Z3_is_finite_set_sort(Z3_context c, Z3_sort s) { + Z3_TRY; + LOG_Z3_is_finite_set_sort(c, s); + RESET_ERROR_CODE(); + return mk_c(c)->fsutil().is_finite_set(to_sort(s)); + Z3_CATCH_RETURN(false); + } + + Z3_sort Z3_API Z3_get_finite_set_sort_basis(Z3_context c, Z3_sort s) { + Z3_TRY; + LOG_Z3_get_finite_set_sort_basis(c, s); + RESET_ERROR_CODE(); + sort* elem_sort = nullptr; + if (!mk_c(c)->fsutil().is_finite_set(to_sort(s), elem_sort)) { + SET_ERROR_CODE(Z3_INVALID_ARG, "expected finite set sort"); + RETURN_Z3(nullptr); + } + RETURN_Z3(of_sort(elem_sort)); + Z3_CATCH_RETURN(nullptr); + } + + Z3_ast Z3_API Z3_mk_finite_set_empty(Z3_context c, Z3_sort set_sort) { + Z3_TRY; + LOG_Z3_mk_finite_set_empty(c, set_sort); + RESET_ERROR_CODE(); + app* a = mk_c(c)->fsutil().mk_empty(to_sort(set_sort)); + mk_c(c)->save_ast_trail(a); + RETURN_Z3(of_ast(a)); + Z3_CATCH_RETURN(nullptr); + } + + Z3_ast Z3_API Z3_mk_finite_set_singleton(Z3_context c, Z3_ast elem) { + Z3_TRY; + LOG_Z3_mk_finite_set_singleton(c, elem); + RESET_ERROR_CODE(); + CHECK_IS_EXPR(elem, nullptr); + app* a = mk_c(c)->fsutil().mk_singleton(to_expr(elem)); + mk_c(c)->save_ast_trail(a); + RETURN_Z3(of_ast(a)); + Z3_CATCH_RETURN(nullptr); + } + + Z3_ast Z3_API Z3_mk_finite_set_union(Z3_context c, Z3_ast s1, Z3_ast s2) { + Z3_TRY; + LOG_Z3_mk_finite_set_union(c, s1, s2); + RESET_ERROR_CODE(); + CHECK_IS_EXPR(s1, nullptr); + CHECK_IS_EXPR(s2, nullptr); + app* a = mk_c(c)->fsutil().mk_union(to_expr(s1), to_expr(s2)); + mk_c(c)->save_ast_trail(a); + RETURN_Z3(of_ast(a)); + Z3_CATCH_RETURN(nullptr); + } + + Z3_ast Z3_API Z3_mk_finite_set_intersect(Z3_context c, Z3_ast s1, Z3_ast s2) { + Z3_TRY; + LOG_Z3_mk_finite_set_intersect(c, s1, s2); + RESET_ERROR_CODE(); + CHECK_IS_EXPR(s1, nullptr); + CHECK_IS_EXPR(s2, nullptr); + app* a = mk_c(c)->fsutil().mk_intersect(to_expr(s1), to_expr(s2)); + mk_c(c)->save_ast_trail(a); + RETURN_Z3(of_ast(a)); + Z3_CATCH_RETURN(nullptr); + } + + Z3_ast Z3_API Z3_mk_finite_set_difference(Z3_context c, Z3_ast s1, Z3_ast s2) { + Z3_TRY; + LOG_Z3_mk_finite_set_difference(c, s1, s2); + RESET_ERROR_CODE(); + CHECK_IS_EXPR(s1, nullptr); + CHECK_IS_EXPR(s2, nullptr); + app* a = mk_c(c)->fsutil().mk_difference(to_expr(s1), to_expr(s2)); + mk_c(c)->save_ast_trail(a); + RETURN_Z3(of_ast(a)); + Z3_CATCH_RETURN(nullptr); + } + + Z3_ast Z3_API Z3_mk_finite_set_member(Z3_context c, Z3_ast elem, Z3_ast set) { + Z3_TRY; + LOG_Z3_mk_finite_set_member(c, elem, set); + RESET_ERROR_CODE(); + CHECK_IS_EXPR(elem, nullptr); + CHECK_IS_EXPR(set, nullptr); + app* a = mk_c(c)->fsutil().mk_in(to_expr(elem), to_expr(set)); + mk_c(c)->save_ast_trail(a); + RETURN_Z3(of_ast(a)); + Z3_CATCH_RETURN(nullptr); + } + + Z3_ast Z3_API Z3_mk_finite_set_size(Z3_context c, Z3_ast set) { + Z3_TRY; + LOG_Z3_mk_finite_set_size(c, set); + RESET_ERROR_CODE(); + CHECK_IS_EXPR(set, nullptr); + app* a = mk_c(c)->fsutil().mk_size(to_expr(set)); + mk_c(c)->save_ast_trail(a); + RETURN_Z3(of_ast(a)); + Z3_CATCH_RETURN(nullptr); + } + + Z3_ast Z3_API Z3_mk_finite_set_subset(Z3_context c, Z3_ast s1, Z3_ast s2) { + Z3_TRY; + LOG_Z3_mk_finite_set_subset(c, s1, s2); + RESET_ERROR_CODE(); + CHECK_IS_EXPR(s1, nullptr); + CHECK_IS_EXPR(s2, nullptr); + app* a = mk_c(c)->fsutil().mk_subset(to_expr(s1), to_expr(s2)); + mk_c(c)->save_ast_trail(a); + RETURN_Z3(of_ast(a)); + Z3_CATCH_RETURN(nullptr); + } + + Z3_ast Z3_API Z3_mk_finite_set_map(Z3_context c, Z3_ast f, Z3_ast set) { + Z3_TRY; + LOG_Z3_mk_finite_set_map(c, f, set); + RESET_ERROR_CODE(); + CHECK_IS_EXPR(f, nullptr); + CHECK_IS_EXPR(set, nullptr); + app* a = mk_c(c)->fsutil().mk_map(to_expr(f), to_expr(set)); + mk_c(c)->save_ast_trail(a); + RETURN_Z3(of_ast(a)); + Z3_CATCH_RETURN(nullptr); + } + + Z3_ast Z3_API Z3_mk_finite_set_filter(Z3_context c, Z3_ast f, Z3_ast set) { + Z3_TRY; + LOG_Z3_mk_finite_set_filter(c, f, set); + RESET_ERROR_CODE(); + CHECK_IS_EXPR(f, nullptr); + CHECK_IS_EXPR(set, nullptr); + app* a = mk_c(c)->fsutil().mk_filter(to_expr(f), to_expr(set)); + mk_c(c)->save_ast_trail(a); + RETURN_Z3(of_ast(a)); + Z3_CATCH_RETURN(nullptr); + } + + Z3_ast Z3_API Z3_mk_finite_set_range(Z3_context c, Z3_ast low, Z3_ast high) { + Z3_TRY; + LOG_Z3_mk_finite_set_range(c, low, high); + RESET_ERROR_CODE(); + CHECK_IS_EXPR(low, nullptr); + CHECK_IS_EXPR(high, nullptr); + app* a = mk_c(c)->fsutil().mk_range(to_expr(low), to_expr(high)); + mk_c(c)->save_ast_trail(a); + RETURN_Z3(of_ast(a)); + Z3_CATCH_RETURN(nullptr); + } + +}; diff --git a/src/api/c++/z3++.h b/src/api/c++/z3++.h index 08dc1489f..45e9ef205 100644 --- a/src/api/c++/z3++.h +++ b/src/api/c++/z3++.h @@ -323,6 +323,10 @@ namespace z3 { \brief Return a regular expression sort over sequences \c seq_sort. */ sort re_sort(sort& seq_sort); + /** + \brief Return a finite set sort over element sort \c s. + */ + sort finite_set_sort(sort& s); /** \brief Return an array sort for arrays from \c d to \c r. @@ -3663,6 +3667,7 @@ namespace z3 { inline sort context::char_sort() { Z3_sort s = Z3_mk_char_sort(m_ctx); check_error(); return sort(*this, s); } inline sort context::seq_sort(sort& s) { Z3_sort r = Z3_mk_seq_sort(m_ctx, s); check_error(); return sort(*this, r); } inline sort context::re_sort(sort& s) { Z3_sort r = Z3_mk_re_sort(m_ctx, s); check_error(); return sort(*this, r); } + inline sort context::finite_set_sort(sort& s) { Z3_sort r = Z3_mk_finite_set_sort(m_ctx, s); check_error(); return sort(*this, r); } inline sort context::fpa_sort(unsigned ebits, unsigned sbits) { Z3_sort s = Z3_mk_fpa_sort(m_ctx, ebits, sbits); check_error(); return sort(*this, s); } template<> @@ -4264,6 +4269,54 @@ namespace z3 { MK_EXPR2(Z3_mk_set_subset, a, b); } + // finite set operations + + inline expr finite_set_empty(sort const& s) { + Z3_ast r = Z3_mk_finite_set_empty(s.ctx(), s); + s.check_error(); + return expr(s.ctx(), r); + } + + inline expr finite_set_singleton(expr const& e) { + MK_EXPR1(Z3_mk_finite_set_singleton, e); + } + + inline expr finite_set_union(expr const& a, expr const& b) { + MK_EXPR2(Z3_mk_finite_set_union, a, b); + } + + inline expr finite_set_intersect(expr const& a, expr const& b) { + MK_EXPR2(Z3_mk_finite_set_intersect, a, b); + } + + inline expr finite_set_difference(expr const& a, expr const& b) { + MK_EXPR2(Z3_mk_finite_set_difference, a, b); + } + + inline expr finite_set_member(expr const& e, expr const& s) { + MK_EXPR2(Z3_mk_finite_set_member, e, s); + } + + inline expr finite_set_size(expr const& s) { + MK_EXPR1(Z3_mk_finite_set_size, s); + } + + inline expr finite_set_subset(expr const& a, expr const& b) { + MK_EXPR2(Z3_mk_finite_set_subset, a, b); + } + + inline expr finite_set_map(expr const& f, expr const& s) { + MK_EXPR2(Z3_mk_finite_set_map, f, s); + } + + inline expr finite_set_filter(expr const& f, expr const& s) { + MK_EXPR2(Z3_mk_finite_set_filter, f, s); + } + + inline expr finite_set_range(expr const& low, expr const& high) { + MK_EXPR2(Z3_mk_finite_set_range, low, high); + } + // sequence and regular expression operations. // union is + // concat is overloaded to handle sequences and regular expressions diff --git a/src/api/dotnet/CMakeLists.txt b/src/api/dotnet/CMakeLists.txt index d3cb87bc7..c309f4027 100644 --- a/src/api/dotnet/CMakeLists.txt +++ b/src/api/dotnet/CMakeLists.txt @@ -64,6 +64,7 @@ set(Z3_DOTNET_ASSEMBLY_SOURCES_IN_SRC_TREE FiniteDomainExpr.cs FiniteDomainNum.cs FiniteDomainSort.cs + FiniteSetSort.cs Fixedpoint.cs FPExpr.cs FPNum.cs diff --git a/src/api/dotnet/Context.cs b/src/api/dotnet/Context.cs index 8ea4d70bc..52c9af8f6 100644 --- a/src/api/dotnet/Context.cs +++ b/src/api/dotnet/Context.cs @@ -2442,6 +2442,180 @@ namespace Microsoft.Z3 #endregion + #region Finite Sets + + /// + /// Create a finite set sort over the given element sort. + /// + public FiniteSetSort MkFiniteSetSort(Sort elemSort) + { + Debug.Assert(elemSort != null); + + CheckContextMatch(elemSort); + return new FiniteSetSort(this, elemSort); + } + + /// + /// Check if a sort is a finite set sort. + /// + public bool IsFiniteSetSort(Sort s) + { + Debug.Assert(s != null); + + CheckContextMatch(s); + return Native.Z3_is_finite_set_sort(nCtx, s.NativeObject) != 0; + } + + /// + /// Get the element sort (basis) of a finite set sort. + /// + public Sort GetFiniteSetSortBasis(Sort s) + { + Debug.Assert(s != null); + + CheckContextMatch(s); + return Sort.Create(this, Native.Z3_get_finite_set_sort_basis(nCtx, s.NativeObject)); + } + + /// + /// Create an empty finite set. + /// + public Expr MkFiniteSetEmpty(Sort setSort) + { + Debug.Assert(setSort != null); + + CheckContextMatch(setSort); + return Expr.Create(this, Native.Z3_mk_finite_set_empty(nCtx, setSort.NativeObject)); + } + + /// + /// Create a singleton finite set. + /// + public Expr MkFiniteSetSingleton(Expr elem) + { + Debug.Assert(elem != null); + + CheckContextMatch(elem); + return Expr.Create(this, Native.Z3_mk_finite_set_singleton(nCtx, elem.NativeObject)); + } + + /// + /// Create the union of two finite sets. + /// + public Expr MkFiniteSetUnion(Expr s1, Expr s2) + { + Debug.Assert(s1 != null); + Debug.Assert(s2 != null); + + CheckContextMatch(s1); + CheckContextMatch(s2); + return Expr.Create(this, Native.Z3_mk_finite_set_union(nCtx, s1.NativeObject, s2.NativeObject)); + } + + /// + /// Create the intersection of two finite sets. + /// + public Expr MkFiniteSetIntersect(Expr s1, Expr s2) + { + Debug.Assert(s1 != null); + Debug.Assert(s2 != null); + + CheckContextMatch(s1); + CheckContextMatch(s2); + return Expr.Create(this, Native.Z3_mk_finite_set_intersect(nCtx, s1.NativeObject, s2.NativeObject)); + } + + /// + /// Create the difference of two finite sets. + /// + public Expr MkFiniteSetDifference(Expr s1, Expr s2) + { + Debug.Assert(s1 != null); + Debug.Assert(s2 != null); + + CheckContextMatch(s1); + CheckContextMatch(s2); + return Expr.Create(this, Native.Z3_mk_finite_set_difference(nCtx, s1.NativeObject, s2.NativeObject)); + } + + /// + /// Check for membership in a finite set. + /// + public BoolExpr MkFiniteSetMember(Expr elem, Expr set) + { + Debug.Assert(elem != null); + Debug.Assert(set != null); + + CheckContextMatch(elem); + CheckContextMatch(set); + return (BoolExpr)Expr.Create(this, Native.Z3_mk_finite_set_member(nCtx, elem.NativeObject, set.NativeObject)); + } + + /// + /// Get the cardinality of a finite set. + /// + public Expr MkFiniteSetSize(Expr set) + { + Debug.Assert(set != null); + + CheckContextMatch(set); + return Expr.Create(this, Native.Z3_mk_finite_set_size(nCtx, set.NativeObject)); + } + + /// + /// Check if one finite set is a subset of another. + /// + public BoolExpr MkFiniteSetSubset(Expr s1, Expr s2) + { + Debug.Assert(s1 != null); + Debug.Assert(s2 != null); + + CheckContextMatch(s1); + CheckContextMatch(s2); + return (BoolExpr)Expr.Create(this, Native.Z3_mk_finite_set_subset(nCtx, s1.NativeObject, s2.NativeObject)); + } + + /// + /// Map a function over all elements in a finite set. + /// + public Expr MkFiniteSetMap(Expr f, Expr set) + { + Debug.Assert(f != null); + Debug.Assert(set != null); + + CheckContextMatch(f); + CheckContextMatch(set); + return Expr.Create(this, Native.Z3_mk_finite_set_map(nCtx, f.NativeObject, set.NativeObject)); + } + + /// + /// Filter a finite set with a predicate. + /// + public Expr MkFiniteSetFilter(Expr f, Expr set) + { + Debug.Assert(f != null); + Debug.Assert(set != null); + + CheckContextMatch(f); + CheckContextMatch(set); + return Expr.Create(this, Native.Z3_mk_finite_set_filter(nCtx, f.NativeObject, set.NativeObject)); + } + + /// + /// Create a finite set containing integers in the range [low, high]. + /// + public Expr MkFiniteSetRange(Expr low, Expr high) + { + Debug.Assert(low != null); + Debug.Assert(high != null); + + CheckContextMatch(low); + CheckContextMatch(high); + return Expr.Create(this, Native.Z3_mk_finite_set_range(nCtx, low.NativeObject, high.NativeObject)); + } + + #endregion + #region Sequence, string and regular expressions /// diff --git a/src/api/dotnet/FiniteSetSort.cs b/src/api/dotnet/FiniteSetSort.cs new file mode 100644 index 000000000..dda981cf9 --- /dev/null +++ b/src/api/dotnet/FiniteSetSort.cs @@ -0,0 +1,53 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + FiniteSetSort.cs + +Abstract: + + Z3 Managed API: Finite Set Sorts + +Author: + + GitHub Copilot + +Notes: + +--*/ + +using System.Diagnostics; +using System; + +namespace Microsoft.Z3 +{ + /// + /// Finite set sorts. + /// + public class FiniteSetSort : Sort + { + #region Internal + internal FiniteSetSort(Context ctx, IntPtr obj) + : base(ctx, obj) + { + Debug.Assert(ctx != null); + } + + internal FiniteSetSort(Context ctx, Sort elemSort) + : base(ctx, Native.Z3_mk_finite_set_sort(ctx.nCtx, elemSort.NativeObject)) + { + Debug.Assert(ctx != null); + Debug.Assert(elemSort != null); + } + #endregion + + /// + /// Get the element sort (basis) of this finite set sort. + /// + public Sort Basis + { + get { return Sort.Create(Context, Native.Z3_get_finite_set_sort_basis(Context.nCtx, NativeObject)); } + } + } +} diff --git a/src/api/java/CMakeLists.txt b/src/api/java/CMakeLists.txt index 9bde1bb20..194b25232 100644 --- a/src/api/java/CMakeLists.txt +++ b/src/api/java/CMakeLists.txt @@ -124,6 +124,7 @@ set(Z3_JAVA_JAR_SOURCE_FILES FiniteDomainExpr.java FiniteDomainNum.java FiniteDomainSort.java + FiniteSetSort.java Fixedpoint.java FPExpr.java FPNum.java diff --git a/src/api/java/Context.java b/src/api/java/Context.java index 2833916f2..b67463431 100644 --- a/src/api/java/Context.java +++ b/src/api/java/Context.java @@ -2134,6 +2134,145 @@ public class Context implements AutoCloseable { } + /** + * Finite Sets + */ + + /** + * Create a finite set sort over the given element sort. + **/ + public final FiniteSetSort mkFiniteSetSort(Sort elemSort) + { + checkContextMatch(elemSort); + return new FiniteSetSort(this, elemSort); + } + + /** + * Check if a sort is a finite set sort. + **/ + public final boolean isFiniteSetSort(Sort s) + { + checkContextMatch(s); + return Native.isFiniteSetSort(nCtx(), s.getNativeObject()); + } + + /** + * Get the element sort (basis) of a finite set sort. + **/ + public final Sort getFiniteSetSortBasis(Sort s) + { + checkContextMatch(s); + return Sort.create(this, Native.getFiniteSetSortBasis(nCtx(), s.getNativeObject())); + } + + /** + * Create an empty finite set. + **/ + public final Expr mkFiniteSetEmpty(Sort setSort) + { + checkContextMatch(setSort); + return Expr.create(this, Native.mkFiniteSetEmpty(nCtx(), setSort.getNativeObject())); + } + + /** + * Create a singleton finite set. + **/ + public final Expr mkFiniteSetSingleton(Expr elem) + { + checkContextMatch(elem); + return Expr.create(this, Native.mkFiniteSetSingleton(nCtx(), elem.getNativeObject())); + } + + /** + * Create the union of two finite sets. + **/ + public final Expr mkFiniteSetUnion(Expr s1, Expr s2) + { + checkContextMatch(s1); + checkContextMatch(s2); + return Expr.create(this, Native.mkFiniteSetUnion(nCtx(), s1.getNativeObject(), s2.getNativeObject())); + } + + /** + * Create the intersection of two finite sets. + **/ + public final Expr mkFiniteSetIntersect(Expr s1, Expr s2) + { + checkContextMatch(s1); + checkContextMatch(s2); + return Expr.create(this, Native.mkFiniteSetIntersect(nCtx(), s1.getNativeObject(), s2.getNativeObject())); + } + + /** + * Create the difference of two finite sets. + **/ + public final Expr mkFiniteSetDifference(Expr s1, Expr s2) + { + checkContextMatch(s1); + checkContextMatch(s2); + return Expr.create(this, Native.mkFiniteSetDifference(nCtx(), s1.getNativeObject(), s2.getNativeObject())); + } + + /** + * Check for membership in a finite set. + **/ + public final BoolExpr mkFiniteSetMember(Expr elem, Expr set) + { + checkContextMatch(elem); + checkContextMatch(set); + return (BoolExpr) Expr.create(this, Native.mkFiniteSetMember(nCtx(), elem.getNativeObject(), set.getNativeObject())); + } + + /** + * Get the cardinality of a finite set. + **/ + public final Expr mkFiniteSetSize(Expr set) + { + checkContextMatch(set); + return Expr.create(this, Native.mkFiniteSetSize(nCtx(), set.getNativeObject())); + } + + /** + * Check if one finite set is a subset of another. + **/ + public final BoolExpr mkFiniteSetSubset(Expr s1, Expr s2) + { + checkContextMatch(s1); + checkContextMatch(s2); + return (BoolExpr) Expr.create(this, Native.mkFiniteSetSubset(nCtx(), s1.getNativeObject(), s2.getNativeObject())); + } + + /** + * Map a function over all elements in a finite set. + **/ + public final Expr mkFiniteSetMap(Expr f, Expr set) + { + checkContextMatch(f); + checkContextMatch(set); + return Expr.create(this, Native.mkFiniteSetMap(nCtx(), f.getNativeObject(), set.getNativeObject())); + } + + /** + * Filter a finite set with a predicate. + **/ + public final Expr mkFiniteSetFilter(Expr f, Expr set) + { + checkContextMatch(f); + checkContextMatch(set); + return Expr.create(this, Native.mkFiniteSetFilter(nCtx(), f.getNativeObject(), set.getNativeObject())); + } + + /** + * Create a finite set containing integers in the range [low, high]. + **/ + public final Expr mkFiniteSetRange(Expr low, Expr high) + { + checkContextMatch(low); + checkContextMatch(high); + return Expr.create(this, Native.mkFiniteSetRange(nCtx(), low.getNativeObject(), high.getNativeObject())); + } + + /** * Sequences, Strings and regular expressions. */ diff --git a/src/api/java/FiniteSetSort.java b/src/api/java/FiniteSetSort.java new file mode 100644 index 000000000..031199539 --- /dev/null +++ b/src/api/java/FiniteSetSort.java @@ -0,0 +1,42 @@ +/** +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + FiniteSetSort.java + +Abstract: + +Author: + + GitHub Copilot + +Notes: + +**/ + +package com.microsoft.z3; + +/** + * Finite set sorts. + **/ +public class FiniteSetSort extends Sort +{ + FiniteSetSort(Context ctx, long obj) + { + super(ctx, obj); + } + + FiniteSetSort(Context ctx, Sort elemSort) + { + super(ctx, Native.mkFiniteSetSort(ctx.nCtx(), elemSort.getNativeObject())); + } + + /** + * Get the element sort (basis) of this finite set sort. + **/ + public Sort getBasis() + { + return Sort.create(getContext(), Native.getFiniteSetSortBasis(getContext().nCtx(), getNativeObject())); + } +} diff --git a/src/api/julia/z3jl.cpp b/src/api/julia/z3jl.cpp index 6bc53f78e..ec3efa86b 100644 --- a/src/api/julia/z3jl.cpp +++ b/src/api/julia/z3jl.cpp @@ -309,6 +309,17 @@ JLCXX_MODULE define_julia_module(jlcxx::Module &m) m.method("sqrt", static_cast(&sqrt)); m.method("fma", static_cast(&fma)); m.method("range", &range); + m.method("finite_set_empty", &finite_set_empty); + m.method("finite_set_singleton", &finite_set_singleton); + m.method("finite_set_union", &finite_set_union); + m.method("finite_set_intersect", &finite_set_intersect); + m.method("finite_set_difference", &finite_set_difference); + m.method("finite_set_member", &finite_set_member); + m.method("finite_set_size", &finite_set_size); + m.method("finite_set_subset", &finite_set_subset); + m.method("finite_set_map", &finite_set_map); + m.method("finite_set_filter", &finite_set_filter); + m.method("finite_set_range", &finite_set_range); // ------------------------------------------------------------------------- @@ -618,6 +629,7 @@ JLCXX_MODULE define_julia_module(jlcxx::Module &m) .MM(context, string_sort) .MM(context, seq_sort) .MM(context, re_sort) + .MM(context, finite_set_sort) .method("array_sort", static_cast(&context::array_sort)) .method("array_sort", static_cast(&context::array_sort)) .method("fpa_sort", static_cast(&context::fpa_sort)) diff --git a/src/api/ml/CMakeLists.txt b/src/api/ml/CMakeLists.txt index 2727c55ed..b35c4ac4c 100644 --- a/src/api/ml/CMakeLists.txt +++ b/src/api/ml/CMakeLists.txt @@ -1,3 +1,4 @@ + find_package(OCaml REQUIRED) set(exe_ext ${CMAKE_EXECUTABLE_SUFFIX}) diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index f859ab961..9438af21f 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -695,6 +695,8 @@ def is_sort(s : Any) -> bool: def _to_sort_ref(s, ctx): if z3_debug(): _z3_assert(isinstance(s, Sort), "Z3 Sort expected") + if Z3_is_finite_set_sort(ctx.ref(), s): + return FiniteSetSortRef(s, ctx) k = _sort_kind(ctx, s) if k == Z3_BOOL_SORT: return BoolSortRef(s, ctx) @@ -858,7 +860,7 @@ class FuncDeclRef(AstRef): elif k == Z3_PARAMETER_RATIONAL: result[i] = Z3_get_decl_rational_parameter(self.ctx_ref(), self.ast, i) elif k == Z3_PARAMETER_SYMBOL: - result[i] = Z3_get_decl_symbol_parameter(self.ctx_ref(), self.ast, i) + result[i] = _symbol2py(ctx, Z3_get_decl_symbol_parameter(self.ctx_ref(), self.ast, i)) elif k == Z3_PARAMETER_SORT: result[i] = SortRef(Z3_get_decl_sort_parameter(self.ctx_ref(), self.ast, i), ctx) elif k == Z3_PARAMETER_AST: @@ -1225,7 +1227,11 @@ def _to_expr_ref(a, ctx): k = Z3_get_ast_kind(ctx_ref, a) if k == Z3_QUANTIFIER_AST: return QuantifierRef(a, ctx) - sk = Z3_get_sort_kind(ctx_ref, Z3_get_sort(ctx_ref, a)) + # Check for finite set sort before checking sort kind + s = Z3_get_sort(ctx_ref, a) + if Z3_is_finite_set_sort(ctx_ref, s): + return FiniteSetRef(a, ctx) + sk = Z3_get_sort_kind(ctx_ref, s) if sk == Z3_BOOL_SORT: return BoolRef(a, ctx) if sk == Z3_INT_SORT: @@ -5066,6 +5072,25 @@ def Ext(a, b): _z3_assert(is_array_sort(a) and (is_array(b) or b.is_lambda()), "arguments must be arrays") return _to_expr_ref(Z3_mk_array_ext(ctx.ref(), a.as_ast(), b.as_ast()), ctx) + +def AsArray(f): + """Return a Z3 as-array expression for the given function declaration. + + >>> f = Function('f', IntSort(), IntSort()) + >>> a = AsArray(f) + >>> a.sort() + Array(Int, Int) + >>> is_as_array(a) + True + >>> get_as_array_func(a) == f + True + """ + if z3_debug(): + _z3_assert(isinstance(f, FuncDeclRef), "function declaration expected") + ctx = f.ctx + return ArrayRef(Z3_mk_as_array(ctx.ref(), f.ast), ctx) + + def is_select(a): """Return `True` if `a` is a Z3 array select application. @@ -5108,6 +5133,8 @@ def EmptySet(s): K(Int, False) """ ctx = s.ctx + if is_finite_set_sort(s): + return FiniteSetEmpty(s) return ArrayRef(Z3_mk_empty_set(ctx.ref(), s.ast), ctx) @@ -5128,6 +5155,9 @@ def SetUnion(*args): union(a, b) """ args = _get_args(args) + if len(args) > 0 and is_finite_set(args[0]): + from functools import reduce + return reduce(FiniteSetUnion, args) ctx = _ctx_from_ast_arg_list(args) _args, sz = _to_ast_array(args) return ArrayRef(Z3_mk_set_union(ctx.ref(), sz, _args), ctx) @@ -5142,6 +5172,9 @@ def SetIntersect(*args): """ args = _get_args(args) ctx = _ctx_from_ast_arg_list(args) + if len(args) > 0 and is_finite_set(args[0]): + from functools import reduce + return reduce(FiniteSetIntersect, args) _args, sz = _to_ast_array(args) return ArrayRef(Z3_mk_set_intersect(ctx.ref(), sz, _args), ctx) @@ -5152,8 +5185,10 @@ def SetAdd(s, e): >>> SetAdd(a, 1) Store(a, 1, True) """ - ctx = _ctx_from_ast_arg_list([s, e]) + ctx = _ctx_from_ast_arg_list([s, e]) e = _py2expr(e, ctx) + if is_finite_set(s): + return FiniteSetSingleton(e) | s return ArrayRef(Z3_mk_set_add(ctx.ref(), s.as_ast(), e.as_ast()), ctx) @@ -5165,6 +5200,8 @@ def SetDel(s, e): """ ctx = _ctx_from_ast_arg_list([s, e]) e = _py2expr(e, ctx) + if is_finite_set(s): + return s - FiniteSetSingleton(e) return ArrayRef(Z3_mk_set_del(ctx.ref(), s.as_ast(), e.as_ast()), ctx) @@ -5186,6 +5223,8 @@ def SetDifference(a, b): setminus(a, b) """ ctx = _ctx_from_ast_arg_list([a, b]) + if is_finite_set(a): + return FiniteSetDifference(a, b) return ArrayRef(Z3_mk_set_difference(ctx.ref(), a.as_ast(), b.as_ast()), ctx) @@ -5197,6 +5236,8 @@ def IsMember(e, s): """ ctx = _ctx_from_ast_arg_list([s, e]) e = _py2expr(e, ctx) + if is_finite_set(s): + return FiniteSetIsMember(e, s) return BoolRef(Z3_mk_set_member(ctx.ref(), e.as_ast(), s.as_ast()), ctx) @@ -5208,9 +5249,228 @@ def IsSubset(a, b): subset(a, b) """ ctx = _ctx_from_ast_arg_list([a, b]) + if is_finite_set(a): + return FiniteSetIsSubset(a, b) return BoolRef(Z3_mk_set_subset(ctx.ref(), a.as_ast(), b.as_ast()), ctx) +######################################### +# +# Finite Sets +# +######################################### + + +class FiniteSetSortRef(SortRef): + """Finite set sort.""" + + def element_sort(self): + """Return the element sort of this finite set sort.""" + return _to_sort_ref(Z3_get_finite_set_sort_basis(self.ctx_ref(), self.ast), self.ctx) + + def cast(self, val): + """Try to cast val as a finite set expression.""" + if is_expr(val): + if self.eq(val.sort()): + return val + else: + _z3_assert(False, "Cannot cast to finite set sort") + if isinstance(val, set): + elem_sort = self.element_sort() + result = FiniteSetEmpty(self) + for e in val: + result = FiniteSetUnion(result, Singleton(_py2expr(e, self.ctx, elem_sort))) + return result + _z3_assert(False, "Cannot cast to finite set sort") + + def subsort(self, other): + return False + + def is_int(self): + return False + + def is_bool(self): + return False + + def is_datatype(self): + return False + + def is_array(self): + return False + + def is_bv(self): + return False + + +def is_finite_set(a): + """Return True if a is a Z3 finite set expression. + >>> s = FiniteSetSort(IntSort()) + >>> is_finite_set(FiniteSetEmpty(s)) + True + >>> is_finite_set(IntVal(1)) + False + """ + return isinstance(a, FiniteSetRef) + + +def is_finite_set_sort(s): + """Return True if s is a Z3 finite set sort. + >>> is_finite_set_sort(FiniteSetSort(IntSort())) + True + >>> is_finite_set_sort(IntSort()) + False + """ + return isinstance(s, FiniteSetSortRef) + + +class FiniteSetRef(ExprRef): + """Finite set expression.""" + + def sort(self): + return FiniteSetSortRef(Z3_get_sort(self.ctx_ref(), self.as_ast()), self.ctx) + + def __or__(self, other): + """Return the union of self and other.""" + return FiniteSetUnion(self, other) + + def __and__(self, other): + """Return the intersection of self and other.""" + return FiniteSetIntersect(self, other) + + def __sub__(self, other): + """Return the set difference of self and other.""" + return FiniteSetDifference(self, other) + + +def FiniteSetSort(elem_sort): + """Create a finite set sort over element sort elem_sort. + >>> s = FiniteSetSort(IntSort()) + >>> s + FiniteSet(Int) + """ + return FiniteSetSortRef(Z3_mk_finite_set_sort(elem_sort.ctx_ref(), elem_sort.ast), elem_sort.ctx) + + +def FiniteSetEmpty(set_sort): + """Create an empty finite set of the given sort. + >>> s = FiniteSetSort(IntSort()) + >>> FiniteSetEmpty(s) + set.empty + """ + ctx = set_sort.ctx + return FiniteSetRef(Z3_mk_finite_set_empty(ctx.ref(), set_sort.ast), ctx) + + +def Singleton(elem): + """Create a singleton finite set containing elem. + >>> Singleton(IntVal(1)) + set.singleton(1) + """ + ctx = elem.ctx + return FiniteSetRef(Z3_mk_finite_set_singleton(ctx.ref(), elem.as_ast()), ctx) + + +def FiniteSetUnion(s1, s2): + """Create the union of two finite sets. + >>> a = Const('a', FiniteSetSort(IntSort())) + >>> b = Const('b', FiniteSetSort(IntSort())) + >>> FiniteSetUnion(a, b) + set.union(a, b) + """ + ctx = _ctx_from_ast_arg_list([s1, s2]) + return FiniteSetRef(Z3_mk_finite_set_union(ctx.ref(), s1.as_ast(), s2.as_ast()), ctx) + + +def FiniteSetIntersect(s1, s2): + """Create the intersection of two finite sets. + >>> a = Const('a', FiniteSetSort(IntSort())) + >>> b = Const('b', FiniteSetSort(IntSort())) + >>> FiniteSetIntersect(a, b) + set.intersect(a, b) + """ + ctx = _ctx_from_ast_arg_list([s1, s2]) + return FiniteSetRef(Z3_mk_finite_set_intersect(ctx.ref(), s1.as_ast(), s2.as_ast()), ctx) + + +def FiniteSetDifference(s1, s2): + """Create the set difference of two finite sets. + >>> a = Const('a', FiniteSetSort(IntSort())) + >>> b = Const('b', FiniteSetSort(IntSort())) + >>> FiniteSetDifference(a, b) + set.difference(a, b) + """ + ctx = _ctx_from_ast_arg_list([s1, s2]) + return FiniteSetRef(Z3_mk_finite_set_difference(ctx.ref(), s1.as_ast(), s2.as_ast()), ctx) + + +def FiniteSetMember(elem, set): + """Check if elem is a member of the finite set. + >>> a = Const('a', FiniteSetSort(IntSort())) + >>> FiniteSetMember(IntVal(1), a) + set.in(1, a) + """ + ctx = _ctx_from_ast_arg_list([elem, set]) + return BoolRef(Z3_mk_finite_set_member(ctx.ref(), elem.as_ast(), set.as_ast()), ctx) + +def In(elem, set): + return FiniteSetMember(elem, set) + +def FiniteSetSize(set): + """Get the size (cardinality) of a finite set. + >>> a = Const('a', FiniteSetSort(IntSort())) + >>> FiniteSetSize(a) + set.size(a) + """ + ctx = set.ctx + return ArithRef(Z3_mk_finite_set_size(ctx.ref(), set.as_ast()), ctx) + + +def FiniteSetSubset(s1, s2): + """Check if s1 is a subset of s2. + >>> a = Const('a', FiniteSetSort(IntSort())) + >>> b = Const('b', FiniteSetSort(IntSort())) + >>> FiniteSetSubset(a, b) + set.subset(a, b) + """ + ctx = _ctx_from_ast_arg_list([s1, s2]) + return BoolRef(Z3_mk_finite_set_subset(ctx.ref(), s1.as_ast(), s2.as_ast()), ctx) + + +def FiniteSetMap(f, set): + """Apply function f to all elements of the finite set. + >>> f = Array('f', IntSort(), IntSort()) + >>> a = Const('a', FiniteSetSort(IntSort())) + >>> FiniteSetMap(f, a) + set.map(f, a) + """ + if isinstance(f, FuncDeclRef): + f = AsArray(f) + ctx = _ctx_from_ast_arg_list([f, set]) + return FiniteSetRef(Z3_mk_finite_set_map(ctx.ref(), f.as_ast(), set.as_ast()), ctx) + + +def FiniteSetFilter(f, set): + """Filter a finite set using predicate f. + >>> f = Array('f', IntSort(), BoolSort()) + >>> a = Const('a', FiniteSetSort(IntSort())) + >>> FiniteSetFilter(f, a) + set.filter(f, a) + """ + if isinstance(f, FuncDeclRef): + f = AsArray(f) + ctx = _ctx_from_ast_arg_list([f, set]) + return FiniteSetRef(Z3_mk_finite_set_filter(ctx.ref(), f.as_ast(), set.as_ast()), ctx) + + +def FiniteSetRange(low, high): + """Create a finite set of integers in the range [low, high). + >>> FiniteSetRange(IntVal(0), IntVal(5)) + set.range(0, 5) + """ + ctx = _ctx_from_ast_arg_list([low, high]) + return FiniteSetRef(Z3_mk_finite_set_range(ctx.ref(), low.as_ast(), high.as_ast()), ctx) + + ######################################### # # Datatypes @@ -11653,6 +11913,14 @@ def Union(*args): sz = len(args) if z3_debug(): _z3_assert(sz > 0, "At least one argument expected.") + arg0 = args[0] + if is_finite_set(arg0): + for a in args[1:]: + if not is_finite_set(a): + raise Z3Exception("All arguments must be regular expressions or finite sets.") + arg0 = arg0 | a + return arg0 + if z3_debug(): _z3_assert(all([is_re(a) for a in args]), "All arguments must be regular expressions.") if sz == 1: return args[0] @@ -11671,6 +11939,14 @@ def Intersect(*args): sz = len(args) if z3_debug(): _z3_assert(sz > 0, "At least one argument expected.") + arg0 = args[0] + if is_finite_set(arg0): + for a in args[1:]: + if not is_finite_set(a): + raise Z3Exception("All arguments must be regular expressions or finite sets.") + arg0 = arg0 & a + return arg0 + if z3_debug(): _z3_assert(all([is_re(a) for a in args]), "All arguments must be regular expressions.") if sz == 1: return args[0] diff --git a/src/api/python/z3/z3printer.py b/src/api/python/z3/z3printer.py index 29d3e0db1..ddcb7b100 100644 --- a/src/api/python/z3/z3printer.py +++ b/src/api/python/z3/z3printer.py @@ -760,6 +760,8 @@ class Formatter: return seq1("Seq", (self.pp_sort(s.basis()), )) elif isinstance(s, z3.CharSortRef): return to_format("Char") + elif isinstance(s, z3.FiniteSetSortRef): + return seq1("FiniteSet", (self.pp_sort(s.element_sort()), )) else: return to_format(s.name()) diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 141c32e5a..4d2c72b86 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -980,6 +980,32 @@ typedef enum 3 = 011 = Z3_OP_FPA_RM_TOWARD_NEGATIVE, 4 = 100 = Z3_OP_FPA_RM_TOWARD_ZERO. + - Z3_OP_FINITE_SET_EMPTY: Empty finite set. + + - Z3_OP_FINITE_SET_SINGLETON: Finite set containing a single element. + + - Z3_OP_FINITE_SET_UNION: Union of two finite sets. + + - Z3_OP_FINITE_SET_INTERSECT: Intersection of two finite sets. + + - Z3_OP_FINITE_SET_DIFFERENCE: Difference of two finite sets. + + - Z3_OP_FINITE_SET_IN: Membership predicate for finite sets. + + - Z3_OP_FINITE_SET_SIZE: Cardinality of a finite set. + + - Z3_OP_FINITE_SET_SUBSET: Subset predicate for finite sets. + + - Z3_OP_FINITE_SET_MAP: Map operation on finite sets. + + - Z3_OP_FINITE_SET_FILTER: Filter operation on finite sets. + + - Z3_OP_FINITE_SET_RANGE: Range operation for finite sets of integers. + + - Z3_OP_FINITE_SET_EXT: Finite set extensionality. Returns a witness element that is in one set but not the other, demonstrating that two sets are different. + + - Z3_OP_FINITE_SET_MAP_INVERSE: Inverse image under a finite set map operation. Related to reasoning about the pre-image of elements under set mappings. + - Z3_OP_INTERNAL: internal (often interpreted) symbol, but no additional information is exposed. Tools may use the string representation of the function declaration to obtain more information. @@ -1313,6 +1339,21 @@ typedef enum { Z3_OP_FPA_BVWRAP, Z3_OP_FPA_BV2RM, + // Finite Sets + Z3_OP_FINITE_SET_EMPTY = 0xc000, + Z3_OP_FINITE_SET_SINGLETON, + Z3_OP_FINITE_SET_UNION, + Z3_OP_FINITE_SET_INTERSECT, + Z3_OP_FINITE_SET_DIFFERENCE, + Z3_OP_FINITE_SET_IN, + Z3_OP_FINITE_SET_SIZE, + Z3_OP_FINITE_SET_SUBSET, + Z3_OP_FINITE_SET_MAP, + Z3_OP_FINITE_SET_FILTER, + Z3_OP_FINITE_SET_RANGE, + Z3_OP_FINITE_SET_EXT, + Z3_OP_FINITE_SET_MAP_INVERSE, + Z3_OP_INTERNAL, Z3_OP_RECURSIVE, @@ -3413,6 +3454,107 @@ extern "C" { Z3_ast Z3_API Z3_mk_array_ext(Z3_context c, Z3_ast arg1, Z3_ast arg2); /**@}*/ + /** @name Finite Sets */ + /**@{*/ + /** + \brief Create a finite set sort. + + def_API('Z3_mk_finite_set_sort', SORT, (_in(CONTEXT), _in(SORT))) + */ + Z3_sort Z3_API Z3_mk_finite_set_sort(Z3_context c, Z3_sort elem_sort); + + /** + \brief Check if a sort is a finite set sort. + + def_API('Z3_is_finite_set_sort', BOOL, (_in(CONTEXT), _in(SORT))) + */ + bool Z3_API Z3_is_finite_set_sort(Z3_context c, Z3_sort s); + + /** + \brief Get the element sort of a finite set sort. + + def_API('Z3_get_finite_set_sort_basis', SORT, (_in(CONTEXT), _in(SORT))) + */ + Z3_sort Z3_API Z3_get_finite_set_sort_basis(Z3_context c, Z3_sort s); + + /** + \brief Create an empty finite set of the given sort. + + def_API('Z3_mk_finite_set_empty', AST, (_in(CONTEXT), _in(SORT))) + */ + Z3_ast Z3_API Z3_mk_finite_set_empty(Z3_context c, Z3_sort set_sort); + + /** + \brief Create a singleton finite set. + + def_API('Z3_mk_finite_set_singleton', AST, (_in(CONTEXT), _in(AST))) + */ + Z3_ast Z3_API Z3_mk_finite_set_singleton(Z3_context c, Z3_ast elem); + + /** + \brief Create the union of two finite sets. + + def_API('Z3_mk_finite_set_union', AST, (_in(CONTEXT), _in(AST), _in(AST))) + */ + Z3_ast Z3_API Z3_mk_finite_set_union(Z3_context c, Z3_ast s1, Z3_ast s2); + + /** + \brief Create the intersection of two finite sets. + + def_API('Z3_mk_finite_set_intersect', AST, (_in(CONTEXT), _in(AST), _in(AST))) + */ + Z3_ast Z3_API Z3_mk_finite_set_intersect(Z3_context c, Z3_ast s1, Z3_ast s2); + + /** + \brief Create the set difference of two finite sets. + + def_API('Z3_mk_finite_set_difference', AST, (_in(CONTEXT), _in(AST), _in(AST))) + */ + Z3_ast Z3_API Z3_mk_finite_set_difference(Z3_context c, Z3_ast s1, Z3_ast s2); + + /** + \brief Check if an element is a member of a finite set. + + def_API('Z3_mk_finite_set_member', AST, (_in(CONTEXT), _in(AST), _in(AST))) + */ + Z3_ast Z3_API Z3_mk_finite_set_member(Z3_context c, Z3_ast elem, Z3_ast set); + + /** + \brief Get the size (cardinality) of a finite set. + + def_API('Z3_mk_finite_set_size', AST, (_in(CONTEXT), _in(AST))) + */ + Z3_ast Z3_API Z3_mk_finite_set_size(Z3_context c, Z3_ast set); + + /** + \brief Check if one finite set is a subset of another. + + def_API('Z3_mk_finite_set_subset', AST, (_in(CONTEXT), _in(AST), _in(AST))) + */ + Z3_ast Z3_API Z3_mk_finite_set_subset(Z3_context c, Z3_ast s1, Z3_ast s2); + + /** + \brief Apply a function to all elements of a finite set. + + def_API('Z3_mk_finite_set_map', AST, (_in(CONTEXT), _in(AST), _in(AST))) + */ + Z3_ast Z3_API Z3_mk_finite_set_map(Z3_context c, Z3_ast f, Z3_ast set); + + /** + \brief Filter a finite set using a predicate. + + def_API('Z3_mk_finite_set_filter', AST, (_in(CONTEXT), _in(AST), _in(AST))) + */ + Z3_ast Z3_API Z3_mk_finite_set_filter(Z3_context c, Z3_ast f, Z3_ast set); + + /** + \brief Create a finite set of integers in the range [low, high]. + + def_API('Z3_mk_finite_set_range', AST, (_in(CONTEXT), _in(AST), _in(AST))) + */ + Z3_ast Z3_API Z3_mk_finite_set_range(Z3_context c, Z3_ast low, Z3_ast high); + /**@}*/ + /** @name Numerals */ /**@{*/ /** diff --git a/src/ast/CMakeLists.txt b/src/ast/CMakeLists.txt index 7a4a03a27..6a50c3b05 100644 --- a/src/ast/CMakeLists.txt +++ b/src/ast/CMakeLists.txt @@ -28,6 +28,7 @@ z3_add_component(ast expr_map.cpp expr_stat.cpp expr_substitution.cpp + finite_set_decl_plugin.cpp for_each_ast.cpp for_each_expr.cpp format.cpp diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 1583fb4ec..cf5e7af87 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -1,3 +1,4 @@ + /*++ Copyright (c) 2006 Microsoft Corporation @@ -1669,7 +1670,7 @@ bool ast_manager::slow_not_contains(ast const * n) { } #endif -#if 0 +#if 1 static unsigned s_count = 0; static void track_id(ast_manager& m, ast* n, unsigned id) { @@ -1708,9 +1709,8 @@ ast * ast_manager::register_node_core(ast * n) { n->m_id = is_decl(n) ? m_decl_id_gen.mk() : m_expr_id_gen.mk(); - // track_id(*this, n, 9213); - -// TRACE(ast, tout << (s_count++) << " Object " << n->m_id << " was created.\n";); + + // TRACE(ast, tout << (s_count++) << " Object " << n->m_id << " was created.\n";); TRACE(mk_var_bug, tout << "mk_ast: " << n->m_id << "\n";); // increment reference counters switch (n->get_kind()) { @@ -3343,15 +3343,14 @@ proof * ast_manager::mk_th_lemma( if (proofs_disabled()) return nullptr; - ptr_buffer args; vector parameters; parameters.push_back(parameter(get_family_name(tid))); - for (unsigned i = 0; i < num_params; ++i) { - parameters.push_back(params[i]); - } + for (unsigned i = 0; i < num_params; ++i) + parameters.push_back(params[i]); + ptr_buffer args; args.append(num_proofs, (expr**) proofs); args.push_back(fact); - return mk_app(basic_family_id, PR_TH_LEMMA, num_params+1, parameters.data(), args.size(), args.data()); + return mk_app(basic_family_id, PR_TH_LEMMA, parameters.size(), parameters.data(), args.size(), args.data()); } proof* ast_manager::mk_hyper_resolve(unsigned num_premises, proof* const* premises, expr* concl, diff --git a/src/ast/ast_smt2_pp.cpp b/src/ast/ast_smt2_pp.cpp index d601635be..70c1bb525 100644 --- a/src/ast/ast_smt2_pp.cpp +++ b/src/ast/ast_smt2_pp.cpp @@ -437,6 +437,11 @@ format_ns::format * smt2_pp_environment::pp_sort(sort * s) { fs.push_back(pp_sort(to_sort(s->get_parameter(0).get_ast()))); return mk_seq1(m, fs.begin(), fs.end(), f2f(), get_sutil().is_seq(s)?"Seq":"RegEx"); } + if ((get_fsutil().is_finite_set(s))) { + ptr_buffer fs; + fs.push_back(pp_sort(to_sort(s->get_parameter(0).get_ast()))); + return mk_seq1(m, fs.begin(), fs.end(), f2f(), "FiniteSet"); + } std::string name = ensure_quote(s->get_name()); if (get_dtutil().is_datatype(s)) { diff --git a/src/ast/ast_smt2_pp.h b/src/ast/ast_smt2_pp.h index 64ea2aec9..85e8e4a9a 100644 --- a/src/ast/ast_smt2_pp.h +++ b/src/ast/ast_smt2_pp.h @@ -30,6 +30,7 @@ Revision History: #include "ast/dl_decl_plugin.h" #include "ast/seq_decl_plugin.h" #include "ast/datatype_decl_plugin.h" +#include "ast/finite_set_decl_plugin.h" #include "ast/ast_smt_pp.h" #include "util/smt2_util.h" @@ -53,6 +54,7 @@ public: virtual array_util & get_arutil() = 0; virtual fpa_util & get_futil() = 0; virtual seq_util & get_sutil() = 0; + virtual finite_set_util &get_fsutil() = 0; virtual datalog::dl_decl_util& get_dlutil() = 0; virtual datatype_util& get_dtutil() = 0; virtual bool uses(symbol const & s) const = 0; @@ -80,9 +82,12 @@ class smt2_pp_environment_dbg : public smt2_pp_environment { fpa_util m_futil; seq_util m_sutil; datatype_util m_dtutil; + finite_set_util m_fsutil; datalog::dl_decl_util m_dlutil; public: - smt2_pp_environment_dbg(ast_manager & m):m_manager(m), m_autil(m), m_bvutil(m), m_arutil(m), m_futil(m), m_sutil(m), m_dtutil(m), m_dlutil(m) {} + smt2_pp_environment_dbg(ast_manager &m) + : m_manager(m), m_autil(m), m_bvutil(m), m_arutil(m), m_futil(m), m_sutil(m), m_dtutil(m), m_fsutil(m), + m_dlutil(m) {} ast_manager & get_manager() const override { return m_manager; } arith_util & get_autil() override { return m_autil; } bv_util & get_bvutil() override { return m_bvutil; } @@ -91,6 +96,7 @@ public: fpa_util & get_futil() override { return m_futil; } datalog::dl_decl_util& get_dlutil() override { return m_dlutil; } datatype_util& get_dtutil() override { return m_dtutil; } + finite_set_util &get_fsutil() override { return m_fsutil; } bool uses(symbol const & s) const override { return false; } }; diff --git a/src/ast/converters/expr_inverter.cpp b/src/ast/converters/expr_inverter.cpp index bc4a279c7..0e756ebe6 100644 --- a/src/ast/converters/expr_inverter.cpp +++ b/src/ast/converters/expr_inverter.cpp @@ -823,6 +823,47 @@ public: }; #endif +class finite_set_inverter : public iexpr_inverter { + finite_set_util fs; +public: + finite_set_inverter(ast_manager& m): iexpr_inverter(m), fs(m) {} + + family_id get_fid() const override { return fs.get_family_id(); } + + bool operator()(func_decl* f, unsigned num, expr* const* args, expr_ref& r) override { + switch (f->get_decl_kind()) { + case OP_FINITE_SET_UNION: + // x union y -> x + // y := x + if (num == 2 && uncnstr(args[0]) && uncnstr(args[1])) { + r = args[0]; + if (m_mc) { + add_def(args[1], r); + } + return true; + } + return false; + case OP_FINITE_SET_INTERSECT: + // x intersect y -> x + // y := x + if (num == 2 && uncnstr(args[0]) && uncnstr(args[1])) { + r = args[0]; + if (m_mc) { + add_def(args[1], r); + } + return true; + } + return false; + default: + break; + } + return false; + } + + bool mk_diff(expr* t, expr_ref& r) override { + return false; + } +}; class seq_expr_inverter : public iexpr_inverter { seq_util seq; @@ -972,6 +1013,7 @@ expr_inverter::expr_inverter(ast_manager& m): iexpr_inverter(m) { add(alloc(basic_expr_inverter, m, *this)); add(alloc(seq_expr_inverter, m)); //add(alloc(pb_expr_inverter, m)); + add(alloc(finite_set_inverter, m)); } diff --git a/src/ast/finite_set_decl_plugin.cpp b/src/ast/finite_set_decl_plugin.cpp new file mode 100644 index 000000000..49ee4e2bf --- /dev/null +++ b/src/ast/finite_set_decl_plugin.cpp @@ -0,0 +1,347 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + finite_set_decl_plugin.cpp + +Abstract: + + Declaration plugin for finite sets + +Author: + + GitHub Copilot Agent 2025 + +Revision History: + +--*/ +#include +#include "ast/finite_set_decl_plugin.h" +#include "ast/arith_decl_plugin.h" +#include "ast/array_decl_plugin.h" +#include "ast/polymorphism_util.h" +#include "ast/ast_pp.h" +#include "util/warning.h" + +finite_set_decl_plugin::finite_set_decl_plugin(): + m_init(false) { + m_names.resize(LAST_FINITE_SET_OP, nullptr); + m_names[OP_FINITE_SET_EMPTY] = "set.empty"; + m_names[OP_FINITE_SET_SINGLETON] = "set.singleton"; + m_names[OP_FINITE_SET_UNION] = "set.union"; + m_names[OP_FINITE_SET_INTERSECT] = "set.intersect"; + m_names[OP_FINITE_SET_DIFFERENCE] = "set.difference"; + m_names[OP_FINITE_SET_IN] = "set.in"; + m_names[OP_FINITE_SET_SIZE] = "set.size"; + m_names[OP_FINITE_SET_SUBSET] = "set.subset"; + m_names[OP_FINITE_SET_MAP] = "set.map"; + m_names[OP_FINITE_SET_FILTER] = "set.filter"; + m_names[OP_FINITE_SET_RANGE] = "set.range"; + m_names[OP_FINITE_SET_EXT] = "set.diff"; + m_names[OP_FINITE_SET_MAP_INVERSE] = "set.map.inverse"; + m_names[OP_FINITE_SET_UNIQUE_SET] = "set.unique"; +} + +finite_set_decl_plugin::~finite_set_decl_plugin() { + for (polymorphism::psig* s : m_sigs) + dealloc(s); +} + +void finite_set_decl_plugin::init() { + if (m_init) return; + ast_manager& m = *m_manager; + array_util autil(m); + m_init = true; + + sort* A = m.mk_type_var(symbol("A")); + sort* B = m.mk_type_var(symbol("B")); + parameter paramA(A); + parameter paramB(B); + sort* setA = m.mk_sort(m_family_id, FINITE_SET_SORT, 1, ¶mA); + sort* setB = m.mk_sort(m_family_id, FINITE_SET_SORT, 1, ¶mB); + sort* boolT = m.mk_bool_sort(); + sort* intT = arith_util(m).mk_int(); + parameter paramInt(intT); + sort* setInt = m.mk_sort(m_family_id, FINITE_SET_SORT, 1, ¶mInt); + sort* arrAB = autil.mk_array_sort(A, B); + sort* arrABool = autil.mk_array_sort(A, boolT); + + sort* setAsetA[2] = { setA, setA }; + sort* AsetA[2] = { A, setA }; + sort* arrABsetA[2] = { arrAB, setA }; + sort* arrABoolsetA[2] = { arrABool, setA }; + sort* intintT[2] = { intT, intT }; + sort *arrABBsetA[3] = {arrAB, B, setA}; + + m_sigs.resize(LAST_FINITE_SET_OP); + m_sigs[OP_FINITE_SET_EMPTY] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_EMPTY], 1, 0, nullptr, setA); + m_sigs[OP_FINITE_SET_SINGLETON] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_SINGLETON], 1, 1, &A, setA); + m_sigs[OP_FINITE_SET_UNION] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_UNION], 1, 2, setAsetA, setA); + m_sigs[OP_FINITE_SET_INTERSECT] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_INTERSECT], 1, 2, setAsetA, setA); + m_sigs[OP_FINITE_SET_DIFFERENCE] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_DIFFERENCE], 1, 2, setAsetA, setA); + m_sigs[OP_FINITE_SET_IN] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_IN], 1, 2, AsetA, boolT); + m_sigs[OP_FINITE_SET_SIZE] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_SIZE], 1, 1, &setA, intT); + m_sigs[OP_FINITE_SET_SUBSET] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_SUBSET], 1, 2, setAsetA, boolT); + m_sigs[OP_FINITE_SET_MAP] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_MAP], 2, 2, arrABsetA, setB); + m_sigs[OP_FINITE_SET_FILTER] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_FILTER], 1, 2, arrABoolsetA, setA); + m_sigs[OP_FINITE_SET_RANGE] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_RANGE], 0, 2, intintT, setInt); + m_sigs[OP_FINITE_SET_EXT] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_EXT], 1, 2, setAsetA, A); + m_sigs[OP_FINITE_SET_MAP_INVERSE] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_MAP_INVERSE], 2, 3, arrABBsetA, A); + m_sigs[OP_FINITE_SET_UNIQUE_SET] = alloc(polymorphism::psig, m, m_names[OP_FINITE_SET_UNIQUE_SET], 1, 2, intintT, setA); +} + +sort * finite_set_decl_plugin::mk_sort(decl_kind k, unsigned num_parameters, parameter const * parameters) { + if (k == FINITE_SET_SORT) { + if (num_parameters != 1) { + m_manager->raise_exception("FiniteSet sort expects exactly one parameter (element sort)"); + return nullptr; + } + if (!parameters[0].is_ast() || !is_sort(parameters[0].get_ast())) { + m_manager->raise_exception("FiniteSet sort parameter must be a sort"); + return nullptr; + } + sort * element_sort = to_sort(parameters[0].get_ast()); + sort_size sz; + + // Compute the size of the finite_set sort based on the element sort + sort_size const& elem_sz = element_sort->get_num_elements(); + if (elem_sz.is_finite() && !elem_sz.is_very_big()) { + uint64_t elem_size = elem_sz.size(); + // If elem_size > 30, the powerset would be > 2^30, so mark as very_big + if (elem_size > 30) { + sz = sort_size::mk_very_big(); + } + else { + // Compute 2^elem_size + sz = sort_size(rational::power_of_two(static_cast(elem_size))); + } + } + else { + // If element sort is infinite or very_big, the finite_set has the same size + sz = elem_sz; + } + + sort_info info(m_family_id, FINITE_SET_SORT, sz, num_parameters, parameters); + return m_manager->mk_sort(symbol("FiniteSet"), info); + } + m_manager->raise_exception("unknown finite set sort"); + return nullptr; +} + +bool finite_set_decl_plugin::is_finite_set(sort* s) const { + return s->get_family_id() == m_family_id && s->get_decl_kind() == FINITE_SET_SORT; +} + +sort * finite_set_decl_plugin::get_element_sort(sort* finite_set_sort) const { + if (!is_finite_set(finite_set_sort)) { + return nullptr; + } + if (finite_set_sort->get_num_parameters() != 1) { + return nullptr; + } + parameter const* params = finite_set_sort->get_parameters(); + if (!params[0].is_ast() || !is_sort(params[0].get_ast())) { + return nullptr; + } + return to_sort(params[0].get_ast()); +} + +func_decl * finite_set_decl_plugin::mk_empty(sort* finite_set_sort) { + parameter param(finite_set_sort); + if (!is_finite_set(finite_set_sort)) + m_manager->raise_exception("set.empty range must be a finite set sort"); + sort * const * no_domain = nullptr; + return m_manager->mk_func_decl(m_sigs[OP_FINITE_SET_EMPTY]->m_name, 0u, no_domain, finite_set_sort, + func_decl_info(m_family_id, OP_FINITE_SET_EMPTY, 1, ¶m)); +} + +func_decl * finite_set_decl_plugin::mk_finite_set_op(decl_kind k, unsigned arity, sort * const * domain, sort* range) { + ast_manager& m = *m_manager; + polymorphism::util poly_util(m); + sort_ref rng(m); + poly_util.match(*m_sigs[k], arity, domain, range, rng); + unsigned np = k == OP_FINITE_SET_UNIQUE_SET ? 1 : 0; + parameter p(rng.get()); + func_decl_info info(m_family_id, k, np, &p); + if (k == OP_FINITE_SET_UNION || k == OP_FINITE_SET_INTERSECT) { + info.set_commutative(true); + info.set_associative(true); + } + return m.mk_func_decl(m_sigs[k]->m_name, arity, domain, rng, info); +} + +func_decl * finite_set_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, + parameter const * parameters, + unsigned arity, sort * const * domain, + sort * range) { + init(); + + switch (k) { + case OP_FINITE_SET_EMPTY: + if (!range) { + if ((num_parameters != 1 || !parameters[0].is_ast() || !is_sort(parameters[0].get_ast()))) { + m_manager->raise_exception("set.empty requires one sort parameter"); + return nullptr; + } + range = to_sort(parameters[0].get_ast()); + } + return mk_empty(range); + case OP_FINITE_SET_UNIQUE_SET: + if (!range) { + if ((num_parameters != 1 || !parameters[0].is_ast() || !is_sort(parameters[0].get_ast()))) { + m_manager->raise_exception("set.unique requires one sort parameter"); + return nullptr; + } + range = to_sort(parameters[0].get_ast()); + } + return mk_finite_set_op(k, arity, domain, range); + case OP_FINITE_SET_UNION: + case OP_FINITE_SET_INTERSECT: + return mk_finite_set_op(k, 2, domain, range); + case OP_FINITE_SET_SINGLETON: + case OP_FINITE_SET_DIFFERENCE: + case OP_FINITE_SET_IN: + case OP_FINITE_SET_SIZE: + case OP_FINITE_SET_SUBSET: + case OP_FINITE_SET_MAP: + case OP_FINITE_SET_MAP_INVERSE: + case OP_FINITE_SET_FILTER: + case OP_FINITE_SET_RANGE: + case OP_FINITE_SET_EXT: + return mk_finite_set_op(k, arity, domain, range); + default: + return nullptr; + } +} + +void finite_set_decl_plugin::get_op_names(svector& op_names, symbol const & logic) { + for (unsigned i = 0; i < m_names.size(); ++i) + if (m_names[i] && i != OP_FINITE_SET_UNIQUE_SET) + op_names.push_back(builtin_name(std::string(m_names[i]), i)); +} + +void finite_set_decl_plugin::get_sort_names(svector& sort_names, symbol const & logic) { + sort_names.push_back(builtin_name("FiniteSet", FINITE_SET_SORT)); +} + +expr * finite_set_decl_plugin::get_some_value(sort * s) { + if (is_finite_set(s)) { + // Return empty set for the given sort + parameter param(s); + return m_manager->mk_app(m_family_id, OP_FINITE_SET_EMPTY, 1, ¶m, 0, nullptr); + } + return nullptr; +} + +bool finite_set_decl_plugin::is_fully_interp(sort * s) const { + SASSERT(is_finite_set(s)); + sort* element_sort = get_element_sort(s); + return element_sort && m_manager->is_fully_interp(element_sort); +} + +bool finite_set_decl_plugin::is_value(app * e) const { + // Check if e is a union of either empty sets or singleton sets, + // where the singleton member is a value. + // Use ptr_buffer and expr_fast_mark1 to avoid exponential overhead. + + ptr_buffer todo; + expr_fast_mark1 visited; + + todo.push_back(e); + + while (!todo.empty()) { + expr* current = todo.back(); + todo.pop_back(); + + // Skip if already visited + if (visited.is_marked(current)) + continue; + visited.mark(current); + + // Check if current is an app + if (!is_app(current)) + return false; + + app* a = to_app(current); + + // Check if it's an empty set + if (is_app_of(a, m_family_id, OP_FINITE_SET_EMPTY)) + continue; + + // Check if it's a singleton with a value element + if (is_app_of(a, m_family_id, OP_FINITE_SET_SINGLETON)) { + if (a->get_num_args() != 1) + return false; + expr* elem = a->get_arg(0); + if (!m_manager->is_value(elem)) + return false; + continue; + } + + bool is_setop = + is_app_of(a, m_family_id, OP_FINITE_SET_UNION) + || is_app_of(a, m_family_id, OP_FINITE_SET_INTERSECT) + || is_app_of(a, m_family_id, OP_FINITE_SET_DIFFERENCE); + // Check if it's a union + if (is_setop) { + // Add arguments to todo list + for (auto arg : *a) + todo.push_back(arg); + continue; + } + + if (is_app_of(a, m_family_id, OP_FINITE_SET_RANGE)) { + for (auto arg : *a) + if (!m_manager->is_value(arg)) + return false; + continue; + } + + // can add also ranges where lo and hi are values. + + // If it's none of the above, it's not a value + return false; + } + + return true; +} + +bool finite_set_decl_plugin::is_unique_value(app* e) const { + // Empty set is a value + // A singleton of a unique value is tagged as unique + // ranges are not considered unique even if the bounds are values + return is_app_of(e, m_family_id, OP_FINITE_SET_EMPTY) || + (is_app_of(e, m_family_id, OP_FINITE_SET_SINGLETON) && m_manager->is_unique_value(to_app(e->get_arg(0)))); +} + +bool finite_set_decl_plugin::are_distinct(app* e1, app* e2) const { + if (is_unique_value(e1) && is_unique_value(e2)) + return e1 != e2; + finite_set_recognizers r(get_family_id()); + if (r.is_empty(e1) && r.is_singleton(e2)) + return true; + if (r.is_singleton(e1) && r.is_empty(e2)) + return true; + expr *x = nullptr, *y = nullptr; + if(r.is_singleton(e1, x) && r.is_singleton(e2, y)) + return m_manager->are_distinct(x, y); + + // TODO: could be extended to cases where we can prove the sets are different by containing one element + // that the other doesn't contain. Such as (union (singleton a) (singleton b)) and (singleton c) where c is different from a, b. + return false; +} + +func_decl *finite_set_util::mk_range_decl() { + arith_util a(m_manager); + sort *i = a.mk_int(); + sort *domain[2] = {i, i}; + return m_manager.mk_func_decl(m_fid, OP_FINITE_SET_RANGE, 0, nullptr, 2, domain, nullptr); +} + +app* finite_set_util::mk_unique_set(expr* index, expr* cardinality, sort* s) { + parameter params[1] = { parameter(s) }; + expr *args[2] = {index, cardinality}; + return m_manager.mk_app(m_fid, OP_FINITE_SET_UNIQUE_SET, 1, params, 2, args); +} + diff --git a/src/ast/finite_set_decl_plugin.h b/src/ast/finite_set_decl_plugin.h new file mode 100644 index 000000000..ec927dfa9 --- /dev/null +++ b/src/ast/finite_set_decl_plugin.h @@ -0,0 +1,222 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + finite_set_decl_plugin.h + +Abstract: + Declaration plugin for finite sets signatures + +Sort: + FiniteSet(S) + +Operators: + set.empty : (FiniteSet S) + set.singleton : S -> (FiniteSet S) + set.union : (FiniteSet S) (FiniteSet S) -> (FiniteSet S) + set.intersect : (FiniteSet S) (FiniteSet S) -> (FiniteSet S) + set.difference : (FiniteSet S) (FiniteSet S) -> (FiniteSet S) + set.in : S (FiniteSet S) -> Bool + set.size : (FiniteSet S) -> Int + set.subset : (FiniteSet S) (FiniteSet S) -> Bool + set.map : (S -> T) (FiniteSet S) -> (FiniteSet T) + set.filter : (S -> Bool) (FiniteSet S) -> (FiniteSet S) + set.range : Int Int -> (FiniteSet Int) + set.diff : (FiniteSet S) (FiniteSet S) -> S + +--*/ +#pragma once + +#include "ast/ast.h" +#include "ast/polymorphism_util.h" + +enum finite_set_sort_kind { + FINITE_SET_SORT +}; + +enum finite_set_op_kind { + OP_FINITE_SET_EMPTY, + OP_FINITE_SET_SINGLETON, + OP_FINITE_SET_UNION, + OP_FINITE_SET_INTERSECT, + OP_FINITE_SET_DIFFERENCE, + OP_FINITE_SET_IN, + OP_FINITE_SET_SIZE, + OP_FINITE_SET_SUBSET, + OP_FINITE_SET_MAP, + OP_FINITE_SET_FILTER, + OP_FINITE_SET_RANGE, + OP_FINITE_SET_EXT, + OP_FINITE_SET_MAP_INVERSE, + OP_FINITE_SET_UNIQUE_SET, + LAST_FINITE_SET_OP +}; + +class finite_set_decl_plugin : public decl_plugin { + ptr_vector m_sigs; + svector m_names; + bool m_init = false; + + void init(); + func_decl * mk_empty(sort* set_sort); + func_decl * mk_finite_set_op(decl_kind k, unsigned arity, sort * const * domain, sort* range); + sort * get_element_sort(sort* finite_set_sort) const; + bool is_finite_set(sort* s) const; + +public: + finite_set_decl_plugin(); + ~finite_set_decl_plugin() override; + + decl_plugin * mk_fresh() override { + return alloc(finite_set_decl_plugin); + } + + void finalize() override { + for (polymorphism::psig* s : m_sigs) + dealloc(s); + m_sigs.reset(); + } + + // + // Contract for sort: + // parameters[0] - element sort + // + sort * mk_sort(decl_kind k, unsigned num_parameters, parameter const * parameters) override; + + // + // Contract for func_decl: + // For OP_FINITE_SET_MAP and OP_FINITE_SET_FILTER: + // parameters[0] - function declaration + // For others: + // no parameters + func_decl * mk_func_decl(decl_kind k, unsigned num_parameters, parameter const * parameters, + unsigned arity, sort * const * domain, sort * range) override; + + void get_op_names(svector & op_names, symbol const & logic) override; + + void get_sort_names(svector & sort_names, symbol const & logic) override; + + expr * get_some_value(sort * s) override; + + bool is_fully_interp(sort * s) const override; + + bool is_value(app * e) const override; + + bool is_unique_value(app* e) const override; + + bool are_distinct(app *e1, app *e2) const override; + +}; + +class finite_set_recognizers { +protected: + family_id m_fid; +public: + finite_set_recognizers(family_id fid):m_fid(fid) {} + family_id get_family_id() const { return m_fid; } + bool is_finite_set(sort* s) const { return is_sort_of(s, m_fid, FINITE_SET_SORT); } + bool is_finite_set(sort* s, sort*& elem_sort) const { + if (is_finite_set(s)) { + elem_sort = to_sort(s->get_parameter(0).get_ast()); + return true; + } + return false; + } + bool is_finite_set(expr const* n) const { return is_finite_set(n->get_sort()); } + bool is_empty(expr const* n) const { return is_app_of(n, m_fid, OP_FINITE_SET_EMPTY); } + bool is_singleton(expr const* n) const { return is_app_of(n, m_fid, OP_FINITE_SET_SINGLETON); } + bool is_union(expr const* n) const { return is_app_of(n, m_fid, OP_FINITE_SET_UNION); } + bool is_intersect(expr const* n) const { return is_app_of(n, m_fid, OP_FINITE_SET_INTERSECT); } + bool is_difference(expr const* n) const { return is_app_of(n, m_fid, OP_FINITE_SET_DIFFERENCE); } + bool is_in(expr const* n) const { return is_app_of(n, m_fid, OP_FINITE_SET_IN); } + bool is_size(expr const* n) const { return is_app_of(n, m_fid, OP_FINITE_SET_SIZE); } + bool is_subset(expr const* n) const { return is_app_of(n, m_fid, OP_FINITE_SET_SUBSET); } + bool is_map(expr const* n) const { return is_app_of(n, m_fid, OP_FINITE_SET_MAP); } + bool is_filter(expr const* n) const { return is_app_of(n, m_fid, OP_FINITE_SET_FILTER); } + bool is_range(expr const* n) const { return is_app_of(n, m_fid, OP_FINITE_SET_RANGE); } + bool is_unique_set(expr const *n) const { return is_app_of(n, m_fid, OP_FINITE_SET_UNIQUE_SET); } + + MATCH_UNARY(is_singleton); + MATCH_UNARY(is_size); + MATCH_BINARY(is_union); + MATCH_BINARY(is_intersect); + MATCH_BINARY(is_difference); + MATCH_BINARY(is_in); + MATCH_BINARY(is_subset); + MATCH_BINARY(is_map); + MATCH_BINARY(is_filter); + MATCH_BINARY(is_range); + MATCH_BINARY(is_unique_set); +}; + +class finite_set_util : public finite_set_recognizers { + ast_manager& m_manager; +public: + finite_set_util(ast_manager& m): + finite_set_recognizers(m.mk_family_id("finite_set")), m_manager(m) {} + + ast_manager& get_manager() const { return m_manager; } + + sort *mk_finite_set_sort(sort *elem_sort) { + parameter param(elem_sort); + return m_manager.mk_sort(m_fid, FINITE_SET_SORT, 1, ¶m); + } + + app * mk_empty(sort* set_sort) { + parameter param(set_sort); + return m_manager.mk_app(m_fid, OP_FINITE_SET_EMPTY, 1, ¶m, 0, nullptr); + } + + app * mk_singleton(expr* elem) { + return m_manager.mk_app(m_fid, OP_FINITE_SET_SINGLETON, elem); + } + + app * mk_union(expr* s1, expr* s2) { + return m_manager.mk_app(m_fid, OP_FINITE_SET_UNION, s1, s2); + } + + app * mk_intersect(expr* s1, expr* s2) { + return m_manager.mk_app(m_fid, OP_FINITE_SET_INTERSECT, s1, s2); + } + + app * mk_difference(expr* s1, expr* s2) { + return m_manager.mk_app(m_fid, OP_FINITE_SET_DIFFERENCE, s1, s2); + } + + app *mk_ext(expr *s1, expr *s2) { + return m_manager.mk_app(m_fid, OP_FINITE_SET_EXT, s1, s2); + } + + app * mk_in(expr* elem, expr* set) { + return m_manager.mk_app(m_fid, OP_FINITE_SET_IN, elem, set); + } + + app * mk_size(expr* set) { + return m_manager.mk_app(m_fid, OP_FINITE_SET_SIZE, set); + } + + app * mk_subset(expr* s1, expr* s2) { + return m_manager.mk_app(m_fid, OP_FINITE_SET_SUBSET, s1, s2); + } + + app * mk_map(expr* arr, expr* set) { + return m_manager.mk_app(m_fid, OP_FINITE_SET_MAP, arr, set); + } + + app *mk_map_inverse(expr *f, expr *x, expr *b) { + return m_manager.mk_app(m_fid, OP_FINITE_SET_MAP_INVERSE, f, x, b); + } + + app * mk_filter(expr* arr, expr* set) { + return m_manager.mk_app(m_fid, OP_FINITE_SET_FILTER, arr, set); + } + + func_decl *mk_range_decl(); + + app *mk_range(expr *low, expr *high) { + return m_manager.mk_app(m_fid, OP_FINITE_SET_RANGE, low, high); + } + + app *mk_unique_set(expr *s1, expr *s2, sort *s); +}; diff --git a/src/ast/polymorphism_util.cpp b/src/ast/polymorphism_util.cpp index 3431dd034..62dd4e949 100644 --- a/src/ast/polymorphism_util.cpp +++ b/src/ast/polymorphism_util.cpp @@ -350,4 +350,31 @@ namespace polymorphism { proc proc(m, tvs); for_each_ast(proc, e, true); } + + void util::match(psig& sig, unsigned dsz, sort* const* dom, sort* range, sort_ref& range_out) { + if (dsz != sig.m_dom.size()) { + std::ostringstream strm; + strm << "Incorrect number of arguments to '" << sig.m_name << "' "; + strm << "expected " << sig.m_dom.size() << " given " << dsz; + m.raise_exception(strm.str()); + } + + substitution sub(m); + bool is_match = true; + for (unsigned i = 0; is_match && i < dsz; ++i) { + SASSERT(dom[i]); + is_match = sub.match(sig.m_dom.get(i), dom[i]); + } + if (range && is_match) { + is_match = sub.match(sig.m_range, range); + } + if (!is_match) { + std::ostringstream strm; + strm << "Sort mismatch for function '" << sig.m_name << "'"; + m.raise_exception(strm.str()); + } + + // Apply substitution to get the range + range_out = sub(sig.m_range); + } } diff --git a/src/ast/polymorphism_util.h b/src/ast/polymorphism_util.h index 3023d0338..d7591b7fb 100644 --- a/src/ast/polymorphism_util.h +++ b/src/ast/polymorphism_util.h @@ -77,6 +77,24 @@ namespace polymorphism { }; typedef hashtable substitutions; + + /** + * Polymorphic signature for operators + */ + struct psig { + symbol m_name; + unsigned m_num_params; + sort_ref_vector m_dom; + sort_ref m_range; + psig(ast_manager& m, char const* name, unsigned n, unsigned dsz, sort* const* dom, sort* rng): + m_name(name), + m_num_params(n), + m_dom(m), + m_range(rng, m) + { + m_dom.append(dsz, dom); + } + }; class util { ast_manager& m; @@ -99,6 +117,13 @@ namespace polymorphism { substitution& sub); bool match(substitution& sub, sort* s1, sort* s_ground); + + /** + * Match a polymorphic signature against concrete argument sorts. + * Raises exception if arity mismatch or type mismatch. + * Returns the instantiated range sort via range_out. + */ + void match(psig& sig, unsigned dsz, sort* const* dom, sort* range, sort_ref& range_out); // collect instantiations of polymorphic functions void collect_poly_instances(expr* e, ptr_vector& instances); diff --git a/src/ast/reg_decl_plugins.cpp b/src/ast/reg_decl_plugins.cpp index 8cb0bbe5a..18b3e79cb 100644 --- a/src/ast/reg_decl_plugins.cpp +++ b/src/ast/reg_decl_plugins.cpp @@ -29,6 +29,7 @@ Revision History: #include "ast/pb_decl_plugin.h" #include "ast/fpa_decl_plugin.h" #include "ast/special_relations_decl_plugin.h" +#include "ast/finite_set_decl_plugin.h" void reg_decl_plugins(ast_manager & m) { if (!m.get_plugin(m.mk_family_id(symbol("arith")))) { @@ -64,4 +65,7 @@ void reg_decl_plugins(ast_manager & m) { if (!m.get_plugin(m.mk_family_id(symbol("specrels")))) { m.register_plugin(symbol("specrels"), alloc(special_relations_decl_plugin)); } + if (!m.get_plugin(m.mk_family_id(symbol("finite_set")))) { + m.register_plugin(symbol("finite_set"), alloc(finite_set_decl_plugin)); + } } diff --git a/src/ast/rewriter/CMakeLists.txt b/src/ast/rewriter/CMakeLists.txt index 8e6306c7e..9d529f9b5 100644 --- a/src/ast/rewriter/CMakeLists.txt +++ b/src/ast/rewriter/CMakeLists.txt @@ -22,6 +22,8 @@ z3_add_component(rewriter expr_safe_replace.cpp factor_equivs.cpp factor_rewriter.cpp + finite_set_axioms.cpp + finite_set_rewriter.cpp fpa_rewriter.cpp func_decl_replace.cpp inj_axiom.cpp diff --git a/src/ast/rewriter/finite_set_axioms.cpp b/src/ast/rewriter/finite_set_axioms.cpp new file mode 100644 index 000000000..025040be0 --- /dev/null +++ b/src/ast/rewriter/finite_set_axioms.cpp @@ -0,0 +1,407 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + finite_set_axioms.cpp + +Abstract: + + This module implements axiom schemas that are invoked by saturating constraints + with respect to the semantics of set operations. + +Author: + + nbjorner October 2025 + +--*/ + +#include "ast/ast.h" +#include "ast/ast_pp.h" +#include "ast/ast_util.h" +#include "ast/finite_set_decl_plugin.h" +#include "ast/arith_decl_plugin.h" +#include "ast/array_decl_plugin.h" +#include "ast/rewriter/finite_set_axioms.h" + + +std::ostream& operator<<(std::ostream& out, theory_axiom const& ax) { + auto &m = ax.clause.get_manager(); + for (auto e : ax.clause) + out << mk_pp(e, m) << " "; + return out; +} + +void finite_set_axioms::add_unit(char const *name, expr *p1, expr *unit) { + expr_ref _f1(unit, m); + if (is_true(unit)) + return; + theory_axiom *ax = alloc(theory_axiom, m, name, p1); + ax->clause.push_back(unit); + m_add_clause(ax); +} + + +bool finite_set_axioms::is_true(expr *f) { + if (m.is_true(f)) + return true; + if (m.is_not(f, f) && m.is_false(f)) + return true; + return false; +} + + +bool finite_set_axioms::is_false(expr* f) { + if (m.is_false(f)) + return true; + if (m.is_not(f, f) && m.is_true(f)) + return true; + return false; +} + +void finite_set_axioms::add_binary(char const *name, expr *p1, expr *p2, expr *f1, expr *f2) { + expr_ref _f1(f1, m), _f2(f2, m); + if (is_true(f1) || is_true(f2)) + return; + theory_axiom *ax = alloc(theory_axiom, m, name, p1, p2); + if (!is_false(f1)) + ax->clause.push_back(f1); + if (!is_false(f2)) + ax->clause.push_back(f2); + m_add_clause(ax); +} + +void finite_set_axioms::add_ternary(char const *name, expr *p1, expr *p2, expr *f1, expr *f2, expr *f3) { + expr_ref _f1(f1, m), _f2(f2, m), _f3(f3, m); + if (is_true(f1) || is_true(f2) || is_true(f3)) + return; + theory_axiom *ax = alloc(theory_axiom, m, name, p1, p2); + if (!is_false(f1)) + ax->clause.push_back(f1); + if (!is_false(f2)) + ax->clause.push_back(f2); + if (!is_false(f3)) + ax->clause.push_back(f3); + m_add_clause(ax); +} + +// a ~ set.empty => not (x in a) +// x is an element, generate axiom that x is not in any empty set of x's type +void finite_set_axioms::in_empty_axiom(expr *x) { + // Generate: not (x in empty_set) + // where empty_set is the empty set of x's type + sort* elem_sort = x->get_sort(); + sort *set_sort = u.mk_finite_set_sort(elem_sort); + expr_ref empty_set(u.mk_empty(set_sort), m); + expr_ref x_in_empty(u.mk_in(x, empty_set), m); + add_unit("in-empty", x, m.mk_not(x_in_empty)); +} + +// a := set.union(b, c) +// (x in a) <=> (x in b) or (x in c) +void finite_set_axioms::in_union_axiom(expr *x, expr *a) { + expr* b = nullptr, *c = nullptr; + if (!u.is_union(a, b, c)) + return; + + expr_ref x_in_a(u.mk_in(x, a), m); + expr_ref x_in_b(u.mk_in(x, b), m); + expr_ref x_in_c(u.mk_in(x, c), m); + + // (x in a) => (x in b) or (x in c) + theory_axiom *ax1 = alloc(theory_axiom, m, "in-union", x, a); + ax1->clause.push_back(m.mk_not(x_in_a)); + ax1->clause.push_back(x_in_b); + ax1->clause.push_back(x_in_c); + m_add_clause(ax1); + + // (x in b) => (x in a) + add_binary("in-union", x, a, m.mk_not(x_in_b), x_in_a); + + // (x in c) => (x in a) + add_binary("in-union", x, a, m.mk_not(x_in_c), x_in_a); +} + +// a := set.intersect(b, c) +// (x in a) <=> (x in b) and (x in c) +void finite_set_axioms::in_intersect_axiom(expr *x, expr *a) { + expr* b = nullptr, *c = nullptr; + if (!u.is_intersect(a, b, c)) + return; + + expr_ref x_in_a(u.mk_in(x, a), m); + expr_ref x_in_b(u.mk_in(x, b), m); + expr_ref x_in_c(u.mk_in(x, c), m); + expr_ref nx_in_a(m.mk_not(x_in_a), m); + expr_ref nx_in_b(m.mk_not(x_in_b), m); + expr_ref nx_in_c(m.mk_not(x_in_c), m); + + // (x in a) => (x in b) + add_binary("in-intersect", x, a, nx_in_a, x_in_b); + + // (x in a) => (x in c) + add_binary("in-intersect", x, a, nx_in_a, x_in_c); + + // (x in b) and (x in c) => (x in a) + add_ternary("in-intersect", x, a, nx_in_b, nx_in_c, x_in_a); +} + +// a := set.difference(b, c) +// (x in a) <=> (x in b) and not (x in c) +void finite_set_axioms::in_difference_axiom(expr *x, expr *a) { + expr* b = nullptr, *c = nullptr; + if (!u.is_difference(a, b, c)) + return; + + expr_ref x_in_a(u.mk_in(x, a), m); + expr_ref x_in_b(u.mk_in(x, b), m); + expr_ref x_in_c(u.mk_in(x, c), m); + expr_ref nx_in_a(m.mk_not(x_in_a), m); + expr_ref nx_in_b(m.mk_not(x_in_b), m); + expr_ref nx_in_c(m.mk_not(x_in_c), m); + + // (x in a) => (x in b) + add_binary("in-difference", x, a, nx_in_a, x_in_b); + + // (x in a) => not (x in c) + add_binary("in-difference", x, a, nx_in_a, nx_in_c); + + // (x in b) and not (x in c) => (x in a) + add_ternary("in-difference", x, a, nx_in_b, x_in_c, x_in_a); +} + +// a := set.singleton(b) +// (x in a) <=> (x == b) +void finite_set_axioms::in_singleton_axiom(expr *x, expr *a) { + expr* b = nullptr; + if (!u.is_singleton(a, b)) + return; + + expr_ref x_in_a(u.mk_in(x, a), m); + + if (x == b) { + // If x and b are syntactically identical, then (x in a) is always true + theory_axiom* ax = alloc(theory_axiom, m, "in-singleton", x, a); + ax->clause.push_back(x_in_a); + m_add_clause(ax); + return; + } + + expr_ref x_eq_b(m.mk_eq(x, b), m); + + // (x in a) => (x == b) + add_binary("in-singleton", x, a, m.mk_not(x_in_a), x_eq_b); + + // (x == b) => (x in a) + add_binary("in-singleton", x, a, m.mk_not(x_eq_b), x_in_a); +} + +void finite_set_axioms::in_singleton_axiom(expr* a) { + expr *b = nullptr; + if (!u.is_singleton(a, b)) + return; + add_unit("in-singleton", a, u.mk_in(b, a)); +} + +// a := set.range(lo, hi) +// (x in a) <=> (lo <= x <= hi) +// we use the rewriter to simplify inequalitiess because the arithmetic solver +// makes some assumptions that inequalities are in normal form. +// this complicates proof checking. +// Options are to include a proof of the rewrite within the justification +// fix the arithmetic solver to use the inequalities without rewriting (it really should) +// the same issue applies to everywhere we apply rewriting when adding theory axioms. + +void finite_set_axioms::in_range_axiom(expr *x, expr *a) { + expr* lo = nullptr, *hi = nullptr; + if (!u.is_range(a, lo, hi)) + return; + + arith_util arith(m); + expr_ref x_in_a(u.mk_in(x, a), m); + expr_ref lo_le_x(arith.mk_le(arith.mk_sub(lo, x), arith.mk_int(0)), m); + expr_ref x_le_hi(arith.mk_le(arith.mk_sub(x, hi), arith.mk_int(0)), m); + m_rewriter(lo_le_x); + m_rewriter(x_le_hi); + expr_ref nx_le_hi(m.mk_not(x_le_hi), m); + expr_ref nlo_le_x(m.mk_not(lo_le_x), m); + + // (x in a) => (lo <= x) + add_binary("in-range", x, a, m.mk_not(x_in_a), lo_le_x); + + // (x in a) => (x <= hi) + add_binary("in-range", x, a, m.mk_not(x_in_a), x_le_hi); + + // (lo <= x) and (x <= hi) => (x in a) + add_ternary("in-range", x, a, nlo_le_x, nx_le_hi, x_in_a); +} + +// a := set.range(lo, hi) +// (not (set.in (- lo 1) r)) +// (not (set.in (+ hi 1) r)) +// (set.in lo r) +// (set.in hi r) +void finite_set_axioms::in_range_axiom(expr* r) { + expr *lo = nullptr, *hi = nullptr; + if (!u.is_range(r, lo, hi)) + return; + + arith_util a(m); + expr_ref lo_le_hi(a.mk_le(a.mk_sub(lo, hi), a.mk_int(0)), m); + m_rewriter(lo_le_hi); + + add_binary("range-bounds", r, nullptr, m.mk_not(lo_le_hi), u.mk_in(lo, r)); + add_binary("range-bounds", r, nullptr, m.mk_not(lo_le_hi), u.mk_in(hi, r)); + add_unit("range-bounds", r, m.mk_not(u.mk_in(a.mk_add(hi, a.mk_int(1)), r))); + add_unit("range-bounds", r, m.mk_not(u.mk_in(a.mk_add(lo, a.mk_int(-1)), r))); +} + +// a := set.map(f, b) +// (x in a) <=> set.map_inverse(f, x, b) in b +// +void finite_set_axioms::in_map_axiom(expr *x, expr *a) { + expr *f = nullptr, *b = nullptr; + sort *elem_sort = nullptr; + VERIFY(u.is_finite_set(a->get_sort(), elem_sort)); + if (x->get_sort() != elem_sort) + return; + if (!u.is_map(a, f, b)) + return; + + expr_ref inv(u.mk_map_inverse(f, x, b), m); + expr_ref f1(u.mk_in(x, a), m); + expr_ref f2(u.mk_in(inv, b), m); + add_binary("map-inverse", x, a, m.mk_not(f1), f2); + add_binary("map-inverse", x, b, f1, m.mk_not(f2)); +} + +// a := set.map(f, b) +// (x in b) => f(x) in a +void finite_set_axioms::in_map_image_axiom(expr *x, expr *a) { + expr* f = nullptr, *b = nullptr; + sort *elem_sort = nullptr; + if (!u.is_map(a, f, b)) + return; + VERIFY(u.is_finite_set(b->get_sort(), elem_sort)); + if (x->get_sort() != elem_sort) + return; + + expr_ref x_in_b(u.mk_in(x, b), m); + + // Apply function f to x using array select + array_util autil(m); + expr_ref fx(autil.mk_select(f, x), m); + expr_ref fx_in_a(u.mk_in(fx, a), m); + m_rewriter(fx); + + // (x in b) => f(x) in a + add_binary("in-map", x, a, m.mk_not(x_in_b), fx_in_a); +} + +// a := set.filter(p, b) +// (x in a) <=> (x in b) and p(x) +void finite_set_axioms::in_filter_axiom(expr *x, expr *a) { + expr* p = nullptr, *b = nullptr; + if (!u.is_filter(a, p, b)) + return; + + expr_ref x_in_a(u.mk_in(x, a), m); + expr_ref x_in_b(u.mk_in(x, b), m); + + // Apply predicate p to x using array select + array_util autil(m); + expr_ref px(autil.mk_select(p, x), m); + m_rewriter(px); + expr_ref npx(mk_not(m, px), m); + + // (x in a) => (x in b) + add_binary("in-filter", x, a, m.mk_not(x_in_a), x_in_b); + + // (x in a) => p(x) + add_binary("in-filter", x, a, m.mk_not(x_in_a), px); + + // (x in b) and p(x) => (x in a) + add_ternary("in-filter", x, a, m.mk_not(x_in_b), npx, x_in_a); +} + +// Auxiliary algebraic axioms to ease reasoning about set.size +// The axioms are not required for completenss for the base fragment +// as they are handled by creating semi-linear sets. +void finite_set_axioms::size_ub_axiom(expr *sz) { + expr *b = nullptr, *e = nullptr, *x = nullptr, *y = nullptr; + if (!u.is_size(sz, e)) + return; + arith_util a(m); + expr_ref ineq(m); + + if (u.is_singleton(e, b)) + add_unit("size", e, m.mk_eq(sz, a.mk_int(1))); + else if (u.is_empty(e)) + add_unit("size", e, m.mk_eq(sz, a.mk_int(0))); + else if (u.is_union(e, x, y)) { + ineq = a.mk_le(sz, a.mk_add(u.mk_size(x), u.mk_size(y))); + m_rewriter(ineq); + add_unit("size", e, ineq); + } + else if (u.is_intersect(e, x, y)) { + ineq = a.mk_le(sz, u.mk_size(x)); + m_rewriter(ineq); + add_unit("size", e, ineq); + ineq = a.mk_le(sz, u.mk_size(y)); + m_rewriter(ineq); + add_unit("size", e, ineq); + } + else if (u.is_difference(e, x, y)) { + ineq = a.mk_le(sz, u.mk_size(x)); + m_rewriter(ineq); + add_unit("size", e, ineq); + } + else if (u.is_filter(e, x, y)) { + ineq = a.mk_le(sz, u.mk_size(y)); + m_rewriter(ineq); + add_unit("size", e, ineq); + } + else if (u.is_map(e, x, y)) { + ineq = a.mk_le(sz, u.mk_size(y)); + m_rewriter(ineq); + add_unit("size", e, ineq); + } + else if (u.is_range(e, x, y)) { + ineq = a.mk_eq(sz, m.mk_ite(a.mk_le(x, y), a.mk_add(a.mk_sub(y, x), a.mk_int(1)), a.mk_int(0))); + m_rewriter(ineq); + add_unit("size", e, ineq); + } +} + +void finite_set_axioms::size_lb_axiom(expr* e) { + VERIFY(u.is_size(e)); + arith_util a(m); + expr_ref ineq(m); + ineq = a.mk_le(a.mk_int(0), e); + m_rewriter(ineq); + add_unit("size", e, ineq); +} + +void finite_set_axioms::subset_axiom(expr* a) { + expr *b = nullptr, *c = nullptr; + if (!u.is_subset(a, b, c)) + return; + expr_ref eq(m.mk_eq(u.mk_intersect(b, c), b), m); + add_binary("subset", a, nullptr, m.mk_not(a), eq); + add_binary("subset", a, nullptr, a, m.mk_not(eq)); +} + +void finite_set_axioms::extensionality_axiom(expr *a, expr* b) { + // a != b => set.in (set.diff(a, b) a) != set.in (set.diff(a, b) b) + expr_ref diff_ab(u.mk_ext(a, b), m); + + expr_ref a_eq_b(m.mk_eq(a, b), m); + expr_ref diff_in_a(u.mk_in(diff_ab, a), m); + expr_ref diff_in_b(u.mk_in(diff_ab, b), m); + expr_ref ndiff_in_a(m.mk_not(diff_in_a), m); + expr_ref ndiff_in_b(m.mk_not(diff_in_b), m); + + // (a != b) => (x in diff_ab != x in diff_ba) + add_ternary("extensionality", a, b, a_eq_b, ndiff_in_a, ndiff_in_b); + add_ternary("extensionality", a, b, a_eq_b, diff_in_a, diff_in_b); +} \ No newline at end of file diff --git a/src/ast/rewriter/finite_set_axioms.h b/src/ast/rewriter/finite_set_axioms.h new file mode 100644 index 000000000..029833fd0 --- /dev/null +++ b/src/ast/rewriter/finite_set_axioms.h @@ -0,0 +1,137 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + finite_set_axioms.h + +Abstract: + + This module implements axiom schemas that are invoked by saturating constraints + with respect to the semantics of set operations. + +--*/ + +#pragma once +#include "ast/rewriter/th_rewriter.h" + +struct theory_axiom { + expr_ref_vector clause; + vector params; + unsigned weight = 0; // can be used to prioritize instantiation of axioms + theory_axiom(ast_manager& m, symbol const& th): clause(m) { + params.push_back(parameter(th)); + } + theory_axiom(ast_manager &m, char const* rule) : clause(m) { + params.push_back(parameter(symbol(rule))); + } + theory_axiom(ast_manager &m) : clause(m) { + } + + theory_axiom(ast_manager &m, char const *rule, expr* x, expr* y = nullptr) : clause(m) { + params.push_back(parameter(symbol(rule))); + params.push_back(parameter(x)); + if (y) + params.push_back(parameter(y)); + } +}; + +std::ostream &operator<<(std::ostream &out, theory_axiom const &ax); + + +class finite_set_axioms { + ast_manager& m; + finite_set_util u; + th_rewriter m_rewriter; + + std::function m_add_clause; + + void add_unit(char const* name, expr* p1, expr *e); + + void add_binary(char const *name, expr *p1, expr *p2, expr *f1, expr *f2); + + void add_ternary(char const *name, expr *p1, expr *p2, expr *f1, expr *f2, expr *f3); + + bool is_true(expr *f); + + bool is_false(expr *f); + +public: + + finite_set_axioms(ast_manager &m) : m(m), u(m), m_rewriter(m) {} + + void set_add_clause(std::function &ac) { + m_add_clause = ac; + } + + // a ~ set.empty => not (x in a) + void in_empty_axiom(expr *x); + + // a := set.union(b, c) + // (x in a) <=> (x in b) or (x in c) + void in_union_axiom(expr *x, expr *a); + + // a := set.intersect(b, c) + // (x in a) <=> (x in b) and (x in c) + void in_intersect_axiom(expr *x, expr *a); + + // a := set.difference(b, c) + // (x in a) <=> (x in b) and not (x in c) + void in_difference_axiom(expr *x, expr *a); + + // a := set.singleton(b) + // (x in a) <=> (x == b) + void in_singleton_axiom(expr *x, expr *a); + + // a := set.singleton(b) + // b in a + // b-1 not in a + // b+1 not in a + void in_singleton_axiom(expr *a); + + // a := set.range(lo, hi) + // (x in a) <=> (lo <= x <= hi) + void in_range_axiom(expr *x, expr *a); + + // a := set.range(lo, hi) + // (not (set.in (- lo 1) a)) + // (not (set.in (+ hi 1) a)) + // lo <= hi => (set.in lo a) + // lo <= hi => (set.in hi a) + void in_range_axiom(expr *a); + + // a := set.map(f, b) + // (x in a) <=> set.map_inverse(f, x, b) in b + void in_map_axiom(expr *x, expr *a); + + // a := set.map(f, b) + // (x in b) => f(x) in a + void in_map_image_axiom(expr *x, expr *a); + + // a := set.filter(p, b) + // (x in a) <=> (x in b) and p(x) + void in_filter_axiom(expr *x, expr *a); + + // a := set.subset(b, c) + // (a) <=> (set.intersect(b, c) = b) + void subset_axiom(expr *a); + + + // set.size(empty) = 0 + // set.size(set.singleton(b)) = 1 + // set.size(a u b) <= set.size(a) + set.size(b) + // set.size(a n b) <= min(set.size(a), set.size(b)) + // set.size(a \ b) <= set.size(a) + // set.size(set.map(f, b)) <= set.size(b) + // set.size(set.filter(p, b)) <= set.size(b) + // set.size([l..u]) = if(l <= u, u - l + 1, 0) + void size_ub_axiom(expr *a); + + // 0 <= set.size(e) + void size_lb_axiom(expr *e); + + + // a != b => set.in (set.diff(a, b) a) != set.in (set.diff(a, b) b) + void extensionality_axiom(expr *a, expr *b); + +}; \ No newline at end of file diff --git a/src/ast/rewriter/finite_set_rewriter.cpp b/src/ast/rewriter/finite_set_rewriter.cpp new file mode 100644 index 000000000..b86f211f1 --- /dev/null +++ b/src/ast/rewriter/finite_set_rewriter.cpp @@ -0,0 +1,431 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + finite_set_rewriter.cpp + +Abstract: + + Rewriting Simplification for finite sets + +Author: + + Nikolaj Bjorner (nbjorner) - October 2025 + +--*/ + +#include "ast/rewriter/finite_set_rewriter.h" +#include "ast/arith_decl_plugin.h" +#include "ast/ast_pp.h" + +br_status finite_set_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result) { + SASSERT(f->get_family_id() == get_fid()); + + switch (f->get_decl_kind()) { + case OP_FINITE_SET_UNION: + return mk_union(num_args, args, result); + case OP_FINITE_SET_INTERSECT: + return mk_intersect(num_args, args, result); + case OP_FINITE_SET_DIFFERENCE: + SASSERT(num_args == 2); + return mk_difference(args[0], args[1], result); + case OP_FINITE_SET_SUBSET: + SASSERT(num_args == 2); + return mk_subset(args[0], args[1], result); + case OP_FINITE_SET_SINGLETON: + SASSERT(num_args == 1); + return mk_singleton(args[0], result); + case OP_FINITE_SET_IN: + SASSERT(num_args == 2); + return mk_in(args[0], args[1], result); + case OP_FINITE_SET_SIZE: + return mk_size(args[0], result); + default: + return BR_FAILED; + } +} + +br_status finite_set_rewriter::mk_union(unsigned num_args, expr * const * args, expr_ref & result) { + VERIFY(num_args == 2); + // Idempotency: set.union(x, x) -> x + if (args[0] == args[1]) { + result = args[0]; + return BR_DONE; + } + + // Identity: set.union(x, empty) -> x or set.union(empty, x) -> x + if (u.is_empty(args[0])) { + result = args[1]; + return BR_DONE; + } + if (u.is_empty(args[1])) { + result = args[0]; + return BR_DONE; + } + + // Absorption: set.union(x, set.intersect(x, y)) -> x + expr *a1, *a2; + if (u.is_intersect(args[1], a1, a2)) { + if (args[0] == a1 || args[0] == a2) { + result = args[0]; + return BR_DONE; + } + } + + // Absorption: set.union(set.intersect(x, y), x) -> x + if (u.is_intersect(args[0], a1, a2)) { + if (args[1] == a1 || args[1] == a2) { + result = args[1]; + return BR_DONE; + } + } + + + return BR_FAILED; +} + +br_status finite_set_rewriter::mk_intersect(unsigned num_args, expr * const * args, expr_ref & result) { + if (num_args != 2) + return BR_FAILED; + + // Idempotency: set.intersect(x, x) -> x + if (args[0] == args[1]) { + result = args[0]; + return BR_DONE; + } + + // Annihilation: set.intersect(x, empty) -> empty or set.intersect(empty, x) -> empty + if (u.is_empty(args[0])) { + result = args[0]; + return BR_DONE; + } + if (u.is_empty(args[1])) { + result = args[1]; + return BR_DONE; + } + + // Absorption: set.intersect(x, set.union(x, y)) -> x + expr *a1, *a2; + if (u.is_union(args[1], a1, a2)) { + if (args[0] == a1 || args[0] == a2) { + result = args[0]; + return BR_DONE; + } + } + + // Absorption: set.intersect(set.union(x, y), x) -> x + if (u.is_union(args[0], a1, a2)) { + if (args[1] == a1 || args[1] == a2) { + result = args[1]; + return BR_DONE; + } + } + expr *l1, *l2, *u1, *u2; + if (u.is_range(args[0], l1, u1) && u.is_range(args[1], l2, u2)) { + arith_util a(m); + auto max_l = m.mk_ite(a.mk_ge(l1, l2), l1, l2); + auto min_u = m.mk_ite(a.mk_ge(u1, u2), u2, u1); + result = u.mk_range(max_l, min_u); + return BR_REWRITE_FULL; + } + + return BR_FAILED; +} + +br_status finite_set_rewriter::mk_difference(expr * arg1, expr * arg2, expr_ref & result) { + // set.difference(x, x) -> set.empty + if (arg1 == arg2) { + sort* set_sort = arg1->get_sort(); + SASSERT(u.is_finite_set(set_sort)); + result = u.mk_empty(set_sort); + return BR_DONE; + } + + // Identity: set.difference(x, empty) -> x + if (u.is_empty(arg2)) { + result = arg1; + return BR_DONE; + } + + // Annihilation: set.difference(empty, x) -> empty + if (u.is_empty(arg1)) { + result = arg1; + return BR_DONE; + } + + return BR_FAILED; +} + +br_status finite_set_rewriter::mk_subset(expr * arg1, expr * arg2, expr_ref & result) { + // set.subset(x, x) -> true + if (arg1 == arg2) { + result = m.mk_true(); + return BR_DONE; + } + + // set.subset(empty, x) -> true + if (u.is_empty(arg1)) { + result = m.mk_true(); + return BR_DONE; + } + + // set.subset(x, empty) -> x = empty + if (u.is_empty(arg2)) { + result = m.mk_eq(arg1, arg2); + return BR_REWRITE1; + } + + // General case: set.subset(x, y) -> set.intersect(x, y) = x + expr_ref intersect(m); + intersect = u.mk_intersect(arg1, arg2); + result = m.mk_eq(intersect, arg1); + return BR_REWRITE3; +} + +br_status finite_set_rewriter::mk_singleton(expr * arg, expr_ref & result) { + // Singleton is already in normal form, no simplifications + return BR_FAILED; +} + +br_status finite_set_rewriter::mk_size(expr * arg, expr_ref & result) { + arith_util a(m); + if (u.is_empty(arg)) { + // size(empty) -> 0 + result = a.mk_int(0); + return BR_DONE; + } + if (u.is_singleton(arg)) { + // size(singleton(x)) -> 1 + result = a.mk_int(1); + return BR_DONE; + } + expr *lower, *upper; + if (u.is_range(arg, lower, upper)) { + // size(range(a, b)) -> b - a + 1 + expr_ref size_expr(m); + size_expr = a.mk_add(a.mk_sub(upper, lower), a.mk_int(1)); + result = m.mk_ite(a.mk_gt(lower, upper), a.mk_int(0), size_expr); + return BR_REWRITE3; + } + // Size is already in normal form, no simplifications + return BR_FAILED; +} + +br_status finite_set_rewriter::mk_in(expr * elem, expr * set, expr_ref & result) { + // set.in(x, empty) -> false + if (u.is_empty(set)) { + result = m.mk_false(); + return BR_DONE; + } + + // set.in(x, singleton(y)) checks + expr* singleton_elem; + if (u.is_singleton(set, singleton_elem)) { + // set.in(x, singleton(x)) -> true (when x is the same) + if (elem == singleton_elem) { + result = m.mk_true(); + return BR_DONE; + } + // set.in(x, singleton(y)) -> x = y (when x != y) + result = m.mk_eq(elem, singleton_elem); + return BR_REWRITE1; + } + expr *lo = nullptr, *hi = nullptr; + if (u.is_range(set, lo, hi)) { + arith_util a(m); + result = m.mk_and(a.mk_le(lo, elem), a.mk_le(elem, hi)); + return BR_REWRITE2; + } + // NB we don't rewrite (set.in x (set.union s t)) to (or (set.in x s) (set.in x t)) + // because it creates two new sub-expressions. The expression (set.union s t) could + // be shared with other expressions so the net effect of this rewrite could be to create + // a larger formula for the solver. + return BR_FAILED; +} + + +/** +* if a, b are set expressions we can create an on-the-fly heap for their min-elements +* a, b are normalized to the form (set.union s t) or (set.empty) where +* s is a singleton or range expression such that every element in t are above s. +* we distinguish numerical values from value expressions: +* - for numerical values we use the ordering over numerals to pick minimal ranges +* - for unique value expressions ranging over non-numerals use expression identifiers +* - for other expressions use identifiers to sort expressions, but make sure to be inconclusive +* for set difference +* We want mk_eq_core to produce a result true/false if the arguments are both (unique) values. +* This allows to evaluate models for being well-formed conclusively. +* +* A way to convert a set expression to a heap is as follows: +* +* min({s}) = {s} u {} +* min({}) = {} +* min([l..u]) = [l..u] u {} +* min(s u t) = +* let {x} u s1 = min(s) +* let {y} u t1 = min(t) +* if x = y then +* { x } u (s1 u t1) +* else if x < y then +* {x} u (s1 u ({y} u t1) +* else // x > y +* {y} u (t1 u ({x} u s1) +* +* Handling ranges is TBD +* For proper range handling we have to change is_less on numeric singleton sets +* to use the numerical value, not the expression identifier. Then the ordering +* has to make all numeric values less than symbolic values. +*/ + +bool finite_set_rewriter::is_less(expr *a, expr *b) { + return a->get_id() < b->get_id(); +} + +expr* finite_set_rewriter::mk_union(expr* a, expr* b) { + if (u.is_empty(a)) + return b; + if (u.is_empty(b)) + return a; + if (a == b) + return a; + return u.mk_union(a, b); +} + +expr* finite_set_rewriter::min(expr* e) { + if (m_is_min.is_marked(e)) + return e; + expr *a = nullptr, *b = nullptr; + if (u.is_union(e, a, b)) { + a = min(a); + b = min(b); + if (u.is_empty(a)) + return b; + if (u.is_empty(b)) + return a; + auto [x,a1] = get_min(a); + auto [y,b1] = get_min(b); + if (x == y) + a = mk_union(x, mk_union(a1, b1)); + else if (is_less(x, y)) + a = mk_union(x, mk_union(a1, b)); + else + a = mk_union(y, mk_union(a, b1)); + m_pinned.push_back(a); + m_is_min.mark(a); + return a; + } + if (u.is_intersect(e, a, b)) { + if (!from_unique_values(a) || !from_unique_values(b)) { + m_pinned.push_back(e); + m_is_min.mark(e); + return e; + } + while (true) { + a = min(a); + b = min(b); + if (u.is_empty(a)) + return a; + if (u.is_empty(b)) + return b; + auto [x, a1] = get_min(a); + auto [y, b1] = get_min(b); + if (x == y) { + a = mk_union(x, u.mk_intersect(a1, b1)); + m_pinned.push_back(a); + m_is_min.mark(a); + return a; + } + else if (is_less(x, y)) + a = a1; + else + b = b1; + } + } + if (u.is_difference(e, a, b)) { + if (!from_unique_values(a) || !from_unique_values(b)) { + m_pinned.push_back(e); + m_is_min.mark(e); + return e; + } + while (true) { + a = min(a); + b = min(b); + if (u.is_empty(a) || u.is_empty(b)) + return a; + auto [x, a1] = get_min(a); + auto [y, b1] = get_min(b); + if (x == y) { + a = a1; + b = b1; + } + else if (is_less(x, y)) { + a = mk_union(x, u.mk_difference(a1, b)); + m_pinned.push_back(a); + m_is_min.mark(a); + return a; + } + else { + b = b1; + } + } + } + // set.filter, set.map don't have decompositions + m_pinned.push_back(e); + m_is_min.mark(e); + return e; +} + +std::pair finite_set_rewriter::get_min(expr* a) { + expr *x = nullptr, *y = nullptr; + if (u.is_union(a, x, y)) + return {x, y}; + auto empty = u.mk_empty(a->get_sort()); + m_pinned.push_back(empty); + return {a, empty}; +} + +br_status finite_set_rewriter::mk_eq_core(expr *a, expr *b, expr_ref &result) { + m_is_min.reset(); + m_pinned.reset(); + bool are_unique = true; + while (true) { + if (a == b) { + result = m.mk_true(); + return BR_DONE; + } + TRACE(finite_set, tout << mk_pp(a, m) << " == " << mk_pp(b, m) << "\n"); + a = min(a); + b = min(b); + auto [x, a1] = get_min(a); + auto [y, b1] = get_min(b); + + // only empty sets and singletons of unique values are unique. + // ranges are not counted as unique. + are_unique &= m.is_unique_value(x) && m.is_unique_value(y); + a = a1; + b = b1; + if (x == y) + continue; + + if (m.are_distinct(x, y) && are_unique) { + are_unique &= from_unique_values(a); + are_unique &= from_unique_values(b); + if (are_unique) { + result = m.mk_false(); + return BR_DONE; + } + } + return BR_FAILED; + } +} + +bool finite_set_rewriter::from_unique_values(expr *a) { + while (!u.is_empty(a)) { + auto [x, a1] = get_min(min(a)); + if (!m.is_unique_value(x)) + return false; + a = a1; + } + return true; +} diff --git a/src/ast/rewriter/finite_set_rewriter.h b/src/ast/rewriter/finite_set_rewriter.h new file mode 100644 index 000000000..062bf9a08 --- /dev/null +++ b/src/ast/rewriter/finite_set_rewriter.h @@ -0,0 +1,69 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + finite_set_rewriter.h + +Abstract: + + Rewriting Simplification for finite sets + +Sample rewrite rules: + set.union s set.empty -> s + set.intersect s set.empty -> set.empty + set.in x (set.singleton y) -> x = y + set.subset(x,y) -> set.intersect(x,y) = x + set.union(x, x) -> x + set.intersect(x, x) -> x + set.difference(x, x) -> set.empty + + +Generally this module implements basic algebraic simplification rules for finite sets +where the signature is defined in finite_set_decl_plugin.h. + +--*/ +#pragma once + +#include "ast/finite_set_decl_plugin.h" +#include "ast/rewriter/rewriter_types.h" +#include "util/params.h" + +/** + \brief Cheap rewrite rules for finite sets +*/ +class finite_set_rewriter { + friend class finite_set_rewriter_test; + ast_manager &m; + finite_set_util u; + expr_ref_vector m_pinned; + expr_mark m_is_min; + + expr * min(expr *a); + std::pair get_min(expr *a); + bool is_less(expr *a, expr *b); + expr *mk_union(expr *a, expr *b); + bool from_unique_values(expr *a); + + // Rewrite rules for set operations + br_status mk_union(unsigned num_args, expr *const *args, expr_ref &result); + br_status mk_intersect(unsigned num_args, expr *const *args, expr_ref &result); + br_status mk_difference(expr *arg1, expr *arg2, expr_ref &result); + br_status mk_subset(expr *arg1, expr *arg2, expr_ref &result); + br_status mk_singleton(expr *arg1, expr_ref &result); + br_status mk_in(expr *arg1, expr *arg2, expr_ref &result); + br_status mk_size(expr *arg, expr_ref &result); + +public: + finite_set_rewriter(ast_manager & m, params_ref const & p = params_ref()): + m(m), u(m), m_pinned(m) { + } + + family_id get_fid() const { return u.get_family_id(); } + finite_set_util& util() { return u; } + + br_status mk_app_core(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result); + + br_status mk_eq_core(expr *a, expr *b, expr_ref &result); +}; + diff --git a/src/ast/rewriter/th_rewriter.cpp b/src/ast/rewriter/th_rewriter.cpp index e5d52ce5a..f77bc1a68 100644 --- a/src/ast/rewriter/th_rewriter.cpp +++ b/src/ast/rewriter/th_rewriter.cpp @@ -30,6 +30,7 @@ Notes: #include "ast/rewriter/pb_rewriter.h" #include "ast/rewriter/recfun_rewriter.h" #include "ast/rewriter/seq_rewriter.h" +#include "ast/rewriter/finite_set_rewriter.h" #include "ast/rewriter/rewriter_def.h" #include "ast/rewriter/var_subst.h" #include "ast/rewriter/der.h" @@ -55,6 +56,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg { seq_rewriter m_seq_rw; char_rewriter m_char_rw; recfun_rewriter m_rec_rw; + finite_set_rewriter m_fs_rw; arith_util m_a_util; bv_util m_bv_util; der m_der; @@ -230,6 +232,8 @@ struct th_rewriter_cfg : public default_rewriter_cfg { return m_char_rw.mk_app_core(f, num, args, result); if (fid == m_rec_rw.get_fid()) return m_rec_rw.mk_app_core(f, num, args, result); + if (fid == m_fs_rw.get_fid()) + return m_fs_rw.mk_app_core(f, num, args, result); return BR_FAILED; } @@ -685,6 +689,8 @@ struct th_rewriter_cfg : public default_rewriter_cfg { st = m_ar_rw.mk_eq_core(a, b, result); else if (s_fid == m_seq_rw.get_fid()) st = m_seq_rw.mk_eq_core(a, b, result); + else if (s_fid == m_fs_rw.get_fid()) + st = m_fs_rw.mk_eq_core(a, b, result); if (st != BR_FAILED) return st; st = extended_bv_eq(a, b, result); @@ -883,7 +889,8 @@ struct th_rewriter_cfg : public default_rewriter_cfg { m_pb_rw(m), m_seq_rw(m, p), m_char_rw(m), - m_rec_rw(m), + m_rec_rw(m), + m_fs_rw(m), m_a_util(m), m_bv_util(m), m_der(m), diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index b065607f6..a4363d006 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -32,6 +32,7 @@ Notes: #include "ast/pb_decl_plugin.h" #include "ast/fpa_decl_plugin.h" #include "ast/special_relations_decl_plugin.h" +#include "ast/finite_set_decl_plugin.h" #include "ast/ast_pp.h" #include "ast/pp.h" #include "ast/ast_smt2_pp.h" @@ -532,6 +533,7 @@ protected: fpa_util m_futil; seq_util m_sutil; datatype_util m_dtutil; + finite_set_util m_fsutil; datalog::dl_decl_util m_dlutil; @@ -553,7 +555,8 @@ protected: } public: - pp_env(cmd_context & o):m_owner(o), m_autil(o.m()), m_bvutil(o.m()), m_arutil(o.m()), m_futil(o.m()), m_sutil(o.m()), m_dtutil(o.m()), m_dlutil(o.m()) {} + pp_env(cmd_context & o):m_owner(o), m_autil(o.m()), m_bvutil(o.m()), m_arutil(o.m()), m_futil(o.m()), + m_sutil(o.m()), m_dtutil(o.m()), m_fsutil(o.m()), m_dlutil(o.m()) {} ast_manager & get_manager() const override { return m_owner.m(); } arith_util & get_autil() override { return m_autil; } bv_util & get_bvutil() override { return m_bvutil; } @@ -561,6 +564,7 @@ public: fpa_util & get_futil() override { return m_futil; } seq_util & get_sutil() override { return m_sutil; } datatype_util & get_dtutil() override { return m_dtutil; } + finite_set_util &get_fsutil() override { return m_fsutil; } datalog::dl_decl_util& get_dlutil() override { return m_dlutil; } bool uses(symbol const & s) const override { @@ -829,6 +833,7 @@ void cmd_context::init_manager_core(bool new_manager) { register_plugin(symbol("fpa"), alloc(fpa_decl_plugin), logic_has_fpa()); register_plugin(symbol("datalog_relation"), alloc(datalog::dl_decl_plugin), !has_logic()); register_plugin(symbol("specrels"), alloc(special_relations_decl_plugin), !has_logic()); + register_plugin(symbol("finite_set"), alloc(finite_set_decl_plugin), !has_logic() || smt_logics::logic_has_finite_sets(m_logic)); } else { // the manager was created by an external module @@ -845,6 +850,7 @@ void cmd_context::init_manager_core(bool new_manager) { load_plugin(symbol("seq"), logic_has_seq(), fids); load_plugin(symbol("fpa"), logic_has_fpa(), fids); load_plugin(symbol("pb"), logic_has_pb(), fids); + load_plugin(symbol("finite_set"), smt_logics::logic_has_finite_sets(m_logic) || !has_logic(), fids); for (family_id fid : fids) { decl_plugin * p = m_manager->get_plugin(fid); diff --git a/src/model/CMakeLists.txt b/src/model/CMakeLists.txt index 9ba93b8e1..12fce27e8 100644 --- a/src/model/CMakeLists.txt +++ b/src/model/CMakeLists.txt @@ -2,6 +2,7 @@ z3_add_component(model SOURCES array_factory.cpp datatype_factory.cpp + finite_set_factory.cpp func_interp.cpp model2expr.cpp model_core.cpp diff --git a/src/model/array_factory.cpp b/src/model/array_factory.cpp index 9e34846a2..518ea4108 100644 --- a/src/model/array_factory.cpp +++ b/src/model/array_factory.cpp @@ -63,8 +63,8 @@ void array_factory::get_some_args_for(sort * s, ptr_buffer & args) { expr * array_factory::get_some_value(sort * s) { TRACE(array_factory, tout << mk_pp(s, m_manager) << "\n";); value_set * set = nullptr; - if (m_sort2value_set.find(s, set) && !set->empty()) - return *(set->begin()); + if (m_sort2value_set.find(s, set) && !set->set.empty()) + return *(set->set.begin()); func_interp * fi; expr * val = mk_array_interp(s, fi); fi->set_else(m_model.get_some_value(get_array_range(s))); @@ -75,7 +75,7 @@ bool array_factory::mk_two_diff_values_for(sort * s) { TRACE(array_factory, tout << mk_pp(s, m_manager) << "\n";); DEBUG_CODE({ value_set * set = 0; - SASSERT(!m_sort2value_set.find(s, set) || set->size() <= 1); + SASSERT(!m_sort2value_set.find(s, set) || set->set.size() <= 1); }); expr_ref r1(m_manager); expr_ref r2(m_manager); @@ -92,24 +92,24 @@ bool array_factory::mk_two_diff_values_for(sort * s) { fi2->insert_entry(args.data(), r2); DEBUG_CODE({ value_set * set = 0; - SASSERT(m_sort2value_set.find(s, set) && set->size() >= 2); + SASSERT(m_sort2value_set.find(s, set) && set->set.size() >= 2); }); return true; } bool array_factory::get_some_values(sort * s, expr_ref & v1, expr_ref & v2) { value_set * set = nullptr; - if (!m_sort2value_set.find(s, set) || set->size() < 2) { + if (!m_sort2value_set.find(s, set) || set->set.size() < 2) { if (!mk_two_diff_values_for(s)) { TRACE(array_factory_bug, tout << "could not create diff values: " << mk_pp(s, m_manager) << "\n";); return false; } } m_sort2value_set.find(s, set); - SASSERT(set != 0); - SASSERT(set->size() >= 2); - - value_set::iterator it = set->begin(); + SASSERT(set); + SASSERT(set->set.size() >= 2); + + auto it = set->set.begin(); v1 = *it; ++it; v2 = *it; @@ -126,8 +126,8 @@ bool array_factory::get_some_values(sort * s, expr_ref & v1, expr_ref & v2) { // is set with the result of some entry. // expr * array_factory::get_fresh_value(sort * s) { - value_set * set = get_value_set(s); - if (set->empty()) { + auto& [set, values] = get_value_set(s); + if (set.empty()) { // easy case return get_some_value(s); } diff --git a/src/model/datatype_factory.cpp b/src/model/datatype_factory.cpp index e0c2f27fe..b93703acd 100644 --- a/src/model/datatype_factory.cpp +++ b/src/model/datatype_factory.cpp @@ -30,8 +30,8 @@ expr * datatype_factory::get_some_value(sort * s) { if (!m_util.is_datatype(s)) return m_model.get_some_value(s); value_set * set = nullptr; - if (m_sort2value_set.find(s, set) && !set->empty()) - return *(set->begin()); + if (m_sort2value_set.find(s, set) && !set->set.empty()) + return *(set->set.begin()); func_decl * c = m_util.get_non_rec_constructor(s); ptr_vector args; unsigned num = c->get_arity(); @@ -50,11 +50,11 @@ expr * datatype_factory::get_last_fresh_value(sort * s) { expr * val = nullptr; if (m_last_fresh_value.find(s, val)) return val; - value_set * set = get_value_set(s); - if (set->empty()) + auto& [set, values] = get_value_set(s); + if (set.empty()) val = get_some_value(s); else - val = *(set->begin()); + val = *(set.begin()); if (m_util.is_recursive(s)) m_last_fresh_value.insert(s, val); return val; @@ -78,8 +78,8 @@ bool datatype_factory::is_subterm_of_last_value(app* e) { expr * datatype_factory::get_almost_fresh_value(sort * s) { if (!m_util.is_datatype(s)) return m_model.get_some_value(s); - value_set * set = get_value_set(s); - if (set->empty()) { + auto& [set, values] = get_value_set(s); + if (set.empty()) { expr * val = get_some_value(s); SASSERT(val); if (m_util.is_recursive(s)) @@ -117,7 +117,7 @@ expr * datatype_factory::get_almost_fresh_value(sort * s) { } if (recursive || found_fresh_arg) { app * new_value = m_manager.mk_app(constructor, args); - SASSERT(!found_fresh_arg || !set->contains(new_value)); + SASSERT(!found_fresh_arg || !set.contains(new_value)); register_value(new_value); if (m_util.is_recursive(s)) { if (is_subterm_of_last_value(new_value)) { @@ -140,10 +140,10 @@ expr * datatype_factory::get_fresh_value(sort * s) { if (!m_util.is_datatype(s)) return m_model.get_fresh_value(s); TRACE(datatype, tout << "generating fresh value for: " << s->get_name() << "\n";); - value_set * set = get_value_set(s); + auto& [set, values] = get_value_set(s); // Approach 0) // if no value for s was generated so far, then used get_some_value - if (set->empty()) { + if (set.empty()) { expr * val = get_some_value(s); if (m_util.is_recursive(s)) m_last_fresh_value.insert(s, val); @@ -178,12 +178,11 @@ expr * datatype_factory::get_fresh_value(sort * s) { expr * some_arg = m_model.get_some_value(s_arg); args.push_back(some_arg); } - new_value = m_manager.mk_app(constructor, args); - CTRACE(datatype, found_fresh_arg && set->contains(new_value), tout << "seen: " << new_value << "\n";); - if (found_fresh_arg && set->contains(new_value)) + CTRACE(datatype, found_fresh_arg && set.contains(new_value), tout << "seen: " << new_value << "\n";); + if (found_fresh_arg && set.contains(new_value)) goto retry_value; - if (!set->contains(new_value)) { + if (!set.contains(new_value)) { register_value(new_value); if (m_util.is_recursive(s)) m_last_fresh_value.insert(s, new_value); @@ -241,7 +240,7 @@ expr * datatype_factory::get_fresh_value(sort * s) { new_value = m_manager.mk_app(constructor, args); TRACE(datatype, tout << "potential new value: " << mk_pp(new_value, m_manager) << "\n";); m_last_fresh_value.insert(s, new_value); - if (!set->contains(new_value)) { + if (!set.contains(new_value)) { register_value(new_value); TRACE(datatype, tout << "2. result: " << mk_pp(new_value, m_manager) << "\n";); return new_value; diff --git a/src/model/finite_set_factory.cpp b/src/model/finite_set_factory.cpp new file mode 100644 index 000000000..7e263663a --- /dev/null +++ b/src/model/finite_set_factory.cpp @@ -0,0 +1,70 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + finite_set_factory.cpp + +Abstract: + + Factory for creating finite set values + +--*/ +#include "model/finite_set_factory.h" +#include "model/model_core.h" + +finite_set_factory::finite_set_factory(ast_manager & m, family_id fid, model_core & md): + struct_factory(m, fid, md), + u(m) { +} + +expr * finite_set_factory::get_some_value(sort * s) { + // Check if we already have a value for this sort + value_set * vset = nullptr; + SASSERT(u.is_finite_set(s)); + if (m_sort2value_set.find(s, vset) && !vset->set.empty()) + return *(vset->set.begin()); + return u.mk_empty(s); +} + +/** + * create sets {}, {a}, {b}, {a,b}, {c}, {a,c}, {b,c}, {a,b,c}, {d}, ... + */ +expr * finite_set_factory::get_fresh_value(sort * s) { + sort* elem_sort = nullptr; + VERIFY(u.is_finite_set(s, elem_sort)); + + auto& [set, values] = get_value_set(s); + + // Case 1: If no values have been generated yet, return empty set + if (values.size() == 0) { + auto r = u.mk_empty(s); + register_value(r); + return r; + } + + // Helper lambda to check if a number is a power of 2 + auto is_power_of_2 = [](unsigned n) { + return n > 0 && (n & (n - 1)) == 0; + }; + + // Case 2: If values.size() is a power of 2, create a fresh singleton + if (is_power_of_2(values.size())) { + auto e = m_model.get_fresh_value(elem_sort); + if (!e) + return nullptr; + auto r = u.mk_singleton(e); + register_value(r); + return r; + } + + // Case 3: Find greatest power of 2 N < values.size() and create union + // Find the greatest N that is a power of 2 and N < values.size() + unsigned N = 1; + while (N * 2 < values.size()) + N *= 2; + + auto r = u.mk_union(values.get(values.size() - N), values.get(N)); + register_value(r); + return r; +} \ No newline at end of file diff --git a/src/model/finite_set_factory.h b/src/model/finite_set_factory.h new file mode 100644 index 000000000..d2d73a4b1 --- /dev/null +++ b/src/model/finite_set_factory.h @@ -0,0 +1,29 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + finite_set_value_factory.h + +Abstract: + + Factory for creating finite set values + +--*/ +#pragma once + +#include "model/struct_factory.h" +#include "ast/finite_set_decl_plugin.h" + +/** + \brief Factory for finite set values. +*/ +class finite_set_factory : public struct_factory { + finite_set_util u; +public: + finite_set_factory(ast_manager & m, family_id fid, model_core & md); + + expr * get_some_value(sort * s) override; + + expr * get_fresh_value(sort * s) override; +}; \ No newline at end of file diff --git a/src/model/model.cpp b/src/model/model.cpp index fa4e50e54..3b1769ef1 100644 --- a/src/model/model.cpp +++ b/src/model/model.cpp @@ -40,6 +40,7 @@ Revision History: #include "model/numeral_factory.h" #include "model/fpa_factory.h" #include "model/char_factory.h" +#include "model/finite_set_factory.h" model::model(ast_manager & m): @@ -111,6 +112,7 @@ value_factory* model::get_factory(sort* s) { m_factories.register_plugin(alloc(arith_factory, m)); m_factories.register_plugin(alloc(seq_factory, m, su.get_family_id(), *this)); m_factories.register_plugin(alloc(fpa_value_factory, m, fu.get_family_id())); + m_factories.register_plugin(alloc(finite_set_factory, m, m.get_family_id("finite_set"), *this)); //m_factories.register_plugin(alloc(char_factory, m, char_decl_plugin(m).get_family_id()); } family_id fid = s->get_family_id(); diff --git a/src/model/struct_factory.cpp b/src/model/struct_factory.cpp index 31f86f883..77171a9fa 100644 --- a/src/model/struct_factory.cpp +++ b/src/model/struct_factory.cpp @@ -19,41 +19,39 @@ Revision History: #include "model/struct_factory.h" #include "model/model_core.h" -struct_factory::value_set * struct_factory::get_value_set(sort * s) { +struct_factory::value_set& struct_factory::get_value_set(sort * s) { value_set * set = nullptr; if (!m_sort2value_set.find(s, set)) { - set = alloc(value_set); + set = alloc(value_set, m_model.get_manager()); m_sort2value_set.insert(s, set); m_sorts.push_back(s); m_sets.push_back(set); } SASSERT(set != 0); - return set; + return *set; } struct_factory::struct_factory(ast_manager & m, family_id fid, model_core & md): value_factory(m, fid), m_model(md), - m_values(m), m_sorts(m) { } struct_factory::~struct_factory() { - std::for_each(m_sets.begin(), m_sets.end(), delete_proc()); } void struct_factory::register_value(expr * new_value) { sort * s = new_value->get_sort(); - value_set * set = get_value_set(s); - if (!set->contains(new_value)) { - m_values.push_back(new_value); - set->insert(new_value); + auto& [set, values] = get_value_set(s); + if (!set.contains(new_value)) { + values.push_back(new_value); + set.insert(new_value); } } bool struct_factory::get_some_values(sort * s, expr_ref & v1, expr_ref & v2) { - value_set * set = get_value_set(s); - switch (set->size()) { + auto& [set, values] = get_value_set(s); + switch (set.size()) { case 0: v1 = get_fresh_value(s); v2 = get_fresh_value(s); @@ -63,7 +61,7 @@ bool struct_factory::get_some_values(sort * s, expr_ref & v1, expr_ref & v2) { v2 = get_fresh_value(s); return v2 != 0; default: - obj_hashtable::iterator it = set->begin(); + obj_hashtable::iterator it = set.begin(); v1 = *it; ++it; v2 = *it; diff --git a/src/model/struct_factory.h b/src/model/struct_factory.h index f8235a89b..1ac90ffe8 100644 --- a/src/model/struct_factory.h +++ b/src/model/struct_factory.h @@ -20,6 +20,7 @@ Revision History: #include "model/value_factory.h" #include "util/obj_hashtable.h" +#include "util/scoped_ptr_vector.h" class model_core; @@ -28,16 +29,19 @@ class model_core; */ class struct_factory : public value_factory { protected: - typedef obj_hashtable value_set; - typedef obj_map sort2value_set; + struct value_set { + obj_hashtable set; + expr_ref_vector values; + value_set(ast_manager &m) : values(m) {} + }; + using sort2value_set = obj_map; model_core & m_model; sort2value_set m_sort2value_set; - expr_ref_vector m_values; sort_ref_vector m_sorts; - ptr_vector m_sets; + scoped_ptr_vector m_sets; - value_set * get_value_set(sort * s); + value_set& get_value_set(sort * s); public: struct_factory(ast_manager & m, family_id fid, model_core & md); diff --git a/src/parsers/smt2/smt2parser.cpp b/src/parsers/smt2/smt2parser.cpp index 3f04ad9f0..3601c5a5e 100644 --- a/src/parsers/smt2/smt2parser.cpp +++ b/src/parsers/smt2/smt2parser.cpp @@ -1817,8 +1817,12 @@ namespace smt2 { void check_qualifier(expr * t, bool has_as) { if (has_as) { sort * s = sort_stack().back(); - if (s != t->get_sort()) - throw parser_exception("invalid qualified identifier, sort mismatch"); + if (s != t->get_sort()) { + std::ostringstream str; + str << "sort mismatch in qualified identifier, expected: " << mk_pp(s, m()) + << ", got: " << mk_pp(t->get_sort(), m()); + throw parser_exception(str.str()); + } sort_stack().pop_back(); } } diff --git a/src/smt/CMakeLists.txt b/src/smt/CMakeLists.txt index 98e79f484..a35b809e1 100644 --- a/src/smt/CMakeLists.txt +++ b/src/smt/CMakeLists.txt @@ -56,6 +56,8 @@ z3_add_component(smt theory_char.cpp theory_datatype.cpp theory_dense_diff_logic.cpp + theory_finite_set.cpp + theory_finite_set_size.cpp theory_diff_logic.cpp theory_dl.cpp theory_dummy.cpp diff --git a/src/smt/smt_arith_value.cpp b/src/smt/smt_arith_value.cpp index bc512350d..806598e76 100644 --- a/src/smt/smt_arith_value.cpp +++ b/src/smt/smt_arith_value.cpp @@ -163,4 +163,10 @@ namespace smt { return th->final_check_eh(level); } + lbool arith_value::check_lp_feasible(vector>& ineqs, literal_vector& lit_core, + enode_pair_vector& eq_core) { + if (!m_thr) + return l_undef; + return m_thr->check_lp_feasible(ineqs, lit_core, eq_core); + } }; diff --git a/src/smt/smt_arith_value.h b/src/smt/smt_arith_value.h index 09bd03d29..7e351e43d 100644 --- a/src/smt/smt_arith_value.h +++ b/src/smt/smt_arith_value.h @@ -48,5 +48,7 @@ namespace smt { expr_ref get_up(expr* e) const; expr_ref get_fixed(expr* e) const; final_check_status final_check(unsigned ); + lbool check_lp_feasible(vector> &ineqs, literal_vector &lit_core, + enode_pair_vector &eq_core); }; }; diff --git a/src/smt/smt_clause_proof.cpp b/src/smt/smt_clause_proof.cpp index bc4105e13..324674a99 100644 --- a/src/smt/smt_clause_proof.cpp +++ b/src/smt/smt_clause_proof.cpp @@ -151,15 +151,16 @@ namespace smt { update(st, m_lits, pr); } - void clause_proof::propagate(literal lit, justification const& jst, literal_vector const& ante) { + void clause_proof::propagate(literal lit, justification * jst, literal_vector const& ante) { if (!is_enabled()) return; m_lits.reset(); for (literal l : ante) m_lits.push_back(ctx.literal2expr(~l)); m_lits.push_back(ctx.literal2expr(lit)); - proof_ref pr(m.mk_app(symbol("smt"), 0, nullptr, m.mk_proof_sort()), m); - update(clause_proof::status::th_lemma, m_lits, pr); + auto st = clause_proof::status::th_lemma; + auto pr = justification2proof(st, jst); + update(st, m_lits, pr); } void clause_proof::del(clause& c) { diff --git a/src/smt/smt_clause_proof.h b/src/smt/smt_clause_proof.h index d7cc421cf..28191cfa2 100644 --- a/src/smt/smt_clause_proof.h +++ b/src/smt/smt_clause_proof.h @@ -82,7 +82,7 @@ namespace smt { void add(literal lit1, literal lit2, clause_kind k, justification* j, literal_buffer const* simp_lits = nullptr); void add(clause& c, literal_buffer const* simp_lits = nullptr); void add(unsigned n, literal const* lits, clause_kind k, justification* j); - void propagate(literal lit, justification const& j, literal_vector const& ante); + void propagate(literal lit, justification* j, literal_vector const& ante); void del(clause& c); proof_ref get_proof(bool inconsistent); bool is_enabled() const { return m_enabled; } diff --git a/src/smt/smt_conflict_resolution.cpp b/src/smt/smt_conflict_resolution.cpp index 0d81e8d9e..c8e378936 100644 --- a/src/smt/smt_conflict_resolution.cpp +++ b/src/smt/smt_conflict_resolution.cpp @@ -347,7 +347,7 @@ namespace smt { literal_vector & antecedents = m_tmp_literal_vector; antecedents.reset(); justification2literals_core(js, antecedents); - m_ctx.get_clause_proof().propagate(consequent, *js, antecedents); + m_ctx.get_clause_proof().propagate(consequent, js, antecedents); for (literal l : antecedents) process_antecedent(l, num_marks); (void)consequent; diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index ff4f5b964..dbf0c14de 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -66,7 +66,7 @@ namespace smt { m_progress_callback(nullptr), m_next_progress_sample(0), m_clause_proof(*this), - m_fingerprints(m, m_region), + m_fingerprints(m, get_region()), m_b_internalized_stack(m), m_e_internalized_stack(m), m_l_internalized_stack(m), @@ -120,6 +120,9 @@ namespace smt { if (!m_setup.already_configured()) { m_fparams.updt_params(p); } + for (auto th : m_theory_set) + if (th) + th->updt_params(); } unsigned context::relevancy_lvl() const { @@ -797,7 +800,7 @@ namespace smt { } else { // uncommon case: r2 will have two theory_vars attached to it. - r2->add_th_var(v1, t1, m_region); + r2->add_th_var(v1, t1, get_region()); push_new_th_diseqs(r2, v1, get_theory(t1)); push_new_th_diseqs(r1, v2, get_theory(t2)); } @@ -848,7 +851,7 @@ namespace smt { theory_var v2 = r2->get_th_var(t1); TRACE(merge_theory_vars, tout << get_theory(t1)->get_name() << ": " << v2 << " == " << v1 << "\n"); if (v2 == null_theory_var) { - r2->add_th_var(v1, t1, m_region); + r2->add_th_var(v1, t1, get_region()); push_new_th_diseqs(r2, v1, get_theory(t1)); } l1 = l1->get_next(); @@ -1523,16 +1526,24 @@ namespace smt { } lbool context::find_assignment(expr * n) const { - if (m.is_false(n)) - return l_false; + expr* arg = nullptr; if (m.is_not(n, arg)) { + if (b_internalized(arg)) return ~get_assignment_core(arg); + if (m.is_false(arg)) + return l_true; + if (m.is_true(arg)) + return l_false; return l_undef; } if (b_internalized(n)) return get_assignment(n); + if (m.is_false(n)) + return l_false; + if (m.is_true(n)) + return l_true; return l_undef; } @@ -1938,13 +1949,13 @@ namespace smt { m_scope_lvl++; m_region.push_scope(); + get_trail_stack().push_scope(); m_scopes.push_back(scope()); scope & s = m_scopes.back(); // TRACE(context, tout << "push " << m_scope_lvl << "\n";); m_relevancy_propagator->push(); s.m_assigned_literals_lim = m_assigned_literals.size(); - s.m_trail_stack_lim = m_trail_stack.size(); s.m_aux_clauses_lim = m_aux_clauses.size(); s.m_justifications_lim = m_justifications.size(); s.m_units_to_reassert_lim = m_units_to_reassert.size(); @@ -1960,12 +1971,6 @@ namespace smt { CASSERT("context", check_invariant()); } - /** - \brief Execute generic undo-objects. - */ - void context::undo_trail_stack(unsigned old_size) { - ::undo_trail_stack(m_trail_stack, old_size); - } /** \brief Remove watch literal idx from the given clause. @@ -2452,23 +2457,25 @@ namespace smt { m_relevancy_propagator->pop(num_scopes); m_fingerprints.pop_scope(num_scopes); + + + unassign_vars(s.m_assigned_literals_lim); - undo_trail_stack(s.m_trail_stack_lim); + m_trail_stack.pop_scope(num_scopes); for (theory* th : m_theory_set) th->pop_scope_eh(num_scopes); - del_justifications(m_justifications, s.m_justifications_lim); - m_asserted_formulas.pop_scope(num_scopes); CTRACE(propagate_atoms, !m_atom_propagation_queue.empty(), tout << m_atom_propagation_queue << "\n";); + m_eq_propagation_queue.reset(); m_th_eq_propagation_queue.reset(); + m_region.pop_scope(num_scopes); m_th_diseq_propagation_queue.reset(); m_atom_propagation_queue.reset(); - m_region.pop_scope(num_scopes); m_scopes.shrink(new_lvl); m_conflict_resolution->reset(); @@ -3056,7 +3063,7 @@ namespace smt { del_clauses(m_lemmas, 0); del_justifications(m_justifications, 0); reset_tmp_clauses(); - undo_trail_stack(0); + m_trail_stack.reset(); m_qmanager = nullptr; if (m_is_diseq_tmp) { m_is_diseq_tmp->del_eh(m, false); diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index a914c8a70..80ab6803d 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -101,6 +101,7 @@ namespace smt { setup m_setup; unsigned m_relevancy_lvl; timer m_timer; + region m_region; asserted_formulas m_asserted_formulas; th_rewriter m_rewriter; scoped_ptr m_qmanager; @@ -113,7 +114,6 @@ namespace smt { progress_callback * m_progress_callback; unsigned m_next_progress_sample; clause_proof m_clause_proof; - region m_region; fingerprint_set m_fingerprints; expr_ref_vector m_b_internalized_stack; // stack of the boolean expressions already internalized. @@ -153,6 +153,7 @@ namespace smt { vector m_decl2enodes; // decl -> enode (for decls with arity > 0) enode_vector m_empty_vector; cg_table m_cg_table; + struct new_eq { enode * m_lhs; enode * m_rhs; @@ -643,7 +644,6 @@ namespace smt { // // ----------------------------------- protected: - typedef ptr_vector trail_stack; trail_stack m_trail_stack; #ifdef Z3DEBUG bool m_trail_enabled { true }; @@ -653,11 +653,15 @@ namespace smt { template void push_trail(const TrailObject & obj) { SASSERT(m_trail_enabled); - m_trail_stack.push_back(new (m_region) TrailObject(obj)); + m_trail_stack.push(obj); } void push_trail_ptr(trail * ptr) { - m_trail_stack.push_back(ptr); + m_trail_stack.push_ptr(ptr); + } + + trail_stack& get_trail_stack() { + return m_trail_stack; } protected: @@ -667,7 +671,6 @@ namespace smt { unsigned m_search_lvl { 0 }; // It is greater than m_base_lvl when assumptions are used. Otherwise, it is equals to m_base_lvl struct scope { unsigned m_assigned_literals_lim; - unsigned m_trail_stack_lim; unsigned m_aux_clauses_lim; unsigned m_justifications_lim; unsigned m_units_to_reassert_lim; @@ -687,8 +690,6 @@ namespace smt { void pop_scope(unsigned num_scopes); - void undo_trail_stack(unsigned old_size); - void unassign_vars(unsigned old_lim); void remove_watch_literal(clause * cls, unsigned idx); @@ -1021,7 +1022,7 @@ namespace smt { template justification * mk_justification(Justification const & j) { - justification * js = new (m_region) Justification(j); + justification * js = new (get_region()) Justification(j); SASSERT(js->in_region()); if (js->has_del_eh()) m_justifications.push_back(js); diff --git a/src/smt/smt_internalizer.cpp b/src/smt/smt_internalizer.cpp index 8889409c2..081d12ebd 100644 --- a/src/smt/smt_internalizer.cpp +++ b/src/smt/smt_internalizer.cpp @@ -607,7 +607,7 @@ namespace smt { m_lambdas.insert(lam_node, q); m_app2enode.setx(q->get_id(), lam_node, nullptr); m_l_internalized_stack.push_back(q); - m_trail_stack.push_back(&m_mk_lambda_trail); + m_trail_stack.push_ptr(&m_mk_lambda_trail); bool_var bv = get_bool_var(fa); assign(literal(bv, false), nullptr); mark_as_relevant(bv); @@ -958,7 +958,7 @@ namespace smt { m_activity[v] = 0.0; m_case_split_queue->mk_var_eh(v); m_b_internalized_stack.push_back(n); - m_trail_stack.push_back(&m_mk_bool_var_trail); + m_trail_stack.push_ptr(&m_mk_bool_var_trail); m_stats.m_num_mk_bool_var++; SASSERT(check_bool_var_vector_sizes()); return v; @@ -1009,7 +1009,8 @@ namespace smt { CTRACE(cached_generation, generation != m_generation, tout << "cached_generation: #" << n->get_id() << " " << generation << " " << m_generation << "\n";); } - enode * e = enode::mk(m, m_region, m_app2enode, n, generation, suppress_args, merge_tf, m_scope_lvl, cgc_enabled, true); + enode *e = enode::mk(m, get_region(), m_app2enode, n, generation, suppress_args, merge_tf, m_scope_lvl, + cgc_enabled, true); TRACE(mk_enode_detail, tout << "e.get_num_args() = " << e->get_num_args() << "\n";); if (m.is_unique_value(n)) e->mark_as_interpreted(); @@ -1017,7 +1018,7 @@ namespace smt { TRACE(generation, tout << "mk_enode: " << id << " " << generation << "\n";); m_app2enode.setx(id, e, nullptr); m_e_internalized_stack.push_back(n); - m_trail_stack.push_back(&m_mk_enode_trail); + m_trail_stack.push_ptr(&m_mk_enode_trail); m_enodes.push_back(e); if (e->get_num_args() > 0) { if (e->is_true_eq()) { @@ -1859,11 +1860,11 @@ namespace smt { if (old_v == null_theory_var) { enode * r = n->get_root(); theory_var v2 = r->get_th_var(th_id); - n->add_th_var(v, th_id, m_region); + n->add_th_var(v, th_id, get_region()); push_trail(add_th_var_trail(n, th_id)); if (v2 == null_theory_var) { if (r != n) - r->add_th_var(v, th_id, m_region); + r->add_th_var(v, th_id, get_region()); push_new_th_diseqs(r, v, th); } else if (r != n) { diff --git a/src/smt/smt_justification.cpp b/src/smt/smt_justification.cpp index d7b9bdff0..1ae673217 100644 --- a/src/smt/smt_justification.cpp +++ b/src/smt/smt_justification.cpp @@ -351,8 +351,9 @@ namespace smt { proof * ext_theory_propagation_justification::mk_proof(conflict_resolution & cr) { ptr_buffer prs; - if (!antecedent2proof(cr, prs)) + if (!antecedent2proof(cr, prs)) { return nullptr; + } context & ctx = cr.get_context(); ast_manager & m = cr.get_manager(); expr_ref fact(m); diff --git a/src/smt/smt_setup.cpp b/src/smt/smt_setup.cpp index a27bc99f2..69dec1348 100644 --- a/src/smt/smt_setup.cpp +++ b/src/smt/smt_setup.cpp @@ -40,6 +40,7 @@ Revision History: #include "smt/theory_pb.h" #include "smt/theory_fpa.h" #include "smt/theory_polymorphism.h" +#include "smt/theory_finite_set.h" namespace smt { @@ -784,6 +785,10 @@ namespace smt { m_context.register_plugin(alloc(smt::theory_char, m_context)); } + void setup::setup_finite_set() { + m_context.register_plugin(alloc(smt::theory_finite_set, m_context)); + } + void setup::setup_special_relations() { m_context.register_plugin(alloc(smt::theory_special_relations, m_context, m_manager)); } @@ -807,6 +812,7 @@ namespace smt { setup_dl(); setup_seq_str(st); setup_fpa(); + setup_finite_set(); setup_special_relations(); setup_polymorphism(); setup_relevancy(st); @@ -839,6 +845,7 @@ namespace smt { setup_bv(); setup_dl(); setup_seq_str(st); + setup_finite_set(); setup_fpa(); setup_recfuns(); setup_special_relations(); diff --git a/src/smt/smt_setup.h b/src/smt/smt_setup.h index 3d2bf47f3..897755ef7 100644 --- a/src/smt/smt_setup.h +++ b/src/smt/smt_setup.h @@ -102,6 +102,7 @@ namespace smt { void setup_seq_str(static_features const & st); void setup_seq(); void setup_char(); + void setup_finite_set(); void setup_card(); void setup_sls(); void setup_i_arith(); diff --git a/src/smt/smt_theory.h b/src/smt/smt_theory.h index 7050e27dd..abd8ae798 100644 --- a/src/smt/smt_theory.h +++ b/src/smt/smt_theory.h @@ -428,6 +428,8 @@ namespace smt { smt_params const& get_fparams() const; + virtual void updt_params() {} + enode * get_enode(theory_var v) const { SASSERT(v < static_cast(m_var2enode.size())); return m_var2enode[v]; diff --git a/src/smt/theory_datatype.cpp b/src/smt/theory_datatype.cpp index c7be4804c..e59ba7f10 100644 --- a/src/smt/theory_datatype.cpp +++ b/src/smt/theory_datatype.cpp @@ -354,6 +354,7 @@ namespace smt { for (unsigned i = 0; i < num_args; ++i) { enode * arg = e->get_arg(i); sort * s = arg->get_sort(); + sort *e_sort = nullptr; if (m_autil.is_array(s) && m_util.is_datatype(get_array_range(s))) { app_ref def(m_autil.mk_default(arg->get_expr()), m); if (!ctx.e_internalized(def)) { @@ -361,6 +362,13 @@ namespace smt { } arg = ctx.get_enode(def); } + if (m_fsutil.is_finite_set(s, e_sort) && m_util.is_datatype(e_sort)) { + app_ref def(m_fsutil.mk_empty(s), m); + if (!ctx.e_internalized(def)) { + ctx.internalize(def, false); + } + arg = ctx.get_enode(def); + } if (!m_util.is_datatype(s) && !m_sutil.is_seq(s)) continue; if (is_attached_to_var(arg)) @@ -799,8 +807,9 @@ namespace smt { found = true; } sort * s = arg->get_sort(); - if (m_autil.is_array(s) && m_util.is_datatype(get_array_range(s))) { - for (enode* aarg : get_array_args(arg)) { + sort *se = nullptr; + auto add_args = [&](ptr_vector const &args) { + for (enode *aarg : args) { if (aarg->get_root() == child->get_root()) { if (aarg != child) { m_used_eqs.push_back(enode_pair(aarg, child)); @@ -808,17 +817,16 @@ namespace smt { found = true; } } + }; + if (m_autil.is_array(s) && m_util.is_datatype(get_array_range(s))) { + add_args(get_array_args(arg)); + } + if (m_fsutil.is_finite_set(s, se) && m_util.is_datatype(se)) { + add_args(get_finite_set_args(arg)); } - sort* se = nullptr; if (m_sutil.is_seq(s, se) && m_util.is_datatype(se)) { - enode* sibling; - for (enode* aarg : get_seq_args(arg, sibling)) { - if (aarg->get_root() == child->get_root()) { - if (aarg != child) - m_used_eqs.push_back(enode_pair(aarg, child)); - found = true; - } - } + enode *sibling = nullptr; + add_args(get_seq_args(arg, sibling)); if (sibling && sibling != arg) m_used_eqs.push_back(enode_pair(arg, sibling)); @@ -907,6 +915,11 @@ namespace smt { return true; } } + else if (m_fsutil.is_finite_set(s, se) && m_util.is_datatype(se)) { + for (enode *aarg : get_finite_set_args(arg)) + if (process_arg(aarg)) + return true; + } else if (m_autil.is_array(s) && m_util.is_datatype(get_array_range(s))) { for (enode* aarg : get_array_args(arg)) if (process_arg(aarg)) @@ -916,6 +929,33 @@ namespace smt { return false; } + ptr_vector const &theory_datatype::get_finite_set_args(enode *n) { + m_args.reset(); + m_todo.reset(); + auto add_todo = [&](enode *n) { + if (!n->is_marked()) { + n->set_mark(); + m_todo.push_back(n); + } + }; + add_todo(n); + + for (unsigned i = 0; i < m_todo.size(); ++i) { + enode *n = m_todo[i]; + expr *e = n->get_expr(); + if (m_fsutil.is_singleton(e)) + m_args.push_back(n->get_arg(0)); + else if (m_fsutil.is_union(e)) + for (auto k : enode::args(n)) + add_todo(k); + } + for (enode *n : m_todo) + n->unset_mark(); + + return m_args; + } + + ptr_vector const& theory_datatype::get_seq_args(enode* n, enode*& sibling) { m_args.reset(); m_todo.reset(); @@ -1028,6 +1068,7 @@ namespace smt { m_util(m), m_autil(m), m_sutil(m), + m_fsutil(m), m_find(*this) { } diff --git a/src/smt/theory_datatype.h b/src/smt/theory_datatype.h index 7287b7da3..88c3be3a0 100644 --- a/src/smt/theory_datatype.h +++ b/src/smt/theory_datatype.h @@ -21,6 +21,7 @@ Revision History: #include "util/union_find.h" #include "ast/array_decl_plugin.h" #include "ast/seq_decl_plugin.h" +#include "ast/finite_set_decl_plugin.h" #include "ast/datatype_decl_plugin.h" #include "model/datatype_factory.h" #include "smt/smt_theory.h" @@ -60,6 +61,7 @@ namespace smt { datatype_util m_util; array_util m_autil; seq_util m_sutil; + finite_set_util m_fsutil; ptr_vector m_var_data; th_union_find m_find; trail_stack m_trail_stack; @@ -116,6 +118,7 @@ namespace smt { ptr_vector m_args, m_todo; ptr_vector const& get_array_args(enode* n); ptr_vector const& get_seq_args(enode* n, enode*& sibling); + ptr_vector const& get_finite_set_args(enode *n); // class for managing state of final_check class final_check_st { diff --git a/src/smt/theory_finite_set.cpp b/src/smt/theory_finite_set.cpp new file mode 100644 index 000000000..44c51d7dc --- /dev/null +++ b/src/smt/theory_finite_set.cpp @@ -0,0 +1,1037 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + theory_finite_set.cpp + +Abstract: + + Theory solver for finite sets. + Implements axiom schemas for finite set operations. + +--*/ + +#include "smt/theory_finite_set.h" +#include "smt/smt_context.h" +#include "smt/smt_model_generator.h" +#include "smt/smt_arith_value.h" +#include "ast/ast_pp.h" + +namespace smt { + + /** + * Constructor. + * Set up callback that adds axiom instantiations as clauses. + **/ + theory_finite_set::theory_finite_set(context& ctx): + theory(ctx, ctx.get_manager().mk_family_id("finite_set")), + u(m), + m_axioms(m), m_rw(m), m_find(*this), + m_cardinality_solver(*this) + { + // Setup the add_clause callback for axioms + std::function add_clause_fn = + [this](theory_axiom* ax) { + this->add_clause(ax); + }; + m_axioms.set_add_clause(add_clause_fn); + } + + theory_finite_set::~theory_finite_set() { + reset_set_members(); + } + + void theory_finite_set::reset_set_members() { + for (auto [k, s] : m_set_members) + dealloc(s); + m_set_members.reset(); + } + + /** + * When creating a theory variable, we associate extra data structures with it. + * if n := (set.in x S) + * then for every T in the equivalence class of S (including S), assert theory axioms for x in T. + * + * if n := (setop T U) + * then for every (set.in x S) where either S ~ T, S ~ U, assert theory axioms for x in n. + * Since n is fresh there are no other (set.in x S) with S ~ n in the state. + * + * if n := (set.filter p S) + * then for every (set.in x T) where S ~ T, assert theory axiom for x in (set.filter p S) + * + * if n := (set.map f S) + * then for every (set.in x T) where S ~ T, assert theory axiom for (set.in x S) and map. + * In other words, assert + * (set.in (f x) (set.map f S)) + */ + theory_var theory_finite_set::mk_var(enode *n) { + TRACE(finite_set, tout << "mk_var: " << enode_pp(n, ctx) << "\n"); + theory_var r = theory::mk_var(n); + VERIFY(r == static_cast(m_find.mk_var())); + SASSERT(r == static_cast(m_var_data.size())); + m_var_data.push_back(alloc(var_data, m)); + ctx.push_trail(push_back_vector>(m_var_data)); + ctx.push_trail(new_obj_trail(m_var_data.back())); + expr *e = n->get_expr(); + if (u.is_in(e)) { + auto set = n->get_arg(1); + auto v = set->get_root()->get_th_var(get_id()); + SASSERT(v != null_theory_var); + m_var_data[v]->m_parent_in.push_back(n); + ctx.push_trail(push_back_trail(m_var_data[v]->m_parent_in)); + add_in_axioms(n, m_var_data[v]); // add axioms x in S x S ~ T, T := setop, or T is arg of setop. + auto f = to_app(e)->get_decl(); + if (!m_set_in_decls.contains(f)) { + m_set_in_decls.push_back(f); + ctx.push_trail(push_back_vector(m_set_in_decls)); + } + } + else if (u.is_union(e) || u.is_intersect(e) || + u.is_difference(e) || u.is_singleton(e) || + u.is_empty(e) || u.is_range(e) || u.is_filter(e) || u.is_map(e)) { + m_var_data[r]->m_setops.push_back(n); + ctx.push_trail(push_back_trail(m_var_data[r]->m_setops)); + for (auto arg : enode::args(n)) { + expr *e = arg->get_expr(); + if (!u.is_finite_set(e)) + continue; + auto v = arg->get_root()->get_th_var(get_id()); + SASSERT(v != null_theory_var); + // add axioms for x in S, e := setop S T ... + for (auto in : m_var_data[v]->m_parent_in) + add_membership_axioms(in->get_arg(0)->get_expr(), e); + m_var_data[v]->m_parent_setops.push_back(n); + ctx.push_trail(push_back_trail(m_var_data[v]->m_parent_setops)); + } + } + else if (u.is_range(e)) { + + } + else if (u.is_size(e)) { + auto f = to_app(e)->get_decl(); + m_cardinality_solver.add_set_size(f); + } + return r; + } + + trail_stack& theory_finite_set::get_trail_stack() { + return ctx.get_trail_stack(); + } + + /* + * Merge the equivalence classes of two variables. + * parent_in := vector of (set.in x S) terms where S is in the equivalence class of r. + * parent_setops := vector of (set.op S T) where S or T is in the equivalence class of r. + * setops := vector of (set.op S T) where (set.op S T) is in the equivalence class of r. + * + */ + void theory_finite_set::merge_eh(theory_var root, theory_var other, theory_var, theory_var) { + // r is the new root + TRACE(finite_set, tout << "merging v" << root << " v" << other << "\n"; display_var(tout, root); + tout << " <- " << mk_pp(get_enode(other)->get_expr(), m) << "\n";); + SASSERT(root == find(root)); + var_data *d_root= m_var_data[root]; + var_data *d_other = m_var_data[other]; + ctx.push_trail(restore_vector(d_root->m_setops)); + ctx.push_trail(restore_vector(d_root->m_parent_setops)); + ctx.push_trail(restore_vector(d_root->m_parent_in)); + add_in_axioms(root, other); + add_in_axioms(other, root); + d_root->m_setops.append(d_other->m_setops); + d_root->m_parent_setops.append(d_other->m_parent_setops); + d_root->m_parent_in.append(d_other->m_parent_in); + TRACE(finite_set, tout << "after merge\n"; display_var(tout, root);); + } + + /* + * for each (set.in x S) in d1->parent_in, + * add axioms for (set.in x S) + */ + void theory_finite_set::add_in_axioms(theory_var v1, theory_var v2) { + auto d1 = m_var_data[v1]; + auto d2 = m_var_data[v2]; + for (enode *in : d1->m_parent_in) + add_in_axioms(in, d2); + } + + /* + * let (set.in x S) + * + * for each T := (set.op U V) in d2->parent_setops + * then S ~ U or S ~ V by construction + * add axioms for (set.in x T) + * + * for each T := (set.op U V) in d2->setops + * then S ~ T by construction + * add axioms for (set.in x T) + * + */ + + void theory_finite_set::add_in_axioms(enode *in, var_data *d) { + SASSERT(u.is_in(in->get_expr())); + auto e = in->get_arg(0)->get_expr(); + auto set1 = in->get_arg(1); + for (enode *setop : d->m_parent_setops) { + SASSERT( + any_of(enode::args(setop), [&](enode *arg) { return in->get_arg(1)->get_root() == arg->get_root(); })); + add_membership_axioms(e, setop->get_expr()); + } + + for (enode *setop : d->m_setops) { + SASSERT(in->get_arg(1)->get_root() == setop->get_root()); + add_membership_axioms(e, setop->get_expr()); + } + } + + /** + * Boolean atomic formulas for finite sets are one of: + * (set.in x S) + * (set.subset S T) + * When an atomic formula is first created it is to be registered with the solver. + * The internalize_atom method takes care of this. + * Atomic formulas are special cases of terms (of non-Boolean type) so they are registered as terms. + * + */ + bool theory_finite_set::internalize_atom(app * atom, bool gate_ctx) { + return internalize_term(atom); + } + + /** + * When terms are registered with the solver , we need to ensure that: + * - All subterms have an associated E-node + * - Boolean terms are registered as boolean variables + * Registering a Boolean variable ensures that the solver will be notified about its truth value. + * - Non-Boolean terms have an associated theory variable + * Registering a theory variable ensures that the solver will be notified about equalities and disequalites. + * The solver can use the theory variable to track auxiliary information about E-nodes. + */ + bool theory_finite_set::internalize_term(app * term) { + TRACE(finite_set, tout << "internalize_term: " << mk_pp(term, m) << "\n";); + + // Internalize all arguments first + for (expr* arg : *term) + ctx.internalize(arg, false); + + // Create boolean variable for Boolean terms + if (m.is_bool(term) && !ctx.b_internalized(term)) { + bool_var bv = ctx.mk_bool_var(term); + ctx.set_var_theory(bv, get_id()); + } + + // Create enode for the term if needed + enode* e = nullptr; + if (ctx.e_internalized(term)) + e = ctx.get_enode(term); + else + e = ctx.mk_enode(term, false, m.is_bool(term), true); + + // Attach theory variable if this is a set + if (!is_attached_to_var(e)){ + ctx.attach_th_var(e, this, mk_var(e)); + TRACE(finite_set, tout << "create_theory_var: " << e->get_th_var(get_id()) << " enode:" << e->get_expr() << "\n";); + } + + + // Assert immediate axioms + if (!ctx.relevancy()) + add_immediate_axioms(term); + + return true; + } + + void theory_finite_set::relevant_eh(app* t) { + add_immediate_axioms(t); + } + + void theory_finite_set::apply_sort_cnstr(enode* n, sort* s) { + SASSERT(u.is_finite_set(s)); + if (!is_attached_to_var(n)) + ctx.attach_th_var(n, this, mk_var(n)); + } + + void theory_finite_set::new_eq_eh(theory_var v1, theory_var v2) { + TRACE(finite_set, tout << "new_eq_eh: v" << v1 << " = v" << v2 << "\n";); + auto n1 = get_enode(v1); + auto n2 = get_enode(v2); + if (u.is_finite_set(n1->get_expr()) && u.is_finite_set(n2->get_expr())) { + m_eqs.push_back({v1, v2}); + ctx.push_trail(push_back_vector(m_eqs)); + m_find.merge(v1, v2); // triggers merge_eh, which triggers incremental generation of theory axioms + } + + // Check if Z3 has a boolean variable for it + TRACE(finite_set, tout << "new_eq_eh_r1: " << n1->get_root() << "r2: "<< n2->get_root() <<"\n";); + } + + /** + * Every dissequality has to be supported by at distinguishing element. + * + */ + void theory_finite_set::new_diseq_eh(theory_var v1, theory_var v2) { + TRACE(finite_set, tout << "new_diseq_eh: v" << v1 << " != v" << v2 << "\n"); + auto n1 = get_enode(v1); + auto n2 = get_enode(v2); + auto e1 = n1->get_expr(); + auto e2 = n2->get_expr(); + if (u.is_finite_set(e1) && u.is_finite_set(e2)) { + if (e1->get_id() > e2->get_id()) + std::swap(e1, e2); + if (!is_new_axiom(e1, e2)) + return; + if (are_forced_distinct(n1, n2)) + return; + m_diseqs.push_back({v1, v2}); + ctx.push_trail(push_back_vector(m_diseqs)); + m_axioms.extensionality_axiom(e1, e2); + } + } + + // + // TODO: add implementation that detects sets that are known to be distinct. + // for example, + // . x in a is assigned to true and y in b is assigned to false and x ~ y + // . or upper-bound(set.size(a)) < lower-bound(set.size(b)) + // where upper and lower bounds can be queried using arith_value + // + bool theory_finite_set::are_forced_distinct(enode* a, enode* b) { + return false; + } + + /** + * Final check for the finite set theory. + * The Final Check method is called when the solver has assigned truth values to all Boolean variables. + * It is responsible for asserting any remaining axioms and checking for inconsistencies. + * + * It ensures saturation with respect to the theory axioms: + * - membership axioms + * - assume eqs axioms + */ + final_check_status theory_finite_set::final_check_eh(unsigned) { + TRACE(finite_set, tout << "final_check_eh\n";); + + if (activate_unasserted_clause()) + return FC_CONTINUE; + + if (activate_range_local_axioms()) + return FC_CONTINUE; + + if (assume_eqs()) + return FC_CONTINUE; + + switch (m_cardinality_solver.final_check()) { + case l_true: break; + case l_false: return FC_CONTINUE; + case l_undef: return FC_GIVEUP; + } + + return is_fully_solved() ? FC_DONE : FC_GIVEUP; + } + + /** + * Determine if the constraints are fully solved. + * They can be fully solved if: + * - the model that is going to be produced satisfies all constraints + * The model will always satisfy the constraints if: + * - there is no occurrence of set.map + * - there is not both set.size and set.filter + */ + bool theory_finite_set::is_fully_solved() { + bool has_map = false, has_filter = false, has_size = false, has_range = false; + for (unsigned v = 0; v < get_num_vars(); ++v) { + auto n = get_enode(v); + auto e = n->get_expr(); + if (u.is_filter(e)) + has_filter = true; + if (u.is_map(e)) + has_map = true; + if (u.is_size(e)) + has_size = true; + if (u.is_range(e)) + has_range = true; + } + TRACE(finite_set, tout << "has-map " << has_map << " has-filter-size " << has_filter << has_size << "\n"); + if (has_map) + return false; // todo use more expensive model check here + if (has_filter && has_size) + return false; // todo use more expensive model check here + if (has_range && has_size) + return false; // todo use more expensive model check here or even ensure range expressions are handled natively by size. + return true; + } + + + /** + * Add immediate axioms that can be asserted when the atom is created. + * These are unit clauses that can be added immediately. + * - (set.in x set.empty) is false + * - (set.subset S T) <=> (= (set.union S T) T) (or (= (set.intersect S T) S)) + * - (set.singleton x) -> (set.in x (set.singleton x)) + * - (set.range lo hi) -> lo-1,hi+1 not in range, lo, hi in range if lo <= hi * + * + * Other axioms: + * - (set.size s) -> 0 <= (set.size s) <= upper-bound(s) + */ + void theory_finite_set::add_immediate_axioms(app* term) { + expr *elem = nullptr, *set = nullptr; + expr *lo = nullptr, *hi = nullptr; + unsigned sz = m_clauses.axioms.size(); + if (u.is_in(term, elem, set) && u.is_empty(set)) + add_membership_axioms(elem, set); + else if (u.is_subset(term)) + m_axioms.subset_axiom(term); + else if (u.is_singleton(term)) + m_axioms.in_singleton_axiom(term); + else if (u.is_range(term, lo, hi)) { + m_axioms.in_range_axiom(term); + auto range = ctx.get_enode(term); + auto v = range->get_th_var(get_id()); + // declare lo-1, lo, hi, hi+1 as range local. + // we don't have to add additional range local variables for them. + auto &range_local = m_var_data[v]->m_range_local; + ctx.push_trail(push_back_vector(range_local)); + arith_util a(m); + range_local.push_back(lo); + range_local.push_back(hi); + range_local.push_back(a.mk_add(lo, a.mk_int(-1))); + range_local.push_back(a.mk_add(hi, a.mk_int(1))); + } + else if (u.is_size(term)) { + m_axioms.size_lb_axiom(term); + m_axioms.size_ub_axiom(term); + } + + // Assert all new lemmas as clauses + for (unsigned i = sz; i < m_clauses.axioms.size(); ++i) { + m_clauses.squeue.push_back(i); + ctx.push_trail(push_back_vector(m_clauses.squeue)); + } + } + + void theory_finite_set::assign_eh(bool_var v, bool is_true) { + TRACE(finite_set, tout << "assign_eh: v" << v << " is_true: " << is_true << "\n";); + expr *e = ctx.bool_var2expr(v); + TRACE(finite_set, tout << "assign_eh_expr: " << mk_pp(e, m) << "\n";); + + // retrieve the watch list for clauses where e appears with opposite polarity + unsigned idx = 2 * e->get_id() + (is_true ? 1 : 0); + if (idx >= m_clauses.watch.size()) + return; + + // walk the watch list and try to find new watches or propagate + unsigned j = 0; + for (unsigned i = 0; i < m_clauses.watch[idx].size(); ++i) { + TRACE(finite_set, tout << "watch[" << i << "] size: " << m_clauses.watch[i].size() << "\n";); + auto clause_idx = m_clauses.watch[idx][i]; + auto* ax = m_clauses.axioms[clause_idx]; + auto &clause = ax->clause; + if (any_of(clause, [&](expr *lit) { return ctx.find_assignment(lit) == l_true; })) { + TRACE(finite_set, tout << "satisfied\n";); + m_clauses.watch[idx][j++] = clause_idx; + continue; // clause is already satisfied + } + auto is_complement_to = [&](bool is_true, expr* lit, expr* arg) { + if (is_true) + return m.is_not(lit) && to_app(lit)->get_arg(0) == arg; + else + return lit == arg; + }; + auto lit1 = clause.get(0); + auto lit2 = clause.get(1); + auto position = 0; + if (is_complement_to(is_true, lit1, e)) + position = 0; + else { + SASSERT(is_complement_to(is_true, lit2, e)); + position = 1; + } + + bool found_swap = false; + for (unsigned k = 2; k < clause.size(); ++k) { + expr *lit = clause.get(k); + if (ctx.find_assignment(lit) == l_false) + continue; + // found a new watch + clause.swap(position, k); + // std::swap(clause[position], clause[k]); + bool litneg = m.is_not(lit, lit); + auto litid = 2 * lit->get_id() + litneg; + m_clauses.watch.reserve(litid + 1); + m_clauses.watch[litid].push_back(clause_idx); + TRACE(finite_set, tout << "new watch for " << mk_pp(lit, m) << "\n";); + found_swap = true; + break; + } + if (found_swap) + continue; // the clause is removed from this watch list + // either all literals are false, or the other watch literal is propagating. + m_clauses.squeue.push_back(clause_idx); + ctx.push_trail(push_back_vector(m_clauses.squeue)); + TRACE(finite_set, tout << "propagate clause\n";); + m_clauses.watch[idx][j++] = clause_idx; + ++i; + for (; i < m_clauses.watch[idx].size(); ++i) + m_clauses.watch[idx][j++] = m_clauses.watch[idx][i]; + break; + } + m_clauses.watch[idx].shrink(j); + } + + bool theory_finite_set::can_propagate() { + return m_clauses.can_propagate(); + } + + void theory_finite_set::propagate() { + // TRACE(finite_set, tout << "propagate\n";); + ctx.push_trail(value_trail(m_clauses.aqhead)); + ctx.push_trail(value_trail(m_clauses.sqhead)); + while (can_propagate() && !ctx.inconsistent()) { + // activate newly created clauses + while (m_clauses.aqhead < m_clauses.axioms.size()) + activate_clause(m_clauses.aqhead++); + + // empty the propagation queue of clauses to assert + while (m_clauses.sqhead < m_clauses.squeue.size() && !ctx.inconsistent()) { + auto clause_idx = m_clauses.squeue[m_clauses.sqhead++]; + auto ax = m_clauses.axioms[clause_idx]; + assert_clause(ax); + } + } + } + + void theory_finite_set::activate_clause(unsigned clause_idx) { + TRACE(finite_set, tout << "activate_clause: " << clause_idx << "\n";); + auto* ax = m_clauses.axioms[clause_idx]; + auto &clause = ax->clause; + if (any_of(clause, [&](expr *e) { return ctx.find_assignment(e) == l_true; })) + return; + if (clause.size() <= 1) { + m_clauses.squeue.push_back(clause_idx); + ctx.push_trail(push_back_vector(m_clauses.squeue)); + return; + } + expr *w1 = nullptr, *w2 = nullptr; + for (unsigned i = 0; i < clause.size(); ++i) { + expr *lit = clause.get(i); + switch (ctx.find_assignment(lit)) { + case l_true: + UNREACHABLE(); + return; + case l_false: + break; + case l_undef: + if (!w1) { + w1 = lit; + clause.swap(0, i); + } + else if (!w2) { + w2 = lit; + clause.swap(1, i); + } + break; + } + } + if (!w2) { + m_clauses.squeue.push_back(clause_idx); + ctx.push_trail(push_back_vector(m_clauses.squeue)); + return; + } + bool w1neg = m.is_not(w1, w1); + bool w2neg = m.is_not(w2, w2); + unsigned w1id = 2 * w1->get_id() + w1neg; + unsigned w2id = 2 * w2->get_id() + w2neg; + unsigned sz = std::max(w1id, w2id) + 1; + m_clauses.watch.reserve(sz); + m_clauses.watch[w1id].push_back(clause_idx); + m_clauses.watch[w2id].push_back(clause_idx); + + struct unwatch_clause : public trail { + theory_finite_set &th; + unsigned index; + unwatch_clause(theory_finite_set &th, unsigned index) : th(th), index(index) {} + void undo() override { + auto* ax = th.m_clauses.axioms[index]; + auto &clause = ax->clause; + expr *w1 = clause.get(0); + expr *w2 = clause.get(1); + bool w1neg = th.m.is_not(w1, w1); + bool w2neg = th.m.is_not(w2, w2); + unsigned w1id = 2 * w1->get_id() + w1neg; + unsigned w2id = 2 * w2->get_id() + w2neg; + auto &watch1 = th.m_clauses.watch[w1id]; + auto &watch2 = th.m_clauses.watch[w2id]; + watch1.erase(index); + watch2.erase(index); + } + }; + ctx.push_trail(unwatch_clause(*this, clause_idx)); + } + + + + /** + * Saturate with respect to equality sharing: + * - Sets corresponding to shared variables having the same interpretation should also be congruent + */ + bool theory_finite_set::assume_eqs() { + collect_members(); + expr_ref_vector trail(m); // make sure reference counts to union expressions are valid in this scope + obj_map set_reprs; + + auto start = ctx.get_random_value(); + auto sz = get_num_vars(); + for (unsigned w = 0; w < sz; ++w) { + auto v = (w + start) % sz; + enode* n = get_enode(v); + if (!u.is_finite_set(n->get_expr())) + continue; + if (!is_relevant_and_shared(n)) + continue; + auto r = n->get_root(); + // Create a union expression that is canonical (sorted) + if (!m_set_members.contains(r)) + continue; + auto& set = *m_set_members[r]; + ptr_vector elems; + for (auto [e,b] : set) + if (b) + elems.push_back(e->get_expr()); + std::sort(elems.begin(), elems.end(), [](expr *a, expr *b) { return a->get_id() < b->get_id(); }); + expr *s = mk_union(elems.size(), elems.data(), n->get_expr()->get_sort()); + trail.push_back(s); + enode *n2 = nullptr; + if (!set_reprs.find(s, n2)) { + set_reprs.insert(s, r); + continue; + } + if (n2->get_root() == r) + continue; + if (is_new_axiom(n->get_expr(), n2->get_expr()) && assume_eq(n, n2)) { + TRACE(finite_set, + tout << "assume " << mk_pp(n->get_expr(), m) << " = " << mk_pp(n2->get_expr(), m) << "\n";); + return true; + } + } + return false; + } + + app* theory_finite_set::mk_union(unsigned num_elems, expr* const* elems, sort* set_sort) { + app *s = nullptr; + for (unsigned i = 0; i < num_elems; ++i) + s = s ? u.mk_union(s, u.mk_singleton(elems[i])) : u.mk_singleton(elems[i]); + return s ? s : u.mk_empty(set_sort); + } + + + bool theory_finite_set::is_new_axiom(expr* a, expr* b) { + struct insert_obj_pair_table : public trail { + obj_pair_hashtable &table; + expr *a, *b; + insert_obj_pair_table(obj_pair_hashtable &t, expr *a, expr *b) : table(t), a(a), b(b) {} + void undo() override { + table.erase({a, b}); + } + }; + if (m_clauses.members.contains({a, b})) + return false; + m_clauses.members.insert({a, b}); + ctx.push_trail(insert_obj_pair_table(m_clauses.members, a, b)); + return true; + } + + /** + * Instantiate axioms for a given element in a set. + */ + void theory_finite_set::add_membership_axioms(expr *elem, expr *set) { + TRACE(finite_set, tout << "add_membership_axioms: " << mk_pp(elem, m) << " in " << mk_pp(set, m) << "\n";); + + // Instantiate appropriate axiom based on set structure + if (!is_new_axiom(elem, set)) + ; + else if (u.is_empty(set)) + m_axioms.in_empty_axiom(elem); + else if (u.is_singleton(set)) + m_axioms.in_singleton_axiom(elem, set); + else if (u.is_union(set)) + m_axioms.in_union_axiom(elem, set); + else if (u.is_intersect(set)) + m_axioms.in_intersect_axiom(elem, set); + else if (u.is_difference(set)) + m_axioms.in_difference_axiom(elem, set); + else if (u.is_range(set)) + m_axioms.in_range_axiom(elem, set); + else if (u.is_map(set)) { + // sort *elem_sort = u.finte_set_elem_sort(set->get_sort()); + + // set.map_inverse can loop. need to check instance generation. + m_axioms.in_map_axiom(elem, set); + + // this can also loop: + m_axioms.in_map_image_axiom(elem, set); + } + else if (u.is_filter(set)) { + m_axioms.in_filter_axiom(elem, set); + } + } + + void theory_finite_set::add_clause(theory_axiom* ax) { + TRACE(finite_set, tout << "add_clause: " << *ax << "\n"); + ctx.push_trail(push_back_vector(m_clauses.axioms)); + ctx.push_trail(new_obj_trail(ax)); + m_clauses.axioms.push_back(ax); + m_stats.m_num_axioms_created++; + } + + theory * theory_finite_set::mk_fresh(context * new_ctx) { + return alloc(theory_finite_set, *new_ctx); + } + + void theory_finite_set::display(std::ostream & out) const { + out << "theory_finite_set:\n"; + for (unsigned i = 0; i < m_clauses.axioms.size(); ++i) + out << "[" << i << "]: " << *m_clauses.axioms[i] << "\n"; + for (unsigned v = 0; v < get_num_vars(); ++v) + display_var(out, v); + for (unsigned i = 0; i < m_clauses.watch.size(); ++i) { + if (m_clauses.watch[i].empty()) + continue; + out << "watch[" << i << "] := " << m_clauses.watch[i] << "\n"; + } + m_cardinality_solver.display(out); + } + + void theory_finite_set::init_model(model_generator & mg) { + TRACE(finite_set, tout << "init_model\n";); + // Model generation will use default interpretation for sets + // The model will be constructed based on the membership literals that are true + m_factory = alloc(finite_set_factory, m, u.get_family_id(), mg.get_model()); + mg.register_factory(m_factory); + collect_members(); + m_cardinality_solver.init_model(mg); + } + + void theory_finite_set::collect_members() { + // This method can be used to collect all elements that are members of sets + // and ensure that the model factory has values for them. + // For now, we rely on the default model construction. + reset_set_members(); + + for (auto f : m_set_in_decls) { + for (auto n : ctx.enodes_of(f)) { + if (!ctx.is_relevant(n)) + continue; + SASSERT(u.is_in(n->get_expr())); + auto x = n->get_arg(0)->get_root(); + if (x->is_marked()) + continue; + x->set_mark(); // make sure we only do this once per element + for (auto p : enode::parents(x)) { + if (!ctx.is_relevant(p)) + continue; + if (!u.is_in(p->get_expr())) + continue; + bool b = ctx.find_assignment(p->get_expr()) == l_true; + auto set = p->get_arg(1)->get_root(); + auto elem = p->get_arg(0)->get_root(); + if (elem != x) + continue; + if (!m_set_members.contains(set)) { + using om = obj_map; + auto m = alloc(om); + m_set_members.insert(set, m); + } + m_set_members.find(set)->insert(x, b); + } + } + } + for (auto f : m_set_in_decls) { + for (auto n : ctx.enodes_of(f)) { + SASSERT(u.is_in(n->get_expr())); + auto x = n->get_arg(0)->get_root(); + if (x->is_marked()) + x->unset_mark(); + } + } + } + + // to collect range interpretations for S: + // walk parents of S that are (set.in x S) + // evaluate truth value of (set.in x S), evaluate x + // add (eval(x), eval(set.in x S)) into a vector. + // sort the vector by eval(x) + // [(v1, b1), (v2, b2), ..., (vn, bn)] + // for the first i, with b_i true, add the range [vi, v_{i+1}-1]. + // the last bn should be false by construction. + + void theory_finite_set::add_range_interpretation(enode* s) { + vector> elements; + arith_value av(m); + av.init(&ctx); + for (auto p : enode::parents(s)) { + if (!ctx.is_relevant(p)) + continue; + if (u.is_in(p->get_expr())) { + rational val; + auto x = p->get_arg(0)->get_root(); + VERIFY(av.get_value_equiv(x->get_expr(), val) && val.is_int()); + elements.push_back({val, x, ctx.find_assignment(p->get_expr()) == l_true}); + } + } + std::stable_sort(elements.begin(), elements.end(), + [](auto const &a, auto const &b) { return std::get<0>(a) < std::get<0>(b); }); + + + } + + struct finite_set_value_proc : model_value_proc { + theory_finite_set &th; + app_ref m_unique_value; + enode *n = nullptr; + obj_map* m_elements = nullptr; + + // use range interpretations if there is a range constraint and the base sort is integer + bool use_range() { + auto &m = th.m; + sort *base_s = nullptr; + VERIFY(th.u.is_finite_set(n->get_expr()->get_sort(), base_s)); + arith_util a(m); + if (!a.is_int(base_s)) + return false; + func_decl_ref range_fn(th.u.mk_range_decl(), m); + return th.ctx.get_num_enodes_of(range_fn.get()) > 0; + } + + finite_set_value_proc(theory_finite_set &th, enode *n, obj_map *elements) + : th(th), m_unique_value(th.m), n(n), m_elements(elements) {} + + finite_set_value_proc(theory_finite_set &th, app* value) + : th(th), m_unique_value(value, th.m) {} + + void get_dependencies(buffer &result) override { + if (m_unique_value) + return; + if (!m_elements) + return; + bool ur = use_range(); + for (auto [n, b] : *m_elements) + if (!ur || b) + result.push_back(model_value_dependency(n)); + } + + app *mk_value(model_generator &mg, expr_ref_vector const &values) override { + if (m_unique_value) + return m_unique_value; + auto s = n->get_sort(); + if (values.empty()) + return th.u.mk_empty(s); + + SASSERT(m_elements); + if (use_range()) + return mk_range_value(mg, values); + else + return th.mk_union(values.size(), values.data(), s); + } + + app *mk_range_value(model_generator &mg, expr_ref_vector const &values) { + unsigned i = 0; + arith_value av(th.m); + av.init(&th.ctx); + vector> elems; + for (auto [n, b] : *m_elements) { + rational r; + av.get_value(n->get_expr(), r); + elems.push_back({r, n, b}); + } + std::stable_sort(elems.begin(), elems.end(), + [](auto const &a, auto const &b) { return std::get<0>(a) < std::get<0>(b); }); + app *range = nullptr; + arith_util a(th.m); + + for (unsigned j = 0; j < elems.size(); ++j) { + auto [r, n, b] = elems[j]; + if (!b) + continue; + rational lo = r; + rational hi = j + 1 < elems.size() ? std::get<0>(elems[j + 1]) - rational(1) : r; + while (j + 1 < elems.size() && std::get<0>(elems[j + 1]) == hi + rational(1) && std::get<2>(elems[j + 1])) { + hi = std::get<0>(elems[j + 1]); + ++j; + } + auto new_range = th.u.mk_range(a.mk_int(lo), a.mk_int(hi)); + range = range ? th.u.mk_union(range, new_range) : new_range; + } + return range ? range : th.u.mk_empty(n->get_sort()); + } + }; + + model_value_proc * theory_finite_set::mk_value(enode * n, model_generator & mg) { + TRACE(finite_set, tout << "mk_value: " << mk_pp(n->get_expr(), m) << "\n";); + app *value = m_cardinality_solver.get_unique_value(n->get_expr()); + if (value) + return alloc(finite_set_value_proc, *this, value); + obj_map*elements = nullptr; + n = n->get_root(); + m_set_members.find(n, elements); + return alloc(finite_set_value_proc, *this, n, elements); + } + + + /** + * a theory axiom can be unasserted if it contains two or more literals that have + * not been internalized yet. + */ + bool theory_finite_set::activate_unasserted_clause() { + for (auto const &clause : m_clauses.axioms) { + if (assert_clause(clause)) + return true; + } + return false; + } + + /* + * Add x-1, x+1 in range axioms for every x in setop(range, S) + * then x-1, x+1 will also propagate against setop(range, S). + */ + bool theory_finite_set::activate_range_local_axioms() { + bool new_axiom = false; + func_decl_ref range_fn(u.mk_range_decl(), m); + for (auto range : ctx.enodes_of(range_fn.get())) { + SASSERT(u.is_range(range->get_expr())); + auto v = range->get_th_var(get_id()); + for (auto p : m_var_data[v]->m_parent_setops) { + auto w = p->get_th_var(get_id()); + for (auto in : m_var_data[w]->m_parent_in) { + if (activate_range_local_axioms(in->get_arg(0)->get_expr(), range)) + new_axiom = true; + } + } + } + return new_axiom; + } + + + bool theory_finite_set::activate_range_local_axioms(expr* elem, enode* range) { + auto v = range->get_th_var(get_id()); + auto &range_local = m_var_data[v]->m_range_local; + auto &parent_in = m_var_data[v]->m_parent_in; + + // simplify elem to canonical form (e.g., (1+1) -> 2) + expr_ref elem_simplified(elem, m); + m_rw(elem_simplified); + + if (range_local.contains(elem_simplified)) + return false; + arith_util a(m); + expr_ref lo(a.mk_add(elem_simplified, a.mk_int(-1)), m); + expr_ref hi(a.mk_add(elem_simplified, a.mk_int(1)), m); + + // simplify lo and hi to avoid nested expressions like ((1+1)+1) + m_rw(lo); + m_rw(hi); + bool new_axiom = false; + if (!range_local.contains(lo) && all_of(parent_in, [lo](enode* in) { return in->get_arg(0)->get_expr() != lo; })) { + // lo is not range local and lo is not already in an expression (lo in range) + // checking that lo is not in range_local is actually redundant because we will instantiate + // membership expressions for every range local expression. + // but we keep this set and check for now in case we want to change the saturation strategy. + ctx.push_trail(push_back_vector(range_local)); + range_local.push_back(lo); + m_axioms.in_range_axiom(lo, range->get_expr()); + new_axiom = true; + } + if (!range_local.contains(hi) && + all_of(parent_in, [&hi](enode *in) { return in->get_arg(0)->get_expr() != hi; })) { + ctx.push_trail(push_back_vector(range_local)); + range_local.push_back(hi); + m_axioms.in_range_axiom(hi, range->get_expr()); + new_axiom = true; + } + return new_axiom; + } + + bool theory_finite_set::assert_clause(theory_axiom const *ax) { + expr *unit = nullptr; + unsigned undef_count = 0; + auto &clause = ax->clause; + for (auto e : clause) { + switch (ctx.find_assignment(e)) { + case l_true: + return false; // clause is already satisfied + case l_false: + break; + case l_undef: + ++undef_count; + unit = e; + break; + } + } + + if (undef_count == 1) { + TRACE(finite_set, tout << "propagate unit:" << clause << "\n";); + auto lit = mk_literal(unit); + literal_vector antecedent; + for (auto e : clause) { + if (e != unit) + antecedent.push_back(~mk_literal(e)); + } + m_stats.m_num_axioms_propagated++; + enode_pair_vector eqs; + auto just = ext_theory_propagation_justification(get_id(), ctx, antecedent.size(), antecedent.data(), eqs.size(), eqs.data(), + lit, ax->params.size(), ax->params.data()); + auto bjust = ctx.mk_justification(just); + if (ctx.clause_proof_active()) { + // assume all justifications is a non-empty list of symbol parameters + // proof logging is basically broken: it doesn't log propagations, but instead + // only propagations that are processed by conflict resolution. + // this misses conflicts at base level. + proof_ref pr(m); + proof_ref_vector args(m); + for (auto a : antecedent) + args.push_back(m.mk_hypothesis(ctx.literal2expr(a))); + pr = m.mk_th_lemma(get_id(), unit, args.size(), args.data(), ax->params.size(), ax->params.data()); + justification_proof_wrapper jp(ctx, pr.get(), false); + ctx.get_clause_proof().propagate(lit, &jp, antecedent); + jp.del_eh(m); + } + ctx.assign(lit, bjust); + return true; + } + bool is_conflict = (undef_count == 0); + if (is_conflict) + m_stats.m_num_axioms_conflicts++; + else + m_stats.m_num_axioms_case_splits++; + TRACE(finite_set, tout << "assert " << (is_conflict ? "conflict" : "case split") << clause << "\n";); + literal_vector lclause; + for (auto e : clause) + lclause.push_back(mk_literal(e)); + ctx.mk_th_axiom(get_id(), lclause, ax->params.size(), ax->params.data()); + return true; + } + + std::ostream& theory_finite_set::display_var(std::ostream& out, theory_var v) const { + out << "v" << v << " := " << enode_pp(get_enode(v), ctx) << "\n"; + auto d = m_var_data[v]; + if (!d->m_setops.empty()) { + out << " setops: "; + for (auto n : d->m_setops) + out << enode_pp(n, ctx) << " "; + out << "\n"; + } + if (!d->m_parent_setops.empty()) { + out << " parent_setops: "; + for (auto n : d->m_parent_setops) + out << enode_pp(n, ctx) << " "; + out << "\n"; + } + if (!d->m_parent_in.empty()) { + out << " parent_in: "; + for (auto n : d->m_parent_in) + out << enode_pp(n, ctx) << " "; + out << "\n"; + } + + return out; + } + +} // namespace smt diff --git a/src/smt/theory_finite_set.h b/src/smt/theory_finite_set.h new file mode 100644 index 000000000..472249960 --- /dev/null +++ b/src/smt/theory_finite_set.h @@ -0,0 +1,212 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + theory_finite_set.h + +Abstract: + + The theory solver relies on instantiating axiom schemas for finite sets. + The instantation rules can be represented as implementing inference rules + that encode the semantics of set operations. + It reduces satisfiability into a combination of satisfiability of arithmetic and uninterpreted functions. + + This module implements axiom schemas that are invoked by saturating constraints + with respect to the semantics of set operations. + + Let v1 ~ v2 mean that v1 and v2 are congruent + + The set-based decision procedure relies on saturating with respect + to rules of the form: + + x in v1 a term, v1 ~ set.empty + ----------------------------------- + not (x in set.empty) + + + x in v1 a term , v1 ~ v3, v3 := (set.union v4 v5) + ----------------------------------------------- + x in v3 <=> x in v4 or x in v5 + + x in v1 a term, v1 ~ v3, v3 := (set.intersect v4 v5) + --------------------------------------------------- + x in v3 <=> x in v4 and x in v5 + + x in v1 a term, v1 ~ v3, v3 := (set.difference v4 v5) + --------------------------------------------------- + x in v3 <=> x in v4 and not (x in v5) + + x in v1 a term, v1 ~ v3, v3 := (set.singleton v4) + ----------------------------------------------- + x in v3 <=> x == v4 + + x in v1 a term, v1 ~ v3, v3 := (set.range lo hi) + ----------------------------------------------- + x in v3 <=> (lo <= x <= hi) + + x in v1 a term, v1 ~ v3, v3 := (set.map f v4) + ----------------------------------------------- + x in v3 <=> set.map_inverse(f, x, v4) in v4 + + x in v1 a term, v1 ~ v3, v3 := (set.map f v4) + ----------------------------------------------- + x in v4 => f(x) in v3 + + + x in v1 is a term, v1 ~ v3, v3 == (set.filter p v4) + ----------------------------------------------- + x in v3 <=> p(x) and x in v4 + +Rules are encoded in src/ast/rewriter/finite_set_axioms.cpp as clauses. + +The central claim is that the above rules are sufficient to +decide satisfiability of finite set constraints for a subset +of operations, namely union, intersection, difference, singleton, membership. +Model construction proceeds by selecting every set.in(x_i, v) for a +congruence root v. Then the set of elements { x_i | set.in(x_i, v) } +is the interpretation. + +This approach for model-construction, however, does not work with ranges, or is impractical. +For ranges we can adapt a different model construction approach. + +When introducing select and map, decidability can be lost. + + +--*/ + +#pragma once + +#include "ast/ast.h" +#include "ast/ast_pp.h" +#include "ast/finite_set_decl_plugin.h" +#include "ast/rewriter/finite_set_axioms.h" +#include "util/obj_pair_hashtable.h" +#include "util/union_find.h" +#include "smt/smt_theory.h" +#include "smt/theory_finite_set_size.h" +#include "model/finite_set_factory.h" + +namespace smt { + class context; + + class theory_finite_set : public theory { + using th_union_find = union_find; + friend class theory_finite_set_test; + friend class theory_finite_set_size; + friend struct finite_set_value_proc; + + struct var_data { + ptr_vector m_setops; // set operations equivalent to this + ptr_vector m_parent_in; // x in A expressions + ptr_vector m_parent_setops; // set of set expressions where this appears as sub-expression + expr_ref_vector m_range_local; // set of range local variables associated with range + var_data(ast_manager &m) : m_range_local(m) {} + }; + + struct theory_clauses { + ptr_vector axioms; // vector of created theory axioms + unsigned aqhead = 0; // queue head of created axioms + unsigned_vector squeue; // propagation queue of axioms to be added to the solver + unsigned sqhead = 0; // head into propagation queue axioms to be added to solver + obj_pair_hashtable members; // set of membership axioms that were instantiated + vector watch; // watch list from expression index to clause occurrence + + bool can_propagate() const { + return sqhead < squeue.size() || aqhead < axioms.size(); + } + }; + + struct stats { + unsigned m_num_axioms_created = 0; + unsigned m_num_axioms_conflicts = 0; + unsigned m_num_axioms_propagated = 0; + unsigned m_num_axioms_case_splits = 0; + + void collect_statistics(::statistics & st) const { + st.update("finite-set-axioms-created", m_num_axioms_created); + st.update("finite-set-axioms-propagated", m_num_axioms_propagated); + st.update("finite-set-axioms-conflicts", m_num_axioms_conflicts); + st.update("finite-set-axioms-case-splits", m_num_axioms_case_splits); + } + }; + + finite_set_util u; + finite_set_axioms m_axioms; + th_rewriter m_rw; + th_union_find m_find; + theory_clauses m_clauses; + theory_finite_set_size m_cardinality_solver; + finite_set_factory *m_factory = nullptr; + obj_map *> m_set_members; + ptr_vector m_set_in_decls; + ptr_vector m_var_data; + svector> m_diseqs, m_eqs; + stats m_stats; + + protected: + // Override relevant methods from smt::theory + bool internalize_atom(app * atom, bool gate_ctx) override; + bool internalize_term(app * term) override; + void new_eq_eh(theory_var v1, theory_var v2) override; + void new_diseq_eh(theory_var v1, theory_var v2) override; + void apply_sort_cnstr(enode *n, sort *s) override; + final_check_status final_check_eh(unsigned) override; + bool can_propagate() override; + void propagate() override; + void assign_eh(bool_var v, bool is_true) override; + void relevant_eh(app *n) override; + + theory * mk_fresh(context * new_ctx) override; + char const * get_name() const override { return "finite_set"; } + void display(std::ostream & out) const override; + void init_model(model_generator & mg) override; + model_value_proc * mk_value(enode * n, model_generator & mg) override; + theory_var mk_var(enode *n) override; + + void collect_statistics(::statistics & st) const override { + m_stats.collect_statistics(st); + } + + void add_in_axioms(theory_var v1, theory_var v2); + void add_in_axioms(enode *in, var_data *d); + + // Helper methods for axiom instantiation + void add_membership_axioms(expr* elem, expr* set); + void add_clause(theory_axiom * ax); + bool assert_clause(theory_axiom const *ax); + void activate_clause(unsigned index); + bool activate_unasserted_clause(); + void add_immediate_axioms(app *atom); + bool activate_range_local_axioms(); + bool activate_range_local_axioms(expr *elem, enode *range); + bool assume_eqs(); + bool is_new_axiom(expr *a, expr *b); + app *mk_union(unsigned num_elems, expr *const *elems, sort* set_sort); + bool is_fully_solved(); + + // model construction + void collect_members(); + void reset_set_members(); + void add_range_interpretation(enode *s); + + // manage union-find of theory variables + theory_var find(theory_var v) const { return m_find.find(v); } + bool is_root(theory_var v) const { return m_find.is_root(v); } + + std::ostream &display_var(std::ostream &out, theory_var v) const; + + bool are_forced_distinct(enode *a, enode *b); + + public: + theory_finite_set(context& ctx); + ~theory_finite_set() override; + + // for union-find + trail_stack &get_trail_stack(); + void merge_eh(theory_var v1, theory_var v2, theory_var, theory_var); + void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) {} + void unmerge_eh(theory_var v1, theory_var v2) {} + }; + +} // namespace smt diff --git a/src/smt/theory_finite_set_size.cpp b/src/smt/theory_finite_set_size.cpp new file mode 100644 index 000000000..2bae076d6 --- /dev/null +++ b/src/smt/theory_finite_set_size.cpp @@ -0,0 +1,479 @@ + +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + theory_finite_set_size.cpp + +Abstract: + + Theory solver for finite sets. + Implements axiom schemas for finite set operations. + +--*/ + +#include "smt/theory_finite_set.h" +#include "smt/theory_finite_set_size.h" +#include "smt/smt_context.h" +#include "smt/smt_model_generator.h" +#include "smt/smt_arith_value.h" +#include "ast/ast_pp.h" + +namespace smt { + + theory_finite_set_size::theory_finite_set_size(theory_finite_set& th): + m(th.m), ctx(th.ctx), th(th), u(m), bs(m), m_assumption(m), m_slacks(m), m_pinned(m) {} + + void theory_finite_set_size::add_set_size(func_decl* f) { + if (!m_set_size_decls.contains(f)) { + m_set_size_decls.push_back(f); + ctx.push_trail(push_back_trail(m_set_size_decls)); + } + } + + void theory_finite_set_size::initialize_solver() { + struct clear_solver : public trail { + theory_finite_set_size &s; + public: + clear_solver(theory_finite_set_size &s) : s(s) {} + void undo() override { + s.m_solver = nullptr; + s.n2b.reset(); + s.m_assumptions.reset(); + s.bs.reset(); + s.m_slacks.reset(); + s.m_slack_members.reset(); + s.m_pinned.reset(); + s.m_unique_values.reset(); + } + }; + ctx.push_trail(clear_solver(*this)); + m_solver = alloc(context, m, ctx.get_fparams(), ctx.get_params()); + // collect all expressons that use cardinality constraints + // collect cone of influence of sets terminals used in cardinality constraints + // for every visited uninterpreted set variable add a Boolean variable + // for every visited set expression add definitions as constraints to m_cardinality solver + // introduce fresh variables for set membership constraints + // assume that distinct enodes in set memberships produce different sets + // assert added disequalities + // assert equalities based on union-find equalities + // assert set membership constraints by in + + enode_vector ns; + collect_subexpressions(ns); + + // + // we now got all subexpressions from equalities, disequalities, set.in + // + // associate a Boolean variable with every set enode + for (auto n : ns) { + std::ostringstream strm; + strm << "|" << enode_pp(n, ctx) << "|"; + symbol name = symbol(strm.str()); + expr_ref b(m.mk_const(name, m.mk_bool_sort()), m); + bs.push_back(b); + n2b.insert(n, b); + TRACE(finite_set, tout << "assoc " << name << " to " << enode_pp(n, ctx) << " " << enode_pp(n->get_root(), ctx) << "\n";); + } + + add_def_axioms(ns); + // b_{s u t} <-> b_{s} or b_{t}, + // b_{s n t} <-> b_{s} and b_{t}, + // b_{s\t} <-> b_{s} and not b_{t} + add_singleton_axioms(ns); // (set.in x s) -> b_{x} => b_s - for occurrences of (set.in x s) + add_diseq_axioms(ns); // s = t or |s\t| > 0 or |t\s| > 0 - for disqualities + add_eq_axioms(ns); // s = t -> b_{s} <=> b_{t} - for equalities + + TRACE(finite_set, display(tout)); + } + + /** + * For every (supported) set expression ensure associated Boolean expressions follow semantics + */ + void theory_finite_set_size::add_def_axioms(enode_vector const& ns) { + for (auto n : ns) { + expr *e = n->get_expr(); + if (u.is_union(e)) { + auto a = n2b[n->get_arg(0)]; + auto b = n2b[n->get_arg(1)]; + m_solver->assert_expr(m.mk_iff(n2b[n], m.mk_or(a, b))); + } + else if (u.is_intersect(e)) { + auto a = n2b[n->get_arg(0)]; + auto b = n2b[n->get_arg(1)]; + m_solver->assert_expr(m.mk_iff(n2b[n], m.mk_and(a, b))); + } + else if (u.is_difference(e)) { + auto a = n2b[n->get_arg(0)]; + auto not_b = m.mk_not(n2b[n->get_arg(1)]); + m_solver->assert_expr(m.mk_iff(n2b[n], m.mk_and(a, not_b))); + } + } + } + + enode* theory_finite_set_size::mk_singleton(enode* n) { + expr_ref s(u.mk_singleton(n->get_expr()), m); + ctx.ensure_internalized(s); + ctx.mark_as_relevant(s.get()); + return ctx.get_enode(s); + } + + enode* theory_finite_set_size::mk_diff(enode* a, enode* b) { + expr_ref d(u.mk_difference(a->get_expr(), b->get_expr()), m); + ctx.ensure_internalized(d); + ctx.mark_as_relevant(d.get()); + return ctx.get_enode(d); + } + + /** + * For every set membership (set.in x s) track propositional + * (set.in x S) => b_{x} => b_S + * ~(set.in x S) => b_{x} => not b_S + * + * Constrain singletons with cardinality constraints: + * |{x}| = 1 + */ + + void theory_finite_set_size::add_singleton_axioms(enode_vector const &ns) { + for (auto n : ns) { + for (auto p : enode::parents(n)) { + if (!u.is_in(p->get_expr())) + continue; + if (!ctx.is_relevant(p)) + continue; + auto v = ctx.get_assignment(p); + if (v == l_undef) + continue; + auto e = p->get_arg(0)->get_root(); + TRACE(finite_set, tout << "singleton axiom for " << enode_pp(e, ctx) << " in " << enode_pp(p, ctx) + << " is " << v << "\n";); + auto s = mk_singleton(e); + SASSERT(n2b.contains(p->get_arg(1))); + SASSERT(n2b.contains(s)); + auto X = n2b[s]; + auto S = n2b[p->get_arg(1)]; + if (v == l_false) + S = m.mk_not(S); + auto is_in = m.mk_implies(X, S); + in lit(p, v == l_true); + std::ostringstream strm; + strm << "|" << (v == l_false ? "~":"") << enode_pp(p, ctx) << "|"; + symbol name = symbol(strm.str()); + expr* t = m.mk_const(name, m.mk_bool_sort()); + bs.push_back(t); + m_assumptions.insert(t, lit); + m_solver->assert_expr(m.mk_implies(t, is_in)); + + // add size axiom |s| = 1 + arith_util a(m); + auto l = th.mk_literal(m.mk_eq(u.mk_size(s->get_expr()), a.mk_int(1))); + ctx.mk_th_axiom(th.get_id(), l); + } + } + } + + /** + * For every asserted equality ensure equivalence + */ + void theory_finite_set_size::add_eq_axioms(enode_vector const &ns) { + for (auto [a, b] : th.m_eqs) { + auto x = th.get_enode(a); + auto y = th.get_enode(b); + if (n2b.contains(x) && n2b.contains(y)) { + eq e = {a, b}; + std::ostringstream strm; + strm << "|" << enode_pp(x, ctx) << " == " << enode_pp(y, ctx) << "|"; + symbol name = symbol(strm.str()); + auto t = m.mk_const(name, m.mk_bool_sort()); + bs.push_back(t); + m_assumptions.insert(t, e); + m_solver->assert_expr(m.mk_implies(t, m.mk_iff(n2b[x], n2b[y]))); + } + } + } + + /* + * For every disequality include the assertions x = y or |x\y| >= 1 or |y\z| >= 1 + * The expressions x\y and y\x are created when ns is created. + */ + void theory_finite_set_size::add_diseq_axioms(enode_vector const &ns) { + for (auto [a, b] : th.m_diseqs) { + auto x = th.get_enode(a); + auto y = th.get_enode(b); + diseq d = {a, b}; + if (n2b.contains(x) && n2b.contains(y)) { + arith_util a(m); + auto d1 = mk_diff(x, y); + auto d2 = mk_diff(y, x); + expr_ref sz1(u.mk_size(d1->get_expr()), m); + expr_ref sz2(u.mk_size(d2->get_expr()), m); + literal l_eq = th.mk_literal(m.mk_eq(x->get_expr(), y->get_expr())); + literal l1 = th.mk_literal(a.mk_ge(sz1, a.mk_int(1))); + literal l2 = th.mk_literal(a.mk_ge(sz2, a.mk_int(1))); + ctx.mk_th_axiom(th.get_id(), l_eq, l1, l2); + } + } + } + + /** + * Walk the cone of influence of expresions that depend on ns + */ + void theory_finite_set_size::collect_subexpressions(enode_vector &ns) { + // initialize disequality watch list + u_map v2diseqs, v2eqs; + for (auto [a, b] : th.m_diseqs) { + v2diseqs.insert_if_not_there(a, unsigned_vector()).push_back(b); + v2diseqs.insert_if_not_there(b, unsigned_vector()).push_back(a); + } + for (auto [a, b] : th.m_eqs) { + v2eqs.insert_if_not_there(a, unsigned_vector()).push_back(b); + v2eqs.insert_if_not_there(b, unsigned_vector()).push_back(a); + } + + auto add_expression = [&](enode *e) { + if (!ctx.is_relevant(e)) + return; + if (e->is_marked()) + return; + e->set_mark(); + ns.push_back(e); + }; + + auto is_setop = [&](enode *n) { + auto e = n->get_expr(); + return u.is_union(e) || u.is_intersect(e) || u.is_difference(e); + }; + + for (auto f : m_set_size_decls) { + for (auto n : ctx.enodes_of(f)) { + SASSERT(u.is_size(n->get_expr())); + add_expression(n->get_arg(0)); + } + } + for (unsigned i = 0; i < ns.size(); ++i) { + auto n = ns[i]; + // add children under set operations + if (is_setop(n)) + for (auto arg : enode::args(n)) + add_expression(arg); + // add parents that are operations and use n + for (auto p : enode::parents(n)) + if (is_setop(p) && any_of(enode::args(p), [n](auto a) { return a == n; })) + add_expression(p); + // add equalities and disequalities + auto v = th.get_th_var(n); + if (v2eqs.contains(v)) { + auto const &other = v2eqs[v]; + for (auto w : other) + add_expression(th.get_enode(w)); + } + if (v2diseqs.contains(v)) { + auto const &other = v2diseqs[v]; + for (auto w : other) { + auto n2 = th.get_enode(w); + add_expression(n2); + auto D1 = mk_diff(n, n2); + auto D2 = mk_diff(n2, n); + ctx.mark_as_relevant(D1->get_expr()); + ctx.mark_as_relevant(D2->get_expr()); + add_expression(D1); + add_expression(D2); + } + } + for (auto p : enode::parents(n)) { + if (!u.is_in(p->get_expr())) + continue; + if (!ctx.is_relevant(p)) + continue; + auto x = p->get_arg(0)->get_root(); + auto X = mk_singleton(x); + ctx.mark_as_relevant(X->get_expr()); + add_expression(X); + } + } + for (auto n : ns) + n->unset_mark(); + } + + + /** + * 1. Base implementation: + * Enumerate all satisfying assignments to m_solver for atoms based on |s| + * Extract Core from enumeration + * Assert Core => |s_i| = sum_ij n_ij for each |s_i| cardinality expression + * NB. Soundness of using Core has not been rigorously established. + * 2. We can check with theory_lra if slack_sums constraints are linear + * feasible. If they are we can possibly terminate by extracting a model + * If they are infeasible, temporarily strengthen m_solver using the negation of unsat core + * that comes from infeasible slack propositions. Then the next model releaxes at least one + * slack variable that is part of the infeasible subset. + */ + lbool theory_finite_set_size::run_solver() { + expr_ref_vector asms(m); + for (auto [k, v] : m_assumptions) + asms.push_back(k); + + m_slacks.reset(); + m_slack_members.reset(); + expr_ref_vector slack_exprs(m); + obj_map slack_sums; + arith_util a(m); + expr_ref z(a.mk_int(0), m); + for (auto f : m_set_size_decls) + for (auto n : ctx.enodes_of(f)) + slack_sums.insert(n->get_expr(), z); + + while (true) { + lbool r = m_solver->check(asms.size(), asms.data()); + if (r == l_false) { + auto const& core = m_solver->unsat_core(); + literal_vector lits; + for (auto c : core) { + auto exp = m_assumptions[c]; + if (std::holds_alternative(exp)) { + auto [a, b] = std::get(exp); + expr_ref eq(m.mk_eq(th.get_expr(a), th.get_expr(b)), m); + lits.push_back(~th.mk_literal(eq)); + } + else if (std::holds_alternative(exp)) { + auto [a, b] = std::get(exp); + expr_ref eq(m.mk_eq(th.get_expr(a), th.get_expr(b)), m); + lits.push_back(th.mk_literal(eq)); + } + else if (std::holds_alternative(exp)) { + auto [n, is_pos] = std::get(exp); + auto lit = th.mk_literal(n->get_expr()); + lits.push_back(is_pos ? ~lit : lit); + } + } + for (auto f : m_set_size_decls) { + for (auto n : ctx.enodes_of(f)) { + expr_ref eq(m.mk_eq(n->get_expr(), slack_sums[n->get_expr()]), m); + auto lit = th.mk_literal(eq); + literal_vector lemma(lits); + lemma.push_back(lit); + TRACE(finite_set, tout << "Asserting cardinality lemma\n"; + for (auto lit : lemma) tout << ctx.literal2expr(lit) << "\n";); + ctx.mk_th_axiom(th.get_id(), lemma); + } + } + ctx.push_trail(value_trail(m_solver_ran)); + TRACE(finite_set, ctx.display(tout << "Core " << core << "\n")); + m_solver_ran = true; + return l_false; + } + if (r != l_true) + return r; + + expr_ref slack(m.mk_fresh_const(symbol("slack"), a.mk_int()), m); + ctx.mk_th_axiom(th.get_id(), th.mk_literal(a.mk_ge(slack, a.mk_int(0)))); // slack is non-negative + model_ref mdl; + m_solver->get_model(mdl); + + + expr_ref_vector props(m); + for (auto f : m_set_size_decls) { + for (auto n : ctx.enodes_of(f)) { + auto arg = n->get_arg(0); + auto b = n2b[arg]; + auto b_is_true = mdl->is_true(b); + props.push_back(b_is_true ? b : m.mk_not(b)); + if (b_is_true) { + auto s = slack_sums[n->get_expr()]; + s = s == z ? slack : a.mk_add(s, slack); + slack_exprs.push_back(s); + slack_sums.insert(n->get_expr(), s); + } + } + } + m_slacks.push_back(slack); + ptr_vector members; + for (auto [n, b] : n2b) { + expr *e = n->get_expr(); + if (is_uninterp_const(e) && mdl->is_true(b)) + members.push_back(e); + } + m_slack_members.push_back(members); + TRACE(finite_set, tout << *mdl << "\nPropositional model:\n" << props << "\n"); + m_solver->assert_expr(m.mk_not(m.mk_and(props))); + } + return l_undef; + } + + lbool theory_finite_set_size::final_check() { + if (m_set_size_decls.empty()) + return l_true; + if (!m_solver) { + initialize_solver(); + return l_false; + } + if (!m_solver_ran) + return run_solver(); + + // + // at this point we assume that + // cardinality constraints are satisfied + // by arithmetic solver. + // + // a refinement checks if this is really necessary + // + return l_true; + } + + // + // construct model based on set variables that have cardinality constraints + // In this case the model construction is not explicit. It uses unique sets + // to represent sets of given cardinality. + // + void theory_finite_set_size::init_model(model_generator &mg) { + if (!m_solver || !m_solver_ran) + return; + TRACE(finite_set, tout << "Constructing model for finite set cardinalities\n";); + // + // construct model based on set variables that have cardinality constraints + // slack -> (set variable x truth assignment)* + // slack -> integer assignment from arithmetic solver + // u.mk_unique_set(unique_index, slack_value, type); + // add to model of set variables that are true for slack. + // + arith_value av(m); + av.init(&ctx); + rational value; + arith_util a(m); + SASSERT(m_slacks.size() == m_slack_members.size()); + unsigned unique_index = 0; + for (unsigned i = 0; i < m_slacks.size(); ++i) { + auto s = m_slacks.get(i); + // + // slack s is equivalent to some integer value + // create a unique set corresponding to this slack value. + // The signature of the unique set is given by the sets that are + // satisfiable in the propositional assignment where the slack variable + // was introduced. + // + if (av.get_value_equiv(s, value)) { + if (value == 0) + continue; + if (m_slack_members[i].empty()) + continue; + + ++unique_index; + for (auto e : m_slack_members[i]) { + app *unique_value = u.mk_unique_set(a.mk_int(unique_index), a.mk_int(value), e->get_sort()); + if (m_unique_values.contains(e)) + unique_value = u.mk_union(m_unique_values[e], unique_value); + m_unique_values.insert(e, unique_value); + m_pinned.push_back(unique_value); + } + } + } + } + + + std::ostream& theory_finite_set_size::display(std::ostream& out) const { + if (m_solver) + m_solver->display(out << "set.size-solver\n"); + return out; + } +} // namespace smt \ No newline at end of file diff --git a/src/smt/theory_finite_set_size.h b/src/smt/theory_finite_set_size.h new file mode 100644 index 000000000..2be7f1f0d --- /dev/null +++ b/src/smt/theory_finite_set_size.h @@ -0,0 +1,79 @@ + +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + theory_finite_set_size.h + +Abstract: + + sub-solver for cardinality constraints of finite sets + +--*/ + +#pragma once + +#include "ast/ast.h" +#include "ast/ast_pp.h" +#include "ast/finite_set_decl_plugin.h" +#include "ast/rewriter/finite_set_axioms.h" +#include "util/obj_pair_hashtable.h" +#include "util/union_find.h" +#include "smt/smt_theory.h" +#include "model/finite_set_factory.h" + +namespace smt { + class context; + class theory_finite_set; + + class theory_finite_set_size { + struct diseq { + theory_var x, y; + }; + struct eq { + theory_var x, y; + }; + struct in { + enode *n; + bool is_pos; + }; + using tracking_literal = std::variant; + ast_manager &m; + context &ctx; + theory_finite_set &th; + finite_set_util u; + scoped_ptr m_solver; + bool m_solver_ran = false; + ptr_vector m_set_size_decls; + expr_ref_vector bs; + obj_map n2b; + obj_map m_assumptions; + expr_ref m_assumption; + expr_ref_vector m_slacks; + vector> m_slack_members; + obj_map m_unique_values; + app_ref_vector m_pinned; + + void collect_subexpressions(enode_vector& ns); + void add_def_axioms(enode_vector const &ns); + void add_singleton_axioms(enode_vector const &ns); + void add_eq_axioms(enode_vector const &ns); + void add_diseq_axioms(enode_vector const &ns); + enode *mk_singleton(enode* n); + enode *mk_diff(enode *a, enode *b); + void initialize_solver(); + + lbool run_solver(); + + public: + theory_finite_set_size(theory_finite_set &th); + void add_set_size(func_decl *f); + lbool final_check(); + std::ostream &display(std::ostream &out) const; + void init_model(model_generator &mg); + app *get_unique_value(expr *n) { + return m_unique_values.contains(n) ? m_unique_values[n] : nullptr; + } + }; +} \ No newline at end of file diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index b3b2afce4..b5f249a00 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -870,15 +870,10 @@ public: get_zero(true); get_zero(false); + lp().updt_params(ctx().get_params()); lp().settings().set_resource_limit(m_resource_limit); lp().settings().bound_propagation() = bound_prop_mode::BP_NONE != propagation_mode(); - - // todo : do not use m_arith_branch_cut_ratio for deciding on cheap cuts - unsigned branch_cut_ratio = ctx().get_fparams().m_arith_branch_cut_ratio; - lp().set_cut_strategy(branch_cut_ratio); - - lp().settings().set_run_gcd_test(ctx().get_fparams().m_arith_gcd_test); lp().settings().set_random_seed(ctx().get_fparams().m_random_seed); m_lia = alloc(lp::int_solver, *m_solver.get()); } @@ -890,6 +885,9 @@ public: mk_is_int_axiom(n); } + ptr_vector m_delay_ineqs; + unsigned m_delay_ineqs_qhead = 0; + bool internalize_atom(app * atom, bool gate_ctx) { TRACE(arith_internalize, tout << bpp(atom) << "\n";); SASSERT(!ctx().b_internalized(atom)); @@ -920,6 +918,11 @@ public: internalize_is_int(atom); return true; } + else if (a.is_le(atom) || a.is_ge(atom)) { + m_delay_ineqs.push_back(atom); + ctx().push_trail(push_back_vector>(m_delay_ineqs)); + return true; + } else { TRACE(arith, tout << "Could not internalize " << mk_pp(atom, m) << "\n";); found_unsupported(atom); @@ -1629,6 +1632,61 @@ public: return FC_DONE; return FC_GIVEUP; } + + /** + * Check if a set of equalities are lp feasible. + * push local scope + * internalize ineqs + * assert ineq constraints + * check lp feasibility + * extract core + * pop local scope + * return verdict + */ + + lbool check_lp_feasible(vector> &ineqs, literal_vector& lit_core, enode_pair_vector& eq_core) { + lbool st = l_undef; + push_scope_eh(); // pushes an arithmetic scope + u_map ci2index; + unsigned index = 0; + for (auto &[in_core, f] : ineqs) { + expr *x, *y; + rational r; + in_core = false; + if (m.is_eq(f, x, y) && a.is_numeral(y, r)) { + internalize_term(to_app(x)); + auto j = get_lpvar(th.get_th_var(x)); + auto ci = lp().add_var_bound(j, lp::EQ, r); + ci2index.insert(ci, index); + lp().activate(ci); + if (is_infeasible()) { + st = l_false; + break; + } + } + else { + NOT_IMPLEMENTED_YET(); + } + ++index; + } + if (st != l_false) { + st = make_feasible(); + SASSERT(st != l_false || is_infeasible()); + } + if (st == l_false) { + m_explanation.clear(); + lp().get_infeasibility_explanation(m_explanation); + for (auto ev : m_explanation) { + unsigned index; + if (ci2index.find(ev.ci(), index)) + ineqs[index].first = true; + else + set_evidence(ev.ci(), lit_core, eq_core); + } + } + pop_scope_eh(1); + return st; + } final_check_status final_check_eh(unsigned level) { if (propagate_core()) @@ -2126,6 +2184,8 @@ public: unsigned total_conflicts = ctx().get_num_conflicts(); if (total_conflicts < 10) return true; + if (m_delay_ineqs_qhead < m_delay_ineqs.size()) + return true; double f = static_cast(m_num_conflicts)/static_cast(total_conflicts); return f >= adaptive_assertion_threshold(); } @@ -2135,7 +2195,8 @@ public: } bool can_propagate_core() { - return m_asserted_atoms.size() > m_asserted_qhead || m_new_def || lp().has_changed_columns(); + return m_asserted_atoms.size() > m_asserted_qhead || m_new_def || lp().has_changed_columns() || + m_delay_ineqs_qhead < m_delay_ineqs.size(); } bool propagate() { @@ -2150,6 +2211,29 @@ public: return true; if (!can_propagate_core()) return false; + + for (; m_delay_ineqs_qhead < m_delay_ineqs.size() && !ctx().inconsistent() && m.inc(); ++m_delay_ineqs_qhead) { + auto atom = m_delay_ineqs[m_delay_ineqs_qhead]; + ctx().push_trail(value_trail(m_delay_ineqs_qhead)); + if (!ctx().is_relevant(atom)) + continue; + expr *x, *y; + if (a.is_le(atom, x, y)) { + auto lit1 = mk_literal(atom); + auto lit2 = mk_literal(a.mk_le(a.mk_sub(x, y), a.mk_numeral(rational(0), a.is_int(x->get_sort())))); + mk_axiom(~lit1, lit2); + mk_axiom(lit1, ~lit2); + } + else if (a.is_ge(atom, x, y)) { + auto lit1 = mk_literal(atom); + auto lit2 = mk_literal(a.mk_ge(a.mk_sub(x, y), a.mk_numeral(rational(0), a.is_int(x->get_sort())))); + mk_axiom(~lit1, lit2); + mk_axiom(lit1, ~lit2); + } + else { + UNREACHABLE(); + } + } m_new_def = false; while (m_asserted_qhead < m_asserted_atoms.size() && !ctx().inconsistent() && m.inc()) { @@ -4201,6 +4285,13 @@ public: m_bound_predicate = nullptr; } + void updt_params() { + if (m_solver) + m_solver->updt_params(ctx().get_params()); + if (m_nla) + m_nla->updt_params(ctx().get_params()); + } + void validate_model(proto_model& mdl) { @@ -4361,10 +4452,18 @@ void theory_lra::setup() { m_imp->setup(); } +void theory_lra::updt_params() { + m_imp->updt_params(); +} + void theory_lra::validate_model(proto_model& mdl) { m_imp->validate_model(mdl); } +lbool theory_lra::check_lp_feasible(vector>& ineqs, literal_vector& lit_core, enode_pair_vector& eq_core) { + return m_imp->check_lp_feasible(ineqs, lit_core, eq_core); +} + } template class lp::lp_bound_propagator; template void lp::lar_solver::propagate_bounds_for_touched_rows(lp::lp_bound_propagator&); diff --git a/src/smt/theory_lra.h b/src/smt/theory_lra.h index 8804d52eb..fb1a16b15 100644 --- a/src/smt/theory_lra.h +++ b/src/smt/theory_lra.h @@ -98,6 +98,14 @@ namespace smt { bool get_lower(enode* n, rational& r, bool& is_strict); bool get_upper(enode* n, rational& r, bool& is_strict); void solve_for(vector& s) override; + + + // check if supplied set of linear constraints are LP feasible within current backtracking context + // identify core by setting Boolean flags to true for constraints used in the proof of infeasibility + // and return l_false if infeasible. + lbool check_lp_feasible(vector> &ineqs, literal_vector& lit_core, enode_pair_vector& eq_core); + + void updt_params() override; void display(std::ostream & out) const override; diff --git a/src/solver/smt_logics.cpp b/src/solver/smt_logics.cpp index 0942ed3fe..a02b90880 100644 --- a/src/solver/smt_logics.cpp +++ b/src/solver/smt_logics.cpp @@ -24,8 +24,8 @@ Revision History: bool smt_logics::supported_logic(symbol const & s) { return logic_has_uf(s) || logic_is_all(s) || logic_has_fd(s) || logic_has_arith(s) || logic_has_bv(s) || - logic_has_array(s) || logic_has_seq(s) || logic_has_str(s) || - logic_has_horn(s) || logic_has_fpa(s) || logic_has_datatype(s); + logic_has_array(s) || logic_has_seq(s) || logic_has_str(s) || logic_has_horn(s) || logic_has_fpa(s) || + logic_has_datatype(s) || logic_has_finite_sets(s); } bool smt_logics::logic_has_reals_only(symbol const& s) { @@ -71,6 +71,13 @@ bool smt_logics::logic_has_bv(symbol const & s) { str == "HORN"; } +bool smt_logics::logic_has_finite_sets(symbol const &s) { + auto str = s.str(); + return + str.find("FS") != std::string::npos || + logic_is_all(s); +} + bool smt_logics::logic_has_array(symbol const & s) { auto str = s.str(); return diff --git a/src/solver/smt_logics.h b/src/solver/smt_logics.h index 80bebabcc..9a32e5708 100644 --- a/src/solver/smt_logics.h +++ b/src/solver/smt_logics.h @@ -27,6 +27,7 @@ public: static bool logic_has_arith(symbol const & s); static bool logic_has_bv(symbol const & s); static bool logic_has_array(symbol const & s); + static bool logic_has_finite_sets(symbol const &s); static bool logic_has_seq(symbol const & s); static bool logic_has_str(symbol const & s); static bool logic_has_fpa(symbol const & s); diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index dcf0856bb..31dd88078 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -56,6 +56,8 @@ add_executable(test-z3 expr_substitution.cpp ext_numeral.cpp f2n.cpp + finite_set.cpp + finite_set_rewriter.cpp factor_rewriter.cpp finder.cpp fixed_bit_vector.cpp diff --git a/src/test/finite_set.cpp b/src/test/finite_set.cpp new file mode 100644 index 000000000..0d113b15d --- /dev/null +++ b/src/test/finite_set.cpp @@ -0,0 +1,286 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + tst_finite_set.cpp + +Abstract: + + Test finite sets decl plugin + +Author: + + GitHub Copilot Agent 2025 + +Revision History: + +--*/ +#include "ast/ast.h" +#include "ast/finite_set_decl_plugin.h" +#include "ast/reg_decl_plugins.h" +#include "ast/arith_decl_plugin.h" +#include "ast/array_decl_plugin.h" + +static void tst_finite_set_basic() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + arith_util arith(m); + + // Test creating a finite set sort + sort_ref int_sort(arith.mk_int(), m); + parameter param(int_sort.get()); + sort_ref finite_set_int(m.mk_sort(fsets.get_family_id(), FINITE_SET_SORT, 1, ¶m), m); + + ENSURE(fsets.is_finite_set(finite_set_int.get())); + + // Test creating empty set + app_ref empty_set(fsets.mk_empty(finite_set_int), m); + ENSURE(fsets.is_empty(empty_set.get())); + ENSURE(empty_set->get_sort() == finite_set_int.get()); + + // Test set.singleton + expr_ref five(arith.mk_int(5), m); + app_ref singleton_set(fsets.mk_singleton(five), m); + ENSURE(fsets.is_singleton(singleton_set.get())); + ENSURE(singleton_set->get_sort() == finite_set_int.get()); + + // Test set.range + expr_ref zero(arith.mk_int(0), m); + expr_ref ten(arith.mk_int(10), m); + app_ref range_set(fsets.mk_range(zero, ten), m); + ENSURE(fsets.is_range(range_set.get())); + ENSURE(range_set->get_sort() == finite_set_int.get()); + + // Test set.union + app_ref union_set(fsets.mk_union(empty_set, range_set), m); + ENSURE(fsets.is_union(union_set.get())); + ENSURE(union_set->get_sort() == finite_set_int.get()); + + // Test set.intersect + app_ref intersect_set(fsets.mk_intersect(range_set, range_set), m); + ENSURE(fsets.is_intersect(intersect_set.get())); + ENSURE(intersect_set->get_sort() == finite_set_int.get()); + + // Test set.difference + app_ref diff_set(fsets.mk_difference(range_set, empty_set), m); + ENSURE(fsets.is_difference(diff_set.get())); + ENSURE(diff_set->get_sort() == finite_set_int.get()); + + // Test set.in + app_ref in_expr(fsets.mk_in(five, range_set), m); + ENSURE(fsets.is_in(in_expr.get())); + ENSURE(m.is_bool(in_expr->get_sort())); + + // Test set.size + app_ref size_expr(fsets.mk_size(range_set), m); + ENSURE(fsets.is_size(size_expr.get())); + ENSURE(arith.is_int(size_expr->get_sort())); + + // Test set.subset + app_ref subset_expr(fsets.mk_subset(empty_set, range_set), m); + ENSURE(fsets.is_subset(subset_expr.get())); + ENSURE(m.is_bool(subset_expr->get_sort())); +} + +static void tst_finite_set_map_filter() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + arith_util arith(m); + array_util autil(m); + + // Create Int and Bool sorts + sort_ref int_sort(arith.mk_int(), m); + sort_ref bool_sort(m.mk_bool_sort(), m); + + // Create finite set sorts + parameter int_param(int_sort.get()); + sort_ref finite_set_int(m.mk_sort(fsets.get_family_id(), FINITE_SET_SORT, 1, &int_param), m); + + // Create Array (Int Int) sort for map + sort_ref arr_int_int(autil.mk_array_sort(int_sort, int_sort), m); + + // Create a const array (conceptually represents the function) + app_ref arr_map(autil.mk_const_array(arr_int_int, arith.mk_int(42)), m); + + // Create a set and test map + expr_ref zero(arith.mk_int(0), m); + expr_ref ten(arith.mk_int(10), m); + app_ref range_set(fsets.mk_range(zero, ten), m); + + app_ref mapped_set(fsets.mk_map(arr_map, range_set), m); + ENSURE(fsets.is_map(mapped_set.get())); + ENSURE(fsets.is_finite_set(mapped_set->get_sort())); + + // Create Array (Int Bool) sort for filter + sort_ref arr_int_bool(autil.mk_array_sort(int_sort, bool_sort), m); + + // Create a const array for filter (conceptually represents predicate) + app_ref arr_filter(autil.mk_const_array(arr_int_bool, m.mk_true()), m); + + app_ref filtered_set(fsets.mk_filter(arr_filter, range_set), m); + ENSURE(fsets.is_filter(filtered_set.get())); + ENSURE(filtered_set->get_sort() == finite_set_int.get()); +} + +static void tst_finite_set_is_value() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + arith_util arith(m); + finite_set_decl_plugin* plugin = static_cast(m.get_plugin(fsets.get_family_id())); + + // Create Int sort and finite set sort + + // Test with Int sort (should be fully interpreted) + sort_ref int_sort(arith.mk_int(), m); + parameter int_param(int_sort.get()); + sort_ref finite_set_int(m.mk_sort(fsets.get_family_id(), FINITE_SET_SORT, 1, &int_param), m); + + + // Test 1: Empty set is a value + app_ref empty_set(fsets.mk_empty(finite_set_int), m); + ENSURE(plugin->is_value(empty_set.get())); + + // Test 2: Singleton with value element is a value + expr_ref five(arith.mk_int(5), m); + app_ref singleton_five(fsets.mk_singleton(five), m); + ENSURE(plugin->is_value(singleton_five.get())); + + // Test 3: Union of empty and singleton is a value + app_ref union_empty_singleton(fsets.mk_union(empty_set, singleton_five), m); + ENSURE(plugin->is_value(union_empty_singleton.get())); + + // Test 4: Union of two singletons with value elements is a value + expr_ref seven(arith.mk_int(7), m); + app_ref singleton_seven(fsets.mk_singleton(seven), m); + app_ref union_two_singletons(fsets.mk_union(singleton_five, singleton_seven), m); + ENSURE(plugin->is_value(union_two_singletons.get())); + + // Test 5: Nested union of singletons and empty sets is a value + app_ref union_nested(fsets.mk_union(union_empty_singleton, singleton_seven), m); + ENSURE(plugin->is_value(union_nested.get())); + + // Test 6: Union with empty set is a value + app_ref union_empty_empty(fsets.mk_union(empty_set, empty_set), m); + ENSURE(plugin->is_value(union_empty_empty.get())); + + // Test 7: Triple union is a value + expr_ref nine(arith.mk_int(9), m); + app_ref singleton_nine(fsets.mk_singleton(nine), m); + app_ref union_temp(fsets.mk_union(singleton_five, singleton_seven), m); + app_ref union_triple(fsets.mk_union(union_temp, singleton_nine), m); + ENSURE(plugin->is_value(union_triple.get())); + + // Test 8: Range is a value + expr_ref zero(arith.mk_int(0), m); + expr_ref ten(arith.mk_int(10), m); + app_ref range_set(fsets.mk_range(zero, ten), m); + ENSURE(plugin->is_value(range_set.get())); + + // Test 9: Union with range is a value + app_ref union_with_range(fsets.mk_union(singleton_five, range_set), m); + ENSURE(plugin->is_value(union_with_range.get())); + + // Test 10: Intersect is a value + app_ref intersect_set(fsets.mk_intersect(singleton_five, singleton_seven), m); + ENSURE(plugin->is_value(intersect_set.get())); + ENSURE(m.is_fully_interp(int_sort)); + ENSURE(m.is_fully_interp(finite_set_int)); +} + +static void tst_finite_set_is_fully_interp() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + + // Test with Bool sort (should be fully interpreted) + sort_ref bool_sort(m.mk_bool_sort(), m); + parameter bool_param(bool_sort.get()); + sort_ref finite_set_bool(m.mk_sort(fsets.get_family_id(), FINITE_SET_SORT, 1, &bool_param), m); + + ENSURE(m.is_fully_interp(bool_sort)); + ENSURE(m.is_fully_interp(finite_set_bool)); + + // Test with uninterpreted sort (should not be fully interpreted) + sort_ref uninterp_sort(m.mk_uninterpreted_sort(symbol("U")), m); + parameter uninterp_param(uninterp_sort.get()); + sort_ref finite_set_uninterp(m.mk_sort(fsets.get_family_id(), FINITE_SET_SORT, 1, &uninterp_param), m); + + ENSURE(!m.is_fully_interp(uninterp_sort)); + ENSURE(!m.is_fully_interp(finite_set_uninterp)); +} + +static void tst_finite_set_sort_size() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + + // Test 1: Bool sort (size 2) -> FiniteSet(Bool) should have size 2^2 = 4 + sort_ref bool_sort(m.mk_bool_sort(), m); + parameter bool_param(bool_sort.get()); + sort_ref finite_set_bool(m.mk_sort(fsets.get_family_id(), FINITE_SET_SORT, 1, &bool_param), m); + + sort_size const& bool_set_sz = finite_set_bool->get_num_elements(); + ENSURE(bool_set_sz.is_finite()); + ENSURE(!bool_set_sz.is_very_big()); + ENSURE(bool_set_sz.size() == 4); // 2^2 = 4 + + // Test 2: Create a finite sort with known size (e.g., BV with size 3) + // BV[3] has 2^3 = 8 elements, so FiniteSet(BV[3]) should have 2^8 = 256 elements + parameter bv_param(3); + sort_ref bv3_sort(m.mk_sort(m.mk_family_id("bv"), 0, 1, &bv_param), m); + parameter bv3_param(bv3_sort.get()); + sort_ref finite_set_bv3(m.mk_sort(fsets.get_family_id(), FINITE_SET_SORT, 1, &bv3_param), m); + + sort_size const& bv3_set_sz = finite_set_bv3->get_num_elements(); + ENSURE(bv3_set_sz.is_finite()); + ENSURE(!bv3_set_sz.is_very_big()); + ENSURE(bv3_set_sz.size() == 256); // 2^8 = 256 + + // Test 3: BV with size 5 -> BV[5] has 2^5 = 32 elements + // Since 32 > 30, FiniteSet(BV[5]) should be marked as very_big + parameter bv5_param(5); + sort_ref bv5_sort(m.mk_sort(m.mk_family_id("bv"), 0, 1, &bv5_param), m); + parameter bv5_set_param(bv5_sort.get()); + sort_ref finite_set_bv5(m.mk_sort(fsets.get_family_id(), FINITE_SET_SORT, 1, &bv5_set_param), m); + + sort_size const& bv5_set_sz = finite_set_bv5->get_num_elements(); + ENSURE(bv5_set_sz.is_very_big()); + + // Test 4: Int sort (infinite) -> FiniteSet(Int) should be infinite + arith_util arith(m); + sort_ref int_sort(arith.mk_int(), m); + parameter int_param(int_sort.get()); + sort_ref finite_set_int(m.mk_sort(fsets.get_family_id(), FINITE_SET_SORT, 1, &int_param), m); + + sort_size const& int_set_sz = finite_set_int->get_num_elements(); + ENSURE(int_set_sz.is_infinite()); + + // Test 5: BV with size 4 -> BV[4] has 2^4 = 16 elements + // FiniteSet(BV[4]) should have 2^16 = 65536 elements + parameter bv4_param(4); + sort_ref bv4_sort(m.mk_sort(m.mk_family_id("bv"), 0, 1, &bv4_param), m); + parameter bv4_set_param(bv4_sort.get()); + sort_ref finite_set_bv4(m.mk_sort(fsets.get_family_id(), FINITE_SET_SORT, 1, &bv4_set_param), m); + + sort_size const& bv4_set_sz = finite_set_bv4->get_num_elements(); + ENSURE(bv4_set_sz.is_finite()); + ENSURE(!bv4_set_sz.is_very_big()); + ENSURE(bv4_set_sz.size() == 65536); // 2^16 = 65536 +} + +void tst_finite_set() { + tst_finite_set_basic(); + tst_finite_set_map_filter(); + tst_finite_set_is_value(); + tst_finite_set_is_fully_interp(); + tst_finite_set_sort_size(); +} diff --git a/src/test/finite_set_rewriter.cpp b/src/test/finite_set_rewriter.cpp new file mode 100644 index 000000000..b2d80ab98 --- /dev/null +++ b/src/test/finite_set_rewriter.cpp @@ -0,0 +1,350 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + finite_set_rewriter.cpp + +Abstract: + + Test finite set rewriter + +Author: + + GitHub Copilot Agent 2025 + +--*/ + +#include "ast/ast.h" +#include "ast/finite_set_decl_plugin.h" +#include "ast/reg_decl_plugins.h" +#include "ast/arith_decl_plugin.h" +#include "ast/rewriter/finite_set_rewriter.h" + +class finite_set_rewriter_test { +public: + void test_union_idempotent() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + finite_set_rewriter rw(m); + arith_util arith(m); + + // Create a set + sort_ref int_sort(arith.mk_int(), m); + expr_ref zero(arith.mk_int(0), m); + expr_ref ten(arith.mk_int(10), m); + app_ref s1(fsets.mk_range(zero, ten), m); + + // Test set.union(s1, s1) -> s1 + app_ref union_app(fsets.mk_union(s1, s1), m); + expr_ref result(m); + br_status st = rw.mk_app_core(union_app->get_decl(), union_app->get_num_args(), union_app->get_args(), result); + + ENSURE(st == BR_DONE); + ENSURE(result == s1); + } + + void test_intersect_idempotent() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + finite_set_rewriter rw(m); + arith_util arith(m); + + // Create a set + sort_ref int_sort(arith.mk_int(), m); + expr_ref zero(arith.mk_int(0), m); + expr_ref ten(arith.mk_int(10), m); + app_ref s1(fsets.mk_range(zero, ten), m); + + // Test set.intersect(s1, s1) -> s1 + app_ref intersect_app(fsets.mk_intersect(s1, s1), m); + expr_ref result(m); + br_status st = + rw.mk_app_core(intersect_app->get_decl(), intersect_app->get_num_args(), intersect_app->get_args(), result); + + ENSURE(st == BR_DONE); + ENSURE(result == s1); + } + + void test_difference_same() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + finite_set_rewriter rw(m); + arith_util arith(m); + + // Create a set + sort_ref int_sort(arith.mk_int(), m); + expr_ref zero(arith.mk_int(0), m); + expr_ref ten(arith.mk_int(10), m); + app_ref s1(fsets.mk_range(zero, ten), m); + + // Test set.difference(s1, s1) -> empty + app_ref diff_app(fsets.mk_difference(s1, s1), m); + expr_ref result(m); + br_status st = rw.mk_app_core(diff_app->get_decl(), diff_app->get_num_args(), diff_app->get_args(), result); + + ENSURE(st == BR_DONE); + ENSURE(fsets.is_empty(result)); + } + + void test_subset_rewrite() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + finite_set_rewriter rw(m); + arith_util arith(m); + + // Create two sets + sort_ref int_sort(arith.mk_int(), m); + expr_ref zero(arith.mk_int(0), m); + expr_ref ten(arith.mk_int(10), m); + expr_ref twenty(arith.mk_int(20), m); + app_ref s1(fsets.mk_range(zero, ten), m); + app_ref s2(fsets.mk_range(zero, twenty), m); + + // Test set.subset(s1, s2) -> set.intersect(s1, s2) = s1 + app_ref subset_app(fsets.mk_subset(s1, s2), m); + expr_ref result(m); + br_status st = + rw.mk_app_core(subset_app->get_decl(), subset_app->get_num_args(), subset_app->get_args(), result); + + ENSURE(st == BR_REWRITE3); + ENSURE(m.is_eq(result)); + + // Check that result is an equality + app *eq = to_app(result); + ENSURE(eq->get_num_args() == 2); + + // The left side should be set.intersect(s1, s2) + expr *lhs = eq->get_arg(0); + ENSURE(fsets.is_intersect(lhs)); + + // The right side should be s1 + expr *rhs = eq->get_arg(1); + ENSURE(rhs == s1); + } + + void test_mk_app_core() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + finite_set_rewriter rw(m); + arith_util arith(m); + + // Create sets + sort_ref int_sort(arith.mk_int(), m); + expr_ref zero(arith.mk_int(0), m); + expr_ref ten(arith.mk_int(10), m); + app_ref s1(fsets.mk_range(zero, ten), m); + + // Test union through mk_app_core + app_ref union_app(fsets.mk_union(s1, s1), m); + expr_ref result(m); + br_status st = rw.mk_app_core(union_app->get_decl(), union_app->get_num_args(), union_app->get_args(), result); + + ENSURE(st == BR_DONE); + ENSURE(result == s1); + } + + void test_union_with_empty() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + finite_set_rewriter rw(m); + arith_util arith(m); + + // Create a set and empty set + sort_ref int_sort(arith.mk_int(), m); + expr_ref zero(arith.mk_int(0), m); + expr_ref ten(arith.mk_int(10), m); + app_ref s1(fsets.mk_range(zero, ten), m); + app_ref empty_set(fsets.mk_empty(s1->get_sort()), m); + + // Test set.union(s1, empty) -> s1 + app_ref union_app1(fsets.mk_union(s1, empty_set), m); + expr_ref result1(m); + br_status st1 = + rw.mk_app_core(union_app1->get_decl(), union_app1->get_num_args(), union_app1->get_args(), result1); + ENSURE(st1 == BR_DONE); + ENSURE(result1 == s1); + + // Test set.union(empty, s1) -> s1 + app_ref union_app2(fsets.mk_union(empty_set, s1), m); + expr_ref result2(m); + br_status st2 = + rw.mk_app_core(union_app2->get_decl(), union_app2->get_num_args(), union_app2->get_args(), result2); + ENSURE(st2 == BR_DONE); + ENSURE(result2 == s1); + } + + void test_intersect_with_empty() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + finite_set_rewriter rw(m); + arith_util arith(m); + + // Create a set and empty set + sort_ref int_sort(arith.mk_int(), m); + expr_ref zero(arith.mk_int(0), m); + expr_ref ten(arith.mk_int(10), m); + app_ref s1(fsets.mk_range(zero, ten), m); + app_ref empty_set(fsets.mk_empty(s1->get_sort()), m); + + // Test set.intersect(s1, empty) -> empty + app_ref intersect_app1(fsets.mk_intersect(s1, empty_set), m); + expr_ref result1(m); + br_status st1 = rw.mk_app_core(intersect_app1->get_decl(), intersect_app1->get_num_args(), + intersect_app1->get_args(), result1); + ENSURE(st1 == BR_DONE); + ENSURE(result1 == empty_set); + + // Test set.intersect(empty, s1) -> empty + app_ref intersect_app2(fsets.mk_intersect(empty_set, s1), m); + expr_ref result2(m); + br_status st2 = rw.mk_app_core(intersect_app2->get_decl(), intersect_app2->get_num_args(), + intersect_app2->get_args(), result2); + ENSURE(st2 == BR_DONE); + ENSURE(result2 == empty_set); + } + + void test_difference_with_empty() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + finite_set_rewriter rw(m); + arith_util arith(m); + + // Create a set and empty set + sort_ref int_sort(arith.mk_int(), m); + expr_ref zero(arith.mk_int(0), m); + expr_ref ten(arith.mk_int(10), m); + app_ref s1(fsets.mk_range(zero, ten), m); + app_ref empty_set(fsets.mk_empty(s1->get_sort()), m); + + // Test set.difference(s1, empty) -> s1 + app_ref diff_app1(fsets.mk_difference(s1, empty_set), m); + expr_ref result1(m); + br_status st1 = + rw.mk_app_core(diff_app1->get_decl(), diff_app1->get_num_args(), diff_app1->get_args(), result1); + ENSURE(st1 == BR_DONE); + ENSURE(result1 == s1); + + // Test set.difference(empty, s1) -> empty + app_ref diff_app2(fsets.mk_difference(empty_set, s1), m); + expr_ref result2(m); + br_status st2 = + rw.mk_app_core(diff_app2->get_decl(), diff_app2->get_num_args(), diff_app2->get_args(), result2); + ENSURE(st2 == BR_DONE); + ENSURE(result2 == empty_set); + } + + void test_subset_with_empty() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + finite_set_rewriter rw(m); + arith_util arith(m); + + // Create a set and empty set + sort_ref int_sort(arith.mk_int(), m); + expr_ref zero(arith.mk_int(0), m); + expr_ref ten(arith.mk_int(10), m); + app_ref s1(fsets.mk_range(zero, ten), m); + app_ref empty_set(fsets.mk_empty(s1->get_sort()), m); + + // Test set.subset(empty, s1) -> true + app_ref subset_app1(fsets.mk_subset(empty_set, s1), m); + expr_ref result1(m); + br_status st1 = + rw.mk_app_core(subset_app1->get_decl(), subset_app1->get_num_args(), subset_app1->get_args(), result1); + ENSURE(st1 == BR_DONE); + ENSURE(m.is_true(result1)); + + // Test set.subset(s1, s1) -> true + app_ref subset_app2(fsets.mk_subset(s1, s1), m); + expr_ref result2(m); + br_status st2 = + rw.mk_app_core(subset_app2->get_decl(), subset_app2->get_num_args(), subset_app2->get_args(), result2); + ENSURE(st2 == BR_DONE); + ENSURE(m.is_true(result2)); + } + + void test_in_singleton() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + finite_set_rewriter rw(m); + arith_util arith(m); + + // Create elements and singleton + expr_ref five(arith.mk_int(5), m); + expr_ref ten(arith.mk_int(10), m); + app_ref singleton_five(fsets.mk_singleton(five), m); + + // Test set.in(five, singleton(five)) -> true + app_ref in_app1(fsets.mk_in(five, singleton_five), m); + expr_ref result1(m); + br_status st1 = rw.mk_app_core(in_app1->get_decl(), in_app1->get_num_args(), in_app1->get_args(), result1); + ENSURE(st1 == BR_DONE); + ENSURE(m.is_true(result1)); + + // Test set.in(ten, singleton(five)) -> ten = five + app_ref in_app2(fsets.mk_in(ten, singleton_five), m); + expr_ref result2(m); + br_status st2 = rw.mk_app_core(in_app2->get_decl(), in_app2->get_num_args(), in_app2->get_args(), result2); + ENSURE(st2 == BR_REWRITE1); + ENSURE(m.is_eq(result2)); + } + + void test_in_empty() { + ast_manager m; + reg_decl_plugins(m); + + finite_set_util fsets(m); + finite_set_rewriter rw(m); + arith_util arith(m); + + // Create element and empty set + sort_ref int_sort(arith.mk_int(), m); + expr_ref five(arith.mk_int(5), m); + parameter param(int_sort.get()); + sort_ref set_sort(m.mk_sort(fsets.get_family_id(), FINITE_SET_SORT, 1, ¶m), m); + app_ref empty_set(fsets.mk_empty(set_sort), m); + + // Test set.in(five, empty) -> false + app_ref in_app(fsets.mk_in(five, empty_set), m); + expr_ref result(m); + br_status st = rw.mk_app_core(in_app->get_decl(), in_app->get_num_args(), in_app->get_args(), result); + ENSURE(st == BR_DONE); + ENSURE(m.is_false(result)); + } +}; + +void tst_finite_set_rewriter() { + finite_set_rewriter_test test; + test.test_union_idempotent(); + test.test_intersect_idempotent(); + test.test_difference_same(); + test.test_subset_rewrite(); + test.test_mk_app_core(); + test.test_union_with_empty(); + test.test_intersect_with_empty(); + test.test_difference_with_empty(); + test.test_subset_with_empty(); + test.test_in_singleton(); + test.test_in_empty(); +} diff --git a/src/test/main.cpp b/src/test/main.cpp index 68dfbb301..a8158b99e 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -285,4 +285,6 @@ int main(int argc, char ** argv) { TST(scoped_vector); TST(sls_seq_plugin); TST(ho_matcher); + TST(finite_set); + TST(finite_set_rewriter); } diff --git a/src/util/ref_vector.h b/src/util/ref_vector.h index f8bd4061b..c20970df1 100644 --- a/src/util/ref_vector.h +++ b/src/util/ref_vector.h @@ -278,6 +278,10 @@ public: SASSERT(&(this->m_manager) == &(other.m_manager)); this->m_nodes.swap(other.m_nodes); } + + void swap(unsigned idx1, unsigned idx2) noexcept { + this->super::swap(idx1, idx2); + } class element_ref { T * & m_ref; diff --git a/src/util/trace_tags.def b/src/util/trace_tags.def index 8eefa1d04..1ad305c2d 100644 --- a/src/util/trace_tags.def +++ b/src/util/trace_tags.def @@ -70,6 +70,8 @@ X(eq_der, top_sort, "top sort") X(expr_substitution_simplifier, expr_substitution_simplifier, "expr substitution simplifier") X(expr_substitution_simplifier, propagate_values, "propagate values") +X(finite_set, finite_set, "finite set") + X(fm_model_converter, fm_model_converter, "fm model converter") X(fm_model_converter, fm_mc, "fm mc") diff --git a/src/util/trail.h b/src/util/trail.h index 43e698234..5b96fdad0 100644 --- a/src/util/trail.h +++ b/src/util/trail.h @@ -423,4 +423,13 @@ public: m_scopes.shrink(new_lvl); m_region.pop_scope(num_scopes); } + + unsigned size() const { + return m_trail_stack.size(); + } + + void shrink(unsigned new_size) { + SASSERT(new_size <= m_trail_stack.size()); + m_trail_stack.shrink(new_size); + } };