3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-07-20 03:12:03 +00:00
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2018-06-27 08:49:52 -07:00
parent 915983821b
commit 5762be2a0f
2 changed files with 180 additions and 174 deletions

View file

@ -1041,6 +1041,10 @@ class smt2_printer {
} }
void process(expr * n, format_ref & r) { void process(expr * n, format_ref & r) {
if (!n) {
r = mk_string(m(), "null");
return;
}
reset_stacks(); reset_stacks();
SASSERT(&(r.get_manager()) == &(fm())); SASSERT(&(r.get_manager()) == &(fm()));
m_soccs(n); m_soccs(n);
@ -1126,6 +1130,10 @@ public:
} }
void operator()(func_decl * f, format_ref & r, char const* cmd) { void operator()(func_decl * f, format_ref & r, char const* cmd) {
if (!f) {
r = mk_string(m(), "null");
return;
}
unsigned arity = f->get_arity(); unsigned arity = f->get_arity();
unsigned len; unsigned len;
format * fname = m_env.pp_fdecl_name(f, len); format * fname = m_env.pp_fdecl_name(f, len);
@ -1189,7 +1197,7 @@ void mk_smt2_format(unsigned sz, expr * const* es, smt2_pp_environment & env, pa
format_ref_vector fmts(fm(m)); format_ref_vector fmts(fm(m));
for (unsigned i = 0; i < sz; ++i) { for (unsigned i = 0; i < sz; ++i) {
format_ref fr(fm(m)); format_ref fr(fm(m));
pr(es[i], num_vars, var_prefix, fr, var_names); pr(es[i], num_vars, var_prefix, fr, var_names);
fmts.push_back(fr); fmts.push_back(fr);
} }
@ -1198,6 +1206,7 @@ void mk_smt2_format(unsigned sz, expr * const* es, smt2_pp_environment & env, pa
std::ostream & ast_smt2_pp(std::ostream & out, expr * n, smt2_pp_environment & env, params_ref const & p, unsigned indent, std::ostream & ast_smt2_pp(std::ostream & out, expr * n, smt2_pp_environment & env, params_ref const & p, unsigned indent,
unsigned num_vars, char const * var_prefix) { unsigned num_vars, char const * var_prefix) {
if (!n) return out << "null";
ast_manager & m = env.get_manager(); ast_manager & m = env.get_manager();
format_ref r(fm(m)); format_ref r(fm(m));
sbuffer<symbol> var_names; sbuffer<symbol> var_names;
@ -1209,6 +1218,7 @@ std::ostream & ast_smt2_pp(std::ostream & out, expr * n, smt2_pp_environment & e
} }
std::ostream & ast_smt2_pp(std::ostream & out, sort * s, smt2_pp_environment & env, params_ref const & p, unsigned indent) { std::ostream & ast_smt2_pp(std::ostream & out, sort * s, smt2_pp_environment & env, params_ref const & p, unsigned indent) {
if (s == nullptr) return out << "null";
ast_manager & m = env.get_manager(); ast_manager & m = env.get_manager();
format_ref r(fm(m)); format_ref r(fm(m));
sbuffer<symbol> var_names; sbuffer<symbol> var_names;
@ -1220,6 +1230,7 @@ std::ostream & ast_smt2_pp(std::ostream & out, sort * s, smt2_pp_environment & e
} }
std::ostream & ast_smt2_pp(std::ostream & out, func_decl * f, smt2_pp_environment & env, params_ref const & p, unsigned indent, char const* cmd) { std::ostream & ast_smt2_pp(std::ostream & out, func_decl * f, smt2_pp_environment & env, params_ref const & p, unsigned indent, char const* cmd) {
if (!f) return out << "null";
ast_manager & m = env.get_manager(); ast_manager & m = env.get_manager();
format_ref r(fm(m)); format_ref r(fm(m));
sbuffer<symbol> var_names; sbuffer<symbol> var_names;
@ -1231,6 +1242,7 @@ std::ostream & ast_smt2_pp(std::ostream & out, func_decl * f, smt2_pp_environmen
} }
std::ostream & ast_smt2_pp(std::ostream & out, func_decl * f, expr* e, smt2_pp_environment & env, params_ref const & p, unsigned indent, char const* cmd) { std::ostream & ast_smt2_pp(std::ostream & out, func_decl * f, expr* e, smt2_pp_environment & env, params_ref const & p, unsigned indent, char const* cmd) {
if (!f) return out << "null";
ast_manager & m = env.get_manager(); ast_manager & m = env.get_manager();
format_ref r(fm(m)); format_ref r(fm(m));
sbuffer<symbol> var_names; sbuffer<symbol> var_names;

View file

@ -64,7 +64,7 @@ namespace datalog {
SASSERT(m_consumers > 0); SASSERT(m_consumers > 0);
cost amortized = m_total_cost/m_consumers; cost amortized = m_total_cost/m_consumers;
if (m_stratified) { if (m_stratified) {
return amortized * ( (amortized>0) ? (1/16.0f) : 16.0f); return amortized * ( (amortized > 0) ? (1/16.0f) : 16.0f);
} }
else { else {
return amortized; return amortized;
@ -85,26 +85,26 @@ namespace datalog {
m_src_stratum = std::max(pl.get_stratum(t1->get_decl()), pl.get_stratum(t2->get_decl())); m_src_stratum = std::max(pl.get_stratum(t1->get_decl()), pl.get_stratum(t2->get_decl()));
} }
m_rules.push_back(r); m_rules.push_back(r);
if (pl.m_rules_content.find(r).size()>2) { if (pl.m_rules_content.find(r).size() > 2) {
m_consumers++; m_consumers++;
} }
if (m_stratified) { if (m_stratified) {
unsigned head_stratum = pl.get_stratum(r->get_decl()); unsigned head_stratum = pl.get_stratum(r->get_decl());
SASSERT(head_stratum>=m_src_stratum); SASSERT(head_stratum >= m_src_stratum);
if (head_stratum==m_src_stratum) { m_stratified = (head_stratum > m_src_stratum);
m_stratified = false;
}
} }
idx_set_union(m_all_nonlocal_vars, non_local_vars_normalized); idx_set_union(m_all_nonlocal_vars, non_local_vars_normalized);
TRACE("dl", tout << "all-nonlocal: " << m_all_nonlocal_vars << "\n";);
} }
/** /**
\brief Remove rule from the pair record. Return true if no rules remain \brief Remove rule from the pair record. Return true if no rules remain
in the pair, and so it should be removed. in the pair, and so it should be removed.
*/ */
bool remove_rule(rule * r, unsigned original_length) { bool remove_rule(rule * r, unsigned original_length) {
VERIFY( remove_from_vector(m_rules, r) ); VERIFY( remove_from_vector(m_rules, r) );
if (original_length>2) { if (original_length > 2) {
SASSERT(m_consumers>0); SASSERT(m_consumers > 0);
m_consumers--; m_consumers--;
} }
SASSERT(!m_rules.empty() || m_consumers==0); SASSERT(!m_rules.empty() || m_consumers==0);
@ -146,70 +146,60 @@ namespace datalog {
{ {
} }
~join_planner() ~join_planner() {
{ for (auto & kv : m_costs) {
cost_map::iterator it = m_costs.begin(); dealloc(kv.m_value);
cost_map::iterator end = m_costs.end();
for (; it != end; ++it) {
dealloc(it->m_value);
} }
m_costs.reset(); m_costs.reset();
} }
private: private:
void get_normalizer(app * t, unsigned & next_var, expr_ref_vector & result) const { void get_normalizer(app * t, unsigned & next_var, expr_ref_vector & result) const {
SASSERT(result.size()>0); SASSERT(!result.empty());
unsigned res_ofs = result.size()-1; unsigned res_ofs = result.size()-1;
unsigned n=t->get_num_args(); for (expr* arg : *t) {
for(unsigned i=0; i<n; i++) { unsigned var_idx = to_var(arg)->get_idx();
SASSERT(is_var(t->get_arg(i))); if (!result.get(res_ofs - var_idx)) {
var * v = to_var(t->get_arg(i)); result[res_ofs - var_idx] = m.mk_var(next_var++, m.get_sort(arg));
unsigned var_idx = v->get_idx();
if (result[res_ofs-var_idx]==nullptr) {
result[res_ofs-var_idx]=m.mk_var(next_var, v->get_sort());
next_var++;
} }
} }
} }
void get_normalizer(app * t1, app * t2, expr_ref_vector & result) const { expr_ref_vector get_normalizer(app * t1, app * t2) const {
SASSERT(result.empty()); expr_ref_vector result(m);
if (t1->get_num_args()==0 && t2->get_num_args()==0) { if (t1->get_num_args() == 0 && t2->get_num_args() == 0) {
return; //nothing to normalize return result; //nothing to normalize
} }
SASSERT(!t1->is_ground() || !t2->is_ground()); SASSERT(!t1->is_ground() || !t2->is_ground());
unsigned max_var_idx = 0; unsigned max_var_idx = 0;
{
var_idx_set& orig_var_set = rm.collect_vars(t1, t2); var_idx_set& orig_var_set = rm.collect_vars(t1, t2);
var_idx_set::iterator ovit = orig_var_set.begin(); for (unsigned var_idx : orig_var_set) {
var_idx_set::iterator ovend = orig_var_set.end(); if (var_idx>max_var_idx) {
for(; ovit!=ovend; ++ovit) { max_var_idx = var_idx;
unsigned var_idx = *ovit;
if (var_idx>max_var_idx) {
max_var_idx = var_idx;
}
} }
} }
if (t1->get_decl()!=t2->get_decl()) { if (t1->get_decl() != t2->get_decl()) {
if (t1->get_decl()->get_id()<t2->get_decl()->get_id()) { if (t1->get_decl()->get_id() < t2->get_decl()->get_id()) {
std::swap(t1, t2); std::swap(t1, t2);
} }
} }
else { else {
int_vector norm1(max_var_idx+1, -1); int_vector norm1(max_var_idx + 1, -1);
int_vector norm2(max_var_idx+1, -1); int_vector norm2(max_var_idx + 1, -1);
unsigned n=t1->get_num_args(); unsigned n = t1->get_num_args();
SASSERT(n==t2->get_num_args()); SASSERT(n == t2->get_num_args());
for(unsigned i=0; i<n; i++) { for (unsigned i = 0; i < n; ++i) {
//We assume that the mk_simple_joins transformer is applied after mk_filter_rules, //We assume that the mk_simple_joins transformer is applied after mk_filter_rules,
//so the only literals which appear in pairs are the ones that contain only variables. //so the only literals which appear in pairs are the ones that contain only variables.
var * v1 = to_var(t1->get_arg(i)); var * v1 = to_var(t1->get_arg(i));
var * v2 = to_var(t2->get_arg(i)); var * v2 = to_var(t2->get_arg(i));
if (v1->get_sort()!=v2->get_sort()) { if (v1->get_sort() != v2->get_sort()) {
//different sorts mean we can distinguish the two terms //different sorts mean we can distinguish the two terms
if (v1->get_sort()->get_id()<v2->get_sort()->get_id()) { if (v1->get_sort()->get_id() < v2->get_sort()->get_id()) {
std::swap(t1, t2); std::swap(t1, t2);
} }
break; break;
@ -218,32 +208,33 @@ namespace datalog {
unsigned v2_idx = v2->get_idx(); unsigned v2_idx = v2->get_idx();
//since the rules already went through the mk_filter_rules transformer, //since the rules already went through the mk_filter_rules transformer,
//variables must be linear //variables must be linear
SASSERT(norm1[v1_idx]==-1); SASSERT(norm1[v1_idx] == -1);
SASSERT(norm2[v2_idx]==-1); SASSERT(norm2[v2_idx] == -1);
if (norm2[v1_idx]!=norm1[v2_idx]) { if (norm2[v1_idx] != norm1[v2_idx]) {
//now we can distinguish the two terms //now we can distinguish the two terms
if (norm2[v1_idx]<norm1[v2_idx]) { if (norm2[v1_idx] < norm1[v2_idx]) {
std::swap(t1, t2); std::swap(t1, t2);
} }
break; break;
} }
norm1[v1_idx]=i; norm1[v1_idx] = i;
norm2[v2_idx]=i; norm2[v2_idx] = i;
} }
//if we did not exit the loop prematurely, the two terms are indistinguishable, //if we did not exit the loop prematurely, the two terms are indistinguishable,
//so the order should not matter //so the order should not matter
} }
result.resize(max_var_idx+1, static_cast<expr *>(nullptr)); result.resize(max_var_idx + 1, static_cast<expr *>(nullptr));
unsigned next_var = 0; unsigned next_var = 0;
get_normalizer(t1, next_var, result); get_normalizer(t1, next_var, result);
get_normalizer(t2, next_var, result); get_normalizer(t2, next_var, result);
return result;
} }
app_pair get_key(app * t1, app * t2) { app_pair get_key(app * t1, app * t2) {
expr_ref_vector norm_subst(m); expr_ref_vector norm_subst = get_normalizer(t1, t2);
get_normalizer(t1, t2, norm_subst);
expr_ref t1n_ref(m); expr_ref t1n_ref(m);
expr_ref t2n_ref(m); expr_ref t2n_ref(m);
m_var_subst(t1, norm_subst.size(), norm_subst.c_ptr(), t1n_ref); m_var_subst(t1, norm_subst.size(), norm_subst.c_ptr(), t1n_ref);
@ -256,6 +247,7 @@ namespace datalog {
m_pinned.push_back(t1n); m_pinned.push_back(t1n);
m_pinned.push_back(t2n); m_pinned.push_back(t2n);
TRACE("dl", tout << mk_pp(t1, m) << " " << mk_pp(t2, m) << " |-> " << t1n_ref << " " << t2n_ref << "\n";);
return app_pair(t1n, t2n); return app_pair(t1n, t2n);
} }
@ -267,30 +259,25 @@ namespace datalog {
by the time of a call to this function by the time of a call to this function
*/ */
void register_pair(app * t1, app * t2, rule * r, const var_idx_set & non_local_vars) { void register_pair(app * t1, app * t2, rule * r, const var_idx_set & non_local_vars) {
SASSERT(t1!=t2); SASSERT (t1!=t2);
cost_map::entry * e = m_costs.insert_if_not_there2(get_key(t1, t2), nullptr); cost_map::entry * e = m_costs.insert_if_not_there2(get_key(t1, t2), nullptr);
pair_info * & ptr_inf = e->get_data().m_value; pair_info * & ptr_inf = e->get_data().m_value;
if (ptr_inf==nullptr) { if (ptr_inf == nullptr) {
ptr_inf = alloc(pair_info); ptr_inf = alloc(pair_info);
} }
pair_info & inf = *ptr_inf; pair_info & inf = *ptr_inf;
expr_ref_vector normalizer(m); expr_ref_vector normalizer = get_normalizer(t1, t2);
get_normalizer(t1, t2, normalizer);
unsigned norm_ofs = normalizer.size()-1; unsigned norm_ofs = normalizer.size()-1;
var_idx_set normalized_vars; var_idx_set normalized_vars;
var_idx_set::iterator vit = non_local_vars.begin(); for (auto idx : non_local_vars) {
var_idx_set::iterator vend = non_local_vars.end(); unsigned norm_var = to_var(normalizer.get(norm_ofs - idx))->get_idx();
for(; vit!=vend; ++vit) {
unsigned norm_var = to_var(normalizer.get(norm_ofs-*vit))->get_idx();
normalized_vars.insert(norm_var); normalized_vars.insert(norm_var);
} }
inf.add_rule(*this, t1, t2, r, normalized_vars, non_local_vars); inf.add_rule(*this, t1, t2, r, normalized_vars, non_local_vars);
TRACE("dl", tout << mk_pp(t1, m) << " " << mk_pp(t2, m) << " "; TRACE("dl", tout << mk_pp(t1, m) << " " << mk_pp(t2, m) << " ";
vit = non_local_vars.begin(); tout << non_local_vars << "\n";
for (; vit != vend; ++vit) tout << *vit << " ";
tout << "\n";
r->display(m_context, tout); r->display(m_context, tout);
if (inf.can_be_joined()) tout << "cost: " << inf.get_cost() << "\n";); if (inf.can_be_joined()) tout << "cost: " << inf.get_cost() << "\n";);
@ -298,8 +285,7 @@ namespace datalog {
void remove_rule_from_pair(app_pair key, rule * r, unsigned original_len) { void remove_rule_from_pair(app_pair key, rule * r, unsigned original_len) {
pair_info * ptr = nullptr; pair_info * ptr = nullptr;
if (m_costs.find(key, ptr) && ptr && if (m_costs.find(key, ptr) && ptr && ptr->remove_rule(r, original_len)) {
ptr->remove_rule(r, original_len)) {
SASSERT(ptr->m_rules.empty()); SASSERT(ptr->m_rules.empty());
m_costs.remove(key); m_costs.remove(key);
dealloc(ptr); dealloc(ptr);
@ -309,28 +295,32 @@ namespace datalog {
void register_rule(rule * r) { void register_rule(rule * r) {
rule_counter counter; rule_counter counter;
counter.count_rule_vars(r, 1); counter.count_rule_vars(r, 1);
TRACE("dl", tout << "counter: "; for (auto const& kv: counter) tout << kv.m_key << ": " << kv.m_value << " "; tout << "\n";);
ptr_vector<app> & rule_content = ptr_vector<app> & rule_content =
m_rules_content.insert_if_not_there2(r, ptr_vector<app>())->get_data().m_value; m_rules_content.insert_if_not_there2(r, ptr_vector<app>())->get_data().m_value;
SASSERT(rule_content.empty()); SASSERT(rule_content.empty());
unsigned pos_tail_size=r->get_positive_tail_size(); TRACE("dl", r->display(m_context, tout << "register "););
for(unsigned i=0; i<pos_tail_size; i++) {
unsigned pos_tail_size = r->get_positive_tail_size();
for (unsigned i = 0; i < pos_tail_size; i++) {
rule_content.push_back(r->get_tail(i)); rule_content.push_back(r->get_tail(i));
} }
for(unsigned i=0; i+1 < pos_tail_size; i++) { for (unsigned i=0; i+1 < pos_tail_size; i++) {
app * t1 = r->get_tail(i); app * t1 = r->get_tail(i);
var_idx_set t1_vars = rm.collect_vars(t1); var_idx_set t1_vars = rm.collect_vars(t1);
counter.count_vars(t1, -1); //temporarily remove t1 variables from counter counter.count_vars(t1, -1); //temporarily remove t1 variables from counter
for(unsigned j=i+1; j<pos_tail_size; j++) { for (unsigned j = i+1; j < pos_tail_size; j++) {
app * t2 = r->get_tail(j); app * t2 = r->get_tail(j);
counter.count_vars(t2, -1); //temporarily remove t2 variables from counter counter.count_vars(t2, -1); //temporarily remove t2 variables from counter
var_idx_set scope_vars = rm.collect_vars(t2); var_idx_set t2_vars = rm.collect_vars(t2);
scope_vars |= t1_vars; t2_vars |= t1_vars;
var_idx_set non_local_vars; var_idx_set non_local_vars;
counter.collect_positive(non_local_vars); counter.collect_positive(non_local_vars);
counter.count_vars(t2, 1); //restore t2 variables in counter counter.count_vars(t2, 1); //restore t2 variables in counter
set_intersection(non_local_vars, scope_vars); set_intersection(non_local_vars, t2_vars);
TRACE("dl", tout << "non-local vars: " << non_local_vars << "\n";);
register_pair(t1, t2, r, non_local_vars); register_pair(t1, t2, r, non_local_vars);
} }
counter.count_vars(t1, 1); //restore t1 variables in counter counter.count_vars(t1, 1); //restore t1 variables in counter
@ -338,11 +328,10 @@ namespace datalog {
} }
bool extract_argument_info(unsigned var_idx, app * t, expr_ref_vector & args, bool extract_argument_info(unsigned var_idx, app * t, expr_ref_vector & args,
ptr_vector<sort> & domain) { ptr_vector<sort> & domain) {
unsigned n=t->get_num_args(); for (expr* arg : *t) {
for(unsigned i=0; i<n; i++) { var * v = to_var(arg);
var * v=to_var(t->get_arg(i)); if (v->get_idx() == var_idx) {
if (v->get_idx()==var_idx) {
args.push_back(v); args.push_back(v);
domain.push_back(m.get_sort(v)); domain.push_back(m.get_sort(v));
return true; return true;
@ -354,33 +343,27 @@ namespace datalog {
void join_pair(app_pair pair_key) { void join_pair(app_pair pair_key) {
app * t1 = pair_key.first; app * t1 = pair_key.first;
app * t2 = pair_key.second; app * t2 = pair_key.second;
pair_info* infp = nullptr; pair_info & inf = *m_costs[pair_key];
if (!m_costs.find(pair_key, infp) || !infp) {
UNREACHABLE();
return;
}
pair_info & inf = *infp;
SASSERT(!inf.m_rules.empty()); SASSERT(!inf.m_rules.empty());
var_idx_set & output_vars = inf.m_all_nonlocal_vars; var_idx_set const & output_vars = inf.m_all_nonlocal_vars;
expr_ref_vector args(m); expr_ref_vector args(m);
ptr_vector<sort> domain; ptr_vector<sort> domain;
unsigned arity = output_vars.num_elems(); unsigned arity = output_vars.num_elems();
idx_set::iterator ovit=output_vars.begin(); for (unsigned var_idx : output_vars) {
idx_set::iterator ovend=output_vars.end(); bool found = extract_argument_info(var_idx, t1, args, domain);
//TODO: improve quadratic complexity
for(;ovit!=ovend;++ovit) {
unsigned var_idx=*ovit;
bool found=extract_argument_info(var_idx, t1, args, domain);
if (!found) { if (!found) {
found=extract_argument_info(var_idx, t2, args, domain); found = extract_argument_info(var_idx, t2, args, domain);
} }
SASSERT(found); SASSERT(found);
} }
TRACE("dl",
tout << mk_pp(t1, m) << " " << mk_pp(t2, m) << " arity: " << arity << "\n";
tout << "output: " << output_vars << "\n";
tout << "args: " << args << "\n";);
SASSERT(args.size()==arity); SASSERT(args.size() == arity);
SASSERT(domain.size()==arity); SASSERT(domain.size() == arity);
rule * one_parent = inf.m_rules.back(); rule * one_parent = inf.m_rules.back();
@ -415,8 +398,7 @@ namespace datalog {
rule_hashtable processed_rules; rule_hashtable processed_rules;
rule_vector rules(inf.m_rules); rule_vector rules(inf.m_rules);
for (unsigned i = 0; i < rules.size(); ++i) { for (rule * r : rules) {
rule* r = rules[i];
if (!processed_rules.contains(r)) { if (!processed_rules.contains(r)) {
apply_binary_rule(r, pair_key, head); apply_binary_rule(r, pair_key, head);
processed_rules.insert(r); processed_rules.insert(r);
@ -425,33 +407,34 @@ namespace datalog {
// SASSERT(!m_costs.contains(pair_key)); // SASSERT(!m_costs.contains(pair_key));
} }
void replace_edges(rule * r, const ptr_vector<app> & removed_tails, void replace_edges(rule * r, const app_ref_vector & removed_tails,
const ptr_vector<app> & added_tails0, const ptr_vector<app> & rule_content) { const app_ref_vector & added_tails0, const ptr_vector<app> & rule_content) {
SASSERT(removed_tails.size()>=added_tails0.size()); SASSERT(removed_tails.size()>=added_tails0.size());
unsigned len = rule_content.size(); unsigned len = rule_content.size();
unsigned original_len = len+removed_tails.size()-added_tails0.size(); unsigned original_len = len+removed_tails.size()-added_tails0.size();
ptr_vector<app> added_tails(added_tails0); //we need a copy since we'll be modifying it app_ref_vector added_tails(added_tails0); //we need a copy since we'll be modifying it
TRACE("dl", tout << added_tails << "\n";);
unsigned rt_sz = removed_tails.size(); unsigned rt_sz = removed_tails.size();
//remove edges between removed tails //remove edges between removed tails
for(unsigned i=0; i<rt_sz; i++) { for (unsigned i = 0; i < rt_sz; i++) {
for(unsigned j=i+1; j<rt_sz; j++) { for (unsigned j = i+1; j < rt_sz; j++) {
app_pair pair_key = get_key(removed_tails[i], removed_tails[j]); app_pair pair_key = get_key(removed_tails[i], removed_tails[j]);
remove_rule_from_pair(pair_key, r, original_len); remove_rule_from_pair(pair_key, r, original_len);
} }
} }
//remove edges between surviving tails and removed tails //remove edges between surviving tails and removed tails
for(unsigned i=0; i<len; i++) { for (unsigned i = 0; i < len; i++) {
if (added_tails.contains(rule_content[i])) { if (added_tails.contains(rule_content[i])) {
continue; continue;
} }
for(unsigned ri=0; ri<rt_sz; ri++) { for (unsigned ri = 0; ri < rt_sz; ri++) {
app_pair pair_key = get_key(rule_content[i], removed_tails[ri]); app_pair pair_key = get_key(rule_content[i], removed_tails[ri]);
remove_rule_from_pair(pair_key, r, original_len); remove_rule_from_pair(pair_key, r, original_len);
} }
} }
if (len==1) { if (len == 1) {
return; return;
} }
@ -463,21 +446,23 @@ namespace datalog {
unsigned tail_size=r->get_tail_size(); unsigned tail_size=r->get_tail_size();
unsigned pos_tail_size=r->get_positive_tail_size(); unsigned pos_tail_size=r->get_positive_tail_size();
for(unsigned i=pos_tail_size; i<tail_size; i++) { for (unsigned i=pos_tail_size; i<tail_size; i++) {
counter.count_vars(r->get_tail(i), 1); counter.count_vars(r->get_tail(i), 1);
} }
for(unsigned i=0; i<len; i++) { for (unsigned i=0; i<len; i++) {
counter.count_vars(rule_content[i], 1); counter.count_vars(rule_content[i], 1);
} }
//add edges that contain added tails //add edges that contain added tails
while(!added_tails.empty()) { while (!added_tails.empty()) {
app * a_tail = added_tails.back(); //added tail app * a_tail = added_tails.back(); //added tail
TRACE("dl", tout << "replace edges " << mk_pp(a_tail, m) << "\n";);
var_idx_set a_tail_vars = rm.collect_vars(a_tail); var_idx_set a_tail_vars = rm.collect_vars(a_tail);
counter.count_vars(a_tail, -1); //temporarily remove a_tail variables from counter counter.count_vars(a_tail, -1); //temporarily remove a_tail variables from counter
for(unsigned i=0; i<len; i++) { for (unsigned i = 0; i < len; i++) {
app * o_tail = rule_content[i]; //other tail app * o_tail = rule_content[i]; //other tail
if (added_tails.contains(o_tail)) { if (added_tails.contains(o_tail)) {
//this avoids adding edges between new tails twice //this avoids adding edges between new tails twice
@ -504,63 +489,95 @@ namespace datalog {
app * t2 = pair_key.second; app * t2 = pair_key.second;
ptr_vector<app> & rule_content = m_rules_content.find(r); ptr_vector<app> & rule_content = m_rules_content.find(r);
unsigned len = rule_content.size(); unsigned len = rule_content.size();
if (len==1) { if (len == 1) {
return; return;
} }
pair_info & inf = *m_costs[pair_key];
TRACE("dl",
r->display(m_context, tout << "rule ");
tout << "pair: " << mk_pp(t1, m) << " " << mk_pp(t2, m) << "\n";
tout << mk_pp(t_new, m) << "\n";
tout << "all-non-local: " << inf.m_all_nonlocal_vars << "\n";
for (app* a : rule_content) tout << mk_pp(a, m) << " "; tout << "\n";);
rule_counter counter;
counter.count_rule_vars(r, 1);
func_decl * t1_pred = t1->get_decl(); func_decl * t1_pred = t1->get_decl();
func_decl * t2_pred = t2->get_decl(); func_decl * t2_pred = t2->get_decl();
ptr_vector<app> removed_tails; app_ref_vector removed_tails(m);
ptr_vector<app> added_tails; app_ref_vector added_tails(m);
for(unsigned i1=0; i1<len; i1++) { for (unsigned i1 = 0; i1 < len; i1++) {
app * rt1 = rule_content[i1]; app * rt1 = rule_content[i1];
if (rt1->get_decl()!=t1_pred) { if (rt1->get_decl() != t1_pred) {
continue; continue;
} }
var_idx_set rt1_vars = rm.collect_vars(rt1);
counter.count_vars(rt1, -1);
var_idx_set t1_vars = rm.collect_vars(t1);
unsigned i2start = (t1_pred==t2_pred) ? (i1+1) : 0; unsigned i2start = (t1_pred==t2_pred) ? (i1+1) : 0;
for(unsigned i2=i2start; i2<len; i2++) { for (unsigned i2 = i2start; i2 < len; i2++) {
app * rt2 = rule_content[i2]; app * rt2 = rule_content[i2];
if (i1==i2 || rt2->get_decl()!=t2_pred) { if (i1 == i2 || rt2->get_decl() != t2_pred) {
continue; continue;
} }
if (get_key(rt1, rt2)!=pair_key) { if (get_key(rt1, rt2) != pair_key) {
continue; continue;
} }
expr_ref_vector normalizer(m);
get_normalizer(rt1, rt2, normalizer);
expr_ref_vector denormalizer(m); expr_ref_vector denormalizer(m);
expr_ref_vector normalizer = get_normalizer(rt1, rt2);
reverse_renaming(m, normalizer, denormalizer); reverse_renaming(m, normalizer, denormalizer);
expr_ref new_transf(m); expr_ref new_transf(m);
m_var_subst(t_new, denormalizer.size(), denormalizer.c_ptr(), new_transf); m_var_subst(t_new, denormalizer.size(), denormalizer.c_ptr(), new_transf);
var_idx_set transf_vars = rm.collect_vars(new_transf);
TRACE("dl", tout << mk_pp(rt1, m) << " " << mk_pp(rt2, m) << " -> " << new_transf << "\n";);
counter.count_vars(rt2, -1);
var_idx_set rt2_vars = rm.collect_vars(rt2);
var_idx_set tr_vars = rm.collect_vars(new_transf);
rt2_vars |= rt1_vars;
var_idx_set non_local_vars;
counter.collect_positive(non_local_vars);
set_intersection(non_local_vars, rt2_vars);
counter.count_vars(rt2, +1);
// require that tr_vars contains non_local_vars
TRACE("dl", tout << "non-local : " << non_local_vars << " tr_vars " << tr_vars << " rt12_vars " << rt2_vars << "\n";);
if (!non_local_vars.subset_of(tr_vars)) {
expr_ref_vector normalizer2 = get_normalizer(rt2, rt1);
TRACE("dl", tout << normalizer << "\nnorm\n" << normalizer2 << "\n";);
denormalizer.reset();
reverse_renaming(m, normalizer2, denormalizer);
m_var_subst(t_new, denormalizer.size(), denormalizer.c_ptr(), new_transf);
SASSERT(non_local_vars.subset_of(rm.collect_vars(new_transf)));
TRACE("dl", tout << mk_pp(rt2, m) << " " << mk_pp(rt1, m) << " -> " << new_transf << "\n";);
}
app * new_lit = to_app(new_transf); app * new_lit = to_app(new_transf);
m_pinned.push_back(new_lit); m_pinned.push_back(new_lit);
rule_content[i1]=new_lit; rule_content[i1] = new_lit;
rule_content[i2]=rule_content.back(); rule_content[i2] = rule_content.back();
rule_content.pop_back(); rule_content.pop_back();
len--; //here the bound of both loops changes!!! len--; //here the bound of both loops changes!!!
removed_tails.push_back(rt1); removed_tails.push_back(rt1);
removed_tails.push_back(rt2); removed_tails.push_back(rt2);
added_tails.push_back(new_lit); added_tails.push_back(new_lit);
//this exits the inner loop, the outer one continues in case there will // this exits the inner loop, the outer one continues in case there will
//be other matches // be other matches
break; break;
} }
counter.count_vars(rt1, 1);
} }
SASSERT(!removed_tails.empty()); SASSERT(!removed_tails.empty());
SASSERT(!added_tails.empty()); SASSERT(!added_tails.empty());
m_modified_rules = true; m_modified_rules = true;
TRACE("dl", tout << "replace rule content\n";);
replace_edges(r, removed_tails, added_tails, rule_content); replace_edges(r, removed_tails, added_tails, rule_content);
} }
cost get_domain_size(func_decl * pred, unsigned arg_index) const { cost get_domain_size(func_decl * pred, unsigned arg_index) const {
relation_sort sort = pred->get_domain(arg_index); relation_sort sort = pred->get_domain(arg_index);
return static_cast<cost>(m_context.get_sort_size_estimate(sort)); return static_cast<cost>(m_context.get_sort_size_estimate(sort));
//unsigned sz;
//if (!m_context.get_sort_size(sort, sz)) {
// sz=UINT_MAX;
//}
//return static_cast<cost>(sz);
} }
unsigned get_stratum(func_decl * pred) const { unsigned get_stratum(func_decl * pred) const {
@ -569,7 +586,7 @@ namespace datalog {
cost estimate_size(app * t) const { cost estimate_size(app * t) const {
func_decl * pred = t->get_decl(); func_decl * pred = t->get_decl();
unsigned n=pred->get_arity(); unsigned n = pred->get_arity();
rel_context_base* rel = m_context.get_rel_context(); rel_context_base* rel = m_context.get_rel_context();
if (!rel) { if (!rel) {
return cost(1); return cost(1);
@ -582,7 +599,7 @@ namespace datalog {
if (rel_size_int!=0) { if (rel_size_int!=0) {
cost rel_size = static_cast<cost>(rel_size_int); cost rel_size = static_cast<cost>(rel_size_int);
cost curr_size = rel_size; cost curr_size = rel_size;
for(unsigned i=0; i<n; i++) { for (unsigned i=0; i<n; i++) {
if (!is_var(t->get_arg(i))) { if (!is_var(t->get_arg(i))) {
curr_size /= get_domain_size(pred, i); curr_size /= get_domain_size(pred, i);
} }
@ -591,7 +608,7 @@ namespace datalog {
} }
} }
cost res = 1; cost res = 1;
for(unsigned i=0; i<n; i++) { for (unsigned i=0; i<n; i++) {
if (is_var(t->get_arg(i))) { if (is_var(t->get_arg(i))) {
res *= get_domain_size(pred, i); res *= get_domain_size(pred, i);
} }
@ -607,7 +624,7 @@ namespace datalog {
vi.populate(t1, t2); vi.populate(t1, t2);
unsigned n = vi.size(); unsigned n = vi.size();
// remove contributions from joined columns. // remove contributions from joined columns.
for(unsigned i=0; i<n; i++) { for (unsigned i=0; i<n; i++) {
unsigned arg_index1, arg_index2; unsigned arg_index1, arg_index2;
vi.get(i, arg_index1, arg_index2); vi.get(i, arg_index1, arg_index2);
SASSERT(is_var(t1->get_arg(arg_index1))); SASSERT(is_var(t1->get_arg(arg_index1)));
@ -634,24 +651,6 @@ namespace datalog {
cost res = estimate_size(t1)*estimate_size(t2)/ inters_size; // (inters_size*inters_size); cost res = estimate_size(t1)*estimate_size(t2)/ inters_size; // (inters_size*inters_size);
//cost res = -inters_size; //cost res = -inters_size;
/*unsigned t1_strat = get_stratum(t1_pred);
SASSERT(t1_strat<=m_head_stratum);
if (t1_strat<m_head_stratum) {
unsigned t2_strat = get_stratum(t2_pred);
SASSERT(t2_strat<=m_head_stratum);
if (t2_strat<m_head_stratum) {
//the rule of this predicates would depend on predicates
//in lower stratum than the head, which is a good thing, since
//then the rule code will not need to appear in a loop
if (res>0) {
res /= 2;
}
else {
res *= 2;
}
}
}*/
TRACE("report_costs", TRACE("report_costs",
display_predicate(m_context, t1, tout); display_predicate(m_context, t1, tout);
display_predicate(m_context, t2, tout); display_predicate(m_context, t2, tout);
@ -665,16 +664,14 @@ namespace datalog {
bool found = false; bool found = false;
cost best_cost; cost best_cost;
cost_map::iterator it = m_costs.begin(); for (auto const& kv : m_costs) {
cost_map::iterator end = m_costs.end(); app_pair key = kv.m_key;
for(; it!=end; ++it) { pair_info & inf = *kv.m_value;
app_pair key = it->m_key;
pair_info & inf = *it->m_value;
if (!inf.can_be_joined()) { if (!inf.can_be_joined()) {
continue; continue;
} }
cost c = inf.get_cost(); cost c = inf.get_cost();
if (!found || c<best_cost) { if (!found || c < best_cost) {
found = true; found = true;
best_cost = c; best_cost = c;
best = key; best = key;
@ -683,7 +680,7 @@ namespace datalog {
if (!found) { if (!found) {
return false; return false;
} }
p=best; p = best;
return true; return true;
} }
@ -691,27 +688,24 @@ namespace datalog {
public: public:
rule_set * run(rule_set const & source) { rule_set * run(rule_set const & source) {
unsigned num_rules = source.get_num_rules(); for (rule * r : source) {
for (unsigned i = 0; i < num_rules; i++) { register_rule(r);
register_rule(source.get_rule(i));
} }
app_pair selected; app_pair selected;
while(pick_best_pair(selected)) { while (pick_best_pair(selected)) {
join_pair(selected); join_pair(selected);
} }
if (!m_modified_rules) { if (!m_modified_rules) {
return nullptr; return nullptr;
} }
rule_set * result = alloc(rule_set, m_context); rule_set * result = alloc(rule_set, m_context);
rule_pred_map::iterator rcit = m_rules_content.begin(); for (auto& kv : m_rules_content) {
rule_pred_map::iterator rcend = m_rules_content.end(); rule * orig_r = kv.m_key;
for(; rcit!=rcend; ++rcit) { ptr_vector<app> content = kv.m_value;
rule * orig_r = rcit->m_key; SASSERT(content.size() <= 2);
ptr_vector<app> content = rcit->m_value; if (content.size() == orig_r->get_positive_tail_size()) {
SASSERT(content.size()<=2);
if (content.size()==orig_r->get_positive_tail_size()) {
//rule did not change //rule did not change
result->add_rule(orig_r); result->add_rule(orig_r);
continue; continue;
@ -720,7 +714,7 @@ namespace datalog {
ptr_vector<app> tail(content); ptr_vector<app> tail(content);
svector<bool> negs(tail.size(), false); svector<bool> negs(tail.size(), false);
unsigned or_len = orig_r->get_tail_size(); unsigned or_len = orig_r->get_tail_size();
for(unsigned i=orig_r->get_positive_tail_size(); i<or_len; i++) { for (unsigned i=orig_r->get_positive_tail_size(); i<or_len; i++) {
tail.push_back(orig_r->get_tail(i)); tail.push_back(orig_r->get_tail(i));
negs.push_back(orig_r->is_neg_tail(i)); negs.push_back(orig_r->is_neg_tail(i));
} }