3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 00:26:38 +00:00

Fix mk_slice, add mk_extract/mk_concat

This commit is contained in:
Jakob Rath 2023-06-16 11:48:01 +02:00
parent 982170e6e0
commit 8a50467ba8
4 changed files with 134 additions and 126 deletions

View file

@ -33,7 +33,7 @@ namespace polysat {
case trail_item::add_var: undo_add_var(); break;
case trail_item::alloc_slice: undo_alloc_slice(); break;
case trail_item::split_slice: undo_split_slice(); break;
case trail_item::merge_class: undo_merge_class(); break;
case trail_item::merge_base: undo_merge_base(); break;
default: UNREACHABLE();
}
m_trail.pop_back();
@ -117,7 +117,7 @@ namespace polysat {
}
}
void slicing::merge(slice s1, slice s2) {
void slicing::merge_base(slice s1, slice s2) {
SASSERT_EQ(width(s1), width(s2));
SASSERT(!has_sub(s1));
SASSERT(!has_sub(s2));
@ -137,11 +137,11 @@ namespace polysat {
// otherwise the classes should have been merged already
SASSERT(m_slice2var[r2] != m_slice2var[r1]);
}
m_trail.push_back(trail_item::merge_class);
m_trail.push_back(trail_item::merge_base);
m_merge_trail.push_back(r1);
}
void slicing::undo_merge_class() {
void slicing::undo_merge_base() {
slice r1 = m_merge_trail.back();
m_merge_trail.pop_back();
slice r2 = m_find[r1];
@ -161,11 +161,21 @@ namespace polysat {
slice y = ys.back();
xs.pop_back();
ys.pop_back();
if (has_sub(x)) {
find_base(x, xs);
x = xs.back();
xs.pop_back();
}
if (has_sub(y)) {
find_base(y, ys);
y = ys.back();
ys.pop_back();
}
SASSERT(!has_sub(x));
SASSERT(!has_sub(y));
if (width(x) == width(y)) {
// LOG("Match " << x << " and " << y);
merge(x, y);
merge_base(x, y);
}
else if (width(x) > width(y)) {
// need to split x according to y
@ -190,6 +200,15 @@ namespace polysat {
merge(xs, tmp);
}
void slicing::merge(slice x, slice y) {
if (!has_sub(x) && !has_sub(y))
return merge_base(x, y);
slice_vector tmpx, tmpy;
tmpx.push_back(x);
tmpy.push_back(y);
merge(tmpx, tmpy);
}
void slicing::find_base(slice src, slice_vector& out_base) const {
// splits are only stored for the representative
SASSERT_EQ(src, find(src));
@ -213,12 +232,18 @@ namespace polysat {
SASSERT(todo.empty());
}
void slicing::mk_slice(slice src, unsigned const hi, unsigned const lo, slice_vector& out_base, bool output_full_src) {
void slicing::mk_slice(slice src, unsigned const hi, unsigned const lo, slice_vector& out, bool output_full_src, bool output_base) {
SASSERT(hi >= lo);
SASSERT_EQ(src, find(src)); // splits are only stored for the representative
SASSERT(width(src) > hi); // extracted range must be fully contained inside the src slice
auto output_slice = [this, output_base, &out](slice s) {
if (output_base)
find_base(s, out);
else
out.push_back(s);
};
if (lo == 0 && width(src) - 1 == hi) {
find_base(src, out_base);
output_slice(src);
return;
}
if (has_sub(src)) {
@ -226,23 +251,23 @@ namespace polysat {
unsigned const cut = m_slice_cut[src];
if (lo >= cut + 1) {
// target slice falls into upper subslice
mk_slice(find_sub_hi(src), hi - cut - 1, lo - cut - 1, out_base);
mk_slice(find_sub_hi(src), hi - cut - 1, lo - cut - 1, out, output_full_src, output_base);
if (output_full_src)
out_base.push_back(find_sub_lo(src));
output_slice(find_sub_lo(src));
return;
}
else if (cut >= hi) {
// target slice falls into lower subslice
if (output_full_src)
out_base.push_back(find_sub_hi(src));
mk_slice(find_sub_lo(src), hi, lo, out_base);
output_slice(find_sub_hi(src));
mk_slice(find_sub_lo(src), hi, lo, out, output_full_src, output_base);
return;
}
else {
SASSERT(hi > cut && cut >= lo);
// desired range spans over the cutpoint, so we get multiple slices in the result
mk_slice(find_sub_hi(src), hi - cut - 1, 0, out_base);
mk_slice(find_sub_lo(src), cut, lo, out_base);
mk_slice(find_sub_hi(src), hi - cut - 1, 0, out, output_full_src, output_base);
mk_slice(find_sub_lo(src), cut, lo, out, output_full_src, output_base);
return;
}
}
@ -250,41 +275,42 @@ namespace polysat {
// [src.width-1, 0] has no subdivision yet
if (width(src) - 1 > hi) {
split(src, hi);
SASSERT(!has_sub(find_sub_hi(src)));
if (output_full_src)
out_base.push_back(find_sub_hi(src));
mk_slice(find_sub_lo(src), hi, lo, out_base); // recursive call to take care of case lo > 0
out.push_back(find_sub_hi(src));
mk_slice(find_sub_lo(src), hi, lo, out, output_full_src, output_base); // recursive call to take care of case lo > 0
return;
}
else {
SASSERT(lo > 0);
split(src, lo - 1);
out_base.push_back(find_sub_hi(src));
out.push_back(find_sub_hi(src));
SASSERT(!has_sub(find_sub_lo(src)));
if (output_full_src)
out_base.push_back(find_sub_lo(src));
out.push_back(find_sub_lo(src));
return;
}
}
UNREACHABLE();
}
pvar slicing::mk_extract_var(pvar src, unsigned hi, unsigned lo) {
pvar slicing::mk_slice_extract(slice src, unsigned hi, unsigned lo) {
slice_vector slices;
mk_slice(var2slice(src), hi, lo, slices);
// src[hi:lo] is the concatenation of the returned slices
// TODO: for each slice, set_extract
#if 0
extract_key key{src, hi, lo};
auto it = m_extracted.find_iterator(key);
if (it != m_extracted.end())
return it->m_value;
pvar v = s.add_var(hi - lo);
set_extract(v, src, hi, lo);
mk_slice(src, hi, lo, slices, false, true);
if (slices.size() == 1) {
slice s = slices[0];
if (slice2var(s) != null_var)
return slice2var(s);
}
pvar v = m_solver.add_var(hi - lo + 1);
merge(slices, var2slice(v));
return v;
#endif
}
#if 0
pvar slicing::mk_extract_var(pvar src, unsigned hi, unsigned lo) {
return mk_slice_extract(var2slice(src), hi, lo);
}
pdd slicing::mk_extract(pvar src, unsigned hi, unsigned lo) {
return m_solver.var(mk_extract_var(src, hi, lo));
}
@ -293,55 +319,27 @@ namespace polysat {
if (!lo) {
// TODO: we could push the extract down into variables of the term instead of introducing a name.
}
return m_solver.var(mk_slice_extract(pdd2slice(p), hi, lo));
}
slicing::slice slicing::pdd2slice(pdd const& p) {
pvar const v = m_solver.m_names.mk_name(p);
return mk_extract(v, hi, lo);
return var2slice(v);
}
pdd slicing::mk_concat(pdd const& p, pdd const& q) {
#if 0
// v := p ++ q (new variable of size |p| + |q|)
// v[:|q|] = p
// v[|q|:] = q
unsigned const p_sz = p.power_of_2();
unsigned const q_sz = q.power_of_2();
unsigned const v_sz = p_sz + q_sz;
// TODO: lookup to see if we can reuse a variable
// either:
// - table of concats
// - check for variable with v[:|q|] = p and v[|q|:] = q in extract table (probably better)
pvar const v = s.add_var(v_sz);
// TODO: probably wrong to use names for p, q.
// we should rather check if there's already an extraction for v[...] and reuse that variable.
pvar const p_name = s.m_names.mk_name(p);
pvar const q_name = s.m_names.mk_name(q);
set_extract(p_name, v, v_sz, q_sz);
set_extract(q_name, v, q_sz, 0);
#endif
NOT_IMPLEMENTED_YET();
}
#endif
void slicing::set_extract(pvar v, pvar src, unsigned hi, unsigned lo) {
#if 0
SASSERT(!is_extract(v));
SASSERT(lo < hi && hi <= s.size(src));
SASSERT_EQ(hi - lo + 1, s.size(v));
SASSERT(src < v);
SASSERT(!m_extracted.contains(extract_key{src, hi, lo}));
#if 0 // try without this first
if (is_extract(src)) {
// y = (x[k:m])[h:l] = x[h+m:l+m]
unsigned const offset = m_lo[src];
set_extract(m_src[src], hi + offset, lo + offset);
return;
}
#endif
m_extracted.insert({src, hi, lo}, v);
m_src[v] = src;
m_hi[v] = hi;
m_lo[v] = lo;
#endif
pvar const v = m_solver.add_var(v_sz);
slice_vector tmp;
tmp.push_back(pdd2slice(p));
tmp.push_back(pdd2slice(q));
merge(tmp, var2slice(v));
return m_solver.var(v);
}
void slicing::propagate(pvar v) {
@ -353,8 +351,14 @@ namespace polysat {
out << "v" << v << ":";
base.reset();
find_base(var2slice(v), base);
for (slice s : base)
// unsigned hi = width(var2slice(v)) - 1;
for (slice s : base) {
// unsigned w = width(s);
// unsigned lo = hi - w + 1;
// out << " s" << s << "_[" << hi << ":" << lo << "]";
// hi -= w;
display(out << " ", s);
}
out << "\n";
}
return out;

View file

@ -25,6 +25,7 @@ Notation:
--*/
#pragma once
#include "math/polysat/types.h"
#include <optional>
namespace polysat {
@ -34,19 +35,9 @@ namespace polysat {
friend class test_slicing;
// solver& m_solver;
solver& m_solver;
#if 0
/// If y := x[h:l], then m_src[y] = x, m_hi[y] = h, m_lo[y] = l.
/// Otherwise m_src[y] = null_var.
///
/// Invariants:
/// m_src[y] != null_var ==> m_src[y] < y (at least as long as we always introduce new variables for extract terms.)
/// m_lo[y] <= m_hi[y]
unsigned_vector m_src;
unsigned_vector m_hi;
unsigned_vector m_lo;
struct extract_key {
pvar src;
unsigned hi;
@ -98,11 +89,12 @@ namespace polysat {
/// Split slice s into s[|s|-1:cut+1] and s[cut:0]
void split(slice s, unsigned cut);
/// Retrieve base slices s_1,...,s_n such that src == s_1 ++ ... + s_n
/// Retrieve base slices s_1,...,s_n such that src == s_1 ++ ... ++ s_n
void find_base(slice src, slice_vector& out_base) const;
// Retrieve (or create) base slices s_1,...,s_n such that src[hi:lo] == s_1 ++ ... ++ s_n
// If output_full_src is true, returns the new base for src, i.e., src == s_1 ++ ... ++ s_n
void mk_slice(slice src, unsigned hi, unsigned lo, slice_vector& out_base, bool output_full_src = false);
/// Retrieve (or create) base slices s_1,...,s_n such that src[hi:lo] == s_1 ++ ... ++ s_n.
/// If output_full_src is true, return the new base for src, i.e., src == s_1 ++ ... ++ s_n.
/// If output_base is false, return coarsest intermediate slices instead of only base slices.
void mk_slice(slice src, unsigned hi, unsigned lo, slice_vector& out, bool output_full_src = false, bool output_base = true);
/// Find representative
slice find(slice s) const;
@ -112,67 +104,62 @@ namespace polysat {
slice find_sub_lo(slice s) const;
// Merge equivalence classes of two base slices
void merge(slice s1, slice s2);
void merge_base(slice s1, slice s2);
// Merge equality x_1 ++ ... ++ x_n == y_1 ++ ... ++ y_k
//
// Precondition:
// - sequence of base slices (equal total width)
// - sequence of slices with equal total width
// - ordered from msb to lsb
void merge(slice_vector& xs, slice_vector& ys);
void merge(slice_vector& xs, slice y);
void set_extract(pvar v, pvar src, unsigned hi_bit, unsigned lo_bit);
void merge(slice x, slice y);
enum class trail_item {
add_var,
alloc_slice,
split_slice,
merge_class,
merge_base,
};
svector<trail_item> m_trail;
slice_vector m_split_trail;
slice_vector m_merge_trail;
slice_vector m_split_trail;
slice_vector m_merge_trail;
unsigned_vector m_scopes;
void undo_add_var();
void undo_alloc_slice();
void undo_split_slice();
void undo_merge_class();
void undo_merge_base();
mutable slice_vector m_tmp1;
// get slice equivalent to the given pdd (may introduce new variable)
slice pdd2slice(pdd const& p);
/** Get variable representing src[hi:lo] */
pvar mk_slice_extract(slice src, unsigned hi, unsigned lo);
public:
// slicing(solver& s): m_solver(s) {}
slicing(solver& s): m_solver(s) {}
void push_scope();
void pop_scope(unsigned num_scopes = 1);
void add_var(unsigned bit_width);
// bool is_extract(pvar v) const { return m_src[v] != null_var; }
/** Get variable representing x[hi:lo] */
pvar mk_extract_var(pvar x, unsigned hi, unsigned lo);
// /** Create expression for x[hi:lo] */
// pdd mk_extract(pvar x, unsigned hi, unsigned lo);
/** Create expression for x[hi:lo] */
pdd mk_extract(pvar x, unsigned hi, unsigned lo);
// /** Create expression for p[hi:lo] */
// pdd mk_extract(pdd const& p, unsigned hi, unsigned lo);
/** Create expression for p[hi:lo] */
pdd mk_extract(pdd const& p, unsigned hi, unsigned lo);
// /** Create expression for p ++ q */
// pdd mk_concat(pdd const& p, pdd const& q);
/** Create expression for p ++ q */
pdd mk_concat(pdd const& p, pdd const& q);
// propagate:
// - value assignments

View file

@ -42,7 +42,7 @@ namespace polysat {
m_free_pvars(m_activity),
m_constraints(*this),
m_names(*this),
// m_slicing(*this),
m_slicing(*this),
m_search(*this) {
}

View file

@ -60,29 +60,45 @@ namespace polysat {
slicing& sl = s.sl();
pvar x = s.add_var(8);
pvar y = s.add_var(8);
pvar a = s.add_var(5);
pvar b = s.add_var(6);
slicing::slice_vector x_7_3;
sl.mk_slice(sl.var2slice(x), 7, 3, x_7_3);
slicing::slice_vector a_4_0;
sl.mk_slice(sl.var2slice(a), 4, 0, a_4_0);
sl.merge(x_7_3, a_4_0);
pvar a = sl.mk_extract_var(x, 7, 3);
std::cout << sl << "\n";
slicing::slice_vector x_base;
sl.find_base(sl.var2slice(x), x_base);
slicing::slice_vector y_base;
sl.find_base(sl.var2slice(y), y_base);
sl.merge(x_base, y_base);
sl.merge(sl.var2slice(x), sl.var2slice(y));
std::cout << sl << "\n";
slicing::slice_vector y_5_0;
sl.mk_slice(sl.var2slice(y), 5, 0, y_5_0);
sl.merge(y_5_0, sl.var2slice(b));
pvar b = sl.mk_extract_var(y, 5, 0);
std::cout << sl << "\n";
}
// x[7:3] = a
// y[5:0] = b
// x[5:0] = c
// x[5:4] ++ y[3:0] = d
// x = y
//
// How easily can we find b=c and b=d?
static void test3() {
std::cout << __func__ << "\n";
scoped_solver_slicing s;
slicing& sl = s.sl();
pvar x = s.add_var(8);
pvar y = s.add_var(8);
std::cout << sl << "\n";
pvar a = sl.mk_extract_var(x, 7, 3);
std::cout << "v" << a << " := v" << x << "[7:3]\n" << sl << "\n";
pvar b = sl.mk_extract_var(y, 5, 0);
std::cout << "v" << b << " := v" << y << "[5:0]\n" << sl << "\n";
pvar c = sl.mk_extract_var(x, 5, 0);
std::cout << "v" << c << " := v" << x << "[5:0]\n" << sl << "\n";
pdd d = sl.mk_concat(sl.mk_extract(x, 5, 4), sl.mk_extract(y, 3, 0));
std::cout << d << " := v" << x << "[5:4] ++ v" << y << "[3:0]\n" << sl << "\n";
sl.merge(sl.var2slice(x), sl.var2slice(y));
std::cout << "v" << x << " = v" << y << "\n" << sl << "\n";
}
};
}
@ -92,5 +108,6 @@ void tst_slicing() {
using namespace polysat;
test_slicing::test1();
test_slicing::test2();
test_slicing::test3();
std::cout << "ok\n";
}