3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-15 05:18:44 +00:00

updated check_relation test for join-project

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2014-10-06 13:05:53 -07:00
parent 893d51eae8
commit 7ef311acd3
3 changed files with 131 additions and 53 deletions

View file

@ -41,6 +41,12 @@ namespace datalog {
return get_plugin().ground(*this, fml);
}
expr_ref check_relation_plugin::ground(relation_base const& dst) const {
expr_ref fml(m);
dst.to_formula(fml);
return ground(dst, fml);
}
expr_ref check_relation_plugin::ground(relation_base const& dst, expr* fml) const {
relation_signature const& sig = dst.get_signature();
var_subst sub(m, false);
@ -187,7 +193,7 @@ namespace datalog {
check_relation const& t2 = get(r2);
check_relation_plugin& p = t1.get_plugin();
relation_base* r = (*m_join)(t1.rb(), t2.rb());
p.verify_join(r1, r2, *r, m_cols1.size(), m_cols1.c_ptr(), m_cols2.c_ptr());
p.verify_join(r1, r2, *r, m_cols1, m_cols2);
return alloc(check_relation, p, r->get_signature(), r);
}
};
@ -199,6 +205,38 @@ namespace datalog {
return j?alloc(join_fn, j, t1.get_signature(), t2.get_signature(), col_cnt, cols1, cols2):0;
}
class check_relation_plugin::join_project_fn : public convenient_relation_join_project_fn {
scoped_ptr<relation_join_fn> m_join;
public:
join_project_fn(
relation_join_fn* j,
const relation_signature & o1_sig, const relation_signature & o2_sig, unsigned col_cnt,
const unsigned * cols1, const unsigned * cols2,
unsigned removed_col_cnt, const unsigned* removed_cols)
: convenient_join_project_fn(o1_sig, o2_sig, col_cnt, cols1, cols2,
removed_col_cnt, removed_cols), m_join(j)
{}
virtual ~join_project_fn() {}
virtual relation_base * operator()(const relation_base & r1, const relation_base & r2) {
check_relation const& t1 = get(r1);
check_relation const& t2 = get(r2);
check_relation_plugin& p = t1.get_plugin();
relation_base* r = (*m_join)(t1.rb(), t2.rb());
p.verify_join_project(r1, r2, *r, m_cols1, m_cols2, m_removed_cols);
return alloc(check_relation, p, r->get_signature(), r);
}
};
relation_join_fn * check_relation_plugin::mk_join_project_fn(
const relation_base & t1, const relation_base & t2,
unsigned col_cnt, const unsigned * cols1, const unsigned * cols2,
unsigned removed_col_cnt, const unsigned * removed_cols) {
relation_join_fn* j = m_base->mk_join_project_fn(get(t1).rb(), get(t2).rb(), col_cnt, cols1, cols2,
removed_col_cnt, removed_cols);
return j?alloc(join_project_fn, j, t1.get_signature(), t2.get_signature(), col_cnt, cols1, cols2,
removed_col_cnt, removed_cols):0;
}
void check_relation_plugin::verify_filter_project(
relation_base const& src, relation_base const& dst,
app* cond, unsigned_vector const& removed_cols) {
@ -222,39 +260,88 @@ namespace datalog {
relation_base const& src, expr* f1,
relation_base const& dst, expr* f2,
unsigned_vector const& removed_cols) {
expr_ref fml1 = ground(dst, mk_project(src.get_signature(), f1, removed_cols));
expr_ref fml2 = ground(dst, f2);
check_equiv("project", fml1, fml2);
}
expr_ref check_relation_plugin::mk_project(
relation_signature const& sig,
expr* fml, unsigned_vector const& removed_cols) {
expr_ref fml1(m);
expr_ref fml2(m);
expr_ref_vector vars1(m), vars2(m);
ptr_vector<sort> bound;
svector<symbol> names;
relation_signature const& sig1 = src.get_signature();
relation_signature const& sig2 = dst.get_signature();
for (unsigned i = 0; i < sig2.size(); ++i) {
vars2.push_back(m.mk_const(symbol(i), sig2[i]));
}
for (unsigned i = 0, j = 0, k = 0; i < sig1.size(); ++i) {
if (j < removed_cols.size() && removed_cols[j] == i) {
expr_ref_vector vars(m);
unsigned rm_cnt = removed_cols.size();
for (unsigned i = 0, j = 0, k = 0; i < sig.size(); ++i) {
if (j < rm_cnt && removed_cols[j] == i) {
std::ostringstream strm;
strm << "x" << j;
bound.push_back(sig1[i]);
bound.push_back(sig[i]);
names.push_back(symbol(strm.str().c_str()));
vars1.push_back(m.mk_var(j, sig1[i]));
vars.push_back(m.mk_var(j, sig[i]));
++j;
}
else {
vars1.push_back(vars2[k].get());
SASSERT(m.get_sort(vars2[k].get()) == sig1[i]);
vars.push_back(m.mk_var(k + rm_cnt, sig[i]));
++k;
}
}
var_subst sub(m, false);
sub(f1, vars1.size(), vars1.c_ptr(), fml1);
sub(f2, vars2.size(), vars2.c_ptr(), fml2);
sub(fml, vars.size(), vars.c_ptr(), fml1);
bound.reverse();
fml1 = m.mk_exists(bound.size(), bound.c_ptr(), names.c_ptr(), fml1);
check_equiv("project", fml1, fml2);
return fml1;
}
void check_relation_plugin::verify_join_project(
relation_base const& t1, relation_base const& t2, relation_base const& t,
unsigned_vector const& cols1, unsigned_vector const& cols2, unsigned_vector const& rm_cols) {
ast_manager& m = get_ast_manager();
relation_signature const& sig2 = t.get_signature();
relation_signature const& sigA = t1.get_signature();
relation_signature const& sigB = t2.get_signature();
relation_signature sig1;
sig1.append(sigA);
sig1.append(sigB);
expr_ref fml1 = mk_join(t1, t2, cols1, cols2);
fml1 = mk_project(sig1, fml1, rm_cols);
fml1 = ground(t, fml1);
expr_ref fml2(m);
t.to_formula(fml2);
fml2 = ground(t, fml2);
check_equiv("join_project", fml1, fml2);
}
expr_ref check_relation_plugin::mk_join(
relation_base const& t1, relation_base const& t2,
unsigned_vector const& cols1, unsigned_vector const& cols2) {
ast_manager& m = get_ast_manager();
expr_ref fml1(m), fml2(m), fml3(m);
relation_signature const& sig1 = t1.get_signature();
relation_signature const& sig2 = t2.get_signature();
var_ref var1(m), var2(m);
t1.to_formula(fml1);
t2.to_formula(fml2);
var_subst sub(m, false);
expr_ref_vector vars(m);
for (unsigned i = 0; i < sig2.size(); ++i) {
vars.push_back(m.mk_var(i + sig1.size(), sig2[i]));
}
sub(fml2, vars.size(), vars.c_ptr(), fml2);
fml1 = m.mk_and(fml1, fml2);
for (unsigned i = 0; i < cols1.size(); ++i) {
unsigned v1 = cols1[i];
unsigned v2 = cols2[i];
var1 = m.mk_var(v1, sig1[v1]);
var2 = m.mk_var(v2 + sig1.size(), sig2[v2]);
fml1 = m.mk_and(m.mk_eq(var1, var2), fml1);
}
return fml1;
}
void check_relation_plugin::verify_permutation(
relation_base const& src, relation_base const& dst,
unsigned_vector const& cycle) {
@ -293,41 +380,13 @@ namespace datalog {
check_equiv("permutation", fml1, fml2);
}
void check_relation_plugin::verify_join(relation_base const& t1, relation_base const& t2, relation_base const& t,
unsigned sz, unsigned const* cols1, unsigned const* cols2) {
void check_relation_plugin::verify_join(
relation_base const& t1, relation_base const& t2, relation_base const& t,
unsigned_vector const& cols1, unsigned_vector const& cols2) {
ast_manager& m = get_ast_manager();
expr_ref fml1(m), fml2(m), fml3(m);
relation_signature const& sig1 = t1.get_signature();
relation_signature const& sig2 = t2.get_signature();
relation_signature const& sig = t.get_signature();
var_ref var1(m), var2(m);
t1.to_formula(fml1);
t2.to_formula(fml2);
t.to_formula(fml3);
var_subst sub(m, false);
expr_ref_vector vars(m);
for (unsigned i = 0; i < sig2.size(); ++i) {
vars.push_back(m.mk_var(i + sig1.size(), sig2[i]));
}
sub(fml2, vars.size(), vars.c_ptr(), fml2);
fml1 = m.mk_and(fml1, fml2);
for (unsigned i = 0; i < sz; ++i) {
unsigned v1 = cols1[i];
unsigned v2 = cols2[i];
var1 = m.mk_var(v1, sig1[v1]);
var2 = m.mk_var(v2 + sig1.size(), sig2[v2]);
fml1 = m.mk_and(m.mk_eq(var1, var2), fml1);
}
vars.reset();
for (unsigned i = 0; i < sig.size(); ++i) {
std::stringstream strm;
strm << "x" << i;
vars.push_back(m.mk_const(symbol(strm.str().c_str()), sig[i]));
}
sub(fml1, vars.size(), vars.c_ptr(), fml1);
sub(fml3, vars.size(), vars.c_ptr(), fml3);
check_equiv("join", fml1, fml3);
expr_ref fml1 = ground(t, mk_join(t1, t2, cols1, cols2));
expr_ref fml2 = ground(t);
check_equiv("join", fml1, fml2);
}
void check_relation_plugin::verify_filter(expr* fml0, relation_base const& t, expr* cond) {

View file

@ -62,6 +62,7 @@ namespace datalog {
friend class check_relation;
class join_fn;
class join_project_fn;
class project_fn;
class union_fn;
class rename_fn;
@ -78,6 +79,15 @@ namespace datalog {
static check_relation* get(relation_base* r);
static check_relation const & get(relation_base const& r);
expr_ref ground(relation_base const& rb, expr* fml) const;
expr_ref ground(relation_base const& rb) const;
expr_ref mk_project(
relation_signature const& sig,
expr* fml, unsigned_vector const& removed_cols);
expr_ref mk_join(
relation_base const& t1, relation_base const& t2,
unsigned_vector const& cols1, unsigned_vector const& cols2);
public:
check_relation_plugin(relation_manager& rm);
~check_relation_plugin();
@ -89,6 +99,10 @@ namespace datalog {
virtual relation_base * mk_full(func_decl* p, const relation_signature & s);
virtual relation_join_fn * mk_join_fn(const relation_base & t1, const relation_base & t2,
unsigned col_cnt, const unsigned * cols1, const unsigned * cols2);
virtual relation_join_fn * mk_join_project_fn(
const relation_base & t1, const relation_base & t2,
unsigned col_cnt, const unsigned * cols1, const unsigned * cols2,
unsigned removed_col_cnt, const unsigned * removed_cols);
virtual relation_transformer_fn * mk_project_fn(const relation_base & t, unsigned col_cnt,
const unsigned * removed_cols);
virtual relation_transformer_fn * mk_rename_fn(const relation_base & t, unsigned permutation_cycle_len,
@ -111,7 +125,8 @@ namespace datalog {
unsigned removed_col_cnt, const unsigned * removed_cols);
void verify_join(relation_base const& t1, relation_base const& t2, relation_base const& t,
unsigned sz, unsigned const* cols1, unsigned const* cols2);
unsigned_vector const& cols1, unsigned_vector const& cols2);
void verify_filter(expr* fml0, relation_base const& t, expr* cond);
@ -136,6 +151,10 @@ namespace datalog {
relation_base const& src, relation_base const& dst,
app* cond, unsigned_vector const& removed_cols);
void verify_join_project(
relation_base const& t1, relation_base const& t2, relation_base const& t,
unsigned_vector const& cols1, unsigned_vector const& cols2, unsigned_vector const& rm_cols);
void check_equiv(char const* objective, expr* f1, expr* f2);
void check_contains(char const* objective, expr* f1, expr* f2);

View file

@ -456,7 +456,7 @@ public:
join_fn = p.mk_join_fn(*t1, *t2, jc1.size(), jc1.c_ptr(), jc2.c_ptr());
t = (*join_fn)(*t1, *t2);
cr.verify_join(*t1, *t2, *t, jc1.size(), jc1.c_ptr(), jc2.c_ptr());
cr.verify_join(*t1, *t2, *t, jc1, jc2);
t1->display(std::cout);
t2->display(std::cout);
t->display(std::cout);