diff --git a/examples/userPropagator/CMakeLists.txt b/examples/userPropagator/CMakeLists.txt index 9ed916d46..384d257dc 100644 --- a/examples/userPropagator/CMakeLists.txt +++ b/examples/userPropagator/CMakeLists.txt @@ -24,7 +24,15 @@ message(STATUS "Z3_FOUND: ${Z3_FOUND}") message(STATUS "Found Z3 ${Z3_VERSION_STRING}") message(STATUS "Z3_DIR: ${Z3_DIR}") -add_executable(user_propagator_example example.cpp) +add_executable(user_propagator_example + example.cpp + common.h + user_propagator.h + user_propagator_with_theory.h + user_propagator_subquery_maximisation.h + user_propagator_internal_maximisation.h + user_propagator_created_maximisation.h) + target_include_directories(user_propagator_example PRIVATE ${Z3_CXX_INCLUDE_DIRS}) target_link_libraries(user_propagator_example PRIVATE ${Z3_LIBRARIES}) diff --git a/examples/userPropagator/common.h b/examples/userPropagator/common.h new file mode 100644 index 000000000..3c9f3b299 --- /dev/null +++ b/examples/userPropagator/common.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "z3++.h" + +using std::to_string; + +#define SIZE(x) std::extent::value + +// #define VERBOSE // Log events +#ifdef VERBOSE +#define WriteEmptyLine std::cout << std::endl +#define WriteLine(x) std::cout << (x) << std::endl +#define Write(x) std::cout << x +#else +#define WriteEmptyLine +#define WriteLine(x) +#define Write(x) +#endif + +int log2i(unsigned n) { + if (n <= 0) { + return 0; + } + if (n <= 2) { + return 1; + } + unsigned l = 1; + int i = 0; + while (l < n) { + l <<= 1; + i++; + } + return i; +} + +typedef std::vector simple_model; + +// For putting z3 expressions in hash-tables +namespace std { + + template<> + struct hash { + std::size_t operator()(const simple_model &m) const { + size_t hash = 0; + for (unsigned i = 0; i < m.size(); i++) { + hash *= m.size(); + hash += m[i]; + } + return hash; + } + }; + + template<> + struct hash { + std::size_t operator()(const z3::expr &k) const { + return k.hash(); + } + }; + + // Do not use Z3's == operator in the hash table + template<> + struct equal_to { + bool operator()(const z3::expr &lhs, const z3::expr &rhs) const { + return z3::eq(lhs, rhs); + } + }; +} \ No newline at end of file diff --git a/examples/userPropagator/example.cpp b/examples/userPropagator/example.cpp index 1b6888798..9c3e7cf3e 100644 --- a/examples/userPropagator/example.cpp +++ b/examples/userPropagator/example.cpp @@ -1,13 +1,8 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "z3++.h" +#include "user_propagator.h" +#include "user_propagator_with_theory.h" +#include "user_propagator_subquery_maximisation.h" +#include "user_propagator_internal_maximisation.h" +#include "user_propagator_created_maximisation.h" /** * The program solves the n-queens problem (number of solutions) with 4 different approaches @@ -20,238 +15,57 @@ */ using namespace std::chrono; -using std::to_string; -#define QUEEN #define REPETITIONS 5 -#define SIZE(x) std::extent::value +#define MIN_BOARD 4 +#define MAX_BOARD1 12 +#define MAX_BOARD2 12 -#ifdef LOG -#define WriteEmptyLine std::cout << std::endl -#define WriteLine(x) std::cout << (x) << std::endl -#define Write(x) std::cout << x -#else -#define WriteEmptyLine -#define WriteLine(x) -#define Write(x) -#endif - -typedef std::vector model; - -struct model_hash_function { - std::size_t operator()(const model &m) const { - size_t hash = 0; - for (unsigned i = 0; i < m.size(); i++) { - hash *= m.size(); - hash += m[i]; - } - return hash; - } -}; - -namespace std { - - template<> - struct hash { - std::size_t operator()(const z3::expr &k) const { - return k.hash(); - } - }; -} - -// Do not use Z3's == operator in the hash table -namespace std { - - template<> - struct equal_to { - bool operator()(const z3::expr &lhs, const z3::expr &rhs) const { - return z3::eq(lhs, rhs); - } - }; -} - -class user_propagator : public z3::user_propagator_base { - -protected: - - unsigned board; - std::unordered_map& id_mapping; - model currentModel; - std::unordered_set modelSet; - std::vector fixedValues; - std::stack fixedCnt; - - int solutionId = 1; - -public: - - int getModelCount() const { - return solutionId - 1; - } - - void final() final { - z3::expr_vector conflicting(fixedValues[0].ctx()); - for (auto&& v : fixedValues) - conflicting.push_back(v); - this->conflict(conflicting); - if (modelSet.find(currentModel) != modelSet.end()) { - WriteLine("Got already computed model"); - return; - } - Write("Model #" << solutionId << ":\n"); - solutionId++; -#ifdef LOG - for (unsigned i = 0; i < fixedValues.size(); i++) { - unsigned id = fixedValues[i]; - WriteLine("q" + to_string(id_mapping[id]) + " = " + to_string(currentModel[id])); - } -#endif - modelSet.insert(currentModel); - WriteEmptyLine; - } - - static unsigned bvToInt(z3::expr e) { - return (unsigned)e.get_numeral_int(); - } - - void fixed(z3::expr const &ast, z3::expr const &value) override { - fixedValues.push_back(ast); - unsigned valueBv = bvToInt(value); - currentModel[id_mapping[ast]] = valueBv; - } - - user_propagator(z3::solver *s, std::unordered_map& idMapping, unsigned board) - : user_propagator_base(s), board(board), id_mapping(idMapping), currentModel(board, (unsigned)-1) { - - this->register_fixed(); - this->register_final(); - } - - ~user_propagator() = default; - - void push() override { - fixedCnt.push((unsigned) fixedValues.size()); - } - - void pop(unsigned num_scopes) override { - for (unsigned i = 0; i < num_scopes; i++) { - unsigned lastCnt = fixedCnt.top(); - fixedCnt.pop(); - for (auto j = fixedValues.size(); j > lastCnt; j--) { - currentModel[fixedValues[j - 1]] = (unsigned)-1; - } - fixedValues.erase(fixedValues.cbegin() + lastCnt, fixedValues.cend()); - } - } - - user_propagator_base *fresh(Z3_context) override { - return this; - } -}; - -class user_propagator_with_theory : public user_propagator { - -public: - - void fixed(z3::expr const &ast, z3::expr const &value) override { - unsigned queenId = id_mapping[ast]; - unsigned queenPos = bvToInt(value); - - if (queenPos >= board) { - z3::expr_vector conflicting(ast.ctx()); - conflicting.push_back(ast); - this->conflict(conflicting); - return; - } - - for (z3::expr fixed : fixedValues) { - unsigned otherId = id_mapping[fixed]; - unsigned otherPos = currentModel[fixed]; - - if (queenPos == otherPos) { - z3::expr_vector conflicting(ast.ctx()); - conflicting.push_back(ast); - conflicting.push_back(fixed); - this->conflict(conflicting); - continue; - } -#ifdef QUEEN - int diffY = abs((int)queenId - (int)otherId); - int diffX = abs((int)queenPos - (int)otherPos); - if (diffX == diffY) { - z3::expr_vector conflicting(ast.ctx()); - conflicting.push_back(ast); - conflicting.push_back(fixed); - this->conflict(conflicting); - } -#endif - } - - fixedValues.push_back(ast); - currentModel[id_mapping[ast]] = queenPos; - } - - user_propagator_with_theory(z3::solver *s, std::unordered_map& idMapping, unsigned board) - : user_propagator(s, idMapping, board) {} -}; - -int log2i(unsigned n) { - if (n <= 0) { - return 0; - } - if (n <= 2) { - return 1; - } - unsigned l = 1; - int i = 0; - while (l < n) { - l <<= 1; - i++; - } - return i; -} - -std::vector createQueens(z3::context &context, unsigned num) { - std::vector queens; - int bits = log2i(num) + 1 /*to detect potential overflow in the diagonal*/; +z3::expr_vector createQueens(z3::context &context, unsigned num, int bits, std::string prefix) { + z3::expr_vector queens(context); for (unsigned i = 0; i < num; i++) { - queens.push_back(context.bv_const((std::string("q") + to_string(i)).c_str(), bits)); + queens.push_back(context.bv_const((prefix + "q" + to_string(i)).c_str(), bits)); } return queens; } -void createConstraints(z3::context &context, z3::solver &solver, const std::vector &queens) { +z3::expr_vector createQueens(z3::context &context, unsigned num) { + return createQueens(context, num, log2i(num) + 1, ""); +} + +z3::expr createConstraints(z3::context &context, const z3::expr_vector &queens) { + z3::expr_vector assertions(context); for (unsigned i = 0; i < queens.size(); i++) { // assert column range - solver.add(z3::uge(queens[i], 0)); - solver.add(z3::ule(queens[i], (int) (queens.size() - 1))); + assertions.push_back(z3::uge(queens[i], 0)); + assertions.push_back(z3::ule(queens[i], (int) (queens.size() - 1))); } z3::expr_vector distinct(context); - for (const z3::expr &queen : queens) { + for (const z3::expr &queen: queens) { distinct.push_back(queen); } - solver.add(z3::distinct(distinct)); + assertions.push_back(z3::distinct(distinct)); -#ifdef QUEEN for (unsigned i = 0; i < queens.size(); i++) { for (unsigned j = i + 1; j < queens.size(); j++) { - solver.add((int)(j - i) != (queens[j] - queens[i])); - solver.add((int)(j - i) != (queens[i] - queens[j])); + assertions.push_back((int) (j - i) != (queens[j] - queens[i])); + assertions.push_back((int) (j - i) != (queens[i] - queens[j])); } } -#endif + + return z3::mk_and(assertions); } int test01(unsigned num, bool simple) { z3::context context; z3::solver solver(context, !simple ? Z3_mk_solver(context) : Z3_mk_simple_solver(context)); - std::vector queens = createQueens(context, num); + z3::expr_vector queens = createQueens(context, num); - createConstraints(context, solver, queens); + solver.add(createConstraints(context, queens)); int solutionId = 1; @@ -292,7 +106,7 @@ inline int test1(unsigned num) { int test23(unsigned num, bool withTheory) { z3::context context; - z3::solver solver(context, Z3_mk_simple_solver(context)); + z3::solver solver(context, z3::solver::simple()); std::unordered_map idMapping; user_propagator *propagator; @@ -303,7 +117,7 @@ int test23(unsigned num, bool withTheory) { propagator = new user_propagator_with_theory(&solver, idMapping, num); } - std::vector queens = createQueens(context, num); + z3::expr_vector queens = createQueens(context, num); for (unsigned i = 0; i < queens.size(); i++) { propagator->add(queens[i]); @@ -311,7 +125,7 @@ int test23(unsigned num, bool withTheory) { } if (!withTheory) { - createConstraints(context, solver, queens); + solver.add(createConstraints(context, queens)); } solver.check(); @@ -328,50 +142,246 @@ inline int test3(unsigned num) { return test23(num, true); } +int test4(unsigned num) { + z3::context context; + z3::solver solver(context, z3::solver::simple()); + + z3::expr_vector queens1 = createQueens(context, num, log2i(num * num), ""); // square to avoid overflow during summation + + z3::expr valid1 = createConstraints(context, queens1); + + z3::expr_vector queens2 = createQueens(context, num, log2i(num * num), "forall_"); + + z3::expr valid2 = createConstraints(context, queens2); + + z3::expr manhattanSum1 = context.bv_val(0, queens1[0].get_sort().bv_size()); + z3::expr manhattanSum2 = context.bv_val(0, queens2[0].get_sort().bv_size()); + + for (int i = 1; i < queens1.size(); i++) { + manhattanSum1 = manhattanSum1 + z3::ite(z3::uge(queens1[i], queens1[i - 1]), queens1[i] - queens1[i - 1], queens1[i - 1] - queens1[i]); + manhattanSum2 = manhattanSum2 + z3::ite(z3::uge(queens2[i], queens2[i - 1]), queens2[i] - queens2[i - 1], queens2[i - 1] - queens2[i]); + } + + + solver.add(valid1 && z3::forall(queens2, z3::implies(valid2, manhattanSum1 >= manhattanSum2))); + + solver.check(); + z3::model model = solver.get_model(); + + int max = 0; + + int prev, curr; + curr = model.eval(queens1[0]).get_numeral_int(); + + for (unsigned i = 1; i < num; i++) { + prev = curr; + curr = model.eval(queens1[i]).get_numeral_int(); + max += abs(curr - prev); + } + + return max; +} + +int test5(unsigned num) { + z3::context context; + z3::solver solver(context, z3::solver::simple()); + std::unordered_map idMapping; + + z3::expr_vector queens = createQueens(context, num, log2i(num * num), ""); + + solver.add(createConstraints(context, queens)); + + user_propagator_subquery_maximisation propagator(&solver, idMapping, num, queens); + + for (unsigned i = 0; i < queens.size(); i++) { + propagator.add(queens[i]); + idMapping[queens[i]] = i; + } + + solver.check(); + z3::model model = solver.get_model(); + + int max = 0; + + int prev, curr; + curr = model.eval(queens[0]).get_numeral_int(); + for (unsigned i = 1; i < num; i++) { + prev = curr; + curr = model.eval(queens[i]).get_numeral_int(); + max += abs(curr - prev); + } + + return max; +} + +int test6(unsigned num) { + z3::context context; + z3::solver solver(context, z3::solver::simple()); + std::unordered_map idMapping; + + z3::expr_vector queens = createQueens(context, num, log2i(num * num), ""); + + solver.add(createConstraints(context, queens)); + + user_propagator_internal_maximisation propagator(&solver, idMapping, num, queens); + + for (unsigned i = 0; i < queens.size(); i++) { + propagator.add(queens[i]); + idMapping[queens[i]] = i; + } + + solver.check(); + return propagator.best; +} + +int test7(unsigned num) { + z3::context context; + z3::solver solver(context, z3::solver::simple()); + + z3::expr_vector queens1 = createQueens(context, num, log2i(num * num), ""); + z3::expr_vector queens2 = createQueens(context, num, log2i(num * num), "forall_"); + + z3::expr manhattanSum1 = context.bv_val(0, queens1[0].get_sort().bv_size()); + z3::expr manhattanSum2 = context.bv_val(0, queens2[0].get_sort().bv_size()); + + for (int i = 1; i < queens1.size(); i++) { + manhattanSum1 = manhattanSum1 + z3::ite(z3::uge(queens1[i], queens1[i - 1]), queens1[i] - queens1[i - 1], queens1[i - 1] - queens1[i]); + manhattanSum2 = manhattanSum2 + z3::ite(z3::uge(queens2[i], queens2[i - 1]), queens2[i] - queens2[i - 1], queens2[i - 1] - queens2[i]); + } + + z3::sort_vector domain(context); + for (int i = 0; i < queens1.size(); i++) { + domain.push_back(queens1[i].get_sort()); + } + z3::func_decl validFunc = context.user_propagate_function(context.str_symbol("valid"), domain, context.bool_sort()); + + solver.add(validFunc(queens1) && z3::forall(queens2, z3::implies(validFunc(queens2), manhattanSum1 >= manhattanSum2))); + user_propagator_created_maximisation propagator(&solver, num); + + solver.check(); + z3::model model = solver.get_model(); + + int max = 0; + + int prev, curr; + curr = model.eval(queens1[0]).get_numeral_int(); + + for (unsigned i = 1; i < num; i++) { + prev = curr; + curr = model.eval(queens1[i]).get_numeral_int(); + max += abs(curr - prev); + } + + return max; +} + int main() { - for (int num = 4; num <= 11; num++) { + for (int num = MIN_BOARD; num <= MAX_BOARD1; num++) { + + std::cout << "num = " << num << ":\n" << std::endl; + + unsigned seed = (unsigned) high_resolution_clock::now().time_since_epoch().count(); + const char *testName[] = + { + "BV + Blocking clauses (Default solver)", + "BV + Blocking clauses (Simple solver)", + "BV + Adding conflicts", + "Custom theory + conflicts", + }; + int permutation[4] = {0, 1, 2, 3,}; + double timeResults[REPETITIONS * SIZE(permutation)]; + + for (int rep = 0; rep < REPETITIONS; rep++) { + // Execute strategies in a randomised order + std::shuffle(&permutation[0], &permutation[SIZE(permutation) - 1], std::default_random_engine(seed)); + + for (int i : permutation) { + int modelCount = -1; + + auto now1 = high_resolution_clock::now(); + + switch (i) { + case 0: + modelCount = test0(num); + break; + case 1: + modelCount = test1(num); + break; + case 2: + modelCount = test2(num); + break; + case 3: + modelCount = test3(num); + break; + default: + WriteLine("Unknown case"); + break; + } + auto now2 = high_resolution_clock::now(); + duration ms = now2 - now1; + std::cout << testName[i] << " took " << ms.count() << "ms (" << modelCount << " models)" << std::endl; + timeResults[rep * SIZE(permutation) + i] = ms.count(); + WriteLine("-------------"); + } + } + + std::cout << "\n" << std::endl; + + for (unsigned i = 0; i < SIZE(permutation); i++) { + std::cout << testName[i]; + double sum = 0; + for (int j = 0; j < REPETITIONS; j++) { + std::cout << " " << timeResults[j * SIZE(permutation) + i] << "ms"; + sum += timeResults[j * SIZE(permutation) + i]; + } + std::cout << " | avg: " << sum / REPETITIONS << "ms" << std::endl; + } + + std::cout << std::endl; + } + + z3::set_param("smt.ematching", "false"); + z3::set_param("smt.mbqi", "true"); + + std::cout << "\nMaximal distance:" << std::endl; + + for (int num = MIN_BOARD; num <= MAX_BOARD2; num++) { std::cout << "num = " << num << ":\n" << std::endl; unsigned seed = (unsigned) high_resolution_clock::now().time_since_epoch().count(); const char *testName[] = { - "BV + Blocking clauses (Default solver)", - "BV + Blocking clauses (Simple solver)", - "BV + Adding conflicts", - "Custom theory + conflicts", - }; - int permutation[4] = - { - 0, - 1, - 2, - 3, + "Ordinary/Direct Encoding", + "SubQuery in final", + "Assert Smaller in final", + "created", }; + int permutation[4] = {0, 1, 2, 3,}; double timeResults[REPETITIONS * SIZE(permutation)]; for (int rep = 0; rep < REPETITIONS; rep++) { // Execute strategies in a randomised order std::shuffle(&permutation[0], &permutation[SIZE(permutation) - 1], std::default_random_engine(seed)); - for (int i : permutation) { - int modelCount = -1; + for (int i: permutation) { + int max = -1; auto now1 = high_resolution_clock::now(); - switch (i) { - case 0: - modelCount = test0(num); + switch (i + 4) { + case 4: + max = test4(num); break; - case 1: - modelCount = test1(num); + case 5: + max = test5(num); break; - case 2: - modelCount = test2(num); + case 6: + max = test6(num); break; - case 3: - modelCount = test3(num); + case 7: + max = test7(num); break; default: WriteLine("Unknown case"); @@ -379,7 +389,7 @@ int main() { } auto now2 = high_resolution_clock::now(); duration ms = now2 - now1; - std::cout << testName[i] << " took " << ms.count() << "ms (" << modelCount << " models)" << std::endl; + std::cout << testName[i] << " took " << ms.count() << "ms. Max: " << max << std::endl; timeResults[rep * SIZE(permutation) + i] = ms.count(); WriteLine("-------------"); } @@ -399,4 +409,4 @@ int main() { std::cout << std::endl; } -} +} \ No newline at end of file diff --git a/examples/userPropagator/example.pdf b/examples/userPropagator/example.pdf index 2e802259c..eaf3c9952 100644 Binary files a/examples/userPropagator/example.pdf and b/examples/userPropagator/example.pdf differ diff --git a/examples/userPropagator/user_propagator.h b/examples/userPropagator/user_propagator.h new file mode 100644 index 000000000..6c12ee2f3 --- /dev/null +++ b/examples/userPropagator/user_propagator.h @@ -0,0 +1,87 @@ +#pragma once + +#include "common.h" + +class user_propagator : public z3::user_propagator_base { + +protected: + + unsigned board; + std::unordered_map &queenToY; + simple_model currentModel; + std::unordered_set modelSet; + z3::expr_vector fixedValues; + std::stack fixedCnt; + + int solutionNr = 1; + +public: + + int getModelCount() const { + return solutionNr - 1; + } + + void final() override { + this->conflict(fixedValues); + if (modelSet.find(currentModel) != modelSet.end()) { + WriteLine("Got already computed model"); + return; + } + Write("Model #" << solutionNr << ":\n"); + solutionNr++; +#ifdef VERBOSE + for (unsigned i = 0; i < fixedValues.size(); i++) { + z3::expr fixed = fixedValues[i]; + WriteLine("q" + to_string(queenToY[fixed]) + " = " + to_string(currentModel[queenToY[fixed]])); + } +#endif + modelSet.insert(currentModel); + WriteEmptyLine; + } + + static unsigned bvToInt(z3::expr const &e) { + return (unsigned) e.get_numeral_int(); + } + + void fixed(z3::expr const &ast, z3::expr const &value) override { + fixedValues.push_back(ast); + unsigned valueBv = bvToInt(value); + currentModel[queenToY[ast]] = valueBv; + } + + user_propagator(z3::context &c, std::unordered_map &queenToY, unsigned board) + : user_propagator_base(c), board(board), queenToY(queenToY), fixedValues(c), currentModel(board, (unsigned) -1) { + + this->register_fixed(); + this->register_final(); + } + + user_propagator(z3::solver *s, std::unordered_map &idMapping, unsigned board) + : user_propagator_base(s), board(board), queenToY(idMapping), fixedValues(s->ctx()), currentModel(board, (unsigned) -1) { + + this->register_fixed(); + this->register_final(); + } + + ~user_propagator() = default; + + void push() override { + fixedCnt.push((unsigned) fixedValues.size()); + } + + void pop(unsigned num_scopes) override { + for (unsigned i = 0; i < num_scopes; i++) { + unsigned lastCnt = fixedCnt.top(); + fixedCnt.pop(); + // Remove fixed values from model + for (unsigned j = fixedValues.size(); j > lastCnt; j--) { + currentModel[queenToY[fixedValues[j - 1]]] = (unsigned) -1; + } + fixedValues.resize(lastCnt); + } + } + + user_propagator_base *fresh(z3::context &) override { + return this; + } +}; \ No newline at end of file diff --git a/examples/userPropagator/user_propagator_created_maximisation.h b/examples/userPropagator/user_propagator_created_maximisation.h new file mode 100644 index 000000000..7ef93b8fe --- /dev/null +++ b/examples/userPropagator/user_propagator_created_maximisation.h @@ -0,0 +1,338 @@ +#pragma once + +#include "common.h" + +class user_propagator_created_maximisation : public z3::user_propagator_base { + + + std::unordered_map argToFcts; + std::unordered_map fctToArgs; + + std::unordered_map currentModel; + z3::expr_vector fixedValues; + std::vector fixedCnt; + + user_propagator_created_maximisation* childPropagator = nullptr; + user_propagator_created_maximisation* parentPropagator = nullptr; + + int board; + int nesting; // Just for logging (0 ... main solver; 1 ... sub-solver) + +public: + + user_propagator_created_maximisation(z3::context &c, user_propagator_created_maximisation* parentPropagator, unsigned board, int nesting) : + z3::user_propagator_base(c), fixedValues(c), parentPropagator(parentPropagator), board(board), nesting(nesting) { + + this->register_fixed(); + this->register_final(); + this->register_created(); + } + + user_propagator_created_maximisation(z3::solver *s, unsigned board) : + z3::user_propagator_base(s), fixedValues(s->ctx()), board(board), nesting(0) { + + this->register_fixed(); + this->register_final(); + this->register_created(); + } + + ~user_propagator_created_maximisation() { + delete childPropagator; + } + + void final() override { + WriteLine("Final (" + to_string(nesting) + ")"); + } + + void push() override { + WriteLine("Push (" + to_string(nesting) + ")"); + fixedCnt.push_back((unsigned) fixedValues.size()); + } + + void pop(unsigned num_scopes) override { + WriteLine("Pop (" + to_string(nesting) + ")"); + for (unsigned i = 0; i < num_scopes; i++) { + unsigned lastCnt = fixedCnt.back(); + fixedCnt.pop_back(); + for (auto j = fixedValues.size(); j > lastCnt; j--) { + currentModel.erase(fixedValues[j - 1]); + } + fixedValues.resize(lastCnt); + } + } + + void checkValidPlacement(std::vector &conflicts, const z3::expr &fct, const z3::expr_vector &args, const std::vector &argValues, int pos) { + unsigned queenId = pos; + unsigned queenPos = argValues[pos]; + z3::expr queenPosExpr = args[pos]; + + if (queenPos >= board) { + z3::expr_vector conflicting(ctx()); + conflicting.push_back(fct); + conflicting.push_back(queenPosExpr); + conflicts.push_back(conflicting); + return; + } + + for (unsigned otherId = 0; otherId < argValues.size(); otherId++) { + if (otherId == pos) + continue; + + unsigned otherPos = argValues[otherId]; + z3::expr otherPosExpr = args[otherId]; + + if (otherPos == (unsigned)-1) + continue; // We apparently do not have this value + + if (queenPos == otherPos) { + z3::expr_vector conflicting(ctx()); + conflicting.push_back(fct); + conflicting.push_back(queenPosExpr); + conflicting.push_back(otherPosExpr); + conflicts.push_back(conflicting); + } + int diffY = abs((int) queenId - (int) otherId); + int diffX = abs((int) queenPos - (int) otherPos); + if (diffX == diffY) { + z3::expr_vector conflicting(ctx()); + conflicting.push_back(fct); + conflicting.push_back(queenPosExpr); + conflicting.push_back(otherPosExpr); + conflicts.push_back(conflicting); + } + } + } + + unsigned getValues(const z3::expr &fct, std::vector &argValues) const { + z3::expr_vector args = fctToArgs.at(fct); + unsigned fixed = 0; + for (const z3::expr &arg: args) { + if (currentModel.contains(arg)) { + argValues.push_back(currentModel.at(arg)); + fixed++; + } + else + argValues.push_back((unsigned) -1); // no value so far + } + return fixed; + } + + + user_propagator_base *fresh(z3::context &ctx) override { + WriteLine("Fresh context"); + childPropagator = new user_propagator_created_maximisation(ctx, this, board, nesting + 1); + return childPropagator; + } + + void fixed(const z3::expr &expr, const z3::expr &value) override { + // Could be optimized! + WriteLine("Fixed (" + to_string(nesting) + ") " + expr.to_string() + " to " + value.to_string()); + unsigned v = value.is_true() ? 1 : (value.is_false() ? 0 : value.get_numeral_uint()); + currentModel[expr] = v; + fixedValues.push_back(expr); + + z3::expr_vector effectedFcts(ctx()); + bool fixedFct = fctToArgs.contains(expr); + + if (fixedFct) { + // fixed the value of a function + effectedFcts.push_back(expr); + } + else { + // fixed the value of a function's argument + effectedFcts = argToFcts.at(expr); + } + + for (const z3::expr& fct : effectedFcts) { + if (!currentModel.contains(fct)) + // we do not know yet whether to expect a valid or invalid placement + continue; + + std::vector values; + unsigned fixedArgsCnt = getValues(fct, values); + bool fctValue = currentModel[fct]; + z3::expr_vector args = fctToArgs.at(fct); + + if (!fctValue) { + // expect invalid placement ... + if (fixedArgsCnt != board) + // we expect an invalid placement, but not all queen positions have been placed yet + return; + std::vector conflicts; + for (unsigned i = 0; i < args.size(); i++) { + if (values[i] != (unsigned)-1) + checkValidPlacement(conflicts, expr, args, values, i); + } + + if (conflicts.empty()) { + // ... but we got a valid one + z3::expr_vector conflicting(ctx()); + conflicting.push_back(fct); + for (const z3::expr &arg: args) { + if (!arg.is_numeral()) + conflicting.push_back(arg); + } + this->conflict(conflicting); + } + else { + // ... and everything is fine; we have at least one conflict + } + } + else { + // expect valid placement ... + std::vector conflicts; + if (fixedFct){ + for (unsigned i = 0; i < args.size(); i++) { + if (values[i] != (unsigned)-1) // check all set queens + checkValidPlacement(conflicts, expr, args, values, i); + } + } + else { + for (unsigned i = 0; i < args.size(); i++) { + if (z3::eq(args[i], expr)) // only check newly fixed values + checkValidPlacement(conflicts, fct, args, values, i); + } + } + if (conflicts.size() > 0) { + // ... but we got an invalid one + for (const z3::expr_vector &conflicting: conflicts) + this->conflict(conflicting); + } + else { + // ... and everything is fine; no conflict + } + } + } + } + +// void fixed(const z3::expr &expr, const z3::expr &value) override { +// WriteLine("Fixed (" + to_string(nesting) + ") " + expr.to_string() + " to " + value.to_string()); +// unsigned v = value.is_true() ? 1 : (value.is_false() ? 0 : value.get_numeral_uint()); +// currentModel[expr] = v; +// fixedValues.push_back(expr); +// +// if (fctToArgs.contains(expr)) { +// // fixed the value of a function +// +// std::vector values; +// unsigned fixedArgsCnt = getValues(expr, values); +// +// if (!v && fixedArgsCnt != board) +// // we expect an invalid placement, but not all queen positions have been placed yet +// return; +// +// z3::expr_vector args = fctToArgs.at(expr); +// +// std::vector conflicts; +// for (unsigned i = 0; i < args.size(); i++) { +// if (values[i] != (unsigned)-1) +// checkValidPlacement(conflicts, expr, args, values, i); +// } +// if (v) { +// //we expected a valid queen placement +// if (conflicts.size() > 0) { +// // ... but we got an invalid one +// for (const z3::expr_vector &conflicting: conflicts) +// this->conflict(conflicting); +// } +// else { +// // everything fine; no conflict +// } +// } +// else { +// // we expect an invalid queen placement +// if (conflicts.empty()) { +// // ... but we got a valid one +// z3::expr_vector conflicting(ctx()); +// conflicting.push_back(expr); +// for (const z3::expr &arg: args) { +// if (!arg.is_numeral()) +// conflicting.push_back(arg); +// } +// this->conflict(conflicting); +// } +// else { +// // everything fine; we have at least one conflict +// } +// } +// } +// else { +// // fixed the value of a function argument +// +// z3::expr_vector effectedFcts = argToFcts.at(expr); +// +// for (const z3::expr& fct : effectedFcts) { +// if (!currentModel.contains(fct)) +// // we do not know yet whether to expect a valid or invalid placement +// continue; +// +// std::vector values; +// unsigned fixedArgsCnt = getValues(fct, values); +// bool fctValue = currentModel[fct]; +// z3::expr_vector args = fctToArgs.at(fct); +// +// if (!fctValue) { +// // expect invalid placement +// if (fixedArgsCnt != board) +// // we expect an invalid placement, but not all queen positions have been placed yet +// return; +// std::vector conflicts; +// for (unsigned i = 0; i < args.size(); i++) { +// if (values[i] != (unsigned)-1) +// checkValidPlacement(conflicts, expr, args, values, i); +// } +// +// if (conflicts.empty()) { +// // ... but we got a valid one +// z3::expr_vector conflicting(ctx()); +// conflicting.push_back(fct); +// for (const z3::expr &arg: args) { +// if (!arg.is_numeral()) +// conflicting.push_back(arg); +// } +// this->conflict(conflicting); +// } +// else { +// // everything fine; we have at least one conflict +// } +// } +// else { +// // expect valid placement +// std::vector conflicts; +// for (unsigned i = 0; i < args.size(); i++) { +// if (z3::eq(args[i], expr)) // only check newly fixed values +// checkValidPlacement(conflicts, fct, args, values, i); +// } +// if (conflicts.size() > 0) { +// // ... but we got an invalid one +// for (const z3::expr_vector &conflicting: conflicts) +// this->conflict(conflicting); +// } +// else { +// // everything fine; no conflict +// } +// } +// } +// } +// } + + void created(const z3::expr &func) override { + WriteLine("Created (" + to_string(nesting) + "): " + func.to_string()); + z3::expr_vector args = func.args(); + for (unsigned i = 0; i < args.size(); i++) { + z3::expr arg = args[i]; + + if (!arg.is_numeral()) { + WriteLine("Registered " + arg.to_string()); + this->add(arg); + } + else { + currentModel[arg] = arg.get_numeral_uint(); + // Skip registering as argument is a fixed BV; + } + + argToFcts.try_emplace(arg, ctx()).first->second.push_back(func); + } + fctToArgs.emplace(std::make_pair(func, args)); + } +}; \ No newline at end of file diff --git a/examples/userPropagator/user_propagator_internal_maximisation.h b/examples/userPropagator/user_propagator_internal_maximisation.h new file mode 100644 index 000000000..7a22270eb --- /dev/null +++ b/examples/userPropagator/user_propagator_internal_maximisation.h @@ -0,0 +1,30 @@ +#pragma once + +#include "user_propagator_with_theory.h" + +class user_propagator_internal_maximisation : public user_propagator_with_theory { + + z3::expr manhattanSum; + +public: + + int best = -1; + + user_propagator_internal_maximisation(z3::solver *s, std::unordered_map &idMapping, unsigned board, z3::expr_vector queens) + : user_propagator_with_theory(s, idMapping, board), + manhattanSum(s->ctx().bv_val(0, queens[0].get_sort().bv_size())) { + for (int i = 1; i < queens.size(); i++) { + manhattanSum = manhattanSum + z3::ite(z3::uge(queens[i], queens[i - 1]), queens[i] - queens[i - 1], queens[i - 1] - queens[i]); + } + } + + void final() override { + + int current = 0; + for (unsigned i = 1; i < board; i++) { + current += abs((signed) currentModel[i] - (signed) currentModel[i - 1]); + } + best = std::max(current, best); + this->propagate(z3::expr_vector(ctx()), z3::ugt(manhattanSum, best)); + } +}; \ No newline at end of file diff --git a/examples/userPropagator/user_propagator_subquery_maximisation.h b/examples/userPropagator/user_propagator_subquery_maximisation.h new file mode 100644 index 000000000..47382c435 --- /dev/null +++ b/examples/userPropagator/user_propagator_subquery_maximisation.h @@ -0,0 +1,51 @@ +#pragma once + +#include "user_propagator.h" + +class user_propagator_subquery_maximisation : public user_propagator { + + z3::expr assertion; + z3::expr_vector queens; + z3::expr manhattanSum; + +public: + + user_propagator_subquery_maximisation(z3::solver *s, std::unordered_map &idMapping, unsigned board, z3::expr_vector queens) + : user_propagator(s, idMapping, board), + assertion(mk_and(s->assertions())), + queens(queens), manhattanSum(s->ctx().bv_val(0, queens[0].get_sort().bv_size())) { + + for (int i = 1; i < queens.size(); i++) { + manhattanSum = manhattanSum + z3::ite(z3::uge(queens[i], queens[i - 1]), queens[i] - queens[i - 1], queens[i - 1] - queens[i]); + } + } + + void final() override { + + int max1 = 0; + for (unsigned i = 1; i < board; i++) { + max1 += abs((signed) currentModel[i] - (signed) currentModel[i - 1]); + } + z3::expr_vector vec(ctx()); + + int max2 = 0; + z3::solver subquery(ctx(), z3::solver::simple()); + + subquery.add(assertion); + subquery.add(z3::ugt(manhattanSum, max1)); + if (subquery.check() == z3::unsat) + return; // model is already maximal + + z3::model counterExample = subquery.get_model(); + + int prev, curr = -1; + + for (int i = 0; i < queens.size(); i++) { + prev = curr; + curr = counterExample.eval(queens[i]).get_numeral_int(); + if (i == 0) continue; + max2 += abs(curr - prev); + } + this->propagate(vec, z3::uge(manhattanSum, max2)); + } +}; \ No newline at end of file diff --git a/examples/userPropagator/user_propagator_with_theory.h b/examples/userPropagator/user_propagator_with_theory.h new file mode 100644 index 000000000..cd3e6f273 --- /dev/null +++ b/examples/userPropagator/user_propagator_with_theory.h @@ -0,0 +1,50 @@ +#pragma once + +#include "user_propagator.h" + +class user_propagator_with_theory : public user_propagator { + +public: + + user_propagator_with_theory(z3::context &c, std::unordered_map &idMapping, unsigned board) + : user_propagator(c, idMapping, board) {} + + user_propagator_with_theory(z3::solver *s, std::unordered_map &idMapping, unsigned board) + : user_propagator(s, idMapping, board) {} + + void fixed(z3::expr const &ast, z3::expr const &value) override { + unsigned queenId = queenToY[ast]; + unsigned queenPos = bvToInt(value); + + if (queenPos >= board) { + z3::expr_vector conflicting(ast.ctx()); + conflicting.push_back(ast); + this->conflict(conflicting); + return; + } + + for (const z3::expr &fixed: fixedValues) { + unsigned otherId = queenToY[fixed]; + unsigned otherPos = currentModel[queenToY[fixed]]; + + if (queenPos == otherPos) { + z3::expr_vector conflicting(ast.ctx()); + conflicting.push_back(ast); + conflicting.push_back(fixed); + this->conflict(conflicting); + continue; + } + int diffY = abs((int) queenId - (int) otherId); + int diffX = abs((int) queenPos - (int) otherPos); + if (diffX == diffY) { + z3::expr_vector conflicting(ast.ctx()); + conflicting.push_back(ast); + conflicting.push_back(fixed); + this->conflict(conflicting); + } + } + + fixedValues.push_back(ast); + currentModel[queenToY[ast]] = queenPos; + } +}; \ No newline at end of file