mirror of
https://github.com/Z3Prover/z3
synced 2025-04-24 09:35:32 +00:00
Add datalog infrastructure for min aggregation function
This patch adds an instruction to the datalog interpreter and constructs a new AST node for min aggregation functions. The compiler is currently still work in progress and depends on changes made to the handling of simple joins and the preprocessor. Signed-off-by: Alex Horn <t-alexh@microsoft.com>
This commit is contained in:
parent
004bf1471f
commit
140fb7942d
11 changed files with 418 additions and 3 deletions
|
@ -44,7 +44,8 @@ namespace datalog {
|
|||
m_num_sym("N"),
|
||||
m_lt_sym("<"),
|
||||
m_le_sym("<="),
|
||||
m_rule_sym("R")
|
||||
m_rule_sym("R"),
|
||||
m_min_sym("min")
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -490,6 +491,66 @@ namespace datalog {
|
|||
return m_manager->mk_func_decl(m_clone_sym, 1, &s, s, info);
|
||||
}
|
||||
|
||||
/**
|
||||
In SMT2 syntax, we can write \c ((_ min R N) v_0 v_1 ... v_k)) where 0 <= N <= k,
|
||||
R is a relation of sort V_0 x V_1 x ... x V_k and each v_i is a zero-arity function
|
||||
(also known as a "constant" in SMT2 parlance) whose range is of sort V_i.
|
||||
|
||||
Example:
|
||||
|
||||
(define-sort number_t () (_ BitVec 2))
|
||||
(declare-rel numbers (number_t number_t))
|
||||
(declare-rel is_min (number_t number_t))
|
||||
|
||||
(declare-var x number_t)
|
||||
(declare-var y number_t)
|
||||
|
||||
(rule (numbers #b00 #b11))
|
||||
(rule (numbers #b00 #b01))
|
||||
|
||||
(rule (=> (and (numbers x y) ((_ min numbers 1) x y)) (is_min x y)))
|
||||
|
||||
This says that we want to find the mininum y grouped by x.
|
||||
*/
|
||||
func_decl * dl_decl_plugin::mk_min(decl_kind k, unsigned num_parameters, parameter const * parameters) {
|
||||
if (num_parameters < 2) {
|
||||
m_manager->raise_exception("invalid min aggregate definition due to missing parameters");
|
||||
return 0;
|
||||
}
|
||||
|
||||
parameter const & relation_parameter = parameters[0];
|
||||
if (!relation_parameter.is_ast() || !is_func_decl(relation_parameter.get_ast())) {
|
||||
m_manager->raise_exception("invalid min aggregate definition, first parameter is not a function declaration");
|
||||
return 0;
|
||||
}
|
||||
|
||||
func_decl* f = to_func_decl(relation_parameter.get_ast());
|
||||
if (!m_manager->is_bool(f->get_range())) {
|
||||
m_manager->raise_exception("invalid min aggregate definition, first paramater must be a predicate");
|
||||
return 0;
|
||||
}
|
||||
|
||||
parameter const & min_col_parameter = parameters[1];
|
||||
if (!min_col_parameter.is_int()) {
|
||||
m_manager->raise_exception("invalid min aggregate definition, second parameter must be an integer");
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (min_col_parameter.get_int() < 0) {
|
||||
m_manager->raise_exception("invalid min aggregate definition, second parameter must be non-negative");
|
||||
return 0;
|
||||
}
|
||||
|
||||
if ((unsigned)min_col_parameter.get_int() >= f->get_arity()) {
|
||||
m_manager->raise_exception("invalid min aggregate definition, second parameter exceeds the arity of the relation");
|
||||
return 0;
|
||||
}
|
||||
|
||||
func_decl_info info(m_family_id, k, num_parameters, parameters);
|
||||
SASSERT(f->get_info() == 0);
|
||||
return m_manager->mk_func_decl(m_min_sym, f->get_arity(), f->get_domain(), f->get_range(), info);
|
||||
}
|
||||
|
||||
func_decl * dl_decl_plugin::mk_func_decl(
|
||||
decl_kind k, unsigned num_parameters, parameter const * parameters,
|
||||
unsigned arity, sort * const * domain, sort * range) {
|
||||
|
@ -617,6 +678,9 @@ namespace datalog {
|
|||
break;
|
||||
}
|
||||
|
||||
case OP_DL_MIN:
|
||||
return mk_min(k, num_parameters, parameters);
|
||||
|
||||
default:
|
||||
m_manager->raise_exception("operator not recognized");
|
||||
return 0;
|
||||
|
@ -627,7 +691,7 @@ namespace datalog {
|
|||
}
|
||||
|
||||
void dl_decl_plugin::get_op_names(svector<builtin_name> & op_names, symbol const & logic) {
|
||||
|
||||
op_names.push_back(builtin_name(m_min_sym.bare_str(), OP_DL_MIN));
|
||||
}
|
||||
|
||||
void dl_decl_plugin::get_sort_names(svector<builtin_name> & sort_names, symbol const & logic) {
|
||||
|
|
|
@ -50,6 +50,7 @@ namespace datalog {
|
|||
OP_DL_LT,
|
||||
OP_DL_REP,
|
||||
OP_DL_ABS,
|
||||
OP_DL_MIN,
|
||||
LAST_RA_OP
|
||||
};
|
||||
|
||||
|
@ -71,6 +72,7 @@ namespace datalog {
|
|||
symbol m_lt_sym;
|
||||
symbol m_le_sym;
|
||||
symbol m_rule_sym;
|
||||
symbol m_min_sym;
|
||||
|
||||
bool check_bounds(char const* msg, unsigned low, unsigned up, unsigned val) const;
|
||||
bool check_domain(unsigned low, unsigned up, unsigned val) const;
|
||||
|
@ -94,12 +96,69 @@ namespace datalog {
|
|||
func_decl * mk_compare(decl_kind k, symbol const& sym, sort*const* domain);
|
||||
func_decl * mk_clone(sort* r);
|
||||
func_decl * mk_rule(unsigned arity);
|
||||
func_decl * mk_min(decl_kind k, unsigned num_parameters, parameter const * parameters);
|
||||
|
||||
sort * mk_finite_sort(unsigned num_params, parameter const* params);
|
||||
sort * mk_relation_sort(unsigned num_params, parameter const* params);
|
||||
sort * mk_rule_sort();
|
||||
|
||||
public:
|
||||
/**
|
||||
Is \c decl a min aggregation function?
|
||||
*/
|
||||
static bool is_aggregate(const func_decl* const decl)
|
||||
{
|
||||
return decl->get_decl_kind() == OP_DL_MIN;
|
||||
}
|
||||
|
||||
/**
|
||||
\pre: is_aggregate(aggregate)
|
||||
|
||||
\returns function declaration of predicate which is subject to min aggregation function
|
||||
*/
|
||||
static func_decl * min_func_decl(const func_decl* const aggregate)
|
||||
{
|
||||
SASSERT(is_aggregate(aggregate));
|
||||
parameter const & relation_parameter = aggregate->get_parameter(0);
|
||||
return to_func_decl(relation_parameter.get_ast());
|
||||
}
|
||||
|
||||
/**
|
||||
\pre: is_aggregate(aggregate)
|
||||
|
||||
\returns column identifier (starting at zero) which is minimized by aggregation function
|
||||
*/
|
||||
static unsigned min_col(const func_decl* const aggregate)
|
||||
{
|
||||
SASSERT(is_aggregate(aggregate));
|
||||
return (unsigned)aggregate->get_parameter(1).get_int();
|
||||
}
|
||||
|
||||
/**
|
||||
\pre: is_aggregate(aggregate)
|
||||
|
||||
\returns column identifiers for the "group by" in the given min aggregation function
|
||||
*/
|
||||
static unsigned_vector group_by_cols(const func_decl* const aggregate)
|
||||
{
|
||||
SASSERT(is_aggregate(aggregate));
|
||||
unsigned _min_col = min_col(aggregate);
|
||||
if (aggregate->get_arity() == 0U)
|
||||
return unsigned_vector();
|
||||
|
||||
unsigned col_num = 0;
|
||||
unsigned_vector cols(aggregate->get_arity() - 1U);
|
||||
for (unsigned i = 0; i < cols.size(); ++i, ++col_num)
|
||||
{
|
||||
if (col_num == _min_col)
|
||||
++col_num;
|
||||
|
||||
cols[i] = col_num;
|
||||
}
|
||||
|
||||
return cols;
|
||||
}
|
||||
|
||||
dl_decl_plugin();
|
||||
virtual ~dl_decl_plugin() {}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue