diff --git a/src/muz/rel/udoc_relation.cpp b/src/muz/rel/udoc_relation.cpp index ea04e4263..28c9bb9a4 100644 --- a/src/muz/rel/udoc_relation.cpp +++ b/src/muz/rel/udoc_relation.cpp @@ -1057,20 +1057,50 @@ namespace datalog { } class udoc_plugin::join_project_fn : convenient_relation_join_project_fn { - udoc_plugin::join_fn m_joiner; + udoc_plugin::join_fn m_joiner; + union_find_default_ctx union_ctx; + bit_vector m_to_delete; + subset_ints m_equalities; + unsigned_vector m_roots; + public: join_project_fn( udoc_relation const& t1, udoc_relation const& t2, unsigned col_cnt, const unsigned * cols1, const unsigned * cols2, - unsigned removed_col_cnt, unsigned const* removed_cols) + unsigned removed_col_cnt, unsigned const* rm_cols) : convenient_relation_join_project_fn( t1.get_signature(), t2.get_signature(), col_cnt, cols1, cols2, - removed_col_cnt, removed_cols), - m_joiner(t1.get_plugin(), t1, t2, col_cnt, cols1, cols2) + removed_col_cnt, rm_cols), + m_joiner(t1.get_plugin(), t1, t2, col_cnt, cols1, cols2), + m_equalities(union_ctx) { + udoc_plugin& p = t1.get_plugin(); + udoc_relation* res = get(p.mk_empty(get_result_signature())); + unsigned num_bits = res->get_num_bits(); + unsigned num_bits1 = t1.get_num_bits(); + unsigned_vector removed_cols(removed_col_cnt, rm_cols); + unsigned_vector expcols1(col_cnt, cols1); + unsigned_vector expcols2(col_cnt, cols2); + res->expand_column_vector(removed_cols); + t1.expand_column_vector(expcols1); + t2.expand_column_vector(expcols2); + m_to_delete.resize(num_bits, false); + for (unsigned i = 0; i < num_bits; ++i) { + m_equalities.mk_var(); + } + for (unsigned i = 0; i < removed_cols.size(); ++i) { + m_to_delete.set(removed_cols[i], true); + } + for (unsigned i = 0; i < expcols1.size(); ++i) { + m_equalities.merge(expcols1[i], expcols2[i] + num_bits1); + } + m_roots.append(expcols1); + res->deallocate(); } + + // TBD: replace this by "join" given below. virtual relation_base* operator()(relation_base const& t1, relation_base const& t2) { udoc_relation *joined = get(m_joiner(t1, t2)); relation_base* result = 0; @@ -1084,6 +1114,57 @@ namespace datalog { joined->deallocate(); return result; } + private: + + udoc_relation* join(udoc_relation const& t1, udoc_relation const& t2) { + relation_signature prod_signature; + prod_signature.append(t1.get_signature()); + prod_signature.append(t2.get_signature()); + udoc prod; + udoc const& d1 = t1.get_udoc(); + udoc const& d2 = t2.get_udoc(); + doc_manager& dm1 = t1.get_dm(); + udoc_plugin& p = t1.get_plugin(); + doc_manager& dm_prod = p.dm(prod_signature); + udoc_relation* result = get(p.mk_empty(get_result_signature())); + udoc& res = result->get_udoc(); + doc_manager& dm_res = result->get_dm(); + for (unsigned i = 0; i < d1.size(); ++i) { + for (unsigned j = 0; j < d2.size(); ++j) { + prod.push_back(xprod(dm_prod, dm1, d1[i], d2[j])); + } + } + prod.merge(dm_prod, m_roots, m_equalities, m_to_delete); + for (unsigned i = 0; i < prod.size(); ++i) { + res.insert(dm_res, dm_prod.project(dm_res, m_to_delete.size(), m_to_delete, prod[i])); + } + prod.reset(dm_prod); + return result; + } + doc* xprod(doc_manager& dm, doc_manager& dm1, doc const& d1, doc const& d2) { + tbv_manager& tbm = dm.tbvm(); + doc_ref d(dm); + tbv_ref t(tbm); + d = dm.allocateX(); + tbv& pos = d->pos(); + utbv& neg = d->neg(); + unsigned mid = dm1.num_tbits(); + unsigned hi = dm.num_tbits(); + tbm.set(pos,d1.pos(), mid-1, 0); + tbm.set(pos,d2.pos(), hi-1, mid); + for (unsigned i = 0; i < d1.neg().size(); ++i) { + t = tbm.allocateX(); + tbm.set(*t, d1.neg()[i], mid-1, 0); + neg.push_back(t.detach()); + } + for (unsigned i = 0; i < d2.neg().size(); ++i) { + t = tbm.allocateX(); + tbm.set(*t, d2.neg()[i], hi-1, mid); + neg.push_back(t.detach()); + } + return d.detach(); + } + }; relation_join_fn * udoc_plugin::mk_join_project_fn( @@ -1138,13 +1219,12 @@ namespace datalog { m_is_subtract(false), m_is_aliased(true) { SASSERT(joined_col_cnt > 0); - if (joined_col_cnt == r.get_signature().size()) { - m_is_subtract = true; - svector found(joined_col_cnt, false); - for (unsigned i = 0; m_is_subtract && i < joined_col_cnt; ++i) { - m_is_subtract = !found[t_cols[i]] && (t_cols[i] == neg_cols[i]); - found[t_cols[i]] = true; - } + m_is_subtract = (joined_col_cnt == r.get_signature().size()); + m_is_subtract &= (joined_col_cnt == neg.get_signature().size()); + svector found(joined_col_cnt, false); + for (unsigned i = 0; m_is_subtract && i < joined_col_cnt; ++i) { + m_is_subtract = !found[t_cols[i]] && (t_cols[i] == neg_cols[i]); + found[t_cols[i]] = true; } r.expand_column_vector(m_t_cols); neg.expand_column_vector(m_neg_cols);