3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-15 13:28:47 +00:00

Merge pull request #126 from ahorn/minimum

Basic infrastructure for minimum aggregation function
This commit is contained in:
Nikolaj Bjorner 2015-06-11 09:38:39 -07:00
commit 94f8ecb06d
17 changed files with 550 additions and 10 deletions

View file

@ -44,7 +44,8 @@ namespace datalog {
m_num_sym("N"), m_num_sym("N"),
m_lt_sym("<"), m_lt_sym("<"),
m_le_sym("<="), m_le_sym("<="),
m_rule_sym("R") m_rule_sym("R"),
m_min_sym("min")
{ {
} }
@ -490,6 +491,66 @@ namespace datalog {
return m_manager->mk_func_decl(m_clone_sym, 1, &s, s, info); return m_manager->mk_func_decl(m_clone_sym, 1, &s, s, info);
} }
/**
In SMT2 syntax, we can write \c ((_ min R N) v_0 v_1 ... v_k)) where 0 <= N <= k,
R is a relation of sort V_0 x V_1 x ... x V_k and each v_i is a zero-arity function
(also known as a "constant" in SMT2 parlance) whose range is of sort V_i.
Example:
(define-sort number_t () (_ BitVec 2))
(declare-rel numbers (number_t number_t))
(declare-rel is_min (number_t number_t))
(declare-var x number_t)
(declare-var y number_t)
(rule (numbers #b00 #b11))
(rule (numbers #b00 #b01))
(rule (=> (and (numbers x y) ((_ min numbers 1) x y)) (is_min x y)))
This says that we want to find the mininum y grouped by x.
*/
func_decl * dl_decl_plugin::mk_min(decl_kind k, unsigned num_parameters, parameter const * parameters) {
if (num_parameters < 2) {
m_manager->raise_exception("invalid min aggregate definition due to missing parameters");
return 0;
}
parameter const & relation_parameter = parameters[0];
if (!relation_parameter.is_ast() || !is_func_decl(relation_parameter.get_ast())) {
m_manager->raise_exception("invalid min aggregate definition, first parameter is not a function declaration");
return 0;
}
func_decl* f = to_func_decl(relation_parameter.get_ast());
if (!m_manager->is_bool(f->get_range())) {
m_manager->raise_exception("invalid min aggregate definition, first paramater must be a predicate");
return 0;
}
parameter const & min_col_parameter = parameters[1];
if (!min_col_parameter.is_int()) {
m_manager->raise_exception("invalid min aggregate definition, second parameter must be an integer");
return 0;
}
if (min_col_parameter.get_int() < 0) {
m_manager->raise_exception("invalid min aggregate definition, second parameter must be non-negative");
return 0;
}
if ((unsigned)min_col_parameter.get_int() >= f->get_arity()) {
m_manager->raise_exception("invalid min aggregate definition, second parameter exceeds the arity of the relation");
return 0;
}
func_decl_info info(m_family_id, k, num_parameters, parameters);
SASSERT(f->get_info() == 0);
return m_manager->mk_func_decl(m_min_sym, f->get_arity(), f->get_domain(), f->get_range(), info);
}
func_decl * dl_decl_plugin::mk_func_decl( func_decl * dl_decl_plugin::mk_func_decl(
decl_kind k, unsigned num_parameters, parameter const * parameters, decl_kind k, unsigned num_parameters, parameter const * parameters,
unsigned arity, sort * const * domain, sort * range) { unsigned arity, sort * const * domain, sort * range) {
@ -617,6 +678,9 @@ namespace datalog {
break; break;
} }
case OP_DL_MIN:
return mk_min(k, num_parameters, parameters);
default: default:
m_manager->raise_exception("operator not recognized"); m_manager->raise_exception("operator not recognized");
return 0; return 0;
@ -627,7 +691,7 @@ namespace datalog {
} }
void dl_decl_plugin::get_op_names(svector<builtin_name> & op_names, symbol const & logic) { void dl_decl_plugin::get_op_names(svector<builtin_name> & op_names, symbol const & logic) {
op_names.push_back(builtin_name(m_min_sym.bare_str(), OP_DL_MIN));
} }
void dl_decl_plugin::get_sort_names(svector<builtin_name> & sort_names, symbol const & logic) { void dl_decl_plugin::get_sort_names(svector<builtin_name> & sort_names, symbol const & logic) {

View file

@ -50,6 +50,7 @@ namespace datalog {
OP_DL_LT, OP_DL_LT,
OP_DL_REP, OP_DL_REP,
OP_DL_ABS, OP_DL_ABS,
OP_DL_MIN,
LAST_RA_OP LAST_RA_OP
}; };
@ -71,6 +72,7 @@ namespace datalog {
symbol m_lt_sym; symbol m_lt_sym;
symbol m_le_sym; symbol m_le_sym;
symbol m_rule_sym; symbol m_rule_sym;
symbol m_min_sym;
bool check_bounds(char const* msg, unsigned low, unsigned up, unsigned val) const; bool check_bounds(char const* msg, unsigned low, unsigned up, unsigned val) const;
bool check_domain(unsigned low, unsigned up, unsigned val) const; bool check_domain(unsigned low, unsigned up, unsigned val) const;
@ -94,12 +96,69 @@ namespace datalog {
func_decl * mk_compare(decl_kind k, symbol const& sym, sort*const* domain); func_decl * mk_compare(decl_kind k, symbol const& sym, sort*const* domain);
func_decl * mk_clone(sort* r); func_decl * mk_clone(sort* r);
func_decl * mk_rule(unsigned arity); func_decl * mk_rule(unsigned arity);
func_decl * mk_min(decl_kind k, unsigned num_parameters, parameter const * parameters);
sort * mk_finite_sort(unsigned num_params, parameter const* params); sort * mk_finite_sort(unsigned num_params, parameter const* params);
sort * mk_relation_sort(unsigned num_params, parameter const* params); sort * mk_relation_sort(unsigned num_params, parameter const* params);
sort * mk_rule_sort(); sort * mk_rule_sort();
public: public:
/**
Is \c decl a min aggregation function?
*/
static bool is_aggregate(const func_decl* const decl)
{
return decl->get_decl_kind() == OP_DL_MIN;
}
/**
\pre: is_aggregate(aggregate)
\returns function declaration of predicate which is subject to min aggregation function
*/
static func_decl * min_func_decl(const func_decl* const aggregate)
{
SASSERT(is_aggregate(aggregate));
parameter const & relation_parameter = aggregate->get_parameter(0);
return to_func_decl(relation_parameter.get_ast());
}
/**
\pre: is_aggregate(aggregate)
\returns column identifier (starting at zero) which is minimized by aggregation function
*/
static unsigned min_col(const func_decl* const aggregate)
{
SASSERT(is_aggregate(aggregate));
return (unsigned)aggregate->get_parameter(1).get_int();
}
/**
\pre: is_aggregate(aggregate)
\returns column identifiers for the "group by" in the given min aggregation function
*/
static unsigned_vector group_by_cols(const func_decl* const aggregate)
{
SASSERT(is_aggregate(aggregate));
unsigned _min_col = min_col(aggregate);
if (aggregate->get_arity() == 0U)
return unsigned_vector();
unsigned col_num = 0;
unsigned_vector cols(aggregate->get_arity() - 1U);
for (unsigned i = 0; i < cols.size(); ++i, ++col_num)
{
if (col_num == _min_col)
++col_num;
cols[i] = col_num;
}
return cols;
}
dl_decl_plugin(); dl_decl_plugin();
virtual ~dl_decl_plugin() {} virtual ~dl_decl_plugin() {}

View file

@ -621,6 +621,7 @@ void cmd_context::init_manager_core(bool new_manager) {
register_plugin(symbol("seq"), alloc(seq_decl_plugin), logic_has_seq()); register_plugin(symbol("seq"), alloc(seq_decl_plugin), logic_has_seq());
register_plugin(symbol("pb"), alloc(pb_decl_plugin), !has_logic()); register_plugin(symbol("pb"), alloc(pb_decl_plugin), !has_logic());
register_plugin(symbol("fpa"), alloc(fpa_decl_plugin), logic_has_fpa()); register_plugin(symbol("fpa"), alloc(fpa_decl_plugin), logic_has_fpa());
register_plugin(symbol("datalog_relation"), alloc(datalog::dl_decl_plugin), !has_logic());
} }
else { else {
// the manager was created by an external module // the manager was created by an external module

View file

@ -346,6 +346,13 @@ namespace datalog {
bool is_neg_tail(unsigned i) const { SASSERT(i < m_tail_size); return GET_TAG(m_tail[i]) == 1; } bool is_neg_tail(unsigned i) const { SASSERT(i < m_tail_size); return GET_TAG(m_tail[i]) == 1; }
/**
A predicate P(Xj) can be annotated by adding an interpreted predicate of the form ((_ min P N) ...)
where N is the column number that should be used for the min aggregation function.
Such an interpreted predicate is an example for which this function returns true.
*/
bool is_min_tail(unsigned i) const { return dl_decl_plugin::is_aggregate(get_tail(i)->get_decl()); }
/** /**
Check whether predicate p is in the interpreted tail. Check whether predicate p is in the interpreted tail.

View file

@ -400,7 +400,7 @@ namespace datalog {
SASSERT(!is_closed()); //the rule_set is not already closed SASSERT(!is_closed()); //the rule_set is not already closed
m_deps.populate(*this); m_deps.populate(*this);
m_stratifier = alloc(rule_stratifier, m_deps); m_stratifier = alloc(rule_stratifier, m_deps);
if (!stratified_negation()) { if (!stratified_negation() || !check_min()) {
m_stratifier = 0; m_stratifier = 0;
m_deps.reset(); m_deps.reset();
return false; return false;
@ -441,6 +441,49 @@ namespace datalog {
return true; return true;
} }
bool rule_set::check_min() {
// For now, we check the following:
//
// if a min aggregation function occurs in an SCC, is this SCC
// free of any other non-monotonic functions, e.g. negation?
const unsigned NEG_BIT = 1U << 0;
const unsigned MIN_BIT = 1U << 1;
ptr_vector<rule>::const_iterator it = m_rules.c_ptr();
ptr_vector<rule>::const_iterator end = m_rules.c_ptr() + m_rules.size();
unsigned_vector component_status(m_stratifier->get_strats().size());
for (; it != end; it++) {
rule * r = *it;
app * head = r->get_head();
func_decl * head_decl = head->get_decl();
unsigned head_strat = get_predicate_strat(head_decl);
unsigned n = r->get_tail_size();
for (unsigned i = 0; i < n; i++) {
func_decl * tail_decl = r->get_tail(i)->get_decl();
unsigned strat = get_predicate_strat(tail_decl);
if (r->is_neg_tail(i)) {
SASSERT(strat < component_status.size());
component_status[strat] |= NEG_BIT;
}
if (r->is_min_tail(i)) {
SASSERT(strat < component_status.size());
component_status[strat] |= MIN_BIT;
}
}
}
const unsigned CONFLICT = NEG_BIT | MIN_BIT;
for (unsigned k = 0; k < component_status.size(); ++k) {
if (component_status[k] == CONFLICT)
return false;
}
return true;
}
void rule_set::replace_rules(const rule_set & src) { void rule_set::replace_rules(const rule_set & src) {
if (this != &src) { if (this != &src) {
reset(); reset();

View file

@ -179,6 +179,7 @@ namespace datalog {
void compute_deps(); void compute_deps();
void compute_tc_deps(); void compute_tc_deps();
bool stratified_negation(); bool stratified_negation();
bool check_min();
public: public:
rule_set(context & ctx); rule_set(context & ctx);
rule_set(const rule_set & rs); rule_set(const rule_set & rs);

View file

@ -485,4 +485,130 @@ namespace datalog {
brw.mk_or(disjs.size(), disjs.c_ptr(), fml); brw.mk_or(disjs.size(), disjs.c_ptr(), fml);
} }
class table_plugin::min_fn : public table_min_fn{
table_signature m_sig;
const unsigned_vector m_group_by_cols;
const unsigned m_col;
public:
min_fn(const table_signature & t_sig, const unsigned_vector& group_by_cols, const unsigned col)
: m_sig(t_sig),
m_group_by_cols(group_by_cols),
m_col(col) {}
virtual table_base* operator()(table_base const& t) {
//return reference_implementation(t);
return reference_implementation_with_hash(t);
}
private:
/**
Reference implementation with negation:
T1 = join(T, T) by group_cols
T2 = { (t1,t2) in T1 | t1[col] > t2[col] }
T3 = { t1 | (t1,t2) in T2 }
T4 = T \ T3
The point of this reference implementation is to show
that the minimum requires negation (set difference).
This is relevant for fixed point computations.
*/
virtual table_base * reference_implementation(const table_base & t) {
relation_manager & manager = t.get_manager();
table_join_fn * join_fn = manager.mk_join_fn(t, t, m_group_by_cols, m_group_by_cols);
table_base * join_table = (*join_fn)(t, t);
dealloc(join_fn);
table_base::iterator join_table_it = join_table->begin();
table_base::iterator join_table_end = join_table->end();
table_fact row;
table_element i, j;
for (; join_table_it != join_table_end; ++join_table_it) {
join_table_it->get_fact(row);
i = row[m_col];
j = row[t.num_columns() + m_col];
if (i > j) {
continue;
}
join_table->remove_fact(row);
}
unsigned_vector cols(t.num_columns());
for (unsigned k = 0; k < cols.size(); ++k) {
cols[k] = cols.size() + k;
SASSERT(cols[k] < join_table->num_columns());
}
table_transformer_fn * project_fn = manager.mk_project_fn(*join_table, cols);
table_base * gt_table = (*project_fn)(*join_table);
dealloc(project_fn);
join_table->deallocate();
for (unsigned k = 0; k < cols.size(); ++k) {
cols[k] = k;
SASSERT(cols[k] < t.num_columns());
SASSERT(cols[k] < gt_table->num_columns());
}
table_base * result = t.clone();
table_intersection_filter_fn * diff_fn = manager.mk_filter_by_negation_fn(*result, *gt_table, cols, cols);
(*diff_fn)(*result, *gt_table);
dealloc(diff_fn);
gt_table->deallocate();
return result;
}
typedef map < table_fact, table_element, svector_hash_proc<table_element_hash>,
vector_eq_proc<table_fact> > group_map;
// Thanks to Nikolaj who kindly helped with the second reference implementation!
virtual table_base * reference_implementation_with_hash(const table_base & t) {
group_map group;
table_base::iterator it = t.begin();
table_base::iterator end = t.end();
table_fact row, row2;
table_element current_value, min_value;
for (; it != end; ++it) {
it->get_fact(row);
current_value = row[m_col];
group_by(row, row2);
group_map::entry* entry = group.find_core(row2);
if (!entry) {
group.insert(row2, current_value);
}
else if (entry->get_data().m_value > current_value) {
entry->get_data().m_value = current_value;
}
}
table_base* result = t.get_plugin().mk_empty(m_sig);
table_base::iterator it2 = t.begin();
for (; it2 != end; ++it2) {
it2->get_fact(row);
current_value = row[m_col];
group_by(row, row2);
VERIFY(group.find(row2, min_value));
if (min_value == current_value) {
result->add_fact(row);
}
}
return result;
}
void group_by(table_fact const& in, table_fact& out) {
out.reset();
for (unsigned i = 0; i < m_group_by_cols.size(); ++i) {
out.push_back(in[m_group_by_cols[i]]);
}
}
};
table_min_fn * table_plugin::mk_min_fn(const table_base & t,
unsigned_vector & group_by_cols, const unsigned col) {
return alloc(table_plugin::min_fn, t.get_signature(), group_by_cols, col);
}
} }

View file

@ -192,6 +192,29 @@ namespace datalog {
virtual base_object * operator()(const base_object & t1, const base_object & t2) = 0; virtual base_object * operator()(const base_object & t1, const base_object & t2) = 0;
}; };
/**
\brief Aggregate minimum value
Informally, we want to group rows in a table \c t by \c group_by_cols and
return the minimum value in column \c col among each group.
Let \c t be a table with N columns.
Let \c group_by_cols be a set of column identifers for table \c t such that |group_by_cols| < N.
Let \c col be a column identifier for table \c t such that \c col is not in \c group_by_cols.
Let R_col be a set of rows in table \c t such that, for all rows r_i, r_j in R_col
and column identifiers k in \c group_by_cols, r_i[k] = r_j[k].
For each R_col, we want to restrict R_col to those rows whose value in column \c col is minimal.
min_fn(R, group_by_cols, col) =
{ row in R | forall row' in R . row'[group_by_cols] = row[group_by_cols] => row'[col] >= row[col] }
*/
class min_fn : public base_fn {
public:
virtual base_object * operator()(const base_object & t) = 0;
};
class transformer_fn : public base_fn { class transformer_fn : public base_fn {
public: public:
virtual base_object * operator()(const base_object & t) = 0; virtual base_object * operator()(const base_object & t) = 0;
@ -856,6 +879,7 @@ namespace datalog {
typedef table_infrastructure::base_fn base_table_fn; typedef table_infrastructure::base_fn base_table_fn;
typedef table_infrastructure::join_fn table_join_fn; typedef table_infrastructure::join_fn table_join_fn;
typedef table_infrastructure::min_fn table_min_fn;
typedef table_infrastructure::transformer_fn table_transformer_fn; typedef table_infrastructure::transformer_fn table_transformer_fn;
typedef table_infrastructure::union_fn table_union_fn; typedef table_infrastructure::union_fn table_union_fn;
typedef table_infrastructure::mutator_fn table_mutator_fn; typedef table_infrastructure::mutator_fn table_mutator_fn;
@ -1020,6 +1044,7 @@ namespace datalog {
class table_plugin : public table_infrastructure::plugin_object { class table_plugin : public table_infrastructure::plugin_object {
friend class relation_manager; friend class relation_manager;
class min_fn;
protected: protected:
table_plugin(symbol const& n, relation_manager & manager) : plugin_object(n, manager) {} table_plugin(symbol const& n, relation_manager & manager) : plugin_object(n, manager) {}
public: public:
@ -1027,6 +1052,9 @@ namespace datalog {
virtual bool can_handle_signature(const table_signature & s) { return s.functional_columns()==0; } virtual bool can_handle_signature(const table_signature & s) { return s.functional_columns()==0; }
protected: protected:
virtual table_min_fn * mk_min_fn(const table_base & t,
unsigned_vector & group_by_cols, const unsigned col);
/** /**
If the returned value is non-zero, the returned object must take ownership of \c mapper. If the returned value is non-zero, the returned object must take ownership of \c mapper.
Otherwise \c mapper must remain unmodified. Otherwise \c mapper must remain unmodified.

View file

@ -73,6 +73,12 @@ namespace datalog {
vars.get_cols2(), removed_cols.size(), removed_cols.c_ptr(), result)); vars.get_cols2(), removed_cols.size(), removed_cols.c_ptr(), result));
} }
void compiler::make_min(reg_idx source, reg_idx & target, const unsigned_vector & group_by_cols,
const unsigned min_col, instruction_block & acc) {
target = get_register(m_reg_signatures[source], true, source);
acc.push_back(instruction::mk_min(source, target, group_by_cols, min_col));
}
void compiler::make_filter_interpreted_and_project(reg_idx src, app_ref & cond, void compiler::make_filter_interpreted_and_project(reg_idx src, app_ref & cond,
const unsigned_vector & removed_cols, reg_idx & result, bool reuse, instruction_block & acc) { const unsigned_vector & removed_cols, reg_idx & result, bool reuse, instruction_block & acc) {
SASSERT(!removed_cols.empty()); SASSERT(!removed_cols.empty());
@ -440,6 +446,30 @@ namespace datalog {
get_local_indexes_for_projection(t2, counter, t1->get_num_args(), res); get_local_indexes_for_projection(t2, counter, t1->get_num_args(), res);
} }
void compiler::find_min_aggregates(const rule * r, ptr_vector<func_decl>& min_aggregates) {
unsigned ut_len = r->get_uninterpreted_tail_size();
unsigned ft_len = r->get_tail_size(); // full tail
func_decl * aggregate;
for (unsigned tail_index = ut_len; tail_index < ft_len; ++tail_index) {
aggregate = r->get_tail(tail_index)->get_decl();
if (dl_decl_plugin::is_aggregate(aggregate)) {
min_aggregates.push_back(aggregate);
}
}
}
bool compiler::prepare_min_aggregate(const func_decl * decl, const ptr_vector<func_decl>& min_aggregates,
unsigned_vector & group_by_cols, unsigned & min_col) {
for (unsigned i = 0; i < min_aggregates.size(); ++i) {
if (dl_decl_plugin::min_func_decl(min_aggregates[i]) == decl) {
group_by_cols = dl_decl_plugin::group_by_cols(min_aggregates[i]);
min_col = dl_decl_plugin::min_col(min_aggregates[i]);
return true;
}
}
return false;
}
void compiler::compile_rule_evaluation_run(rule * r, reg_idx head_reg, const reg_idx * tail_regs, void compiler::compile_rule_evaluation_run(rule * r, reg_idx head_reg, const reg_idx * tail_regs,
reg_idx delta_reg, bool use_widening, instruction_block & acc) { reg_idx delta_reg, bool use_widening, instruction_block & acc) {
@ -465,6 +495,12 @@ namespace datalog {
// whether to dealloc the previous result // whether to dealloc the previous result
bool dealloc = true; bool dealloc = true;
// setup information for min aggregation
ptr_vector<func_decl> min_aggregates;
find_min_aggregates(r, min_aggregates);
unsigned_vector group_by_cols;
unsigned min_col;
if(pt_len == 2) { if(pt_len == 2) {
reg_idx t1_reg=tail_regs[0]; reg_idx t1_reg=tail_regs[0];
reg_idx t2_reg=tail_regs[1]; reg_idx t2_reg=tail_regs[1];
@ -473,6 +509,14 @@ namespace datalog {
SASSERT(m_reg_signatures[t1_reg].size()==a1->get_num_args()); SASSERT(m_reg_signatures[t1_reg].size()==a1->get_num_args());
SASSERT(m_reg_signatures[t2_reg].size()==a2->get_num_args()); SASSERT(m_reg_signatures[t2_reg].size()==a2->get_num_args());
if (prepare_min_aggregate(a1->get_decl(), min_aggregates, group_by_cols, min_col)) {
make_min(t1_reg, single_res, group_by_cols, min_col, acc);
}
if (prepare_min_aggregate(a2->get_decl(), min_aggregates, group_by_cols, min_col)) {
make_min(t2_reg, single_res, group_by_cols, min_col, acc);
}
variable_intersection a1a2(m_context.get_manager()); variable_intersection a1a2(m_context.get_manager());
a1a2.populate(a1,a2); a1a2.populate(a1,a2);
@ -514,6 +558,10 @@ namespace datalog {
single_res = tail_regs[0]; single_res = tail_regs[0];
dealloc = false; dealloc = false;
if (prepare_min_aggregate(a->get_decl(), min_aggregates, group_by_cols, min_col)) {
make_min(single_res, single_res, group_by_cols, min_col, acc);
}
SASSERT(m_reg_signatures[single_res].size() == a->get_num_args()); SASSERT(m_reg_signatures[single_res].size() == a->get_num_args());
unsigned n=a->get_num_args(); unsigned n=a->get_num_args();
@ -597,7 +645,8 @@ namespace datalog {
unsigned ft_len = r->get_tail_size(); // full tail unsigned ft_len = r->get_tail_size(); // full tail
ptr_vector<expr> tail; ptr_vector<expr> tail;
for (unsigned tail_index = ut_len; tail_index < ft_len; ++tail_index) { for (unsigned tail_index = ut_len; tail_index < ft_len; ++tail_index) {
tail.push_back(r->get_tail(tail_index)); if (!r->is_min_tail(tail_index))
tail.push_back(r->get_tail(tail_index));
} }
expr_ref_vector binding(m); expr_ref_vector binding(m);

View file

@ -120,6 +120,22 @@ namespace datalog {
instruction_observer m_instruction_observer; instruction_observer m_instruction_observer;
expr_free_vars m_free_vars; expr_free_vars m_free_vars;
/**
\brief Finds all the min aggregation functions in the premise of a given rule.
*/
static void find_min_aggregates(const rule * r, ptr_vector<func_decl>& min_aggregates);
/**
\brief Decides whether a predicate is subject to a min aggregation function.
If \c decl is subject to a min aggregation function, the output parameters are written
with the neccessary information.
\returns true if the output paramaters have been written
*/
static bool prepare_min_aggregate(const func_decl * decl, const ptr_vector<func_decl>& min_aggregates,
unsigned_vector & group_by_cols, unsigned & min_col);
/** /**
If true, the union operation on the underlying structure only provides the information If true, the union operation on the underlying structure only provides the information
whether the updated relation has changed or not. In this case we do not get anything whether the updated relation has changed or not. In this case we do not get anything
@ -146,6 +162,8 @@ namespace datalog {
void make_join(reg_idx t1, reg_idx t2, const variable_intersection & vars, reg_idx & result, void make_join(reg_idx t1, reg_idx t2, const variable_intersection & vars, reg_idx & result,
bool reuse_t1, instruction_block & acc); bool reuse_t1, instruction_block & acc);
void make_min(reg_idx source, reg_idx & target, const unsigned_vector & group_by_cols,
const unsigned min_col, instruction_block & acc);
void make_join_project(reg_idx t1, reg_idx t2, const variable_intersection & vars, void make_join_project(reg_idx t1, reg_idx t2, const variable_intersection & vars,
const unsigned_vector & removed_cols, reg_idx & result, bool reuse_t1, instruction_block & acc); const unsigned_vector & removed_cols, reg_idx & result, bool reuse_t1, instruction_block & acc);
void make_filter_interpreted_and_project(reg_idx src, app_ref & cond, void make_filter_interpreted_and_project(reg_idx src, app_ref & cond,

View file

@ -25,6 +25,7 @@ Revision History:
#include"rel_context.h" #include"rel_context.h"
#include"debug.h" #include"debug.h"
#include"warning.h" #include"warning.h"
#include"dl_table_relation.h"
namespace datalog { namespace datalog {
@ -552,7 +553,7 @@ namespace datalog {
if (r.fast_empty()) { if (r.fast_empty()) {
ctx.make_empty(m_reg); ctx.make_empty(m_reg);
} }
TRACE("dl_verbose", r.display(tout <<"post-filter-interpreted:\n");); //TRACE("dl_verbose", r.display(tout <<"post-filter-interpreted:\n"););
return true; return true;
} }
@ -609,7 +610,7 @@ namespace datalog {
if (ctx.reg(m_res)->fast_empty()) { if (ctx.reg(m_res)->fast_empty()) {
ctx.make_empty(m_res); ctx.make_empty(m_res);
} }
TRACE("dl_verbose", reg.display(tout << "post-filter-interpreted-and-project:\n");); //TRACE("dl_verbose", reg.display(tout << "post-filter-interpreted-and-project:\n"););
return true; return true;
} }
@ -883,6 +884,60 @@ namespace datalog {
removed_cols, result); removed_cols, result);
} }
class instr_min : public instruction {
reg_idx m_source_reg;
reg_idx m_target_reg;
unsigned_vector m_group_by_cols;
unsigned m_min_col;
public:
instr_min(reg_idx source_reg, reg_idx target_reg, const unsigned_vector & group_by_cols, unsigned min_col)
: m_source_reg(source_reg),
m_target_reg(target_reg),
m_group_by_cols(group_by_cols),
m_min_col(min_col) {
}
virtual bool perform(execution_context & ctx) {
log_verbose(ctx);
if (!ctx.reg(m_source_reg)) {
ctx.make_empty(m_target_reg);
return true;
}
const relation_base & s = *ctx.reg(m_source_reg);
if (!s.from_table()) {
throw default_exception("relation is not a table %s",
s.get_plugin().get_name().bare_str());
}
++ctx.m_stats.m_min;
const table_relation & tr = static_cast<const table_relation &>(s);
const table_base & source_t = tr.get_table();
relation_manager & r_manager = s.get_manager();
const relation_signature & r_sig = s.get_signature();
table_min_fn * fn = r_manager.mk_min_fn(source_t, m_group_by_cols, m_min_col);
table_base * target_t = (*fn)(source_t);
dealloc(fn);
TRACE("dl",
tout << "% ";
target_t->display(tout);
tout << "\n";);
relation_base * target_r = r_manager.mk_table_relation(r_sig, target_t);
ctx.set_reg(m_target_reg, target_r);
return true;
}
virtual void display_head_impl(execution_context const& ctx, std::ostream & out) const {
out << " MIN AGGR ";
}
virtual void make_annotations(execution_context & ctx) {
}
};
instruction * instruction::mk_min(reg_idx source, reg_idx target, const unsigned_vector & group_by_cols,
const unsigned min_col) {
return alloc(instr_min, source, target, group_by_cols, min_col);
}
class instr_select_equal_and_project : public instruction { class instr_select_equal_and_project : public instruction {
reg_idx m_src; reg_idx m_src;

View file

@ -93,6 +93,7 @@ namespace datalog {
unsigned m_filter_interp_project; unsigned m_filter_interp_project;
unsigned m_filter_id; unsigned m_filter_id;
unsigned m_filter_eq; unsigned m_filter_eq;
unsigned m_min;
stats() { reset(); } stats() { reset(); }
void reset() { memset(this, 0, sizeof(*this)); } void reset() { memset(this, 0, sizeof(*this)); }
}; };
@ -284,6 +285,8 @@ namespace datalog {
static instruction * mk_join_project(reg_idx rel1, reg_idx rel2, unsigned joined_col_cnt, static instruction * mk_join_project(reg_idx rel1, reg_idx rel2, unsigned joined_col_cnt,
const unsigned * cols1, const unsigned * cols2, unsigned removed_col_cnt, const unsigned * cols1, const unsigned * cols2, unsigned removed_col_cnt,
const unsigned * removed_cols, reg_idx result); const unsigned * removed_cols, reg_idx result);
static instruction * mk_min(reg_idx source, reg_idx target, const unsigned_vector & group_by_cols,
const unsigned min_col);
static instruction * mk_rename(reg_idx src, unsigned cycle_len, const unsigned * permutation_cycle, static instruction * mk_rename(reg_idx src, unsigned cycle_len, const unsigned * permutation_cycle,
reg_idx tgt); reg_idx tgt);
static instruction * mk_filter_by_negation(reg_idx tgt, reg_idx neg_rel, unsigned col_cnt, static instruction * mk_filter_by_negation(reg_idx tgt, reg_idx neg_rel, unsigned col_cnt,

View file

@ -354,7 +354,9 @@ namespace datalog {
return product_relation_plugin::get_plugin(*this).mk_empty(s); return product_relation_plugin::get_plugin(*this).mk_empty(s);
} }
/**
The newly created object takes ownership of the \c table object.
*/
relation_base * relation_manager::mk_table_relation(const relation_signature & s, table_base * table) { relation_base * relation_manager::mk_table_relation(const relation_signature & s, table_base * table) {
SASSERT(s.size()==table->get_signature().size()); SASSERT(s.size()==table->get_signature().size());
return get_table_relation_plugin(table->get_plugin()).mk_from_table(s, table); return get_table_relation_plugin(table->get_plugin()).mk_from_table(s, table);
@ -1021,6 +1023,11 @@ namespace datalog {
return res; return res;
} }
table_min_fn * relation_manager::mk_min_fn(const table_base & t,
unsigned_vector & group_by_cols, const unsigned col)
{
return t.get_plugin().mk_min_fn(t, group_by_cols, col);
}
class relation_manager::auxiliary_table_transformer_fn { class relation_manager::auxiliary_table_transformer_fn {
table_fact m_row; table_fact m_row;

View file

@ -251,6 +251,9 @@ namespace datalog {
return mk_join_fn(t1, t2, cols1.size(), cols1.c_ptr(), cols2.c_ptr(), allow_product_relation); return mk_join_fn(t1, t2, cols1.size(), cols1.c_ptr(), cols2.c_ptr(), allow_product_relation);
} }
table_min_fn * mk_min_fn(const table_base & t,
unsigned_vector & group_by_cols, const unsigned col);
/** /**
\brief Return functor that transforms a table into one that lacks columns listed in \brief Return functor that transforms a table into one that lacks columns listed in
\c removed_cols array. \c removed_cols array.

View file

@ -63,6 +63,9 @@ namespace datalog {
return alloc(table_relation, *this, s, t); return alloc(table_relation, *this, s, t);
} }
/**
The newly created object takes ownership of the \c t object.
*/
relation_base * table_relation_plugin::mk_from_table(const relation_signature & s, table_base * t) { relation_base * table_relation_plugin::mk_from_table(const relation_signature & s, table_base * t) {
if (&t->get_plugin() == &m_table_plugin) if (&t->get_plugin() == &m_table_plugin)
return alloc(table_relation, *this, s, t); return alloc(table_relation, *this, s, t);

View file

@ -292,17 +292,23 @@ namespace datalog {
void rel_context::transform_rules() { void rel_context::transform_rules() {
rule_transformer transf(m_context); rule_transformer transf(m_context);
#ifdef _MIN_DONE_
transf.register_plugin(alloc(mk_coi_filter, m_context)); transf.register_plugin(alloc(mk_coi_filter, m_context));
#endif
transf.register_plugin(alloc(mk_filter_rules, m_context)); transf.register_plugin(alloc(mk_filter_rules, m_context));
transf.register_plugin(alloc(mk_simple_joins, m_context)); transf.register_plugin(alloc(mk_simple_joins, m_context));
if (m_context.unbound_compressor()) { if (m_context.unbound_compressor()) {
transf.register_plugin(alloc(mk_unbound_compressor, m_context)); transf.register_plugin(alloc(mk_unbound_compressor, m_context));
} }
#ifdef _MIN_DONE_
if (m_context.similarity_compressor()) { if (m_context.similarity_compressor()) {
transf.register_plugin(alloc(mk_similarity_compressor, m_context)); transf.register_plugin(alloc(mk_similarity_compressor, m_context));
} }
#endif
transf.register_plugin(alloc(mk_partial_equivalence_transformer, m_context)); transf.register_plugin(alloc(mk_partial_equivalence_transformer, m_context));
#ifdef _MIN_DONE_
transf.register_plugin(alloc(mk_rule_inliner, m_context)); transf.register_plugin(alloc(mk_rule_inliner, m_context));
#endif
transf.register_plugin(alloc(mk_interp_tail_simplifier, m_context)); transf.register_plugin(alloc(mk_interp_tail_simplifier, m_context));
transf.register_plugin(alloc(mk_separate_negated_tails, m_context)); transf.register_plugin(alloc(mk_separate_negated_tails, m_context));

View file

@ -1,10 +1,8 @@
/*++ /*++
Copyright (c) 2015 Microsoft Corporation Copyright (c) 2015 Microsoft Corporation
--*/ --*/
#if defined(_WINDOWS) || defined(_CYGWIN)
#ifdef _WINDOWS
#include "dl_context.h" #include "dl_context.h"
#include "dl_table.h" #include "dl_table.h"
#include "dl_register_engine.h" #include "dl_register_engine.h"
@ -97,9 +95,78 @@ void test_dl_bitvector_table() {
test_table(mk_bv_table); test_table(mk_bv_table);
} }
void test_table_min() {
std::cout << "----- test_table_min -----\n";
datalog::table_signature sig;
sig.push_back(2);
sig.push_back(4);
sig.push_back(8);
smt_params params;
ast_manager ast_m;
datalog::register_engine re;
datalog::context ctx(ast_m, re, params);
datalog::relation_manager & m = ctx.get_rel_context()->get_rmanager();
m.register_plugin(alloc(datalog::bitvector_table_plugin, m));
datalog::table_base* tbl = mk_bv_table(m, sig);
datalog::table_base& table = *tbl;
datalog::table_fact row, row1, row2, row3;
row.push_back(1);
row.push_back(2);
row.push_back(5);
// Group (1,2,*)
row1 = row;
row[2] = 6;
row2 = row;
row[2] = 5;
row3 = row;
table.add_fact(row1);
table.add_fact(row2);
table.add_fact(row3);
// Group (1,3,*)
row[1] = 3;
row1 = row;
row[2] = 7;
row2 = row;
row[2] = 4;
row3 = row;
table.add_fact(row1);
table.add_fact(row2);
table.add_fact(row3);
table.display(std::cout);
unsigned_vector group_by(2);
group_by[0] = 0;
group_by[1] = 1;
datalog::table_min_fn * min_fn = m.mk_min_fn(table, group_by, 2);
datalog::table_base * min_tbl = (*min_fn)(table);
min_tbl->display(std::cout);
row[1] = 2;
row[2] = 5;
SASSERT(min_tbl->contains_fact(row));
row[1] = 3;
row[2] = 4;
SASSERT(min_tbl->contains_fact(row));
dealloc(min_fn);
min_tbl->deallocate();
tbl->deallocate();
}
void tst_dl_table() { void tst_dl_table() {
test_dl_bitvector_table(); test_dl_bitvector_table();
test_table_min();
} }
#else #else
void tst_dl_table() { void tst_dl_table() {