diff --git a/src/smt/theory_card.cpp b/src/smt/theory_card.cpp index 6f54fdd4f..3c4cb21f4 100644 --- a/src/smt/theory_card.cpp +++ b/src/smt/theory_card.cpp @@ -15,11 +15,29 @@ Author: Notes: - - count number of clauses per cardinality constraint. - - when number of conflicts exceeds n^2 or n*log(n), then create a sorting circuit. - where n is the arity of the cardinality constraint. - - extra: do clauses get re-created? keep track of gc status of created clauses. + - Uses cutting plane simplification on 'k' for repeated literals. + In other words, if the gcd of the multiplicity of literals in c3 + is g, then divide through by g and truncate k. + + Example: + ((_ at-most 3) x1 x1 x2 x2) == ((_ at-most 1) x1 x2) + - count number of clauses per cardinality constraint. + + - TBD: when number of conflicts exceeds n^2 or n*log(n), + then create a sorting circuit. + where n is the arity of the cardinality constraint. + + - TBD: do clauses get re-created? keep track of gc + status of created clauses. + + - TBD: add conflict resolution + The idea is that if cardinality constraints c1, c2 + are repeatedly asserted together, then + resolve them into combined cardinality constraint c3 + + c1 /\ c2 -> c3 + --*/ #include "theory_card.h" diff --git a/src/test/sorting_network.cpp b/src/test/sorting_network.cpp index 904bbb970..a6ec9e5e2 100644 --- a/src/test/sorting_network.cpp +++ b/src/test/sorting_network.cpp @@ -2,6 +2,9 @@ #include "sorting_network.h" #include "vector.h" #include "ast.h" +#include "ast_pp.h" +#include "reg_decl_plugins.h" + struct ast_ext { ast_manager& m; @@ -38,23 +41,100 @@ struct unsigned_ext { } }; -void tst_sorting_network() { - svector vec; +static void is_sorted(svector const& v) { + for (unsigned i = 0; i + 1 < v.size(); ++i) { + SASSERT(v[i] <= v[i+1]); + } +} + +static void test_sorting1() { + svector in, out; unsigned_ext uext; - sorting_network sn(uext, vec); + sorting_network sn(uext); - svector in1; - in1.push_back(0); - in1.push_back(1); - in1.push_back(0); - in1.push_back(1); - in1.push_back(1); - in1.push_back(0); + in.push_back(0); + in.push_back(1); + in.push_back(0); + in.push_back(1); + in.push_back(1); + in.push_back(0); - sn(in1); + sn(in, out); - for (unsigned i = 0; i < vec.size(); ++i) { - std::cout << vec[i]; + is_sorted(out); + for (unsigned i = 0; i < out.size(); ++i) { + std::cout << out[i]; } std::cout << "\n"; } + +static void test_sorting2() { + svector in, out; + unsigned_ext uext; + sorting_network sn(uext); + + in.push_back(0); + in.push_back(1); + in.push_back(2); + in.push_back(1); + in.push_back(1); + in.push_back(3); + + sn(in, out); + + is_sorted(out); + + for (unsigned i = 0; i < out.size(); ++i) { + std::cout << out[i]; + } + std::cout << "\n"; +} + +static void test_sorting4_r(unsigned i, svector& in) { + if (i == in.size()) { + svector out; + unsigned_ext uext; + sorting_network sn(uext); + sn(in, out); + is_sorted(out); + std::cout << "sorted\n"; + } + else { + in[i] = 0; + test_sorting4_r(i+1, in); + in[i] = 1; + test_sorting4_r(i+1, in); + } +} + +static void test_sorting4() { + svector in; + in.resize(5); + test_sorting4_r(0, in); +} + +void test_sorting3() { + ast_manager m; + reg_decl_plugins(m); + expr_ref_vector in(m), out(m); + for (unsigned i = 0; i < 7; ++i) { + in.push_back(m.mk_fresh_const("a",m.mk_bool_sort())); + } + for (unsigned i = 0; i < in.size(); ++i) { + std::cout << mk_pp(in[i].get(), m) << "\n"; + } + ast_ext aext(m); + sorting_network sn(aext); + sn(in, out); + std::cout << "size: " << out.size() << "\n"; + for (unsigned i = 0; i < out.size(); ++i) { + std::cout << mk_pp(out[i].get(), m) << "\n"; + } +} + +void tst_sorting_network() { + test_sorting1(); + test_sorting2(); + test_sorting3(); + test_sorting4(); +} diff --git a/src/util/sorting_network.h b/src/util/sorting_network.h index d823cf44c..731bda8a9 100644 --- a/src/util/sorting_network.h +++ b/src/util/sorting_network.h @@ -7,7 +7,7 @@ template class sorting_network { - typename Ext::vector& m_es; + typedef typename Ext::vector vect; Ext& m_ext; svector m_currentv; svector m_nextv; @@ -17,21 +17,21 @@ unsigned& current(unsigned i) { return (*m_current)[i]; } unsigned& next(unsigned i) { return (*m_next)[i]; } - void exchange(unsigned i, unsigned j) { + void exchange(unsigned i, unsigned j, vect& out) { SASSERT(i <= j); if (i < j) { - Ext::T ei = m_es.get(i); - Ext::T ej = m_es.get(j); - m_es.set(i, m_ext.mk_ite(m_ext.mk_le(ei, ej), ei, ej)); - m_es.set(j, m_ext.mk_ite(m_ext.mk_le(ej, ei), ei, ej)); + Ext::T ei = out.get(i); + Ext::T ej = out.get(j); + out.set(i, m_ext.mk_ite(m_ext.mk_le(ei, ej), ei, ej)); + out.set(j, m_ext.mk_ite(m_ext.mk_le(ej, ei), ei, ej)); } } - void sort(unsigned k) { + void sort(unsigned k, vect& out) { SASSERT(is_power_of2(k) && k > 0); if (k == 2) { - for (unsigned i = 0; i < m_es.size()/2; ++i) { - exchange(current(2*i), current(2*i+1)); + for (unsigned i = 0; i < out.size()/2; ++i) { + exchange(current(2*i), current(2*i+1), out); next(2*i) = current(2*i); next(2*i+1) = current(2*i+1); } @@ -39,7 +39,7 @@ } else { - for (unsigned i = 0; i < m_es.size()/k; ++i) { + for (unsigned i = 0; i < out.size()/k; ++i) { unsigned ki = k * i; for (unsigned j = 0; j < k / 2; ++j) { next(ki + j) = current(ki + (2 * j)); @@ -48,8 +48,8 @@ } std::swap(m_current, m_next); - sort(k / 2); - for (unsigned i = 0; i < m_es.size() / k; ++i) { + sort(k / 2, out); + for (unsigned i = 0; i < out.size() / k; ++i) { unsigned ki = k * i; for (unsigned j = 0; j < k / 2; ++j) { next(ki + (2 * j)) = current(ki + j); @@ -57,7 +57,7 @@ } for (unsigned j = 0; j < (k / 2) - 1; ++j) { - exchange(next(ki + (2 * j) + 1), next(ki + (2 * (j + 1)))); + exchange(next(ki + (2 * j) + 1), next(ki + (2 * (j + 1))), out); } } std::swap(m_current, m_next); @@ -69,28 +69,28 @@ } public: - sorting_network(Ext& ext, typename Ext::vector& es): + sorting_network(Ext& ext): m_ext(ext), - m_es(es), m_current(&m_currentv), m_next(&m_nextv) {} - void operator()(typename Ext::vector const& inputs) { - if (inputs.size() <= 1) { + void operator()(vect const& in, vect& out) { + if (in.size() <= 1) { return; } - m_es.reset(); - m_es.append(inputs); - while (!is_power_of2(m_es.size())) { - m_es.push_back(m_ext.mk_default()); + out.reset(); + out.append(in); + while (!is_power_of2(out.size())) { + out.push_back(m_ext.mk_default()); } - for (unsigned i = 0; i < m_es.size(); ++i) { - current(i) = i; + for (unsigned i = 0; i < out.size(); ++i) { + m_currentv.push_back(i); + m_nextv.push_back(i); } unsigned k = 2; - while (k <= m_es.size()) { - sort(k); + while (k <= out.size()) { + sort(k, out); k *= 2; } }