3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-20 04:43:39 +00:00

more ddnf

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2014-08-21 23:48:36 -07:00
parent eaabae3219
commit 3d0cb6a5e9

View file

@ -24,6 +24,7 @@ Revision History:
#include "dl_rule_set.h" #include "dl_rule_set.h"
#include "dl_context.h" #include "dl_context.h"
#include "scoped_proof.h" #include "scoped_proof.h"
#include "bv_decl_plugin.h"
namespace datalog { namespace datalog {
@ -43,6 +44,7 @@ namespace datalog {
resize(n, val); resize(n, val);
} }
tbv(uint64 val, unsigned n) : bit_vector(2*n) { tbv(uint64 val, unsigned n) : bit_vector(2*n) {
resize(n, BIT_x);
for (unsigned bit = n; bit > 0;) { for (unsigned bit = n; bit > 0;) {
--bit; --bit;
if (val & (1ULL << bit)) { if (val & (1ULL << bit)) {
@ -53,6 +55,14 @@ namespace datalog {
} }
} }
tbv(uint64 v, unsigned sz, unsigned hi, unsigned lo) : bit_vector(2*sz) {
resize(sz, BIT_x);
SASSERT(64 >= sz && sz > hi && hi >= lo);
for (unsigned i = 0; i < hi - lo + 1; ++i) {
set(lo + i, (v & (1ULL << i))?BIT_1:BIT_0);
}
}
tbv(rational const& v, unsigned n) : bit_vector(2*n) { tbv(rational const& v, unsigned n) : bit_vector(2*n) {
if (v.is_uint64() && n <= 64) { if (v.is_uint64() && n <= 64) {
tbv tmp(v.get_uint64(), n); tbv tmp(v.get_uint64(), n);
@ -60,6 +70,7 @@ namespace datalog {
return; return;
} }
resize(n, BIT_x);
for (unsigned bit = n; bit > 0; ) { for (unsigned bit = n; bit > 0; ) {
--bit; --bit;
if (bitwise_and(v, rational::power_of_two(bit)).is_zero()) { if (bitwise_and(v, rational::power_of_two(bit)).is_zero()) {
@ -127,6 +138,19 @@ namespace datalog {
} }
} }
struct eq {
bool operator()(tbv const& d1, tbv const& d2) const {
return d1 == d2;
}
};
struct hash {
unsigned operator()(tbv const& d) const {
return d.get_hash();
}
};
friend bool intersect(tbv const& a, tbv const& b, tbv& result); friend bool intersect(tbv const& a, tbv const& b, tbv& result);
private: private:
@ -142,6 +166,11 @@ namespace datalog {
} }
}; };
std::ostream& operator<<(std::ostream& out, tbv const& t) {
t.display(out);
return out;
}
bool intersect(tbv const& a, tbv const& b, tbv& result) { bool intersect(tbv const& a, tbv const& b, tbv& result) {
result = a; result = a;
result &= b; result &= b;
@ -156,6 +185,7 @@ namespace datalog {
vector<tbv> m_negs; vector<tbv> m_negs;
public: public:
dot() {} dot() {}
dot(tbv const& pos): m_pos(pos) {}
dot(tbv const& pos, vector<tbv> const& negs): dot(tbv const& pos, vector<tbv> const& negs):
m_pos(pos), m_negs(negs) { m_pos(pos), m_negs(negs) {
DEBUG_CODE( DEBUG_CODE(
@ -277,7 +307,7 @@ namespace datalog {
ddnf_node* add_neg(dot const& d) { m_neg.push_back(d); return this; } ddnf_node* add_neg(dot const& d) { m_neg.push_back(d); return this; }
void display(std::ostream& out) const { void display(std::ostream& out) const {
out << "node["; out << "node[" << get_id() << ": ";
m_tbv.display(out); m_tbv.display(out);
for (unsigned i = 0; i < m_children.size(); ++i) { for (unsigned i = 0; i < m_children.size(); ++i) {
out << " " << m_children[i]->get_id(); out << " " << m_children[i]->get_id();
@ -318,23 +348,32 @@ namespace datalog {
} }
void insert(dot const& d) { void insert(dot const& d) {
SASSERT(d.size() == m_num_bits); SASSERT(d.num_bits() == m_num_bits);
SASSERT(!m_internalized); SASSERT(!m_internalized);
if (m_dots.contains(d)) return; if (m_dots.contains(d)) return;
ddnf_nodes* ns = alloc(ddnf_nodes); ddnf_nodes* ns = alloc(ddnf_nodes);
m_tables.push_back(ns); m_tables.push_back(ns);
m_dots.insert(d, ns); m_dots.insert(d, ns);
insert(d.pos())->add_pos(d); insert_tbv(d.pos())->add_pos(d);
for (unsigned i = 0; i < d.size(); ++i) { for (unsigned i = 0; i < d.size(); ++i) {
insert(d.neg(i))->add_neg(d); insert_tbv(d.neg(i))->add_neg(d);
} }
} }
void insert(tbv const& t) {
insert(dot(t));
}
ddnf_nodes const& lookup(dot const& d) { ddnf_nodes const& lookup(dot const& d) {
internalize(); internalize();
return *m_dots.find(d); return *m_dots.find(d);
} }
ddnf_nodes const& lookup(tbv const& t) {
internalize();
return *m_dots.find(dot(t));
}
void display(std::ostream& out) const { void display(std::ostream& out) const {
for (unsigned i = 0; i < m_noderefs.size(); ++i) { for (unsigned i = 0; i < m_noderefs.size(); ++i) {
m_noderefs[i]->display(out); m_noderefs[i]->display(out);
@ -344,7 +383,18 @@ namespace datalog {
private: private:
ddnf_node* insert(tbv const& t) {
ddnf_node* find(tbv const& t) {
ddnf_node dummy(*this, t, 0);
return *(m_nodes.find(&dummy));
}
bool contains(tbv const& t) {
ddnf_node dummy(*this, t, 0);
return m_nodes.contains(&dummy);
}
ddnf_node* insert_tbv(tbv const& t) {
vector<tbv> new_tbvs; vector<tbv> new_tbvs;
new_tbvs.push_back(t); new_tbvs.push_back(t);
for (unsigned i = 0; i < new_tbvs.size(); ++i) { for (unsigned i = 0; i < new_tbvs.size(); ++i) {
@ -358,15 +408,6 @@ namespace datalog {
return find(t); return find(t);
} }
ddnf_node* find(tbv const& t) {
ddnf_node dummy(*this, t, 0);
return *(m_nodes.find(&dummy));
}
bool contains(tbv const& t) {
ddnf_node dummy(*this, t, 0);
return m_nodes.contains(&dummy);
}
void insert(ddnf_node& root, ddnf_node* new_n, vector<tbv>& new_intersections) { void insert(ddnf_node& root, ddnf_node* new_n, vector<tbv>& new_intersections) {
tbv const& new_tbv = new_n->get_tbv(); tbv const& new_tbv = new_n->get_tbv();
@ -487,6 +528,10 @@ namespace datalog {
m->insert(d); m->insert(d);
} }
void insert(tbv const& t) {
insert(dot(t));
}
ddnf_nodes const& lookup(dot const& d) const { ddnf_nodes const& lookup(dot const& d) const {
return m_mgrs.find(d.num_bits())->lookup(d); return m_mgrs.find(d.num_bits())->lookup(d);
} }
@ -510,17 +555,20 @@ namespace datalog {
context& m_ctx; context& m_ctx;
ast_manager& m; ast_manager& m;
rule_manager& rm; rule_manager& rm;
bv_util bv;
volatile bool m_cancel; volatile bool m_cancel;
ptr_vector<expr> m_todo; ptr_vector<expr> m_todo;
ast_mark m_visited1, m_visited2; ast_mark m_visited1, m_visited2;
ddnfs m_ddnfs; ddnfs m_ddnfs;
stats m_stats; stats m_stats;
obj_map<expr, tbv> m_cache;
public: public:
imp(context& ctx): imp(context& ctx):
m_ctx(ctx), m_ctx(ctx),
m(ctx.get_manager()), m(ctx.get_manager()),
rm(ctx.get_rule_manager()), rm(ctx.get_rule_manager()),
bv(m),
m_cancel(false) m_cancel(false)
{ {
} }
@ -529,10 +577,11 @@ namespace datalog {
lbool query(expr* query) { lbool query(expr* query) {
m_ctx.ensure_opened(); m_ctx.ensure_opened();
if (!can_handle_rules()) { if (!process_rules()) {
return l_undef; return l_undef;
} }
IF_VERBOSE(0, verbose_stream() << "rules are OK\n";); IF_VERBOSE(0, verbose_stream() << "rules are OK\n";);
IF_VERBOSE(0, m_ddnfs.display(verbose_stream()););
return run(); return run();
} }
@ -573,28 +622,29 @@ namespace datalog {
return l_undef; return l_undef;
} }
bool can_handle_rules() { bool process_rules() {
m_visited1.reset(); m_visited1.reset();
m_todo.reset(); m_todo.reset();
m_cache.reset();
rule_set const& rules = m_ctx.get_rules(); rule_set const& rules = m_ctx.get_rules();
datalog::rule_set::iterator it = rules.begin(); datalog::rule_set::iterator it = rules.begin();
datalog::rule_set::iterator end = rules.end(); datalog::rule_set::iterator end = rules.end();
for (; it != end; ++it) { for (; it != end; ++it) {
if (!can_handle_rule(**it)) { if (!process_rule(**it)) {
return false; return false;
} }
} }
return true; return true;
} }
bool can_handle_rule(rule const& r) { bool process_rule(rule const& r) {
// all predicates are monadic. // all predicates are monadic.
unsigned utsz = r.get_uninterpreted_tail_size(); unsigned utsz = r.get_uninterpreted_tail_size();
unsigned sz = r.get_tail_size(); unsigned sz = r.get_tail_size();
for (unsigned i = utsz; i < sz; ++i) { for (unsigned i = utsz; i < sz; ++i) {
m_todo.push_back(r.get_tail(i)); m_todo.push_back(r.get_tail(i));
} }
if (check_monadic()) { if (process_todo()) {
return true; return true;
} }
else { else {
@ -603,8 +653,7 @@ namespace datalog {
} }
} }
bool check_monadic() { bool process_todo() {
expr* e1, *e2;
while (!m_todo.empty()) { while (!m_todo.empty()) {
expr* e = m_todo.back(); expr* e = m_todo.back();
m_todo.pop_back(); m_todo.pop_back();
@ -626,21 +675,10 @@ namespace datalog {
m_todo.append(to_app(e)->get_num_args(), to_app(e)->get_args()); m_todo.append(to_app(e)->get_num_args(), to_app(e)->get_args());
continue; continue;
} }
if (m.is_eq(e, e1, e2)) {
if (is_var(e1) && is_ground(e2)) {
continue;
}
if (is_var(e2) && is_ground(e1)) {
continue;
}
if (is_var(e1) && is_var(e2)) {
continue;
}
}
if (is_ground(e)) { if (is_ground(e)) {
continue; continue;
} }
if (is_unary(e)) { if (process_atomic(e)) {
continue; continue;
} }
IF_VERBOSE(0, verbose_stream() << "Could not handle: " << mk_pp(e, m) << "\n";); IF_VERBOSE(0, verbose_stream() << "Could not handle: " << mk_pp(e, m) << "\n";);
@ -649,40 +687,84 @@ namespace datalog {
return true; return true;
} }
bool is_unary(expr* e) { bool process_atomic(expr* e) {
var* v = 0; expr* e1, *e2, *e3;
m_visited2.reset(); unsigned lo, hi;
unsigned sz = m_todo.size();
m_todo.push_back(e); if (m.is_eq(e, e1, e2) && bv.is_bv(e1)) {
while (m_todo.size() > sz) { if (is_var(e1) && is_ground(e2)) {
expr* e = m_todo.back(); return process_eq(e, to_var(e1), bv.get_bv_size(e1)-1, 0, e2);
m_todo.pop_back(); }
if (m_visited2.is_marked(e)) { if (is_var(e2) && is_ground(e1)) {
continue; return process_eq(e, to_var(e2), bv.get_bv_size(e2)-1, 0, e1);
}
if (bv.is_extract(e1, lo, hi, e3) && is_var(e3) && is_ground(e2)) {
return process_eq(e, to_var(e3), hi, lo, e2);
}
if (bv.is_extract(e2, lo, hi, e3) && is_var(e3) && is_ground(e1)) {
return process_eq(e, to_var(e3), hi, lo, e1);
}
if (is_var(e1) && is_var(e2)) {
std::cout << mk_pp(e, m) << "\n";
return true;
}
} }
m_visited2.mark(e, true);
if (is_var(e)) {
if (v && v != e) {
return false; return false;
} }
v = to_var(e);
} bool process_eq(expr* e, var* v, unsigned hi, unsigned lo, expr* c) {
else if (is_app(e)) { rational val;
m_todo.append(to_app(e)->get_num_args(), to_app(e)->get_args()); unsigned sz_c;
} unsigned sz_v = bv.get_bv_size(v);
else { if (!bv.is_numeral(c, val, sz_c)) {
return false; return false;
} }
if (!val.is_uint64()) {
return false;
} }
// v[hi:lo] = val
tbv tbv(val.get_uint64(), sz_v, hi, lo);
m_ddnfs.insert(tbv);
m_cache.insert(e, tbv);
std::cout << mk_pp(v, m) << " " << lo << " " << hi << " " << v << " " << tbv << "\n";
return true; return true;
} }
void compile(expr* e) { void compile(expr* phi) {
// TBD: // TBD:
// compiles monadic predicates into dots.
// saves the mapping from expr |-> dot // for each v
// such that atomic sub-formula can be expressed // associate a set of ddnf nodes that they can
// as a set of ddnf_nodes // take.
// - for each v, find the number of nodes associated with
// bit-width of v.
// - associate bit-vector for such nodes (the ids are consecutive).
// - compile formula into cross-product of such ranges.
// - disjunction requires special casing (which is not typical case)
// - negation over a tbv is the complemment of the set associated with
// the tbv.
// extract(hi, lo, v) == k
// |->
// tbv (set of nodes associated with tbv)
// compile(not (phi))
// |-> complement of ddnf nodes associated with phi
// compile (phi1 & phi2)
// |-> intersection of ddnf nodes associated with phi1, phi2
// compile (phi | phi2)
// |-> union
// v1 == v2
// no-op !!!
// extract(hi1, lo1, v) == extract(h2, lo2, v)
// |-> TBD
//
} }
}; };