diff --git a/src/muz/rel/check_relation.cpp b/src/muz/rel/check_relation.cpp index 46d1171d0..c1b840c61 100644 --- a/src/muz/rel/check_relation.cpp +++ b/src/muz/rel/check_relation.cpp @@ -41,6 +41,12 @@ namespace datalog { return get_plugin().ground(*this, fml); } + expr_ref check_relation_plugin::ground(relation_base const& dst) const { + expr_ref fml(m); + dst.to_formula(fml); + return ground(dst, fml); + } + expr_ref check_relation_plugin::ground(relation_base const& dst, expr* fml) const { relation_signature const& sig = dst.get_signature(); var_subst sub(m, false); @@ -187,7 +193,7 @@ namespace datalog { check_relation const& t2 = get(r2); check_relation_plugin& p = t1.get_plugin(); relation_base* r = (*m_join)(t1.rb(), t2.rb()); - p.verify_join(r1, r2, *r, m_cols1.size(), m_cols1.c_ptr(), m_cols2.c_ptr()); + p.verify_join(r1, r2, *r, m_cols1, m_cols2); return alloc(check_relation, p, r->get_signature(), r); } }; @@ -199,6 +205,38 @@ namespace datalog { return j?alloc(join_fn, j, t1.get_signature(), t2.get_signature(), col_cnt, cols1, cols2):0; } + class check_relation_plugin::join_project_fn : public convenient_relation_join_project_fn { + scoped_ptr m_join; + public: + join_project_fn( + relation_join_fn* j, + const relation_signature & o1_sig, const relation_signature & o2_sig, unsigned col_cnt, + const unsigned * cols1, const unsigned * cols2, + unsigned removed_col_cnt, const unsigned* removed_cols) + : convenient_join_project_fn(o1_sig, o2_sig, col_cnt, cols1, cols2, + removed_col_cnt, removed_cols), m_join(j) + {} + virtual ~join_project_fn() {} + virtual relation_base * operator()(const relation_base & r1, const relation_base & r2) { + check_relation const& t1 = get(r1); + check_relation const& t2 = get(r2); + check_relation_plugin& p = t1.get_plugin(); + relation_base* r = (*m_join)(t1.rb(), t2.rb()); + p.verify_join_project(r1, r2, *r, m_cols1, m_cols2, m_removed_cols); + return alloc(check_relation, p, r->get_signature(), r); + } + }; + + relation_join_fn * check_relation_plugin::mk_join_project_fn( + const relation_base & t1, const relation_base & t2, + unsigned col_cnt, const unsigned * cols1, const unsigned * cols2, + unsigned removed_col_cnt, const unsigned * removed_cols) { + relation_join_fn* j = m_base->mk_join_project_fn(get(t1).rb(), get(t2).rb(), col_cnt, cols1, cols2, + removed_col_cnt, removed_cols); + return j?alloc(join_project_fn, j, t1.get_signature(), t2.get_signature(), col_cnt, cols1, cols2, + removed_col_cnt, removed_cols):0; + } + void check_relation_plugin::verify_filter_project( relation_base const& src, relation_base const& dst, app* cond, unsigned_vector const& removed_cols) { @@ -222,39 +260,88 @@ namespace datalog { relation_base const& src, expr* f1, relation_base const& dst, expr* f2, unsigned_vector const& removed_cols) { + expr_ref fml1 = ground(dst, mk_project(src.get_signature(), f1, removed_cols)); + expr_ref fml2 = ground(dst, f2); + check_equiv("project", fml1, fml2); + } + expr_ref check_relation_plugin::mk_project( + relation_signature const& sig, + expr* fml, unsigned_vector const& removed_cols) { expr_ref fml1(m); - expr_ref fml2(m); - expr_ref_vector vars1(m), vars2(m); ptr_vector bound; svector names; - relation_signature const& sig1 = src.get_signature(); - relation_signature const& sig2 = dst.get_signature(); - for (unsigned i = 0; i < sig2.size(); ++i) { - vars2.push_back(m.mk_const(symbol(i), sig2[i])); - } - for (unsigned i = 0, j = 0, k = 0; i < sig1.size(); ++i) { - if (j < removed_cols.size() && removed_cols[j] == i) { + expr_ref_vector vars(m); + unsigned rm_cnt = removed_cols.size(); + for (unsigned i = 0, j = 0, k = 0; i < sig.size(); ++i) { + if (j < rm_cnt && removed_cols[j] == i) { std::ostringstream strm; strm << "x" << j; - bound.push_back(sig1[i]); + bound.push_back(sig[i]); names.push_back(symbol(strm.str().c_str())); - vars1.push_back(m.mk_var(j, sig1[i])); + vars.push_back(m.mk_var(j, sig[i])); ++j; } else { - vars1.push_back(vars2[k].get()); - SASSERT(m.get_sort(vars2[k].get()) == sig1[i]); + vars.push_back(m.mk_var(k + rm_cnt, sig[i])); ++k; } } var_subst sub(m, false); - sub(f1, vars1.size(), vars1.c_ptr(), fml1); - sub(f2, vars2.size(), vars2.c_ptr(), fml2); + sub(fml, vars.size(), vars.c_ptr(), fml1); bound.reverse(); fml1 = m.mk_exists(bound.size(), bound.c_ptr(), names.c_ptr(), fml1); - check_equiv("project", fml1, fml2); + return fml1; } + void check_relation_plugin::verify_join_project( + relation_base const& t1, relation_base const& t2, relation_base const& t, + unsigned_vector const& cols1, unsigned_vector const& cols2, unsigned_vector const& rm_cols) { + ast_manager& m = get_ast_manager(); + relation_signature const& sig2 = t.get_signature(); + relation_signature const& sigA = t1.get_signature(); + relation_signature const& sigB = t2.get_signature(); + relation_signature sig1; + sig1.append(sigA); + sig1.append(sigB); + + expr_ref fml1 = mk_join(t1, t2, cols1, cols2); + fml1 = mk_project(sig1, fml1, rm_cols); + fml1 = ground(t, fml1); + expr_ref fml2(m); + t.to_formula(fml2); + fml2 = ground(t, fml2); + check_equiv("join_project", fml1, fml2); + } + + expr_ref check_relation_plugin::mk_join( + relation_base const& t1, relation_base const& t2, + unsigned_vector const& cols1, unsigned_vector const& cols2) { + ast_manager& m = get_ast_manager(); + expr_ref fml1(m), fml2(m), fml3(m); + + relation_signature const& sig1 = t1.get_signature(); + relation_signature const& sig2 = t2.get_signature(); + var_ref var1(m), var2(m); + t1.to_formula(fml1); + t2.to_formula(fml2); + var_subst sub(m, false); + expr_ref_vector vars(m); + for (unsigned i = 0; i < sig2.size(); ++i) { + vars.push_back(m.mk_var(i + sig1.size(), sig2[i])); + } + sub(fml2, vars.size(), vars.c_ptr(), fml2); + fml1 = m.mk_and(fml1, fml2); + for (unsigned i = 0; i < cols1.size(); ++i) { + unsigned v1 = cols1[i]; + unsigned v2 = cols2[i]; + var1 = m.mk_var(v1, sig1[v1]); + var2 = m.mk_var(v2 + sig1.size(), sig2[v2]); + fml1 = m.mk_and(m.mk_eq(var1, var2), fml1); + } + return fml1; + } + + void check_relation_plugin::verify_permutation( relation_base const& src, relation_base const& dst, unsigned_vector const& cycle) { @@ -293,41 +380,13 @@ namespace datalog { check_equiv("permutation", fml1, fml2); } - void check_relation_plugin::verify_join(relation_base const& t1, relation_base const& t2, relation_base const& t, - unsigned sz, unsigned const* cols1, unsigned const* cols2) { + void check_relation_plugin::verify_join( + relation_base const& t1, relation_base const& t2, relation_base const& t, + unsigned_vector const& cols1, unsigned_vector const& cols2) { ast_manager& m = get_ast_manager(); - expr_ref fml1(m), fml2(m), fml3(m); - - relation_signature const& sig1 = t1.get_signature(); - relation_signature const& sig2 = t2.get_signature(); - relation_signature const& sig = t.get_signature(); - var_ref var1(m), var2(m); - t1.to_formula(fml1); - t2.to_formula(fml2); - t.to_formula(fml3); - var_subst sub(m, false); - expr_ref_vector vars(m); - for (unsigned i = 0; i < sig2.size(); ++i) { - vars.push_back(m.mk_var(i + sig1.size(), sig2[i])); - } - sub(fml2, vars.size(), vars.c_ptr(), fml2); - fml1 = m.mk_and(fml1, fml2); - for (unsigned i = 0; i < sz; ++i) { - unsigned v1 = cols1[i]; - unsigned v2 = cols2[i]; - var1 = m.mk_var(v1, sig1[v1]); - var2 = m.mk_var(v2 + sig1.size(), sig2[v2]); - fml1 = m.mk_and(m.mk_eq(var1, var2), fml1); - } - vars.reset(); - for (unsigned i = 0; i < sig.size(); ++i) { - std::stringstream strm; - strm << "x" << i; - vars.push_back(m.mk_const(symbol(strm.str().c_str()), sig[i])); - } - sub(fml1, vars.size(), vars.c_ptr(), fml1); - sub(fml3, vars.size(), vars.c_ptr(), fml3); - check_equiv("join", fml1, fml3); + expr_ref fml1 = ground(t, mk_join(t1, t2, cols1, cols2)); + expr_ref fml2 = ground(t); + check_equiv("join", fml1, fml2); } void check_relation_plugin::verify_filter(expr* fml0, relation_base const& t, expr* cond) { diff --git a/src/muz/rel/check_relation.h b/src/muz/rel/check_relation.h index 580ce3dd4..8000a2c72 100644 --- a/src/muz/rel/check_relation.h +++ b/src/muz/rel/check_relation.h @@ -62,6 +62,7 @@ namespace datalog { friend class check_relation; class join_fn; + class join_project_fn; class project_fn; class union_fn; class rename_fn; @@ -78,6 +79,15 @@ namespace datalog { static check_relation* get(relation_base* r); static check_relation const & get(relation_base const& r); expr_ref ground(relation_base const& rb, expr* fml) const; + expr_ref ground(relation_base const& rb) const; + + expr_ref mk_project( + relation_signature const& sig, + expr* fml, unsigned_vector const& removed_cols); + + expr_ref mk_join( + relation_base const& t1, relation_base const& t2, + unsigned_vector const& cols1, unsigned_vector const& cols2); public: check_relation_plugin(relation_manager& rm); ~check_relation_plugin(); @@ -89,6 +99,10 @@ namespace datalog { virtual relation_base * mk_full(func_decl* p, const relation_signature & s); virtual relation_join_fn * mk_join_fn(const relation_base & t1, const relation_base & t2, unsigned col_cnt, const unsigned * cols1, const unsigned * cols2); + virtual relation_join_fn * mk_join_project_fn( + const relation_base & t1, const relation_base & t2, + unsigned col_cnt, const unsigned * cols1, const unsigned * cols2, + unsigned removed_col_cnt, const unsigned * removed_cols); virtual relation_transformer_fn * mk_project_fn(const relation_base & t, unsigned col_cnt, const unsigned * removed_cols); virtual relation_transformer_fn * mk_rename_fn(const relation_base & t, unsigned permutation_cycle_len, @@ -111,7 +125,8 @@ namespace datalog { unsigned removed_col_cnt, const unsigned * removed_cols); void verify_join(relation_base const& t1, relation_base const& t2, relation_base const& t, - unsigned sz, unsigned const* cols1, unsigned const* cols2); + unsigned_vector const& cols1, unsigned_vector const& cols2); + void verify_filter(expr* fml0, relation_base const& t, expr* cond); @@ -136,6 +151,10 @@ namespace datalog { relation_base const& src, relation_base const& dst, app* cond, unsigned_vector const& removed_cols); + void verify_join_project( + relation_base const& t1, relation_base const& t2, relation_base const& t, + unsigned_vector const& cols1, unsigned_vector const& cols2, unsigned_vector const& rm_cols); + void check_equiv(char const* objective, expr* f1, expr* f2); void check_contains(char const* objective, expr* f1, expr* f2); diff --git a/src/test/udoc_relation.cpp b/src/test/udoc_relation.cpp index 3fc727d14..24ccb34b0 100644 --- a/src/test/udoc_relation.cpp +++ b/src/test/udoc_relation.cpp @@ -456,7 +456,7 @@ public: join_fn = p.mk_join_fn(*t1, *t2, jc1.size(), jc1.c_ptr(), jc2.c_ptr()); t = (*join_fn)(*t1, *t2); - cr.verify_join(*t1, *t2, *t, jc1.size(), jc1.c_ptr(), jc2.c_ptr()); + cr.verify_join(*t1, *t2, *t, jc1, jc2); t1->display(std::cout); t2->display(std::cout); t->display(std::cout);