From 8a49002f6089d59abba370f029dda2a9180ba15f Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 25 Aug 2024 18:33:01 -0700 Subject: [PATCH] reorg monomials Signed-off-by: Nikolaj Bjorner --- src/ast/sls/sls_arith_base.cpp | 34 ++++++++++++++++++++++++++++------ src/ast/sls/sls_arith_base.h | 6 +++++- src/ast/sls/sls_context.cpp | 12 ++++++++++++ src/ast/sls/sls_context.h | 3 +++ src/ast/sls/sls_smt_solver.cpp | 12 ++++++++++-- 5 files changed, 58 insertions(+), 9 deletions(-) diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 6fab41a98..4dd6b62e6 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -63,8 +63,23 @@ namespace sls { template std::ostream& arith_base::ineq::display(std::ostream& out) const { bool first = true; - for (auto const& [c, v] : this->m_args) - out << (first ? "" : " + ") << c << " * v" << v, first = false; + unsigned j = 0; + for (auto const& [c, v] : this->m_args) { + out << (first ? (c > 0 ? "" : "-") : (c > 0 ? " + " : " - ")); + bool first2 = abs(c) == 1; + if (abs(c) != 1) + out << abs(c); + auto const& m = this->m_monomials[j]; + + for (auto [w, p] : m) { + out << (first2 ? "" : " * ") << "v" << w; + if (p > 1) + out << "^" << p; + first2 = false; + } + first = false; + ++j; + } if (this->m_coeff != 0) out << " + " << this->m_coeff; switch (m_op) { @@ -78,15 +93,17 @@ namespace sls { out << " < " << 0 << "(" << m_args_value << ")"; break; } +#if 0 for (auto const& [x, nl] : this->m_nonlinear) { if (nl.size() == 1 && nl[0].v == x) continue; for (auto const& [v, c, p] : nl) { out << " v" << x; if (p > 1) out << "^" << p; - out << " in " << c << " * v" << v; + out << " in v" << v; } } +#endif return out; } @@ -1058,6 +1075,14 @@ namespace sls { i.m_args[k++] = i.m_args[j]; } i.m_args.shrink(k); + i.m_monomials.reserve(k); + for (unsigned j = 0; j < i.m_args.size(); ++j) { + auto const& [c, v] = i.m_args[j]; + if (is_mul(v)) + i.m_monomials[j].append(get_mul(v).m_monomial); + else + i.m_monomials[j].push_back({ v, 1 }); + } // compute the value of the linear term, and accumulate non-linear sub-terms i.m_args_value = i.m_coeff; for (auto const& [coeff, v] : i.m_args) { @@ -1896,9 +1921,6 @@ namespace sls { template void arith_base::on_restart() { - for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) - init_bool_var_assignment(v); - check_ineqs(); } template diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index f42d3b318..7fed38786 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -58,9 +58,13 @@ namespace sls { num_t coeff; // coeff of v in inequality unsigned p; // power }; + + typedef svector> monomial_t; + // encode args <= bound, args = bound, args < bound struct ineq : public linear_term { vector>> m_nonlinear; + vector m_monomials; ineq_kind m_op = ineq_kind::LE; num_t m_args_value; bool m_is_linear = true; @@ -120,7 +124,7 @@ namespace sls { struct mul_def { unsigned m_var; - svector> m_monomial; + monomial_t m_monomial; }; struct add_def : public linear_term { diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index fdf364ed5..62c206812 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -22,6 +22,7 @@ Author: #include "ast/sls/sls_basic_plugin.h" #include "ast/ast_ll_pp.h" #include "ast/ast_pp.h" +#include "smt/params/smt_params_helper.hpp" namespace sls { @@ -42,6 +43,11 @@ namespace sls { register_plugin(alloc(basic_plugin, *this)); } + void context::updt_params(params_ref const& p) { + smt_params_helper smtp(p); + m_rand.set_seed(smtp.random_seed()); + } + void context::register_plugin(plugin* p) { m_plugins.reserve(p->fid() + 1); m_plugins.set(p->fid(), p); @@ -51,6 +57,12 @@ namespace sls { m_atoms.setx(v, e); m_atom2bool_var.setx(e->get_id(), v, sat::null_bool_var); } + + void context::on_restart() { + for (auto p : m_plugins) + if (p) + p->on_restart(); + } lbool context::check() { // diff --git a/src/ast/sls/sls_context.h b/src/ast/sls/sls_context.h index 599f25e77..b0afcdde5 100644 --- a/src/ast/sls/sls_context.h +++ b/src/ast/sls/sls_context.h @@ -137,6 +137,9 @@ namespace sls { void register_atom(sat::bool_var v, expr* e); lbool check(); + void on_restart(); + void updt_params(params_ref const& p); + // expose sat_solver to plugins vector const& clauses() const { return s.clauses(); } sat::clause_info const& get_clause(unsigned idx) const { return s.get_clause(idx); } diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp index 1960b809c..c0e242685 100644 --- a/src/ast/sls/sls_smt_solver.cpp +++ b/src/ast/sls/sls_smt_solver.cpp @@ -46,7 +46,9 @@ namespace sls { void on_rescale() override {} - void on_restart() override {} + void on_restart() override { + m_context.on_restart(); + } bool m_on_save_model = false; void on_save_model() override { @@ -114,13 +116,19 @@ namespace sls { m_ddfw.reset_statistics(); m_context.reset_statistics(); } + + void updt_params(params_ref const& p) { + m_ddfw.updt_params(p); + m_context.updt_params(p); + } }; smt_solver::smt_solver(ast_manager& m, params_ref const& p): m(m), m_solver_ctx(alloc(solver_ctx, m, m_ddfw)), m_assertions(m) { - m_ddfw.updt_params(p); + + m_solver_ctx->updt_params(p); } smt_solver::~smt_solver() {