From 624907823dbf10e5df874eb72989d6afa807edf8 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 13 Apr 2023 11:19:06 -0700 Subject: [PATCH] add tests for distribution utility and fix loose ends --- src/test/CMakeLists.txt | 1 + src/test/distribution.cpp | 45 +++++++++++++++++++++++++++++++++++++++ src/test/main.cpp | 1 + src/util/distribution.h | 16 ++++++++------ 4 files changed, 57 insertions(+), 6 deletions(-) create mode 100644 src/test/distribution.cpp diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index f959e9bd5..df3010295 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -29,6 +29,7 @@ add_executable(test-z3 datalog_parser.cpp ddnf.cpp diff_logic.cpp + distribution.cpp dl_context.cpp dl_product_relation.cpp dl_query.cpp diff --git a/src/test/distribution.cpp b/src/test/distribution.cpp new file mode 100644 index 000000000..c67757737 --- /dev/null +++ b/src/test/distribution.cpp @@ -0,0 +1,45 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + distribution.cpp + +Abstract: + + Test distribution + +Author: + + Nikolaj Bjorner (nbjorner) 2023-04-13 + + +--*/ +#include "util/distribution.h" +#include + +static void tst1() { + distribution dist(1); + dist.push(1, 3); + dist.push(2, 1); + dist.push(3, 1); + dist.push(4, 1); + + unsigned counts[4] = { 0, 0, 0, 0 }; + for (unsigned i = 0; i < 1000; ++i) + counts[dist.choose()-1]++; + for (unsigned i = 1; i <= 4; ++i) + std::cout << "count " << i << ": " << counts[i-1] << "\n"; + + for (unsigned i = 0; i < 5; ++i) { + std::cout << "enum "; + for (auto j : dist) + std::cout << j << " "; + std::cout << "\n"; + } + +} + +void tst_distribution() { + tst1(); +} diff --git a/src/test/main.cpp b/src/test/main.cpp index f9e4e0815..7cd4b6cf9 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -264,4 +264,5 @@ int main(int argc, char ** argv) { //TST_ARGV(hs); TST(finder); TST(totalizer); + TST(distribution); } diff --git a/src/util/distribution.h b/src/util/distribution.h index de385a08e..0ed63d510 100644 --- a/src/util/distribution.h +++ b/src/util/distribution.h @@ -1,5 +1,5 @@ /*++ -Copyright (c) 2017 Microsoft Corporation +Copyright (c) 2023 Microsoft Corporation Module Name: @@ -18,6 +18,8 @@ Notes: Distribution class works by pushing identifiers with associated scores. After they have been pushed, you can access a random element using choose or you can enumerate the elements in random order, sorted by the score probability. + Only one iterator can be active at a time because the iterator reshuffles the registered elements. + The requirement is not checked or enforced. --*/ #pragma once @@ -32,10 +34,12 @@ class distribution { unsigned choose(unsigned sum) { unsigned s = m_random(sum); + unsigned idx = 0; for (auto const& [j, score] : m_elems) { if (s < score) - return j; + return idx; s -= score; + ++idx; } UNREACHABLE(); return 0; @@ -76,9 +80,8 @@ public: unsigned m_sum = 0; unsigned m_index = 0; void next_index() { - if (0 == m_sz) - return; - m_index = d.choose(m_sum); + if (0 != m_sz) + m_index = d.choose(m_sum); } public: iterator(distribution& d, bool start): d(d), m_sz(start?d.m_elems.size():0), m_sum(d.m_sum) { @@ -88,8 +91,9 @@ public: iterator operator++() { m_sum -= d.m_elems[m_index].second; --m_sz; - std::swap(d.m_elems[m_index], d.m_elems[d.m_elems.size() - 1]); + std::swap(d.m_elems[m_index], d.m_elems[m_sz]); next_index(); + return *this; } bool operator==(iterator const& other) const { return m_sz == other.m_sz; } bool operator!=(iterator const& other) const { return m_sz != other.m_sz; }