From 2d01c4d50f77963028ea39b45f89896502495c5e Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 24 Sep 2013 06:41:46 +0300 Subject: [PATCH] update join planner to take projected columns into account Signed-off-by: Nikolaj Bjorner --- src/muz/rel/dl_mk_simple_joins.cpp | 145 ++++++++++++++++------------- 1 file changed, 82 insertions(+), 63 deletions(-) diff --git a/src/muz/rel/dl_mk_simple_joins.cpp b/src/muz/rel/dl_mk_simple_joins.cpp index 4b9ce582a..c2214dad9 100644 --- a/src/muz/rel/dl_mk_simple_joins.cpp +++ b/src/muz/rel/dl_mk_simple_joins.cpp @@ -47,9 +47,9 @@ namespace datalog { being notified about it, it will surely see the decrease from length 3 to 2 which the threshold for rule being counted in this counter. */ - unsigned m_consumers; - bool m_stratified; - unsigned m_src_stratum; + unsigned m_consumers; + bool m_stratified; + unsigned m_src_stratum; public: var_idx_set m_all_nonlocal_vars; rule_vector m_rules; @@ -57,16 +57,13 @@ namespace datalog { pair_info() : m_consumers(0), m_stratified(true), m_src_stratum(0) {} bool can_be_joined() const { - return m_consumers>0; + return m_consumers > 0; } cost get_cost() const { - /*if(m_instantiated) { - return std::numeric_limits::min(); - }*/ - SASSERT(m_consumers>0); + SASSERT(m_consumers > 0); cost amortized = m_total_cost/m_consumers; - if(m_stratified) { + if (m_stratified) { return amortized * ( (amortized>0) ? (1/16.0f) : 16.0f); } else { @@ -81,19 +78,20 @@ namespace datalog { by the time of a call to this function */ void add_rule(join_planner & pl, app * t1, app * t2, rule * r, - const var_idx_set & non_local_vars_normalized) { - if(m_rules.empty()) { - m_total_cost = pl.compute_cost(t1, t2); + const var_idx_set & non_local_vars_normalized, + const var_idx_set & non_local_vars) { + if (m_rules.empty()) { + m_total_cost = pl.compute_cost(t1, t2, non_local_vars); m_src_stratum = std::max(pl.get_stratum(t1->get_decl()), pl.get_stratum(t2->get_decl())); } m_rules.push_back(r); - if(pl.m_rules_content.find_core(r)->get_data().m_value.size()>2) { + if (pl.m_rules_content.find(r).size()>2) { m_consumers++; } - if(m_stratified) { + if (m_stratified) { unsigned head_stratum = pl.get_stratum(r->get_decl()); SASSERT(head_stratum>=m_src_stratum); - if(head_stratum==m_src_stratum) { + if (head_stratum==m_src_stratum) { m_stratified = false; } } @@ -105,7 +103,7 @@ namespace datalog { */ bool remove_rule(rule * r, unsigned original_length) { TRUSTME( remove_from_vector(m_rules, r) ); - if(original_length>2) { + if (original_length>2) { SASSERT(m_consumers>0); m_consumers--; } @@ -165,7 +163,7 @@ namespace datalog { SASSERT(is_var(t->get_arg(i))); var * v = to_var(t->get_arg(i)); unsigned var_idx = v->get_idx(); - if(result[res_ofs-var_idx]==0) { + if (result[res_ofs-var_idx]==0) { result[res_ofs-var_idx]=m.mk_var(next_var, v->get_sort()); next_var++; } @@ -174,7 +172,7 @@ namespace datalog { void get_normalizer(app * t1, app * t2, expr_ref_vector & result) const { SASSERT(result.empty()); - if(t1->get_num_args()==0 && t2->get_num_args()==0) { + if (t1->get_num_args()==0 && t2->get_num_args()==0) { return; //nothing to normalize } SASSERT(!t1->is_ground() || !t2->is_ground()); @@ -186,14 +184,14 @@ namespace datalog { var_idx_set::iterator ovend = orig_var_set.end(); for(; ovit!=ovend; ++ovit) { unsigned var_idx = *ovit; - if(var_idx>max_var_idx) { + if (var_idx>max_var_idx) { max_var_idx = var_idx; } } } - if(t1->get_decl()!=t2->get_decl()) { - if(t1->get_decl()->get_id()get_decl()->get_id()) { + if (t1->get_decl()!=t2->get_decl()) { + if (t1->get_decl()->get_id()get_decl()->get_id()) { std::swap(t1, t2); } } @@ -207,9 +205,9 @@ namespace datalog { //so the only literals which appear in pairs are the ones that contain only variables. var * v1 = to_var(t1->get_arg(i)); var * v2 = to_var(t2->get_arg(i)); - if(v1->get_sort()!=v2->get_sort()) { + if (v1->get_sort()!=v2->get_sort()) { //different sorts mean we can distinguish the two terms - if(v1->get_sort()->get_id()get_sort()->get_id()) { + if (v1->get_sort()->get_id()get_sort()->get_id()) { std::swap(t1, t2); } break; @@ -221,9 +219,9 @@ namespace datalog { SASSERT(norm1[v1_idx]==-1); SASSERT(norm2[v2_idx]==-1); - if(norm2[v1_idx]!=norm1[v2_idx]) { + if (norm2[v1_idx]!=norm1[v2_idx]) { //now we can distinguish the two terms - if(norm2[v1_idx]t2n) { + if (t1n>t2n) { std::swap(t1n, t2n); } m_pinned.push_back(t1n); @@ -274,12 +272,10 @@ namespace datalog { by the time of a call to this function */ void register_pair(app * t1, app * t2, rule * r, const var_idx_set & non_local_vars) { - TRACE("dl", tout << mk_pp(t1, m) << " " << mk_pp(t2, m) << "\n"; - r->display(m_context, tout); tout << "\n";); SASSERT(t1!=t2); cost_map::entry * e = m_costs.insert_if_not_there2(get_key(t1, t2), 0); pair_info * & ptr_inf = e->get_data().m_value; - if(ptr_inf==0) { + if (ptr_inf==0) { ptr_inf = alloc(pair_info); } pair_info & inf = *ptr_inf; @@ -288,25 +284,30 @@ namespace datalog { get_normalizer(t1, t2, normalizer); unsigned norm_ofs = normalizer.size()-1; var_idx_set normalized_vars; - var_idx_set::iterator vit = non_local_vars.begin(); + var_idx_set::iterator vit = non_local_vars.begin(); var_idx_set::iterator vend = non_local_vars.end(); for(; vit!=vend; ++vit) { unsigned norm_var = to_var(normalizer.get(norm_ofs-*vit))->get_idx(); normalized_vars.insert(norm_var); } - inf.add_rule(*this, t1, t2, r, normalized_vars); + inf.add_rule(*this, t1, t2, r, normalized_vars, non_local_vars); + TRACE("dl", tout << mk_pp(t1, m) << " " << mk_pp(t2, m) << " "; + vit = non_local_vars.begin(); + for (; vit != vend; ++vit) tout << *vit << " "; + tout << "\n"; + r->display(m_context, tout); + if (inf.can_be_joined()) tout << "cost: " << inf.get_cost() << "\n";); + } pair_info & get_pair(app_pair key) const { - cost_map::entry * e = m_costs.find_core(key); - SASSERT(e); - return *e->get_data().m_value; + return *m_costs.find(key); } void remove_rule_from_pair(app_pair key, rule * r, unsigned original_len) { pair_info * ptr = &get_pair(key); - if(ptr->remove_rule(r, original_len)) { + if (ptr->remove_rule(r, original_len)) { SASSERT(ptr->m_rules.empty()); m_costs.remove(key); dealloc(ptr); @@ -349,7 +350,7 @@ namespace datalog { unsigned n=t->get_num_args(); for(unsigned i=0; iget_arg(i)); - if(v->get_idx()==var_idx) { + if (v->get_idx()==var_idx) { args.push_back(v); domain.push_back(m.get_sort(v)); return true; @@ -375,7 +376,7 @@ namespace datalog { unsigned var_idx=*ovit; bool found=extract_argument_info(var_idx, t1, args, domain); - if(!found) { + if (!found) { found=extract_argument_info(var_idx, t2, args, domain); } SASSERT(found); @@ -389,7 +390,7 @@ namespace datalog { func_decl* parent_head = one_parent->get_decl(); const char * one_parent_name = parent_head->get_name().bare_str(); std::string parent_name; - if(inf.m_rules.size()>1) { + if (inf.m_rules.size()>1) { parent_name = one_parent_name + std::string("_and_") + to_string(inf.m_rules.size()-1); } else { @@ -443,7 +444,7 @@ namespace datalog { } //remove edges between surviving tails and removed tails for(unsigned i=0; i & rule_content = m_rules_content.find_core(r)->get_data().m_value; + ptr_vector & rule_content = m_rules_content.find(r); unsigned len = rule_content.size(); - if(len==1) { + if (len==1) { return; } @@ -515,16 +516,16 @@ namespace datalog { ptr_vector added_tails; for(unsigned i1=0; i1get_decl()!=t1_pred) { + if (rt1->get_decl()!=t1_pred) { continue; } unsigned i2start = (t1_pred==t2_pred) ? (i1+1) : 0; for(unsigned i2=i2start; i2get_decl()!=t2_pred) { + if (i1==i2 || rt2->get_decl()!=t2_pred) { continue; } - if(get_key(rt1, rt2)!=pair_key) { + if (get_key(rt1, rt2)!=pair_key) { continue; } expr_ref_vector normalizer(m); @@ -558,7 +559,7 @@ namespace datalog { relation_sort sort = pred->get_domain(arg_index); return static_cast(m_context.get_sort_size_estimate(sort)); //unsigned sz; - //if(!m_context.get_sort_size(sort, sz)) { + //if (!m_context.get_sort_size(sort, sz)) { // sz=UINT_MAX; //} //return static_cast(sz); @@ -576,15 +577,15 @@ namespace datalog { return cost(1); } relation_manager& rm = rel->get_rmanager(); - if( (m_context.saturation_was_run() && rm.try_get_relation(pred)) + if ( (m_context.saturation_was_run() && rm.try_get_relation(pred)) || rm.is_saturated(pred)) { SASSERT(rm.try_get_relation(pred)); //if it is saturated, it should exist unsigned rel_size_int = rel->get_relation(pred).get_size_estimate_rows(); - if(rel_size_int!=0) { + if (rel_size_int!=0) { cost rel_size = static_cast(rel_size_int); cost curr_size = rel_size; for(unsigned i=0; iget_arg(i))) { + if (!is_var(t->get_arg(i))) { curr_size /= get_domain_size(pred, i); } } @@ -593,40 +594,58 @@ namespace datalog { } cost res = 1; for(unsigned i=0; iget_arg(i))) { + if (is_var(t->get_arg(i))) { res *= get_domain_size(pred, i); } } return res; } - cost compute_cost(app * t1, app * t2) const { + cost compute_cost(app * t1, app * t2, const var_idx_set & non_local_vars) const { func_decl * t1_pred = t1->get_decl(); func_decl * t2_pred = t2->get_decl(); cost inters_size = 1; variable_intersection vi(m_context.get_manager()); vi.populate(t1, t2); unsigned n = vi.size(); + // remove contributions from joined columns. for(unsigned i=0; iget_arg(arg_index1))); + if (non_local_vars.contains(to_var(t1->get_arg(arg_index1))->get_idx())) { + inters_size *= get_domain_size(t1_pred, arg_index1); + } //joined arguments must have the same domain SASSERT(get_domain_size(t1_pred, arg_index1)==get_domain_size(t2_pred, arg_index2)); } - cost res = estimate_size(t1)*estimate_size(t2)/(inters_size*inters_size); + // remove contributions from projected columns. + for (unsigned i = 0; i < t1->get_num_args(); ++i) { + if (is_var(t1->get_arg(i)) && + !non_local_vars.contains(to_var(t1->get_arg(i))->get_idx())) { + inters_size *= get_domain_size(t1_pred, i); + } + } + for (unsigned i = 0; i < t2->get_num_args(); ++i) { + if (is_var(t2->get_arg(i)) && + !non_local_vars.contains(to_var(t2->get_arg(i))->get_idx())) { + inters_size *= get_domain_size(t2_pred, i); + } + } + + cost res = estimate_size(t1)*estimate_size(t2)/ inters_size; // (inters_size*inters_size); //cost res = -inters_size; /*unsigned t1_strat = get_stratum(t1_pred); SASSERT(t1_strat<=m_head_stratum); - if(t1_strat0) { + if (res>0) { res /= 2; } else { @@ -653,17 +672,17 @@ namespace datalog { for(; it!=end; ++it) { app_pair key = it->m_key; pair_info & inf = *it->m_value; - if(!inf.can_be_joined()) { + if (!inf.can_be_joined()) { continue; } cost c = inf.get_cost(); - if(!found || cm_key; ptr_vector content = rcit->m_value; SASSERT(content.size()<=2); - if(content.size()==orig_r->get_positive_tail_size()) { + if (content.size()==orig_r->get_positive_tail_size()) { //rule did not change result->add_rule(orig_r); continue; @@ -728,7 +747,7 @@ namespace datalog { rule_set * mk_simple_joins::operator()(rule_set const & source) { rule_set rs_aux_copy(m_context); rs_aux_copy.replace_rules(source); - if(!rs_aux_copy.is_closed()) { + if (!rs_aux_copy.is_closed()) { rs_aux_copy.close(); }