diff --git a/src/math/simplex/model_based_opt.cpp b/src/math/simplex/model_based_opt.cpp index cb3a5be4d..ac7e89d5b 100644 --- a/src/math/simplex/model_based_opt.cpp +++ b/src/math/simplex/model_based_opt.cpp @@ -113,10 +113,16 @@ namespace opt { return result; } - model_based_opt::def model_based_opt::def::substitute(unsigned v, def const& other) const { - def result; + /** + a1*x1 + a2*x2 + a3*x3 + coeff1 / c1 + x2 |-> b1*x1 + b4*x4 + ceoff2 / c2 + ------------------------------------------------------------------------ + (a1*x1 + a2*((b1*x1 + b4*x4 + coeff2) / c2) + a3*x3 + coeff1) / c1 + ------------------------------------------------------------------------ + (c2*a1*x1 + a2*b1*x1 + a2*b4*x4 + c2*a3*x3 + c2*coeff1 + coeff2) / c1*c2 + */ + void model_based_opt::def::substitute(unsigned v, def const& other) { vector const& vs1 = m_vars; - // vector const& vs2 = other.m_vars; rational coeff(0); for (auto const& [id, c] : vs1) { if (id == v) { @@ -124,13 +130,46 @@ namespace opt { break; } } - if (coeff == 0) { - return *this; - } + if (coeff == 0) + return; - NOT_IMPLEMENTED_YET(); - result.normalize(); - return result; + rational c1 = m_div; + rational c2 = other.m_div; + + vector const& vs2 = other.m_vars; + vector vs; + unsigned i = 0, j = 0; + while (i < vs1.size() || j < vs2.size()) { + unsigned v1 = UINT_MAX, v2 = UINT_MAX; + if (i < vs1.size()) v1 = vs1[i].m_id; + if (j < vs2.size()) v2 = vs2[j].m_id; + if (v1 == v) + ++i; + else if (v1 == v2) { + vs.push_back(vs1[i]); + vs.back().m_coeff *= c2; + vs.back().m_coeff += coeff * vs2[j].m_coeff; + ++i; ++j; + if (vs.back().m_coeff.is_zero()) + vs.pop_back(); + } + else if (v1 < v2) { + vs.push_back(vs1[i]); + vs.back().m_coeff *= c2; + ++i; + } + else { + vs.push_back(vs2[j]); + vs.back().m_coeff *= coeff; + ++j; + } + } + m_div *= other.m_div; + m_coeff *= c2; + m_coeff += coeff*other.m_coeff; + m_vars.reset(); + m_vars.append(vs); + normalize(); } model_based_opt::def model_based_opt::def::operator/(rational const& r) const { @@ -1436,15 +1475,20 @@ namespace opt { } - for (unsigned v : vs) - project(v, false); - + for (unsigned v : vs) { + def v_def = project(v, false); + if (compute_def) + eliminate(v, v_def); + } + // project internal variables. def z_def = project(z, compute_def); def y_def = project(y, compute_def); // may depend on z - if (compute_def) { + z_def.substitute(y, y_def); + eliminate(y, y_def); + eliminate(z, z_def); result = (y_def * K) + z_def; m_var2value[x] = eval(result); @@ -1650,14 +1694,20 @@ namespace opt { TRACE("opt", display(tout << "solved v" << x << "\n")); return result; } + + void model_based_opt::eliminate(unsigned v, def const& new_def) { + for (auto & d : m_result) + d.substitute(v, new_def); + } vector model_based_opt::project(unsigned num_vars, unsigned const* vars, bool compute_def) { - vector result; + m_result.reset(); for (unsigned i = 0; i < num_vars; ++i) { - result.push_back(project(vars[i], compute_def)); + m_result.push_back(project(vars[i], compute_def)); + eliminate(vars[i], m_result.back()); TRACE("opt", display(tout << "After projecting: v" << vars[i] << "\n");); } - return result; + return m_result; } } diff --git a/src/math/simplex/model_based_opt.h b/src/math/simplex/model_based_opt.h index 8150f945f..35516283d 100644 --- a/src/math/simplex/model_based_opt.h +++ b/src/math/simplex/model_based_opt.h @@ -86,7 +86,7 @@ namespace opt { def operator/(rational const& n) const; def operator*(rational const& n) const; def operator+(rational const& n) const; - def substitute(unsigned v, def const& other) const; + void substitute(unsigned v, def const& other); void normalize(); }; @@ -101,6 +101,9 @@ namespace opt { unsigned_vector m_lub, m_glb, m_divides, m_mod, m_div; unsigned_vector m_above, m_below; unsigned_vector m_retired_rows; + vector m_result; + + void eliminate(unsigned v, def const& d); bool invariant(); bool invariant(unsigned index, row const& r); diff --git a/src/test/model_based_opt.cpp b/src/test/model_based_opt.cpp index b307f85e4..e2dc74db0 100644 --- a/src/test/model_based_opt.cpp +++ b/src/test/model_based_opt.cpp @@ -392,9 +392,29 @@ static void test11() { } +static void test12() { + opt::model_based_opt::def d1, d2, d3, d4; + typedef opt::model_based_opt::var var; + d1.m_vars.push_back(var(1, rational(4))); + d1.m_vars.push_back(var(2, rational(3))); + d1.m_vars.push_back(var(3, rational(5))); + d1.m_coeff = rational(8); + d1.m_div = rational(7); + std::cout << d1 << "\n"; + d2.m_vars.push_back(var(3, rational(2))); + d2.m_vars.push_back(var(4, rational(2))); + d2.m_div = rational(3); + d2.m_coeff = rational(5); + std::cout << d2 << "\n"; + d1.substitute(2, d2); + std::cout << d1 << "\n"; +} + // test with mix of upper and lower bounds void tst_model_based_opt() { + test12(); + return; test10(); check_random_ineqs(); test1();