diff --git a/src/ast/dl_decl_plugin.cpp b/src/ast/dl_decl_plugin.cpp index badf8a59d..305ac1779 100644 --- a/src/ast/dl_decl_plugin.cpp +++ b/src/ast/dl_decl_plugin.cpp @@ -44,7 +44,8 @@ namespace datalog { m_num_sym("N"), m_lt_sym("<"), m_le_sym("<="), - m_rule_sym("R") + m_rule_sym("R"), + m_min_sym("min") { } @@ -490,6 +491,66 @@ namespace datalog { return m_manager->mk_func_decl(m_clone_sym, 1, &s, s, info); } + /** + In SMT2 syntax, we can write \c ((_ min R N) v_0 v_1 ... v_k)) where 0 <= N <= k, + R is a relation of sort V_0 x V_1 x ... x V_k and each v_i is a zero-arity function + (also known as a "constant" in SMT2 parlance) whose range is of sort V_i. + + Example: + + (define-sort number_t () (_ BitVec 2)) + (declare-rel numbers (number_t number_t)) + (declare-rel is_min (number_t number_t)) + + (declare-var x number_t) + (declare-var y number_t) + + (rule (numbers #b00 #b11)) + (rule (numbers #b00 #b01)) + + (rule (=> (and (numbers x y) ((_ min numbers 1) x y)) (is_min x y))) + + This says that we want to find the mininum y grouped by x. + */ + func_decl * dl_decl_plugin::mk_min(decl_kind k, unsigned num_parameters, parameter const * parameters) { + if (num_parameters < 2) { + m_manager->raise_exception("invalid min aggregate definition due to missing parameters"); + return 0; + } + + parameter const & relation_parameter = parameters[0]; + if (!relation_parameter.is_ast() || !is_func_decl(relation_parameter.get_ast())) { + m_manager->raise_exception("invalid min aggregate definition, first parameter is not a function declaration"); + return 0; + } + + func_decl* f = to_func_decl(relation_parameter.get_ast()); + if (!m_manager->is_bool(f->get_range())) { + m_manager->raise_exception("invalid min aggregate definition, first paramater must be a predicate"); + return 0; + } + + parameter const & min_col_parameter = parameters[1]; + if (!min_col_parameter.is_int()) { + m_manager->raise_exception("invalid min aggregate definition, second parameter must be an integer"); + return 0; + } + + if (min_col_parameter.get_int() < 0) { + m_manager->raise_exception("invalid min aggregate definition, second parameter must be non-negative"); + return 0; + } + + if ((unsigned)min_col_parameter.get_int() >= f->get_arity()) { + m_manager->raise_exception("invalid min aggregate definition, second parameter exceeds the arity of the relation"); + return 0; + } + + func_decl_info info(m_family_id, k, num_parameters, parameters); + SASSERT(f->get_info() == 0); + return m_manager->mk_func_decl(m_min_sym, f->get_arity(), f->get_domain(), f->get_range(), info); + } + func_decl * dl_decl_plugin::mk_func_decl( decl_kind k, unsigned num_parameters, parameter const * parameters, unsigned arity, sort * const * domain, sort * range) { @@ -617,6 +678,9 @@ namespace datalog { break; } + case OP_DL_MIN: + return mk_min(k, num_parameters, parameters); + default: m_manager->raise_exception("operator not recognized"); return 0; @@ -627,7 +691,7 @@ namespace datalog { } void dl_decl_plugin::get_op_names(svector & op_names, symbol const & logic) { - + op_names.push_back(builtin_name(m_min_sym.bare_str(), OP_DL_MIN)); } void dl_decl_plugin::get_sort_names(svector & sort_names, symbol const & logic) { diff --git a/src/ast/dl_decl_plugin.h b/src/ast/dl_decl_plugin.h index 65b00235c..e3bc4dd63 100644 --- a/src/ast/dl_decl_plugin.h +++ b/src/ast/dl_decl_plugin.h @@ -50,6 +50,7 @@ namespace datalog { OP_DL_LT, OP_DL_REP, OP_DL_ABS, + OP_DL_MIN, LAST_RA_OP }; @@ -71,6 +72,7 @@ namespace datalog { symbol m_lt_sym; symbol m_le_sym; symbol m_rule_sym; + symbol m_min_sym; bool check_bounds(char const* msg, unsigned low, unsigned up, unsigned val) const; bool check_domain(unsigned low, unsigned up, unsigned val) const; @@ -94,12 +96,69 @@ namespace datalog { func_decl * mk_compare(decl_kind k, symbol const& sym, sort*const* domain); func_decl * mk_clone(sort* r); func_decl * mk_rule(unsigned arity); + func_decl * mk_min(decl_kind k, unsigned num_parameters, parameter const * parameters); sort * mk_finite_sort(unsigned num_params, parameter const* params); sort * mk_relation_sort(unsigned num_params, parameter const* params); sort * mk_rule_sort(); public: + /** + Is \c decl a min aggregation function? + */ + static bool is_aggregate(const func_decl* const decl) + { + return decl->get_decl_kind() == OP_DL_MIN; + } + + /** + \pre: is_aggregate(aggregate) + + \returns function declaration of predicate which is subject to min aggregation function + */ + static func_decl * min_func_decl(const func_decl* const aggregate) + { + SASSERT(is_aggregate(aggregate)); + parameter const & relation_parameter = aggregate->get_parameter(0); + return to_func_decl(relation_parameter.get_ast()); + } + + /** + \pre: is_aggregate(aggregate) + + \returns column identifier (starting at zero) which is minimized by aggregation function + */ + static unsigned min_col(const func_decl* const aggregate) + { + SASSERT(is_aggregate(aggregate)); + return (unsigned)aggregate->get_parameter(1).get_int(); + } + + /** + \pre: is_aggregate(aggregate) + + \returns column identifiers for the "group by" in the given min aggregation function + */ + static unsigned_vector group_by_cols(const func_decl* const aggregate) + { + SASSERT(is_aggregate(aggregate)); + unsigned _min_col = min_col(aggregate); + if (aggregate->get_arity() == 0U) + return unsigned_vector(); + + unsigned col_num = 0; + unsigned_vector cols(aggregate->get_arity() - 1U); + for (unsigned i = 0; i < cols.size(); ++i, ++col_num) + { + if (col_num == _min_col) + ++col_num; + + cols[i] = col_num; + } + + return cols; + } + dl_decl_plugin(); virtual ~dl_decl_plugin() {} diff --git a/src/ast/fpa/fpa2bv_converter.cpp b/src/ast/fpa/fpa2bv_converter.cpp index 4c199ebc6..baba7b701 100644 --- a/src/ast/fpa/fpa2bv_converter.cpp +++ b/src/ast/fpa/fpa2bv_converter.cpp @@ -93,6 +93,20 @@ void fpa2bv_converter::mk_ite(expr * c, expr * t, expr * f, expr_ref & result) { mk_fp(sgn, e, s, result); } +void fpa2bv_converter::mk_distinct(func_decl * f, unsigned num, expr * const * args, expr_ref & result) { + // Note: in SMT there is only one NaN, so multiple of them are considered + // equal, thus (distinct NaN NaN) is false, even if the two NaNs have + // different bitwise representations (see also mk_eq). + result = m.mk_true(); + for (unsigned i = 0; i < num; i++) { + for (unsigned j = i+1; j < num; j++) { + expr_ref eq(m); + mk_eq(args[i], args[j], eq); + m_simp.mk_and(result, m.mk_not(eq), result); + } + } +} + void fpa2bv_converter::mk_numeral(func_decl * f, unsigned num, expr * const * args, expr_ref & result) { SASSERT(num == 0); SASSERT(f->get_num_parameters() == 1); diff --git a/src/ast/fpa/fpa2bv_converter.h b/src/ast/fpa/fpa2bv_converter.h index 4b3c1a6ca..b0881a364 100644 --- a/src/ast/fpa/fpa2bv_converter.h +++ b/src/ast/fpa/fpa2bv_converter.h @@ -80,6 +80,7 @@ public: void mk_eq(expr * a, expr * b, expr_ref & result); void mk_ite(expr * c, expr * t, expr * f, expr_ref & result); + void mk_distinct(func_decl * f, unsigned num, expr * const * args, expr_ref & result); void mk_rounding_mode(func_decl * f, expr_ref & result); void mk_numeral(func_decl * f, unsigned num, expr * const * args, expr_ref & result); diff --git a/src/ast/fpa/fpa2bv_rewriter.h b/src/ast/fpa/fpa2bv_rewriter.h index ed885a4cc..fa88c227c 100644 --- a/src/ast/fpa/fpa2bv_rewriter.h +++ b/src/ast/fpa/fpa2bv_rewriter.h @@ -103,8 +103,7 @@ struct fpa2bv_rewriter_cfg : public default_rewriter_cfg { } return BR_FAILED; } - - if (m().is_ite(f)) { + else if (m().is_ite(f)) { SASSERT(num == 3); if (m_conv.is_float(args[1])) { m_conv.mk_ite(args[0], args[1], args[2], result); @@ -112,6 +111,14 @@ struct fpa2bv_rewriter_cfg : public default_rewriter_cfg { } return BR_FAILED; } + else if (m().is_distinct(f)) { + sort * ds = f->get_domain()[0]; + if (m_conv.is_float(ds) || m_conv.is_rm(ds)) { + m_conv.mk_distinct(f, num, args, result); + return BR_DONE; + } + return BR_FAILED; + } if (m_conv.is_float_family(f)) { switch (f->get_decl_kind()) { diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index 50afc1735..c8d5e8cab 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -621,6 +621,7 @@ void cmd_context::init_manager_core(bool new_manager) { register_plugin(symbol("seq"), alloc(seq_decl_plugin), logic_has_seq()); register_plugin(symbol("pb"), alloc(pb_decl_plugin), !has_logic()); register_plugin(symbol("fpa"), alloc(fpa_decl_plugin), logic_has_fpa()); + register_plugin(symbol("datalog_relation"), alloc(datalog::dl_decl_plugin), !has_logic()); } else { // the manager was created by an external module diff --git a/src/muz/base/dl_rule.h b/src/muz/base/dl_rule.h index bdca80d0b..468b9f88c 100644 --- a/src/muz/base/dl_rule.h +++ b/src/muz/base/dl_rule.h @@ -346,6 +346,13 @@ namespace datalog { bool is_neg_tail(unsigned i) const { SASSERT(i < m_tail_size); return GET_TAG(m_tail[i]) == 1; } + /** + A predicate P(Xj) can be annotated by adding an interpreted predicate of the form ((_ min P N) ...) + where N is the column number that should be used for the min aggregation function. + Such an interpreted predicate is an example for which this function returns true. + */ + bool is_min_tail(unsigned i) const { return dl_decl_plugin::is_aggregate(get_tail(i)->get_decl()); } + /** Check whether predicate p is in the interpreted tail. diff --git a/src/muz/base/dl_rule_set.cpp b/src/muz/base/dl_rule_set.cpp index ad3b512a3..555b592ef 100644 --- a/src/muz/base/dl_rule_set.cpp +++ b/src/muz/base/dl_rule_set.cpp @@ -400,7 +400,7 @@ namespace datalog { SASSERT(!is_closed()); //the rule_set is not already closed m_deps.populate(*this); m_stratifier = alloc(rule_stratifier, m_deps); - if (!stratified_negation()) { + if (!stratified_negation() || !check_min()) { m_stratifier = 0; m_deps.reset(); return false; @@ -441,6 +441,49 @@ namespace datalog { return true; } + bool rule_set::check_min() { + // For now, we check the following: + // + // if a min aggregation function occurs in an SCC, is this SCC + // free of any other non-monotonic functions, e.g. negation? + const unsigned NEG_BIT = 1U << 0; + const unsigned MIN_BIT = 1U << 1; + + ptr_vector::const_iterator it = m_rules.c_ptr(); + ptr_vector::const_iterator end = m_rules.c_ptr() + m_rules.size(); + unsigned_vector component_status(m_stratifier->get_strats().size()); + + for (; it != end; it++) { + rule * r = *it; + app * head = r->get_head(); + func_decl * head_decl = head->get_decl(); + unsigned head_strat = get_predicate_strat(head_decl); + unsigned n = r->get_tail_size(); + for (unsigned i = 0; i < n; i++) { + func_decl * tail_decl = r->get_tail(i)->get_decl(); + unsigned strat = get_predicate_strat(tail_decl); + + if (r->is_neg_tail(i)) { + SASSERT(strat < component_status.size()); + component_status[strat] |= NEG_BIT; + } + + if (r->is_min_tail(i)) { + SASSERT(strat < component_status.size()); + component_status[strat] |= MIN_BIT; + } + } + } + + const unsigned CONFLICT = NEG_BIT | MIN_BIT; + for (unsigned k = 0; k < component_status.size(); ++k) { + if (component_status[k] == CONFLICT) + return false; + } + + return true; + } + void rule_set::replace_rules(const rule_set & src) { if (this != &src) { reset(); diff --git a/src/muz/base/dl_rule_set.h b/src/muz/base/dl_rule_set.h index e13d92105..a7d09b099 100644 --- a/src/muz/base/dl_rule_set.h +++ b/src/muz/base/dl_rule_set.h @@ -179,6 +179,7 @@ namespace datalog { void compute_deps(); void compute_tc_deps(); bool stratified_negation(); + bool check_min(); public: rule_set(context & ctx); rule_set(const rule_set & rs); diff --git a/src/muz/rel/dl_base.cpp b/src/muz/rel/dl_base.cpp index 95efccce8..6dc7f2f6e 100644 --- a/src/muz/rel/dl_base.cpp +++ b/src/muz/rel/dl_base.cpp @@ -485,4 +485,125 @@ namespace datalog { brw.mk_or(disjs.size(), disjs.c_ptr(), fml); } + class table_plugin::min_fn : public table_min_fn{ + table_signature m_sig; + const unsigned_vector m_group_by_cols; + const unsigned m_col; + public: + min_fn(const table_signature & t_sig, const unsigned_vector& group_by_cols, const unsigned col) + : m_sig(t_sig), + m_group_by_cols(group_by_cols), + m_col(col) {} + + virtual table_base* operator()(table_base const& t) { + //return reference_implementation(t); + return reference_implementation_with_hash(t); + } + + private: + + /** + Reference implementation with negation: + + T1 = join(T, T) by group_cols + T2 = { (t1,t2) in T1 | t1[col] > t2[col] } + T3 = { t1 | (t1,t2) in T2 } + T4 = T \ T3 + + The point of this reference implementation is to show + that the minimum requires negation (set difference). + This is relevant for fixed point computations. + */ + virtual table_base * reference_implementation(const table_base & t) { + relation_manager & manager = t.get_manager(); + scoped_ptr join_fn = manager.mk_join_fn(t, t, m_group_by_cols, m_group_by_cols); + scoped_rel join_table = (*join_fn)(t, t); + + table_base::iterator join_table_it = join_table->begin(); + table_base::iterator join_table_end = join_table->end(); + table_fact row; + + table_element i, j; + + for (; join_table_it != join_table_end; ++join_table_it) { + join_table_it->get_fact(row); + i = row[m_col]; + j = row[t.num_columns() + m_col]; + + if (i > j) { + continue; + } + + join_table->remove_fact(row); + } + + unsigned_vector cols(t.num_columns()); + for (unsigned k = 0; k < cols.size(); ++k) { + cols[k] = cols.size() + k; + SASSERT(cols[k] < join_table->num_columns()); + } + + scoped_ptr project_fn = manager.mk_project_fn(*join_table, cols); + scoped_rel gt_table = (*project_fn)(*join_table); + + for (unsigned k = 0; k < cols.size(); ++k) { + cols[k] = k; + SASSERT(cols[k] < t.num_columns()); + SASSERT(cols[k] < gt_table->num_columns()); + } + + table_base * result = t.clone(); + scoped_ptr diff_fn = manager.mk_filter_by_negation_fn(*result, *gt_table, cols, cols); + (*diff_fn)(*result, *gt_table); + return result; + } + + typedef map < table_fact, table_element, svector_hash_proc, + vector_eq_proc > group_map; + + // Thanks to Nikolaj who kindly helped with the second reference implementation! + virtual table_base * reference_implementation_with_hash(const table_base & t) { + group_map group; + table_base::iterator it = t.begin(); + table_base::iterator end = t.end(); + table_fact row, row2; + table_element current_value, min_value; + for (; it != end; ++it) { + it->get_fact(row); + current_value = row[m_col]; + group_by(row, row2); + group_map::entry* entry = group.find_core(row2); + if (!entry) { + group.insert(row2, current_value); + } + else if (entry->get_data().m_value > current_value) { + entry->get_data().m_value = current_value; + } + } + table_base* result = t.get_plugin().mk_empty(m_sig); + table_base::iterator it2 = t.begin(); + for (; it2 != end; ++it2) { + it2->get_fact(row); + current_value = row[m_col]; + group_by(row, row2); + VERIFY(group.find(row2, min_value)); + if (min_value == current_value) { + result->add_fact(row); + } + } + return result; + } + + void group_by(table_fact const& in, table_fact& out) { + out.reset(); + for (unsigned i = 0; i < m_group_by_cols.size(); ++i) { + out.push_back(in[m_group_by_cols[i]]); + } + } + }; + + table_min_fn * table_plugin::mk_min_fn(const table_base & t, + unsigned_vector & group_by_cols, const unsigned col) { + return alloc(table_plugin::min_fn, t.get_signature(), group_by_cols, col); + } } diff --git a/src/muz/rel/dl_base.h b/src/muz/rel/dl_base.h index 268cc602e..6ab1b2a96 100644 --- a/src/muz/rel/dl_base.h +++ b/src/muz/rel/dl_base.h @@ -192,6 +192,29 @@ namespace datalog { virtual base_object * operator()(const base_object & t1, const base_object & t2) = 0; }; + /** + \brief Aggregate minimum value + + Informally, we want to group rows in a table \c t by \c group_by_cols and + return the minimum value in column \c col among each group. + + Let \c t be a table with N columns. + Let \c group_by_cols be a set of column identifers for table \c t such that |group_by_cols| < N. + Let \c col be a column identifier for table \c t such that \c col is not in \c group_by_cols. + + Let R_col be a set of rows in table \c t such that, for all rows r_i, r_j in R_col + and column identifiers k in \c group_by_cols, r_i[k] = r_j[k]. + + For each R_col, we want to restrict R_col to those rows whose value in column \c col is minimal. + + min_fn(R, group_by_cols, col) = + { row in R | forall row' in R . row'[group_by_cols] = row[group_by_cols] => row'[col] >= row[col] } + */ + class min_fn : public base_fn { + public: + virtual base_object * operator()(const base_object & t) = 0; + }; + class transformer_fn : public base_fn { public: virtual base_object * operator()(const base_object & t) = 0; @@ -856,6 +879,7 @@ namespace datalog { typedef table_infrastructure::base_fn base_table_fn; typedef table_infrastructure::join_fn table_join_fn; + typedef table_infrastructure::min_fn table_min_fn; typedef table_infrastructure::transformer_fn table_transformer_fn; typedef table_infrastructure::union_fn table_union_fn; typedef table_infrastructure::mutator_fn table_mutator_fn; @@ -1020,6 +1044,7 @@ namespace datalog { class table_plugin : public table_infrastructure::plugin_object { friend class relation_manager; + class min_fn; protected: table_plugin(symbol const& n, relation_manager & manager) : plugin_object(n, manager) {} public: @@ -1027,6 +1052,9 @@ namespace datalog { virtual bool can_handle_signature(const table_signature & s) { return s.functional_columns()==0; } protected: + virtual table_min_fn * mk_min_fn(const table_base & t, + unsigned_vector & group_by_cols, const unsigned col); + /** If the returned value is non-zero, the returned object must take ownership of \c mapper. Otherwise \c mapper must remain unmodified. diff --git a/src/muz/rel/dl_compiler.cpp b/src/muz/rel/dl_compiler.cpp index 59ba260a4..c35985e98 100644 --- a/src/muz/rel/dl_compiler.cpp +++ b/src/muz/rel/dl_compiler.cpp @@ -73,6 +73,12 @@ namespace datalog { vars.get_cols2(), removed_cols.size(), removed_cols.c_ptr(), result)); } + void compiler::make_min(reg_idx source, reg_idx & target, const unsigned_vector & group_by_cols, + const unsigned min_col, instruction_block & acc) { + target = get_register(m_reg_signatures[source], true, source); + acc.push_back(instruction::mk_min(source, target, group_by_cols, min_col)); + } + void compiler::make_filter_interpreted_and_project(reg_idx src, app_ref & cond, const unsigned_vector & removed_cols, reg_idx & result, bool reuse, instruction_block & acc) { SASSERT(!removed_cols.empty()); @@ -440,6 +446,30 @@ namespace datalog { get_local_indexes_for_projection(t2, counter, t1->get_num_args(), res); } + void compiler::find_min_aggregates(const rule * r, ptr_vector& min_aggregates) { + unsigned ut_len = r->get_uninterpreted_tail_size(); + unsigned ft_len = r->get_tail_size(); // full tail + func_decl * aggregate; + for (unsigned tail_index = ut_len; tail_index < ft_len; ++tail_index) { + aggregate = r->get_tail(tail_index)->get_decl(); + if (dl_decl_plugin::is_aggregate(aggregate)) { + min_aggregates.push_back(aggregate); + } + } + } + + bool compiler::prepare_min_aggregate(const func_decl * decl, const ptr_vector& min_aggregates, + unsigned_vector & group_by_cols, unsigned & min_col) { + for (unsigned i = 0; i < min_aggregates.size(); ++i) { + if (dl_decl_plugin::min_func_decl(min_aggregates[i]) == decl) { + group_by_cols = dl_decl_plugin::group_by_cols(min_aggregates[i]); + min_col = dl_decl_plugin::min_col(min_aggregates[i]); + return true; + } + } + return false; + } + void compiler::compile_rule_evaluation_run(rule * r, reg_idx head_reg, const reg_idx * tail_regs, reg_idx delta_reg, bool use_widening, instruction_block & acc) { @@ -465,6 +495,12 @@ namespace datalog { // whether to dealloc the previous result bool dealloc = true; + // setup information for min aggregation + ptr_vector min_aggregates; + find_min_aggregates(r, min_aggregates); + unsigned_vector group_by_cols; + unsigned min_col; + if(pt_len == 2) { reg_idx t1_reg=tail_regs[0]; reg_idx t2_reg=tail_regs[1]; @@ -473,6 +509,14 @@ namespace datalog { SASSERT(m_reg_signatures[t1_reg].size()==a1->get_num_args()); SASSERT(m_reg_signatures[t2_reg].size()==a2->get_num_args()); + if (prepare_min_aggregate(a1->get_decl(), min_aggregates, group_by_cols, min_col)) { + make_min(t1_reg, single_res, group_by_cols, min_col, acc); + } + + if (prepare_min_aggregate(a2->get_decl(), min_aggregates, group_by_cols, min_col)) { + make_min(t2_reg, single_res, group_by_cols, min_col, acc); + } + variable_intersection a1a2(m_context.get_manager()); a1a2.populate(a1,a2); @@ -514,6 +558,10 @@ namespace datalog { single_res = tail_regs[0]; dealloc = false; + if (prepare_min_aggregate(a->get_decl(), min_aggregates, group_by_cols, min_col)) { + make_min(single_res, single_res, group_by_cols, min_col, acc); + } + SASSERT(m_reg_signatures[single_res].size() == a->get_num_args()); unsigned n=a->get_num_args(); @@ -597,7 +645,8 @@ namespace datalog { unsigned ft_len = r->get_tail_size(); // full tail ptr_vector tail; for (unsigned tail_index = ut_len; tail_index < ft_len; ++tail_index) { - tail.push_back(r->get_tail(tail_index)); + if (!r->is_min_tail(tail_index)) + tail.push_back(r->get_tail(tail_index)); } expr_ref_vector binding(m); diff --git a/src/muz/rel/dl_compiler.h b/src/muz/rel/dl_compiler.h index 4902b9387..a9e37a8a3 100644 --- a/src/muz/rel/dl_compiler.h +++ b/src/muz/rel/dl_compiler.h @@ -120,6 +120,22 @@ namespace datalog { instruction_observer m_instruction_observer; expr_free_vars m_free_vars; + /** + \brief Finds all the min aggregation functions in the premise of a given rule. + */ + static void find_min_aggregates(const rule * r, ptr_vector& min_aggregates); + + /** + \brief Decides whether a predicate is subject to a min aggregation function. + + If \c decl is subject to a min aggregation function, the output parameters are written + with the neccessary information. + + \returns true if the output paramaters have been written + */ + static bool prepare_min_aggregate(const func_decl * decl, const ptr_vector& min_aggregates, + unsigned_vector & group_by_cols, unsigned & min_col); + /** If true, the union operation on the underlying structure only provides the information whether the updated relation has changed or not. In this case we do not get anything @@ -146,6 +162,8 @@ namespace datalog { void make_join(reg_idx t1, reg_idx t2, const variable_intersection & vars, reg_idx & result, bool reuse_t1, instruction_block & acc); + void make_min(reg_idx source, reg_idx & target, const unsigned_vector & group_by_cols, + const unsigned min_col, instruction_block & acc); void make_join_project(reg_idx t1, reg_idx t2, const variable_intersection & vars, const unsigned_vector & removed_cols, reg_idx & result, bool reuse_t1, instruction_block & acc); void make_filter_interpreted_and_project(reg_idx src, app_ref & cond, diff --git a/src/muz/rel/dl_instruction.cpp b/src/muz/rel/dl_instruction.cpp index 7eb8d4375..f8145b922 100644 --- a/src/muz/rel/dl_instruction.cpp +++ b/src/muz/rel/dl_instruction.cpp @@ -25,6 +25,7 @@ Revision History: #include"rel_context.h" #include"debug.h" #include"warning.h" +#include"dl_table_relation.h" namespace datalog { @@ -552,7 +553,7 @@ namespace datalog { if (r.fast_empty()) { ctx.make_empty(m_reg); } - TRACE("dl_verbose", r.display(tout <<"post-filter-interpreted:\n");); + //TRACE("dl_verbose", r.display(tout <<"post-filter-interpreted:\n");); return true; } @@ -609,7 +610,7 @@ namespace datalog { if (ctx.reg(m_res)->fast_empty()) { ctx.make_empty(m_res); } - TRACE("dl_verbose", reg.display(tout << "post-filter-interpreted-and-project:\n");); + //TRACE("dl_verbose", reg.display(tout << "post-filter-interpreted-and-project:\n");); return true; } @@ -883,6 +884,59 @@ namespace datalog { removed_cols, result); } + class instr_min : public instruction { + reg_idx m_source_reg; + reg_idx m_target_reg; + unsigned_vector m_group_by_cols; + unsigned m_min_col; + public: + instr_min(reg_idx source_reg, reg_idx target_reg, const unsigned_vector & group_by_cols, unsigned min_col) + : m_source_reg(source_reg), + m_target_reg(target_reg), + m_group_by_cols(group_by_cols), + m_min_col(min_col) { + } + virtual bool perform(execution_context & ctx) { + log_verbose(ctx); + if (!ctx.reg(m_source_reg)) { + ctx.make_empty(m_target_reg); + return true; + } + + const relation_base & s = *ctx.reg(m_source_reg); + if (!s.from_table()) { + throw default_exception("relation is not a table %s", + s.get_plugin().get_name().bare_str()); + } + ++ctx.m_stats.m_min; + const table_relation & tr = static_cast(s); + const table_base & source_t = tr.get_table(); + relation_manager & r_manager = s.get_manager(); + + const relation_signature & r_sig = s.get_signature(); + scoped_ptr fn = r_manager.mk_min_fn(source_t, m_group_by_cols, m_min_col); + table_base * target_t = (*fn)(source_t); + + TRACE("dl", + tout << "% "; + target_t->display(tout); + tout << "\n";); + + relation_base * target_r = r_manager.mk_table_relation(r_sig, target_t); + ctx.set_reg(m_target_reg, target_r); + return true; + } + virtual void display_head_impl(execution_context const& ctx, std::ostream & out) const { + out << " MIN AGGR "; + } + virtual void make_annotations(execution_context & ctx) { + } + }; + + instruction * instruction::mk_min(reg_idx source, reg_idx target, const unsigned_vector & group_by_cols, + const unsigned min_col) { + return alloc(instr_min, source, target, group_by_cols, min_col); + } class instr_select_equal_and_project : public instruction { reg_idx m_src; diff --git a/src/muz/rel/dl_instruction.h b/src/muz/rel/dl_instruction.h index 3910f6d0b..a02346b99 100644 --- a/src/muz/rel/dl_instruction.h +++ b/src/muz/rel/dl_instruction.h @@ -93,6 +93,7 @@ namespace datalog { unsigned m_filter_interp_project; unsigned m_filter_id; unsigned m_filter_eq; + unsigned m_min; stats() { reset(); } void reset() { memset(this, 0, sizeof(*this)); } }; @@ -284,6 +285,8 @@ namespace datalog { static instruction * mk_join_project(reg_idx rel1, reg_idx rel2, unsigned joined_col_cnt, const unsigned * cols1, const unsigned * cols2, unsigned removed_col_cnt, const unsigned * removed_cols, reg_idx result); + static instruction * mk_min(reg_idx source, reg_idx target, const unsigned_vector & group_by_cols, + const unsigned min_col); static instruction * mk_rename(reg_idx src, unsigned cycle_len, const unsigned * permutation_cycle, reg_idx tgt); static instruction * mk_filter_by_negation(reg_idx tgt, reg_idx neg_rel, unsigned col_cnt, diff --git a/src/muz/rel/dl_relation_manager.cpp b/src/muz/rel/dl_relation_manager.cpp index 6a9bb7f2a..2b78baf05 100644 --- a/src/muz/rel/dl_relation_manager.cpp +++ b/src/muz/rel/dl_relation_manager.cpp @@ -354,7 +354,9 @@ namespace datalog { return product_relation_plugin::get_plugin(*this).mk_empty(s); } - + /** + The newly created object takes ownership of the \c table object. + */ relation_base * relation_manager::mk_table_relation(const relation_signature & s, table_base * table) { SASSERT(s.size()==table->get_signature().size()); return get_table_relation_plugin(table->get_plugin()).mk_from_table(s, table); @@ -1021,6 +1023,11 @@ namespace datalog { return res; } + table_min_fn * relation_manager::mk_min_fn(const table_base & t, + unsigned_vector & group_by_cols, const unsigned col) + { + return t.get_plugin().mk_min_fn(t, group_by_cols, col); + } class relation_manager::auxiliary_table_transformer_fn { table_fact m_row; diff --git a/src/muz/rel/dl_relation_manager.h b/src/muz/rel/dl_relation_manager.h index 53d7f21e2..f91b7496a 100644 --- a/src/muz/rel/dl_relation_manager.h +++ b/src/muz/rel/dl_relation_manager.h @@ -251,6 +251,9 @@ namespace datalog { return mk_join_fn(t1, t2, cols1.size(), cols1.c_ptr(), cols2.c_ptr(), allow_product_relation); } + table_min_fn * mk_min_fn(const table_base & t, + unsigned_vector & group_by_cols, const unsigned col); + /** \brief Return functor that transforms a table into one that lacks columns listed in \c removed_cols array. diff --git a/src/muz/rel/dl_table_relation.cpp b/src/muz/rel/dl_table_relation.cpp index 364c29367..d42d071aa 100644 --- a/src/muz/rel/dl_table_relation.cpp +++ b/src/muz/rel/dl_table_relation.cpp @@ -63,6 +63,9 @@ namespace datalog { return alloc(table_relation, *this, s, t); } + /** + The newly created object takes ownership of the \c t object. + */ relation_base * table_relation_plugin::mk_from_table(const relation_signature & s, table_base * t) { if (&t->get_plugin() == &m_table_plugin) return alloc(table_relation, *this, s, t); diff --git a/src/muz/rel/rel_context.cpp b/src/muz/rel/rel_context.cpp index c3ffdceae..d2a2f0181 100644 --- a/src/muz/rel/rel_context.cpp +++ b/src/muz/rel/rel_context.cpp @@ -290,19 +290,27 @@ namespace datalog { return res; } +#define _MIN_DONE_ 1 + void rel_context::transform_rules() { rule_transformer transf(m_context); +#ifdef _MIN_DONE_ transf.register_plugin(alloc(mk_coi_filter, m_context)); +#endif transf.register_plugin(alloc(mk_filter_rules, m_context)); transf.register_plugin(alloc(mk_simple_joins, m_context)); if (m_context.unbound_compressor()) { transf.register_plugin(alloc(mk_unbound_compressor, m_context)); } +#ifdef _MIN_DONE_ if (m_context.similarity_compressor()) { transf.register_plugin(alloc(mk_similarity_compressor, m_context)); } +#endif transf.register_plugin(alloc(mk_partial_equivalence_transformer, m_context)); +#ifdef _MIN_DONE_ transf.register_plugin(alloc(mk_rule_inliner, m_context)); +#endif transf.register_plugin(alloc(mk_interp_tail_simplifier, m_context)); transf.register_plugin(alloc(mk_separate_negated_tails, m_context)); diff --git a/src/tactic/fpa/qffp_tactic.cpp b/src/tactic/fpa/qffp_tactic.cpp index fbce1668e..485e837ee 100644 --- a/src/tactic/fpa/qffp_tactic.cpp +++ b/src/tactic/fpa/qffp_tactic.cpp @@ -53,7 +53,7 @@ struct has_fp_to_real_predicate { class has_fp_to_real_probe : public probe { public: virtual result operator()(goal const & g) { - return !test(g); + return test(g); } virtual ~has_fp_to_real_probe() {} diff --git a/src/test/dl_table.cpp b/src/test/dl_table.cpp index 6dc688ede..b14988f11 100644 --- a/src/test/dl_table.cpp +++ b/src/test/dl_table.cpp @@ -1,10 +1,8 @@ - /*++ Copyright (c) 2015 Microsoft Corporation - --*/ +#if defined(_WINDOWS) || defined(_CYGWIN) -#ifdef _WINDOWS #include "dl_context.h" #include "dl_table.h" #include "dl_register_engine.h" @@ -97,9 +95,78 @@ void test_dl_bitvector_table() { test_table(mk_bv_table); } +void test_table_min() { + std::cout << "----- test_table_min -----\n"; + datalog::table_signature sig; + sig.push_back(2); + sig.push_back(4); + sig.push_back(8); + smt_params params; + ast_manager ast_m; + datalog::register_engine re; + datalog::context ctx(ast_m, re, params); + datalog::relation_manager & m = ctx.get_rel_context()->get_rmanager(); + + m.register_plugin(alloc(datalog::bitvector_table_plugin, m)); + + datalog::table_base* tbl = mk_bv_table(m, sig); + datalog::table_base& table = *tbl; + + datalog::table_fact row, row1, row2, row3; + row.push_back(1); + row.push_back(2); + row.push_back(5); + + // Group (1,2,*) + row1 = row; + row[2] = 6; + row2 = row; + row[2] = 5; + row3 = row; + + table.add_fact(row1); + table.add_fact(row2); + table.add_fact(row3); + + // Group (1,3,*) + row[1] = 3; + row1 = row; + row[2] = 7; + row2 = row; + row[2] = 4; + row3 = row; + + table.add_fact(row1); + table.add_fact(row2); + table.add_fact(row3); + + table.display(std::cout); + + unsigned_vector group_by(2); + group_by[0] = 0; + group_by[1] = 1; + + datalog::table_min_fn * min_fn = m.mk_min_fn(table, group_by, 2); + datalog::table_base * min_tbl = (*min_fn)(table); + + min_tbl->display(std::cout); + + row[1] = 2; + row[2] = 5; + SASSERT(min_tbl->contains_fact(row)); + + row[1] = 3; + row[2] = 4; + SASSERT(min_tbl->contains_fact(row)); + + dealloc(min_fn); + min_tbl->deallocate(); + tbl->deallocate(); +} void tst_dl_table() { test_dl_bitvector_table(); + test_table_min(); } #else void tst_dl_table() {