diff --git a/src/api/api_ast.cpp b/src/api/api_ast.cpp index 293983c9a..c96328bf8 100644 --- a/src/api/api_ast.cpp +++ b/src/api/api_ast.cpp @@ -31,7 +31,7 @@ Revision History: #include"ast_smt2_pp.h" #include"th_rewriter.h" #include"var_subst.h" -#include"expr_substitution.h" +#include"expr_safe_replace.h" #include"pp.h" #include"scoped_ctrl_c.h" #include"cancel_eh.h" @@ -786,17 +786,12 @@ extern "C" { RETURN_Z3(of_expr(0)); } } - - expr_substitution subst(m); + expr_safe_replace subst(m); for (unsigned i = 0; i < num_exprs; i++) { subst.insert(from[i], to[i]); } - th_rewriter m_rw(m); - m_rw.set_substitution(&subst); - expr_ref new_a(m); - proof_ref pr(m); - m_rw(a, new_a, pr); + subst(a, new_a); mk_c(c)->save_ast_trail(new_a); r = new_a.get(); RETURN_Z3(of_expr(r)); diff --git a/src/muz/rel/dl_mk_simple_joins.cpp b/src/muz/rel/dl_mk_simple_joins.cpp index 4b9ce582a..c2214dad9 100644 --- a/src/muz/rel/dl_mk_simple_joins.cpp +++ b/src/muz/rel/dl_mk_simple_joins.cpp @@ -47,9 +47,9 @@ namespace datalog { being notified about it, it will surely see the decrease from length 3 to 2 which the threshold for rule being counted in this counter. */ - unsigned m_consumers; - bool m_stratified; - unsigned m_src_stratum; + unsigned m_consumers; + bool m_stratified; + unsigned m_src_stratum; public: var_idx_set m_all_nonlocal_vars; rule_vector m_rules; @@ -57,16 +57,13 @@ namespace datalog { pair_info() : m_consumers(0), m_stratified(true), m_src_stratum(0) {} bool can_be_joined() const { - return m_consumers>0; + return m_consumers > 0; } cost get_cost() const { - /*if(m_instantiated) { - return std::numeric_limits::min(); - }*/ - SASSERT(m_consumers>0); + SASSERT(m_consumers > 0); cost amortized = m_total_cost/m_consumers; - if(m_stratified) { + if (m_stratified) { return amortized * ( (amortized>0) ? (1/16.0f) : 16.0f); } else { @@ -81,19 +78,20 @@ namespace datalog { by the time of a call to this function */ void add_rule(join_planner & pl, app * t1, app * t2, rule * r, - const var_idx_set & non_local_vars_normalized) { - if(m_rules.empty()) { - m_total_cost = pl.compute_cost(t1, t2); + const var_idx_set & non_local_vars_normalized, + const var_idx_set & non_local_vars) { + if (m_rules.empty()) { + m_total_cost = pl.compute_cost(t1, t2, non_local_vars); m_src_stratum = std::max(pl.get_stratum(t1->get_decl()), pl.get_stratum(t2->get_decl())); } m_rules.push_back(r); - if(pl.m_rules_content.find_core(r)->get_data().m_value.size()>2) { + if (pl.m_rules_content.find(r).size()>2) { m_consumers++; } - if(m_stratified) { + if (m_stratified) { unsigned head_stratum = pl.get_stratum(r->get_decl()); SASSERT(head_stratum>=m_src_stratum); - if(head_stratum==m_src_stratum) { + if (head_stratum==m_src_stratum) { m_stratified = false; } } @@ -105,7 +103,7 @@ namespace datalog { */ bool remove_rule(rule * r, unsigned original_length) { TRUSTME( remove_from_vector(m_rules, r) ); - if(original_length>2) { + if (original_length>2) { SASSERT(m_consumers>0); m_consumers--; } @@ -165,7 +163,7 @@ namespace datalog { SASSERT(is_var(t->get_arg(i))); var * v = to_var(t->get_arg(i)); unsigned var_idx = v->get_idx(); - if(result[res_ofs-var_idx]==0) { + if (result[res_ofs-var_idx]==0) { result[res_ofs-var_idx]=m.mk_var(next_var, v->get_sort()); next_var++; } @@ -174,7 +172,7 @@ namespace datalog { void get_normalizer(app * t1, app * t2, expr_ref_vector & result) const { SASSERT(result.empty()); - if(t1->get_num_args()==0 && t2->get_num_args()==0) { + if (t1->get_num_args()==0 && t2->get_num_args()==0) { return; //nothing to normalize } SASSERT(!t1->is_ground() || !t2->is_ground()); @@ -186,14 +184,14 @@ namespace datalog { var_idx_set::iterator ovend = orig_var_set.end(); for(; ovit!=ovend; ++ovit) { unsigned var_idx = *ovit; - if(var_idx>max_var_idx) { + if (var_idx>max_var_idx) { max_var_idx = var_idx; } } } - if(t1->get_decl()!=t2->get_decl()) { - if(t1->get_decl()->get_id()get_decl()->get_id()) { + if (t1->get_decl()!=t2->get_decl()) { + if (t1->get_decl()->get_id()get_decl()->get_id()) { std::swap(t1, t2); } } @@ -207,9 +205,9 @@ namespace datalog { //so the only literals which appear in pairs are the ones that contain only variables. var * v1 = to_var(t1->get_arg(i)); var * v2 = to_var(t2->get_arg(i)); - if(v1->get_sort()!=v2->get_sort()) { + if (v1->get_sort()!=v2->get_sort()) { //different sorts mean we can distinguish the two terms - if(v1->get_sort()->get_id()get_sort()->get_id()) { + if (v1->get_sort()->get_id()get_sort()->get_id()) { std::swap(t1, t2); } break; @@ -221,9 +219,9 @@ namespace datalog { SASSERT(norm1[v1_idx]==-1); SASSERT(norm2[v2_idx]==-1); - if(norm2[v1_idx]!=norm1[v2_idx]) { + if (norm2[v1_idx]!=norm1[v2_idx]) { //now we can distinguish the two terms - if(norm2[v1_idx]t2n) { + if (t1n>t2n) { std::swap(t1n, t2n); } m_pinned.push_back(t1n); @@ -274,12 +272,10 @@ namespace datalog { by the time of a call to this function */ void register_pair(app * t1, app * t2, rule * r, const var_idx_set & non_local_vars) { - TRACE("dl", tout << mk_pp(t1, m) << " " << mk_pp(t2, m) << "\n"; - r->display(m_context, tout); tout << "\n";); SASSERT(t1!=t2); cost_map::entry * e = m_costs.insert_if_not_there2(get_key(t1, t2), 0); pair_info * & ptr_inf = e->get_data().m_value; - if(ptr_inf==0) { + if (ptr_inf==0) { ptr_inf = alloc(pair_info); } pair_info & inf = *ptr_inf; @@ -288,25 +284,30 @@ namespace datalog { get_normalizer(t1, t2, normalizer); unsigned norm_ofs = normalizer.size()-1; var_idx_set normalized_vars; - var_idx_set::iterator vit = non_local_vars.begin(); + var_idx_set::iterator vit = non_local_vars.begin(); var_idx_set::iterator vend = non_local_vars.end(); for(; vit!=vend; ++vit) { unsigned norm_var = to_var(normalizer.get(norm_ofs-*vit))->get_idx(); normalized_vars.insert(norm_var); } - inf.add_rule(*this, t1, t2, r, normalized_vars); + inf.add_rule(*this, t1, t2, r, normalized_vars, non_local_vars); + TRACE("dl", tout << mk_pp(t1, m) << " " << mk_pp(t2, m) << " "; + vit = non_local_vars.begin(); + for (; vit != vend; ++vit) tout << *vit << " "; + tout << "\n"; + r->display(m_context, tout); + if (inf.can_be_joined()) tout << "cost: " << inf.get_cost() << "\n";); + } pair_info & get_pair(app_pair key) const { - cost_map::entry * e = m_costs.find_core(key); - SASSERT(e); - return *e->get_data().m_value; + return *m_costs.find(key); } void remove_rule_from_pair(app_pair key, rule * r, unsigned original_len) { pair_info * ptr = &get_pair(key); - if(ptr->remove_rule(r, original_len)) { + if (ptr->remove_rule(r, original_len)) { SASSERT(ptr->m_rules.empty()); m_costs.remove(key); dealloc(ptr); @@ -349,7 +350,7 @@ namespace datalog { unsigned n=t->get_num_args(); for(unsigned i=0; iget_arg(i)); - if(v->get_idx()==var_idx) { + if (v->get_idx()==var_idx) { args.push_back(v); domain.push_back(m.get_sort(v)); return true; @@ -375,7 +376,7 @@ namespace datalog { unsigned var_idx=*ovit; bool found=extract_argument_info(var_idx, t1, args, domain); - if(!found) { + if (!found) { found=extract_argument_info(var_idx, t2, args, domain); } SASSERT(found); @@ -389,7 +390,7 @@ namespace datalog { func_decl* parent_head = one_parent->get_decl(); const char * one_parent_name = parent_head->get_name().bare_str(); std::string parent_name; - if(inf.m_rules.size()>1) { + if (inf.m_rules.size()>1) { parent_name = one_parent_name + std::string("_and_") + to_string(inf.m_rules.size()-1); } else { @@ -443,7 +444,7 @@ namespace datalog { } //remove edges between surviving tails and removed tails for(unsigned i=0; i & rule_content = m_rules_content.find_core(r)->get_data().m_value; + ptr_vector & rule_content = m_rules_content.find(r); unsigned len = rule_content.size(); - if(len==1) { + if (len==1) { return; } @@ -515,16 +516,16 @@ namespace datalog { ptr_vector added_tails; for(unsigned i1=0; i1get_decl()!=t1_pred) { + if (rt1->get_decl()!=t1_pred) { continue; } unsigned i2start = (t1_pred==t2_pred) ? (i1+1) : 0; for(unsigned i2=i2start; i2get_decl()!=t2_pred) { + if (i1==i2 || rt2->get_decl()!=t2_pred) { continue; } - if(get_key(rt1, rt2)!=pair_key) { + if (get_key(rt1, rt2)!=pair_key) { continue; } expr_ref_vector normalizer(m); @@ -558,7 +559,7 @@ namespace datalog { relation_sort sort = pred->get_domain(arg_index); return static_cast(m_context.get_sort_size_estimate(sort)); //unsigned sz; - //if(!m_context.get_sort_size(sort, sz)) { + //if (!m_context.get_sort_size(sort, sz)) { // sz=UINT_MAX; //} //return static_cast(sz); @@ -576,15 +577,15 @@ namespace datalog { return cost(1); } relation_manager& rm = rel->get_rmanager(); - if( (m_context.saturation_was_run() && rm.try_get_relation(pred)) + if ( (m_context.saturation_was_run() && rm.try_get_relation(pred)) || rm.is_saturated(pred)) { SASSERT(rm.try_get_relation(pred)); //if it is saturated, it should exist unsigned rel_size_int = rel->get_relation(pred).get_size_estimate_rows(); - if(rel_size_int!=0) { + if (rel_size_int!=0) { cost rel_size = static_cast(rel_size_int); cost curr_size = rel_size; for(unsigned i=0; iget_arg(i))) { + if (!is_var(t->get_arg(i))) { curr_size /= get_domain_size(pred, i); } } @@ -593,40 +594,58 @@ namespace datalog { } cost res = 1; for(unsigned i=0; iget_arg(i))) { + if (is_var(t->get_arg(i))) { res *= get_domain_size(pred, i); } } return res; } - cost compute_cost(app * t1, app * t2) const { + cost compute_cost(app * t1, app * t2, const var_idx_set & non_local_vars) const { func_decl * t1_pred = t1->get_decl(); func_decl * t2_pred = t2->get_decl(); cost inters_size = 1; variable_intersection vi(m_context.get_manager()); vi.populate(t1, t2); unsigned n = vi.size(); + // remove contributions from joined columns. for(unsigned i=0; iget_arg(arg_index1))); + if (non_local_vars.contains(to_var(t1->get_arg(arg_index1))->get_idx())) { + inters_size *= get_domain_size(t1_pred, arg_index1); + } //joined arguments must have the same domain SASSERT(get_domain_size(t1_pred, arg_index1)==get_domain_size(t2_pred, arg_index2)); } - cost res = estimate_size(t1)*estimate_size(t2)/(inters_size*inters_size); + // remove contributions from projected columns. + for (unsigned i = 0; i < t1->get_num_args(); ++i) { + if (is_var(t1->get_arg(i)) && + !non_local_vars.contains(to_var(t1->get_arg(i))->get_idx())) { + inters_size *= get_domain_size(t1_pred, i); + } + } + for (unsigned i = 0; i < t2->get_num_args(); ++i) { + if (is_var(t2->get_arg(i)) && + !non_local_vars.contains(to_var(t2->get_arg(i))->get_idx())) { + inters_size *= get_domain_size(t2_pred, i); + } + } + + cost res = estimate_size(t1)*estimate_size(t2)/ inters_size; // (inters_size*inters_size); //cost res = -inters_size; /*unsigned t1_strat = get_stratum(t1_pred); SASSERT(t1_strat<=m_head_stratum); - if(t1_strat0) { + if (res>0) { res /= 2; } else { @@ -653,17 +672,17 @@ namespace datalog { for(; it!=end; ++it) { app_pair key = it->m_key; pair_info & inf = *it->m_value; - if(!inf.can_be_joined()) { + if (!inf.can_be_joined()) { continue; } cost c = inf.get_cost(); - if(!found || cm_key; ptr_vector content = rcit->m_value; SASSERT(content.size()<=2); - if(content.size()==orig_r->get_positive_tail_size()) { + if (content.size()==orig_r->get_positive_tail_size()) { //rule did not change result->add_rule(orig_r); continue; @@ -728,7 +747,7 @@ namespace datalog { rule_set * mk_simple_joins::operator()(rule_set const & source) { rule_set rs_aux_copy(m_context); rs_aux_copy.replace_rules(source); - if(!rs_aux_copy.is_closed()) { + if (!rs_aux_copy.is_closed()) { rs_aux_copy.close(); } diff --git a/src/muz/transforms/dl_mk_array_blast.cpp b/src/muz/transforms/dl_mk_array_blast.cpp index 776a2da5b..93e31ff23 100644 --- a/src/muz/transforms/dl_mk_array_blast.cpp +++ b/src/muz/transforms/dl_mk_array_blast.cpp @@ -31,7 +31,6 @@ namespace datalog { rm(ctx.get_rule_manager()), m_rewriter(m, m_params), m_simplifier(ctx), - m_sub(m), m_next_var(0) { m_params.set_bool("expand_select_store",true); m_rewriter.updt_params(m_params); @@ -82,7 +81,6 @@ namespace datalog { return false; } if (v) { - m_sub.insert(e, v); m_defs.insert(e, to_var(v)); } else { @@ -92,71 +90,113 @@ namespace datalog { m_next_var = vars.size() + 1; } v = m.mk_var(m_next_var, m.get_sort(e)); - m_sub.insert(e, v); m_defs.insert(e, v); ++m_next_var; } return true; } + + bool mk_array_blast::is_select_eq_var(expr* e, app*& s, var*& v) const { + expr* x, *y; + if (m.is_eq(e, x, y) || m.is_iff(e, x, y)) { + if (a.is_select(y)) { + std::swap(x,y); + } + if (a.is_select(x) && is_var(y)) { + s = to_app(x); + v = to_var(y); + return true; + } + } + return false; + } + bool mk_array_blast::ackermanize(rule const& r, expr_ref& body, expr_ref& head) { - expr_ref_vector conjs(m); + expr_ref_vector conjs(m), trail(m); qe::flatten_and(body, conjs); m_defs.reset(); - m_sub.reset(); m_next_var = 0; ptr_vector todo; - todo.push_back(head); + obj_map cache; + ptr_vector args; + app_ref e1(m); + app* s; + var* v; + for (unsigned i = 0; i < conjs.size(); ++i) { expr* e = conjs[i].get(); - expr* x, *y; - if (m.is_eq(e, x, y) || m.is_iff(e, x, y)) { - if (a.is_select(y)) { - std::swap(x,y); - } - if (a.is_select(x) && is_var(y)) { - if (!insert_def(r, to_app(x), to_var(y))) { - return false; - } - } + if (is_select_eq_var(e, s, v)) { + todo.append(s->get_num_args(), s->get_args()); } - if (a.is_select(e) && !insert_def(r, to_app(e), 0)) { - return false; + else { + todo.push_back(e); } - todo.push_back(e); } - // now make sure to cover all occurrences. - ast_mark mark; while (!todo.empty()) { expr* e = todo.back(); - todo.pop_back(); - if (mark.is_marked(e)) { + if (cache.contains(e)) { + todo.pop_back(); continue; } - mark.mark(e, true); if (is_var(e)) { + cache.insert(e, e); + todo.pop_back(); continue; } if (!is_app(e)) { return false; } app* ap = to_app(e); - if (a.is_select(ap) && !m_defs.contains(ap)) { - if (!insert_def(r, ap, 0)) { - return false; + bool valid = true; + args.reset(); + for (unsigned i = 0; i < ap->get_num_args(); ++i) { + expr* arg; + if (cache.find(ap->get_arg(i), arg)) { + args.push_back(arg); + } + else { + todo.push_back(ap->get_arg(i)); + valid = false; } } - if (a.is_select(e)) { - get_select_args(e, todo); - continue; + if (valid) { + todo.pop_back(); + e1 = m.mk_app(ap->get_decl(), args.size(), args.c_ptr()); + trail.push_back(e1); + if (a.is_select(ap)) { + if (m_defs.find(e1, v)) { + cache.insert(e, v); + } + else if (!insert_def(r, e1, 0)) { + return false; + } + else { + cache.insert(e, m_defs.find(e1)); + } + } + else { + cache.insert(e, e1); + } + } + } + for (unsigned i = 0; i < conjs.size(); ++i) { + expr* e = conjs[i].get(); + if (is_select_eq_var(e, s, v)) { + args.reset(); + for (unsigned j = 0; j < s->get_num_args(); ++j) { + args.push_back(cache.find(s->get_arg(j))); + } + e1 = m.mk_app(s->get_decl(), args.size(), args.c_ptr()); + if (!m_defs.contains(e1) && !insert_def(r, e1, v)) { + return false; + } + conjs[i] = m.mk_eq(v, m_defs.find(e1)); } - for (unsigned i = 0; i < ap->get_num_args(); ++i) { - todo.push_back(ap->get_arg(i)); + else { + conjs[i] = cache.find(e); } } - m_sub(body); - m_sub(head); - conjs.reset(); // perform the Ackermann reduction by creating implications // i1 = i2 => val1 = val2 for each equality pair: @@ -171,6 +211,7 @@ namespace datalog { for (; it2 != end; ++it2) { app* a2 = it2->m_key; var* v2 = it2->m_value; + TRACE("dl", tout << mk_pp(a1, m) << " " << mk_pp(a2, m) << "\n";); if (get_select(a1) != get_select(a2)) { continue; } @@ -184,10 +225,7 @@ namespace datalog { conjs.push_back(m.mk_implies(m.mk_and(eqs.size(), eqs.c_ptr()), m.mk_eq(v1, v2))); } } - if (!conjs.empty()) { - conjs.push_back(body); - body = m.mk_and(conjs.size(), conjs.c_ptr()); - } + body = m.mk_and(conjs.size(), conjs.c_ptr()); m_rewriter(body); return true; } diff --git a/src/muz/transforms/dl_mk_array_blast.h b/src/muz/transforms/dl_mk_array_blast.h index f4b685b7a..c96573848 100644 --- a/src/muz/transforms/dl_mk_array_blast.h +++ b/src/muz/transforms/dl_mk_array_blast.h @@ -44,7 +44,6 @@ namespace datalog { mk_interp_tail_simplifier m_simplifier; defs_t m_defs; - expr_safe_replace m_sub; unsigned m_next_var; bool blast(rule& r, rule_set& new_rules); @@ -59,6 +58,8 @@ namespace datalog { bool insert_def(rule const& r, app* e, var* v); + bool is_select_eq_var(expr* e, app*& s, var*& v) const; + public: /** \brief Create rule transformer that removes array stores and selects by ackermannization. diff --git a/src/test/expr_substitution.cpp b/src/test/expr_substitution.cpp new file mode 100644 index 000000000..f83bde97d --- /dev/null +++ b/src/test/expr_substitution.cpp @@ -0,0 +1,56 @@ +#include "expr_substitution.h" +#include "smt_params.h" +#include "substitution.h" +#include "unifier.h" +#include "bv_decl_plugin.h" +#include "ast_pp.h" +#include "arith_decl_plugin.h" +#include "reg_decl_plugins.h" +#include "th_rewriter.h" + +expr* mk_bv_xor(bv_util& bv, expr* a, expr* b) { + expr* args[2]; + args[0] = a; + args[1] = b; + return bv.mk_bv_xor(2, args); +} + +expr* mk_bv_and(bv_util& bv, expr* a, expr* b) { + expr* args[2]; + args[0] = a; + args[1] = b; + ast_manager& m = bv.get_manager(); + return m.mk_app(bv.get_family_id(), OP_BAND, 2, args); +} + +void tst_expr_substitution() { + memory::initialize(0); + ast_manager m; + reg_decl_plugins(m); + bv_util bv(m); + + expr_ref a(m), b(m), c(m), d(m); + expr_ref x(m); + expr_ref new_a(m); + proof_ref pr(m); + x = m.mk_const(symbol("x"), bv.mk_sort(8)); + a = mk_bv_and(bv, mk_bv_xor(bv, x,bv.mk_numeral(8,8)), mk_bv_xor(bv,x,x)); + b = x; + c = bv.mk_bv_sub(x, bv.mk_numeral(4, 8)); + + expr_substitution subst(m); + th_rewriter rw(m); + + // normalizing c does not help. + rw(c, d, pr); + subst.insert(b, d); + + rw.set_substitution(&subst); + + + enable_trace("th_rewriter_step"); + rw(a, new_a, pr); + + std::cout << mk_pp(new_a, m) << "\n"; + +} diff --git a/src/test/main.cpp b/src/test/main.cpp index 333456369..bc7e04124 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -215,6 +215,7 @@ int main(int argc, char ** argv) { TST(rcf); TST(polynorm); TST(qe_arith); + TST(expr_substitution); } void initialize_mam() {}