From 4a6053b289110ea67437451fdf974b79b36f2d62 Mon Sep 17 00:00:00 2001 From: Clemens Eisenhofer Date: Thu, 5 Jan 2023 18:02:21 +0100 Subject: [PATCH] Missing univariate for pseudo-inverse --- .../polysat/univariate/univariate_solver.cpp | 25 +++++++++++++++++++ .../polysat/univariate/univariate_solver.h | 8 +++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/math/polysat/univariate/univariate_solver.cpp b/src/math/polysat/univariate/univariate_solver.cpp index 73a2702ce..43a1c667a 100644 --- a/src/math/polysat/univariate/univariate_solver.cpp +++ b/src/math/polysat/univariate/univariate_solver.cpp @@ -231,6 +231,16 @@ namespace polysat { add(m.mk_eq(bv->mk_bv_not(mk_poly(in)), mk_poly(out)), sign, dep); } + void add_inv(univariate const& in, univariate const& out, bool sign, dep_t dep) override { + expr* input = mk_poly(in); + expr* output = mk_poly(out); + expr* parity = get_parity(in); + expr* one = bv->mk_numeral(1, bit_width); + + add(m.mk_eq(bv->mk_bv_shl(one, parity), bv->mk_bv_mul(input, output)), false, null_dep); + add(bv->mk_ule(output, bv->mk_bv_sub(bv->mk_bv_shl(one, bv->mk_bv_sub(bv->mk_numeral(bit_width, bit_width), parity)), one)), false, null_dep); // TODO: Depending on whether we want all pseudo-inverses + } + void add_ule_const(rational const& val, bool sign, dep_t dep) override { if (val == 0) add(m.mk_eq(x, mk_poly(val)), sign, dep); @@ -245,6 +255,21 @@ namespace polysat { void add_bit(unsigned idx, bool sign, dep_t dep) override { add(bv->mk_bit2bool(x, idx), sign, dep); } + + expr* get_parity(univariate const& in) override { + expr* v = mk_poly(in); + expr* parity = m.mk_fresh_const("parity", bv->mk_sort(bit_width)); + expr* parity_1 = bv->mk_bv_add(parity, bv->mk_numeral(1, bit_width)); + add(m.mk_ite( + m.mk_eq(v, bv->mk_numeral(0, bit_width)), + m.mk_eq(parity, bv->mk_numeral(bit_width, bit_width)), + m.mk_and( + m.mk_eq(bv->mk_bv_shl(bv->mk_bv_lshr(v, parity), parity), v), + m.mk_not(m.mk_eq(bv->mk_bv_shl(bv->mk_bv_lshr(v, parity_1), parity_1), v)) + ) + ), false, null_dep); + return parity; + } lbool check() override { return s->check_sat(); diff --git a/src/math/polysat/univariate/univariate_solver.h b/src/math/polysat/univariate/univariate_solver.h index a90ebe825..5bfe7324c 100644 --- a/src/math/polysat/univariate/univariate_solver.h +++ b/src/math/polysat/univariate/univariate_solver.h @@ -18,11 +18,8 @@ Author: #pragma once #include +#include "ast/ast.h" #include "util/lbool.h" -#include "util/rational.h" -#include "util/vector.h" -#include "util/util.h" - namespace polysat { @@ -107,6 +104,7 @@ namespace polysat { virtual void add_or(univariate const& in1, univariate const& in2, univariate const& out, bool sign, dep_t dep) = 0; virtual void add_xor(univariate const& in1, univariate const& in2, univariate const& out, bool sign, dep_t dep) = 0; virtual void add_not(univariate const& in, univariate const& out, bool sign, dep_t dep) = 0; + virtual void add_inv(univariate const& in, univariate const& out, bool sign, dep_t dep) = 0; /// Add x <= val or x > val, depending on sign virtual void add_ule_const(rational const& val, bool sign, dep_t dep) = 0; @@ -119,6 +117,8 @@ namespace polysat { virtual void add_bit(unsigned idx, bool sign, dep_t dep) = 0; void add_bit0(unsigned idx, dep_t dep) { add_bit(idx, true, dep); } void add_bit1(unsigned idx, dep_t dep) { add_bit(idx, false, dep); } + + virtual expr* get_parity(univariate const& in) = 0; virtual std::ostream& display(std::ostream& out) const = 0; };