diff --git a/src/api/api_datatype.cpp b/src/api/api_datatype.cpp index f3a275508..fae8a21bc 100644 --- a/src/api/api_datatype.cpp +++ b/src/api/api_datatype.cpp @@ -618,4 +618,25 @@ extern "C" { Z3_CATCH_RETURN(0); } + Z3_ast Z3_datatype_update_field( + __in Z3_context c, __in Z3_func_decl f, __in Z3_ast t, __in Z3_ast v) { + Z3_TRY; + LOG_Z3_datatype_update_field(c, f, t, v); + RESET_ERROR_CODE(); + ast_manager & m = mk_c(c)->m(); + func_decl* _f = to_func_decl(f); + expr* _t = to_expr(t); + expr* _v = to_expr(v); + expr* args[2] = { _t, _v }; + sort* domain[2] = { m.get_sort(_t), m.get_sort(_v) }; + parameter param(_f); + func_decl * d = m.mk_func_decl(mk_c(c)->get_array_fid(), OP_DT_UPDATE_FIELD, 1, ¶m, 2, domain); + app* r = m.mk_app(d, 2, args); + mk_c(c)->save_ast_trail(r); + check_sorts(c, r); + RETURN_Z3(of_ast(r)); + Z3_CATCH_RETURN(0); + } + + }; diff --git a/src/api/dotnet/Context.cs b/src/api/dotnet/Context.cs index 989d1d7d7..13e78e495 100644 --- a/src/api/dotnet/Context.cs +++ b/src/api/dotnet/Context.cs @@ -449,6 +449,19 @@ namespace Microsoft.Z3 return MkDatatypeSorts(MkSymbols(names), c); } + /// + /// Update a datatype field at expression t with value v. + /// The function performs a record update at t. The field + /// that is passed in as argument is updated with value v, + /// the remainig fields of t are unchanged. + /// + public Expr MkUpdateField(FuncDecl field, Expr t, Expr v) + { + return Expr.Create(this, Native.Z3_datatype_update_field( + nCtx, field.NativeObject, + t.NativeObject, v.NativeObject)); + } + #endregion #endregion diff --git a/src/api/java/Context.java b/src/api/java/Context.java index 4fbd79be2..d52edfb82 100644 --- a/src/api/java/Context.java +++ b/src/api/java/Context.java @@ -375,6 +375,22 @@ public class Context extends IDisposable return mkDatatypeSorts(MkSymbols(names), c); } + /** + * Update a datatype field at expression t with value v. + * The function performs a record update at t. The field + * that is passed in as argument is updated with value v, + * the remainig fields of t are unchanged. + **/ + public Expr MkUpdateField(FuncDecl field, Expr t, Expr v) + { + return Expr.Create + (this, + Native.datatypeUpdateField + (nCtx(), field.getNativeObject(), + t.getNativeObject(), v.getNativeObject())); + } + + /** * Creates a new function declaration. **/ diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 4550e7976..792174233 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -877,6 +877,8 @@ typedef enum - Z3_OP_DT_ACCESSOR: datatype accessor. + - Z3_OP_DT_UPDATE_FIELD: datatype field update. + - Z3_OP_PB_AT_MOST: Cardinality constraint. E.g., x + y + z <= 2 @@ -1066,6 +1068,7 @@ typedef enum { Z3_OP_DT_CONSTRUCTOR=0x800, Z3_OP_DT_RECOGNISER, Z3_OP_DT_ACCESSOR, + Z3_OP_DT_UPDATE_FIELD, // Pseudo Booleans Z3_OP_PB_AT_MOST=0x900, @@ -3751,6 +3754,28 @@ END_MLAPI_EXCLUDE Z3_func_decl Z3_API Z3_get_datatype_sort_constructor_accessor( __in Z3_context c, __in Z3_sort t, unsigned idx_c, unsigned idx_a); + /** + \brief Update record field with a value. + + This corresponds to the 'with' construct in OCaml. + It has the effect of updating a record field with a given value. + The remaining fields are left unchanged. It is the record + equivalent of an array store (see \sa Z3_mk_store). + If the datatype has more than one constructor, then the update function + behaves as identity if there is a miss-match between the accessor and + constructor. For example ((_ update-field car) nil 1) is nil, + while ((_ update-field car) (cons 2 nil) 1) is (cons 1 nil). + + + \pre Z3_get_sort_kind(Z3_get_sort(c, t)) == Z3_get_domain(c, field_access, 1) == Z3_DATATYPE_SORT + \pre Z3_get_sort(c, value) == Z3_get_range(c, field_access) + + + def_API('Z3_datatype_update_field', AST, (_in(CONTEXT), _in(FUNC_DECL), _in(AST), _in(AST))) + */ + Z3_ast Z3_API Z3_datatype_update_field( + __in Z3_context c, __in Z3_func_decl field_access, + __in Z3_ast t, __in Z3_ast value); /** \brief Return arity of relation. diff --git a/src/ast/datatype_decl_plugin.cpp b/src/ast/datatype_decl_plugin.cpp index 5b29ec20b..d707f0178 100644 --- a/src/ast/datatype_decl_plugin.cpp +++ b/src/ast/datatype_decl_plugin.cpp @@ -422,8 +422,55 @@ static sort * get_type(ast_manager & m, family_id datatype_fid, sort * source_da } } +func_decl * datatype_decl_plugin::mk_update_field( + unsigned num_parameters, parameter const * parameters, + unsigned arity, sort * const * domain, sort * range) { + decl_kind k = OP_DT_UPDATE_FIELD; + ast_manager& m = *m_manager; + + if (num_parameters != 1 || !parameters[0].is_ast()) { + m.raise_exception("invalid parameters for datatype field update"); + return 0; + } + if (arity != 2) { + m.raise_exception("invalid number of arguments for datatype field update"); + return 0; + } + func_decl* acc = 0; + if (is_func_decl(parameters[0].get_ast())) { + acc = to_func_decl(parameters[0].get_ast()); + } + if (acc && !get_util().is_accessor(acc)) { + acc = 0; + } + if (!acc) { + m.raise_exception("datatype field update requires a datatype accessor as the second argument"); + return 0; + } + sort* dom = acc->get_domain(0); + sort* rng = acc->get_range(); + if (dom != domain[0]) { + m.raise_exception("first argument to field update should be a data-type"); + return 0; + } + if (rng != domain[1]) { + std::ostringstream buffer; + buffer << "second argument to field update should be " << mk_ismt2_pp(rng, m) + << " instead of " << mk_ismt2_pp(domain[1], m); + m.raise_exception(buffer.str().c_str()); + return 0; + } + range = domain[0]; + func_decl_info info(m_family_id, k, num_parameters, parameters); + return m.mk_func_decl(symbol("update_field"), arity, domain, range, info); +} + func_decl * datatype_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, parameter const * parameters, unsigned arity, sort * const * domain, sort * range) { + + if (k == OP_DT_UPDATE_FIELD) { + return mk_update_field(num_parameters, parameters, arity, domain, range); + } if (num_parameters < 2 || !parameters[0].is_ast() || !is_sort(parameters[0].get_ast())) { m_manager->raise_exception("invalid parameters for datatype operator"); return 0; @@ -521,20 +568,9 @@ func_decl * datatype_decl_plugin::mk_func_decl(decl_kind k, unsigned num_paramet return m_manager->mk_func_decl(a_name, arity, domain, a_type, info); } break; - case OP_DT_UPDATE_FIELD: - if (num_parameters != 2 || arity != 2 || domain[0] != datatype) { - m_manager->raise_exception("invalid parameters for datatype field update"); - return 0; - } - else { - symbol con_name = parameters[0].get_symbol(); - symbol acc_name = parameters[1].get_symbol(); - func_decl_info info(m_family_id, k, num_parameters, parameters); - info.m_private_parameters = true; - SASSERT(info.private_parameters()); - return m_manager->mk_func_decl(symbol("update_field"), arity, domain, datatype, info); - } - + case OP_DT_UPDATE_FIELD: + UNREACHABLE(); + return 0; default: m_manager->raise_exception("invalid datatype operator kind"); return 0; @@ -687,12 +723,9 @@ bool datatype_decl_plugin::is_value(app * e) const { } void datatype_decl_plugin::get_op_names(svector & op_names, symbol const & logic) { -#if 0 - // disabled if (logic == symbol::null) { - op_names.push_back(builtin_name("update_field", OP_DT_UPDATE_FIELD)); + op_names.push_back(builtin_name("update-field", OP_DT_UPDATE_FIELD)); } -#endif } diff --git a/src/ast/datatype_decl_plugin.h b/src/ast/datatype_decl_plugin.h index af7c689bc..2218963c3 100644 --- a/src/ast/datatype_decl_plugin.h +++ b/src/ast/datatype_decl_plugin.h @@ -154,6 +154,10 @@ public: private: bool is_value_visit(expr * arg, ptr_buffer & todo) const; + + func_decl * mk_update_field( + unsigned num_parameters, parameter const * parameters, + unsigned arity, sort * const * domain, sort * range); }; class datatype_util { @@ -184,9 +188,11 @@ public: bool is_constructor(func_decl * f) const { return is_decl_of(f, m_family_id, OP_DT_CONSTRUCTOR); } bool is_recognizer(func_decl * f) const { return is_decl_of(f, m_family_id, OP_DT_RECOGNISER); } bool is_accessor(func_decl * f) const { return is_decl_of(f, m_family_id, OP_DT_ACCESSOR); } + bool is_update_field(func_decl * f) const { return is_decl_of(f, m_family_id, OP_DT_UPDATE_FIELD); } bool is_constructor(app * f) const { return is_app_of(f, m_family_id, OP_DT_CONSTRUCTOR); } bool is_recognizer(app * f) const { return is_app_of(f, m_family_id, OP_DT_RECOGNISER); } bool is_accessor(app * f) const { return is_app_of(f, m_family_id, OP_DT_ACCESSOR); } + bool is_update_field(app * f) const { return is_app_of(f, m_family_id, OP_DT_UPDATE_FIELD); } ptr_vector const * get_datatype_constructors(sort * ty); unsigned get_datatype_num_constructors(sort * ty) { return get_datatype_constructors(ty)->size(); } unsigned get_constructor_idx(func_decl * f) const { SASSERT(is_constructor(f)); return f->get_parameter(1).get_int(); } diff --git a/src/ast/rewriter/datatype_rewriter.cpp b/src/ast/rewriter/datatype_rewriter.cpp index 8c55ba498..be198c3d9 100644 --- a/src/ast/rewriter/datatype_rewriter.cpp +++ b/src/ast/rewriter/datatype_rewriter.cpp @@ -60,6 +60,32 @@ br_status datatype_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr UNREACHABLE(); break; } + case OP_DT_UPDATE_FIELD: { + SASSERT(num_args == 2); + if (!is_app(args[0]) || !m_util.is_constructor(to_app(args[0]))) + return BR_FAILED; + app * a = to_app(args[0]); + func_decl * c_decl = a->get_decl(); + if (c_decl != m_util.get_accessor_constructor(f)) { + result = a; + return BR_DONE; + } + ptr_vector const * acc = m_util.get_constructor_accessors(c_decl); + SASSERT(acc && acc->size() == a->get_num_args()); + unsigned num = acc->size(); + ptr_buffer new_args; + for (unsigned i = 0; i < num; ++i) { + + if (f == (*acc)[i]) { + new_args.push_back(args[1]); + } + else { + new_args.push_back(a->get_arg(i)); + } + } + result = m().mk_app(c_decl, num, new_args.c_ptr()); + return BR_DONE; + } default: UNREACHABLE(); } diff --git a/src/ast/simplifier/datatype_simplifier_plugin.cpp b/src/ast/simplifier/datatype_simplifier_plugin.cpp index 129c0e34b..b434a8bd0 100644 --- a/src/ast/simplifier/datatype_simplifier_plugin.cpp +++ b/src/ast/simplifier/datatype_simplifier_plugin.cpp @@ -81,6 +81,8 @@ bool datatype_simplifier_plugin::reduce(func_decl * f, unsigned num_args, expr * } UNREACHABLE(); } + case OP_DT_UPDATE_FIELD: + return false; default: UNREACHABLE(); } diff --git a/src/smt/theory_datatype.cpp b/src/smt/theory_datatype.cpp index 8c6543eff..b5cdbcfe2 100644 --- a/src/smt/theory_datatype.cpp +++ b/src/smt/theory_datatype.cpp @@ -61,6 +61,13 @@ namespace smt { if (antecedent == null_literal) { ctx.assign_eq(lhs, ctx.get_enode(rhs), eq_justification::mk_axiom()); } + else if (ctx.get_assignment(antecedent) != l_true) { + literal l(mk_eq(lhs->get_owner(), rhs, true)); + ctx.mark_as_relevant(l); + ctx.mark_as_relevant(antecedent); + literal lits[2] = {l, ~antecedent}; + ctx.mk_th_axiom(get_id(), 2, lits); + } else { SASSERT(ctx.get_assignment(antecedent) == l_true); region & r = ctx.get_region(); @@ -143,6 +150,48 @@ namespace smt { ctx.set_conflict(ctx.mk_justification(ext_theory_conflict_justification(get_id(), reg, 1, &l, 1, &p))); } + /** + \brief Given a field update n := { r with field := v } for constructor C, assert the axioms: + (=> (is-C r) (= (acc_j n) (acc_j r))) for acc_j != field + (=> (is-C r) (= (field n) v)) for acc_j != field + (=> (not (is-C r)) (= n r)) + */ + void theory_datatype::assert_update_field_axioms(enode * n) { + m_stats.m_assert_update_field++; + SASSERT(is_update_field(n)); + context & ctx = get_context(); + ast_manager & m = get_manager(); + app* own = n->get_owner(); + expr* arg1 = own->get_arg(0); + expr* arg2 = own->get_arg(1); + func_decl * upd = n->get_decl(); + func_decl * acc = to_func_decl(upd->get_parameter(0).get_ast()); + func_decl * con = m_util.get_accessor_constructor(acc); + func_decl * rec = m_util.get_constructor_recognizer(con); + ptr_vector const * accessors = m_util.get_constructor_accessors(con); + ptr_vector::const_iterator it = accessors->begin(); + ptr_vector::const_iterator end = accessors->end(); + app_ref rec_app(m.mk_app(rec, arg1), m); + ctx.internalize(rec_app, false); + literal is_con(ctx.get_bool_var(rec_app)); + for (; it != end; ++it) { + enode* arg; + func_decl * acc1 = *it; + if (acc1 == acc) { + arg = n->get_arg(1); + } + else { + app* acc_app = m.mk_app(acc1, arg1); + ctx.internalize(acc_app, false); + arg = ctx.get_enode(acc_app); + } + app * acc_own = m.mk_app(acc1, own); + assert_eq_axiom(arg, acc_own, is_con); + } + // update_field is identity if 'n' is not created by a matching constructor. + assert_eq_axiom(n, arg1, ~is_con); + } + theory_var theory_datatype::mk_var(enode * n) { theory_var r = theory::mk_var(n); theory_var r2 = m_find.mk_var(); @@ -150,15 +199,17 @@ namespace smt { SASSERT(r == static_cast(m_var_data.size())); m_var_data.push_back(alloc(var_data)); var_data * d = m_var_data[r]; + context & ctx = get_context(); + ctx.attach_th_var(n, this, r); if (is_constructor(n)) { d->m_constructor = n; - get_context().attach_th_var(n, this, r); assert_accessor_axioms(n); } + else if (is_update_field(n)) { + assert_update_field_axioms(n); + } else { ast_manager & m = get_manager(); - context & ctx = get_context(); - ctx.attach_th_var(n, this, r); sort * s = m.get_sort(n->get_owner()); if (m_util.get_datatype_num_constructors(s) == 1) { func_decl * c = m_util.get_datatype_constructors(s)->get(0); @@ -192,7 +243,7 @@ namespace smt { ctx.set_var_theory(bv, get_id()); ctx.set_enode_flag(bv, true); } - if (is_constructor(term)) { + if (is_constructor(term) || is_update_field(term)) { SASSERT(!is_attached_to_var(e)); // *** We must create a theory variable for each argument that has sort datatype *** // @@ -478,6 +529,7 @@ namespace smt { st.update("datatype splits", m_stats.m_splits); st.update("datatype constructor ax", m_stats.m_assert_cnstr); st.update("datatype accessor ax", m_stats.m_assert_accessor); + st.update("datatype update ax", m_stats.m_assert_update_field); } void theory_datatype::display_var(std::ostream & out, theory_var v) const { diff --git a/src/smt/theory_datatype.h b/src/smt/theory_datatype.h index cf2f933ad..5ebfc220f 100644 --- a/src/smt/theory_datatype.h +++ b/src/smt/theory_datatype.h @@ -41,7 +41,7 @@ namespace smt { struct stats { unsigned m_occurs_check, m_splits; - unsigned m_assert_cnstr, m_assert_accessor; + unsigned m_assert_cnstr, m_assert_accessor, m_assert_update_field; void reset() { memset(this, 0, sizeof(stats)); } stats() { reset(); } }; @@ -58,14 +58,17 @@ namespace smt { bool is_constructor(app * f) const { return m_util.is_constructor(f); } bool is_recognizer(app * f) const { return m_util.is_recognizer(f); } bool is_accessor(app * f) const { return m_util.is_accessor(f); } - + bool is_update_field(app * f) const { return m_util.is_update_field(f); } + bool is_constructor(enode * n) const { return is_constructor(n->get_owner()); } bool is_recognizer(enode * n) const { return is_recognizer(n->get_owner()); } bool is_accessor(enode * n) const { return is_accessor(n->get_owner()); } + bool is_update_field(enode * n) const { return m_util.is_update_field(n->get_owner()); } void assert_eq_axiom(enode * lhs, expr * rhs, literal antecedent); void assert_is_constructor_axiom(enode * n, func_decl * c, literal antecedent); void assert_accessor_axioms(enode * n); + void assert_update_field_axioms(enode * n); void add_recognizer(theory_var v, enode * recognizer); void propagate_recognizer(theory_var v, enode * r); void sign_recognizer_conflict(enode * c, enode * r);