From de8faa231f622698b8251968bbf4373c98574615 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 16 Aug 2024 16:48:12 -0700 Subject: [PATCH] fixes to ite and other Signed-off-by: Nikolaj Bjorner --- src/ast/sls/sls_arith_base.cpp | 87 ++++++++++++++++++++++++-------- src/ast/sls/sls_arith_base.h | 1 + src/ast/sls/sls_arith_plugin.cpp | 6 ++- src/ast/sls/sls_basic_plugin.cpp | 7 ++- src/ast/sls/sls_context.cpp | 19 +++++-- src/ast/sls/sls_context.h | 2 +- src/ast/sls/sls_smt_solver.cpp | 30 ++++------- src/tactic/sls/sls_tactic.cpp | 1 + src/util/checked_int64.h | 5 +- 9 files changed, 108 insertions(+), 50 deletions(-) diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 5061cd46b..2fd0ebe33 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -42,6 +42,7 @@ Done: #include "ast/sls/sls_arith_base.h" #include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" namespace sls { @@ -51,7 +52,7 @@ namespace sls { case ineq_kind::LE: return m_args_value <= 0; case ineq_kind::EQ: - return m_args_value== 0; + return m_args_value == 0; default: return m_args_value < 0; } @@ -431,7 +432,7 @@ namespace sls { if (m_last_var == v && m_last_delta == -delta) return false; - if (false && m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) + if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) return false; auto old_value = value(v); @@ -625,7 +626,7 @@ namespace sls { return false; } - IF_VERBOSE(10, display(verbose_stream(), v) << " := " << new_value << "\n"); + // IF_VERBOSE(0, display(verbose_stream(), v) << " := " << new_value << "\n"); @@ -965,10 +966,9 @@ namespace sls { template void arith_base::init_bool_var(sat::bool_var bv) { + expr* e = ctx.atom(bv); if (m_bool_vars.get(bv, nullptr)) return; - expr* e = ctx.atom(bv); - // verbose_stream() << "bool var " << bv << " " << mk_bounded_pp(e, m) << "\n"; if (!e) return; expr* x, * y; @@ -1069,7 +1069,7 @@ namespace sls { // attach i to bv m_bool_vars.set(bv, &i); - } + } template void arith_base::init_bool_var_assignment(sat::bool_var v) { @@ -1209,6 +1209,7 @@ namespace sls { auto const& vi = m_vars[v]; if (vi.m_lo || vi.m_hi) continue; + expr* e = vi.m_expr; if (is_add(v)) { auto const& ad = get_add(v); num_t lo(ad.m_coeff), hi(ad.m_coeff); @@ -1261,10 +1262,11 @@ namespace sls { if (!lo_valid && !hi_valid) break; auto const& wi = m_vars[w]; - if (lo_valid) { - // TODO + if (wi.m_lo && !wi.m_lo->is_strict && wi.m_lo->value >= 0) + lo *= power_of(value(w), p); + else lo_valid = false; - } + if (hi_valid) { // TODO hi_valid = false; @@ -1283,6 +1285,20 @@ namespace sls { add_le(v, hi); } } + expr* c, * th, * el; + if (m.is_ite(e, c, th, el)) { + auto vth = m_expr2var.get(th->get_id(), UINT_MAX); + auto vel = m_expr2var.get(el->get_id(), UINT_MAX); + if (vth == UINT_MAX || vel == UINT_MAX) + continue; + auto const& vith = m_vars[vth]; + auto const& viel = m_vars[vel]; + if (vith.m_lo && viel.m_lo && !vith.m_lo->is_strict && !viel.m_lo->is_strict) + add_ge(v, std::min(vith.m_lo->value, viel.m_lo->value)); + if (vith.m_hi && viel.m_hi && !vith.m_hi->is_strict && !viel.m_hi->is_strict) + add_le(v, std::max(vith.m_hi->value, viel.m_hi->value)); + + } // TBD: can also do with other operators. } } @@ -1296,7 +1312,7 @@ namespace sls { if (ineq->m_args.size() != 1) return; - auto [c, v] = ineq->m_args[0]; + auto [c, v] = ineq->m_args[0]; switch (ineq->m_op) { case ineq_kind::LE: @@ -1396,17 +1412,26 @@ namespace sls { if (old_value == sum) return true; + //display(verbose_stream() << "repair add v" << v << " ", ad) << " " << old_value << " sum " << sum << "\n"; + m_updates.reset(); // display(verbose_stream(), v) << " "; // verbose_stream() << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << old_value << " " << sum << "\n"; - for (auto const& [coeff, w] : coeffs) - add_update(v, divide(w, old_value - sum, coeff)); - + for (auto const& [coeff, w] : coeffs) { + auto delta = divide(w, sum - old_value, coeff); + if (sum == coeff*delta + old_value) + add_update(w, delta); + } if (apply_update()) return eval_is_correct(v); m_updates.reset(); + for (auto const& [coeff, w] : coeffs) { + auto delta = divide(w, sum - old_value, coeff); + if (sum != coeff*delta + old_value) + add_update(w, delta); + } for (auto const& [coeff, w] : coeffs) if (is_mul(w)) { auto const& [w1, c, monomial] = get_mul(w); @@ -1433,7 +1458,6 @@ namespace sls { if (product == val) return true; - m_updates.reset(); if (val == 0) { for (auto [x, p] : monomial) @@ -1459,6 +1483,8 @@ namespace sls { } } + // verbose_stream() << "repair product v" << v << "\n"; + if (apply_update()) return eval_is_correct(v); @@ -1872,6 +1898,23 @@ namespace sls { return true; } + template + std::ostream& arith_base::display(std::ostream& out, mul_def const& md) const { + auto const& [w, coeff, monomial] = md; + bool first = true; + if (coeff != 1) + out << coeff, first = false; + for (auto [v, p] : monomial) { + if (!first) + out << " * "; + out << "v" << v; + if (p > 1) + out << "^" << p; + first = false; + } + return out; + } + template std::ostream& arith_base::display(std::ostream& out, add_def const& ad) const { bool first = true; @@ -1893,7 +1936,9 @@ namespace sls { first = false; out << "v" << w; } - if (ad.m_coeff > 0) + if (ad.m_args.empty()) + out << ad.m_coeff; + else if (ad.m_coeff > 0) out << " + " << ad.m_coeff; else if (ad.m_coeff < 0) out << " - " << -ad.m_coeff; @@ -1921,6 +1966,8 @@ namespace sls { out << mk_bounded_pp(vi.m_expr, m) << " "; if (is_add(v)) display(out << "add: ", get_add(v)) << " "; + if (is_mul(v)) + display(out << "mul: ", get_mul(v)) << " "; if (!vi.m_adds.empty()) { out << " adds: "; @@ -1936,10 +1983,11 @@ namespace sls { out << " "; } - if (!vi.m_bool_vars.empty()) + if (!vi.m_bool_vars.empty()) { out << " bool: "; - for (auto [c, bv] : vi.m_bool_vars) - out << c << "@" << bv << " "; + for (auto [c, bv] : vi.m_bool_vars) + out << c << "@" << bv << " "; + } return out; } @@ -1964,8 +2012,6 @@ namespace sls { out << "\n"; } - for (auto ad : m_adds) - display(out, ad) << "\n"; for (auto od : m_ops) { out << "v" << od.m_var << " := "; @@ -2077,6 +2123,7 @@ namespace sls { out << "v" << ad.m_var << " := "; display(out, ad) << "\n"; } + UNREACHABLE(); } } diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index 5c0988a99..f219c2d8b 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -250,6 +250,7 @@ namespace sls { void add_gt(var_t v, num_t const& n); std::ostream& display(std::ostream& out, var_t v) const; std::ostream& display(std::ostream& out, add_def const& ad) const; + std::ostream& display(std::ostream& out, mul_def const& md) const; public: arith_base(context& ctx); ~arith_base() override {} diff --git a/src/ast/sls/sls_arith_plugin.cpp b/src/ast/sls/sls_arith_plugin.cpp index 120930c02..f48c85642 100644 --- a/src/ast/sls/sls_arith_plugin.cpp +++ b/src/ast/sls/sls_arith_plugin.cpp @@ -49,7 +49,11 @@ namespace sls { plugin(ctx), m_shared(ctx.get_manager()) { m_arith64 = alloc(arith_base>, ctx); m_arith = alloc(arith_base, ctx); - m_fid = m_arith->fid(); + m_arith64 = nullptr; + if (m_arith) + m_fid = m_arith->fid(); + else + m_fid = m_arith64->fid(); } void arith_plugin::init_backup() { diff --git a/src/ast/sls/sls_basic_plugin.cpp b/src/ast/sls/sls_basic_plugin.cpp index 712ce742a..bb876c44f 100644 --- a/src/ast/sls/sls_basic_plugin.cpp +++ b/src/ast/sls/sls_basic_plugin.cpp @@ -29,7 +29,7 @@ namespace sls { bool basic_plugin::is_basic(expr* e) const { if (!e || !is_app(e)) return false; - if (m.is_ite(e) && !m.is_bool(e)) + if (m.is_ite(e) && !m.is_bool(e) && false) return true; if (m.is_xor(e) && to_app(e)->get_num_args() != 2) return true; @@ -42,6 +42,11 @@ namespace sls { } void basic_plugin::register_term(expr* e) { + expr* c, * th, * el; + if (m.is_ite(e, c, th, el) && !m.is_bool(e)) { + ctx.add_clause(m.mk_or(mk_not(m, c), m.mk_eq(e, th))); + ctx.add_clause(m.mk_or(c, m.mk_eq(e, el))); + } } void basic_plugin::initialize() { diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp index 4d3704db0..c08f231d8 100644 --- a/src/ast/sls/sls_context.cpp +++ b/src/ast/sls/sls_context.cpp @@ -48,7 +48,7 @@ namespace sls { } void context::register_atom(sat::bool_var v, expr* e) { - m_atoms.setx(v, e); + m_atoms.setx(v, e); m_atom2bool_var.setx(e->get_id(), v, sat::null_bool_var); } @@ -225,7 +225,6 @@ namespace sls { void context::add_clause(expr* f) { expr_ref _e(f, m); - verbose_stream() << "add constraint " << _e << "\n"; expr* g, * h, * k; sat::literal_vector clause; if (m.is_not(f, g) && m.is_not(g, g)) { @@ -289,6 +288,7 @@ namespace sls { } sat::literal context::mk_literal(expr* e) { + expr_ref _e(e, m); sat::literal lit; bool neg = false; expr* a, * b, * c; @@ -299,6 +299,7 @@ namespace sls { return sat::literal(v, neg); sat::literal_vector clause; lit = mk_literal(); + register_atom(lit.var(), e); if (m.is_true(e)) { clause.push_back(lit); s.add_clause(clause.size(), clause.data()); @@ -355,9 +356,7 @@ namespace sls { s.add_clause(3, cls4); } else - register_terms(e); - - register_atom(lit.var(), e); + register_terms(e); return neg ? ~lit : lit; } @@ -392,6 +391,8 @@ namespace sls { return; m_subterms.reset(); m_todo.push_back(e); + if (m_todo.size() > 1) + return; while (!m_todo.empty()) { expr* e = m_todo.back(); if (is_visited(e)) @@ -402,6 +403,8 @@ namespace sls { m_parents.reserve(arg->get_id() + 1); m_parents[arg->get_id()].push_back(e); } + if (m.is_bool(e)) + mk_literal(e); register_term(e); visit(e); m_todo.pop_back(); @@ -494,6 +497,12 @@ namespace sls { out << "d " << mk_bounded_pp(term(id), m) << "\n"; for (auto id : m_repair_up) out << "u " << mk_bounded_pp(term(id), m) << "\n"; + for (unsigned v = 0; v < m_atoms.size(); ++v) { + auto e = m_atoms[v]; + if (e) + out << v << ": " << mk_bounded_pp(e, m) << " := " << (is_true(v)?"T":"F") << "\n"; + + } for (auto p : m_plugins) if (p) p->display(out); diff --git a/src/ast/sls/sls_context.h b/src/ast/sls/sls_context.h index 52fc8133c..0a7d7bb5b 100644 --- a/src/ast/sls/sls_context.h +++ b/src/ast/sls/sls_context.h @@ -141,7 +141,7 @@ namespace sls { double get_weight(unsigned clause_idx) { return s.get_weigth(clause_idx); } unsigned num_bool_vars() const { return s.num_vars(); } bool is_true(sat::literal lit) { return s.is_true(lit); } - bool is_true(sat::bool_var v) { return s.is_true(sat::literal(v, false)); } + bool is_true(sat::bool_var v) const { return s.is_true(sat::literal(v, false)); } expr* atom(sat::bool_var v) { return m_atoms.get(v, nullptr); } expr* term(unsigned id) const { return m_allterms.get(id); } sat::bool_var atom2bool_var(expr* e) const { return m_atom2bool_var.get(e->get_id(), sat::null_bool_var); } diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp index 34dd06e63..d1f99d149 100644 --- a/src/ast/sls/sls_smt_solver.cpp +++ b/src/ast/sls/sls_smt_solver.cpp @@ -28,6 +28,7 @@ namespace sls { sat::ddfw& m_ddfw; context m_context; bool m_dirty = false; + bool m_new_constraint = false; model_ref m_model; obj_map m_expr2lit; public: @@ -55,11 +56,11 @@ namespace sls { TRACE("sls", display(tout)); while (unsat().empty()) { m_context.check(); - if (!m_dirty) + if (!m_new_constraint) break; TRACE("sls", display(tout)); m_ddfw.reinit(); - m_dirty = false; + m_new_constraint = false; } } @@ -87,16 +88,12 @@ namespace sls { bool is_true(sat::literal lit) override { return m_ddfw.get_value(lit.var()) != lit.sign(); } unsigned num_vars() const override { return m_ddfw.num_vars(); } indexed_uint_set const& unsat() const override { return m_ddfw.unsat_set(); } - sat::bool_var add_var() override { m_dirty = true; return m_ddfw.add_var(); } - - - void add_clause(expr* f) { - m_context.add_clause(f); - } + sat::bool_var add_var() override { m_dirty = true; return m_ddfw.add_var(); } + void add_clause(expr* f) { m_context.add_clause(f); } void add_clause(unsigned n, sat::literal const* lits) override { m_ddfw.add(n, lits); - m_dirty = true; + m_new_constraint = true; } sat::literal mk_literal() { @@ -124,8 +121,7 @@ namespace sls { m_ddfw.updt_params(p); } - smt_solver::~smt_solver() { - + smt_solver::~smt_solver() { } void smt_solver::assert_expr(expr* e) { @@ -137,17 +133,11 @@ namespace sls { m_assertions.push_back(e); } - lbool smt_solver::check() { - // send clauses to ddfw - // send expression mapping to m_solver_ctx - + lbool smt_solver::check() { for (auto f : m_assertions) - m_solver_ctx->add_clause(f); - + m_solver_ctx->add_clause(f); IF_VERBOSE(10, m_solver_ctx->display(verbose_stream())); - auto r = m_ddfw.check(0, nullptr); - - return r; + return m_ddfw.check(0, nullptr); } model_ref smt_solver::get_model() { diff --git a/src/tactic/sls/sls_tactic.cpp b/src/tactic/sls/sls_tactic.cpp index a0f5f3b76..fdb4620c8 100644 --- a/src/tactic/sls/sls_tactic.cpp +++ b/src/tactic/sls/sls_tactic.cpp @@ -82,6 +82,7 @@ public: m_sls->collect_statistics(m_st); throw; } + m_sls->collect_statistics(m_st); // report_tactic_progress("Number of flips:", m_sls->get_num_moves()); IF_VERBOSE(10, verbose_stream() << res << "\n"); diff --git a/src/util/checked_int64.h b/src/util/checked_int64.h index 37f7c5526..31ef5bdd6 100644 --- a/src/util/checked_int64.h +++ b/src/util/checked_int64.h @@ -166,10 +166,11 @@ public: uint64_t r = x * y; if ((y != 0 && r / y != x) || r > INT64_MAX) throw overflow_exception(); + int64_t old_value = m_value; m_value = r; - if (m_value < 0 && other.m_value > 0) + if (old_value < 0 && other.m_value > 0) m_value = -m_value; - if (m_value > 0 && other.m_value < 0) + else if (old_value > 0 && other.m_value < 0) m_value = -m_value; } }