diff --git a/src/math/dd/dd_pdd.cpp b/src/math/dd/dd_pdd.cpp index cca2d5e04..291af3b5c 100644 --- a/src/math/dd/dd_pdd.cpp +++ b/src/math/dd/dd_pdd.cpp @@ -165,6 +165,47 @@ namespace dd { return true; } + unsigned pdd_manager::min_parity(PDD p) { + if (m_semantics != mod2N_e) + return 0; + + if (is_val(p)) { + rational v = val(p); + if (v.is_zero()) + return m_power_of_2 + 1; + unsigned r = 0; + while (v.is_even() && v > 0) + r++, v /= 2; + return r; + } + init_mark(); + PDD q = p; + m_todo.push_back(hi(q)); + while (!is_val(q)) { + q = lo(q); + m_todo.push_back(hi(q)); + } + unsigned p2 = val(q).trailing_zeros(); + init_mark(); + while (p2 != 0 && !m_todo.empty()) { + PDD r = m_todo.back(); + m_todo.pop_back(); + if (is_marked(r)) + continue; + set_mark(r); + if (!is_val(r)) { + m_todo.push_back(lo(r)); + m_todo.push_back(hi(r)); + } + else if (val(r).is_zero()) + continue; + else if (val(r).trailing_zeros() < p2) + p2 = val(r).trailing_zeros(); + } + m_todo.reset(); + return p2; + } + pdd pdd_manager::subst_val(pdd const& p, pdd const& s) { return pdd(apply(p.root, s.root, pdd_subst_val_op), this); } @@ -185,7 +226,20 @@ namespace dd { pdd v_val = mk_var(v) + val; return pdd(apply(s.root, v_val.root, pdd_subst_add_op), this); } - + + bool pdd_manager::subst_get(pdd const& s, unsigned v, rational& out_val) { + unsigned level_v = m_var2level[v]; + PDD p = s.root; + while (/* !is_val(p) && */ level(p) > level_v) { + SASSERT(is_val(lo(p))); + p = hi(p); + } + if (!is_val(p) && level(p) == level_v) { + out_val = val(lo(p)); + return true; + } + return false; + } pdd_manager::PDD pdd_manager::apply(PDD arg1, PDD arg2, pdd_op op) { bool first = true; @@ -1154,6 +1208,11 @@ namespace dd { return true; } + /** Return true iff p contains no variables other than v. */ + bool pdd_manager::is_univariate_in(PDD p, unsigned v) { + return (is_val(p) || var(p) == v) && is_univariate(p); + } + /** * Push coefficients of univariate polynomial in order of ascending degree. * Example: a*x^2 + b*x + c ==> [ c, b, a ] @@ -1532,7 +1591,6 @@ namespace dd { } void pdd_manager::gc() { - m_gc_generation++; init_dmark(); m_free_nodes.reset(); SASSERT(well_formed()); @@ -1617,26 +1675,26 @@ namespace dd { std::ostream& pdd_manager::display(std::ostream& out, pdd const& b) { auto mons = to_monomials(b); bool first = true; - for (auto& m : mons) { + for (auto& [a, vs] : mons) { if (!first) out << " "; - if (m.first.is_neg()) + if (a.is_neg()) out << "- "; else if (!first) out << "+ "; first = false; - rational c = abs(m.first); - m.second.reverse(); - if (!c.is_one() || m.second.empty()) { - if (m_semantics == mod2N_e && mod(-c, m_mod2N) < c) - out << -mod(-c, m_mod2N); - else + rational c = abs(a); + vs.reverse(); + if (!c.is_one() || vs.empty()) { + if (m_semantics == mod2N_e) + out << val_pp(*this, c, !vs.empty()); + else out << c; - if (!m.second.empty()) out << "*"; + if (!vs.empty()) out << "*"; } unsigned v_prev = UINT_MAX; unsigned pow = 0; - for (unsigned v : m.second) { + for (unsigned v : vs) { if (v == v_prev) { pow++; continue; @@ -1660,6 +1718,23 @@ namespace dd { return out; } + std::ostream& val_pp::display(std::ostream& out) const { + if (m.get_semantics() != pdd_manager::mod2N_e) + return out << val; + unsigned pow; + if (val.is_power_of_two(pow) && pow > 10) + return out << "2^" << pow; + for (int offset : {-2, -1, 1, 2}) + if (val < m.max_value() && (val - offset).is_power_of_two(pow) && pow > 10 && pow < m.power_of_2()) + return out << lparen() << "2^" << pow << (offset >= 0 ? "+" : "") << offset << rparen(); + rational neg_val = mod(-val, m.two_to_N()); + if (neg_val < val) { // keep this condition so we don't suddenly print negative values where we wouldn't otherwise + if (neg_val.is_power_of_two(pow) && pow > 10) + return out << "-2^" << pow; + } + return out << m.normalize(val); + } + bool pdd_manager::well_formed() { bool ok = true; for (unsigned n : m_free_nodes) { @@ -1737,6 +1812,13 @@ namespace dd { return p.val(); } + rational const& pdd::offset() const { + pdd p = *this; + while (!p.is_val()) + p = p.lo(); + return p.val(); + } + pdd pdd::shl(unsigned n) const { return (*this) * rational::power_of_two(n); } diff --git a/src/math/dd/dd_pdd.h b/src/math/dd/dd_pdd.h index aef0eb6f7..6dee7977f 100644 --- a/src/math/dd/dd_pdd.h +++ b/src/math/dd/dd_pdd.h @@ -10,7 +10,7 @@ Abstract: Poly DD package It is a mild variant of ZDDs. - In PDDs arithmetic is either standard or using mod 2 (over GF2). + In PDDs arithmetic is either standard or using mod 2^n. Non-leaf nodes are of the form x*hi + lo where @@ -208,7 +208,6 @@ namespace dd { rational m_mod2N; unsigned m_power_of_2 = 0; rational m_max_value; - unsigned m_gc_generation = 0; ///< will be incremented on each GC void reset_op_cache(); void init_nodes(unsigned_vector const& l2v); @@ -254,7 +253,9 @@ namespace dd { inline bool is_val(PDD p) const { return m_nodes[p].is_val(); } inline bool is_internal(PDD p) const { return m_nodes[p].is_internal(); } inline bool is_var(PDD p) const { return !is_val(p) && is_zero(lo(p)) && is_one(hi(p)); } + inline bool is_max(PDD p) const { SASSERT(m_semantics == mod2_e || m_semantics == mod2N_e); return is_val(p) && val(p) == max_value(); } bool is_never_zero(PDD p); + unsigned min_parity(PDD p); inline unsigned level(PDD p) const { return m_nodes[p].m_level; } inline unsigned var(PDD p) const { return m_level2var[level(p)]; } inline PDD lo(PDD p) const { return m_nodes[p].m_lo; } @@ -315,6 +316,11 @@ namespace dd { pdd_manager(unsigned num_vars, semantics s = free_e, unsigned power_of_2 = 0); ~pdd_manager(); + pdd_manager(pdd_manager const&) = delete; + pdd_manager(pdd_manager&&) = delete; + pdd_manager& operator=(pdd_manager const&) = delete; + pdd_manager& operator=(pdd_manager&&) = delete; + semantics get_semantics() const { return m_semantics; } void reset(unsigned_vector const& level2var); @@ -343,6 +349,7 @@ namespace dd { pdd subst_val(pdd const& a, unsigned v, rational const& val); pdd subst_val(pdd const& a, pdd const& s); pdd subst_add(pdd const& s, unsigned v, rational const& val); + bool subst_get(pdd const& s, unsigned v, rational& out_val); bool resolve(unsigned v, pdd const& p, pdd const& q, pdd& r); pdd reduce(unsigned v, pdd const& a, pdd const& b); void quot_rem(pdd const& a, pdd const& b, pdd& q, pdd& r); @@ -357,6 +364,7 @@ namespace dd { bool is_monomial(PDD p); bool is_univariate(PDD p); + bool is_univariate_in(PDD p, unsigned v); void get_univariate_coefficients(PDD p, vector& coeff); // create an spoly r if leading monomials of a and b overlap @@ -375,6 +383,8 @@ namespace dd { unsigned power_of_2() const { return m_power_of_2; } rational const& max_value() const { return m_max_value; } rational const& two_to_N() const { return m_mod2N; } + rational normalize(rational const& n) const { return mod(-n, m_mod2N) < n ? -mod(-n, m_mod2N) : n; } + unsigned_vector const& free_vars(pdd const& p); @@ -406,21 +416,26 @@ namespace dd { unsigned var() const { return m.var(root); } rational const& val() const { SASSERT(is_val()); return m.val(root); } rational const& leading_coefficient() const; + rational const& offset() const; bool is_val() const { return m.is_val(root); } bool is_one() const { return m.is_one(root); } bool is_zero() const { return m.is_zero(root); } bool is_linear() const { return m.is_linear(root); } bool is_var() const { return m.is_var(root); } - /** Polynomial is of the form a * x + b for numerals a, b. */ + bool is_max() const { return m.is_max(root); } + /** Polynomial is of the form a * x + b for some numerals a, b. */ bool is_unilinear() const { return !is_val() && lo().is_val() && hi().is_val(); } + /** Polynomial is of the form a * x for some numeral a. */ bool is_unary() const { return !is_val() && lo().is_zero() && hi().is_val(); } bool is_offset() const { return !is_val() && lo().is_val() && hi().is_one(); } bool is_binary() const { return m.is_binary(root); } bool is_monomial() const { return m.is_monomial(root); } bool is_univariate() const { return m.is_univariate(root); } + bool is_univariate_in(unsigned v) const { return m.is_univariate_in(root, v); } void get_univariate_coefficients(vector& coeff) const { m.get_univariate_coefficients(root, coeff); } vector get_univariate_coefficients() const { vector coeff; m.get_univariate_coefficients(root, coeff); return coeff; } bool is_never_zero() const { return m.is_never_zero(root); } + unsigned min_parity() const { return m.min_parity(root); } bool var_is_leaf(unsigned v) const { return m.var_is_leaf(root, v); } pdd operator-() const { return m.minus(*this); } @@ -455,7 +470,8 @@ namespace dd { pdd subst_val0(vector> const& s) const { return m.subst_val0(*this, s); } pdd subst_val(pdd const& s) const { return m.subst_val(*this, s); } pdd subst_val(unsigned v, rational const& val) const { return m.subst_val(*this, v, val); } - pdd subst_add(unsigned var, rational const& val) { return m.subst_add(*this, var, val); } + pdd subst_add(unsigned var, rational const& val) const { return m.subst_add(*this, var, val); } + bool subst_get(unsigned var, rational& out_val) const { return m.subst_get(*this, var, out_val); } /** * \brief substitute variable v by r. @@ -538,6 +554,18 @@ namespace dd { bool operator!=(pdd_iterator const& other) const { return m_nodes != other.m_nodes; } }; + class val_pp { + pdd_manager const& m; + rational const& val; + bool require_parens; + char const* lparen() const { return require_parens ? "(" : ""; } + char const* rparen() const { return require_parens ? ")" : ""; } + public: + val_pp(pdd_manager const& m, rational const& val, bool require_parens = false): m(m), val(val), require_parens(require_parens) {} + std::ostream& display(std::ostream& out) const; + }; + + inline std::ostream& operator<<(std::ostream& out, val_pp const& v) { return v.display(out); } } diff --git a/src/test/pdd.cpp b/src/test/pdd.cpp index a0946d81d..0c9b0f85c 100644 --- a/src/test/pdd.cpp +++ b/src/test/pdd.cpp @@ -571,6 +571,38 @@ public: } } + static void subst_get() { + std::cout << "subst_get\n"; + pdd_manager m(4, pdd_manager::mod2N_e, 32); + + unsigned const va = 0; + unsigned const vb = 1; + unsigned const vc = 2; + unsigned const vd = 3; + + rational val; + pdd s = m.one(); + std::cout << s << "\n"; + VERIFY(!s.subst_get(va, val)); + VERIFY(!s.subst_get(vb, val)); + VERIFY(!s.subst_get(vc, val)); + VERIFY(!s.subst_get(vd, val)); + + s = s.subst_add(va, rational(5)); + std::cout << s << "\n"; + VERIFY(s.subst_get(va, val) && val == 5); + VERIFY(!s.subst_get(vb, val)); + VERIFY(!s.subst_get(vc, val)); + VERIFY(!s.subst_get(vd, val)); + + s = s.subst_add(vc, rational(7)); + std::cout << s << "\n"; + VERIFY(s.subst_get(va, val) && val == 5); + VERIFY(!s.subst_get(vb, val)); + VERIFY(s.subst_get(vc, val) && val == 7); + VERIFY(!s.subst_get(vd, val)); + } + static void univariate() { std::cout << "univariate\n"; pdd_manager m(4, pdd_manager::mod2N_e, 4); @@ -671,6 +703,7 @@ void tst_pdd() { dd::test::binary_resolve(); dd::test::pow(); dd::test::subst_val(); + dd::test::subst_get(); dd::test::univariate(); dd::test::factors(); } diff --git a/src/util/debug.cpp b/src/util/debug.cpp index f97a2b57b..c9ca9fc31 100644 --- a/src/util/debug.cpp +++ b/src/util/debug.cpp @@ -75,32 +75,62 @@ bool is_debug_enabled(const char * tag) { return g_enabled_debug_tags->contains(tag); } +atomic g_default_debug_action(debug_action::ask); + +debug_action get_default_debug_action() { + return g_default_debug_action; +} + +void set_default_debug_action(debug_action a) { + g_default_debug_action = a; +} + +debug_action ask_debug_action(std::istream& in) { + std::cerr << "(C)ontinue, (A)bort, (S)top, (T)hrow exception, Invoke (G)DB\n"; + char result; + bool ok = bool(in >> result); + if (!ok) + exit(ERR_INTERNAL_FATAL); // happens if std::cin is eof or unattached. + switch(result) { + case 'C': + case 'c': + return debug_action::cont; + case 'A': + case 'a': + return debug_action::abort; + case 'S': + case 's': + return debug_action::stop; + case 't': + case 'T': + return debug_action::throw_exception; + case 'G': + case 'g': + return debug_action::invoke_debugger; + default: + std::cerr << "INVALID COMMAND\n"; + return debug_action::ask; + } +} + #if !defined(_WINDOWS) && !defined(NO_Z3_DEBUGGER) void invoke_gdb() { std::string buffer; - int * x = nullptr; + int *x = nullptr; + debug_action a = get_default_debug_action(); for (;;) { - std::cerr << "(C)ontinue, (A)bort, (S)top, (T)hrow exception, Invoke (G)DB\n"; - char result; - bool ok = bool(std::cin >> result); - if (!ok) exit(ERR_INTERNAL_FATAL); // happens if std::cin is eof or unattached. - switch(result) { - case 'C': - case 'c': + switch (a) { + case debug_action::cont: return; - case 'A': - case 'a': + case debug_action::abort: exit(1); - case 'S': - case 's': + case debug_action::stop: // force seg fault... *x = 0; return; - case 't': - case 'T': + case debug_action::throw_exception: throw default_exception("assertion violation"); - case 'G': - case 'g': + case debug_action::invoke_debugger: buffer = "gdb -nw /proc/" + std::to_string(getpid()) + "/exe " + std::to_string(getpid()); std::cerr << "invoking GDB...\n"; if (system(buffer.c_str()) == 0) { @@ -109,12 +139,13 @@ void invoke_gdb() { else { std::cerr << "error starting GDB...\n"; // forcing seg fault. - int * x = nullptr; + int *x = nullptr; *x = 0; } return; + case debug_action::ask: default: - std::cerr << "INVALID COMMAND\n"; + a = ask_debug_action(std::cin); } } } diff --git a/src/util/debug.h b/src/util/debug.h index 795976eac..5f092b181 100644 --- a/src/util/debug.h +++ b/src/util/debug.h @@ -19,10 +19,22 @@ Revision History: #pragma once #include +#include void enable_assertions(bool f); bool assertions_enabled(); +enum class debug_action { + ask, + cont, + abort, + stop, + throw_exception, + invoke_debugger, +}; +debug_action get_default_debug_action(); +void set_default_debug_action(debug_action a); + #include "util/error_codes.h" #include "util/warning.h" diff --git a/src/util/dlist.h b/src/util/dlist.h index 7efe5bb53..e5c95b8cf 100644 --- a/src/util/dlist.h +++ b/src/util/dlist.h @@ -17,20 +17,38 @@ Revision History: --*/ #pragma once +#include +#include "util/debug.h" +#include "util/util.h" +#define DLIST_EXTRA_ASSERTIONS 0 -template +template class dll_iterator; + +template class dll_base { - T* m_next { nullptr }; - T* m_prev { nullptr }; + T* m_next = nullptr; + T* m_prev = nullptr; + +protected: + dll_base() = default; + ~dll_base() = default; + public: + dll_base(dll_base const&) = delete; + dll_base(dll_base&&) = delete; + dll_base& operator=(dll_base const&) = delete; + dll_base& operator=(dll_base&&) = delete; T* prev() { return m_prev; } T* next() { return m_next; } + T const* prev() const { return m_prev; } + T const* next() const { return m_next; } void init(T* t) { m_next = t; m_prev = t; + SASSERT(invariant()); } static T* pop(T*& list) { @@ -41,23 +59,63 @@ public: return head; } - void insert_after(T* elem) { + void insert_after(T* other) { +#if DLIST_EXTRA_ASSERTIONS + SASSERT(other); + SASSERT(invariant()); + SASSERT(other->invariant()); + size_t const old_sz1 = count_if(*static_cast(this), [](T const&) { return true; }); + size_t const old_sz2 = count_if(*other, [](T const&) { return true; }); +#endif + // have: this -> next -> ... + // insert: other -> ... -> other_end + // result: this -> other -> ... -> other_end -> next -> ... T* next = this->m_next; - elem->m_prev = next->m_prev; - elem->m_next = next; - this->m_next = elem; - next->m_prev = elem; + T* other_end = other->m_prev; + this->m_next = other; + other->m_prev = static_cast(this); + other_end->m_next = next; + next->m_prev = other_end; +#if DLIST_EXTRA_ASSERTIONS + SASSERT(invariant()); + SASSERT(other->invariant()); + size_t const new_sz = count_if(*static_cast(this), [](T const&) { return true; }); + SASSERT_EQ(new_sz, old_sz1 + old_sz2); +#endif } - void insert_before(T* elem) { + void insert_before(T* other) { +#if DLIST_EXTRA_ASSERTIONS + SASSERT(other); + SASSERT(invariant()); + SASSERT(other->invariant()); + size_t const old_sz1 = count_if(*static_cast(this), [](T const&) { return true; }); + size_t const old_sz2 = count_if(*other, [](T const&) { return true; }); +#endif + // have: prev -> this -> ... + // insert: other -> ... -> other_end + // result: prev -> other -> ... -> other_end -> this -> ... T* prev = this->m_prev; - elem->m_next = prev->m_next; - elem->m_prev = prev; - prev->m_next = elem; - this->m_prev = elem; + T* other_end = other->m_prev; + prev->m_next = other; + other->m_prev = prev; + other_end->m_next = static_cast(this); + this->m_prev = other_end; +#if DLIST_EXTRA_ASSERTIONS + SASSERT(invariant()); + SASSERT(other->invariant()); + size_t const new_sz = count_if(*static_cast(this), [](T const&) { return true; }); + SASSERT_EQ(new_sz, old_sz1 + old_sz2); +#endif } static void remove_from(T*& list, T* elem) { +#if DLIST_EXTRA_ASSERTIONS + SASSERT(list); + SASSERT(elem); + SASSERT(list->invariant()); + SASSERT(elem->invariant()); +#endif if (list->m_next == list) { SASSERT(elem == list); list = nullptr; @@ -69,6 +127,9 @@ public: auto* prev = elem->m_prev; prev->m_next = next; next->m_prev = prev; +#if DLIST_EXTRA_ASSERTIONS + SASSERT(list->invariant()); +#endif } static void push_to_front(T*& list, T* elem) { @@ -105,11 +166,10 @@ public: return true; } - - static bool contains(T* list, T* elem) { + static bool contains(T const* list, T const* elem) { if (!list) return false; - T* first = list; + T const* first = list; do { if (list == elem) return true; @@ -120,5 +180,60 @@ public: } }; +template +class dll_iterator { + T const* m_elem; + bool m_first; + dll_iterator(T const* elem, bool first): m_elem(elem), m_first(first) { } +public: + static dll_iterator mk_begin(T const* elem) { + // Setting first==(bool)elem makes this also work for elem==nullptr; + // but we can't implement top-level begin/end for pointers because it clashes with the definition for arrays. + return {elem, (bool)elem}; + } + + static dll_iterator mk_end(T const* elem) { + return {elem, false}; + } + + using value_type = T; + using pointer = T const*; + using reference = T const&; + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + + dll_iterator& operator++() { + m_elem = m_elem->next(); + m_first = false; + return *this; + } + + T const& operator*() const { + return *m_elem; + } + + bool operator==(dll_iterator const& other) const { + return m_elem == other.m_elem && m_first == other.m_first; + } + + bool operator!=(dll_iterator const& other) const { + return !operator==(other); + } +}; + +template < typename T + , typename U = std::enable_if_t, T>> // should only match if T actually inherits from dll_base + > +dll_iterator begin(T const& elem) { + return dll_iterator::mk_begin(&elem); +} + +template < typename T + , typename U = std::enable_if_t, T>> // should only match if T actually inherits from dll_base + > +dll_iterator end(T const& elem) +{ + return dll_iterator::mk_end(&elem); +} diff --git a/src/util/map.h b/src/util/map.h index 602c042fb..e9880e0a0 100644 --- a/src/util/map.h +++ b/src/util/map.h @@ -33,6 +33,10 @@ struct _key_data { m_key(k), m_value(v) { } + _key_data(Key const& k, Value&& v): + m_key(k), + m_value(std::move(v)) { + } }; template @@ -106,6 +110,10 @@ public: void insert(key const & k, value const & v) { m_table.insert(key_data(k, v)); } + + void insert(key const& k, value&& v) { + m_table.insert(key_data(k, std::move(v))); + } bool insert_if_not_there_core(key const & k, value const & v, entry *& et) { return m_table.insert_if_not_there_core(key_data(k,v), et); diff --git a/src/util/mpq.h b/src/util/mpq.h index 31ffbeab8..e254ade69 100644 --- a/src/util/mpq.h +++ b/src/util/mpq.h @@ -487,6 +487,8 @@ public: void machine_div_rem(mpz const & a, mpz const & b, mpz & c, mpz & d) { mpz_manager::machine_div_rem(a, b, c, d); } + void machine_div2k(mpz const & a, unsigned k, mpz & c) { mpz_manager::machine_div2k(a, k, c); } + void div(mpz const & a, mpz const & b, mpz & c) { mpz_manager::div(a, b, c); } void rat_div(mpz const & a, mpz const & b, mpq & c) { @@ -513,6 +515,12 @@ public: machine_div(a.m_num, b.m_num, c); } + void machine_idiv2k(mpq const & a, unsigned k, mpq & c) { + SASSERT(is_int(a)); + machine_div2k(a.m_num, k, c.m_num); + reset_denominator(c); + } + void idiv(mpq const & a, mpq const & b, mpq & c) { SASSERT(is_int(a) && is_int(b)); div(a.m_num, b.m_num, c.m_num); diff --git a/src/util/rational.cpp b/src/util/rational.cpp index af3c89ced..54b40ac58 100644 --- a/src/util/rational.cpp +++ b/src/util/rational.cpp @@ -153,3 +153,21 @@ bool rational::mult_inverse(unsigned num_bits, rational & result) const { return true; } +/** + * Compute the smallest multiplicative pseudo-inverse modulo 2^num_bits: + * + * mod(n * n.pseudo_inverse(bits), 2^bits) == 2^k, + * where k is maximal such that 2^k divides n. + * + * Precondition: number is non-zero. + */ +rational rational::pseudo_inverse(unsigned num_bits) const { + rational result; + rational const& n = *this; + SASSERT(!n.is_zero()); // TODO: or we define pseudo-inverse of 0 as 0. + unsigned const k = n.trailing_zeros(); + rational const odd = machine_div2k(n, k); + VERIFY(odd.mult_inverse(num_bits - k, result)); + SASSERT_EQ(mod(n * result, rational::power_of_two(num_bits)), rational::power_of_two(k)); + return result; +} diff --git a/src/util/rational.h b/src/util/rational.h index 4203a54ea..f47fddefe 100644 --- a/src/util/rational.h +++ b/src/util/rational.h @@ -56,6 +56,8 @@ public: explicit rational(char const * v) { m().set(m_val, v); } + explicit rational(unsigned const * v, unsigned sz) { m().set(m_val, sz, v); } + struct i64 {}; rational(int64_t i, i64) { m().set(m_val, i); } @@ -227,6 +229,12 @@ public: return r; } + friend inline rational machine_div2k(rational const & r1, unsigned k) { + rational r; + rational::m().machine_idiv2k(r1.m_val, k, r.m_val); + return r; + } + friend inline rational mod(rational const & r1, rational const & r2) { rational r; rational::m().mod(r1.m_val, r2.m_val, r.m_val); @@ -353,6 +361,7 @@ public: } bool mult_inverse(unsigned num_bits, rational & result) const; + rational pseudo_inverse(unsigned num_bits) const; static rational const & zero() { return m_zero; diff --git a/src/util/tbv.h b/src/util/tbv.h index 2a337be1f..cffdc2460 100644 --- a/src/util/tbv.h +++ b/src/util/tbv.h @@ -27,10 +27,10 @@ Revision History: class tbv; enum tbit { - BIT_z = 0x0, - BIT_0 = 0x1, - BIT_1 = 0x2, - BIT_x = 0x3 + BIT_z = 0x0, // unknown + BIT_0 = 0x1, // for sure 0 + BIT_1 = 0x2, // for sure 1 + BIT_x = 0x3 // don't care }; inline tbit neg(tbit t) { @@ -43,6 +43,7 @@ class tbv_manager { ptr_vector allocated_tbvs; public: tbv_manager(unsigned n): m(2*n) {} + tbv_manager(tbv_manager const& m) = delete; ~tbv_manager(); void reset(); tbv* allocate(); @@ -132,8 +133,9 @@ class tbv_ref { tbv_manager& mgr; tbv* d; public: - tbv_ref(tbv_manager& mgr):mgr(mgr),d(nullptr) {} - tbv_ref(tbv_manager& mgr, tbv* d):mgr(mgr),d(d) {} + tbv_ref(tbv_manager& mgr) : mgr(mgr), d(nullptr) {} + tbv_ref(tbv_manager& mgr, tbv* d) : mgr(mgr), d(d) {} + tbv_ref(tbv_ref&& d) : mgr(d.mgr), d(d.detach()) {} ~tbv_ref() { if (d) mgr.deallocate(d); } @@ -144,8 +146,17 @@ public: } tbv& operator*() { return *d; } tbv* operator->() { return d; } - tbv* get() { return d; } + tbit operator[](unsigned idx) const { return (*d)[idx]; } + tbv* get() const { return d; } tbv* detach() { tbv* result = d; d = nullptr; return result; } + tbv_manager& manager() const { return mgr; } + unsigned num_tbits() const { return mgr.num_tbits(); } }; - +inline std::ostream& operator<<(std::ostream& out, tbv_ref const& c) { + char const* names[] = { "z", "0", "1", "x" }; + for (unsigned i = c.num_tbits(); i-- > 0; ) { + out << names[static_cast(c[i])]; + } + return out; +} diff --git a/src/util/util.h b/src/util/util.h index 121031492..6d4efb671 100644 --- a/src/util/util.h +++ b/src/util/util.h @@ -20,12 +20,14 @@ Revision History: #include "util/debug.h" #include "util/memory_manager.h" -#include -#include -#include -#include +#include +#include +#include +#include #include #include +#include +#include #ifndef SIZE_MAX #define SIZE_MAX std::numeric_limits::max() @@ -410,3 +412,36 @@ inline size_t megabytes_to_bytes(unsigned mb) { r = SIZE_MAX; return r; } + +/** Compact version of std::count */ +template +std::size_t count(Container const& c, Item x) +{ + using std::begin, std::end; // allows begin(c) to also find c.begin() + return std::count(begin(c), end(c), std::forward(x)); +} + +/** Compact version of std::count_if */ +template +std::size_t count_if(Container const& c, Predicate p) +{ + using std::begin, std::end; // allows begin(c) to also find c.begin() + return std::count_if(begin(c), end(c), std::forward(p)); +} + +/** Basic version of https://en.cppreference.com/w/cpp/experimental/scope_exit */ +template +class on_scope_exit final { + Callable m_ef; +public: + on_scope_exit(Callable&& ef) + : m_ef(std::forward(ef)) + { } + ~on_scope_exit() { + m_ef(); + } +}; + +/** Helper type for std::visit, see examples on https://en.cppreference.com/w/cpp/utility/variant/visit */ +template +struct always_false : std::false_type {}; diff --git a/src/util/var_queue.h b/src/util/var_queue.h index 62df77784..7245153ca 100644 --- a/src/util/var_queue.h +++ b/src/util/var_queue.h @@ -89,6 +89,10 @@ public: } return out; } + + using const_iterator = decltype(m_queue)::const_iterator; + const_iterator begin() const { return m_queue.begin(); } + const_iterator end() const { return m_queue.end(); } }; inline std::ostream& operator<<(std::ostream& out, var_queue const& queue) {