diff --git a/src/sat/sat_bdd.cpp b/src/sat/sat_bdd.cpp index 4835b8580..36900760e 100644 --- a/src/sat/sat_bdd.cpp +++ b/src/sat/sat_bdd.cpp @@ -94,8 +94,8 @@ namespace sat { bool bdd_manager::check_result(op_entry*& e1, op_entry const* e2, BDD a, BDD b, BDD c) { if (e1 != e2) { + push_entry(e1); if (e2->m_bdd1 == a && e2->m_bdd2 == b && e2->m_op == c) { - push_entry(e1); return true; } e1 = const_cast(e2); @@ -343,12 +343,15 @@ namespace sat { bdd_manager::BDD bdd_manager::mk_quant_rec(unsigned l, BDD b, bdd_op op) { unsigned lvl = level(b); - - if (lvl == l) { - return apply(lo(b), hi(b), op); + BDD r; + if (is_const(b)) { + r = b; + } + else if (lvl == l) { + r = apply(lo(b), hi(b), op); } else if (lvl < l) { - return b; + r = b; } else { BDD a = level2bdd(l); @@ -356,14 +359,16 @@ namespace sat { op_entry * e1 = pop_entry(a, b, q_op); op_entry const* e2 = m_op_cache.insert_if_not_there(e1); if (check_result(e1, e2, a, b, q_op)) - return e2->m_result; - push(mk_quant_rec(l, lo(b), op)); - push(mk_quant_rec(l, hi(b), op)); - BDD r = make_node(lvl, read(2), read(1)); - pop(2); - e1->m_result = r; - return r; + r = e2->m_result; + else { + push(mk_quant_rec(l, lo(b), op)); + push(mk_quant_rec(l, hi(b), op)); + r = make_node(lvl, read(2), read(1)); + pop(2); + e1->m_result = r; + } } + return r; } double bdd_manager::count(bdd const& b, unsigned z) { @@ -495,7 +500,7 @@ namespace sat { return out; } - bdd::bdd(int root, bdd_manager* m): root(root), m(m) { m->inc_ref(root); } + bdd::bdd(unsigned root, bdd_manager* m): root(root), m(m) { m->inc_ref(root); } bdd::bdd(bdd & other): root(other.root), m(other.m) { m->inc_ref(root); } bdd::~bdd() { m->dec_ref(root); } bdd bdd::lo() const { return bdd(m->lo(root), m); } @@ -506,7 +511,7 @@ namespace sat { bdd bdd::operator!() { return m->mk_not(*this); } bdd bdd::operator&&(bdd const& other) { return m->mk_and(*this, other); } bdd bdd::operator||(bdd const& other) { return m->mk_or(*this, other); } - bdd& bdd::operator=(bdd const& other) { int r1 = root; root = other.root; m->inc_ref(root); m->dec_ref(r1); return *this; } + bdd& bdd::operator=(bdd const& other) { unsigned r1 = root; root = other.root; m->inc_ref(root); m->dec_ref(r1); return *this; } std::ostream& bdd::display(std::ostream& out) const { return m->display(out, *this); } std::ostream& operator<<(std::ostream& out, bdd const& b) { return b.display(out); } diff --git a/src/sat/sat_bdd.h b/src/sat/sat_bdd.h index da26c0b36..29d91446e 100644 --- a/src/sat/sat_bdd.h +++ b/src/sat/sat_bdd.h @@ -29,9 +29,9 @@ namespace sat { class bdd { friend class bdd_manager; - int root; + unsigned root; bdd_manager* m; - bdd(int root, bdd_manager* m); + bdd(unsigned root, bdd_manager* m); public: bdd(bdd & other); bdd& operator=(bdd const& other); @@ -57,7 +57,7 @@ namespace sat { class bdd_manager { friend bdd; - typedef int BDD; + typedef unsigned BDD; enum bdd_op { bdd_and_op = 2, @@ -70,7 +70,7 @@ namespace sat { }; struct bdd_node { - bdd_node(unsigned level, int lo, int hi): + bdd_node(unsigned level, BDD lo, BDD hi): m_refcount(0), m_level(level), m_lo(lo), @@ -80,8 +80,8 @@ namespace sat { bdd_node(): m_level(0), m_lo(0), m_hi(0), m_index(0) {} unsigned m_refcount : 10; unsigned m_level : 22; - int m_lo; - int m_hi; + BDD m_lo; + BDD m_hi; unsigned m_index; unsigned hash() const { return mk_mix(m_level, m_lo, m_hi); } }; @@ -176,7 +176,7 @@ namespace sat { inline bool is_true(BDD b) const { return b == true_bdd; } inline bool is_false(BDD b) const { return b == false_bdd; } - inline bool is_const(BDD b) const { return 0 <= b && b <= 1; } + inline bool is_const(BDD b) const { return b <= 1; } inline unsigned level(BDD b) const { return m_nodes[b].m_level; } inline unsigned var(BDD b) const { return m_level2var[level(b)]; } inline BDD lo(BDD b) const { return m_nodes[b].m_lo; } diff --git a/src/test/bdd.cpp b/src/test/bdd.cpp index 319f05bf8..f3f6e6b8d 100644 --- a/src/test/bdd.cpp +++ b/src/test/bdd.cpp @@ -33,9 +33,34 @@ namespace sat { SASSERT(!(v0 && v1) == (!v0 || !v1)); SASSERT(!(v0 || v1) == (!v0 && !v1)); } + + static void test3() { + bdd_manager m(20, 1000); + bdd v0 = m.mk_var(0); + bdd v1 = m.mk_var(1); + bdd v2 = m.mk_var(2); + bdd c1 = (v0 && v1) || v2; + bdd c2 = m.mk_exists(0, c1); + std::cout << c1 << "\n"; + std::cout << c2 << "\n"; + SASSERT(c2 == (v1 || v2)); + c2 = m.mk_exists(1, c1); + SASSERT(c2 == (v0 || v2)); + c2 = m.mk_exists(2, c1); + SASSERT(c2.is_true()); + SASSERT(m.mk_exists(3, c1) == c1); + c1 = (v0 && v1) || (v1 && v2) || (!v0 && !v2); + c2 = m.mk_exists(0, c1); + SASSERT(c2 == (v1 || (v1 && v2) || !v2)); + c2 = m.mk_exists(1, c1); + SASSERT(c2 == (v0 || v2 || (!v0 && !v2))); + c2 = m.mk_exists(2, c1); + SASSERT(c2 == ((v0 && v1) || v1 || !v0)); + } } void tst_bdd() { sat::test1(); sat::test2(); + sat::test3(); }