3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-24 17:45:32 +00:00

eliminate basic variables from new rows

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2021-05-12 15:58:35 -07:00
parent 62b7719d5a
commit 0d776ecf88
4 changed files with 191 additions and 94 deletions

View file

@ -33,8 +33,9 @@ inline std::ostream& operator<<(std::ostream& out, pp<Numeral> const& p) {
}
template<typename Numeral>
struct mod_interval {
class mod_interval {
bool emp { false };
public:
Numeral lo { 0 };
Numeral hi { 0 };
mod_interval() {}
@ -43,6 +44,8 @@ struct mod_interval {
static mod_interval empty() { mod_interval i(0, 0); i.emp = true; return i; }
bool is_free() const { return !emp && lo == hi; }
bool is_empty() const { return emp; }
void set_free() { lo = hi = 0; emp = false; }
void set_bounds(Numeral const& l, Numeral const& h) { lo = l; hi = h; }
bool contains(Numeral const& n) const;
mod_interval operator&(mod_interval const& other) const;
mod_interval operator+(mod_interval const& other) const;

View file

@ -41,6 +41,14 @@ namespace polysat {
typedef typename matrix::row row;
typedef typename matrix::row_iterator row_iterator;
typedef typename matrix::col_iterator col_iterator;
struct var_eq {
var_t x, y;
row r1, r2;
var_eq(var_t x, var_t y, row const& r1, row const& r2):
x(x), y(y), r1(r1), r2(r2) {}
};
protected:
struct var_lt {
bool operator()(var_t v1, var_t v2) const { return v1 < v2; }
@ -73,9 +81,13 @@ namespace polysat {
m_is_base(false)
{}
var_info& operator&=(mod_interval<numeral> const& range) {
mod_interval<numeral>::operator=(range);
mod_interval<numeral>::operator=(range & *this);
return *this;
}
var_info& operator=(mod_interval<numeral> const& range) {
mod_interval<numeral>::operator=(range);
return *this;
}
};
struct row_info {
@ -85,12 +97,6 @@ namespace polysat {
numeral m_base_coeff;
};
struct var_eq {
var_t x, y;
row r1, r2;
var_eq(var_t x, var_t y, row const& r1, row const& r2):
x(x), y(y), r1(r1), r2(r2) {}
};
struct fix_entry {
var_t x;
@ -128,7 +134,7 @@ namespace polysat {
void set_bounds(var_t v, numeral const& lo, numeral const& hi);
void unset_bounds(var_t v) { m_vars[v].lo = m_vars[v].hi; }
void unset_bounds(var_t v) { m_vars[v].set_free(); }
var_t get_base_var(row const& r) const { return m_rows[r.id()].m_base; }
numeral const& lo(var_t var) const { return m_vars[var].lo; }
@ -136,9 +142,11 @@ namespace polysat {
numeral const& value(var_t var) const { return m_vars[var].m_value; }
void set_max_iterations(unsigned n) { m_max_iterations = n; }
unsigned get_num_vars() const { return m_vars.size(); }
void reset();
void propagate_bounds();
void propagate_eqs();
void reset();
void propagate_bounds();
void propagate_eqs();
vector<var_eq> const& var_eqs() const { return m_var_eqs; }
void reset_eqs() { m_var_eqs.reset(); }
lbool make_feasible();
row add_row(var_t base, unsigned num_vars, var_t const* vars, numeral const* coeffs);
std::ostream& display(std::ostream& out) const;
@ -149,14 +157,10 @@ namespace polysat {
void del_row(var_t base_var);
private:
void gauss_jordan();
void make_basic(var_t v, row const& r);
void update_value_core(var_t v, numeral const& delta);
void ensure_var(var_t v);
void ensure_var(var_t v);
var_t select_smallest_var() { return m_to_patch.empty()?null_var:m_to_patch.erase_min(); }
lbool make_var_feasible(var_t x_i);
@ -171,9 +175,11 @@ namespace polysat {
void new_bound(row const& r, var_t x, mod_interval<numeral> const& range);
void pivot(var_t x_i, var_t x_j, numeral const& b, numeral const& value);
numeral value2delta(var_t v, numeral const& new_value) const;
numeral value2error(var_t v, numeral const& new_value) const;
void update_value(var_t v, numeral const& delta);
bool can_pivot(var_t x_i, numeral const& new_value, numeral const& a_ij, var_t x_j);
bool has_minimal_trailing_zeros(var_t y, numeral const& b);
var_t select_pivot(var_t x_i, numeral const& new_value, numeral& out_b);
var_t select_pivot_core(var_t x, numeral const& new_value, numeral& out_b);
bool in_bounds(var_t v) const { return in_bounds(v, value(v)); }
bool in_bounds(var_t v, numeral const& b) const { return in_bounds(b, m_vars[v]); }
@ -206,17 +212,18 @@ namespace polysat {
void del_row(row const& r);
var_t select_pivot_blands(var_t x, numeral const& new_value, numeral& out_b);
bool can_improve(var_t x, numeral const& new_value, var_t y, numeral const& b);
#if 0
// TBD:
void move_to_bound(var_t x, bool to_lower) {}
var_t select_pivot(var_t x_i, bool is_below, numeral& out_a_ij) { throw nullptr; }
var_t select_pivot_blands(var_t x_i, bool is_below, numeral& out_a_ij) { throw nullptr; }
var_t pick_var_to_leave(var_t x_j, bool is_pos,
numeral& gain, numeral& new_a_ij, bool& inc) { throw nullptr; }
#endif
bool pivot_base_vars();
bool elim_base(var_t v);
bool eliminate_var(
row const& r_y,
row const& r_z,
numeral const& c,
unsigned tz_b,
numeral const& old_value_y);
};

View file

@ -47,6 +47,7 @@ namespace polysat {
m_rows.reset();
m_left_basis.reset();
m_base_vars.reset();
m_var_eqs.reset();
}
template<typename Ext>
@ -116,19 +117,45 @@ namespace polysat {
m_vars[base_var].m_base2row = r.id();
m_vars[base_var].m_is_base = true;
set_base_value(base_var);
// TBD: record when base_coeff does not divide value
add_patch(base_var);
if (!m_base_vars.empty()) {
gauss_jordan();
}
bool elim = pivot_base_vars();
SASSERT(well_formed_row(r));
SASSERT(well_formed());
return r;
}
template<typename Ext>
bool fixplex<Ext>::pivot_base_vars() {
bool ok = true;
for (auto v : m_base_vars)
if (!elim_base(v))
ok = false;
m_base_vars.reset();
return ok;
}
template<typename Ext>
bool fixplex<Ext>::elim_base(var_t v) {
SASSERT(is_base(v));
row r = row(base2row(v));
numeral b = row2base_coeff(r);
unsigned tz_b = m.trailing_zeros(b);
for (auto col : M.col_entries(v)) {
if (r.id() == col.get_row().id())
continue;
numeral c = col.get_row_entry().coeff();
numeral value_v = value(v);
if (!eliminate_var(r, col.get_row(), c, tz_b, value_v))
return false;
}
return true;
}
template<typename Ext>
void fixplex<Ext>::del_row(row const& r) {
var_t var = m_rows[r.id()].m_base;
m_var_eqs.reset();
var_t var = row2base(r);
m_vars[var].m_is_base = false;
m_vars[var].lo = 0;
m_vars[var].hi = 0;
@ -160,7 +187,7 @@ namespace polysat {
}
if (tz == UINT_MAX)
return;
var_t old_base = m_rows[r.id()].m_base;
var_t old_base = row2base(r);
numeral new_value;
var_info& vi = m_vars[old_base];
if (!vi.contains(value(old_base)))
@ -205,33 +232,6 @@ namespace polysat {
}
}
template<typename Ext>
void fixplex<Ext>::gauss_jordan() {
#if 0
while (!m_base_vars.empty()) {
auto v = m_base_vars.back();
auto rid = m_vars[v].m_base2row;
auto const& row = m_rows[rid];
make_basic(v, row);
}
#endif
}
/**
* If v is already a basic variable in preferred_row, skip
* If v is non-basic but basic in a different row, then
* eliminate v from one of the rows.
* If v if non-basic
*/
template<typename Ext>
void fixplex<Ext>::make_basic(var_t v, row const& preferred_row) {
NOT_IMPLEMENTED_YET();
}
/**
* Attempt to improve assigment to make x feasible.
@ -263,11 +263,16 @@ namespace polysat {
pivot(x, y, b, new_value);
// get_offset_eqs(row(base2row(y)));
return l_true;
}
template<typename Ext>
var_t fixplex<Ext>::select_pivot(var_t x, numeral const& new_value, numeral & out_b) {
if (m_bland)
return select_pivot_blands(x, new_value, out_b);
return select_pivot_core(x, new_value, out_b);
}
/**
\brief Select a variable y in the row r defining the base var x,
s.t. y can be used to patch the error in x_i. Return null_var
@ -286,8 +291,8 @@ namespace polysat {
int n = 0;
unsigned best_col_sz = UINT_MAX;
int best_so_far = INT_MAX;
numeral a = m_rows[r.id()].m_base_coeff;
numeral row_value = m_rows[r.id()].m_value + a * new_value;
numeral a = row2base_coeff(r);
numeral row_value = row2value(r) + a * new_value;
numeral delta_y = 0;
numeral delta_best = 0;
bool best_in_bounds = false;
@ -351,6 +356,40 @@ namespace polysat {
return result;
}
template<typename Ext>
var_t fixplex<Ext>::select_pivot_blands(var_t x, numeral const& new_value, numeral & out_b) {
SASSERT(is_base(x));
unsigned max = get_num_vars();
var_t result = max;
row r(base2row(x));
for (auto const& c : M.col_entries(r)) {
var_t y = c.var();
if (x == y || y >= result)
continue;
numeral const & b = c.coeff();
if (can_improve(y, b)) {
out_b = b;
result = y;
}
}
return result < max ? result : null_var;
}
/**
* determine whether setting x := new_value
* allows to change the value of y in a direction
* that reduces or maintains the overall error.
*/
template<typename Ext>
bool fixplex<Ext>::can_improve(var_t x, numeral const& new_x_value, var_t y, numeral const& b) {
row r(base2row(x));
numeral row_value = row2value(r) + row2base_coeff(r) * new_x_value;
numeral new_y_value = solve_for(row_value - b * value(y), b);
if (in_bounds(y, new_y_value))
return true;
return value2error(y, new_y_value) <= value2error(x, value(x));
}
/**
* Compute delta to add to the value, such that value + delta is either lo(v), or hi(v) - 1
* A pre-condition is that value is not in the interval [lo(v),hi(v)[,
@ -367,6 +406,18 @@ namespace polysat {
return hi(v) - value - 1;
}
template<typename Ext>
typename fixplex<Ext>::numeral
fixplex<Ext>::value2error(var_t v, numeral const& value) const {
if (in_bounds(v))
return 0;
SASSERT(lo(v) != hi(v));
if (lo(v) - value < value - hi(v))
return lo(v) - value;
else
return value - hi(v) - 1;
}
/**
* The the bounds of variable v.
@ -376,9 +427,8 @@ namespace polysat {
* - the variable v is queued to patch if v is basic.
*/
template<typename Ext>
void fixplex<Ext>::set_bounds(var_t v, numeral const& lo, numeral const& hi) {
m_vars[v].lo = lo;
m_vars[v].hi = hi;
void fixplex<Ext>::set_bounds(var_t v, numeral const& lo, numeral const& hi) {
m_vars[v] = mod_interval(lo, hi);
if (in_bounds(v))
return;
if (is_base(v))
@ -523,32 +573,61 @@ namespace polysat {
add_patch(y);
SASSERT(well_formed_row(r_x));
unsigned tz1 = m.trailing_zeros(b);
unsigned tz_b = m.trailing_zeros(b);
for (auto col : M.col_entries(y)) {
row r_z = col.get_row();
unsigned rz = r_z.id();
if (rz == rx)
continue;
auto z = row2base(r_z);
auto& row_z = m_rows[rz];
var_info& zI = m_vars[z];
numeral c = col.get_row_entry().coeff();
unsigned tz2 = m.trailing_zeros(c);
SASSERT(tz1 <= tz2);
numeral b1 = b >> tz1;
numeral c1 = 0 - (c >> (tz2 - tz1));
M.mul(r_z, b1);
M.add(r_z, c1, r_x);
row_z.m_value = (b1 * (row_z.m_value - c * old_value_y)) + c1 * row_x.m_value;
row_z.m_base_coeff *= b1;
set_base_value(z);
SASSERT(well_formed_row(r_z));
add_patch(row_z.m_base);
VERIFY(eliminate_var(r_x, r_z, c, tz_b, old_value_y));
add_patch(row2base(r_z));
}
SASSERT(well_formed());
}
/**
* r_y - row where y is base variable
* r_z - row that contains y with z base variable, z != y
* c - coefficient of y in r_z
* tz_b - number of trailing zeros to coefficient of y in r_y
* old_value_y - the value of y used to compute row2value(r_z)
*
* returns true if elimination preserves equivalence (is lossless).
*/
template<typename Ext>
bool fixplex<Ext>::eliminate_var(
row const& r_y,
row const& r_z,
numeral const& c,
unsigned tz_b,
numeral const& old_value_y) {
var_t y = row2base(r_y);
numeral b = row2base_coeff(r_y);
auto z = row2base(r_z);
auto& row_z = m_rows[r_z.id()];
var_info& zI = m_vars[z];
unsigned tz_c = m.trailing_zeros(c);
numeral b1, c1;
if (tz_b <= tz_c) {
b1 = b >> tz_b;
c1 = 0 - (c >> (tz_c - tz_b));
}
else {
b1 = b >> (tz_b - tz_c);
c1 = 0 - (c >> tz_c);
}
M.mul(r_z, b1);
M.add(r_z, c1, r_y);
row_z.m_value = (b1 * (row2value(r_z) - c * old_value_y)) + c1 * row2value(r_y);
row_z.m_base_coeff *= b1;
set_base_value(z);
SASSERT(well_formed_row(r_z));
return tz_b <= tz_c;
}
template<typename Ext>
bool fixplex<Ext>::is_feasible() const {
for (unsigned i = m_vars.size(); i-- > 0; )
@ -574,7 +653,7 @@ namespace polysat {
int fixplex<Ext>::get_num_non_free_dep_vars(var_t x_j, int best_so_far) {
int result = is_non_free(x_j);
for (auto const& col : M.col_entries(x_j)) {
var_t s = m_rows[col.get_row().id()].m_base;
var_t s = row2base(col.get_row());
result += is_non_free(s);
if (result > best_so_far)
return result;
@ -607,14 +686,11 @@ namespace polysat {
template<typename Ext>
var_t fixplex<Ext>::select_error_var(bool least) {
var_t best = null_var;
numeral best_error = 0, curr_error = 0;
numeral best_error = 0;
for (var_t v : m_to_patch) {
if (in_bounds(v))
numeral curr_error = value2error(v, value(v));
if (curr_error == 0)
continue;
if (lo(v) - value(v) < value(v) - hi(v))
curr_error = lo(v) - value(v);
else
curr_error = value(v) - hi(v) - 1;
if ((best == null_var) ||
(least && curr_error < best_error) ||
(!least && curr_error > best_error)) {
@ -804,7 +880,6 @@ namespace polysat {
template<typename Ext>
void fixplex<Ext>::eq_eh(var_t x, var_t y, row const& r1, row const& r2) {
std::cout << "eq " << x << " == " << y << "\n";
m_var_eqs.push_back(var_eq(x, y, r1, r2));
}
@ -863,8 +938,11 @@ namespace polysat {
void fixplex<Ext>::new_bound(row const& r, var_t x, mod_interval<numeral> const& range) {
if (range.is_free())
return;
bool was_fixed = lo(x) + 1 == hi(x);
m_vars[x] &= range;
IF_VERBOSE(0, verbose_stream() << "new-bound v" << x << " " << m_vars[x] << "\n");
if (!was_fixed && lo(x) + 1 == hi(x))
fixed_var_eh(r, x);
}
template<typename Ext>
@ -881,11 +959,15 @@ namespace polysat {
template<typename Ext>
std::ostream& fixplex<Ext>::display_row(std::ostream& out, row const& r, bool values) {
out << r.id() << " := " << pp(row2value(r)) << " : ";
for (auto const& e : M.row_entries(r)) {
var_t v = e.var();
if (e.coeff() != 1)
out << pp(e.coeff()) << " * ";
out << "v" << v << " ";
out << "v" << v;
if (is_base(v))
out << "b";
out << " ";
if (values)
out << pp(value(v)) << " " << m_vars[v] << " ";
}
@ -911,11 +993,11 @@ namespace polysat {
template<typename Ext>
bool fixplex<Ext>::well_formed_row(row const& r) const {
var_t s = m_rows[r.id()].m_base;
var_t s = row2base(r);
VERIFY(base2row(s) == r.id());
VERIFY(m_vars[s].m_is_base);
numeral sum = 0;
numeral base_coeff = m_rows[r.id()].m_base_coeff;
numeral base_coeff = row2base_coeff(r);
for (auto const& e : M.row_entries(r)) {
sum += value(e.var()) * e.coeff();
SASSERT(s != e.var() || base_coeff == e.coeff());
@ -925,7 +1007,7 @@ namespace polysat {
TRACE("polysat", display(tout << "non-well formed row\n"); M.display_row(tout << "row: ", r););
throw default_exception("non-well formed row");
}
SASSERT(sum == m_rows[r.id()].m_value + base_coeff * value(s));
SASSERT(sum == row2value(r) + base_coeff * value(s));
return true;
}
@ -937,4 +1019,7 @@ namespace polysat {
st.update("fixplex num checks", m_stats.m_num_checks);
st.update("fixplex num non-integral", m_num_non_integral);
}
}

View file

@ -110,7 +110,9 @@ namespace polysat {
fp.set_bounds(u, 1, 2);
fp.run();
fp.propagate_eqs();
for (auto e : fp.var_eqs())
std::cout << e.x << " == " << e.y << "\n";
}