3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-15 13:28:47 +00:00

cheap equalities

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
This commit is contained in:
Lev Nachmanson 2020-06-09 17:14:32 -07:00
parent ccc8651800
commit 1587497562
5 changed files with 156 additions and 70 deletions

View file

@ -109,7 +109,6 @@ class lar_solver : public column_namer {
static_matrix<double, double > const & A_d() const; static_matrix<double, double > const & A_d() const;
static bool valid_index(unsigned j) { return static_cast<int>(j) >= 0;} static bool valid_index(unsigned j) { return static_cast<int>(j) >= 0;}
unsigned external_to_column_index(unsigned) const;
const lar_term & get_term(unsigned j) const; const lar_term & get_term(unsigned j) const;
bool row_has_a_big_num(unsigned i) const; bool row_has_a_big_num(unsigned i) const;
// init region // init region
@ -282,6 +281,7 @@ class lar_solver : public column_namer {
void register_normalized_term(const lar_term&, lpvar); void register_normalized_term(const lar_term&, lpvar);
void deregister_normalized_term(const lar_term&); void deregister_normalized_term(const lar_term&);
public: public:
unsigned external_to_column_index(unsigned) const;
bool inside_bounds(lpvar, const impq&) const; bool inside_bounds(lpvar, const impq&) const;
inline void set_column_value(unsigned j, const impq& v) { inline void set_column_value(unsigned j, const impq& v) {
m_mpq_lar_core_solver.m_r_solver.update_x(j, v); m_mpq_lar_core_solver.m_r_solver.update_x(j, v);

View file

@ -9,12 +9,13 @@
namespace lp { namespace lp {
template <typename T> template <typename T>
class lp_bound_propagator { class lp_bound_propagator {
typedef std::pair<int, impq> var_offset; typedef std::pair<int, mpq> var_offset;
typedef pair_hash<int_hash, obj_hash<impq> > var_offset_hash; typedef pair_hash<int_hash, obj_hash<mpq> > var_offset_hash;
typedef map<var_offset, unsigned, var_offset_hash, default_eq<var_offset> > var_offset2row_id; typedef map<var_offset, unsigned, var_offset_hash, default_eq<var_offset> > var_offset2row_id;
typedef std::pair<mpq, bool> value_sort_pair;
var_offset2row_id m_var_offset2row_id; var_offset2row_id m_var_offset2row_id;
struct impq_eq { bool operator()(const impq& a, const impq& b) const {return a == b;}}; struct mpq_eq { bool operator()(const mpq& a, const mpq& b) const {return a == b;}};
// vertex represents a pair (row,x) or (row,y) for an offset row. // vertex represents a pair (row,x) or (row,y) for an offset row.
// The set of all pair are organised in a tree. // The set of all pair are organised in a tree.
@ -25,7 +26,7 @@ class lp_bound_propagator {
unsigned m_row; unsigned m_row;
unsigned m_index_in_row; // in the row unsigned m_index_in_row; // in the row
ptr_vector<vertex> m_children; ptr_vector<vertex> m_children;
impq m_offset; // offset from parent (parent - child = offset) mpq m_offset; // offset from parent (parent - child = offset)
vertex* m_parent; vertex* m_parent;
unsigned m_level; // the distance in hops to the root; unsigned m_level; // the distance in hops to the root;
// it is handy to find the common ancestor // it is handy to find the common ancestor
@ -33,7 +34,7 @@ class lp_bound_propagator {
vertex() {} vertex() {}
vertex(unsigned row, vertex(unsigned row,
unsigned index_in_row, unsigned index_in_row,
const impq & offset) : const mpq & offset) :
m_row(row), m_row(row),
m_index_in_row(index_in_row), m_index_in_row(index_in_row),
m_offset(offset), m_offset(offset),
@ -43,7 +44,7 @@ class lp_bound_propagator {
unsigned row() const { return m_row; } unsigned row() const { return m_row; }
vertex* parent() const { return m_parent; } vertex* parent() const { return m_parent; }
unsigned level() const { return m_level; } unsigned level() const { return m_level; }
const impq& offset() const { return m_offset; } const mpq& offset() const { return m_offset; }
void add_child(vertex* child) { void add_child(vertex* child) {
child->m_parent = this; child->m_parent = this;
m_children.push_back(child); m_children.push_back(child);
@ -61,12 +62,12 @@ class lp_bound_propagator {
hashtable<unsigned, u_hash, u_eq> m_visited_rows; hashtable<unsigned, u_hash, u_eq> m_visited_rows;
hashtable<unsigned, u_hash, u_eq> m_visited_columns; hashtable<unsigned, u_hash, u_eq> m_visited_columns;
vertex* m_root; vertex* m_root;
map<impq, vertex*, obj_hash<impq>, impq_eq> m_offset_to_verts; map<mpq, vertex*, obj_hash<mpq>, mpq_eq> m_offset_to_verts;
// these maps map a column index to the corresponding index in ibounds // these maps map a column index to the corresponding index in ibounds
std::unordered_map<unsigned, unsigned> m_improved_lower_bounds; std::unordered_map<unsigned, unsigned> m_improved_lower_bounds;
std::unordered_map<unsigned, unsigned> m_improved_upper_bounds; std::unordered_map<unsigned, unsigned> m_improved_upper_bounds;
T& m_imp; T& m_imp;
impq m_zero; mpq m_zero;
vector<implied_bound> m_ibounds; vector<implied_bound> m_ibounds;
public: public:
const vector<implied_bound>& ibounds() const { return m_ibounds; } const vector<implied_bound>& ibounds() const { return m_ibounds; }
@ -75,7 +76,7 @@ public:
m_improved_lower_bounds.clear(); m_improved_lower_bounds.clear();
m_ibounds.reset(); m_ibounds.reset();
} }
lp_bound_propagator(T& imp): m_imp(imp), m_zero(impq(0)) {} lp_bound_propagator(T& imp): m_imp(imp), m_zero(mpq(0)) {}
const lar_solver& lp() const { return m_imp.lp(); } const lar_solver& lp() const { return m_imp.lp(); }
column_type get_column_type(unsigned j) const { column_type get_column_type(unsigned j) const {
return m_imp.lp().get_column_type(j); return m_imp.lp().get_column_type(j);
@ -85,9 +86,23 @@ public:
return m_imp.lp().get_lower_bound(j); return m_imp.lp().get_lower_bound(j);
} }
const mpq & get_lower_bound_rational(unsigned j) const {
return m_imp.lp().get_lower_bound(j).x;
}
const impq & get_upper_bound(unsigned j) const { const impq & get_upper_bound(unsigned j) const {
return m_imp.lp().get_upper_bound(j); return m_imp.lp().get_upper_bound(j);
} }
const mpq & get_upper_bound_rational(unsigned j) const {
return m_imp.lp().get_upper_bound(j).x;
}
// require also the zero infinitesemal part
bool column_is_fixed(lpvar j) const {
return lp().column_is_fixed(j) && get_lower_bound(j).y.is_zero();
}
void try_add_bound(mpq const& v, unsigned j, bool is_low, bool coeff_before_j_is_pos, unsigned row_or_term_index, bool strict) { void try_add_bound(mpq const& v, unsigned j, bool is_low, bool coeff_before_j_is_pos, unsigned row_or_term_index, bool strict) {
j = m_imp.lp().column_to_reported_index(j); j = m_imp.lp().column_to_reported_index(j);
@ -130,15 +145,63 @@ public:
m_imp.consume(a, ci); m_imp.consume(a, ci);
} }
bool is_offset_row(unsigned row_index, bool is_offset_row(unsigned r, lpvar & x, lpvar & y, mpq & k) const {
if (r >= lp().row_count())
return false;
x = y = null_lpvar;
for (auto& c : lp().get_row(r)) {
lpvar v = c.var();
if (column_is_fixed(v))
continue;
if (c.coeff().is_one() && x == null_lpvar) {
x = v;
continue;
}
if (c.coeff().is_minus_one() && y == null_lpvar) {
y = v;
continue;
}
return false;
}
if (x == null_lpvar && y == null_lpvar) {
return false;
}
k = mpq(0);
for (const auto& c : lp().get_row(r)) {
if (!column_is_fixed(c.var()))
continue;
k -= c.coeff() * get_lower_bound_rational(c.var());
}
if (y == null_lpvar)
return true;
if (x == null_lpvar) {
std::swap(x, y);
k.neg();
return true;
}
if (/*r.get_base_var() != x &&*/ x > y) {
std::swap(x, y);
k.neg();
}
return true;
}
bool is_offset_row_wrong(unsigned row_index,
unsigned & x_index, unsigned & x_index,
lpvar & y_index, lpvar & y_index,
impq& offset) { mpq& offset) {
if (row_index >= lp().row_count())
return false;
x_index = y_index = UINT_MAX; x_index = y_index = UINT_MAX;
const auto & row = lp().get_row(row_index); const auto & row = lp().get_row(row_index);
for (unsigned k = 0; k < row.size(); k++) { for (unsigned k = 0; k < row.size(); k++) {
const auto& c = row[k]; const auto& c = row[k];
if (lp().column_is_fixed(c.var())) if (column_is_fixed(c.var()))
continue; continue;
if (x_index == UINT_MAX && c.coeff().is_one()) if (x_index == UINT_MAX && c.coeff().is_one())
x_index = k; x_index = k;
@ -147,19 +210,19 @@ public:
else else
return false; return false;
} }
if (x_index == UINT_MAX || y_index == UINT_MAX) if (x_index == UINT_MAX && y_index == UINT_MAX)
return false; return false;
if (lp().column_is_int(row[x_index].var()) != lp().column_is_int(row[y_index].var())) if (lp().column_is_int(row[x_index].var()) != lp().column_is_int(row[y_index].var()))
return false; return false;
offset = impq(0); offset = mpq(0);
for (const auto& c : row) { for (const auto& c : row) {
if (!lp().column_is_fixed(c.var())) if (!column_is_fixed(c.var()))
continue; continue;
offset += c.coeff() * lp().get_lower_bound(c.var()); offset += c.coeff() * get_lower_bound_rational(c.var());
} }
if (offset.is_zero() && if (offset.is_zero() &&
!pair_is_reported_or_congruent(row[x_index].var(), row[y_index].var())) { !is_equal(row[x_index].var(), row[y_index].var())) {
lp::explanation ex; lp::explanation ex;
explain_fixed_in_row(row_index, ex); explain_fixed_in_row(row_index, ex);
add_eq_on_columns(ex, row[x_index].var(), row[y_index].var()); add_eq_on_columns(ex, row[x_index].var(), row[y_index].var());
@ -167,8 +230,8 @@ public:
return true; return true;
} }
bool pair_is_reported_or_congruent(lpvar j, lpvar k) const { bool is_equal(lpvar j, lpvar k) const {
return m_imp.congruent_or_irrelevant(lp().column_to_reported_index(j), lp().column_to_reported_index(k)); return m_imp.is_equal(col_to_imp(j), col_to_imp(k));
} }
void check_for_eq_and_add_to_offset_table(vertex* v) { void check_for_eq_and_add_to_offset_table(vertex* v) {
@ -176,7 +239,7 @@ public:
if (m_offset_to_verts.find(v->offset(), k)) { if (m_offset_to_verts.find(v->offset(), k)) {
if (column(k) != column(v) && if (column(k) != column(v) &&
!pair_is_reported_or_congruent(column(k),column(v))) !is_equal(column(k),column(v)))
report_eq(k, v); report_eq(k, v);
} else { } else {
TRACE("cheap_eq", tout << "registered offset " << v->offset() << " to " << v << "\n";); TRACE("cheap_eq", tout << "registered offset " << v->offset() << " to " << v << "\n";);
@ -229,22 +292,57 @@ public:
This equalities are detected by maintaining a map: This equalities are detected by maintaining a map:
(y, k) -> row_id when a row is of the form x = y + k (y, k) -> row_id when a row is of the form x = y + k
If x = k, then y is null_lpvar
This methods checks whether the given row is an offset row (is_offset_row()) This methods checks whether the given row is an offset row (is_offset_row())
and uses the map to find new equalities if that is the case. and uses the map to find new equalities if that is the case.
Some equalities, those spreading more than two rows, can be missed Some equalities, those spreading more than two rows, can be missed
*/ */
// column to theory_var
unsigned col_to_imp(unsigned j) const {
return lp().local_to_external(lp().column_to_reported_index(j));
}
// theory_var to column
unsigned imp_to_col(unsigned j) const {
return lp().external_to_column_index(j);
}
bool is_int(lpvar j) const {
return lp().column_is_int(j);
}
void cheap_eq_table(unsigned rid) { void cheap_eq_table(unsigned rid) {
TRACE("cheap_eqs", tout << "checking if row " << rid << " can propagate equality.\n"; display_row_info(rid, tout);); TRACE("cheap_eqs", tout << "checking if row " << rid << " can propagate equality.\n"; display_row_info(rid, tout););
unsigned x_o; // x offset unsigned x;
unsigned y_o; // y offset unsigned y;
impq k; mpq k;
if (is_offset_row(rid, x_o, y_o, k)) { if (is_offset_row(rid, x, y, k)) {
SASSERT(x_o != UINT_MAX && y_o != UINT_MAX && x_o != y_o); if (y == null_lpvar) {
const auto& row = lp().get_row(rid); // x is an implied fixed var at k.
lpvar x = row[x_o].var(); value_sort_pair key(k, is_int(x));
lpvar y = row[y_o].var(); int x2;
SASSERT(lp().column_is_int(x) == lp().column_is_int(y)); if (m_imp.m_fixed_var_table.find(key, x2) &&
x2 < static_cast<int>(m_imp.get_num_vars())
&&
lp().column_is_fixed(x2 = imp_to_col(x2)) && // change x2
get_lower_bound_rational(x2) == k &&
// We must check whether x2 is an integer.
// The table m_fixed_var_table is not restored during backtrack. So, it may
// contain invalid (key -> value) pairs.
// So, we must check whether x2 is really equal to k (previous test)
// AND has the same sort of x.
is_int(x) == is_int(x2) &&
!is_equal(x, x2)) {
explanation ex;
constraint_index lc, uc;
lp().get_bound_constraint_witnesses_for_column(x2, lc, uc);
ex.push_back(lc);
ex.push_back(uc);
explain_fixed_in_row(rid, ex);
add_eq_on_columns(ex, x, x2);
}
return;
}
if (k.is_zero()) { if (k.is_zero()) {
explanation ex; explanation ex;
explain_fixed_in_row(rid, ex); explain_fixed_in_row(rid, ex);
@ -258,12 +356,11 @@ public:
// it is the same row. // it is the same row.
return; return;
} }
NOT_IMPLEMENTED_YET(); unsigned x2;
/* unsigned y2;
theory_var x2; mpq k2;
theory_var y2; if (is_offset_row(row_id, x2, y2, k2)) {
numeral k2;
if (r2.get_base_var() != null_theory_var && is_offset_row(r2, x2, y2, k2)) {
bool new_eq = false; bool new_eq = false;
#ifdef _TRACE #ifdef _TRACE
bool swapped = false; bool swapped = false;
@ -271,7 +368,7 @@ public:
if (y == y2 && k == k2) { if (y == y2 && k == k2) {
new_eq = true; new_eq = true;
} }
else if (y2 != null_theory_var) { else if (y2 != null_lpvar) {
#ifdef _TRACE #ifdef _TRACE
swapped = true; swapped = true;
#endif #endif
@ -283,24 +380,23 @@ public:
} }
if (new_eq) { if (new_eq) {
if (!is_equal(x, x2) && is_int_src(x) == is_int_src(x2)) { if (!is_equal(x, x2) && is_int(x) == is_int(x2)) {
SASSERT(y == y2 && k == k2); SASSERT(y == y2 && k == k2);
antecedents ante(*this); explanation ex;
collect_fixed_var_justifications(r, ante); explain_fixed_in_row(rid, ex);
collect_fixed_var_justifications(r2, ante); explain_fixed_in_row(row_id, ex);
TRACE("arith_eq", tout << "propagate eq two rows:\n"; TRACE("arith_eq", tout << "propagate eq two rows:\n";
tout << "swapped: " << swapped << "\n"; tout << "swapped: " << swapped << "\n";
tout << "x : v" << x << "\n"; tout << "x : v" << x << "\n";
tout << "x2 : v" << x2 << "\n"; tout << "x2 : v" << x2 << "\n";
display_row_info(tout, r); display_row_info(rid, tout);
display_row_info(tout, r2);); display_row_info(row_id, tout););
m_stats.m_offset_eqs++; add_eq_on_columns(ex, x, x2);
propagate_eq_to_core(x, x2, ante);
} }
return; return;
} }
}*/ }
// the original row was delete or it is not offset row anymore ===> remove it from table // the original row was deleted or it is not offset row anymore ===> remove it from table
m_var_offset2row_id.erase(key); m_var_offset2row_id.erase(key);
} }
// add new entry // add new entry
@ -423,11 +519,11 @@ public:
TRACE("cheap_eq", tout << "row_index = " << row_index << "\n";); TRACE("cheap_eq", tout << "row_index = " << row_index << "\n";);
clear_for_eq(); clear_for_eq();
unsigned x_index, y_index; unsigned x_index, y_index;
impq offset; mpq offset;
if (!is_offset_row(row_index, x_index, y_index, offset)) if (!is_offset_row_wrong(row_index, x_index, y_index, offset))
return; return;
TRACE("cheap_eq", lp().get_int_solver()->display_row_info(tout, row_index);); TRACE("cheap_eq", lp().get_int_solver()->display_row_info(tout, row_index););
m_root = alloc(vertex, row_index, x_index, impq(0)); m_root = alloc(vertex, row_index, x_index, mpq(0));
vertex* v_y = alloc(vertex, row_index, y_index, offset); vertex* v_y = alloc(vertex, row_index, y_index, offset);
m_root->add_child(v_y); m_root->add_child(v_y);
SASSERT(tree_is_correct()); SASSERT(tree_is_correct());
@ -466,8 +562,8 @@ public:
continue; continue;
m_visited_rows.insert(row_index); m_visited_rows.insert(row_index);
unsigned x_index, y_index; unsigned x_index, y_index;
impq row_offset; mpq row_offset;
if (!is_offset_row(row_index, x_index, y_index, row_offset)) if (!is_offset_row_wrong(row_index, x_index, y_index, row_offset))
continue; continue;
TRACE("cheap_eq", lp().get_int_solver()->display_row_info(tout, row_index);); TRACE("cheap_eq", lp().get_int_solver()->display_row_info(tout, row_index););
// who is it the same column? // who is it the same column?

View file

@ -259,6 +259,8 @@ struct numeric_pair {
bool is_neg() const { return x.is_neg() || (x.is_zero() && y.is_neg());} bool is_neg() const { return x.is_neg() || (x.is_zero() && y.is_neg());}
void neg() { x.neg(); y.neg(); }
std::string to_string() const { std::string to_string() const {
return std::string("(") + T_to_string(x) + ", " + T_to_string(y) + ")"; return std::string("(") + T_to_string(x) + ", " + T_to_string(y) + ")";
} }

View file

@ -223,8 +223,7 @@ namespace smt {
theory_var x; theory_var x;
theory_var y; theory_var y;
numeral k; numeral k;
if (is_offset_row(r, x, y, k)) { if (is_offset_row(r, x, y, k)) {
if (y == null_theory_var) { if (y == null_theory_var) {
// x is an implied fixed var at k. // x is an implied fixed var at k.
value_sort_pair key(k, is_int_src(x)); value_sort_pair key(k, is_int_src(x));

View file

@ -1582,8 +1582,8 @@ public:
} }
m_variable_values[t.index()] = result; m_variable_values[t.index()] = result;
return result; return result;
} }
void init_variable_values() { void init_variable_values() {
reset_variable_values(); reset_variable_values();
if (m.inc() && m_solver.get() && th.get_num_vars() > 0) { if (m.inc() && m_solver.get() && th.get_num_vars() > 0) {
@ -1596,19 +1596,6 @@ public:
m_variable_values.clear(); m_variable_values.clear();
} }
bool congruent_or_irrelevant(lpvar k, lpvar j) {
theory_var kv = lp().local_to_external(k);
if (kv == null_theory_var)
return true;
theory_var jv = lp().local_to_external(j);
if (jv == null_theory_var)
return true;
enode * n0 = get_enode(kv);
enode * n1 = get_enode(jv);
return n0->get_root() == n1->get_root();
}
void random_update() { void random_update() {
if (m_nla) if (m_nla)
return; return;
@ -3203,6 +3190,8 @@ public:
return get_enode(x)->get_root() == get_enode(y)->get_root(); return get_enode(x)->get_root() == get_enode(y)->get_root();
} }
unsigned get_num_vars() const { return th.get_num_vars(); }
void fixed_var_eh(theory_var v1, rational const& bound) { void fixed_var_eh(theory_var v1, rational const& bound) {
// IF_VERBOSE(0, verbose_stream() << "fix " << mk_bounded_pp(get_owner(v1), m) << " " << bound << "\n"); // IF_VERBOSE(0, verbose_stream() << "fix " << mk_bounded_pp(get_owner(v1), m) << " " << bound << "\n");