3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-19 04:13:38 +00:00

update join planner to take projected columns into account

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2013-09-24 06:41:46 +03:00
parent 1733af2641
commit 2d01c4d50f

View file

@ -61,9 +61,6 @@ namespace datalog {
} }
cost get_cost() const { cost get_cost() const {
/*if(m_instantiated) {
return std::numeric_limits<cost>::min();
}*/
SASSERT(m_consumers > 0); SASSERT(m_consumers > 0);
cost amortized = m_total_cost/m_consumers; cost amortized = m_total_cost/m_consumers;
if (m_stratified) { if (m_stratified) {
@ -81,13 +78,14 @@ namespace datalog {
by the time of a call to this function by the time of a call to this function
*/ */
void add_rule(join_planner & pl, app * t1, app * t2, rule * r, void add_rule(join_planner & pl, app * t1, app * t2, rule * r,
const var_idx_set & non_local_vars_normalized) { const var_idx_set & non_local_vars_normalized,
const var_idx_set & non_local_vars) {
if (m_rules.empty()) { if (m_rules.empty()) {
m_total_cost = pl.compute_cost(t1, t2); m_total_cost = pl.compute_cost(t1, t2, non_local_vars);
m_src_stratum = std::max(pl.get_stratum(t1->get_decl()), pl.get_stratum(t2->get_decl())); m_src_stratum = std::max(pl.get_stratum(t1->get_decl()), pl.get_stratum(t2->get_decl()));
} }
m_rules.push_back(r); m_rules.push_back(r);
if(pl.m_rules_content.find_core(r)->get_data().m_value.size()>2) { if (pl.m_rules_content.find(r).size()>2) {
m_consumers++; m_consumers++;
} }
if (m_stratified) { if (m_stratified) {
@ -274,8 +272,6 @@ namespace datalog {
by the time of a call to this function by the time of a call to this function
*/ */
void register_pair(app * t1, app * t2, rule * r, const var_idx_set & non_local_vars) { void register_pair(app * t1, app * t2, rule * r, const var_idx_set & non_local_vars) {
TRACE("dl", tout << mk_pp(t1, m) << " " << mk_pp(t2, m) << "\n";
r->display(m_context, tout); tout << "\n";);
SASSERT(t1!=t2); SASSERT(t1!=t2);
cost_map::entry * e = m_costs.insert_if_not_there2(get_key(t1, t2), 0); cost_map::entry * e = m_costs.insert_if_not_there2(get_key(t1, t2), 0);
pair_info * & ptr_inf = e->get_data().m_value; pair_info * & ptr_inf = e->get_data().m_value;
@ -295,13 +291,18 @@ namespace datalog {
normalized_vars.insert(norm_var); normalized_vars.insert(norm_var);
} }
inf.add_rule(*this, t1, t2, r, normalized_vars); inf.add_rule(*this, t1, t2, r, normalized_vars, non_local_vars);
TRACE("dl", tout << mk_pp(t1, m) << " " << mk_pp(t2, m) << " ";
vit = non_local_vars.begin();
for (; vit != vend; ++vit) tout << *vit << " ";
tout << "\n";
r->display(m_context, tout);
if (inf.can_be_joined()) tout << "cost: " << inf.get_cost() << "\n";);
} }
pair_info & get_pair(app_pair key) const { pair_info & get_pair(app_pair key) const {
cost_map::entry * e = m_costs.find_core(key); return *m_costs.find(key);
SASSERT(e);
return *e->get_data().m_value;
} }
void remove_rule_from_pair(app_pair key, rule * r, unsigned original_len) { void remove_rule_from_pair(app_pair key, rule * r, unsigned original_len) {
@ -503,7 +504,7 @@ namespace datalog {
void apply_binary_rule(rule * r, app_pair pair_key, app * t_new) { void apply_binary_rule(rule * r, app_pair pair_key, app * t_new) {
app * t1 = pair_key.first; app * t1 = pair_key.first;
app * t2 = pair_key.second; app * t2 = pair_key.second;
ptr_vector<app> & rule_content = m_rules_content.find_core(r)->get_data().m_value; ptr_vector<app> & rule_content = m_rules_content.find(r);
unsigned len = rule_content.size(); unsigned len = rule_content.size();
if (len==1) { if (len==1) {
return; return;
@ -600,21 +601,39 @@ namespace datalog {
return res; return res;
} }
cost compute_cost(app * t1, app * t2) const { cost compute_cost(app * t1, app * t2, const var_idx_set & non_local_vars) const {
func_decl * t1_pred = t1->get_decl(); func_decl * t1_pred = t1->get_decl();
func_decl * t2_pred = t2->get_decl(); func_decl * t2_pred = t2->get_decl();
cost inters_size = 1; cost inters_size = 1;
variable_intersection vi(m_context.get_manager()); variable_intersection vi(m_context.get_manager());
vi.populate(t1, t2); vi.populate(t1, t2);
unsigned n = vi.size(); unsigned n = vi.size();
// remove contributions from joined columns.
for(unsigned i=0; i<n; i++) { for(unsigned i=0; i<n; i++) {
unsigned arg_index1, arg_index2; unsigned arg_index1, arg_index2;
vi.get(i, arg_index1, arg_index2); vi.get(i, arg_index1, arg_index2);
SASSERT(is_var(t1->get_arg(arg_index1)));
if (non_local_vars.contains(to_var(t1->get_arg(arg_index1))->get_idx())) {
inters_size *= get_domain_size(t1_pred, arg_index1); inters_size *= get_domain_size(t1_pred, arg_index1);
}
//joined arguments must have the same domain //joined arguments must have the same domain
SASSERT(get_domain_size(t1_pred, arg_index1)==get_domain_size(t2_pred, arg_index2)); SASSERT(get_domain_size(t1_pred, arg_index1)==get_domain_size(t2_pred, arg_index2));
} }
cost res = estimate_size(t1)*estimate_size(t2)/(inters_size*inters_size); // remove contributions from projected columns.
for (unsigned i = 0; i < t1->get_num_args(); ++i) {
if (is_var(t1->get_arg(i)) &&
!non_local_vars.contains(to_var(t1->get_arg(i))->get_idx())) {
inters_size *= get_domain_size(t1_pred, i);
}
}
for (unsigned i = 0; i < t2->get_num_args(); ++i) {
if (is_var(t2->get_arg(i)) &&
!non_local_vars.contains(to_var(t2->get_arg(i))->get_idx())) {
inters_size *= get_domain_size(t2_pred, i);
}
}
cost res = estimate_size(t1)*estimate_size(t2)/ inters_size; // (inters_size*inters_size);
//cost res = -inters_size; //cost res = -inters_size;
/*unsigned t1_strat = get_stratum(t1_pred); /*unsigned t1_strat = get_stratum(t1_pred);