From 41e62fe1738723893ff3daed13bcbfc9fbc6071b Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 7 Sep 2025 13:53:29 -0700 Subject: [PATCH] add search tree template Signed-off-by: Nikolaj Bjorner --- src/test/CMakeLists.txt | 1 + src/test/main.cpp | 1 + src/test/search_tree.cpp | 190 ++++++++++++++++++++++++++++++++ src/util/search_tree.h | 230 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 422 insertions(+) create mode 100644 src/test/search_tree.cpp create mode 100644 src/util/search_tree.h diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 2b356f222..115c4e0f4 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -108,6 +108,7 @@ add_executable(test-z3 sat_user_scope.cpp scoped_timer.cpp scoped_vector.cpp + search_tree.cpp simple_parser.cpp simplex.cpp simplifier.cpp diff --git a/src/test/main.cpp b/src/test/main.cpp index 06ca91fbc..83c393025 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -271,4 +271,5 @@ int main(int argc, char ** argv) { TST(scoped_vector); TST(sls_seq_plugin); TST(ho_matcher); + TST(search_tree); } diff --git a/src/test/search_tree.cpp b/src/test/search_tree.cpp new file mode 100644 index 000000000..25bad2150 --- /dev/null +++ b/src/test/search_tree.cpp @@ -0,0 +1,190 @@ +#include "util/search_tree.h" +#include "util/trace.h" +#include +#include +#include +#include + + +// Initially there are no cubes. +// workers that enter at this stage will receive an empty cube to work on. +// If they succeeed, they return the empty conflict. +// If they fail, they generate two cubes, one with +id and one with -id +// and add them to the cube manager. + +struct literal { + using atom = unsigned; + atom a; + bool sign; + literal(atom a, bool s = false) : a(a), sign(s) {} + literal operator~() const { return literal(a, !sign); } + bool operator==(literal const& other) const { return a == other.a && sign == other.sign; } +}; + +inline std::ostream& operator<<(std::ostream& out, literal lit) { + if (lit.a == UINT_MAX) { + out << "null"; + return out; + } + if (!lit.sign) + out << "-"; + out << lit.a; + return out; +} + +struct literal_config { + using literal = literal; + static bool literal_is_null(literal const& l) { return l.a == UINT_MAX; } + static literal null_literal() { return literal(UINT_MAX); } + static std::ostream& display_literal(std::ostream& out, literal l) { return out << l; } +}; + + +using literal_vector = vector; + +inline std::ostream& operator<<(std::ostream& out, literal_vector const& v) { + out << "["; + for (unsigned i = 0; i < v.size(); ++i) { + if (i > 0) + out << " "; + out << v[i]; + } + out << "]"; + return out; +} + + +class cube_manager { + using node = search_tree::node; + using status = search_tree::status; + using literal = typename literal_config::literal; + std::mutex mutex; + std::condition_variable cv; + search_tree::tree tree; + unsigned num_workers = 0; + std::atomic num_waiting = 0; +public: + cube_manager(unsigned num_workers) : num_workers(num_workers), tree(literal_config::null_literal()) {} + ~cube_manager() {} + + void split(node* n, literal a, literal b) { + std::lock_guard lock(mutex); + IF_VERBOSE(1, verbose_stream() << "adding literal " << a << " and " << b << "\n";); + tree.split(n, a, b); + IF_VERBOSE(1, tree.display(verbose_stream());); + cv.notify_all(); + } + + bool get_cube(node*& n, literal_vector& cube) { + cube.reset(); + std::unique_lock lock(mutex); + node* t = nullptr; + while ((t = tree.activate_node(n)) == nullptr) { + // if all threads have reported they are done, then return false + // otherwise wait for condition variable + IF_VERBOSE(1, verbose_stream() << "waiting... " << "\n";); + if (tree.is_closed()) { + IF_VERBOSE(1, verbose_stream() << "all done\n";); + cv.notify_all(); + return false; + } + cv.wait(lock); + } + n = t; + while (t) { + if (literal_config::literal_is_null(t->get_literal())) + break; + cube.push_back(t->get_literal()); + t = t->parent(); + } +// IF_VERBOSE(1, verbose_stream() << "got cube " << cube << " from " << " " << t->get_status() << "\n";); + return true; + } + + void backtrack(node* n, literal_vector const& core) { + std::lock_guard lock(mutex); + IF_VERBOSE(1, verbose_stream() << "backtrack " << core << "\n"; tree.display(verbose_stream());); + tree.backtrack(n, core); + if (tree.is_closed()) { + IF_VERBOSE(1, verbose_stream() << "all done\n";); + cv.notify_all(); + } + } + +}; +class worker { + unsigned id; + cube_manager& cm; + random_gen m_rand; + + bool solve_cube(const literal_vector& cube) { + // dummy implementation + IF_VERBOSE(0, verbose_stream() << id << " solving " << cube << "\n";); + std::this_thread::sleep_for(std::chrono::milliseconds(50 + m_rand(100))); + // the deeper the cube, the more likely we are to succeed. + // 1 - (9/10)^(|cube|) success probability + if (cube.empty()) + return false; + double prob = m_rand(100); + double threshold = 100.0 * (1.0 - std::pow(9.0 / 10.0, cube.size())); + bool solved = prob < threshold; + IF_VERBOSE(0, verbose_stream() << id << (solved ? " solved " : " failed ") << cube << " " << prob << " " << threshold << "\n";); + return solved; + } + +public: + worker(unsigned id, cube_manager& cm) : id(id), cm(cm), m_rand(id) { + m_rand.set_seed(rand()); // make it random across runs + } + ~worker() {} + void run() { + literal_vector cube; + search_tree::node* n = nullptr; + while (cm.get_cube(n, cube)) { + if (solve_cube(cube)) { + literal_vector core; + for (auto l : cube) + if (m_rand(2) == 0) + core.push_back(l); + cm.backtrack(n, core); + } + else { + unsigned atom = 1 + cube.size() + 1000 * id; + literal lit(atom); + cm.split(n, lit, ~lit); + IF_VERBOSE(1, verbose_stream() << id << " getting new cube\n";); + } + } + } +}; + + +class parallel_cuber { + unsigned num_workers; + std::vector workers; + cube_manager cm; +public: + parallel_cuber(unsigned num_workers) : + num_workers(num_workers), + cm(num_workers) { + } + ~parallel_cuber() {} + + void start() { + for (unsigned i = 0; i < num_workers; ++i) + workers.push_back(new worker(i, cm)); + std::vector threads; + for (auto w : workers) + threads.push_back(std::thread([w]() { w->run(); })); + for (auto& t : threads) + t.join(); + for (auto w : workers) + delete w; + } +}; + + +void tst_search_tree() { + parallel_cuber sp(8); + sp.start(); +} \ No newline at end of file diff --git a/src/util/search_tree.h b/src/util/search_tree.h new file mode 100644 index 000000000..0e7968197 --- /dev/null +++ b/src/util/search_tree.h @@ -0,0 +1,230 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + search_tree.h + +Abstract: + + A binary search tree for managing the search space of a DPLL(T) solver. + It supports splitting on atoms, backtracking on conflicts, and activating nodes. + + Nodes can be in one of three states: open, closed, or active. + - Closed nodes are fully explored (both children are closed). + - Active nodes have no children and are currently being explored. + - Open nodes either have children that are open or are leaves. + + A node can be split if it is active. After splitting, it becomes open and has two open children. + + Backtracking on a conflict closes all nodes below the last node whose atom is in the conflict set. + + Activation searches an open node closest to a seed node. + +Author: + + Ilana Shapiro 2025-9-06 + +--*/ + +#include "util/util.h" +#include "util/vector.h" +#pragma once + +namespace search_tree { + + enum class status { open, closed, active }; + + template + class node { + typedef typename Config::literal literal; + literal m_literal; + node* m_left = nullptr, * m_right = nullptr, * m_parent = nullptr; + status m_status; + public: + node(literal const& l, node* parent) : + m_literal(l), m_parent(parent), m_status(status::open) {} + ~node() { + dealloc(m_left); + dealloc(m_right); + } + + status get_status() const { return m_status; } + void set_status(status s) { m_status = s; } + literal const& get_literal() const { return m_literal; } + void set_literal(literal const& l) { m_literal = l; } + bool literal_is_null() const { return Config::is_null(m_literal); } + void split(literal const& a, literal const& b) { + if (m_status != status::active) + return; + SASSERT(!m_left); + SASSERT(!m_right); + m_left = alloc(node, a, this); + m_right = alloc(node, b, this); + m_status = status::open; + } + + node* left() const { return m_left; } + node* right() const { return m_right; } + node* parent() const { return m_parent; } + + void display(std::ostream& out, unsigned indent) const { + for (unsigned i = 0; i < indent; ++i) + out << " "; + Config::display_literal(out, m_literal); + out << (get_status() == status::open ? " (o)" : get_status() == status::closed ? " (c)" : " (a)"); + out << "\n"; + if (m_left) + m_left->display(out, indent + 2); + if (m_right) + m_right->display(out, indent + 2); + } + }; + + template + class tree { + typedef typename Config::literal literal; + scoped_ptr> m_root = nullptr; + literal m_null_literal; + random_gen m_rand; + + // return an active node in the subtree rooted at n, or nullptr if there is none + // close nodes that are fully explored (whose children are all closed) + node* activate_from_root(node* n) { + if (!n) + return nullptr; + if (n->get_status() != status::open) + return nullptr; + auto left = n->left(); + auto right = n->right(); + if (!left && !right) { + n->set_status(status::active); + return n; + } + node* nodes[2] = { left, right }; + unsigned index = m_rand(2); + auto child = activate_from_root(nodes[index]); + if (child) + return child; + child = activate_from_root(nodes[1 - index]); + if (child) + return child; + if (left && right && left->get_status() == status::closed && right->get_status() == status::closed) + n->set_status(status::closed); + return nullptr; + } + + void close_node(node* n) { + if (!n) + return; + if (n->get_status() == status::closed) + return; + n->set_status(status::closed); + close_node(n->left()); + close_node(n->right()); + } + + public: + + tree(literal const& null_literal) : m_null_literal(null_literal) { + m_root = alloc(node, m_null_literal, nullptr); + m_root->set_status(status::active); + } + + void set_seed(unsigned seed) { + m_rand.set_seed(seed); + } + + // Split current node if it is active. + // After the call, n is open and has two children. + void split(node* n, literal const& a, literal const& b) { + SASSERT(!Config::literal_is_null(a)); + SASSERT(!Config::literal_is_null(b)); + if (n->get_status() == status::active) { + n->split(a, b); + n->set_status(status::open); + } + } + + // conflict is given by a set of atoms. + // they are a subset of atoms on the path from root to n + void backtrack(node* n, vector const& conflict) { + if (conflict.empty()) { + close_node(m_root.get()); + m_root->set_status(status::closed); + return; + } + SASSERT(n != m_root.get()); + // all literals in conflict are on the path from root to n + DEBUG_CODE( + auto on_path = [&](literal const& a) { + node* p = n; + while (p) { + if (p->get_literal() == a) + return true; + p = p->parent(); + } + return false; + }; + SASSERT(all_of(conflict, [&](auto const& a) { return on_path(a); })); + ); + + while (n) { + if (any_of(conflict, [&](auto const& a) { return a == n->get_literal(); })) { + close_node(n); + return; + } + n = n->parent(); + } + UNREACHABLE(); + } + + // return an active node in the tree, or nullptr if there is none + // first check if there is a node to activate under n, + // if not, go up the tree and try to activate a sibling subtree + node* activate_node(node* n) { + if (!n) { + if (m_root->get_status() == status::active) + return m_root.get(); + n = m_root.get(); + } + auto res = activate_from_root(n); + if (res) + return res; + while (n) { + if (n->left() && n->left()->get_status() == status::closed && + n->right() && n->right()->get_status() == status::closed) { + n->set_status(status::closed); + n = n->parent(); + continue; + } + auto p = n->parent(); + if (!p) + return nullptr; + if (n == p->left()) { + res = activate_from_root(p->right()); + if (res) + return res; + } + else { + SASSERT(n == p->right()); + res = activate_from_root(p->left()); + if (res) + return res; + } + n = p; + } + return nullptr; + } + + bool is_closed() const { + return m_root->get_status() == status::closed; + } + + std::ostream& display(std::ostream& out) const { + m_root->display(out, 0); + return out; + } + + }; +}