From 1cd95e9db495a10c346a38bca27c3ce05272bd37 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 14 Jul 2024 16:51:06 -0700 Subject: [PATCH] add sls-sms solver --- src/ast/sls/sls_basic_plugin.cpp | 39 +++++---- src/ast/sls/sls_smt_solver.cpp | 146 +++++++++++++++++++++++++++++++ src/ast/sls/sls_smt_solver.h | 46 ++++++++++ 3 files changed, 212 insertions(+), 19 deletions(-) create mode 100644 src/ast/sls/sls_smt_solver.cpp create mode 100644 src/ast/sls/sls_smt_solver.h diff --git a/src/ast/sls/sls_basic_plugin.cpp b/src/ast/sls/sls_basic_plugin.cpp index 278fa55b9..ef462a9cb 100644 --- a/src/ast/sls/sls_basic_plugin.cpp +++ b/src/ast/sls/sls_basic_plugin.cpp @@ -28,7 +28,8 @@ namespace sls { auto a = ctx.atom(lit.var()); if (!a || !is_app(a)) return; - SASSERT(to_app(a)->get_family_id() != basic_family_id); + if (to_app(a)->get_family_id() != basic_family_id) + return; if (bval1(to_app(a)) != bval0(to_app(a))) ctx.new_value_eh(a); } @@ -62,10 +63,7 @@ namespace sls { void basic_plugin::set_value(expr* e, expr* v) { if (!m.is_bool(e)) return; - SASSERT(m.is_bool(v)); SASSERT(m.is_true(v) || m.is_false(v)); - if (bval0(e) != m.is_true(v)) - return; set_value(e, m.is_true(v)); } @@ -198,16 +196,21 @@ namespace sls { } bool basic_plugin::try_repair_ite(app* e, unsigned i) { - auto child = e->get_arg(i); - bool c = bval0(e->get_arg(0)); - if (i == 0) - return set_value(child, !c); - - if (c != (i == 1)) + if (!m.is_bool(e)) return false; - if (m.is_bool(e)) - return set_value(child, bval0(e)); - return false; + auto child = e->get_arg(i); + auto cond = e->get_arg(0); + bool c = bval0(cond); + if (i == 0) { + if (ctx.rand(2) == 0) + return set_value(cond, true) && set_value(e->get_arg(1), bval0(e)); + else + return set_value(cond, false) && set_value(e->get_arg(2), bval0(e)); + } + + if (!set_value(child, bval0(e))) + return false; + return (c == (i == 1)) || set_value(cond, !c); } bool basic_plugin::try_repair_implies(app* e, unsigned i) { @@ -234,12 +237,10 @@ namespace sls { return true; if (i == 1 && bv && !av) return true; - if (i == 0) { - return set_value(child, true) && set_value(sibling, false); - } - if (i == 1) { - return set_value(child, false) && set_value(sibling, true); - } + if (i == 0) + return set_value(child, true) && set_value(sibling, false); + if (i == 1) + return set_value(child, false) && set_value(sibling, true); return false; } diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp new file mode 100644 index 000000000..2f287a680 --- /dev/null +++ b/src/ast/sls/sls_smt_solver.cpp @@ -0,0 +1,146 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_smt_solver.cpp + +Abstract: + + A Stochastic Local Search (SLS) Solver. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-10 + +--*/ + +#pragma once +#include "ast/sls/sls_context.h" +#include "ast/sls/sat_ddfw.h" +#include "ast/sls/sls_smt_solver.h" + + +namespace sls { + + class smt_solver::solver_ctx : public sat::local_search_plugin, public sls::sat_solver_context { + ast_manager& m; + sat::ddfw& m_ddfw; + context m_context; + bool m_new_clause_added = false; + model_ref m_model; + public: + solver_ctx(ast_manager& m, sat::ddfw& d) : + m(m), m_ddfw(d), m_context(m, *this) { + m_ddfw.set_plugin(this); + } + + ~solver_ctx() override { + } + + void init_search() override {} + + void finish_search() override {} + + void on_rescale() override {} + + void on_restart() override {} + + void on_save_model() override { + TRACE("sls", display(tout)); + while (unsat().empty()) { + m_context.check(); + if (!m_new_clause_added) + break; + m_ddfw.reinit(); + m_new_clause_added = false; + } + } + + void on_model(model_ref& mdl) override { + IF_VERBOSE(1, verbose_stream() << "on-model " << "\n"); + m_model = mdl; + } + + void register_atom(sat::bool_var v, expr* e) { + m_context.register_atom(v, e); + } + + std::ostream& display(std::ostream& out) { + m_ddfw.display(out); + m_context.display(out); + return out; + } + + vector const& clauses() const override { return m_ddfw.clauses(); } + sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw.get_clause_info(idx); } + std::initializer_list get_use_list(sat::literal lit) override { return m_ddfw.use_list(lit); } + void flip(sat::bool_var v) override { m_ddfw.flip(v); } + double reward(sat::bool_var v) override { return m_ddfw.get_reward(v); } + double get_weigth(unsigned clause_idx) override { return m_ddfw.get_clause_info(clause_idx).m_weight; } + bool is_true(sat::literal lit) override { return m_ddfw.get_value(lit.var()) != lit.sign(); } + unsigned num_vars() const override { return m_ddfw.num_vars(); } + indexed_uint_set const& unsat() const override { return m_ddfw.unsat_set(); } + sat::bool_var add_var() override { return m_ddfw.add_var(); } + void add_clause(unsigned n, sat::literal const* lits) override { + m_ddfw.add(n, lits); + m_new_clause_added = true; + } + model_ref get_model() { return m_model; } + }; + + smt_solver::smt_solver(ast_manager& m, params_ref const& p): + m(m), + m_solver_ctx(alloc(solver_ctx, m, m_ddfw)), + m_assertions(m) { + m_ddfw.updt_params(p); + } + + smt_solver::~smt_solver() { + + } + + void smt_solver::assert_expr(expr* e) { + m_assertions.push_back(e); + } + + lbool smt_solver::check() { + // send clauses to ddfw + // send expression mapping to m_solver_ctx + + sat::literal_vector clause; + for (auto f : m_assertions) { + if (m.is_or(f)) { + clause.reset(); + for (auto arg : *to_app(f)) + clause.push_back(mk_literal(arg)); + m_solver_ctx->add_clause(clause.size(), clause.data()); + } + else { + sat::literal lit = mk_literal(f); + m_solver_ctx->add_clause(1, &lit); + } + } + IF_VERBOSE(10, m_solver_ctx->display(verbose_stream())); + return m_ddfw.check(0, nullptr); + } + + sat::literal smt_solver::mk_literal(expr* e) { + bool neg = m.is_not(e, e); + sat::bool_var v; + if (!m_expr2var.find(e, v)) { + v = m_expr2var.size(); + m_expr2var.insert(e, v); + m_solver_ctx->register_atom(v, e); + } + return sat::literal(v, neg); + } + + model_ref smt_solver::get_model() { + return m_solver_ctx->get_model(); + } + + std::ostream& smt_solver::display(std::ostream& out) { + return m_solver_ctx->display(out); + } +} diff --git a/src/ast/sls/sls_smt_solver.h b/src/ast/sls/sls_smt_solver.h new file mode 100644 index 000000000..ebac65021 --- /dev/null +++ b/src/ast/sls/sls_smt_solver.h @@ -0,0 +1,46 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_smt_solver.h + +Abstract: + + A Stochastic Local Search (SLS) Solver. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-10 + +--*/ + +#pragma once +#include "ast/sls/sls_context.h" +#include "ast/sls/sat_ddfw.h" + + +namespace sls { + + class smt_solver { + ast_manager& m; + class solver_ctx; + sat::ddfw m_ddfw; + solver_ctx* m_solver_ctx = nullptr; + expr_ref_vector m_assertions; + statistics m_st; + obj_map m_expr2var; + + sat::literal mk_literal(expr* e); + public: + smt_solver(ast_manager& m, params_ref const& p); + ~smt_solver(); + void assert_expr(expr* e); + lbool check(); + model_ref get_model(); + void updt_params(params_ref& p) {} + void collect_statistics(statistics& st) { st.copy(m_st); } + std::ostream& display(std::ostream& out); + void reset_statistics() { m_st.reset(); } + }; +}