3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-06 17:44:08 +00:00
ensure substitutions are applied to eliminate internal variables from results
This commit is contained in:
Nikolaj Bjorner 2022-10-20 13:14:54 -07:00
parent 5976978062
commit edad727cd5
3 changed files with 90 additions and 17 deletions

View file

@ -113,10 +113,16 @@ namespace opt {
return result;
}
model_based_opt::def model_based_opt::def::substitute(unsigned v, def const& other) const {
def result;
/**
a1*x1 + a2*x2 + a3*x3 + coeff1 / c1
x2 |-> b1*x1 + b4*x4 + ceoff2 / c2
------------------------------------------------------------------------
(a1*x1 + a2*((b1*x1 + b4*x4 + coeff2) / c2) + a3*x3 + coeff1) / c1
------------------------------------------------------------------------
(c2*a1*x1 + a2*b1*x1 + a2*b4*x4 + c2*a3*x3 + c2*coeff1 + coeff2) / c1*c2
*/
void model_based_opt::def::substitute(unsigned v, def const& other) {
vector<var> const& vs1 = m_vars;
// vector<var> const& vs2 = other.m_vars;
rational coeff(0);
for (auto const& [id, c] : vs1) {
if (id == v) {
@ -124,13 +130,46 @@ namespace opt {
break;
}
}
if (coeff == 0) {
return *this;
}
if (coeff == 0)
return;
NOT_IMPLEMENTED_YET();
result.normalize();
return result;
rational c1 = m_div;
rational c2 = other.m_div;
vector<var> const& vs2 = other.m_vars;
vector<var> vs;
unsigned i = 0, j = 0;
while (i < vs1.size() || j < vs2.size()) {
unsigned v1 = UINT_MAX, v2 = UINT_MAX;
if (i < vs1.size()) v1 = vs1[i].m_id;
if (j < vs2.size()) v2 = vs2[j].m_id;
if (v1 == v)
++i;
else if (v1 == v2) {
vs.push_back(vs1[i]);
vs.back().m_coeff *= c2;
vs.back().m_coeff += coeff * vs2[j].m_coeff;
++i; ++j;
if (vs.back().m_coeff.is_zero())
vs.pop_back();
}
else if (v1 < v2) {
vs.push_back(vs1[i]);
vs.back().m_coeff *= c2;
++i;
}
else {
vs.push_back(vs2[j]);
vs.back().m_coeff *= coeff;
++j;
}
}
m_div *= other.m_div;
m_coeff *= c2;
m_coeff += coeff*other.m_coeff;
m_vars.reset();
m_vars.append(vs);
normalize();
}
model_based_opt::def model_based_opt::def::operator/(rational const& r) const {
@ -1436,15 +1475,20 @@ namespace opt {
}
for (unsigned v : vs)
project(v, false);
for (unsigned v : vs) {
def v_def = project(v, false);
if (compute_def)
eliminate(v, v_def);
}
// project internal variables.
def z_def = project(z, compute_def);
def y_def = project(y, compute_def); // may depend on z
if (compute_def) {
z_def.substitute(y, y_def);
eliminate(y, y_def);
eliminate(z, z_def);
result = (y_def * K) + z_def;
m_var2value[x] = eval(result);
@ -1650,14 +1694,20 @@ namespace opt {
TRACE("opt", display(tout << "solved v" << x << "\n"));
return result;
}
void model_based_opt::eliminate(unsigned v, def const& new_def) {
for (auto & d : m_result)
d.substitute(v, new_def);
}
vector<model_based_opt::def> model_based_opt::project(unsigned num_vars, unsigned const* vars, bool compute_def) {
vector<def> result;
m_result.reset();
for (unsigned i = 0; i < num_vars; ++i) {
result.push_back(project(vars[i], compute_def));
m_result.push_back(project(vars[i], compute_def));
eliminate(vars[i], m_result.back());
TRACE("opt", display(tout << "After projecting: v" << vars[i] << "\n"););
}
return result;
return m_result;
}
}

View file

@ -86,7 +86,7 @@ namespace opt {
def operator/(rational const& n) const;
def operator*(rational const& n) const;
def operator+(rational const& n) const;
def substitute(unsigned v, def const& other) const;
void substitute(unsigned v, def const& other);
void normalize();
};
@ -101,6 +101,9 @@ namespace opt {
unsigned_vector m_lub, m_glb, m_divides, m_mod, m_div;
unsigned_vector m_above, m_below;
unsigned_vector m_retired_rows;
vector<model_based_opt::def> m_result;
void eliminate(unsigned v, def const& d);
bool invariant();
bool invariant(unsigned index, row const& r);

View file

@ -392,9 +392,29 @@ static void test11() {
}
static void test12() {
opt::model_based_opt::def d1, d2, d3, d4;
typedef opt::model_based_opt::var var;
d1.m_vars.push_back(var(1, rational(4)));
d1.m_vars.push_back(var(2, rational(3)));
d1.m_vars.push_back(var(3, rational(5)));
d1.m_coeff = rational(8);
d1.m_div = rational(7);
std::cout << d1 << "\n";
d2.m_vars.push_back(var(3, rational(2)));
d2.m_vars.push_back(var(4, rational(2)));
d2.m_div = rational(3);
d2.m_coeff = rational(5);
std::cout << d2 << "\n";
d1.substitute(2, d2);
std::cout << d1 << "\n";
}
// test with mix of upper and lower bounds
void tst_model_based_opt() {
test12();
return;
test10();
check_random_ineqs();
test1();