diff --git a/src/ast/ast_smt2_pp.cpp b/src/ast/ast_smt2_pp.cpp index 9917a24af..0c818b848 100644 --- a/src/ast/ast_smt2_pp.cpp +++ b/src/ast/ast_smt2_pp.cpp @@ -1372,6 +1372,12 @@ std::ostream& operator<<(std::ostream& out, expr_ref_vector const& e) { return ast_smt2_pp(out, e.size(), e.c_ptr(), env, p, 0, 0, nullptr); } +std::ostream& operator<<(std::ostream& out, var_ref_vector const& e) { + smt2_pp_environment_dbg env(e.get_manager()); + params_ref p; + return ast_smt2_pp(out, e.size(), (expr*const*)e.c_ptr(), env, p, 0, 0, nullptr); +} + std::ostream& operator<<(std::ostream& out, app_ref_vector const& e) { smt2_pp_environment_dbg env(e.get_manager()); params_ref p; diff --git a/src/ast/ast_smt2_pp.h b/src/ast/ast_smt2_pp.h index dc306c6df..8debbed87 100644 --- a/src/ast/ast_smt2_pp.h +++ b/src/ast/ast_smt2_pp.h @@ -133,6 +133,7 @@ std::ostream& operator<<(std::ostream& out, sort_ref const& e); std::ostream& operator<<(std::ostream& out, expr_ref_vector const& e); std::ostream& operator<<(std::ostream& out, app_ref_vector const& e); +std::ostream& operator<<(std::ostream& out, var_ref_vector const& e); std::ostream& operator<<(std::ostream& out, func_decl_ref_vector const& e); std::ostream& operator<<(std::ostream& out, sort_ref_vector const& e); diff --git a/src/ast/rewriter/var_subst.h b/src/ast/rewriter/var_subst.h index 546f86b92..e4e1b57b9 100644 --- a/src/ast/rewriter/var_subst.h +++ b/src/ast/rewriter/var_subst.h @@ -50,6 +50,8 @@ public: */ expr_ref operator()(expr * n, unsigned num_args, expr * const * args); inline expr_ref operator()(expr * n, expr_ref_vector const& args) { return (*this)(n, args.size(), args.c_ptr()); } + inline expr_ref operator()(expr * n, var_ref_vector const& args) { return (*this)(n, args.size(), (expr*const*)args.c_ptr()); } + inline expr_ref operator()(expr * n, app_ref_vector const& args) { return (*this)(n, args.size(), (expr*const*)args.c_ptr()); } inline expr_ref operator()(expr * n, ptr_vector const& args) { return (*this)(n, args.size(), args.c_ptr()); } void reset() { m_reducer.reset(); } }; diff --git a/src/muz/base/dl_util.cpp b/src/muz/base/dl_util.cpp index 83902846c..27e1e840f 100644 --- a/src/muz/base/dl_util.cpp +++ b/src/muz/base/dl_util.cpp @@ -420,32 +420,31 @@ namespace datalog { - void reverse_renaming(ast_manager & m, const expr_ref_vector & src, expr_ref_vector & tgt) { + void reverse_renaming(const var_ref_vector & src, var_ref_vector & tgt) { + ast_manager& m = src.m(); SASSERT(tgt.empty()); unsigned src_sz = src.size(); - unsigned src_ofs = src_sz-1; + unsigned src_ofs = src_sz - 1; unsigned max_var_idx = 0; for(unsigned i=0; iget_idx(); - if(var_idx>max_var_idx) { - max_var_idx=var_idx; + unsigned var_idx = src[i]->get_idx(); + if (var_idx > max_var_idx) { + max_var_idx = var_idx; } } unsigned tgt_sz = max_var_idx+1; - unsigned tgt_ofs = tgt_sz-1; + unsigned tgt_ofs = tgt_sz - 1; tgt.resize(tgt_sz, nullptr); - for(unsigned i=0; iget_idx(); tgt[tgt_ofs-var_idx] = m.mk_var(i, v->get_sort()); } diff --git a/src/muz/base/dl_util.h b/src/muz/base/dl_util.h index b17bc2417..b7690b380 100644 --- a/src/muz/base/dl_util.h +++ b/src/muz/base/dl_util.h @@ -369,7 +369,7 @@ namespace datalog { proof_converter* mk_skip_proof_converter(); - void reverse_renaming(ast_manager & m, const expr_ref_vector & src, expr_ref_vector & tgt); + void reverse_renaming(const var_ref_vector & src, var_ref_vector & tgt); void print_renaming(const expr_ref_vector & cont, std::ostream & out); diff --git a/src/muz/rel/dl_mk_simple_joins.cpp b/src/muz/rel/dl_mk_simple_joins.cpp index 69330fd46..7aaf5a533 100644 --- a/src/muz/rel/dl_mk_simple_joins.cpp +++ b/src/muz/rel/dl_mk_simple_joins.cpp @@ -136,13 +136,14 @@ namespace datalog { public: join_planner(context & ctx, rule_set & rs_aux_copy) - : m_context(ctx), m(ctx.get_manager()), + : m_context(ctx), + m(ctx.get_manager()), rm(ctx.get_rule_manager()), m_var_subst(ctx.get_var_subst()), m_rs_aux_copy(rs_aux_copy), - m_introduced_rules(ctx.get_rule_manager()), + m_introduced_rules(rm), m_modified_rules(false), - m_pinned(ctx.get_manager()) + m_pinned(m) { } @@ -155,7 +156,7 @@ namespace datalog { private: - void get_normalizer(app * t, unsigned & next_var, expr_ref_vector & result) const { + void get_normalizer(app * t, unsigned & next_var, var_ref_vector & result) const { SASSERT(!result.empty()); unsigned res_ofs = result.size()-1; for (expr* arg : *t) { @@ -166,8 +167,8 @@ namespace datalog { } } - expr_ref_vector get_normalizer(app * t1, app * t2) const { - expr_ref_vector result(m); + var_ref_vector get_normalizer(app * t1, app * t2) const { + var_ref_vector result(m); if (t1->get_num_args() == 0 && t2->get_num_args() == 0) { return result; //nothing to normalize } @@ -225,7 +226,7 @@ namespace datalog { //so the order should not matter } - result.resize(max_var_idx + 1, static_cast(nullptr)); + result.resize(max_var_idx + 1, static_cast(nullptr)); unsigned next_var = 0; get_normalizer(t1, next_var, result); get_normalizer(t2, next_var, result); @@ -234,11 +235,9 @@ namespace datalog { app_pair get_key(app * t1, app * t2) { - expr_ref_vector norm_subst = get_normalizer(t1, t2); - expr_ref t1n_ref(m); - expr_ref t2n_ref(m); - t1n_ref = m_var_subst(t1, norm_subst.size(), norm_subst.c_ptr()); - t2n_ref = m_var_subst(t2, norm_subst.size(), norm_subst.c_ptr()); + var_ref_vector norm_subst = get_normalizer(t1, t2); + expr_ref t1n_ref = m_var_subst(t1, norm_subst); + expr_ref t2n_ref = m_var_subst(t2, norm_subst); app * t1n = to_app(t1n_ref); app * t2n = to_app(t2n_ref); if (t1n->get_id() > t2n->get_id()) { @@ -259,7 +258,7 @@ namespace datalog { by the time of a call to this function */ void register_pair(app * t1, app * t2, rule * r, const var_idx_set & non_local_vars) { - SASSERT (t1!=t2); + SASSERT (t1 != t2); cost_map::entry * e = m_costs.insert_if_not_there2(get_key(t1, t2), nullptr); pair_info * & ptr_inf = e->get_data().m_value; if (ptr_inf == nullptr) { @@ -267,11 +266,11 @@ namespace datalog { } pair_info & inf = *ptr_inf; - expr_ref_vector normalizer = get_normalizer(t1, t2); + var_ref_vector normalizer = get_normalizer(t1, t2); unsigned norm_ofs = normalizer.size()-1; var_idx_set normalized_vars; for (auto idx : non_local_vars) { - unsigned norm_var = to_var(normalizer.get(norm_ofs - idx))->get_idx(); + unsigned norm_var = normalizer.get(norm_ofs - idx)->get_idx(); normalized_vars.insert(norm_var); } @@ -313,6 +312,8 @@ namespace datalog { counter.count_vars(t1, -1); //temporarily remove t1 variables from counter for (unsigned j = i+1; j < pos_tail_size; j++) { app * t2 = r->get_tail(j); + if (t1 == t2) + continue; counter.count_vars(t2, -1); //temporarily remove t2 variables from counter var_idx_set t2_vars = rm.collect_vars(t2); t2_vars |= t1_vars; @@ -383,9 +384,9 @@ namespace datalog { app_ref head(m.mk_app(decl, arity, args.c_ptr()), m); - app * tail[] = {t1, t2}; + app * tail[] = { t1, t2 }; - rule * new_rule = m_context.get_rule_manager().mk(head, 2, tail, nullptr); + rule * new_rule = rm.mk(head, 2, tail, nullptr); //TODO: update accounting so that it can handle multiple parents new_rule->set_accounting_parent_object(m_context, one_parent); @@ -516,7 +517,7 @@ namespace datalog { var_idx_set t1_vars = rm.collect_vars(t1); - unsigned i2start = (t1_pred==t2_pred) ? (i1+1) : 0; + unsigned i2start = (t1_pred == t2_pred) ? (i1+1) : 0; for (unsigned i2 = i2start; i2 < len; i2++) { app * rt2 = rule_content[i2]; if (i1 == i2 || rt2->get_decl() != t2_pred) { @@ -526,11 +527,11 @@ namespace datalog { continue; } - expr_ref_vector denormalizer(m); - expr_ref_vector normalizer = get_normalizer(rt1, rt2); - reverse_renaming(m, normalizer, denormalizer); + var_ref_vector denormalizer(m); + var_ref_vector normalizer = get_normalizer(rt1, rt2); + reverse_renaming(normalizer, denormalizer); expr_ref new_transf(m); - new_transf = m_var_subst(t_new, denormalizer.size(), denormalizer.c_ptr()); + new_transf = m_var_subst(t_new, denormalizer); var_idx_set transf_vars = rm.collect_vars(new_transf); TRACE("dl", tout << mk_pp(rt1, m) << " " << mk_pp(rt2, m) << " -> " << new_transf << "\n";); counter.count_vars(rt2, -1); @@ -544,11 +545,11 @@ namespace datalog { // require that tr_vars contains non_local_vars TRACE("dl", tout << "non-local : " << non_local_vars << " tr_vars " << tr_vars << " rt12_vars " << rt2_vars << "\n";); if (!non_local_vars.subset_of(tr_vars)) { - expr_ref_vector normalizer2 = get_normalizer(rt2, rt1); + var_ref_vector normalizer2 = get_normalizer(rt2, rt1); TRACE("dl", tout << normalizer << "\nnorm\n" << normalizer2 << "\n";); denormalizer.reset(); - reverse_renaming(m, normalizer2, denormalizer); - new_transf = m_var_subst(t_new, denormalizer.size(), denormalizer.c_ptr()); + reverse_renaming(normalizer2, denormalizer); + new_transf = m_var_subst(t_new, denormalizer); SASSERT(non_local_vars.subset_of(rm.collect_vars(new_transf))); TRACE("dl", tout << mk_pp(rt2, m) << " " << mk_pp(rt1, m) << " -> " << new_transf << "\n";); } @@ -574,9 +575,13 @@ namespace datalog { replace_edges(r, removed_tails, added_tails, rule_content); } - cost get_domain_size(func_decl * pred, unsigned arg_index) const { - relation_sort sort = pred->get_domain(arg_index); - return static_cast(m_context.get_sort_size_estimate(sort)); + + cost get_domain_size(expr* e) const { + return get_domain_size(m.get_sort(e)); + } + + cost get_domain_size(sort* s) const { + return static_cast(m_context.get_sort_size_estimate(s)); } unsigned get_stratum(func_decl * pred) const { @@ -584,40 +589,34 @@ namespace datalog { } cost estimate_size(app * t) const { - func_decl * pred = t->get_decl(); - unsigned n = pred->get_arity(); rel_context_base* rel = m_context.get_rel_context(); if (!rel) { return cost(1); } relation_manager& rm = rel->get_rmanager(); - if ( (m_context.saturation_was_run() && rm.try_get_relation(pred)) - || rm.is_saturated(pred)) { + func_decl * pred = t->get_decl(); + if ( (m_context.saturation_was_run() && rm.try_get_relation(pred)) || rm.is_saturated(pred)) { SASSERT(rm.try_get_relation(pred)); //if it is saturated, it should exist unsigned rel_size_int = rel->get_relation(pred).get_size_estimate_rows(); if (rel_size_int != 0) { - cost rel_size = static_cast(rel_size_int); - cost curr_size = rel_size; - for (unsigned i = 0; i < n; i++) { - if (!is_var(t->get_arg(i))) { - curr_size /= get_domain_size(pred, i); + cost curr_size = static_cast(rel_size_int); + for (expr* arg : *t) { + if (!is_var(arg)) { + curr_size /= get_domain_size(arg); } } return curr_size; } } cost res = 1; - for (unsigned i = 0; i < n; i++) { - if (is_var(t->get_arg(i))) { - res *= get_domain_size(pred, i); - } + for (expr* arg : *t) { + if (is_var(arg)) + res *= get_domain_size(arg); } return res; } cost compute_cost(app * t1, app * t2, const var_idx_set & non_local_vars) const { - func_decl * t1_pred = t1->get_decl(); - func_decl * t2_pred = t2->get_decl(); cost inters_size = 1; variable_intersection vi(m_context.get_manager()); vi.populate(t1, t2); @@ -626,29 +625,27 @@ namespace datalog { for (unsigned i = 0; i < n; i++) { unsigned arg_index1, arg_index2; vi.get(i, arg_index1, arg_index2); - SASSERT(is_var(t1->get_arg(arg_index1))); - if (non_local_vars.contains(to_var(t1->get_arg(arg_index1))->get_idx())) { - inters_size *= get_domain_size(t1_pred, arg_index1); + expr* arg = t1->get_arg(arg_index1); + SASSERT(is_var(arg)); + if (non_local_vars.contains(to_var(arg)->get_idx())) { + inters_size *= get_domain_size(arg); } - //joined arguments must have the same domain - SASSERT(get_domain_size(t1_pred, arg_index1)==get_domain_size(t2_pred, arg_index2)); + // joined arguments must have the same domain + SASSERT(get_domain_size(arg) == get_domain_size(t2->get_arg(arg_index2))); } // remove contributions from projected columns. - for (unsigned i = 0; i < t1->get_num_args(); ++i) { - if (is_var(t1->get_arg(i)) && - !non_local_vars.contains(to_var(t1->get_arg(i))->get_idx())) { - inters_size *= get_domain_size(t1_pred, i); + for (expr* arg : *t1) { + if (is_var(arg) && !non_local_vars.contains(to_var(arg)->get_idx())) { + inters_size *= get_domain_size(arg); } } - for (unsigned i = 0; i < t2->get_num_args(); ++i) { - if (is_var(t2->get_arg(i)) && - !non_local_vars.contains(to_var(t2->get_arg(i))->get_idx())) { - inters_size *= get_domain_size(t2_pred, i); + for (expr* arg : *t2) { + if (is_var(arg) && !non_local_vars.contains(to_var(arg)->get_idx())) { + inters_size *= get_domain_size(arg); } } - cost res = estimate_size(t1)*estimate_size(t2)/ inters_size; // (inters_size*inters_size); - //cost res = -inters_size; + cost res = (estimate_size(t1) * estimate_size(t2)) / inters_size; TRACE("report_costs", display_predicate(m_context, t1, tout); @@ -699,10 +696,10 @@ namespace datalog { if (!m_modified_rules) { return nullptr; } - rule_set * result = alloc(rule_set, m_context); + scoped_ptr result = alloc(rule_set, m_context); for (auto& kv : m_rules_content) { rule * orig_r = kv.m_key; - ptr_vector content = kv.m_value; + ptr_vector const& content = kv.m_value; SASSERT(content.size() <= 2); if (content.size() == orig_r->get_positive_tail_size()) { //rule did not change @@ -713,25 +710,25 @@ namespace datalog { ptr_vector tail(content); bool_vector negs(tail.size(), false); unsigned or_len = orig_r->get_tail_size(); - for (unsigned i=orig_r->get_positive_tail_size(); i < or_len; i++) { + for (unsigned i = orig_r->get_positive_tail_size(); i < or_len; i++) { tail.push_back(orig_r->get_tail(i)); negs.push_back(orig_r->is_neg_tail(i)); } - rule * new_rule = m_context.get_rule_manager().mk(orig_r->get_head(), tail.size(), tail.c_ptr(), + rule * new_rule = rm.mk(orig_r->get_head(), tail.size(), tail.c_ptr(), negs.c_ptr(), orig_r->name()); new_rule->set_accounting_parent_object(m_context, orig_r); - m_context.get_rule_manager().mk_rule_rewrite_proof(*orig_r, *new_rule); + rm.mk_rule_rewrite_proof(*orig_r, *new_rule); result->add_rule(new_rule); } - while (!m_introduced_rules.empty()) { - result->add_rule(m_introduced_rules.back()); - m_context.get_rule_manager().mk_rule_asserted_proof(*m_introduced_rules.back()); - m_introduced_rules.pop_back(); + for (rule* r : m_introduced_rules) { + result->add_rule(r); + rm.mk_rule_asserted_proof(*r); } + m_introduced_rules.reset(); result->inherit_predicates(source); - return result; + return result.detach(); } }; @@ -741,9 +738,7 @@ namespace datalog { if (!rs_aux_copy.is_closed()) { rs_aux_copy.close(); } - join_planner planner(m_context, rs_aux_copy); - return planner.run(source); }