3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-02 20:31:21 +00:00

add match expression construct to SMT-LIB2.6 frontend

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-09-19 19:39:02 -07:00
parent 43e47271f7
commit caa02c3c02
6 changed files with 203 additions and 36 deletions

View file

@ -584,6 +584,8 @@ class smt2_printer {
string_buffer<> buf; string_buffer<> buf;
buf.append("(:var "); buf.append("(:var ");
buf.append(v->get_idx()); buf.append(v->get_idx());
buf.append(" ");
buf.append(v->get_sort()->get_name().str().c_str());
buf.append(")"); buf.append(")");
f = mk_string(m(), buf.c_str()); f = mk_string(m(), buf.c_str());
} }

View file

@ -42,6 +42,10 @@ void rewriter_tpl<Config>::process_var(var * v) {
unsigned index = m_bindings.size() - idx - 1; unsigned index = m_bindings.size() - idx - 1;
var * r = (var*)(m_bindings[index]); var * r = (var*)(m_bindings[index]);
if (r != 0) { if (r != 0) {
CTRACE("rewriter", v->get_sort() != m().get_sort(r),
tout << expr_ref(v, m()) << ":" << sort_ref(v->get_sort(), m()) << " != " << expr_ref(r, m()) << ":" << sort_ref(m().get_sort(r), m());
tout << "index " << index << " bindings " << m_bindings.size() << "\n";
display_bindings(tout););
SASSERT(v->get_sort() == m().get_sort(r)); SASSERT(v->get_sort() == m().get_sort(r));
if (!is_ground(r) && m_shifts[index] != m_bindings.size()) { if (!is_ground(r) && m_shifts[index] != m_bindings.size()) {

View file

@ -202,7 +202,7 @@ func_decl * func_decls::find(unsigned arity, sort * const * domain, sort * range
if (f->get_arity() != arity) if (f->get_arity() != arity)
continue; continue;
unsigned i = 0; unsigned i = 0;
for (i = 0; i < arity; i++) { for (i = 0; domain && i < arity; i++) {
if (f->get_domain(i) != domain[i]) if (f->get_domain(i) != domain[i])
break; break;
} }
@ -937,7 +937,7 @@ static builtin_decl const & peek_builtin_decl(builtin_decl const & first, family
func_decl * cmd_context::find_func_decl(symbol const & s, unsigned num_indices, unsigned const * indices, func_decl * cmd_context::find_func_decl(symbol const & s, unsigned num_indices, unsigned const * indices,
unsigned arity, sort * const * domain, sort * range) const { unsigned arity, sort * const * domain, sort * range) const {
builtin_decl d; builtin_decl d;
if (m_builtin_decls.find(s, d)) { if (domain && m_builtin_decls.find(s, d)) {
family_id fid = d.m_fid; family_id fid = d.m_fid;
decl_kind k = d.m_decl; decl_kind k = d.m_decl;
// Hack: if d.m_next != 0, we use domain[0] (if available) to decide which plugin we use. // Hack: if d.m_next != 0, we use domain[0] (if available) to decide which plugin we use.
@ -961,7 +961,7 @@ func_decl * cmd_context::find_func_decl(symbol const & s, unsigned num_indices,
return f; return f;
} }
if (contains_macro(s, arity, domain)) if (domain && contains_macro(s, arity, domain))
throw cmd_exception("invalid function declaration reference, named expressions (aka macros) cannot be referenced ", s); throw cmd_exception("invalid function declaration reference, named expressions (aka macros) cannot be referenced ", s);
if (num_indices > 0) if (num_indices > 0)

View file

@ -140,7 +140,6 @@ void func_interp::set_else(expr * e) {
return; return;
reset_interp_cache(); reset_interp_cache();
ptr_vector<expr> args; ptr_vector<expr> args;
while (e && is_fi_entry_expr(e, args)) { while (e && is_fi_entry_expr(e, args)) {
TRACE("func_interp", tout << "fi entry expr: " << mk_ismt2_pp(e, m()) << std::endl;); TRACE("func_interp", tout << "fi entry expr: " << mk_ismt2_pp(e, m()) << std::endl;);

View file

@ -24,6 +24,7 @@ Revision History:
#include "ast/ast_pp.h" #include "ast/ast_pp.h"
#include "ast/well_sorted.h" #include "ast/well_sorted.h"
#include "ast/rewriter/rewriter.h" #include "ast/rewriter/rewriter.h"
#include "ast/rewriter/var_subst.h"
#include "ast/has_free_vars.h" #include "ast/has_free_vars.h"
#include "ast/ast_smt2_pp.h" #include "ast/ast_smt2_pp.h"
#include "parsers/smt2/smt2parser.h" #include "parsers/smt2/smt2parser.h"
@ -68,6 +69,7 @@ namespace smt2 {
scoped_ptr<bv_util> m_bv_util; scoped_ptr<bv_util> m_bv_util;
scoped_ptr<arith_util> m_arith_util; scoped_ptr<arith_util> m_arith_util;
scoped_ptr<datatype_util> m_datatype_util;
scoped_ptr<seq_util> m_seq_util; scoped_ptr<seq_util> m_seq_util;
scoped_ptr<pattern_validator> m_pattern_validator; scoped_ptr<pattern_validator> m_pattern_validator;
scoped_ptr<var_shifter> m_var_shifter; scoped_ptr<var_shifter> m_var_shifter;
@ -136,7 +138,7 @@ namespace smt2 {
typedef psort_frame sort_frame; typedef psort_frame sort_frame;
enum expr_frame_kind { EF_APP, EF_LET, EF_LET_DECL, EF_QUANT, EF_MATCH, EF_ATTR_EXPR, EF_PATTERN }; enum expr_frame_kind { EF_APP, EF_LET, EF_LET_DECL, EF_MATCH, EF_QUANT, EF_ATTR_EXPR, EF_PATTERN };
struct expr_frame { struct expr_frame {
expr_frame_kind m_kind; expr_frame_kind m_kind;
@ -174,9 +176,7 @@ namespace smt2 {
}; };
struct match_frame : public expr_frame { struct match_frame : public expr_frame {
match_frame(): match_frame():expr_frame(EF_MATCH) {}
expr_frame(EF_MATCH)
{}
}; };
struct let_frame : public expr_frame { struct let_frame : public expr_frame {
@ -282,6 +282,12 @@ namespace smt2 {
return *(m_arith_util.get()); return *(m_arith_util.get());
} }
datatype_util & dtutil() {
if (m_datatype_util.get() == 0)
m_datatype_util = alloc(datatype_util, m());
return *(m_datatype_util.get());
}
seq_util & sutil() { seq_util & sutil() {
if (m_seq_util.get() == 0) if (m_seq_util.get() == 0)
m_seq_util = alloc(seq_util, m()); m_seq_util = alloc(seq_util, m());
@ -1266,6 +1272,23 @@ namespace smt2 {
return num; return num;
} }
void push_let_frame() {
next();
check_lparen_next("invalid let declaration, '(' expected");
void * mem = m_stack.allocate(sizeof(let_frame));
new (mem) let_frame(symbol_stack().size(), expr_stack().size());
m_num_expr_frames++;
}
void push_bang_frame(expr_frame * curr) {
TRACE("consume_attributes", tout << "begin bang, expr_stack.size(): " << expr_stack().size() << "\n";);
next();
void * mem = m_stack.allocate(sizeof(attr_expr_frame));
new (mem) attr_expr_frame(curr, symbol_stack().size(), expr_stack().size());
m_num_expr_frames++;
}
void push_quant_frame(bool is_forall) { void push_quant_frame(bool is_forall) {
SASSERT(curr_is_identifier()); SASSERT(curr_is_identifier());
SASSERT(curr_id_is_forall() || curr_id_is_exists()); SASSERT(curr_id_is_forall() || curr_id_is_exists());
@ -1286,40 +1309,179 @@ namespace smt2 {
* (match t ((p1 t1) ··· (pm+1 tm+1))) * (match t ((p1 t1) ··· (pm+1 tm+1)))
*/ */
void push_match_frame() { void push_match_frame() {
SASSERT(curr_is_identifier());
SASSERT(curr_id() == m_match);
next(); next();
#if 0 void * mem = m_stack.allocate(sizeof(match_frame));
// just use the stack for parsing these for now.
void * mem = m_stack.allocate(sizeof(match_frame));
new (mem) match_frame(); new (mem) match_frame();
m_num_expr_frames++; unsigned num_frames = m_num_expr_frames;
#endif
parse_expr(); parse_expr();
expr_ref t(expr_stack().back(), m()); expr_ref t(expr_stack().back(), m());
expr_stack().pop_back(); expr_stack().pop_back();
expr_ref_vector patterns(m()), cases(m()); expr_ref_vector patterns(m()), cases(m());
sort* srt = m().get_sort(t);
check_lparen_next("pattern bindings should be enclosed in a parenthesis"); check_lparen_next("pattern bindings should be enclosed in a parenthesis");
while (!curr_is_rparen()) { while (!curr_is_rparen()) {
m_env.begin_scope();
unsigned num_bindings = m_num_bindings;
check_lparen_next("invalid pattern binding, '(' expected"); check_lparen_next("invalid pattern binding, '(' expected");
parse_expr(); // TBD need to parse a pattern here. The sort of 't' provides context for how to interpret _. parse_match_pattern(srt);
patterns.push_back(expr_stack().back()); patterns.push_back(expr_stack().back());
expr_stack().pop_back(); expr_stack().pop_back();
parse_expr(); parse_expr();
cases.push_back(expr_stack().back()); cases.push_back(expr_stack().back());
expr_stack().pop_back(); expr_stack().pop_back();
m_num_bindings = num_bindings;
m_env.end_scope();
check_rparen_next("invalid pattern binding, ')' expected"); check_rparen_next("invalid pattern binding, ')' expected");
} }
next(); next();
m_num_expr_frames = num_frames + 1;
expr_stack().push_back(compile_patterns(t, patterns, cases)); expr_stack().push_back(compile_patterns(t, patterns, cases));
} }
expr_ref compile_patterns(expr* t, expr_ref_vector const& patterns, expr_ref_vector const& cases) { void pop_match_frame(match_frame* fr) {
NOT_IMPLEMENTED_YET(); m_stack.deallocate(fr);
return expr_ref(m()); m_num_expr_frames--;
} }
void pop_match_frame(match_frame * fr) { expr_ref compile_patterns(expr* t, expr_ref_vector const& patterns, expr_ref_vector const& cases) {
expr_ref result(m());
var_subst sub(m(), false);
TRACE("parse_expr", tout << "term\n" << expr_ref(t, m()) << "\npatterns\n" << patterns << "\ncases\n" << cases << "\n";);
for (unsigned i = patterns.size(); i > 0; ) {
--i;
expr_ref_vector subst(m());
expr_ref cond = bind_match(t, patterns[i], subst);
expr_ref new_case(m());
if (subst.empty()) {
new_case = cases[i];
}
else {
sub(cases[i], subst.size(), subst.c_ptr(), new_case);
inv_var_shifter inv(m());
inv(new_case, subst.size(), new_case);
}
if (result) {
result = m().mk_ite(cond, new_case, result);
}
else {
// pattern match binding is ignored.
result = new_case;
}
}
TRACE("parse_expr", tout << result << "\n";);
return result;
}
// compute match condition and substitution
// t is shifted by size of subst.
expr_ref bind_match(expr* t, expr* pattern, expr_ref_vector& subst) {
expr_ref tsh(m());
if (is_var(pattern)) {
shifter()(t, 1, tsh);
subst.push_back(tsh);
return expr_ref(m().mk_true(), m());
}
else {
SASSERT(is_app(pattern));
func_decl * f = to_app(pattern)->get_decl();
func_decl * r = dtutil().get_constructor_recognizer(f);
ptr_vector<func_decl> const * acc = dtutil().get_constructor_accessors(f);
shifter()(t, acc->size(), tsh);
for (func_decl* a : *acc) {
subst.push_back(m().mk_app(a, tsh));
}
return expr_ref(m().mk_app(r, t), m());
}
}
/**
* parse a match pattern
* (C x1 .... xn)
* C
* _
* x
*/
bool parse_constructor_pattern(sort * srt) {
if (!curr_is_lparen()) {
return false;
}
next();
svector<symbol> vars;
expr_ref_vector args(m());
symbol C(check_identifier_next("constructor symbol expected"));
while (!curr_is_rparen()) {
symbol v(check_identifier_next("variable symbol expected"));
if (v != m_underscore && vars.contains(v)) {
throw parser_exception("unexpected repeated variable in pattern expression");
}
vars.push_back(v);
}
next();
// now have C, vars
// look up constructor C,
// create bound variables based on constructor type.
// store expression in expr_stack().
// ensure that bound variables are adjusted to vars
func_decl* f = m_ctx.find_func_decl(C, 0, nullptr, vars.size(), nullptr, srt);
if (!f) {
throw parser_exception("expecting a constructor that has been declared");
}
if (!dtutil().is_constructor(f)) {
throw parser_exception("expecting a constructor");
}
if (f->get_arity() != vars.size()) {
throw parser_exception("mismatching number of variables supplied to constructor");
}
m_num_bindings += vars.size();
for (unsigned i = 0; i < vars.size(); ++i) {
var * v = m().mk_var(i, f->get_domain(i));
args.push_back(v);
if (vars[i] != m_underscore) {
m_env.insert(vars[i], local(v, m_num_bindings));
}
}
expr_stack().push_back(m().mk_app(f, args.size(), args.c_ptr()));
return true;
}
void parse_match_pattern(sort* srt) {
if (parse_constructor_pattern(srt)) {
// done
}
else if (curr_id() == m_underscore) {
// we have a wild-card.
// store dummy variable in expr_stack()
next();
var* v = m().mk_var(0, srt);
expr_stack().push_back(v);
}
else {
symbol xC(check_identifier_next("constructor symbol or variable expected"));
// check if xC is a constructor, otherwise make it a variable
// of sort srt.
try {
func_decl* f = m_ctx.find_func_decl(xC, 0, nullptr, 0, nullptr, srt);
if (!dtutil().is_constructor(f)) {
throw parser_exception("expecting a constructor, got a previously declared function");
}
if (f->get_arity() > 0) {
throw parser_exception("constructor expects arguments, but no arguments were supplied in pattern");
}
expr_stack().push_back(m().mk_const(f));
}
catch (cmd_exception &) {
var* v = m().mk_var(0, srt);
expr_stack().push_back(v);
m_env.insert(xC, local(v, m_num_bindings++));
}
}
} }
symbol parse_indexed_identifier_core() { symbol parse_indexed_identifier_core() {
@ -1613,8 +1775,7 @@ namespace smt2 {
new (mem) app_frame(f, expr_spos, param_spos, has_as); new (mem) app_frame(f, expr_spos, param_spos, has_as);
m_num_expr_frames++; m_num_expr_frames++;
} }
// return true if a new frame was created.
void push_expr_frame(expr_frame * curr) { void push_expr_frame(expr_frame * curr) {
SASSERT(curr_is_lparen()); SASSERT(curr_is_lparen());
next(); next();
@ -1622,11 +1783,7 @@ namespace smt2 {
if (curr_is_identifier()) { if (curr_is_identifier()) {
TRACE("push_expr_frame", tout << "push_expr_frame(), curr_id(): " << curr_id() << "\n";); TRACE("push_expr_frame", tout << "push_expr_frame(), curr_id(): " << curr_id() << "\n";);
if (curr_id_is_let()) { if (curr_id_is_let()) {
next(); push_let_frame();
check_lparen_next("invalid let declaration, '(' expected");
void * mem = m_stack.allocate(sizeof(let_frame));
new (mem) let_frame(symbol_stack().size(), expr_stack().size());
m_num_expr_frames++;
} }
else if (curr_id_is_forall()) { else if (curr_id_is_forall()) {
push_quant_frame(true); push_quant_frame(true);
@ -1635,14 +1792,9 @@ namespace smt2 {
push_quant_frame(false); push_quant_frame(false);
} }
else if (curr_id_is_bang()) { else if (curr_id_is_bang()) {
TRACE("consume_attributes", tout << "begin bang, expr_stack.size(): " << expr_stack().size() << "\n";); push_bang_frame(curr);
next();
void * mem = m_stack.allocate(sizeof(attr_expr_frame));
new (mem) attr_expr_frame(curr, symbol_stack().size(), expr_stack().size());
m_num_expr_frames++;
} }
else if (curr_id_is_as() || curr_id_is_underscore()) { else if (curr_id_is_as() || curr_id_is_underscore()) {
TRACE("push_expr_frame", tout << "push_expr_frame(): parse_qualified_name\n";);
parse_qualified_name(); parse_qualified_name();
} }
else if (curr_id_is_root_obj()) { else if (curr_id_is_root_obj()) {
@ -1825,12 +1977,12 @@ namespace smt2 {
m_stack.deallocate(static_cast<let_decl_frame*>(fr)); m_stack.deallocate(static_cast<let_decl_frame*>(fr));
m_num_expr_frames--; m_num_expr_frames--;
break; break;
case EF_QUANT:
pop_quant_frame(static_cast<quant_frame*>(fr));
break;
case EF_MATCH: case EF_MATCH:
pop_match_frame(static_cast<match_frame*>(fr)); pop_match_frame(static_cast<match_frame*>(fr));
break; break;
case EF_QUANT:
pop_quant_frame(static_cast<quant_frame*>(fr));
break;
case EF_ATTR_EXPR: case EF_ATTR_EXPR:
pop_attr_expr_frame(static_cast<attr_expr_frame*>(fr)); pop_attr_expr_frame(static_cast<attr_expr_frame*>(fr));
break; break;
@ -2287,8 +2439,10 @@ namespace smt2 {
throw cmd_exception("invalid assert command, expression required as argument"); throw cmd_exception("invalid assert command, expression required as argument");
} }
expr * f = expr_stack().back(); expr * f = expr_stack().back();
if (!m().is_bool(f)) if (!m().is_bool(f)) {
TRACE("smt2parser", tout << expr_ref(f, m()) << "\n";);
throw cmd_exception("invalid assert command, term is not Boolean"); throw cmd_exception("invalid assert command, term is not Boolean");
}
if (f == m_last_named_expr.second) { if (f == m_last_named_expr.second) {
m_ctx.assert_expr(m_last_named_expr.first, f); m_ctx.assert_expr(m_last_named_expr.first, f);
} }

View file

@ -4374,9 +4374,17 @@ namespace smt {
expr* fn = to_app(q->get_pattern(0))->get_arg(0); expr* fn = to_app(q->get_pattern(0))->get_arg(0);
expr* body = to_app(q->get_pattern(1))->get_arg(0); expr* body = to_app(q->get_pattern(1))->get_arg(0);
SASSERT(is_app(fn)); SASSERT(is_app(fn));
// reverse argument order so that variable 0 starts at the beginning.
expr_ref_vector subst(m);
for (expr* arg : *to_app(fn)) {
subst.push_back(arg);
}
expr_ref bodyr(m);
var_subst sub(m, false);
sub(body, subst.size(), subst.c_ptr(), bodyr);
func_decl* f = to_app(fn)->get_decl(); func_decl* f = to_app(fn)->get_decl();
func_interp* fi = alloc(func_interp, m, f->get_arity()); func_interp* fi = alloc(func_interp, m, f->get_arity());
fi->set_else(body); fi->set_else(bodyr);
m_model->register_decl(f, fi); m_model->register_decl(f, fi);
} }
} }