diff --git a/src/muz/base/dl_context.cpp b/src/muz/base/dl_context.cpp index f3088fa04..529028606 100644 --- a/src/muz/base/dl_context.cpp +++ b/src/muz/base/dl_context.cpp @@ -233,6 +233,7 @@ namespace datalog { m_engine_type(LAST_ENGINE), m_cancel(false) { re.set_context(this); + m_generate_proof_trace = m_params->generate_proof_trace(); } context::~context() { @@ -271,7 +272,7 @@ namespace datalog { } - bool context::generate_proof_trace() const { return m_params->generate_proof_trace(); } + bool context::generate_proof_trace() const { return m_generate_proof_trace; } bool context::output_profile() const { return m_params->datalog_output_profile(); } bool context::output_tuples() const { return m_params->datalog_print_tuples(); } bool context::use_map_names() const { return m_params->datalog_use_map_names(); } @@ -489,10 +490,8 @@ namespace datalog { ++m_rule_fmls_head; } rule_set::iterator it = m_rule_set.begin(), end = m_rule_set.end(); - rule_ref r(m_rule_manager); for (; it != end; ++it) { - r = *it; - check_rule(r); + check_rule(*(*it)); } } @@ -579,35 +578,35 @@ namespace datalog { m_engine->add_cover(level, pred, property); } - void context::check_uninterpreted_free(rule_ref& r) { + void context::check_uninterpreted_free(rule& r) { func_decl* f = 0; - if (r->has_uninterpreted_non_predicates(m, f)) { + if (r.has_uninterpreted_non_predicates(m, f)) { std::stringstream stm; stm << "Uninterpreted '" << f->get_name() << "' in "; - r->display(*this, stm); + r.display(*this, stm); throw default_exception(stm.str()); } } - void context::check_quantifier_free(rule_ref& r) { - if (r->has_quantifiers()) { + void context::check_quantifier_free(rule& r) { + if (r.has_quantifiers()) { std::stringstream stm; stm << "cannot process quantifiers in rule "; - r->display(*this, stm); + r.display(*this, stm); throw default_exception(stm.str()); } } - void context::check_existential_tail(rule_ref& r) { - unsigned ut_size = r->get_uninterpreted_tail_size(); - unsigned t_size = r->get_tail_size(); + void context::check_existential_tail(rule& r) { + unsigned ut_size = r.get_uninterpreted_tail_size(); + unsigned t_size = r.get_tail_size(); - TRACE("dl", r->display_smt2(get_manager(), tout); tout << "\n";); + TRACE("dl", r.display_smt2(get_manager(), tout); tout << "\n";); for (unsigned i = ut_size; i < t_size; ++i) { - app* t = r->get_tail(i); + app* t = r.get_tail(i); TRACE("dl", tout << "checking: " << mk_ismt2_pp(t, get_manager()) << "\n";); if (m_check_pred(t)) { std::ostringstream out; @@ -617,14 +616,14 @@ namespace datalog { } } - void context::check_positive_predicates(rule_ref& r) { + void context::check_positive_predicates(rule& r) { ast_mark visited; ptr_vector todo, tocheck; - unsigned ut_size = r->get_uninterpreted_tail_size(); - unsigned t_size = r->get_tail_size(); + unsigned ut_size = r.get_uninterpreted_tail_size(); + unsigned t_size = r.get_tail_size(); for (unsigned i = 0; i < ut_size; ++i) { - if (r->is_neg_tail(i)) { - tocheck.push_back(r->get_tail(i)); + if (r.is_neg_tail(i)) { + tocheck.push_back(r.get_tail(i)); } } ast_manager& m = get_manager(); @@ -632,7 +631,7 @@ namespace datalog { check_pred check_pred(contains_p, get_manager()); for (unsigned i = ut_size; i < t_size; ++i) { - todo.push_back(r->get_tail(i)); + todo.push_back(r.get_tail(i)); } while (!todo.empty()) { expr* e = todo.back(), *e1, *e2; @@ -670,14 +669,14 @@ namespace datalog { if (check_pred(e)) { std::ostringstream out; out << "recursive predicate " << mk_ismt2_pp(e, get_manager()) << " occurs nested in body"; - r->display(*this, out << "\n"); + r.display(*this, out << "\n"); throw default_exception(out.str()); } } } - void context::check_rule(rule_ref& r) { + void context::check_rule(rule& r) { switch(get_engine()) { case DATALOG_ENGINE: check_quantifier_free(r); @@ -719,8 +718,8 @@ namespace datalog { UNREACHABLE(); break; } - if (generate_proof_trace() && !r->get_proof()) { - m_rule_manager.mk_rule_asserted_proof(*r.get()); + if (generate_proof_trace() && !r.get_proof()) { + m_rule_manager.mk_rule_asserted_proof(r); } } @@ -847,6 +846,7 @@ namespace datalog { void context::updt_params(params_ref const& p) { m_params_ref.copy(p); if (m_engine.get()) m_engine->updt_params(); + m_generate_proof_trace = m_params->generate_proof_trace(); } expr_ref context::get_background_assertion() { @@ -908,6 +908,9 @@ namespace datalog { }; void context::configure_engine() { + if (m_engine_type != LAST_ENGINE) { + return; + } symbol e = m_params->engine(); if (e == symbol("datalog")) { @@ -969,8 +972,7 @@ namespace datalog { rule_set::iterator it = m_rule_set.begin(), end = m_rule_set.end(); rule_ref r(m_rule_manager); for (; it != end; ++it) { - r = *it; - check_rule(r); + check_rule(*(*it)); } } #endif diff --git a/src/muz/base/dl_context.h b/src/muz/base/dl_context.h index b7bcbb5fe..66addc37c 100644 --- a/src/muz/base/dl_context.h +++ b/src/muz/base/dl_context.h @@ -171,6 +171,7 @@ namespace datalog { smt_params & m_fparams; params_ref m_params_ref; fixedpoint_params* m_params; + bool m_generate_proof_trace; dl_decl_util m_decl_util; th_rewriter m_rewriter; var_subst m_var_subst; @@ -416,7 +417,7 @@ namespace datalog { /** \brief Check if rule is well-formed according to engine. */ - void check_rule(rule_ref& r); + void check_rule(rule& r); /** \brief Return true if facts to \c pred can be added using the \c add_table_fact() function. @@ -562,10 +563,10 @@ namespace datalog { void ensure_engine(); - void check_quantifier_free(rule_ref& r); - void check_uninterpreted_free(rule_ref& r); - void check_existential_tail(rule_ref& r); - void check_positive_predicates(rule_ref& r); + void check_quantifier_free(rule& r); + void check_uninterpreted_free(rule& r); + void check_existential_tail(rule& r); + void check_positive_predicates(rule& r); // auxilary functions for SMT2 pretty-printer. void declare_vars(expr_ref_vector& rules, mk_fresh_name& mk_fresh, std::ostream& out); diff --git a/src/muz/ddnf/ddnf.cpp b/src/muz/ddnf/ddnf.cpp index 9a47976b2..6d5dd1c80 100644 --- a/src/muz/ddnf/ddnf.cpp +++ b/src/muz/ddnf/ddnf.cpp @@ -513,13 +513,13 @@ namespace datalog { lbool query(expr* query) { m_ctx.ensure_opened(); - rm.mk_query(query, m_ctx.get_rules()); - + rule_set& old_rules = m_ctx.get_rules(); + rm.mk_query(query, old_rules); rule_set new_rules(m_ctx); - if (!pre_process_rules()) { + if (!pre_process_rules(old_rules)) { return l_undef; } - if (!compile_rules1(new_rules)) { + if (!compile_rules1(old_rules, new_rules)) { return l_undef; } IF_VERBOSE(2, m_ddnfs.display(verbose_stream());); @@ -564,12 +564,11 @@ namespace datalog { return pr; } - bool pre_process_rules() { + bool pre_process_rules(rule_set const& rules) { m_visited1.reset(); m_todo.reset(); m_cache.reset(); m_expr2tbv.reset(); - rule_set const& rules = m_ctx.get_rules(); datalog::rule_set::iterator it = rules.begin(); datalog::rule_set::iterator end = rules.end(); for (; it != end; ++it) { @@ -700,20 +699,19 @@ namespace datalog { return m_inner_ctx.rel_query(heads.size(), heads.c_ptr()); } - bool compile_rules1(rule_set& new_rules) { - rule_set const& rules = m_ctx.get_rules(); + bool compile_rules1(rule_set const& rules, rule_set& new_rules) { datalog::rule_set::iterator it = rules.begin(); datalog::rule_set::iterator end = rules.end(); unsigned idx = 0; for (; it != end; ++idx, ++it) { - if (!compile_rule1(**it, new_rules)) { + if (!compile_rule1(**it, rules, new_rules)) { return false; } } return true; } - bool compile_rule1(rule& r, rule_set& new_rules) { + bool compile_rule1(rule& r, rule_set const& old_rules, rule_set& new_rules) { app_ref head(m), pred(m); app_ref_vector body(m); expr_ref tmp(m); @@ -728,10 +726,10 @@ namespace datalog { compile_expr(r.get_tail(i), tmp); body.push_back(to_app(tmp)); } - rule* r_new = rm.mk(head, body.size(), body.c_ptr(), 0, r.name(), true); + rule* r_new = rm.mk(head, body.size(), body.c_ptr(), 0, r.name(), false); new_rules.add_rule(r_new); IF_VERBOSE(2, r_new->display(m_ctx, verbose_stream());); - if (m_ctx.get_rules().is_output_predicate(r.get_decl())) { + if (old_rules.is_output_predicate(r.get_decl())) { new_rules.set_output_predicate(r_new->get_decl()); } return true; @@ -767,7 +765,8 @@ namespace datalog { result = to_var(r); } else { - result = m.mk_var(v->get_id(), compile_sort(v->get_sort())); + unsigned idx = v->get_idx(); + result = m.mk_var(idx, compile_sort(v->get_sort())); insert_cache(v, result); } } diff --git a/src/muz/rel/dl_vector_relation.h b/src/muz/rel/dl_vector_relation.h index 114f4ca43..6c55e7b6d 100644 --- a/src/muz/rel/dl_vector_relation.h +++ b/src/muz/rel/dl_vector_relation.h @@ -42,8 +42,6 @@ namespace datalog { union_find_default_ctx m_ctx; union_find<>* m_eqs; - friend class vector_relation_plugin; - public: vector_relation(relation_plugin& p, relation_signature const& s, bool is_empty, T const& t = T()): relation_base(p, s), @@ -107,9 +105,10 @@ namespace datalog { display_index(i, (*m_elems)[i], out); } else { - out << i << " = " << find(i) << "\n"; + out << i << " = " << find(i) << " "; } } + out << "\n"; } diff --git a/src/muz/rel/product_set.cpp b/src/muz/rel/product_set.cpp new file mode 100644 index 000000000..2812c56e2 --- /dev/null +++ b/src/muz/rel/product_set.cpp @@ -0,0 +1,510 @@ +/*++ +Copyright (c) 2014 Microsoft Corporation + +Module Name: + + product_set.cpp + +Abstract: + + Product set. + +Author: + + Nikolaj Bjorner (nbjorner) 2014-08-23 + +Revision History: + +--*/ + +#include "product_set.h" +#include "bv_decl_plugin.h" +#include "dl_relation_manager.h" +#include "bool_rewriter.h" + +namespace datalog { + + product_set::product_set( + product_set_plugin& p, relation_signature const& s, + bool is_empty, T const& t): + vector_relation(p, s, is_empty, t), m_refs(0) { + for (unsigned i = 0; i < s.size(); ++i) { + (*this)[i] = bit_vector(p.set_size(s[i])); + } + } + + + unsigned product_set::get_hash() const { + unsigned hash = 0; + for (unsigned i = 0; i < get_signature().size(); ++i) { + hash ^= (*this)[i].get_hash(); + } + return hash; + } + + bool product_set::operator==(product_set const& p) const { + for (unsigned i = 0; i < get_signature().size(); ++i) { + if ((*this)[i] != p[i]) return false; + } + return true; + } + + bool product_set::contains(product_set const& p) const { + for (unsigned i = 0; i < get_signature().size(); ++i) { + if ((*this)[i].contains(p[i])) return false; + } + return true; + } + + void product_set::add_fact(const relation_fact & f) { + UNREACHABLE(); + } + bool product_set::contains_fact(const relation_fact & f) const { + return false; + } + relation_base * product_set::clone() const { + UNREACHABLE(); + return 0; + } + relation_base * product_set::complement(func_decl*) const { + UNREACHABLE(); + return 0; + } + void product_set::to_formula(expr_ref& fml) const { + ast_manager& m = fml.get_manager(); + bv_util bv(m); + expr_ref_vector conjs(m), disjs(m); + relation_signature const& sig = get_signature(); + if (m_empty) { + fml = m.mk_false(); + return; + } + for (unsigned i = 0; i < sig.size(); ++i) { + sort* ty = sig[i]; + expr_ref var(m.mk_var(i, ty), m); + unsigned j = find(i); + if (i != j) { + conjs.push_back(m.mk_eq(var, m.mk_var(j, sig[j]))); + continue; + } + T const& t = (*this)[i]; + disjs.reset(); + for (j = 0; j < t.size(); ++j) { + if (t.get(j)) { + disjs.push_back(m.mk_eq(var, bv.mk_numeral(rational(j), ty))); + } + } + if (disjs.empty()) { + UNREACHABLE(); + fml = m.mk_false(); + return; + } + if (disjs.size() == 1) { + conjs.push_back(disjs[0].get()); + } + else { + conjs.push_back(m.mk_or(disjs.size(), disjs.c_ptr())); + } + } + bool_rewriter br(m); + br.mk_and(conjs.size(), conjs.c_ptr(), fml); + } + void product_set::display_index(unsigned i, const T& t, std::ostream& out) const { + out << i << ":"; + t.display(out); + } + bool product_set::mk_intersect(unsigned idx, T const& t) { + if (!m_empty) { + (*this)[idx] &= t; + m_empty = is_empty(idx, (*this)[idx]); + } + return !m_empty; + } + + // product_set_relation + + + product_set_relation::product_set_relation(product_set_plugin& p, relation_signature const& s): + relation_base(p, s) { + } + + product_set_relation::~product_set_relation() { + product_sets::iterator it = m_elems.begin(), end = m_elems.end(); + for (; it != end; ++it) { + (*it)->dec_ref(); + } + } + + void product_set_relation::add_fact(const relation_fact & f) { + ast_manager& m = get_plugin().get_ast_manager(); + bv_util bv(m); + rational v; + unsigned bv_size; + product_set* s = alloc(product_set, get_plugin(), get_signature(), false); + for (unsigned i = 0; i < f.size(); ++i) { + VERIFY(bv.is_numeral(f[i], v, bv_size)); + SASSERT(v.is_unsigned()); + (*s)[i] = bit_vector(get_plugin().set_size(m.get_sort(f[i]))); + (*s)[i].set(v.get_unsigned(), true); + } + s->display(std::cout << "fact"); + if (m_elems.contains(s)) { + dealloc(s); + } + else { + s->inc_ref(); + m_elems.insert(s); + } + + } + bool product_set_relation::contains_fact(const relation_fact & f) const { + std::cout << "contains fact\n"; + NOT_IMPLEMENTED_YET(); + return false; + } + product_set_relation * product_set_relation::clone() const { + product_set_relation* r = alloc(product_set_relation, get_plugin(), get_signature()); + product_sets::iterator it = m_elems.begin(), end = m_elems.end(); + for (; it != end; ++it) { + // TBD: have to copy because other operations are destructive. + (*it)->inc_ref(); + r->m_elems.insert(*it); + } + return r; + } + product_set_relation * product_set_relation::complement(func_decl*) const { + std::cout << "complement\n"; + NOT_IMPLEMENTED_YET(); + return 0; + } + void product_set_relation::to_formula(expr_ref& fml) const { + product_sets::iterator it = m_elems.begin(), end = m_elems.end(); + ast_manager& m = get_plugin().get_manager().get_context().get_manager(); + expr_ref_vector disjs(m); + for (; it != end; ++it) { + (*it)->to_formula(fml); + disjs.push_back(fml); + } + fml = m.mk_or(disjs.size(), disjs.c_ptr()); + } + product_set_plugin& product_set_relation::get_plugin() const { + return static_cast(relation_base::get_plugin()); + } + + void product_set_relation::display(std::ostream& out) const { + product_sets::iterator it = m_elems.begin(), end = m_elems.end(); + for (; it != end; ++it) { + (*it)->display(out); + } + } + + // product_set_plugin + + product_set_plugin::product_set_plugin(relation_manager& rm): + relation_plugin(product_set_plugin::get_name(), rm) { + } + + bool product_set_plugin::can_handle_signature(const relation_signature & sig) { + bv_util bv(get_manager().get_context().get_manager()); + for (unsigned i = 0; i < sig.size(); ++i) { + if (!bv.is_bv_sort(sig[i])) return false; + } + return true; + } + + product_set_relation& product_set_plugin::get(relation_base& r) { + return dynamic_cast(r); + } + product_set_relation* product_set_plugin::get(relation_base* r) { + return r?dynamic_cast(r):0; + } + product_set_relation const & product_set_plugin::get(relation_base const& r) { + return dynamic_cast(r); + } + relation_base * product_set_plugin::mk_empty(const relation_signature & s) { + return alloc(product_set_relation, *this, s); + } + relation_base * product_set_plugin::mk_full(func_decl* p, const relation_signature & sig) { + product_set_relation* result = alloc(product_set_relation, *this, sig); + product_set* s = alloc(product_set, *this, sig, false); + s->inc_ref(); + for (unsigned i = 0; i < sig.size(); ++i) { + bit_vector& t = (*s)[i]; + t = bit_vector(set_size(sig[i])); + for (unsigned j = 0; j < t.size(); ++j) { + t.set(j, true); + } + } + result->m_elems.insert(s); + return result; + } + product_set* product_set_plugin::insert(product_set* s, product_set_relation* r) { + if (s->empty()) { + s->reset(); + } + else if (r->m_elems.contains(s)) { + s->reset(); + } + else { + s->inc_ref(); + r->m_elems.insert(s); + s = alloc(product_set, *this, r->get_signature(), false); + } + return s; + } + + unsigned product_set_plugin::set_size(sort* ty) { + bv_util bv(get_ast_manager()); + unsigned bv_size = bv.get_bv_size(ty); + SASSERT(bv_size <= 16); + if (bv_size > 16) { + throw default_exception("bit-vector based sets are not suitable for this domain size"); + } + return 1 << bv_size; + } + + class product_set_plugin::join_fn : public convenient_relation_join_fn { + public: + join_fn(const relation_signature & o1_sig, const relation_signature & o2_sig, unsigned col_cnt, + const unsigned * cols1, const unsigned * cols2) + : convenient_relation_join_fn(o1_sig, o2_sig, col_cnt, cols1, cols2){ + } + + virtual relation_base * operator()(const relation_base & _r1, const relation_base & _r2) { + product_set_relation const& r1 = get(_r1); + product_set_relation const& r2 = get(_r2); + product_set_plugin& p = r1.get_plugin(); + relation_signature const& sig = get_result_signature(); + product_set_relation * result = alloc(product_set_relation, p, sig); + product_set* s = alloc(product_set, p, sig, false); + product_sets::iterator it1 = r1.m_elems.begin(), end1 = r1.m_elems.end(); + for (; it1 != end1; ++it1) { + product_sets::iterator it2 = r2.m_elems.begin(), end2 = r2.m_elems.end(); + for (; it2 != end2; ++it2) { + s->mk_join(*(*it1), *(*it2), m_cols1.size(), m_cols1.c_ptr(), m_cols2.c_ptr()); + s = p.insert(s, result); + } + } + dealloc(s); + return result; + } + }; + relation_join_fn * product_set_plugin::mk_join_fn( + const relation_base & r1, const relation_base & r2, + unsigned col_cnt, const unsigned * cols1, const unsigned * cols2) { + if (!check_kind(r1) || !check_kind(r2)) { + return 0; + } + return alloc(join_fn, r1.get_signature(), r2.get_signature(), col_cnt, cols1, cols2); + } + + class product_set_plugin::project_fn : public convenient_relation_project_fn { + public: + project_fn(const relation_signature & orig_sig, unsigned removed_col_cnt, + const unsigned * removed_cols) + : convenient_relation_project_fn(orig_sig, removed_col_cnt, removed_cols) { + } + + virtual relation_base * operator()(const relation_base & _r) { + product_set_relation const& r = get(_r); + product_set_plugin& p = r.get_plugin(); + relation_signature const& sig = get_result_signature(); + product_set_relation* result = alloc(product_set_relation, p, sig); + product_set* s = alloc(product_set, p, sig, false); + product_sets::iterator it = r.m_elems.begin(), end = r.m_elems.end(); + for (; it != end; ++it) { + s->mk_project(*(*it), m_removed_cols.size(), m_removed_cols.c_ptr()); + s = p.insert(s, result); + } + dealloc(s); + return result; + } + }; + relation_transformer_fn * product_set_plugin::mk_project_fn( + const relation_base & r, unsigned col_cnt, + const unsigned * removed_cols) { + if (check_kind(r)) { + return alloc(project_fn, r.get_signature(), col_cnt, removed_cols); + } + else { + return 0; + } + } + class product_set_plugin::rename_fn : public convenient_relation_rename_fn { + public: + rename_fn(const relation_signature & orig_sig, unsigned cycle_len, const unsigned * cycle) + : convenient_relation_rename_fn(orig_sig, cycle_len, cycle) { + } + + virtual relation_base * operator()(const relation_base & _r) { + product_set_relation const& r = get(_r); + product_set_plugin& p = r.get_plugin(); + relation_signature const& sig = get_result_signature(); + product_set_relation* result = alloc(product_set_relation, p, sig); + product_set* s = alloc(product_set, p, sig, false); + product_sets::iterator it = r.m_elems.begin(), end = r.m_elems.end(); + for (; it != end; ++it) { + s->mk_rename(*(*it), m_cycle.size(), m_cycle.c_ptr()); + s = p.insert(s, result); + } + dealloc(s); + return result; + } + }; + + relation_transformer_fn * product_set_plugin::mk_rename_fn(const relation_base & r, + unsigned cycle_len, const unsigned * permutation_cycle) { + if (check_kind(r)) { + return alloc(rename_fn, r.get_signature(), cycle_len, permutation_cycle); + } + else { + return 0; + } + } + + class product_set_plugin::union_fn : public relation_union_fn { + public: + union_fn() {} + + virtual void operator()(relation_base & _r, const relation_base & _src, relation_base * _delta) { + + TRACE("dl", _r.display(tout << "dst:\n"); _src.display(tout << "src:\n");); + + product_set_relation& r = get(_r); + product_set_relation const& src = get(_src); + product_set_relation* d = get(_delta); + product_sets::iterator it = src.m_elems.begin(), end = src.m_elems.end(); + for (; it != end; ++it) { + product_set* ps = *it; + if (!r.m_elems.contains(ps)) { + ps->inc_ref(); + r.m_elems.insert(ps); + if (d) { + ps->inc_ref(); + d->m_elems.insert(ps); + } + } + } + } + }; + relation_union_fn * product_set_plugin::mk_union_fn( + const relation_base & tgt, const relation_base & src, + const relation_base * delta) { + if (!check_kind(tgt) || !check_kind(src) || (delta && !check_kind(*delta))) { + return 0; + } + return alloc(union_fn); + } + relation_union_fn * product_set_plugin::mk_widen_fn( + const relation_base & tgt, const relation_base & src, + const relation_base * delta) { + return mk_union_fn(tgt, src, delta); + } + + + class product_set_plugin::filter_identical_fn : public relation_mutator_fn { + unsigned_vector m_identical_cols; + public: + filter_identical_fn(unsigned col_cnt, const unsigned * identical_cols) + : m_identical_cols(col_cnt, identical_cols) {} + + virtual void operator()(relation_base & _r) { + product_set_relation & r = get(_r); + product_set_plugin& p = r.get_plugin(); + ptr_vector elems; + product_sets::iterator it = r.m_elems.begin(), end = r.m_elems.end(); + for (; it != end; ++it) { + elems.push_back(*it); + } + r.m_elems.reset(); + for (unsigned i = 0; i < elems.size(); ++i) { + product_set* s = elems[i]; + if (equate(*s)) { + r.m_elems.insert(s); + } + else { + s->dec_ref(); + } + } + } + private: + bool equate(product_set& dst) { + for (unsigned i = 1; !dst.empty() && i < m_identical_cols.size(); ++i) { + unsigned c1 = m_identical_cols[0]; + unsigned c2 = m_identical_cols[i]; + dst.equate(c1, c2); + } + return !dst.empty(); + } + }; + relation_mutator_fn * product_set_plugin::mk_filter_identical_fn( + const relation_base & t, unsigned col_cnt, const unsigned * identical_cols) { + return check_kind(t)?alloc(filter_identical_fn, col_cnt, identical_cols):0; + } + + class product_set_plugin::filter_equal_fn : public relation_mutator_fn { + unsigned m_col; + bit_vector m_value; + public: + filter_equal_fn(product_set_plugin& p, const relation_element & value, unsigned col, bool is_eq) + : m_col(col) { + ast_manager& m = p.get_ast_manager(); + // m.get_context().get_manager() + bv_util bv(m); + rational v; + unsigned bv_size; + unsigned sz = p.set_size(m.get_sort(value)); + VERIFY(bv.is_numeral(value, v, bv_size)); + SASSERT(v.is_unsigned()); + unsigned w = v.get_unsigned(); + SASSERT(w < sz); + m_value = bit_vector(sz); + if (is_eq) { + m_value.set(w, true); + } + else { + for (unsigned i = 0; i < sz; ++i) { + m_value.set(i, i != w); + } + } + } + + virtual void operator()(relation_base & _r) { + product_set_relation & r = get(_r); + product_set_plugin & p = r.get_plugin(); + + ptr_vector elems; + product_sets::iterator it = r.m_elems.begin(), end = r.m_elems.end(); + for (; it != end; ++it) { + elems.push_back(*it); + } + r.m_elems.reset(); + for (unsigned i = 0; i < elems.size(); ++i) { + product_set* s = elems[i]; + + if (s->mk_intersect(m_col, m_value)) { + r.m_elems.insert(s); + } + else { + s->dec_ref(); + } + } + } + }; + + relation_mutator_fn * product_set_plugin::mk_filter_equal_fn(const relation_base & r, + const relation_element & value, unsigned col) { + return check_kind(r)?alloc(filter_equal_fn, *this, value, col, true):0; + } + + relation_mutator_fn * product_set_plugin::mk_filter_interpreted_fn( + const relation_base & t, app * condition) { + ast_manager& m =get_manager().get_context().get_manager(); + std::cout << "filter interpreted '" << mk_pp(condition, m) << "'\n"; + NOT_IMPLEMENTED_YET(); + return 0; + } + + +}; + diff --git a/src/muz/rel/product_set.h b/src/muz/rel/product_set.h new file mode 100644 index 000000000..670ec1cbf --- /dev/null +++ b/src/muz/rel/product_set.h @@ -0,0 +1,187 @@ +/*++ +Copyright (c) 2014 Microsoft Corporation + +Module Name: + + product_set.h + +Abstract: + + Product set relation. + A product set is a tuple of sets. + The meaning of a product set is the set of + elements in the cross-product. + A product set relation is a set of product sets, + and the meaning of this relation is the union of + all elements from the products. + It is to be used when computing over product sets is + (much) cheaper than over the space of tuples. + +Author: + + Nikolaj Bjorner (nbjorner) 2014-08-23 + +Revision History: + +--*/ +#ifndef _DL_PRODUCT_SET__H_ +#define _DL_PRODUCT_SET__H_ + +#include "util.h" +#include "bit_vector.h" +#include "dl_base.h" +#include "dl_vector_relation.h" + +namespace datalog { + + class product_set_plugin; + + class product_set : public vector_relation { + typedef bit_vector T; + unsigned m_refs; + public: + product_set(product_set_plugin& p, relation_signature const& s, bool is_empty, T const& t = T()); + + virtual ~product_set() {} + unsigned get_hash() const; + bool operator==(product_set const& p) const; + bool contains(product_set const& p) const; + + void inc_ref() { ++m_refs; } + void dec_ref() { --m_refs; if (0 == m_refs) dealloc(this); } + unsigned ref_count() const { return m_refs; } + + struct eq { + bool operator()(product_set const* s1, product_set const* s2) const { + return *s1 == *s2; + } + }; + + struct hash { + unsigned operator()(product_set const* s) const { + return s->get_hash(); + } + }; + + virtual void add_fact(const relation_fact & f); + virtual bool contains_fact(const relation_fact & f) const; + virtual relation_base * clone() const; + virtual relation_base * complement(func_decl*) const; + virtual void to_formula(expr_ref& fml) const; + + bool mk_intersect(unsigned idx, T const& t); + + private: + virtual void display_index(unsigned i, const T&, std::ostream& out) const; + virtual T mk_intersect(T const& t1, T const& t2, bool& _is_empty) const { + T result(t1); + result &= t2; + _is_empty = is_empty(0, result); + return result; + } + + virtual T mk_widen(T const& t1, T const& t2) const { + UNREACHABLE(); + return t1; + } + + virtual T mk_unite(T const& t1, T const& t2) const { + UNREACHABLE(); + return t1; + } + + virtual bool is_subset_of(T const& t1, T const& t2) const { + return t2.contains(t1); + } + + virtual bool is_full(T const& t) const { + for (unsigned j = 0; j < t.size(); ++j) { + if (!t.get(j)) return false; + } + return true; + } + + virtual bool is_empty(unsigned i, T const& t) const { + for (unsigned j = 0; j < t.size(); ++j) { + if (t.get(j)) return false; + } + return true; + } + + virtual void mk_rename_elem(T& t, unsigned col_cnt, unsigned const* cycle) { + // no-op. + } + + virtual T mk_eq(union_find<> const& old_eqs, union_find<> const& neq_eqs, T const& t) const { + UNREACHABLE(); + return t; + } + + + }; + + typedef ptr_hashtable product_sets; + + class product_set_relation : public relation_base { + friend class product_set_plugin; + product_sets m_elems; + public: + product_set_relation(product_set_plugin& p, relation_signature const& s); + virtual ~product_set_relation(); + virtual void add_fact(const relation_fact & f); + virtual bool contains_fact(const relation_fact & f) const; + virtual product_set_relation * clone() const; + virtual product_set_relation * complement(func_decl*) const; + virtual void to_formula(expr_ref& fml) const; + product_set_plugin& get_plugin() const; + virtual bool empty() const { return m_elems.empty(); } + virtual void display(std::ostream& out) const; + + virtual bool is_precise() const { return true; } + }; + + class product_set_plugin : public relation_plugin { + friend class product_set_relation; + class join_fn; + class project_fn; + class union_fn; + class rename_fn; + class filter_equal_fn; + class filter_identical_fn; + class filter_interpreted_fn; + class filter_by_negation_fn; + + public: + product_set_plugin(relation_manager& rm); + virtual bool can_handle_signature(const relation_signature & s); + static symbol get_name() { return symbol("product_set"); } + virtual relation_base * mk_empty(const relation_signature & s); + virtual relation_base * mk_full(func_decl* p, const relation_signature & s); + virtual relation_join_fn * mk_join_fn(const relation_base & t1, const relation_base & t2, + unsigned col_cnt, const unsigned * cols1, const unsigned * cols2); + virtual relation_transformer_fn * mk_project_fn(const relation_base & t, unsigned col_cnt, + const unsigned * removed_cols); + virtual relation_transformer_fn * mk_rename_fn(const relation_base & t, unsigned permutation_cycle_len, + const unsigned * permutation_cycle); + virtual relation_union_fn * mk_union_fn(const relation_base & tgt, const relation_base & src, + const relation_base * delta); + virtual relation_union_fn * mk_widen_fn(const relation_base & tgt, const relation_base & src, + const relation_base * delta); + virtual relation_mutator_fn * mk_filter_identical_fn(const relation_base & t, unsigned col_cnt, + const unsigned * identical_cols); + virtual relation_mutator_fn * mk_filter_equal_fn(const relation_base & t, const relation_element & value, + unsigned col); + virtual relation_mutator_fn * mk_filter_interpreted_fn(const relation_base & t, app * condition); + + unsigned set_size(sort* ty); + + private: + static product_set_relation& get(relation_base& r); + static product_set_relation* get(relation_base* r); + static product_set_relation const & get(relation_base const& r); + product_set* insert(product_set* s, product_set_relation* r); + }; + +}; + +#endif diff --git a/src/muz/rel/rel_context.cpp b/src/muz/rel/rel_context.cpp index 1a3d2cdae..742930b5c 100644 --- a/src/muz/rel/rel_context.cpp +++ b/src/muz/rel/rel_context.cpp @@ -31,6 +31,7 @@ Revision History: #include"dl_interval_relation.h" #include"karr_relation.h" #include"dl_finite_product_relation.h" +#include"product_set.h" #include"dl_lazy_table.h" #include"dl_sparse_table.h" #include"dl_table.h" @@ -112,6 +113,7 @@ namespace datalog { rm.register_plugin(alloc(bound_relation_plugin, rm)); rm.register_plugin(alloc(interval_relation_plugin, rm)); rm.register_plugin(alloc(karr_relation_plugin, rm)); + rm.register_plugin(alloc(product_set_plugin, rm)); } rel_context::~rel_context() {