diff --git a/src/opt/CMakeLists.txt b/src/opt/CMakeLists.txt index 9c20b7d2d..21075d88c 100644 --- a/src/opt/CMakeLists.txt +++ b/src/opt/CMakeLists.txt @@ -14,6 +14,7 @@ z3_add_component(opt opt_solver.cpp pb_sls.cpp sortmax.cpp + totalizer.cpp wmax.cpp COMPONENT_DEPENDENCIES sat_solver diff --git a/src/opt/totalizer.cpp b/src/opt/totalizer.cpp new file mode 100644 index 000000000..fee66e5d7 --- /dev/null +++ b/src/opt/totalizer.cpp @@ -0,0 +1,122 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + totalizer.cpp + +Abstract: + + Incremental totalizer for at least constraints + +Author: + + Nikolaj Bjorner (nbjorner) 2022-06-27 + +--*/ + +#include "opt/totalizer.h" +#include "ast/ast_util.h" +#include "ast/ast_pp.h" +#include + +namespace opt { + + + void totalizer::ensure_bound(node* n, unsigned k) { + auto& lits = n->m_literals; + if (k > lits.size()) + return; + auto* l = n->m_left; + auto* r = n->m_right; + if (l) + ensure_bound(l, k); + if (r) + ensure_bound(r, k); + + for (unsigned i = k; i > 0 && !lits.get(i - 1); --i) { + if (l->m_literals.size() + r->m_literals.size() < i) { + lits[i - 1] = m.mk_false(); + continue; + } + + expr* c = m.mk_fresh_const("c", m.mk_bool_sort()); + lits[i - 1] = c; + + // >= 3 + // r[2] => >= 3 + // l[0] & r[1] => >= 3 + // l[1] & r[0] => >= 3 + // l[2] => >= 3 + + for (unsigned j1 = 0; j1 <= i; ++j1) { + unsigned j2 = i - j1; + if (j1 > l->m_literals.size()) + continue; + if (j2 > r->m_literals.size()) + continue; + expr_ref_vector clause(m); + if (0 < j1) { + expr* a = l->m_literals.get(j1 - 1); + clause.push_back(mk_not(m, a)); + } + if (0 < j2) { + expr* b = r->m_literals.get(j2 - 1); + clause.push_back(mk_not(m, b)); + } + if (clause.empty()) + continue; + clause.push_back(c); + m_clauses.push_back(clause); + } + } + } + + totalizer::totalizer(expr_ref_vector const& literals): + m(literals.m()), + m_literals(literals), + m_tree(nullptr) { + ptr_vector trees; + for (expr* e : literals) { + expr_ref_vector ls(m); + ls.push_back(e); + trees.push_back(alloc(node, ls)); + } + for (unsigned i = 0; i + 1 < trees.size(); i += 2) { + node* left = trees[i]; + node* right = trees[i + 1]; + expr_ref_vector ls(m); + ls.resize(left->m_literals.size() + right->m_literals.size()); + node* n = alloc(node, ls); + n->m_left = left; + n->m_right = right; + trees.push_back(n); + } + m_tree = trees.back(); + } + + totalizer::~totalizer() { + ptr_vector trees; + trees.push_back(m_tree); + while (!trees.empty()) { + node* n = trees.back(); + trees.pop_back(); + if (n->m_left) + trees.push_back(n->m_left); + if (n->m_right) + trees.push_back(n->m_right); + dealloc(n); + } + } + + expr* totalizer::at_least(unsigned k) { + if (k == 0) + return m.mk_true(); + if (m_tree->m_literals.size() < k) + return m.mk_false(); + SASSERT(1 <= k && k <= m_tree->m_literals.size()); + ensure_bound(m_tree, k); + return m_tree->m_literals.get(k - 1); + } + +} diff --git a/src/opt/totalizer.h b/src/opt/totalizer.h new file mode 100644 index 000000000..e68ac81b5 --- /dev/null +++ b/src/opt/totalizer.h @@ -0,0 +1,44 @@ +/*++ +Copyright (c) 2022 Microsoft Corporation + +Module Name: + + totalizer.h + +Abstract: + + Incremental totalizer for at least constraints + +Author: + + Nikolaj Bjorner (nbjorner) 2022-06-27 + +--*/ + +#pragma once +#include "ast/ast.h" + +namespace opt { + + class totalizer { + struct node { + node* m_left = nullptr; + node* m_right = nullptr; + expr_ref_vector m_literals; + node(expr_ref_vector& l): m_literals(l) {} + }; + + ast_manager& m; + expr_ref_vector m_literals; + node* m_tree; + vector m_clauses; + + void ensure_bound(node* n, unsigned k); + + public: + totalizer(expr_ref_vector const& literals); + ~totalizer(); + expr* at_least(unsigned k); + vector& clauses() { return m_clauses; } + }; +} diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 500cb4258..f959e9bd5 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -120,6 +120,7 @@ add_executable(test-z3 theory_pb.cpp timeout.cpp total_order.cpp + totalizer.cpp trigo.cpp udoc_relation.cpp uint_set.cpp diff --git a/src/test/main.cpp b/src/test/main.cpp index 6272c2dee..f9e4e0815 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -263,4 +263,5 @@ int main(int argc, char ** argv) { TST(solver_pool); //TST_ARGV(hs); TST(finder); + TST(totalizer); } diff --git a/src/test/totalizer.cpp b/src/test/totalizer.cpp new file mode 100644 index 000000000..13cebd7c7 --- /dev/null +++ b/src/test/totalizer.cpp @@ -0,0 +1,25 @@ +#include "opt/totalizer.h" +#include "ast/ast_pp.h" +#include "ast/reg_decl_plugins.h" +#include + +void tst_totalizer() { + std::cout << "totalizer\n"; + ast_manager m; + reg_decl_plugins(m); + expr_ref_vector lits(m); + for (unsigned i = 0; i < 5; ++i) + lits.push_back(m.mk_fresh_const("a", m.mk_bool_sort())); + opt::totalizer tot(lits); + + for (unsigned i = 0; i <= 6; ++i) { + std::cout << "at least " << i << " "; + expr* am = tot.at_least(i); + std::cout << mk_pp(am, m) << "\n"; + } + for (auto& clause : tot.clauses()) { + for (auto * l : clause) + std::cout << mk_pp(l, m) << " "; + std::cout << "\n"; + } +}