3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2026-05-22 01:49:36 +00:00
* reorg sls

* sls

* na

* split into base and plugin

* move sat_params to params directory, add op_def repair options

* move sat_ddfw to sls, initiate sls-bv-plugin

* porting bv-sls

* adding basic plugin

* na

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* add sls-sms solver

* bv updates

* updated dependencies

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* updated dependencies

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* use portable ptr-initializer

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* move definitions to cpp

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* use template<> syntax

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fix compiler errors for gcc

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* Bump docker/build-push-action from 6.0.0 to 6.1.0 (#7265)

Bumps [docker/build-push-action](https://github.com/docker/build-push-action) from 6.0.0 to 6.1.0.
- [Release notes](https://github.com/docker/build-push-action/releases)
- [Commits](https://github.com/docker/build-push-action/compare/v6.0.0...v6.1.0)

---
updated-dependencies:
- dependency-name: docker/build-push-action
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* set clean shutdown for local search and re-enable local search when it parallelizes with PB solver

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* Bump docker/build-push-action from 6.1.0 to 6.2.0 (#7269)

Bumps [docker/build-push-action](https://github.com/docker/build-push-action) from 6.1.0 to 6.2.0.
- [Release notes](https://github.com/docker/build-push-action/releases)
- [Commits](https://github.com/docker/build-push-action/compare/v6.1.0...v6.2.0)

---
updated-dependencies:
- dependency-name: docker/build-push-action
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Fix a comment for Z3_solver_from_string (#7271)

Z3_solver_from_string accepts a string buffer with solver
assertions, not a string buffer with filename.

* trigger the build with a comment change

Signed-off-by: Lev Nachmanson <levnach@hotmail.com>

* remove macro distinction #7270

* fix #7268

* kludge to address #7232, probably superseeded by planned revision to setup/pypi

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* add new ema invariant (#7288)

* Bump docker/build-push-action from 6.2.0 to 6.3.0 (#7280)

Bumps [docker/build-push-action](https://github.com/docker/build-push-action) from 6.2.0 to 6.3.0.
- [Release notes](https://github.com/docker/build-push-action/releases)
- [Commits](https://github.com/docker/build-push-action/compare/v6.2.0...v6.3.0)

---
updated-dependencies:
- dependency-name: docker/build-push-action
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* merge

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fix unit test build

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* remove shared attribute

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* remove stale files

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fix build of unit test

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixes and rename sls-cc to sls-euf-plugin

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* na

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* testing / debugging arithmetic

* updates to repair logic, mainly arithmetic

* fixes to sls

* evolve sls arith

* bugfixes in sls-arith

* fix typo

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* bug fixes

* Update sls_test.cpp

* fixes

* fixes

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fix build

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* refactor basic plugin and clause generation

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixes to ite and other

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* updates

* update

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fix division by 0

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* disable fail restart

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* disable tabu when using reset moves

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* update sls_test

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* add factoring

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixes to semantics

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* re-add tabu override

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* generalize factoring

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fix bug

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* remove restart

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* disable tabu in fallback modes

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* localize impact of factoring

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* delay factoring

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* flatten products

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* perform lookahead update + nested mul

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* disable nested mul

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* disable nested mul, use non-lookahead

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* make reset updates recursive

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* include linear moves

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* include 5% reset probability

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* separate linear update

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* separate linear update remove 20% threshold

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* remove linear opt

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* enable multiplier expansion, enable linear move

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* use unit coefficients for muls

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* disable non-tabu version of find_nl_moves

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* remove coefficient from multiplication definition

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* reorg monomials

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* add smt params to path

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* avoid negative reward

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* use reward as proxy for score

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* use reward as proxy for score

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* use exponential decay with breaks

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* use std::pow

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixes to bv

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixes to fixed

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixup repairs

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* reserve for multiplication

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixing repair

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* include bounds checks in set random

* na

* fixes to mul

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fix mul inverse

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixes to handling signed operators

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* logging and fixes

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* gcm

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* peli

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* Add .env to gitignore to prevent environment files from being tracked

* Add m_num_pelis counter to stats in sls_context

* Remove m_num_pelis member from stats struct in sls_context

* Enhance bv_sls_eval with improved repair and logging, refine is_bv_predicate in sls_bv_plugin

* Remove verbose logging in register_term function of sls_basic_plugin and fix formatting in sls_context

* Rename source files for consistency in `src/ast/sls` directory

* Refactor bv_sls files to sls_bv with namespace and class name adjustments

* Remove typename from member declarations in bv_fixed class

* fixing conca

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* Add initial implementation of bit-vector SLS evaluation module in bv_sls_eval.cpp

* Remove bv_sls_eval.cpp as part of code cleanup and refactoring

* Refactor alignment of member variables in bv_plugin of sls namespace

* Rename SLS engine related files to reflect their specific use for bit-vectors

* Refactor SLS engine and evaluator components for bit-vector specifics and adjust memory manager alignment

* Enhance bv_eval with use_current, lookahead strategies, and randomization improvements in SLS module

* Refactor verbose logging and fix logic in range adjustment functions in sls bv modules

* Remove commented verbose output in sls_bv_plugin.cpp during repair process

* Add early return after setting fixed subterms in sls_bv_fixed.cpp

* Remove redundant return statement in sls_bv_fixed.cpp

* fixes to new value propagation

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* Refactor sls bv evaluation and fix logic checks for bit operations

* Add array plugin support and update bv_eval in ast_sls module

* Add array, model value, and user sort plugins to SLS module with enhancements in array propagation logic

* Refactor array_plugin in sls to improve handling of select expressions with multiple arguments

* Enhance array plugin with early termination and propagation verification, and improve euf and user sort plugins with propagation adjustments and debugging enhancements

* Add support for handling 'distinct' expressions in SLS context and user sort plugin

* Remove model value and user sort plugins from SLS theory

* replace user plugin by euf plugin

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* remove extra file

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* Refactor handling of term registration and enhance distinct handling in sls_euf_plugin

* Add TODO list for enhancements in sls_euf_plugin.cpp

* add incremental mode

* updated package

* fix sls build

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* break sls build

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fix build

* break build again

* fix build

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixes

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixing incremental

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* avoid units

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixup handling of disequality propagation

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fx

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* recover shift-weight loop

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* alternate

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* throttle save model

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* allow for alternating

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fix test for new signature of flip

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* bug fixes

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* restore use of value_hash

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixes

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* adding dt plugin

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* adt

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* dt updates

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* added cycle detection

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* updated sls-datatype

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* Refactor context management, improve datatype handling, and enhance logging in sls plugins.

* axiomatize dt

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* add missing factory plugins to model

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixup finite domain search

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixup finite domain search

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixes

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* redo dfs

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixing model construction for underspecified operators

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixes to occurs check

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixup interpretation building

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* saturate worklist

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* delay distinct axiom

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* adding model-based sls for datatatypes

* update the interface in sls_solver to transfer phase between SAT and SLS

* add value transfer option

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* rename aux functions

* Track shared variables using a unit set

* debugging parallel integration

* fix dirty flag setting

* update log level

* add plugin to smt_context, factor out sls_smt_plugin functionality.

* bug fixes

* fixes

* use common infrastructure for sls-smt

* fix build

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fix build

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* remove declaration of context

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* add virtual destructor

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* build warnings

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* reorder inclusion order to define smt_context before theory_sls

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* change namespace for single threaded

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* check delayed eqs before nla

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* use independent completion flag for sls to avoid conflating with genuine cancelation

* validate sls-arith lemmas

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* bugfixes

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* add intblast to legacy SMT solver

* fixup model generation for theory_intblast

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* na

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* mk_value needs to accept more cases where integer expression doesn't evalate

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* use th-axioms to track origins of assertions

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* add missing operator handling for bitwise operators

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* add missing operator handling for bitwise operators

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* normalizing inequality

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* add virtual destructor

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* rework elim_unconstrained

* fix non-termination

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* use glue as computed without adjustment

* update model generation to fix model bug

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fixes to model construction

* remove package and package lock

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* fix build warning

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

* use original gai

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>

---------

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Sergey Bronnikov <estetus@gmail.com>
Co-authored-by: Lev Nachmanson <levnach@hotmail.com>
Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com>
This commit is contained in:
Nikolaj Bjorner 2024-11-02 12:32:48 -07:00 committed by GitHub
parent ecdfab81a6
commit 91dc02d862
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
120 changed files with 11172 additions and 4148 deletions

View file

@ -1,14 +1,25 @@
z3_add_component(ast_sls
SOURCES
bvsls_opt_engine.cpp
bv_sls.cpp
bv_sls_eval.cpp
bv_sls_fixed.cpp
bv_sls_terms.cpp
sls_engine.cpp
sls_valuation.cpp
sat_ddfw.cpp
sls_arith_base.cpp
sls_arith_plugin.cpp
sls_array_plugin.cpp
sls_basic_plugin.cpp
sls_bv_engine.cpp
sls_bv_eval.cpp
sls_bv_fixed.cpp
sls_bv_plugin.cpp
sls_bv_terms.cpp
sls_bv_valuation.cpp
sls_context.cpp
sls_datatype_plugin.cpp
sls_euf_plugin.cpp
sls_smt_plugin.cpp
sls_smt_solver.cpp
COMPONENT_DEPENDENCIES
ast
euf
converters
normal_forms
)

View file

@ -1,364 +0,0 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
bv_sls.cpp
Abstract:
A Stochastic Local Search (SLS) engine
Uses invertibility conditions,
interval annotations
don't care annotations
Author:
Nikolaj Bjorner (nbjorner) 2024-02-07
--*/
#include "ast/sls/bv_sls.h"
#include "ast/ast_pp.h"
#include "ast/ast_ll_pp.h"
#include "params/sls_params.hpp"
namespace bv {
sls::sls(ast_manager& m, params_ref const& p):
m(m),
bv(m),
m_terms(m),
m_eval(m),
m_engine(m, p)
{
updt_params(p);
}
void sls::init() {
m_terms.init();
}
void sls::init_eval(std::function<bool(expr*, unsigned)>& eval) {
m_eval.init_eval(m_terms.assertions(), eval);
m_eval.tighten_range(m_terms.assertions());
init_repair();
}
void sls::init_repair() {
m_repair_down = UINT_MAX;
m_repair_up.reset();
m_repair_roots.reset();
for (auto* e : m_terms.assertions()) {
if (!m_eval.bval0(e)) {
m_eval.set(e, true);
m_repair_roots.insert(e->get_id());
}
}
for (auto* t : m_terms.terms()) {
if (t && !m_eval.re_eval_is_correct(t))
m_repair_roots.insert(t->get_id());
}
}
void sls::set_model() {
if (!m_set_model)
return;
if (m_repair_roots.size() >= m_min_repair_size)
return;
m_min_repair_size = m_repair_roots.size();
IF_VERBOSE(2, verbose_stream() << "(sls-update-model :num-unsat " << m_min_repair_size << ")\n");
m_set_model(*get_model());
}
void sls::init_repair_goal(app* t) {
m_eval.init_eval(t);
}
void sls::init_repair_candidates() {
m_to_repair.reset();
ptr_vector<expr> todo;
expr_fast_mark1 mark;
for (auto index : m_repair_roots)
todo.push_back(m_terms.term(index));
for (unsigned i = 0; i < todo.size(); ++i) {
expr* e = todo[i];
if (mark.is_marked(e))
continue;
mark.mark(e);
if (!is_app(e))
continue;
for (expr* arg : *to_app(e))
todo.push_back(arg);
if (is_uninterp_const(e))
m_to_repair.insert(e->get_id());
}
}
void sls::reinit_eval() {
init_repair_candidates();
if (m_to_repair.empty())
return;
// refresh the best model so far to a callback
set_model();
// add fresh units, if any
bool new_assertion = false;
while (m_get_unit) {
auto e = m_get_unit();
if (!e)
break;
new_assertion = true;
assert_expr(e);
}
if (new_assertion)
init();
std::function<bool(expr*, unsigned)> eval = [&](expr* e, unsigned i) {
unsigned id = e->get_id();
bool keep = !m_to_repair.contains(id);
if (m.is_bool(e)) {
if (m_eval.is_fixed0(e) || keep)
return m_eval.bval0(e);
if (m_engine_init) {
auto const& z = m_engine.get_value(e);
return rational(z).get_bit(0);
}
}
else if (bv.is_bv(e)) {
auto& w = m_eval.wval(e);
if (w.fixed.get(i) || keep)
return w.get_bit(i);
if (m_engine_init) {
auto const& z = m_engine.get_value(e);
return rational(z).get_bit(i);
}
}
return m_rand() % 2 == 0;
};
m_eval.init_eval(m_terms.assertions(), eval);
init_repair();
// m_engine_init = false;
}
std::pair<bool, app*> sls::next_to_repair() {
app* e = nullptr;
if (m_repair_down != UINT_MAX) {
e = m_terms.term(m_repair_down);
m_repair_down = UINT_MAX;
return { true, e };
}
if (!m_repair_up.empty()) {
unsigned index = m_repair_up.elem_at(m_rand(m_repair_up.size()));
m_repair_up.remove(index);
e = m_terms.term(index);
return { false, e };
}
while (!m_repair_roots.empty()) {
unsigned index = m_repair_roots.elem_at(m_rand(m_repair_roots.size()));
e = m_terms.term(index);
if (m_terms.is_assertion(e) && !m_eval.bval1(e)) {
SASSERT(m_eval.bval0(e));
return { true, e };
}
if (!m_eval.re_eval_is_correct(e)) {
init_repair_goal(e);
return { true, e };
}
m_repair_roots.remove(index);
}
return { false, nullptr };
}
lbool sls::search1() {
// init and init_eval were invoked
unsigned n = 0;
for (; n < m_config.m_max_repairs && m.inc(); ++n) {
auto [down, e] = next_to_repair();
if (!e)
return l_true;
IF_VERBOSE(20, trace_repair(down, e));
++m_stats.m_moves;
if (down)
try_repair_down(e);
else
try_repair_up(e);
}
return l_undef;
}
lbool sls::search2() {
lbool res = l_undef;
if (m_stats.m_restarts == 0)
res = m_engine(),
m_engine_init = true;
else if (m_stats.m_restarts % 1000 == 0)
res = m_engine.search_loop(),
m_engine_init = true;
if (res != l_undef)
m_engine_model = true;
return res;
}
lbool sls::operator()() {
lbool res = l_undef;
m_stats.reset();
m_stats.m_restarts = 0;
m_engine_model = false;
m_engine_init = false;
do {
res = search1();
if (res != l_undef)
break;
trace();
//res = search2();
if (res != l_undef)
break;
reinit_eval();
}
while (m.inc() && m_stats.m_restarts++ < m_config.m_max_restarts);
return res;
}
void sls::try_repair_down(app* e) {
unsigned n = e->get_num_args();
if (n == 0) {
m_eval.commit_eval(e);
for (auto p : m_terms.parents(e))
m_repair_up.insert(p->get_id());
return;
}
if (n == 2) {
auto d1 = get_depth(e->get_arg(0));
auto d2 = get_depth(e->get_arg(1));
unsigned s = m_rand(d1 + d2 + 2);
if (s <= d1 && m_eval.try_repair(e, 0)) {
set_repair_down(e->get_arg(0));
return;
}
if (m_eval.try_repair(e, 1)) {
set_repair_down(e->get_arg(1));
return;
}
if (m_eval.try_repair(e, 0)) {
set_repair_down(e->get_arg(0));
return;
}
}
else {
unsigned s = m_rand(n);
for (unsigned i = 0; i < n; ++i) {
auto j = (i + s) % n;
if (m_eval.try_repair(e, j)) {
set_repair_down(e->get_arg(j));
return;
}
}
}
IF_VERBOSE(3, verbose_stream() << "init-repair " << mk_bounded_pp(e, m) << "\n");
// repair was not successful, so reset the state to find a different way to repair
init_repair();
}
void sls::try_repair_up(app* e) {
if (m_terms.is_assertion(e))
m_repair_roots.insert(e->get_id());
else if (!m_eval.repair_up(e)) {
IF_VERBOSE(2, verbose_stream() << "repair-up "; trace_repair(true, e));
if (m_rand(10) != 0) {
m_eval.set_random(e);
m_repair_roots.insert(e->get_id());
}
else
init_repair();
}
else {
if (!m_eval.eval_is_correct(e)) {
verbose_stream() << "incorrect eval #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n";
}
SASSERT(m_eval.eval_is_correct(e));
for (auto p : m_terms.parents(e))
m_repair_up.insert(p->get_id());
}
}
model_ref sls::get_model() {
if (m_engine_model)
return m_engine.get_model();
model_ref mdl = alloc(model, m);
auto& terms = m_eval.sort_assertions(m_terms.assertions());
for (expr* e : terms) {
if (!is_uninterp_const(e))
continue;
auto f = to_app(e)->get_decl();
auto v = m_eval.get_value(to_app(e));
if (v)
mdl->register_decl(f, v);
}
terms.reset();
return mdl;
}
std::ostream& sls::display(std::ostream& out) {
auto& terms = m_eval.sort_assertions(m_terms.assertions());
for (expr* e : terms) {
out << e->get_id() << ": " << mk_bounded_pp(e, m, 1) << " ";
if (m_eval.is_fixed0(e))
out << "f ";
if (m_repair_up.contains(e->get_id()))
out << "u ";
if (m_repair_roots.contains(e->get_id()))
out << "r ";
m_eval.display_value(out, e);
out << "\n";
}
terms.reset();
return out;
}
void sls::updt_params(params_ref const& _p) {
sls_params p(_p);
m_config.m_max_restarts = p.max_restarts();
m_config.m_max_repairs = p.max_repairs();
m_rand.set_seed(p.random_seed());
m_terms.updt_params(_p);
params_ref q = _p;
q.set_uint("max_restarts", 10);
m_engine.updt_params(q);
}
std::ostream& sls::trace_repair(bool down, expr* e) {
verbose_stream() << (down ? "d #" : "u #")
<< e->get_id() << ": "
<< mk_bounded_pp(e, m, 1) << " ";
m_eval.display_value(verbose_stream(), e) << "\n";
return verbose_stream();
}
void sls::trace() {
IF_VERBOSE(2, verbose_stream()
<< "(bvsls :restarts " << m_stats.m_restarts
<< " :repair-up " << m_repair_up.size()
<< " :repair-roots " << m_repair_roots.size() << ")\n");
}
}

View file

@ -1,129 +0,0 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
bv_sls.h
Abstract:
A Stochastic Local Search (SLS) engine
Author:
Nikolaj Bjorner (nbjorner) 2024-02-07
--*/
#pragma once
#include "util/lbool.h"
#include "util/params.h"
#include "util/scoped_ptr_vector.h"
#include "util/uint_set.h"
#include "ast/ast.h"
#include "ast/sls/sls_stats.h"
#include "ast/sls/sls_powers.h"
#include "ast/sls/sls_valuation.h"
#include "ast/sls/bv_sls_terms.h"
#include "ast/sls/bv_sls_eval.h"
#include "ast/sls/sls_engine.h"
#include "ast/bv_decl_plugin.h"
#include "model/model.h"
namespace bv {
class sls {
struct config {
unsigned m_max_restarts = 1000;
unsigned m_max_repairs = 1000;
};
ast_manager& m;
bv_util bv;
sls_terms m_terms;
sls_eval m_eval;
sls_stats m_stats;
indexed_uint_set m_repair_up, m_repair_roots;
unsigned m_repair_down = UINT_MAX;
ptr_vector<expr> m_todo;
random_gen m_rand;
config m_config;
sls_engine m_engine;
bool m_engine_model = false;
bool m_engine_init = false;
std::function<expr_ref()> m_get_unit;
std::function<void(model& mdl)> m_set_model;
unsigned m_min_repair_size = UINT_MAX;
std::pair<bool, app*> next_to_repair();
void init_repair_goal(app* e);
void set_model();
void try_repair_down(app* e);
void try_repair_up(app* e);
void set_repair_down(expr* e) { m_repair_down = e->get_id(); }
lbool search1();
lbool search2();
void reinit_eval();
void init_repair();
void trace();
std::ostream& trace_repair(bool down, expr* e);
indexed_uint_set m_to_repair;
void init_repair_candidates();
public:
sls(ast_manager& m, params_ref const& p);
/**
* Add constraints
*/
void assert_expr(expr* e) { m_terms.assert_expr(e); m_engine.assert_expr(e); }
/*
* Invoke init after all expressions are asserted.
*/
void init();
/**
* Invoke init_eval to initialize, or re-initialize, values of
* uninterpreted constants.
*/
void init_eval(std::function<bool(expr*, unsigned)>& eval);
/**
* add callback to retrieve new units
*/
void init_unit(std::function<expr_ref()> get_unit) { m_get_unit = get_unit; }
/**
* Add callback to set model
*/
void set_model(std::function<void(model& mdl)> f) { m_set_model = f; }
/**
* Run (bounded) local search to find feasible assignments.
*/
lbool operator()();
void updt_params(params_ref const& p);
void collect_statistics(statistics& st) const { m_stats.collect_statistics(st); m_engine.collect_statistics(st); }
void reset_statistics() { m_stats.reset(); m_engine.reset_statistics(); }
unsigned get_num_moves() { return m_stats.m_moves + m_engine.get_stats().m_moves; }
std::ostream& display(std::ostream& out);
/**
* Retrieve valuation
*/
sls_valuation const& wval(expr* e) const { return m_eval.wval(e); }
model_ref get_model();
void cancel() { m.limit().cancel(); }
};
}

View file

@ -1,229 +0,0 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
bv_sls.cpp
Abstract:
A Stochastic Local Search (SLS) engine
Uses invertibility conditions,
interval annotations
don't care annotations
Author:
Nikolaj Bjorner (nbjorner) 2024-02-07
--*/
#include "ast/ast_ll_pp.h"
#include "ast/sls/bv_sls.h"
#include "ast/rewriter/th_rewriter.h"
namespace bv {
sls_terms::sls_terms(ast_manager& m):
m(m),
bv(m),
m_rewriter(m),
m_assertions(m),
m_pinned(m),
m_translated(m),
m_terms(m){}
void sls_terms::assert_expr(expr* e) {
m_assertions.push_back(ensure_binary(e));
}
expr* sls_terms::ensure_binary(expr* e) {
expr* top = e;
m_pinned.push_back(e);
m_todo.push_back(e);
{
expr_fast_mark1 mark;
for (unsigned i = 0; i < m_todo.size(); ++i) {
expr* e = m_todo[i];
if (!is_app(e))
continue;
if (m_translated.get(e->get_id(), nullptr))
continue;
if (mark.is_marked(e))
continue;
mark.mark(e);
for (auto arg : *to_app(e))
m_todo.push_back(arg);
}
}
std::stable_sort(m_todo.begin(), m_todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); });
for (expr* e : m_todo)
ensure_binary_core(e);
m_todo.reset();
return m_translated.get(top->get_id());
}
void sls_terms::ensure_binary_core(expr* e) {
if (m_translated.get(e->get_id(), nullptr))
return;
app* a = to_app(e);
auto arg = [&](unsigned i) {
return m_translated.get(a->get_arg(i)->get_id());
};
unsigned num_args = a->get_num_args();
expr_ref r(m);
#define FOLD_OP(oper) \
r = arg(0); \
for (unsigned i = 1; i < num_args; ++i)\
r = oper(r, arg(i)); \
if (m.is_and(e)) {
FOLD_OP(m.mk_and);
}
else if (m.is_or(e)) {
FOLD_OP(m.mk_or);
}
else if (m.is_xor(e)) {
FOLD_OP(m.mk_xor);
}
else if (bv.is_bv_and(e)) {
FOLD_OP(bv.mk_bv_and);
}
else if (bv.is_bv_or(e)) {
FOLD_OP(bv.mk_bv_or);
}
else if (bv.is_bv_xor(e)) {
FOLD_OP(bv.mk_bv_xor);
}
else if (bv.is_bv_add(e)) {
FOLD_OP(bv.mk_bv_add);
}
else if (bv.is_bv_mul(e)) {
FOLD_OP(bv.mk_bv_mul);
}
else if (bv.is_concat(e)) {
FOLD_OP(bv.mk_concat);
}
else if (m.is_distinct(e)) {
expr_ref_vector es(m);
for (unsigned i = 0; i < num_args; ++i)
for (unsigned j = i + 1; j < num_args; ++j)
es.push_back(m.mk_not(m.mk_eq(arg(i), arg(j))));
r = m.mk_and(es);
}
else if (bv.is_bv_sdiv(e) || bv.is_bv_sdiv0(e) || bv.is_bv_sdivi(e)) {
r = mk_sdiv(arg(0), arg(1));
}
else if (bv.is_bv_smod(e) || bv.is_bv_smod0(e) || bv.is_bv_smodi(e)) {
r = mk_smod(arg(0), arg(1));
}
else if (bv.is_bv_srem(e) || bv.is_bv_srem0(e) || bv.is_bv_sremi(e)) {
r = mk_srem(arg(0), arg(1));
}
else {
for (unsigned i = 0; i < num_args; ++i)
m_args.push_back(arg(i));
r = m.mk_app(a->get_decl(), num_args, m_args.data());
m_args.reset();
}
m_translated.setx(e->get_id(), r);
}
expr_ref sls_terms::mk_sdiv(expr* x, expr* y) {
// d = udiv(abs(x), abs(y))
// y = 0, x >= 0 -> -1
// y = 0, x < 0 -> 1
// x = 0, y != 0 -> 0
// x > 0, y < 0 -> -d
// x < 0, y > 0 -> -d
// x > 0, y > 0 -> d
// x < 0, y < 0 -> d
unsigned sz = bv.get_bv_size(x);
rational N = rational::power_of_two(sz);
expr_ref z(bv.mk_zero(sz), m);
expr* signx = bv.mk_ule(bv.mk_numeral(N / 2, sz), x);
expr* signy = bv.mk_ule(bv.mk_numeral(N / 2, sz), y);
expr* absx = m.mk_ite(signx, bv.mk_bv_neg(x), x);
expr* absy = m.mk_ite(signy, bv.mk_bv_neg(y), y);
expr* d = bv.mk_bv_udiv(absx, absy);
expr_ref r(m.mk_ite(m.mk_eq(signx, signy), d, bv.mk_bv_neg(d)), m);
r = m.mk_ite(m.mk_eq(z, y),
m.mk_ite(signx, bv.mk_one(sz), bv.mk_numeral(N - 1, sz)),
m.mk_ite(m.mk_eq(x, z), z, r));
m_rewriter(r);
return r;
}
expr_ref sls_terms::mk_smod(expr* x, expr* y) {
// u := umod(abs(x), abs(y))
// u = 0 -> 0
// y = 0 -> x
// x < 0, y < 0 -> -u
// x < 0, y >= 0 -> y - u
// x >= 0, y < 0 -> y + u
// x >= 0, y >= 0 -> u
unsigned sz = bv.get_bv_size(x);
expr_ref z(bv.mk_zero(sz), m);
expr_ref abs_x(m.mk_ite(bv.mk_sle(z, x), x, bv.mk_bv_neg(x)), m);
expr_ref abs_y(m.mk_ite(bv.mk_sle(z, y), y, bv.mk_bv_neg(y)), m);
expr_ref u(bv.mk_bv_urem(abs_x, abs_y), m);
expr_ref r(m);
r = m.mk_ite(m.mk_eq(u, z), z,
m.mk_ite(m.mk_eq(y, z), x,
m.mk_ite(m.mk_and(bv.mk_sle(z, x), bv.mk_sle(z, x)), u,
m.mk_ite(bv.mk_sle(z, x), bv.mk_bv_add(y, u),
m.mk_ite(bv.mk_sle(z, y), bv.mk_bv_sub(y, u), bv.mk_bv_neg(u))))));
m_rewriter(r);
return r;
}
expr_ref sls_terms::mk_srem(expr* x, expr* y) {
// y = 0 -> x
// else x - sdiv(x, y) * y
expr_ref r(m);
r = m.mk_ite(m.mk_eq(y, bv.mk_zero(bv.get_bv_size(x))),
x, bv.mk_bv_sub(x, bv.mk_bv_mul(y, mk_sdiv(x, y))));
m_rewriter(r);
return r;
}
void sls_terms::init() {
// populate terms
expr_fast_mark1 mark;
for (expr* e : m_assertions)
m_todo.push_back(e);
while (!m_todo.empty()) {
expr* e = m_todo.back();
m_todo.pop_back();
if (mark.is_marked(e) || !is_app(e))
continue;
mark.mark(e);
m_terms.setx(e->get_id(), to_app(e));
for (expr* arg : *to_app(e))
m_todo.push_back(arg);
}
// populate parents
m_parents.reset();
m_parents.reserve(m_terms.size());
for (expr* e : m_terms) {
if (!e || !is_app(e))
continue;
for (expr* arg : *to_app(e))
m_parents[arg->get_id()].push_back(e);
}
m_assertion_set.reset();
for (auto a : m_assertions)
m_assertion_set.insert(a->get_id());
}
void sls_terms::updt_params(params_ref const& p) {
params_ref q = p;
q.set_bool("flat", false);
m_rewriter.updt_params(q);
}
}

View file

@ -1,79 +0,0 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
bv_sls_terms.h
Abstract:
A Stochastic Local Search (SLS) engine
Author:
Nikolaj Bjorner (nbjorner) 2024-02-07
--*/
#pragma once
#include "util/lbool.h"
#include "util/params.h"
#include "util/scoped_ptr_vector.h"
#include "util/uint_set.h"
#include "ast/ast.h"
#include "ast/rewriter/th_rewriter.h"
#include "ast/sls/sls_stats.h"
#include "ast/sls/sls_powers.h"
#include "ast/sls/sls_valuation.h"
#include "ast/bv_decl_plugin.h"
namespace bv {
class sls_terms {
ast_manager& m;
bv_util bv;
th_rewriter m_rewriter;
ptr_vector<expr> m_todo, m_args;
expr_ref_vector m_assertions, m_pinned, m_translated;
app_ref_vector m_terms;
vector<ptr_vector<expr>> m_parents;
tracked_uint_set m_assertion_set;
expr* ensure_binary(expr* e);
void ensure_binary_core(expr* e);
expr_ref mk_sdiv(expr* x, expr* y);
expr_ref mk_smod(expr* x, expr* y);
expr_ref mk_srem(expr* x, expr* y);
public:
sls_terms(ast_manager& m);
void updt_params(params_ref const& p);
/**
* Add constraints
*/
void assert_expr(expr* e);
/**
* Initialize structures: assertions, parents, terms
*/
void init();
/**
* Accessors.
*/
ptr_vector<expr> const& parents(expr* e) const { return m_parents[e->get_id()]; }
expr_ref_vector const& assertions() const { return m_assertions; }
app* term(unsigned id) const { return m_terms.get(id); }
app_ref_vector const& terms() const { return m_terms; }
bool is_assertion(expr* e) const { return m_assertion_set.contains(e->get_id()); }
};
}

View file

@ -18,7 +18,7 @@ Notes:
--*/
#pragma once
#include "ast/sls/sls_engine.h"
#include "ast/sls/sls_bv_engine.h"
class bvsls_opt_engine : public sls_engine {
sls_tracker & m_hard_tracker;

684
src/ast/sls/sat_ddfw.cpp Normal file
View file

@ -0,0 +1,684 @@
/*++
Copyright (c) 2019 Microsoft Corporation
Module Name:
sat_ddfw.cpp
Abstract:
DDFW Local search module for clauses
Author:
Nikolaj Bjorner, Marijn Heule 2019-4-23
Notes:
http://www.ict.griffith.edu.au/~johnt/publications/CP2006raouf.pdf
Todo:
- rephase strategy
- experiment with backoff schemes for restarts
- parallel sync
--*/
#include "util/luby.h"
#include "util/trace.h"
#include "ast/sls/sat_ddfw.h"
#include "params/sat_params.hpp"
namespace sat {
ddfw::~ddfw() {
}
lbool ddfw::check(unsigned sz, literal const* assumptions) {
init(sz, assumptions);
if (m_plugin)
check_with_plugin();
else
check_without_plugin();
remove_assumptions();
log();
return m_min_sz == 0 ? l_true : l_undef;
}
void ddfw::check_without_plugin() {
while (m_limit.inc() && m_min_sz > 0) {
if (should_reinit_weights()) do_reinit_weights();
else if (do_flip());
else if (should_restart()) do_restart();
else if (m_parallel_sync && m_parallel_sync());
else shift_weights();
}
}
void ddfw::check_with_plugin() {
m_plugin->init_search();
unsigned steps = 0;
if (m_min_sz <= m_unsat.size())
save_best_values();
try {
while (m_min_sz > 0 && m_limit.inc()) {
if (should_reinit_weights()) do_reinit_weights();
else if (steps % 5000 == 0) shift_weights(), m_plugin->on_rescale();
else if (should_restart()) do_restart(), m_plugin->on_restart();
else if (do_flip());
else shift_weights(), m_plugin->on_rescale();
//verbose_stream() << "steps: " << steps << " min_sz: " << m_min_sz << " unsat: " << m_unsat.size() << "\n";
++steps;
}
}
catch (z3_exception& ex) {
IF_VERBOSE(0, verbose_stream() << "Exception: " << ex.msg() << "\n");
throw;
}
m_plugin->finish_search();
}
void ddfw::log() {
double sec = m_stopwatch.get_current_seconds();
double kflips_per_sec = sec > 0 ? (m_flips - m_last_flips) / (1000.0 * sec) : 0.0;
if (m_last_flips == 0) {
IF_VERBOSE(1, verbose_stream() << "(sat.ddfw :unsat :models :kflips/sec :flips :restarts :reinits :unsat_vars :shifts";
verbose_stream() << ")\n");
}
IF_VERBOSE(1, verbose_stream() << "(sat.ddfw "
<< std::setw(07) << m_min_sz
<< std::setw(07) << m_models.size()
<< std::setw(10) << kflips_per_sec
<< std::setw(10) << m_flips
<< std::setw(10) << m_restart_count
<< std::setw(11) << m_reinit_count
<< std::setw(13) << m_unsat_vars.size()
<< std::setw(9) << m_shifts;
verbose_stream() << ")\n");
m_stopwatch.start();
m_last_flips = m_flips;
}
bool ddfw::do_flip() {
double reward = 0;
bool_var v = pick_var(reward);
//verbose_stream() << "flip " << v << " " << reward << "\n";
return apply_flip(v, reward);
}
bool ddfw::apply_flip(bool_var v, double reward) {
if (v == null_bool_var)
return false;
if (reward > 0 || (reward == 0 && m_rand(100) <= m_config.m_use_reward_zero_pct)) {
flip(v);
if (m_unsat.size() <= m_min_sz)
save_best_values();
return true;
}
return false;
}
bool_var ddfw::pick_var(double& r) {
double sum_pos = 0;
unsigned n = 1;
bool_var v0 = null_bool_var;
for (bool_var v : m_unsat_vars) {
r = reward(v);
if (r > 0.0)
sum_pos += score(r);
else if (r == 0.0 && sum_pos == 0 && (m_rand() % (n++)) == 0)
v0 = v;
}
if (sum_pos > 0) {
double lim_pos = ((double) m_rand() / (1.0 + m_rand.max_value())) * sum_pos;
for (bool_var v : m_unsat_vars) {
r = reward(v);
if (r > 0) {
lim_pos -= score(r);
if (lim_pos <= 0)
return v;
}
}
}
r = 0;
if (v0 != null_bool_var)
return v0;
if (m_unsat_vars.empty())
return null_bool_var;
return m_unsat_vars.elem_at(m_rand(m_unsat_vars.size()));
}
void ddfw::add(unsigned n, literal const* c) {
unsigned idx = m_clauses.size();
m_clauses.push_back(clause_info(n, c, m_config.m_init_clause_weight));
if (n > 2)
++m_num_non_binary_clauses;
for (literal lit : m_clauses.back().m_clause) {
m_use_list.reserve(2*(lit.var()+1));
m_vars.reserve(lit.var()+1);
m_use_list[lit.index()].push_back(idx);
}
}
sat::bool_var ddfw::add_var() {
auto v = m_vars.size();
m_vars.reserve(v + 1);
return v;
}
void ddfw::reserve_vars(unsigned n) {
m_vars.reserve(n);
}
/**
* Remove the last clause that was added
*/
void ddfw::del() {
auto& info = m_clauses.back();
for (literal lit : info.m_clause)
m_use_list[lit.index()].pop_back();
m_clauses.pop_back();
if (m_unsat.contains(m_clauses.size()))
m_unsat.remove(m_clauses.size());
}
void ddfw::add_assumptions() {
for (unsigned i = 0; i < m_assumptions.size(); ++i)
add(1, m_assumptions.data() + i);
}
void ddfw::remove_assumptions() {
if (m_assumptions.empty())
return;
for (unsigned i = 0; i < m_assumptions.size(); ++i)
del();
init(0, nullptr);
}
void ddfw::init(unsigned sz, literal const* assumptions) {
m_assumptions.reset();
m_assumptions.append(sz, assumptions);
add_assumptions();
for (unsigned v = 0; v < num_vars(); ++v) {
value(v) = (m_rand() % 2) == 0; // m_use_list[lit.index()].size() >= m_use_list[nlit.index()].size();
}
if (!flatten_use_list())
init_clause_data();
m_reinit_count = 0;
m_reinit_next = m_config.m_reinit_base;
m_restart_count = 0;
m_restart_next = m_config.m_restart_base*2;
m_min_sz = m_clauses.size();
m_flips = 0;
m_last_flips = 0;
m_shifts = 0;
m_stopwatch.start();
}
void ddfw::reinit() {
add_assumptions();
flatten_use_list();
}
bool ddfw::flatten_use_list() {
if (num_vars() == m_use_list_vars && m_clauses.size() == m_use_list_clauses)
return false;
m_use_list_vars = num_vars();
m_use_list_clauses = m_clauses.size();
m_use_list_index.reset();
m_flat_use_list.reset();
for (auto const& ul : m_use_list) {
m_use_list_index.push_back(m_flat_use_list.size());
m_flat_use_list.append(ul);
}
m_use_list_index.push_back(m_flat_use_list.size());
init_clause_data();
return true;
}
void ddfw::flip(bool_var v) {
++m_flips;
literal lit = literal(v, !value(v));
literal nlit = ~lit;
SASSERT(is_true(lit));
for (unsigned cls_idx : use_list(lit)) {
clause_info& ci = m_clauses[cls_idx];
ci.del(lit);
double w = ci.m_weight;
// cls becomes false: flip any variable in clause to receive reward w
switch (ci.m_num_trues) {
case 0: {
#if 0
if (ci.m_clause.size() == 1)
verbose_stream() << "flipping unit clause " << ci << "\n";
#endif
m_unsat.insert_fresh(cls_idx);
auto const& c = get_clause(cls_idx);
for (literal l : c) {
inc_reward(l, w);
inc_make(l);
}
inc_reward(lit, w);
break;
}
case 1:
dec_reward(to_literal(ci.m_trues), w);
break;
default:
break;
}
}
for (unsigned cls_idx : use_list(nlit)) {
clause_info& ci = m_clauses[cls_idx];
double w = ci.m_weight;
// the clause used to have a single true (pivot) literal, now it has two.
// Then the previous pivot is no longer penalized for flipping.
switch (ci.m_num_trues) {
case 0: {
m_unsat.remove(cls_idx);
auto const& c = get_clause(cls_idx);
for (literal l : c) {
dec_reward(l, w);
dec_make(l);
}
dec_reward(nlit, w);
break;
}
case 1:
inc_reward(to_literal(ci.m_trues), w);
break;
default:
break;
}
ci.add(nlit);
}
value(v) = !value(v);
update_reward_avg(v);
}
bool ddfw::should_reinit_weights() {
return m_flips >= m_reinit_next;
}
void ddfw::do_reinit_weights() {
log();
if (m_reinit_count % 2 == 0) {
for (auto& ci : m_clauses)
ci.m_weight += 1;
}
else {
for (auto& ci : m_clauses)
if (ci.is_true())
ci.m_weight = m_config.m_init_clause_weight;
else
ci.m_weight = m_config.m_init_clause_weight + 1;
}
init_clause_data();
++m_reinit_count;
m_reinit_next += m_reinit_count * m_config.m_reinit_base;
}
void ddfw::init_clause_data() {
for (unsigned v = 0; v < num_vars(); ++v) {
make_count(v) = 0;
reward(v) = 0;
}
m_unsat_vars.reset();
m_unsat.reset();
unsigned sz = m_clauses.size();
for (unsigned i = 0; i < sz; ++i) {
auto& ci = m_clauses[i];
auto const& c = get_clause(i);
ci.m_trues = 0;
ci.m_num_trues = 0;
for (literal lit : c)
if (is_true(lit))
ci.add(lit);
switch (ci.m_num_trues) {
case 0:
for (literal lit : c) {
inc_reward(lit, ci.m_weight);
inc_make(lit);
}
m_unsat.insert_fresh(i);
break;
case 1:
dec_reward(to_literal(ci.m_trues), ci.m_weight);
break;
default:
break;
}
}
if (m_unsat.size() < m_min_sz)
save_best_values();
}
bool ddfw::should_restart() {
return m_flips >= m_restart_next;
}
void ddfw::do_restart() {
reinit_values();
init_clause_data();
m_restart_next += m_config.m_restart_base*get_luby(++m_restart_count);
}
/**
\brief the higher the bias, the lower the probability to deviate from the value of the bias
during a restart.
bias = 0 -> flip truth value with 50%
|bias| = 1 -> toss coin with 25% probability
|bias| = 2 -> toss coin with 12.5% probability
etc
*/
void ddfw::reinit_values() {
for (unsigned i = 0; i < num_vars(); ++i) {
int b = bias(i);
if (0 == (m_rand() % (1 + abs(b))))
value(i) = (m_rand() % 2) == 0;
else
value(i) = bias(i) > 0;
}
}
void ddfw::save_priorities() {
m_probs.reset();
for (unsigned v = 0; v < num_vars(); ++v)
m_probs.push_back(-m_vars[v].m_reward_avg);
}
void ddfw::save_model() {
m_model.reserve(num_vars());
for (unsigned i = 0; i < num_vars(); ++i)
m_model[i] = to_lbool(value(i));
save_priorities();
if (m_plugin)
m_plugin->on_save_model();
}
void ddfw::save_best_values() {
if (m_save_best_values)
return;
if (m_plugin && !m_unsat.empty())
return;
flet<bool> _save_best_values(m_save_best_values, true);
bool do_save_model = ((m_unsat.size() < m_min_sz || m_unsat.empty()) &&
((m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11)));
if (do_save_model)
save_model();
if (m_unsat.size() < m_min_sz) {
m_models.reset();
m_min_sz = m_unsat.size();
}
unsigned h = value_hash();
unsigned occs = 0;
bool contains = m_models.find(h, occs);
if (!contains) {
for (unsigned v = 0; v < num_vars(); ++v)
bias(v) += value(v) ? 1 : -1;
if (m_models.size() > m_config.m_max_num_models)
m_models.erase(m_models.begin()->m_key);
}
m_models.insert(h, occs + 1);
if (occs > 100) {
m_restart_next = m_flips;
m_models.erase(h);
}
}
unsigned ddfw::value_hash() const {
unsigned s0 = 0, s1 = 0;
for (auto const& vi : m_vars) {
s0 += vi.m_value;
s1 += s0;
}
return s1;
}
/**
\brief Filter on whether to select a satisfied clause
1. with some probability prefer higher weight to lesser weight.
2. take into account number of trues ?
3. select multiple clauses instead of just one per clause in unsat.
*/
bool ddfw::select_clause(double max_weight, clause_info const& cn, unsigned& n) {
if (cn.m_num_trues == 0 || cn.m_weight + 1e-5 < max_weight)
return false;
if (cn.m_weight > max_weight) {
n = 2;
return true;
}
return (m_rand() % (n++)) == 0;
}
unsigned ddfw::select_max_same_sign(unsigned cf_idx) {
auto& ci = m_clauses[cf_idx];
unsigned cl = UINT_MAX; // clause pointer to same sign, max weight satisfied clause.
auto const& c = ci.m_clause;
double max_weight = m_init_weight;
unsigned n = 1;
for (literal lit : c) {
for (unsigned cn_idx : use_list(lit)) {
auto& cn = m_clauses[cn_idx];
if (select_clause(max_weight, cn, n)) {
cl = cn_idx;
max_weight = cn.m_weight;
}
}
}
return cl;
}
void ddfw::transfer_weight(unsigned from, unsigned to, double w) {
auto& cf = m_clauses[to];
auto& cn = m_clauses[from];
if (cn.m_weight < w)
return;
cf.m_weight += w;
cn.m_weight -= w;
for (literal lit : get_clause(to))
inc_reward(lit, w);
if (cn.m_num_trues == 1)
inc_reward(to_literal(cn.m_trues), w);
}
unsigned ddfw::select_random_true_clause() {
unsigned num_clauses = m_clauses.size();
for (unsigned i = 0; i < num_clauses; ++i) {
unsigned idx = (m_rand() * m_rand()) % num_clauses;
auto & cn = m_clauses[idx];
if (cn.is_true() && cn.m_weight >= m_init_weight)
return idx;
}
unsigned n = 0, idx = UINT_MAX;
for (unsigned i = 0; i < num_clauses; ++i) {
auto& cn = m_clauses[i];
if (cn.is_true() && cn.m_weight >= m_init_weight && (m_rand() % (++n)) == 0)
idx = i;
}
return idx;
}
// 1% chance to disregard neighbor
inline bool ddfw::disregard_neighbor() {
return false; // rand() % 1000 == 0;
}
double ddfw::calculate_transfer_weight(double w) {
return (w > m_init_weight) ? m_init_weight : 1;
}
void ddfw::shift_weights() {
++m_shifts;
bool shifted = false;
for (unsigned to_idx : m_unsat) {
SASSERT(!m_clauses[to_idx].is_true());
unsigned from_idx = select_max_same_sign(to_idx);
if (from_idx == UINT_MAX || disregard_neighbor())
from_idx = select_random_true_clause();
if (from_idx == UINT_MAX)
continue;
shifted = true;
auto & cn = m_clauses[from_idx];
SASSERT(cn.is_true());
double w = calculate_transfer_weight(cn.m_weight);
transfer_weight(from_idx, to_idx, w);
}
//verbose_stream() << m_shifts << " " << m_flips << " " << shifted << "\n";
if (!shifted && m_restart_next > m_flips)
m_restart_next = m_flips + (m_restart_next - m_flips) / 2;
// DEBUG_CODE(invariant(););
}
// apply unit propagation.
void ddfw::simplify() {
verbose_stream() << "simplify\n";
sat::literal_vector units;
uint_set unit_set;
for (unsigned i = 0; i < m_clauses.size(); ++i) {
auto& ci = m_clauses[i];
if (ci.m_clause.size() != 1)
continue;
auto lit = ci.m_clause[0];
units.push_back(lit);
unit_set.insert(lit.index());
m_use_list[lit.index()].reset();
m_use_list[lit.index()].push_back(i);
}
auto is_unit = [&](sat::literal lit) {
return unit_set.contains(lit.index());
};
sat::literal_vector new_clause;
for (unsigned i = 0; i < units.size(); ++i) {
auto lit = units[i];
for (auto cidx : m_use_list[(~lit).index()]) {
auto& ci = m_clauses[cidx];
if (ci.m_clause.size() == 1)
continue;
new_clause.reset();
for (auto l : ci) {
if (!is_unit(~l))
new_clause.push_back(l);
}
if (new_clause.size() == 1) {
verbose_stream() << "new unit " << lit << " " << ci << " -> " << new_clause << "\n";
}
m_clauses[cidx] = sat::clause_info(new_clause.size(), new_clause.data(), m_config.m_init_clause_weight);
if (new_clause.size() == 1) {
units.push_back(new_clause[0]);
unit_set.insert(new_clause[0].index());
}
}
}
for (auto unit : units)
m_use_list[(~unit).index()].reset();
}
std::ostream& ddfw::display(std::ostream& out) const {
unsigned num_cls = m_clauses.size();
for (unsigned i = 0; i < num_cls; ++i) {
out << get_clause(i) << " nt: ";
auto const& ci = m_clauses[i];
out << ci.m_num_trues << " w: " << ci.m_weight << "\n";
}
for (unsigned v = 0; v < num_vars(); ++v)
out << (is_true(literal(v, false)) ? "" : "-") << v << " rw: " << get_reward(v) << "\n";
out << "unsat vars: ";
for (bool_var v : m_unsat_vars)
out << v << " ";
out << "\n";
return out;
}
void ddfw::invariant() {
// every variable in unsat vars is in a false clause.
for (bool_var v : m_unsat_vars) {
bool found = false;
for (unsigned cl : m_unsat) {
for (literal lit : get_clause(cl)) {
if (lit.var() == v) { found = true; break; }
}
if (found) break;
}
if (!found) IF_VERBOSE(0, verbose_stream() << "unsat var not found: " << v << "\n"; );
VERIFY(found);
}
for (unsigned v = 0; v < num_vars(); ++v) {
double v_reward = 0;
literal lit(v, !value(v));
for (unsigned j : m_use_list[lit.index()]) {
clause_info const& ci = m_clauses[j];
if (ci.m_num_trues == 1) {
SASSERT(lit == to_literal(ci.m_trues));
v_reward -= ci.m_weight;
}
}
for (unsigned j : m_use_list[(~lit).index()]) {
clause_info const& ci = m_clauses[j];
if (ci.m_num_trues == 0) {
v_reward += ci.m_weight;
}
}
IF_VERBOSE(0, if (v_reward != reward(v)) verbose_stream() << v << " " << v_reward << " " << reward(v) << "\n");
// SASSERT(reward(v) == v_reward);
}
DEBUG_CODE(
for (auto const& ci : m_clauses) {
SASSERT(ci.m_weight > 0);
}
for (unsigned i = 0; i < m_clauses.size(); ++i) {
bool found = false;
for (literal lit : get_clause(i)) {
if (is_true(lit)) found = true;
}
SASSERT(found == !m_unsat.contains(i));
}
// every variable in a false clause is in unsat vars
for (unsigned cl : m_unsat) {
for (literal lit : get_clause(cl)) {
SASSERT(m_unsat_vars.contains(lit.var()));
}
});
}
void ddfw::updt_params(params_ref const& _p) {
sat_params p(_p);
m_config.m_init_clause_weight = p.ddfw_init_clause_weight();
m_config.m_use_reward_zero_pct = p.ddfw_use_reward_pct();
m_config.m_reinit_base = p.ddfw_reinit_base();
m_config.m_restart_base = p.ddfw_restart_base();
}
void ddfw::collect_statistics(statistics& st) const {
st.update("sls-ddfw-flips", (double)m_flips);
st.update("sls-ddfw-restarts", m_restart_count);
st.update("sls-ddfw-reinits", m_reinit_count);
st.update("sls-ddfw-shifts", (double)m_shifts);
}
void ddfw::reset_statistics() {
m_flips = 0;
m_restart_count = 0;
m_reinit_count = 0;
m_shifts = 0;
}
}

281
src/ast/sls/sat_ddfw.h Normal file
View file

@ -0,0 +1,281 @@
/*++
Copyright (c) 2019 Microsoft Corporation
Module Name:
sat_ddfw.h
Abstract:
DDFW Local search module for clauses
Author:
Nikolaj Bjorner, Marijn Heule 2019-4-23
Notes:
http://www.ict.griffith.edu.au/~johnt/publications/CP2006raouf.pdf
--*/
#pragma once
#include "util/uint_set.h"
#include "util/rlimit.h"
#include "util/params.h"
#include "util/ema.h"
#include "util/sat_sls.h"
#include "util/map.h"
#include "util/sat_literal.h"
#include "util/statistics.h"
#include "util/stopwatch.h"
namespace sat {
class local_search_plugin {
public:
virtual ~local_search_plugin() {}
virtual void init_search() = 0;
virtual void finish_search() = 0;
virtual void on_rescale() = 0;
virtual void on_save_model() = 0;
virtual void on_restart() = 0;
};
class ddfw {
friend class ddfw_wrapper;
protected:
struct config {
config() { reset(); }
unsigned m_use_reward_zero_pct;
unsigned m_init_clause_weight;
unsigned m_max_num_models;
unsigned m_restart_base;
unsigned m_reinit_base;
unsigned m_parsync_base;
double m_itau;
void reset() {
m_init_clause_weight = 8;
m_use_reward_zero_pct = 15;
m_max_num_models = (1 << 10);
m_restart_base = 100333;
m_reinit_base = 10000;
m_parsync_base = 333333;
m_itau = 0.5;
}
};
struct var_info {
var_info() {}
bool m_value = false;
double m_reward = 0;
double m_last_reward = 0;
unsigned m_make_count = 0;
int m_bias = 0;
ema m_reward_avg = 1e-5;
};
config m_config;
reslimit m_limit;
vector<clause_info> m_clauses;
literal_vector m_assumptions;
svector<var_info> m_vars; // var -> info
svector<double> m_probs; // var -> probability of flipping
svector<double> m_scores; // reward -> score
svector<lbool> m_model; // var -> best assignment
unsigned m_init_weight = 2;
vector<unsigned_vector> m_use_list;
unsigned_vector m_flat_use_list;
unsigned_vector m_use_list_index;
unsigned m_use_list_vars = 0, m_use_list_clauses = 0;
indexed_uint_set m_unsat;
indexed_uint_set m_unsat_vars; // set of variables that are in unsat clauses
random_gen m_rand;
uint64_t m_last_flips_for_shift = 0;
unsigned m_num_non_binary_clauses = 0;
unsigned m_restart_count = 0, m_reinit_count = 0;
uint64_t m_restart_next = 0, m_reinit_next = 0;
uint64_t m_flips = 0, m_last_flips = 0, m_shifts = 0;
unsigned m_min_sz = UINT_MAX;
u_map<unsigned> m_models;
stopwatch m_stopwatch;
unsigned_vector m_num_models;
bool m_save_best_values = false;
scoped_ptr<local_search_plugin> m_plugin = nullptr;
std::function<bool(void)> m_parallel_sync;
bool flatten_use_list();
/**
* TBD: map reward value to a score, possibly through an exponential function, such as
* exp(-tau/r), where tau > 0
*/
inline double score(double r) { return r; }
inline unsigned& make_count(bool_var v) { return m_vars[v].m_make_count; }
inline bool& value(bool_var v) { return m_vars[v].m_value; }
inline bool value(bool_var v) const { return m_vars[v].m_value; }
inline double& reward(bool_var v) { return m_vars[v].m_reward; }
unsigned value_hash() const;
inline bool is_true(literal lit) const { return value(lit.var()) != lit.sign(); }
inline sat::literal_vector const& get_clause(unsigned idx) const { return m_clauses[idx].m_clause; }
inline double get_weight(unsigned idx) const { return m_clauses[idx].m_weight; }
inline bool is_true(unsigned idx) const { return m_clauses[idx].is_true(); }
void update_reward_avg(bool_var v) { m_vars[v].m_reward_avg.update(reward(v)); }
unsigned select_max_same_sign(unsigned cf_idx);
inline void inc_make(literal lit) {
bool_var v = lit.var();
if (make_count(v)++ == 0) m_unsat_vars.insert_fresh(v);
}
inline void dec_make(literal lit) {
bool_var v = lit.var();
if (--make_count(v) == 0) m_unsat_vars.remove(v);
}
inline void inc_reward(literal lit, double w) { reward(lit.var()) += w; }
inline void dec_reward(literal lit, double w) { reward(lit.var()) -= w; }
void check_with_plugin();
void check_without_plugin();
// flip activity
bool do_flip();
bool_var pick_var(double& reward);
bool apply_flip(bool_var v, double reward);
void save_best_values();
void save_model();
void save_priorities();
// shift activity
void shift_weights();
inline double calculate_transfer_weight(double w);
// reinitialize weights activity
bool should_reinit_weights();
void do_reinit_weights();
inline bool select_clause(double max_weight, clause_info const& cn, unsigned& n);
// restart activity
bool should_restart();
void do_restart();
void reinit_values();
unsigned select_random_true_clause();
void log();
void init(unsigned sz, literal const* assumptions);
void init_clause_data();
void invariant();
void del();
void add_assumptions();
inline void transfer_weight(unsigned from, unsigned to, double w);
inline bool disregard_neighbor();
public:
ddfw() {}
~ddfw();
void set_plugin(local_search_plugin* p) { m_plugin = p; }
lbool check(unsigned sz, literal const* assumptions);
void updt_params(params_ref const& p);
svector<lbool> const& get_model() const { return m_model; }
reslimit& rlimit() { return m_limit; }
void set_seed(unsigned n) { m_rand.set_seed(n); }
bool get_value(bool_var v) const { return value(v); }
std::ostream& display(std::ostream& out) const;
// for parallel integration
unsigned num_non_binary_clauses() const { return m_num_non_binary_clauses; }
void collect_statistics(statistics& st) const;
void reset_statistics();
double get_priority(bool_var v) const { return m_probs[v]; }
// access clause information and state of Boolean search
indexed_uint_set& unsat_set() { return m_unsat; }
indexed_uint_set const& unsat_set() const { return m_unsat; }
vector<clause_info> const& clauses() const { return m_clauses; }
clause_info& get_clause_info(unsigned idx) { return m_clauses[idx]; }
clause_info const& get_clause_info(unsigned idx) const { return m_clauses[idx]; }
void remove_assumptions();
void flip(bool_var v);
inline double get_reward(bool_var v) const { return m_vars[v].m_reward; }
double get_reward_avg(bool_var v) const { return m_vars[v].m_reward_avg; }
inline int& bias(bool_var v) { return m_vars[v].m_bias; }
void reserve_vars(unsigned n);
void add(unsigned sz, literal const* c);
sat::bool_var add_var();
void reinit();
void force_restart() { m_restart_next = m_flips; }
inline unsigned num_vars() const { return m_vars.size(); }
void simplify();
ptr_iterator<unsigned> use_list(literal lit) {
flatten_use_list();
unsigned i = lit.index();
auto const* b = m_flat_use_list.data() + m_use_list_index[i];
auto const* e = m_flat_use_list.data() + m_use_list_index[i + 1];
return { b, e };
}
};
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,292 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
sls_arith_base.h
Abstract:
Theory plugin for arithmetic local search
Author:
Nikolaj Bjorner (nbjorner) 2020-09-08
--*/
#pragma once
#include "util/obj_pair_set.h"
#include "util/checked_int64.h"
#include "util/optional.h"
#include "ast/ast_trail.h"
#include "ast/arith_decl_plugin.h"
#include "ast/sls/sls_context.h"
namespace sls {
using theory_var = int;
// local search portion for arithmetic
template<typename num_t>
class arith_base : public plugin {
enum class ineq_kind { EQ, LE, LT};
enum class var_sort { INT, REAL };
struct bound { bool is_strict = false; num_t value; };
typedef unsigned var_t;
typedef unsigned atom_t;
struct config {
double cb = 2.85;
unsigned L = 20;
unsigned t = 45;
unsigned max_no_improve = 500000;
double sp = 0.0003;
};
struct stats {
unsigned m_num_steps = 0;
};
public:
struct linear_term {
vector<std::pair<num_t, var_t>> m_args;
num_t m_coeff{ 0 };
};
struct nonlinear_coeff {
var_t v; // variable or multiplier containing x
num_t coeff; // coeff of v in inequality
unsigned p; // power
};
typedef svector<std::pair<unsigned, unsigned>> monomial_t;
// encode args <= bound, args = bound, args < bound
struct ineq : public linear_term {
vector<std::pair<var_t, vector<nonlinear_coeff>>> m_nonlinear;
vector<monomial_t> m_monomials;
ineq_kind m_op = ineq_kind::LE;
num_t m_args_value;
bool m_is_linear = true;
bool is_true() const;
std::ostream& display(std::ostream& out) const;
};
private:
class var_info {
num_t m_range{ 100000000 };
num_t m_update_value{ 0 };
unsigned m_update_timestamp = 0;
public:
var_info(expr* e, var_sort k): m_expr(e), m_sort(k) {}
expr* m_expr;
num_t m_value{ 0 };
num_t m_best_value{ 0 };
var_sort m_sort;
arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP;
unsigned m_def_idx = UINT_MAX;
vector<std::pair<num_t, sat::bool_var>> m_bool_vars;
unsigned_vector m_muls;
unsigned_vector m_adds;
optional<bound> m_lo, m_hi;
// retrieve temporary value during an update.
void set_update_value(num_t const& v, unsigned timestamp) {
m_update_value = v;
m_update_timestamp = timestamp;
}
num_t const& get_update_value(unsigned ts) const {
return ts == m_update_timestamp ? m_update_value : m_value;
}
bool in_range(num_t const& n) const {
if (-m_range < n && n < m_range)
return true;
if (m_lo && !m_hi)
return n < m_lo->value + m_range;
if (!m_lo && m_hi)
return n > m_hi->value - m_range;
return false;
}
unsigned m_tabu_pos = 0, m_tabu_neg = 0;
unsigned m_last_pos = 0, m_last_neg = 0;
bool is_tabu(unsigned step, num_t const& delta) {
return (delta > 0 ? m_tabu_pos : m_tabu_neg) > step;
}
void set_step(unsigned step, unsigned tabu_step, num_t const& delta) {
if (delta > 0)
m_tabu_pos = tabu_step, m_last_pos = step;
else
m_tabu_neg = tabu_step, m_last_neg = step;
}
};
struct mul_def {
unsigned m_var;
monomial_t m_monomial;
};
struct add_def : public linear_term {
unsigned m_var;
};
struct op_def {
unsigned m_var = UINT_MAX;
arith_op_kind m_op = LAST_ARITH_OP;
unsigned m_arg1, m_arg2;
};
struct var_change {
unsigned m_var;
num_t m_delta;
double m_score;
};
stats m_stats;
config m_config;
scoped_ptr_vector<ineq> m_bool_vars;
vector<var_info> m_vars;
vector<mul_def> m_muls;
vector<add_def> m_adds;
vector<op_def> m_ops;
unsigned_vector m_expr2var;
svector<double> m_probs;
bool m_dscore_mode = false;
vector<var_change> m_updates;
var_t m_last_var = 0;
sat::literal m_last_literal = sat::null_literal;
num_t m_last_delta { 0 };
bool m_use_tabu = true;
unsigned m_updates_max_size = 45;
arith_util a;
svector<double> m_prob_break;
void invariant();
void invariant(ineq const& i);
unsigned get_num_vars() const { return m_vars.size(); }
bool eval_is_correct(var_t v);
bool repair_mul(mul_def const& md);
bool repair_add(add_def const& ad);
bool repair_mod(op_def const& od);
bool repair_idiv(op_def const& od);
bool repair_div(op_def const& od);
bool repair_rem(op_def const& od);
bool repair_power(op_def const& od);
bool repair_abs(op_def const& od);
bool repair_to_int(op_def const& od);
bool repair_to_real(op_def const& od);
bool repair(sat::literal lit);
bool in_bounds(var_t v, num_t const& value);
bool is_fixed(var_t v);
bool is_linear(var_t x, vector<nonlinear_coeff> const& nlc, num_t& b);
bool is_quadratic(var_t x, vector<nonlinear_coeff> const& nlc, num_t& a, num_t& b);
num_t mul_value_without(var_t m, var_t x);
void add_update(var_t v, num_t delta);
bool is_permitted_update(var_t v, num_t const& delta, num_t& delta_out);
unsigned m_update_timestamp = 0;
svector<var_t> m_update_trail;
bool check_update(var_t v, num_t new_value);
void apply_checked_update();
num_t value1(var_t v);
vector<num_t> m_factors;
vector<num_t> const& factor(num_t n);
num_t root_of(unsigned n, num_t a);
num_t power_of(num_t a, unsigned k);
struct monomial_elem {
num_t other_product;
var_t v;
unsigned p; // power
};
// double reward(sat::literal lit);
bool sign(sat::bool_var v) const { return !ctx.is_true(sat::literal(v, false)); }
ineq* atom(sat::bool_var bv) const { return m_bool_vars.get(bv, nullptr); }
num_t dtt(bool sign, ineq const& ineq) const { return dtt(sign, ineq.m_args_value, ineq); }
num_t dtt(bool sign, num_t const& args_value, ineq const& ineq) const;
num_t dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const;
num_t dtt(bool sign, ineq const& ineq, num_t const& coeff, num_t const& delta) const;
num_t dts(unsigned cl, var_t v, num_t const& new_value) const;
num_t compute_dts(unsigned cl) const;
bool is_mul(var_t v) const { return m_vars[v].m_op == arith_op_kind::OP_MUL; }
bool is_add(var_t v) const { return m_vars[v].m_op == arith_op_kind::OP_ADD; }
mul_def const& get_mul(var_t v) const { SASSERT(is_mul(v)); return m_muls[m_vars[v].m_def_idx]; }
add_def const& get_add(var_t v) const { SASSERT(is_add(v)); return m_adds[m_vars[v].m_def_idx]; }
bool update(var_t v, num_t const& new_value);
bool apply_update();
bool find_nl_moves(sat::literal lit);
bool find_lin_moves(sat::literal lit);
bool find_reset_moves(sat::literal lit);
void add_reset_update(var_t v);
void find_linear_moves(ineq const& i, var_t x, num_t const& coeff, num_t const& sum);
void find_quadratic_moves(ineq const& i, var_t x, num_t const& a, num_t const& b, num_t const& sum);
double compute_score(var_t x, num_t const& delta);
void save_best_values();
var_t mk_var(expr* e);
var_t mk_term(expr* e);
var_t mk_op(arith_op_kind k, expr* e, expr* x, expr* y);
void add_arg(linear_term& term, num_t const& c, var_t v);
void add_args(linear_term& term, expr* e, num_t const& sign);
ineq& new_ineq(ineq_kind op, num_t const& bound);
void init_ineq(sat::bool_var bv, ineq& i);
num_t divide(var_t v, num_t const& delta, num_t const& coeff);
num_t divide_floor(var_t v, num_t const& a, num_t const& b);
num_t divide_ceil(var_t v, num_t const& a, num_t const& b);
void init_bool_var_assignment(sat::bool_var v);
bool is_int(var_t v) const { return m_vars[v].m_sort == var_sort::INT; }
num_t value(var_t v) const { return m_vars[v].m_value; }
num_t const& get_update_value(var_t v) const { return m_vars[v].get_update_value(m_update_timestamp); }
bool is_num(expr* e, num_t& i);
expr_ref from_num(sort* s, num_t const& n);
void check_ineqs();
void init_bool_var(sat::bool_var bv);
void initialize_unit(sat::literal lit);
void add_le(var_t v, num_t const& n);
void add_ge(var_t v, num_t const& n);
void add_lt(var_t v, num_t const& n);
void add_gt(var_t v, num_t const& n);
std::ostream& display(std::ostream& out, var_t v) const;
std::ostream& display(std::ostream& out, add_def const& ad) const;
std::ostream& display(std::ostream& out, mul_def const& md) const;
public:
arith_base(context& ctx);
~arith_base() override {}
void register_term(expr* e) override;
bool set_value(expr* e, expr* v) override;
expr_ref get_value(expr* e) override;
void initialize() override;
void propagate_literal(sat::literal lit) override;
bool propagate() override;
void repair_up(app* e) override;
bool repair_down(app* e) override;
void repair_literal(sat::literal lit) override;
bool is_sat() override;
void on_rescale() override;
void on_restart() override;
std::ostream& display(std::ostream& out) const override;
void collect_statistics(statistics& st) const override;
void reset_statistics() override;
};
inline std::ostream& operator<<(std::ostream& out, typename arith_base<checked_int64<true>>::ineq const& ineq) {
return ineq.display(out);
}
inline std::ostream& operator<<(std::ostream& out, typename arith_base<rational>::ineq const& ineq) {
return ineq.display(out);
}
}

View file

@ -0,0 +1,131 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
sls_arith_plugin.cpp
Abstract:
Local search dispatch for NIA
Author:
Nikolaj Bjorner (nbjorner) 2023-02-07
--*/
#include "ast/sls/sls_arith_plugin.h"
#include "ast/ast_ll_pp.h"
namespace sls {
#define WITH_FALLBACK(_fn_) \
if (m_arith64) { \
try {\
return m_arith64->_fn_;\
}\
catch (overflow_exception&) {\
throw;\
init_backup();\
}\
}\
return m_arith->_fn_; \
#define APPLY_BOTH(_fn_) \
if (m_arith64) { \
try {\
m_arith64->_fn_;\
}\
catch (overflow_exception&) {\
throw;\
init_backup();\
}\
}\
m_arith->_fn_; \
arith_plugin::arith_plugin(context& ctx) :
plugin(ctx), m_shared(ctx.get_manager()) {
m_arith64 = alloc(arith_base<checked_int64<true>>, ctx);
m_arith = alloc(arith_base<rational>, ctx);
m_arith64 = nullptr;
if (m_arith)
m_fid = m_arith->fid();
else
m_fid = m_arith64->fid();
}
void arith_plugin::init_backup() {
m_arith64 = nullptr;
}
void arith_plugin::register_term(expr* e) {
APPLY_BOTH(register_term(e));
}
expr_ref arith_plugin::get_value(expr* e) {
WITH_FALLBACK(get_value(e));
}
void arith_plugin::initialize() {
APPLY_BOTH(initialize());
}
void arith_plugin::propagate_literal(sat::literal lit) {
WITH_FALLBACK(propagate_literal(lit));
}
bool arith_plugin::propagate() {
WITH_FALLBACK(propagate());
}
bool arith_plugin::is_sat() {
WITH_FALLBACK(is_sat());
}
void arith_plugin::on_rescale() {
APPLY_BOTH(on_rescale());
}
void arith_plugin::on_restart() {
WITH_FALLBACK(on_restart());
}
std::ostream& arith_plugin::display(std::ostream& out) const {
if (m_arith64)
return m_arith64->display(out);
else
return m_arith->display(out);
}
bool arith_plugin::repair_down(app* e) {
WITH_FALLBACK(repair_down(e));
}
void arith_plugin::repair_up(app* e) {
WITH_FALLBACK(repair_up(e));
}
void arith_plugin::repair_literal(sat::literal lit) {
WITH_FALLBACK(repair_literal(lit));
}
bool arith_plugin::set_value(expr* e, expr* v) {
WITH_FALLBACK(set_value(e, v));
}
void arith_plugin::collect_statistics(statistics& st) const {
if (m_arith64)
m_arith64->collect_statistics(st);
else
m_arith->collect_statistics(st);
}
void arith_plugin::reset_statistics() {
if (m_arith)
m_arith->reset_statistics();
if (m_arith64)
m_arith64->reset_statistics();
}
}

View file

@ -0,0 +1,52 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
sls_arith_plugin.h
Abstract:
Theory plugin for arithmetic local search
Author:
Nikolaj Bjorner (nbjorner) 2024-07-05
--*/
#pragma once
#include "ast/sls/sls_context.h"
#include "ast/sls/sls_arith_base.h"
namespace sls {
class arith_plugin : public plugin {
scoped_ptr<arith_base<checked_int64<true>>> m_arith64;
scoped_ptr<arith_base<rational>> m_arith;
expr_ref_vector m_shared;
void init_backup();
public:
arith_plugin(context& ctx);
~arith_plugin() override {}
void register_term(expr* e) override;
expr_ref get_value(expr* e) override;
void initialize() override;
void propagate_literal(sat::literal lit) override;
bool propagate() override;
bool repair_down(app* e) override;
void repair_up(app* e) override;
void repair_literal(sat::literal lit) override;
bool is_sat() override;
void on_rescale() override;
void on_restart() override;
std::ostream& display(std::ostream& out) const override;
bool set_value(expr* e, expr* v) override;
void collect_statistics(statistics& st) const override;
void reset_statistics() override;
};
}

View file

@ -0,0 +1,277 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
sls_array_plugin.cpp
Abstract:
Theory plugin for arrays local search
Author:
Nikolaj Bjorner (nbjorner) 2024-07-06
--*/
#include "ast/sls/sls_array_plugin.h"
#include "ast/ast_ll_pp.h"
#include "ast/ast_pp.h"
namespace sls {
array_plugin::array_plugin(context& ctx):
plugin(ctx),
a(m)
{
m_fid = a.get_family_id();
}
bool array_plugin::is_sat() {
if (!m_has_arrays)
return true;
m_g = alloc(euf::egraph, m);
m_kv = nullptr;
init_egraph(*m_g);
saturate_store(*m_g);
return true;
}
// b ~ a[i -> v]
// ensure b[i] ~ v
// ensure b[j] ~ a[j] for j != i
void array_plugin::saturate_store(euf::egraph& g) {
unsigned sz = 0;
while (sz < g.nodes().size()) {
sz = g.nodes().size();
for (unsigned i = 0; i < sz; ++i) {
auto n = g.nodes()[i];
if (!a.is_store(n->get_expr()))
continue;
force_store_axiom1(g, n);
for (auto p : euf::enode_parents(n->get_root()))
if (a.is_select(p->get_expr()))
force_store_axiom2_down(g, n, p);
auto arr = n->get_arg(0);
for (auto p : euf::enode_parents(arr->get_root()))
if (a.is_select(p->get_expr()))
force_store_axiom2_up(g, n, p);
}
}
display(verbose_stream() << "saturated\n");
}
euf::enode* array_plugin::mk_select(euf::egraph& g, euf::enode* b, euf::enode* sel) {
auto arity = get_array_arity(b->get_sort());
ptr_buffer<expr> args;
ptr_buffer<euf::enode> eargs;
args.push_back(b->get_expr());
eargs.push_back(b);
for (unsigned i = 1; i <= arity; ++i) {
auto idx = sel->get_arg(i);
eargs.push_back(idx);
args.push_back(idx->get_expr());
}
expr_ref esel(a.mk_select(args), m);
auto n = g.find(esel);
return n ? n : g.mk(esel, 0, eargs.size(), eargs.data());
}
// ensure a[i->v][i] = v exists in the e-graph
void array_plugin::force_store_axiom1(euf::egraph& g, euf::enode* n) {
SASSERT(a.is_store(n->get_expr()));
auto val = n->get_arg(n->num_args() - 1);
auto nsel = mk_select(g, n, n);
if (are_distinct(nsel, val))
add_store_axiom1(n->get_app());
else {
g.merge(nsel, val, nullptr);
VERIFY(g.propagate());
}
}
// i /~ j, b ~ a[i->v], b[j] occurs -> a[j] = b[j]
void array_plugin::force_store_axiom2_down(euf::egraph& g, euf::enode* sto, euf::enode* sel) {
SASSERT(a.is_store(sto->get_expr()));
SASSERT(a.is_select(sel->get_expr()));
if (sel->get_arg(0)->get_root() != sto->get_root())
return;
if (eq_args(sto, sel))
return;
auto nsel = mk_select(g, sto->get_arg(0), sel);
if (are_distinct(nsel, sel))
add_store_axiom2(sto->get_app(), sel->get_app());
else {
g.merge(nsel, sel, nullptr);
VERIFY(g.propagate());
}
}
// a ~ b, i /~ j, b[j] occurs -> a[i -> v][j] = b[j]
void array_plugin::force_store_axiom2_up(euf::egraph& g, euf::enode* sto, euf::enode* sel) {
SASSERT(a.is_store(sto->get_expr()));
SASSERT(a.is_select(sel->get_expr()));
if (sel->get_arg(0)->get_root() != sto->get_arg(0)->get_root())
return;
if (eq_args(sto, sel))
return;
auto nsel = mk_select(g, sto, sel);
if (are_distinct(nsel, sel))
add_store_axiom2(sto->get_app(), sel->get_app());
else {
g.merge(nsel, sel, nullptr);
VERIFY(g.propagate());
}
}
bool array_plugin::are_distinct(euf::enode* a, euf::enode* b) {
a = a->get_root();
b = b->get_root();
return a->interpreted() && b->interpreted() && a != b; // TODO work with nested arrays?
}
bool array_plugin::eq_args(euf::enode* sto, euf::enode* sel) {
SASSERT(a.is_store(sto->get_expr()));
SASSERT(a.is_select(sel->get_expr()));
unsigned arity = get_array_arity(sto->get_sort());
for (unsigned i = 1; i < arity; ++i) {
if (sto->get_arg(i)->get_root() != sel->get_arg(i)->get_root())
return false;
}
return true;
}
void array_plugin::add_store_axiom1(app* sto) {
if (!m_add_conflicts)
return;
ptr_vector<expr> args;
args.push_back(sto);
for (unsigned i = 1; i < sto->get_num_args() - 1; ++i)
args.push_back(sto->get_arg(i));
expr_ref sel(a.mk_select(args), m);
expr_ref eq(m.mk_eq(sel, to_app(sto)->get_arg(sto->get_num_args() - 1)), m);
verbose_stream() << "add store axiom 1 " << mk_bounded_pp(sto, m) << "\n";
ctx.add_clause(eq);
}
void array_plugin::add_store_axiom2(app* sto, app* sel) {
if (!m_add_conflicts)
return;
ptr_vector<expr> args1, args2;
args1.push_back(sto);
args2.push_back(sto->get_arg(0));
for (unsigned i = 1; i < sel->get_num_args() - 1; ++i) {
args1.push_back(sel->get_arg(i));
args2.push_back(sel->get_arg(i));
}
expr_ref sel1(a.mk_select(args1), m);
expr_ref sel2(a.mk_select(args2), m);
expr_ref eq(m.mk_eq(sel1, sel2), m);
expr_ref_vector ors(m);
ors.push_back(eq);
for (unsigned i = 1; i < sel->get_num_args() - 1; ++i)
ors.push_back(m.mk_eq(sel->get_arg(i), sto->get_arg(i)));
verbose_stream() << "add store axiom 2 " << mk_bounded_pp(sto, m) << " " << mk_bounded_pp(sel, m) << "\n";
ctx.add_clause(m.mk_or(ors));
}
void array_plugin::init_egraph(euf::egraph& g) {
ptr_vector<euf::enode> args;
for (auto t : ctx.subterms()) {
args.reset();
if (is_app(t))
for (auto* arg : *to_app(t))
args.push_back(g.find(arg));
euf::enode* n1, * n2;
n1 = g.find(t);
n1 = n1 ? n1 : g.mk(t, 0, args.size(), args.data());
if (a.is_array(t))
continue;
auto v = ctx.get_value(t);
verbose_stream() << "init " << mk_bounded_pp(t, m) << " := " << mk_bounded_pp(v, m) << "\n";
n2 = g.find(v);
n2 = n2 ? n2: g.mk(v, 0, 0, nullptr);
g.merge(n1, n2, nullptr);
}
for (auto lit : ctx.root_literals()) {
if (!ctx.is_true(lit) || lit.sign())
continue;
auto e = ctx.atom(lit.var());
expr* x, * y;
if (e && m.is_eq(e, x, y))
g.merge(g.find(x), g.find(y), nullptr);
}
display(verbose_stream());
}
void array_plugin::init_kv(euf::egraph& g, kv& kv) {
for (auto n : g.nodes()) {
if (!n->is_root() || !a.is_array(n->get_expr()))
continue;
kv.insert(n, select2value());
for (auto p : euf::enode_parents(n)) {
if (!a.is_select(p->get_expr()))
continue;
if (p->get_arg(0)->get_root() != n->get_root())
continue;
auto val = p->get_root();
kv[n].insert(select_args(p), val);
}
}
display(verbose_stream());
}
expr_ref array_plugin::get_value(expr* e) {
SASSERT(a.is_array(e));
if (!m_g) {
m_g = alloc(euf::egraph, m);
init_egraph(*m_g);
flet<bool> _strong(m_add_conflicts, false);
saturate_store(*m_g);
}
if (!m_kv) {
m_kv = alloc(kv);
init_kv(*m_g, *m_kv);
}
auto& kv = *m_kv;
auto n = m_g->find(e)->get_root();
expr_ref r(n->get_expr(), m);
for (auto [k, v] : kv[n]) {
ptr_vector<expr> args;
args.push_back(r);
args.push_back(k.sel->get_arg(1)->get_expr());
args.push_back(v->get_expr());
r = a.mk_store(args);
}
return r;
}
std::ostream& array_plugin::display(std::ostream& out) const {
if (m_g)
m_g->display(out);
if (m_kv) {
for (auto& [n, kvs] : *m_kv) {
out << m_g->pp(n) << " -> {";
char const* sp = "";
for (auto& [k, v] : kvs) {
out << sp;
for (unsigned i = 1; i < k.sel->num_args(); ++i)
out << m_g->pp(k.sel->get_arg(i)->get_root()) << " ";
out << "-> " << m_g->pp(v);
sp = " ";
}
out << "}\n";
}
}
return out;
}
}

View file

@ -0,0 +1,90 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
sls_array_plugin.h
Abstract:
Theory plugin for arrays local search
Author:
Nikolaj Bjorner (nbjorner) 2024-07-06
--*/
#pragma once
#include "ast/sls/sls_context.h"
#include "ast/array_decl_plugin.h"
#include "ast/euf/euf_egraph.h"
namespace sls {
class array_plugin : public plugin {
struct select_args {
euf::enode* sel = nullptr;
select_args(euf::enode* s) : sel(s) {}
select_args() {}
};
struct select_args_hash {
unsigned operator()(select_args const& a) const {
unsigned h = 0;
for (unsigned i = 1; i < a.sel->num_args(); ++i)
h ^= a.sel->get_arg(i)->get_root()->hash();
return h;
};
};
struct select_args_eq {
bool operator()(select_args const& a, select_args const& b) const {
SASSERT(a.sel->num_args() == b.sel->num_args());
for (unsigned i = 1; i < a.sel->num_args(); ++i)
if (a.sel->get_arg(i)->get_root() != b.sel->get_arg(i)->get_root())
return false;
return true;
}
};
typedef map<select_args, euf::enode*, select_args_hash, select_args_eq> select2value;
typedef obj_map<euf::enode, select2value> kv;
array_util a;
scoped_ptr<euf::egraph> m_g;
scoped_ptr<kv> m_kv;
bool m_add_conflicts = true;
bool m_has_arrays = false;
void init_egraph(euf::egraph& g);
void init_kv(euf::egraph& g, kv& kv);
void saturate_store(euf::egraph& g);
void force_store_axiom1(euf::egraph& g, euf::enode* n);
void force_store_axiom2_down(euf::egraph& g, euf::enode* sto, euf::enode* sel);
void force_store_axiom2_up(euf::egraph& g, euf::enode* sto, euf::enode* sel);
void add_store_axiom1(app* sto);
void add_store_axiom2(app* sto, app* sel);
bool are_distinct(euf::enode* a, euf::enode* b);
bool eq_args(euf::enode* sto, euf::enode* sel);
euf::enode* mk_select(euf::egraph& g, euf::enode* b, euf::enode* sel);
public:
array_plugin(context& ctx);
~array_plugin() override {}
void register_term(expr* e) override { if (a.is_array(e->get_sort())) m_has_arrays = true; }
expr_ref get_value(expr* e) override;
void initialize() override { m_g = nullptr; }
void propagate_literal(sat::literal lit) override { m_g = nullptr; }
bool propagate() override { return false; }
bool repair_down(app* e) override { return true; }
void repair_up(app* e) override {}
void repair_literal(sat::literal lit) override { m_g = nullptr; }
bool is_sat() override;
void on_rescale() override {}
void on_restart() override {}
std::ostream& display(std::ostream& out) const override;
bool set_value(expr* e, expr* v) override { return false; }
void collect_statistics(statistics& st) const override {}
void reset_statistics() override {}
};
}

View file

@ -0,0 +1,210 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_basic_plugin.cpp
Abstract:
Local search dispatch for Boolean connectives
Author:
Nikolaj Bjorner (nbjorner) 2024-07-07
--*/
#include "ast/sls/sls_basic_plugin.h"
#include "ast/ast_ll_pp.h"
#include "ast/ast_pp.h"
#include "ast/ast_util.h"
namespace sls {
expr_ref basic_plugin::get_value(expr* e) {
return expr_ref(m.mk_bool_val(bval0(e)), m);
}
bool basic_plugin::is_basic(expr* e) const {
if (!e || !is_app(e))
return false;
if (m.is_ite(e) && !m.is_bool(e) && false)
return true;
if (m.is_xor(e) && to_app(e)->get_num_args() != 2)
return true;
if (m.is_distinct(e))
return true;
return false;
}
void basic_plugin::propagate_literal(sat::literal lit) {
}
void basic_plugin::register_term(expr* e) {
expr* c, * th, * el;
if (m.is_ite(e, c, th, el) && !m.is_bool(e)) {
ctx.add_clause(m.mk_or(mk_not(m, c), m.mk_eq(e, th)));
ctx.add_clause(m.mk_or(c, m.mk_eq(e, el)));
}
}
void basic_plugin::initialize() {
}
bool basic_plugin::propagate() {
return false;
}
bool basic_plugin::is_sat() {
return true;
}
std::ostream& basic_plugin::display(std::ostream& out) const {
return out;
}
bool basic_plugin::set_value(expr* e, expr* v) {
if (!m.is_bool(e))
return false;
SASSERT(m.is_true(v) || m.is_false(v));
return set_value(e, m.is_true(v));
}
expr_ref basic_plugin::eval_ite(app* e) {
expr* c, * th, * el;
VERIFY(m.is_ite(e, c, th, el));
if (bval0(c))
return ctx.get_value(th);
else
return ctx.get_value(el);
}
expr_ref basic_plugin::eval_distinct(app* e) {
for (unsigned i = 0; i < e->get_num_args(); ++i) {
for (unsigned j = i + 1; j < e->get_num_args(); ++j) {
if (bval0(e->get_arg(i)) == bval0(e->get_arg(j)))
return expr_ref(m.mk_false(), m);
}
}
return expr_ref(m.mk_true(), m);
}
expr_ref basic_plugin::eval_xor(app* e) {
bool b = false;
for (expr* arg : *e)
b ^= bval0(arg);
return expr_ref(m.mk_bool_val(b), m);
}
bool basic_plugin::bval0(expr* e) const {
SASSERT(m.is_bool(e));
return ctx.is_true(ctx.mk_literal(e));
}
bool basic_plugin::try_repair(app* e, unsigned i) {
switch (e->get_decl_kind()) {
case OP_XOR:
return try_repair_xor(e, i);
case OP_ITE:
return try_repair_ite(e, i);
case OP_DISTINCT:
return try_repair_distinct(e, i);
default:
return true;
}
}
bool basic_plugin::try_repair_xor(app* e, unsigned i) {
auto child = e->get_arg(i);
bool bv = false;
for (unsigned j = 0; j < e->get_num_args(); ++j)
if (j != i)
bv ^= bval0(e->get_arg(j));
bool ev = bval0(e);
return set_value(child, ev != bv);
}
bool basic_plugin::try_repair_ite(app* e, unsigned i) {
if (m.is_bool(e))
return true;
auto child = e->get_arg(i);
auto cond = e->get_arg(0);
bool c = bval0(cond);
if (i == 0) {
auto eval = ctx.get_value(e);
auto eval1 = ctx.get_value(e->get_arg(1));
auto eval2 = ctx.get_value(e->get_arg(2));
if (eval == eval1 && eval == eval2)
return true;
if (eval == eval1)
return set_value(cond, true);
if (eval == eval2)
return set_value(cond, false);
return false;
}
if (c != (i == 1))
return false;
if (m.is_value(child))
return false;
bool r = ctx.set_value(child, ctx.get_value(e));
verbose_stream() << "repair-ite-down " << mk_bounded_pp(e, m) << " @ " << mk_bounded_pp(child, m) << " := " << ctx.get_value(e) << " success " << r << "\n";
return r;
}
void basic_plugin::repair_up(app* e) {
expr* c, * th, * el;
expr_ref val(m);
if (!is_basic(e))
return;
if (m.is_ite(e, c, th, el) && !m.is_bool(e))
val = eval_ite(e);
else if (m.is_xor(e))
val = eval_xor(e);
else if (m.is_distinct(e))
val = eval_distinct(e);
else
return;
verbose_stream() << "repair-up " << mk_bounded_pp(e, m) << " " << val << "\n";
if (!ctx.set_value(e, val))
ctx.new_value_eh(e);
}
void basic_plugin::repair_literal(sat::literal lit) {
}
bool basic_plugin::repair_down(app* e) {
if (!is_basic(e))
return true;
if (m.is_xor(e) && eval_xor(e) == ctx.get_value(e))
return true;
if (m.is_ite(e) && eval_ite(e) == ctx.get_value(e))
return true;
if (m.is_distinct(e) && eval_distinct(e) == ctx.get_value(e))
return true;
verbose_stream() << "basic repair down " << mk_bounded_pp(e, m) << "\n";
unsigned n = e->get_num_args();
unsigned s = ctx.rand(n);
for (unsigned i = 0; i < n; ++i) {
auto j = (i + s) % n;
if (try_repair(e, j))
return true;
}
return false;
}
bool basic_plugin::try_repair_distinct(app* e, unsigned i) {
NOT_IMPLEMENTED_YET();
return false;
}
bool basic_plugin::set_value(expr* e, bool b) {
auto lit = ctx.mk_literal(e);
if (ctx.is_true(lit) != b) {
ctx.flip(lit.var());
ctx.new_value_eh(e);
}
return true;
}
}

View file

@ -0,0 +1,58 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
sls_basic_plugin.h
Author:
Nikolaj Bjorner (nbjorner) 2024-07-05
--*/
#pragma once
#include "ast/sls/sls_context.h"
namespace sls {
class basic_plugin : public plugin {
expr_mark m_axiomatized;
bool is_basic(expr* e) const;
bool bval0(expr* e) const;
bool try_repair(app* e, unsigned i);
bool try_repair_xor(app* e, unsigned i);
bool try_repair_ite(app* e, unsigned i);
bool try_repair_distinct(app* e, unsigned i);
bool set_value(expr* e, bool b);
expr_ref eval_ite(app* e);
expr_ref eval_distinct(app* e);
expr_ref eval_xor(app* e);
public:
basic_plugin(context& ctx) :
plugin(ctx) {
m_fid = basic_family_id;
}
~basic_plugin() override {}
void register_term(expr* e) override;
expr_ref get_value(expr* e) override;
void initialize() override;
void propagate_literal(sat::literal lit) override;
bool propagate() override;
bool repair_down(app* e) override;
void repair_up(app* e) override;
void repair_literal(sat::literal lit) override;
bool is_sat() override;
void on_rescale() override {}
void on_restart() override {}
std::ostream& display(std::ostream& out) const override;
bool set_value(expr* e, expr* v) override;
void collect_statistics(statistics& st) const override {}
void reset_statistics() override {}
};
}

View file

@ -26,7 +26,7 @@ Notes:
#include "util/luby.h"
#include "params/sls_params.hpp"
#include "ast/sls/sls_engine.h"
#include "ast/sls/sls_bv_engine.h"
sls_engine::sls_engine(ast_manager & m, params_ref const & p) :

View file

@ -23,8 +23,8 @@ Notes:
#include "ast/converters/model_converter.h"
#include "ast/sls/sls_stats.h"
#include "ast/sls/sls_tracker.h"
#include "ast/sls/sls_evaluator.h"
#include "ast/sls/sls_bv_tracker.h"
#include "ast/sls/sls_bv_evaluator.h"
class sls_engine {

File diff suppressed because it is too large Load diff

View file

@ -17,79 +17,81 @@ Author:
#pragma once
#include "ast/ast.h"
#include "ast/sls/sls_valuation.h"
#include "ast/sls/bv_sls_fixed.h"
#include "ast/sls/sls_bv_valuation.h"
#include "ast/sls/sls_bv_fixed.h"
#include "ast/sls/sls_context.h"
#include "ast/bv_decl_plugin.h"
namespace bv {
namespace sls {
class sls_fixed;
class bv_terms;
class sls_eval_plugin {
public:
virtual ~sls_eval_plugin() {}
};
class sls_eval {
using bvect = sls::bvect;
class bv_eval {
struct config {
unsigned m_prob_randomize_extract = 50;
};
friend class sls_fixed;
friend class sls::bv_fixed;
friend class sls_test;
ast_manager& m;
sls::context& ctx;
sls::bv_terms& terms;
bv_util bv;
sls_fixed m_fix;
sls::bv_fixed m_fix;
mutable mpn_manager mpn;
ptr_vector<expr> m_todo;
random_gen m_rand;
config m_config;
bool_vector m_fixed;
scoped_ptr_vector<sls_eval_plugin> m_plugins;
scoped_ptr_vector<sls::bv_valuation> m_values; // expr-id -> bv valuation
scoped_ptr_vector<sls_valuation> m_values; // expr-id -> bv valuation
bool_vector m_eval; // expr-id -> boolean valuation
bool_vector m_fixed; // expr-id -> is Boolean fixed
mutable bvect m_tmp, m_tmp2, m_tmp3, m_tmp4, m_zero, m_one, m_minus_one;
mutable bvect m_tmp, m_tmp2, m_tmp3, m_tmp4, m_mul_tmp, m_zero, m_one, m_minus_one;
bvect m_a, m_b, m_nextb, m_nexta, m_aux;
using bvval = sls_valuation;
using bvval = sls::bv_valuation;
void init_eval_basic(app* e);
void init_eval_bv(app* e);
ptr_vector<expr> m_restore;
vector<ptr_vector<expr>> m_update_stack;
expr_mark m_on_restore;
void insert_update_stack(expr* e);
bool insert_update(expr* e);
double lookahead(expr* e, bvect const& new_value);
void restore_lookahead();
/**
* Register e as a bit-vector.
* Return true if not already registered, false if already registered.
*/
bool add_bit_vector(app* e);
sls_valuation* alloc_valuation(app* e);
void add_bit_vector(app* e);
sls::bv_valuation* alloc_valuation(app* e);
bool bval1_basic(app* e) const;
bool bval1_bv(app* e) const;
bool bval1_bv(app* e, bool use_current) const;
bool bval1_tmp(app* e) const;
void fold_oper(bvect& out, app* e, unsigned i, std::function<void(bvect&, bvval const&)> const& f);
/**
* Repair operations
*/
bool try_repair_basic(app* e, unsigned i);
bool try_repair_bv(app * e, unsigned i);
bool try_repair_and_or(app* e, unsigned i);
bool try_repair_not(app* e);
bool try_repair_eq(app* e, unsigned i);
bool try_repair_xor(app* e, unsigned i);
bool try_repair_ite(app* e, unsigned i);
bool try_repair_implies(app* e, unsigned i);
bool try_repair_band(bvect const& e, bvval& a, bvval const& b);
bool try_repair_band(app* t, unsigned i);
bool try_repair_bor(bvect const& e, bvval& a, bvval const& b);
bool try_repair_bor(app* t, unsigned i);
bool try_repair_add(bvect const& e, bvval& a, bvval const& b);
bool try_repair_add(app* t, unsigned i);
bool try_repair_sub(bvect const& e, bvval& a, bvval& b, unsigned i);
bool try_repair_mul(bvect const& e, bvval& a, bvval const& b);
bool try_repair_mul(bvect const& e, bvval& a, bvect const& b);
bool try_repair_bxor(bvect const& e, bvval& a, bvval const& b);
bool try_repair_bxor(app* t, unsigned i);
bool try_repair_bnot(bvect const& e, bvval& a);
bool try_repair_bneg(bvect const& e, bvval& a);
bool try_repair_ule(bool e, bvval& a, bvval const& b);
@ -116,11 +118,14 @@ namespace bv {
bool try_repair_umul_ovfl(bool e, bvval& a, bvval& b, unsigned i);
bool try_repair_zero_ext(bvect const& e, bvval& a);
bool try_repair_sign_ext(bvect const& e, bvval& a);
bool try_repair_concat(bvect const& e, bvval& a, bvval& b, unsigned i);
bool try_repair_concat(app* e, unsigned i);
bool try_repair_extract(bvect const& e, bvval& a, unsigned lo);
bool try_repair_comp(bvect const& e, bvval& a, bvval& b, unsigned i);
bool try_repair_eq(bool is_true, bvval& a, bvval const& b);
void add_p2_1(bvval const& a, bvect& t) const;
bool try_repair_eq(app* e, unsigned i);
bool try_repair_eq_lookahead(app* e);
bool try_repair_int2bv(bvect const& e, expr* arg);
void add_p2_1(bvval const& a, bool use_current, bvect& t) const;
bool add_overflow_on_fixed(bvval const& a, bvect const& t);
bool mul_overflow_on_fixed(bvval const& a, bvect const& t);
@ -130,66 +135,58 @@ namespace bv {
digit_t random_bits();
bool random_bool() { return m_rand() % 2 == 0; }
sls_valuation& wval(app* e, unsigned i) { return wval(e->get_arg(i)); }
sls::bv_valuation& wval(app* e, unsigned i) { return wval(e->get_arg(i)); }
void eval(app* e, sls_valuation& val) const;
void eval(app* e, sls::bv_valuation& val) const;
bvect const& eval_value(app* e) const { return wval(e).eval; }
bvect const& assign_value(app* e) const { return wval(e).bits(); }
/**
* Retrieve evaluation based on immediate children.
*/
bool can_eval1(app* e) const;
void commit_eval(expr* p, app* e);
public:
sls_eval(ast_manager& m);
bv_eval(sls::bv_terms& terms, sls::context& ctx);
void init_eval(expr_ref_vector const& es, std::function<bool(expr*, unsigned)> const& eval);
void init() { m_fix.init(); }
void tighten_range(expr_ref_vector const& es) { m_fix.init(es); }
ptr_vector<expr>& sort_assertions(expr_ref_vector const& es);
void register_term(expr* e);
/**
* Retrieve evaluation based on cache.
* bval - Boolean values
* wval - Word (bit-vector) values
*/
bool bval0(expr* e) const { return m_eval[e->get_id()]; }
*/
sls_valuation& wval(expr* e) const;
sls::bv_valuation& wval(expr* e) const;
void set(expr* e, sls::bv_valuation const& val);
bool is_fixed0(expr* e) const { return m_fixed.get(e->get_id(), false); }
/**
* Retrieve evaluation based on immediate children.
*/
bool bval1(app* e) const;
bool can_eval1(app* e) const;
sls_valuation& eval(app* e) const;
void commit_eval(app* e);
void init_eval(app* e);
sls::bv_valuation& eval(app* e) const;
void set_random(app* e);
bool eval_is_correct(app* e);
bool re_eval_is_correct(app* e);
bool is_uninterpreted(app* e) const;
expr_ref get_value(app* e);
/**
* Override evaluaton.
bool bval0(expr* e) const { return ctx.is_true(e); }
bool bval1(app* e) const;
/*
* Try to invert value of child to repair value assignment of parent.
*/
void set(expr* e, bool b) {
m_eval[e->get_id()] = b;
}
/*
* Try to invert value of child to repair value assignment of parent.
*/
bool try_repair(app* e, unsigned i);
bool repair_down(app* e, unsigned i);
/*
* Propagate repair up to parent
@ -197,8 +194,8 @@ namespace bv {
bool repair_up(expr* e);
std::ostream& display(std::ostream& out, expr_ref_vector const& es);
std::ostream& display(std::ostream& out) const;
std::ostream& display_value(std::ostream& out, expr* e);
std::ostream& display_value(std::ostream& out, expr* e) const;
};
}

View file

@ -22,7 +22,7 @@ Notes:
#include "model/model_evaluator.h"
#include "ast/sls/sls_powers.h"
#include "ast/sls/sls_tracker.h"
#include "ast/sls/sls_bv_tracker.h"
class sls_evaluator {
ast_manager & m_manager;

View file

@ -13,56 +13,52 @@ Author:
#include "ast/ast_pp.h"
#include "ast/ast_ll_pp.h"
#include "ast/sls/bv_sls_fixed.h"
#include "ast/sls/bv_sls_eval.h"
#include "ast/sls/sls_bv_fixed.h"
#include "ast/sls/sls_bv_terms.h"
#include "ast/sls/sls_bv_eval.h"
namespace bv {
namespace sls {
sls_fixed::sls_fixed(sls_eval& ev):
bv_fixed::bv_fixed(bv_eval& ev, bv_terms& terms, sls::context& ctx):
ev(ev),
terms(terms),
m(ev.m),
bv(ev.bv)
bv(ev.bv),
ctx(ctx)
{}
void sls_fixed::init(expr_ref_vector const& es) {
ev.sort_assertions(es);
for (expr* e : ev.m_todo) {
if (!is_app(e))
void bv_fixed::init() {
for (auto e : ctx.subterms())
set_fixed(e);
//ctx.display(verbose_stream());
for (auto lit : ctx.unit_literals()) {
auto a = ctx.atom(lit.var());
if (!a)
continue;
app* a = to_app(e);
ev.m_fixed.setx(a->get_id(), is_fixed1(a), false);
if (a->get_family_id() == basic_family_id)
init_fixed_basic(a);
else if (a->get_family_id() == bv.get_family_id())
init_fixed_bv(a);
else
;
if (is_app(a))
init_range(to_app(a), lit.sign());
ev.m_fixed.setx(a->get_id(), true, false);
}
init_ranges(es);
ev.m_todo.reset();
//ctx.display(verbose_stream());
for (auto e : ctx.subterms())
propagate_range_up(e);
//ctx.display(verbose_stream());
}
void sls_fixed::init_ranges(expr_ref_vector const& es) {
for (expr* e : es) {
bool sign = m.is_not(e, e);
if (is_app(e))
init_range(to_app(e), sign);
}
for (expr* e : ev.m_todo)
propagate_range_up(e);
}
void sls_fixed::propagate_range_up(expr* e) {
void bv_fixed::propagate_range_up(expr* e) {
expr* t, * s;
rational v;
if (bv.is_concat(e, t, s)) {
auto& vals = wval(s);
auto& vals = ev.wval(s);
if (vals.lo() != vals.hi() && (vals.lo() < vals.hi() || vals.hi() == 0))
// lo <= e
add_range(e, vals.lo(), rational::zero(), false);
auto valt = wval(t);
auto valt = ev.wval(t);
if (valt.lo() != valt.hi() && (valt.lo() < valt.hi() || valt.hi() == 0)) {
// (2^|s|) * lo <= e < (2^|s|) * hi
auto p = rational::power_of_two(bv.get_bv_size(s));
@ -70,12 +66,12 @@ namespace bv {
}
}
else if (bv.is_bv_add(e, s, t) && bv.is_numeral(s, v)) {
auto& val = wval(t);
auto& val = ev.wval(t);
if (val.lo() != val.hi())
add_range(e, v + val.lo(), v + val.hi(), false);
}
else if (bv.is_bv_add(e, t, s) && bv.is_numeral(s, v)) {
auto& val = wval(t);
auto& val = ev.wval(t);
if (val.lo() != val.hi())
add_range(e, v + val.lo(), v + val.hi(), false);
}
@ -83,7 +79,7 @@ namespace bv {
// x in [lo, hi[ => -x in [-hi + 1, -lo + 1[
else if (bv.is_bv_mul(e, s, t) && bv.is_numeral(s, v) &&
v + 1 == rational::power_of_two(bv.get_bv_size(e))) {
auto& val = wval(t);
auto& val = ev.wval(t);
if (val.lo() != val.hi())
add_range(e, -val.hi() + 1, - val.lo() + 1, false);
}
@ -91,7 +87,7 @@ namespace bv {
// s <=s t <=> s + K <= t + K, K = 2^{bw-1}
bool sls_fixed::init_range(app* e, bool sign) {
bool bv_fixed::init_range(app* e, bool sign) {
expr* s, * t, * x, * y;
rational a, b;
unsigned idx;
@ -149,7 +145,7 @@ namespace bv {
return true;
}
else if (bv.is_bit2bool(e, s, idx)) {
auto& val = wval(s);
auto& val = ev.wval(s);
val.try_set_bit(idx, !sign);
val.fixed.set(idx, true);
val.tighten_range();
@ -159,17 +155,17 @@ namespace bv {
return false;
}
bool sls_fixed::init_eq(expr* t, rational const& a, bool sign) {
bool bv_fixed::init_eq(expr* t, rational const& a, bool sign) {
unsigned lo, hi;
rational b(0);
// verbose_stream() << mk_bounded_pp(t, m) << " == " << a << "\n";
expr* s = nullptr;
if (sign)
if (sign && true)
// 1 <= t - a
init_range(nullptr, rational(1), t, -a, false);
else
if (!sign)
// t - a <= 0
init_range(t, -a, nullptr, rational::zero(), false);
if (!sign && bv.is_bv_not(t, s)) {
for (unsigned i = 0; i < bv.get_bv_size(s); ++i)
if (!a.get_bit(i))
@ -187,20 +183,21 @@ namespace bv {
}
if (bv.is_extract(t, lo, hi, s)) {
if (hi == lo) {
sign = sign ? a == 1 : a == 0;
auto& val = wval(s);
if (val.try_set_bit(lo, !sign))
val.fixed.set(lo, true);
auto sign1 = sign ? a == 1 : a == 0;
auto& val = ev.wval(s);
if (val.try_set_bit(lo, !sign1))
val.fixed.set(lo, true);
val.tighten_range();
}
else if (!sign) {
auto& val = wval(s);
auto& val = ev.wval(s);
for (unsigned i = lo; i <= hi; ++i)
if (val.try_set_bit(i, a.get_bit(i - lo)))
val.fixed.set(i, true);
val.tighten_range();
// verbose_stream() << lo << " " << hi << " " << val << " := " << a << "\n";
}
}
if (!sign && hi + 1 == bv.get_bv_size(s)) {
// s < 2^lo * (a + 1)
@ -223,7 +220,7 @@ namespace bv {
// a < x + b <=> ! (x + b <= a) <=> x not in [-a, b - a [ <=> x in [b - a, -a [ a != -1
// x + a < x + b <=> ! (x + b <= x + a) <=> x in [-a, -b [ a != b
//
bool sls_fixed::init_range(expr* x, rational const& a, expr* y, rational const& b, bool sign) {
bool bv_fixed::init_range(expr* x, rational const& a, expr* y, rational const& b, bool sign) {
if (!x && !y)
return false;
if (!x)
@ -235,8 +232,8 @@ namespace bv {
return false;
}
bool sls_fixed::add_range(expr* e, rational lo, rational hi, bool sign) {
auto& v = wval(e);
bool bv_fixed::add_range(expr* e, rational lo, rational hi, bool sign) {
auto& v = ev.wval(e);
lo = mod(lo, rational::power_of_two(bv.get_bv_size(e)));
hi = mod(hi, rational::power_of_two(bv.get_bv_size(e)));
if (lo == hi)
@ -262,7 +259,7 @@ namespace bv {
return true;
}
void sls_fixed::get_offset(expr* e, expr*& x, rational& offset) {
void bv_fixed::get_offset(expr* e, expr*& x, rational& offset) {
expr* s, * t;
x = e;
offset = 0;
@ -285,177 +282,173 @@ namespace bv {
x = nullptr;
}
sls_valuation& sls_fixed::wval(expr* e) {
return ev.wval(e);
}
void sls_fixed::init_fixed_basic(app* e) {
if (bv.is_bv(e) && m.is_ite(e)) {
auto& val = wval(e);
auto& val_th = wval(e->get_arg(1));
auto& val_el = wval(e->get_arg(2));
for (unsigned i = 0; i < val.nw; ++i)
val.fixed[i] = val_el.fixed[i] & val_th.fixed[i] & ~(val_el.bits(i) ^ val_th.bits(i));
}
}
void sls_fixed::init_fixed_bv(app* e) {
if (bv.is_bv(e))
set_fixed_bw(e);
}
bool sls_fixed::is_fixed1(app* e) const {
bool bv_fixed::is_fixed1(app* e) const {
if (is_uninterp(e))
return false;
if (e->get_family_id() == basic_family_id)
return is_fixed1_basic(e);
return all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); });
}
bool sls_fixed::is_fixed1_basic(app* e) const {
switch (e->get_decl_kind()) {
case OP_TRUE:
case OP_FALSE:
return true;
case OP_AND:
return any_of(*e, [&](expr* arg) { return ev.is_fixed0(arg) && !ev.bval0(e); });
case OP_OR:
return any_of(*e, [&](expr* arg) { return ev.is_fixed0(arg) && ev.bval0(e); });
default:
return all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); });
}
}
void sls_fixed::set_fixed_bw(app* e) {
SASSERT(bv.is_bv(e));
SASSERT(e->get_family_id() == bv.get_fid());
auto& v = ev.wval(e);
if (all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); })) {
for (unsigned i = 0; i < v.bw; ++i)
v.fixed.set(i, true);
void bv_fixed::set_fixed(expr* _e) {
if (!is_app(_e))
return;
auto e = to_app(_e);
if (e->get_family_id() == bv.get_family_id() && all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); })) {
if (bv.is_bv(e)) {
auto& v = ev.wval(e);
for (unsigned i = 0; i < v.bw; ++i)
v.fixed.set(i, true);
}
ev.m_fixed.setx(e->get_id(), true, false);
return;
}
if (!bv.is_bv(e))
return;
auto& v = ev.wval(e);
if (m.is_ite(e)) {
auto& val_th = ev.wval(e->get_arg(1));
auto& val_el = ev.wval(e->get_arg(2));
for (unsigned i = 0; i < v.nw; ++i)
v.fixed[i] = val_el.fixed[i] & val_th.fixed[i] & ~(val_el.bits(i) ^ val_th.bits(i));
return;
}
if (e->get_family_id() != bv.get_fid())
return;
switch (e->get_decl_kind()) {
case OP_BAND: {
auto& a = wval(e->get_arg(0));
auto& b = wval(e->get_arg(1));
// (a.fixed & b.fixed) | (a.fixed & ~a.bits) | (b.fixed & ~b.bits)
for (unsigned i = 0; i < a.nw; ++i)
v.fixed[i] = (a.fixed[i] & b.fixed[i]) | (a.fixed[i] & ~a.bits(i)) | (b.fixed[i] & ~b.bits(i));
if (e->get_num_args() == 2) {
auto& a = ev.wval(e->get_arg(0));
auto& b = ev.wval(e->get_arg(1));
// (a.fixed & b.fixed) | (a.fixed & ~a.bits) | (b.fixed & ~b.bits)
for (unsigned i = 0; i < a.nw; ++i)
v.fixed[i] = (a.fixed[i] & b.fixed[i]) | (a.fixed[i] & ~a.bits(i)) | (b.fixed[i] & ~b.bits(i));
}
break;
}
case OP_BOR: {
auto& a = wval(e->get_arg(0));
auto& b = wval(e->get_arg(1));
// (a.fixed & b.fixed) | (a.fixed & a.bits) | (b.fixed & b.bits)
for (unsigned i = 0; i < a.nw; ++i)
v.fixed[i] = (a.fixed[i] & b.fixed[i]) | (a.fixed[i] & a.bits(i)) | (b.fixed[i] & b.bits(i));
if (e->get_num_args() == 2) {
auto& a = ev.wval(e->get_arg(0));
auto& b = ev.wval(e->get_arg(1));
// (a.fixed & b.fixed) | (a.fixed & a.bits) | (b.fixed & b.bits)
for (unsigned i = 0; i < a.nw; ++i)
v.fixed[i] = (a.fixed[i] & b.fixed[i]) | (a.fixed[i] & a.bits(i)) | (b.fixed[i] & b.bits(i));
}
break;
}
case OP_BXOR: {
auto& a = wval(e->get_arg(0));
auto& b = wval(e->get_arg(1));
for (unsigned i = 0; i < a.nw; ++i)
v.fixed[i] = a.fixed[i] & b.fixed[i];
if (e->get_num_args() == 2) {
auto& a = ev.wval(e->get_arg(0));
auto& b = ev.wval(e->get_arg(1));
for (unsigned i = 0; i < a.nw; ++i)
v.fixed[i] = a.fixed[i] & b.fixed[i];
}
break;
}
case OP_BNOT: {
auto& a = wval(e->get_arg(0));
auto& a = ev.wval(e->get_arg(0));
for (unsigned i = 0; i < a.nw; ++i)
v.fixed[i] = a.fixed[i];
break;
}
case OP_BADD: {
auto& a = wval(e->get_arg(0));
auto& b = wval(e->get_arg(1));
bool pfixed = true;
for (unsigned i = 0; i < v.bw; ++i) {
if (pfixed && a.fixed.get(i) && b.fixed.get(i))
v.fixed.set(i, true);
else if (!pfixed && a.fixed.get(i) && b.fixed.get(i) &&
!a.get_bit(i) && !b.get_bit(i)) {
pfixed = true;
v.fixed.set(i, false);
}
else {
pfixed = false;
v.fixed.set(i, false);
for (unsigned j = 0; pfixed && j < e->get_num_args(); ++j) {
auto& a = ev.wval(e->get_arg(j));
pfixed &= a.fixed.get(i);
}
v.fixed.set(i, pfixed);
}
break;
}
case OP_BMUL: {
auto& a = wval(e->get_arg(0));
auto& b = wval(e->get_arg(1));
unsigned j = 0, k = 0, zj = 0, zk = 0, hzj = 0, hzk = 0;
// i'th bit depends on bits j + k = i
// if the first j, resp k bits are 0, the bits j + k are 0
for (; j < v.bw; ++j)
if (!a.fixed.get(j))
break;
for (; k < v.bw; ++k)
if (!b.fixed.get(k))
break;
for (; zj < v.bw; ++zj)
if (!a.fixed.get(zj) || a.get_bit(zj))
break;
for (; zk < v.bw; ++zk)
if (!b.fixed.get(zk) || b.get_bit(zk))
break;
for (; hzj < v.bw; ++hzj)
if (!a.fixed.get(v.bw - hzj - 1) || a.get_bit(v.bw - hzj - 1))
break;
for (; hzk < v.bw; ++hzk)
if (!b.fixed.get(v.bw - hzk - 1) || b.get_bit(v.bw - hzk - 1))
break;
if (e->get_num_args() == 2) {
SASSERT(e->get_num_args() == 2);
auto& a = ev.wval(e->get_arg(0));
auto& b = ev.wval(e->get_arg(1));
unsigned j = 0, k = 0, zj = 0, zk = 0, hzj = 0, hzk = 0;
// i'th bit depends on bits j + k = i
// if the first j, resp k bits are 0, the bits j + k are 0
for (; j < v.bw; ++j)
if (!a.fixed.get(j))
break;
for (; k < v.bw; ++k)
if (!b.fixed.get(k))
break;
for (; zj < v.bw; ++zj)
if (!a.fixed.get(zj) || a.get_bit(zj))
break;
for (; zk < v.bw; ++zk)
if (!b.fixed.get(zk) || b.get_bit(zk))
break;
for (; hzj < v.bw; ++hzj)
if (!a.fixed.get(v.bw - hzj - 1) || a.get_bit(v.bw - hzj - 1))
break;
for (; hzk < v.bw; ++hzk)
if (!b.fixed.get(v.bw - hzk - 1) || b.get_bit(v.bw - hzk - 1))
break;
if (j > 0 && k > 0) {
for (unsigned i = 0; i < std::min(k, j); ++i) {
SASSERT(!v.get_bit(i));
v.fixed.set(i, true);
if (j > 0 && k > 0) {
for (unsigned i = 0; i < std::min(k, j); ++i) {
SASSERT(!v.get_bit(i));
v.fixed.set(i, true);
}
}
// lower zj + jk bits are 0
if (zk > 0 || zj > 0) {
for (unsigned i = 0; i < zk + zj; ++i) {
SASSERT(!v.get_bit(i));
v.fixed.set(i, true);
}
}
// upper bits are 0, if enough high order bits of a, b are 0.
// TODO - buggy
if (false && hzj < v.bw && hzk < v.bw && hzj + hzk > v.bw) {
hzj = v.bw - hzj;
hzk = v.bw - hzk;
for (unsigned i = hzj + hzk - 1; i < v.bw; ++i) {
SASSERT(!v.get_bit(i));
v.fixed.set(i, true);
}
}
}
// lower zj + jk bits are 0
if (zk > 0 || zj > 0) {
for (unsigned i = 0; i < zk + zj; ++i) {
SASSERT(!v.get_bit(i));
v.fixed.set(i, true);
else {
bool pfixed = true;
for (unsigned i = 0; i < v.bw; ++i) {
for (unsigned j = 0; pfixed && j < e->get_num_args(); ++j) {
auto& a = ev.wval(e->get_arg(j));
pfixed &= a.fixed.get(i);
}
v.fixed.set(i, pfixed);
}
}
// upper bits are 0, if enough high order bits of a, b are 0.
// TODO - buggy
if (false && hzj < v.bw && hzk < v.bw && hzj + hzk > v.bw) {
hzj = v.bw - hzj;
hzk = v.bw - hzk;
for (unsigned i = hzj + hzk - 1; i < v.bw; ++i) {
SASSERT(!v.get_bit(i));
v.fixed.set(i, true);
}
}
break;
}
case OP_CONCAT: {
auto& a = wval(e->get_arg(0));
auto& b = wval(e->get_arg(1));
for (unsigned i = 0; i < b.bw; ++i)
v.fixed.set(i, b.fixed.get(i));
for (unsigned i = 0; i < a.bw; ++i)
v.fixed.set(i + b.bw, a.fixed.get(i));
unsigned bw = 0;
for (unsigned i = e->get_num_args(); i-- > 0; ) {
auto& a = ev.wval(e->get_arg(i));
for (unsigned j = 0; j < a.bw; ++j)
v.fixed.set(bw + j, a.fixed.get(j));
bw += a.bw;
}
break;
}
case OP_EXTRACT: {
expr* child;
unsigned lo, hi;
VERIFY(bv.is_extract(e, lo, hi, child));
auto& a = wval(child);
auto& a = ev.wval(child);
for (unsigned i = lo; i <= hi; ++i)
v.fixed.set(i - lo, a.fixed.get(i));
break;
}
case OP_BNEG: {
auto& a = wval(e->get_arg(0));
auto& a = ev.wval(e->get_arg(0));
bool pfixed = true;
for (unsigned i = 0; i < v.bw; ++i) {
if (pfixed && a.fixed.get(i))

View file

@ -17,19 +17,23 @@ Author:
#pragma once
#include "ast/ast.h"
#include "ast/sls/sls_valuation.h"
#include "ast/sls/sls_bv_valuation.h"
#include "ast/sls/sls_context.h"
#include "ast/bv_decl_plugin.h"
namespace bv {
class sls_eval;
namespace sls {
class bv_terms;
class bv_eval;
class sls_fixed {
sls_eval& ev;
ast_manager& m;
bv_util& bv;
class bv_fixed {
bv_eval& ev;
bv_terms& terms;
ast_manager& m;
bv_util& bv;
sls::context& ctx;
void init_ranges(expr_ref_vector const& es);
bool init_range(app* e, bool sign);
void propagate_range_up(expr* e);
bool init_range(expr* x, rational const& a, expr* y, rational const& b, bool sign);
@ -37,19 +41,11 @@ namespace bv {
bool init_eq(expr* e, rational const& v, bool sign);
bool add_range(expr* e, rational lo, rational hi, bool sign);
void init_fixed_basic(app* e);
void init_fixed_bv(app* e);
bool is_fixed1(app* e) const;
bool is_fixed1_basic(app* e) const;
void set_fixed_bw(app* e);
sls_valuation& wval(expr* e);
void set_fixed(expr* e);
public:
sls_fixed(sls_eval& ev);
void init(expr_ref_vector const& es);
bv_fixed(bv_eval& ev, bv_terms& terms, sls::context& ctx);
void init();
};
}

View file

@ -0,0 +1,206 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_bv_plugin.cpp
Abstract:
Theory plugin for bit-vector local search
Author:
Nikolaj Bjorner (nbjorner) 2024-07-06
--*/
#include "ast/sls/sls_bv_plugin.h"
#include "ast/ast_ll_pp.h"
#include "ast/ast_pp.h"
namespace sls {
bv_plugin::bv_plugin(context& ctx):
plugin(ctx),
bv(m),
m_terms(ctx),
m_eval(m_terms, ctx) {
m_fid = bv.get_family_id();
}
void bv_plugin::register_term(expr* e) {
m_terms.register_term(e);
m_eval.register_term(e);
}
expr_ref bv_plugin::get_value(expr* e) {
SASSERT(bv.is_bv(e));
auto const & val = m_eval.wval(e);
return expr_ref(bv.mk_numeral(val.get_value(), e->get_sort()), m);
}
bool bv_plugin::is_bv_predicate(expr* e) {
if (!e || !is_app(e))
return false;
auto a = to_app(e);
if (a->get_family_id() == bv.get_family_id())
return true;
if (m.is_eq(e) && bv.is_bv(a->get_arg(0)))
return true;
return false;
}
void bv_plugin::propagate_literal(sat::literal lit) {
SASSERT(ctx.is_true(lit));
auto e = ctx.atom(lit.var());
if (!is_bv_predicate(e))
return;
auto a = to_app(e);
if (!m_eval.eval_is_correct(a)) {
IF_VERBOSE(20, verbose_stream() << "repair " << lit << " " << mk_bounded_pp(e, m) << "\n");
ctx.new_value_eh(e);
}
}
bool bv_plugin::propagate() {
auto& axioms = m_terms.axioms();
if (!axioms.empty()) {
for (auto* e : axioms)
ctx.add_constraint(e);
axioms.reset();
return true;
}
return false;
}
void bv_plugin::initialize() {
if (!m_initialized) {
m_eval.init();
m_initialized = true;
}
}
void bv_plugin::init_bool_var_assignment(sat::bool_var v) {
auto a = ctx.atom(v);
if (!a || !is_app(a))
return;
if (to_app(a)->get_family_id() != bv.get_family_id())
return;
bool is_true = m_eval.bval1(to_app(a));
if (is_true != ctx.is_true(v))
ctx.flip(v);
}
bool bv_plugin::is_sat() {
bool is_sat = true;
for (auto t : ctx.subterms())
if (is_app(t) && bv.is_bv(t) && to_app(t)->get_family_id() == bv.get_fid() && !m_eval.eval_is_correct(to_app(t))) {
ctx.new_value_eh(t);
is_sat = false;
}
return is_sat;
}
std::ostream& bv_plugin::display(std::ostream& out) const {
return m_eval.display(out);
}
bool bv_plugin::set_value(expr* e, expr* v) {
if (!bv.is_bv(e))
return false;
rational val;
VERIFY(bv.is_numeral(v, val));
auto& w = m_eval.eval(to_app(e));
w.set_value(w.eval, val);
return w.commit_eval();
}
bool bv_plugin::repair_down(app* e) {
unsigned n = e->get_num_args();
bool status = true;
if (n == 0 || m_eval.is_uninterpreted(e) || m_eval.eval_is_correct(e))
goto done;
if (n == 2) {
auto d1 = get_depth(e->get_arg(0));
auto d2 = get_depth(e->get_arg(1));
unsigned s = ctx.rand(d1 + d2 + 2);
if (s <= d1 && m_eval.repair_down(e, 0))
goto done;
if (m_eval.repair_down(e, 1))
goto done;
if (m_eval.repair_down(e, 0))
goto done;
}
else {
unsigned s = ctx.rand(n);
for (unsigned i = 0; i < n; ++i) {
auto j = (i + s) % n;
if (m_eval.repair_down(e, j))
goto done;
}
}
status = false;
done:
log(e, false, status);
return status;
}
void bv_plugin::repair_up(app* e) {
if (m_eval.repair_up(e)) {
if (!m_eval.eval_is_correct(e)) {
verbose_stream() << "Incorrect eval #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n";
}
log(e, true, true);
SASSERT(m_eval.eval_is_correct(e));
if (m.is_bool(e)) {
if (ctx.is_true(e) != m_eval.bval1(e))
ctx.flip(ctx.atom2bool_var(e));
}
}
else if (bv.is_bv(e)) {
log(e, true, false);
IF_VERBOSE(5, verbose_stream() << "repair-up "; trace_repair(true, e));
auto& v = m_eval.wval(e);
m_eval.set_random(e);
ctx.new_value_eh(e);
}
else
log(e, true, false);
}
void bv_plugin::repair_literal(sat::literal lit) {
SASSERT(ctx.is_true(lit));
auto e = ctx.atom(lit.var());
if (!is_bv_predicate(e))
return;
auto a = to_app(e);
if (!m_eval.eval_is_correct(a))
ctx.flip(lit.var());
}
std::ostream& bv_plugin::trace_repair(bool down, expr* e) {
verbose_stream() << (down ? "d #" : "u #")
<< e->get_id() << ": "
<< mk_bounded_pp(e, m, 1) << " ";
return m_eval.display_value(verbose_stream(), e) << "\n";
}
void bv_plugin::trace() {
IF_VERBOSE(2, verbose_stream()
<< "(bvsls :restarts " << m_stats.m_restarts << ")\n");
}
void bv_plugin::log(expr* e, bool up_down, bool success) {
IF_VERBOSE(11, verbose_stream() << mk_bounded_pp(e, m) << " " << (up_down?"u":"d") << " " << (success ? "S" : "F");
if (bv.is_bv(e)) verbose_stream() << " " << m_eval.wval(e);
verbose_stream() << "\n");
}
}

View file

@ -0,0 +1,62 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
sls_bv_plugin.h
Abstract:
Theory plugin for bit-vector local search
Author:
Nikolaj Bjorner (nbjorner) 2024-07-06
--*/
#pragma once
#include "ast/sls/sls_context.h"
#include "ast/bv_decl_plugin.h"
#include "ast/sls/sls_bv_terms.h"
#include "ast/sls/sls_bv_eval.h"
namespace sls {
class bv_plugin : public plugin {
bv_util bv;
bv_terms m_terms;
bv_eval m_eval;
bv::sls_stats m_stats;
bool m_initialized = false;
void init_bool_var_assignment(sat::bool_var v);
std::ostream& trace_repair(bool down, expr* e);
void trace();
bool can_propagate();
bool is_bv_predicate(expr* e);
void log(expr* e, bool up_down, bool success);
public:
bv_plugin(context& ctx);
~bv_plugin() override {}
void register_term(expr* e) override;
expr_ref get_value(expr* e) override;
void initialize() override;
void propagate_literal(sat::literal lit) override;
bool propagate() override;
bool repair_down(app* e) override;
void repair_up(app* e) override;
void repair_literal(sat::literal lit) override;
bool is_sat() override;
void on_rescale() override {}
void on_restart() override {}
std::ostream& display(std::ostream& out) const override;
bool set_value(expr* e, expr* v) override;
void collect_statistics(statistics& st) const override {}
void reset_statistics() override {}
};
}

View file

@ -0,0 +1,143 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
bv_sls_terms.cpp
Abstract:
normalize bit-vector expressions to use only binary operators.
Author:
Nikolaj Bjorner (nbjorner) 2024-02-07
--*/
#include "ast/ast_ll_pp.h"
#include "ast/sls/sls_bv_terms.h"
#include "ast/rewriter/bool_rewriter.h"
#include "ast/rewriter/bv_rewriter.h"
namespace sls {
bv_terms::bv_terms(sls::context& ctx):
m(ctx.get_manager()),
bv(m),
m_axioms(m) {}
void bv_terms::register_term(expr* e) {
auto r = ensure_binary(e);
if (r != e)
m_axioms.push_back(m.mk_eq(e, r));
register_uninterp(e);
}
expr_ref bv_terms::ensure_binary(expr* e) {
expr* x, * y;
expr_ref r(m);
if (bv.is_bv_sdiv(e, x, y) || bv.is_bv_sdiv0(e, x, y) || bv.is_bv_sdivi(e, x, y))
r = mk_sdiv(x, y);
else if (bv.is_bv_smod(e, x, y) || bv.is_bv_smod0(e, x, y) || bv.is_bv_smodi(e, x, y))
r = mk_smod(x, y);
else if (bv.is_bv_srem(e, x, y) || bv.is_bv_srem0(e, x, y) || bv.is_bv_sremi(e, x, y))
r = mk_srem(x, y);
else
r = e;
return r;
}
expr_ref bv_terms::mk_sdiv(expr* x, expr* y) {
// d = udiv(abs(x), abs(y))
// y = 0, x >= 0 -> -1
// y = 0, x < 0 -> 1
// x = 0, y != 0 -> 0
// x > 0, y < 0 -> -d
// x < 0, y > 0 -> -d
// x > 0, y > 0 -> d
// x < 0, y < 0 -> d
bool_rewriter br(m);
bv_rewriter bvr(m);
unsigned sz = bv.get_bv_size(x);
rational N = rational::power_of_two(sz);
expr_ref z(bv.mk_zero(sz), m);
expr_ref o(bv.mk_one(sz), m);
expr_ref n1(bv.mk_numeral(N - 1, sz), m);
expr_ref signx = bvr.mk_ule(bv.mk_numeral(N / 2, sz), x);
expr_ref signy = bvr.mk_ule(bv.mk_numeral(N / 2, sz), y);
expr_ref absx = br.mk_ite(signx, bvr.mk_bv_neg(x), x);
expr_ref absy = br.mk_ite(signy, bvr.mk_bv_neg(y), y);
expr_ref d = expr_ref(bv.mk_bv_udiv(absx, absy), m);
expr_ref r = br.mk_ite(br.mk_eq(signx, signy), d, bvr.mk_bv_neg(d));
r = br.mk_ite(br.mk_eq(z, y),
br.mk_ite(signx, o, n1),
br.mk_ite(br.mk_eq(x, z), z, r));
return r;
}
expr_ref bv_terms::mk_smod(expr* x, expr* y) {
// u := umod(abs(x), abs(y))
// u = 0 -> 0
// y = 0 -> x
// x < 0, y < 0 -> -u
// x < 0, y >= 0 -> y - u
// x >= 0, y < 0 -> y + u
// x >= 0, y >= 0 -> u
bool_rewriter br(m);
bv_rewriter bvr(m);
unsigned sz = bv.get_bv_size(x);
expr_ref z(bv.mk_zero(sz), m);
expr_ref abs_x = br.mk_ite(bvr.mk_sle(z, x), x, bvr.mk_bv_neg(x));
expr_ref abs_y = br.mk_ite(bvr.mk_sle(z, y), y, bvr.mk_bv_neg(y));
expr_ref u = bvr.mk_bv_urem(abs_x, abs_y);
expr_ref r(m);
r = br.mk_ite(br.mk_eq(u, z), z,
br.mk_ite(br.mk_eq(y, z), x,
br.mk_ite(br.mk_and(bvr.mk_sle(z, x), bvr.mk_sle(z, x)), u,
br.mk_ite(bvr.mk_sle(z, x), bvr.mk_bv_add(y, u),
br.mk_ite(bv.mk_sle(z, y), bvr.mk_bv_sub(y, u), bvr.mk_bv_neg(u))))));
return r;
}
expr_ref bv_terms::mk_srem(expr* x, expr* y) {
// y = 0 -> x
// else x - sdiv(x, y) * y
expr_ref r(m);
bool_rewriter br(m);
bv_rewriter bvr(m);
expr_ref z(bv.mk_zero(bv.get_bv_size(x)), m);
r = br.mk_ite(br.mk_eq(y, z), x, bvr.mk_bv_sub(x, bvr.mk_bv_mul(y, mk_sdiv(x, y))));
return r;
}
void bv_terms::register_uninterp(expr* e) {
if (!m.is_bool(e))
return;
expr* x, *y;
if (m.is_eq(e, x, y) && bv.is_bv(x))
;
else if (is_app(e) && to_app(e)->get_family_id() == bv.get_fid())
;
else
return;
m_uninterp_occurs.reserve(e->get_id() + 1);
auto& occs = m_uninterp_occurs[e->get_id()];
ptr_vector<expr> todo;
todo.append(to_app(e)->get_num_args(), to_app(e)->get_args());
expr_mark marked;
for (unsigned i = 0; i < todo.size(); ++i) {
e = todo[i];
if (marked.is_marked(e))
continue;
marked.mark(e);
if (is_app(e) && to_app(e)->get_family_id() == bv.get_fid()) {
for (expr* arg : *to_app(e))
todo.push_back(arg);
}
else if (bv.is_bv(e))
occs.push_back(e);
}
}
}

View file

@ -0,0 +1,54 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
bv_sls_terms.h
Abstract:
A Stochastic Local Search (SLS) engine
Author:
Nikolaj Bjorner (nbjorner) 2024-02-07
--*/
#pragma once
#include "util/lbool.h"
#include "util/scoped_ptr_vector.h"
#include "util/uint_set.h"
#include "ast/ast.h"
#include "ast/bv_decl_plugin.h"
#include "ast/sls/sls_stats.h"
#include "ast/sls/sls_powers.h"
#include "ast/sls/sls_bv_valuation.h"
#include "ast/sls/sls_context.h"
namespace sls {
class bv_terms {
ast_manager& m;
bv_util bv;
expr_ref_vector m_axioms;
vector<ptr_vector<expr>> m_uninterp_occurs;
expr_ref ensure_binary(expr* e);
expr_ref mk_sdiv(expr* x, expr* y);
expr_ref mk_smod(expr* x, expr* y);
expr_ref mk_srem(expr* x, expr* y);
void register_uninterp(expr* e);
public:
bv_terms(sls::context& ctx);
void register_term(expr* e);
expr_ref_vector& axioms() { return m_axioms; }
ptr_vector<expr> const& uninterp_occurs(expr* e) { m_uninterp_occurs.reserve(e->get_id() + 1); return m_uninterp_occurs[e->get_id()]; }
};
}

View file

@ -18,9 +18,9 @@ Author:
--*/
#include "ast/sls/sls_valuation.h"
#include "ast/sls/sls_bv_valuation.h"
namespace bv {
namespace sls {
void bvect::set_bw(unsigned bw) {
this->bw = bw;
@ -138,6 +138,7 @@ namespace bv {
set_bw(a.bw);
SASSERT(a.bw == b.bw);
unsigned shift = b.to_nat(b.bw);
if (shift == 0)
a.copy_to(a.nw, *this);
else if (shift >= a.bw)
@ -148,7 +149,7 @@ namespace bv {
return *this;
}
sls_valuation::sls_valuation(unsigned bw) {
bv_valuation::bv_valuation(unsigned bw) {
set_bw(bw);
m_lo.set_bw(bw);
m_hi.set_bw(bw);
@ -162,7 +163,7 @@ namespace bv {
fixed[nw - 1] = ~mask;
}
void sls_valuation::set_bw(unsigned b) {
void bv_valuation::set_bw(unsigned b) {
bw = b;
nw = (bw + sizeof(digit_t) * 8 - 1) / (8 * sizeof(digit_t));
mask = (1 << (bw % (8 * sizeof(digit_t)))) - 1;
@ -170,7 +171,7 @@ namespace bv {
mask = ~(digit_t)0;
}
bool sls_valuation::commit_eval() {
bool bv_valuation::commit_eval() {
for (unsigned i = 0; i < nw; ++i)
if (0 != (fixed[i] & (m_bits[i] ^ eval[i])))
return false;
@ -180,11 +181,12 @@ namespace bv {
for (unsigned i = 0; i < nw; ++i)
m_bits[i] = eval[i];
SASSERT(well_formed());
return true;
}
bool sls_valuation::in_range(bvect const& bits) const {
bool bv_valuation::in_range(bvect const& bits) const {
mpn_manager m;
auto c = m.compare(m_lo.data(), nw, m_hi.data(), nw);
SASSERT(!has_overflow(bits));
@ -207,7 +209,7 @@ namespace bv {
// largest dst <= src and dst is feasible
//
bool sls_valuation::get_at_most(bvect const& src, bvect& dst) const {
bool bv_valuation::get_at_most(bvect const& src, bvect& dst) const {
SASSERT(!has_overflow(src));
src.copy_to(nw, dst);
sup_feasible(dst);
@ -227,7 +229,7 @@ namespace bv {
//
// smallest dst >= src and dst is feasible with respect to this.
bool sls_valuation::get_at_least(bvect const& src, bvect& dst) const {
bool bv_valuation::get_at_least(bvect const& src, bvect& dst) const {
SASSERT(!has_overflow(src));
src.copy_to(nw, dst);
dst.set_bw(bw);
@ -244,34 +246,38 @@ namespace bv {
return true;
}
bool sls_valuation::set_random_at_most(bvect const& src, random_gen& r) {
bool bv_valuation::set_random_at_most(bvect const& src, random_gen& r) {
m_tmp.set_bw(bw);
//verbose_stream() << "set_random_at_most " << src << "\n";
if (!get_at_most(src, m_tmp))
return false;
if (is_zero(m_tmp) || (0 != r(10)))
return try_set(m_tmp);
if (is_zero(m_tmp) && (0 != r(2)))
return try_set(m_tmp) && m_tmp <= src;
// random value below tmp
set_random_below(m_tmp, r);
//verbose_stream() << "can set " << m_tmp << " " << can_set(m_tmp) << "\n";
return (can_set(m_tmp) || get_at_most(src, m_tmp)) && try_set(m_tmp);
return (can_set(m_tmp) || get_at_most(src, m_tmp)) && m_tmp <= src && try_set(m_tmp);
}
bool sls_valuation::set_random_at_least(bvect const& src, random_gen& r) {
bool bv_valuation::set_random_at_least(bvect const& src, random_gen& r) {
m_tmp.set_bw(bw);
if (!get_at_least(src, m_tmp))
return false;
if (is_ones(m_tmp) || (0 != r(10)))
if (is_ones(m_tmp) && (0 != r(10)))
return try_set(m_tmp);
// random value at least tmp
set_random_above(m_tmp, r);
return (can_set(m_tmp) || get_at_least(src, m_tmp)) && try_set(m_tmp);
return (can_set(m_tmp) || get_at_least(src, m_tmp)) && src <= m_tmp && try_set(m_tmp);
}
bool sls_valuation::set_random_in_range(bvect const& lo, bvect const& hi, random_gen& r) {
bool bv_valuation::set_random_in_range(bvect const& lo, bvect const& hi, random_gen& r) {
bvect& tmp = m_tmp;
if (0 == r(2)) {
if (!get_at_least(lo, tmp))
@ -279,14 +285,10 @@ namespace bv {
SASSERT(in_range(tmp));
if (hi < tmp)
return false;
if (is_ones(tmp) || (0 == r() % 2))
return try_set(tmp);
set_random_above(tmp, r);
round_down(tmp, [&](bvect const& t) { return hi >= t && in_range(t); });
if (in_range(tmp) && lo <= tmp && hi >= tmp)
return try_set(tmp);
return get_at_least(lo, tmp) && hi >= tmp && try_set(tmp);
if (in_range(tmp) || get_at_least(lo, tmp))
return lo <= tmp && tmp <= hi && try_set(tmp);
}
else {
if (!get_at_most(hi, tmp))
@ -294,37 +296,35 @@ namespace bv {
SASSERT(in_range(tmp));
if (lo > tmp)
return false;
if (is_zero(tmp) || (0 == r() % 2))
return try_set(tmp);
set_random_below(tmp, r);
round_up(tmp, [&](bvect const& t) { return lo <= t && in_range(t); });
if (in_range(tmp) && lo <= tmp && hi >= tmp)
return try_set(tmp);
return get_at_most(hi, tmp) && lo <= tmp && try_set(tmp);
if (in_range(tmp) || get_at_most(hi, tmp))
return lo <= tmp && tmp <= hi && try_set(tmp);
}
return false;
}
void sls_valuation::round_down(bvect& dst, std::function<bool(bvect const&)> const& is_feasible) {
void bv_valuation::round_down(bvect& dst, std::function<bool(bvect const&)> const& is_feasible) {
for (unsigned i = bw; !is_feasible(dst) && i-- > 0; )
if (!fixed.get(i) && dst.get(i))
dst.set(i, false);
repair_sign_bits(dst);
}
void sls_valuation::round_up(bvect& dst, std::function<bool(bvect const&)> const& is_feasible) {
void bv_valuation::round_up(bvect& dst, std::function<bool(bvect const&)> const& is_feasible) {
for (unsigned i = 0; !is_feasible(dst) && i < bw; ++i)
if (!fixed.get(i) && !dst.get(i))
dst.set(i, true);
repair_sign_bits(dst);
}
void sls_valuation::set_random_above(bvect& dst, random_gen& r) {
void bv_valuation::set_random_above(bvect& dst, random_gen& r) {
for (unsigned i = 0; i < nw; ++i)
dst[i] = dst[i] | (random_bits(r) & ~fixed[i]);
repair_sign_bits(dst);
}
void sls_valuation::set_random_below(bvect& dst, random_gen& r) {
void bv_valuation::set_random_below(bvect& dst, random_gen& r) {
if (is_zero(dst))
return;
unsigned n = 0, idx = UINT_MAX;
@ -341,7 +341,7 @@ namespace bv {
repair_sign_bits(dst);
}
bool sls_valuation::set_repair(bool try_down, bvect& dst) {
bool bv_valuation::set_repair(bool try_down, bvect& dst) {
for (unsigned i = 0; i < nw; ++i)
dst[i] = (~fixed[i] & dst[i]) | (fixed[i] & m_bits[i]);
clear_overflow_bits(dst);
@ -358,7 +358,7 @@ namespace bv {
dst.set(i, false);
for (unsigned i = 0; i < bw && dst < m_lo && !in_range(dst); ++i)
if (!fixed.get(i) && !dst.get(i))
dst.set(i, true);
dst.set(i, true);
}
else {
for (unsigned i = 0; !in_range(dst) && i < bw; ++i)
@ -377,7 +377,7 @@ namespace bv {
return repaired;
}
void sls_valuation::min_feasible(bvect& out) const {
void bv_valuation::min_feasible(bvect& out) const {
if (m_lo < m_hi)
m_lo.copy_to(nw, out);
else {
@ -388,7 +388,7 @@ namespace bv {
SASSERT(!has_overflow(out));
}
void sls_valuation::max_feasible(bvect& out) const {
void bv_valuation::max_feasible(bvect& out) const {
if (m_lo < m_hi) {
m_hi.copy_to(nw, out);
sub1(out);
@ -401,7 +401,7 @@ namespace bv {
SASSERT(!has_overflow(out));
}
unsigned sls_valuation::msb(bvect const& src) const {
unsigned bv_valuation::msb(bvect const& src) const {
SASSERT(!has_overflow(src));
for (unsigned i = nw; i-- > 0; )
if (src[i] != 0)
@ -409,7 +409,7 @@ namespace bv {
return bw;
}
unsigned sls_valuation::clz(bvect const& src) const {
unsigned bv_valuation::clz(bvect const& src) const {
SASSERT(!has_overflow(src));
unsigned i = bw;
for (; i-- > 0; )
@ -419,36 +419,64 @@ namespace bv {
}
void sls_valuation::set_value(bvect& bits, rational const& n) {
void bv_valuation::set_value(bvect& bits, rational const& n) {
for (unsigned i = 0; i < bw; ++i)
bits.set(i, n.get_bit(i));
clear_overflow_bits(bits);
}
void sls_valuation::get(bvect& dst) const {
void bv_valuation::get(bvect& dst) const {
m_bits.copy_to(nw, dst);
}
digit_t sls_valuation::random_bits(random_gen& rand) {
digit_t bv_valuation::random_bits(random_gen& rand) {
digit_t r = 0;
for (digit_t i = 0; i < sizeof(digit_t); ++i)
r ^= rand() << (8 * i);
return r;
}
void sls_valuation::get_variant(bvect& dst, random_gen& r) const {
void bv_valuation::get_variant(bvect& dst, random_gen& r) const {
for (unsigned i = 0; i < nw; ++i)
dst[i] = (random_bits(r) & ~fixed[i]) | (fixed[i] & m_bits[i]);
repair_sign_bits(dst);
clear_overflow_bits(dst);
}
bool sls_valuation::set_random(random_gen& r) {
bool bv_valuation::set_random(random_gen& r) {
get_variant(m_tmp, r);
return set_repair(r(2) == 0, m_tmp);
repair_sign_bits(m_tmp);
if (in_range(m_tmp)) {
set(eval, m_tmp);
return true;
}
for (unsigned i = 0; i < nw; ++i)
m_tmp[i] = random_bits(r);
clear_overflow_bits(m_tmp);
// find a random offset within [lo, hi[
SASSERT(m_lo != m_hi);
set_sub(eval, m_hi, m_lo);
for (unsigned i = bw; i-- > 0 && m_tmp >= eval; )
m_tmp.set(i, false);
// set eval back to m_bits. It was garbage.
set(eval, m_bits);
// tmp := lo + tmp is within [lo, hi[
set_add(m_tmp, m_tmp, m_lo);
// respect fixed bits
for (unsigned i = 0; i < bw; ++i)
if (fixed.get(i))
m_tmp.set(i, m_bits.get(i));
// decrease tmp until it is in range again
for (unsigned i = bw; i-- > 0 && !in_range(m_tmp); )
if (!fixed.get(i))
m_tmp.set(i, false);
repair_sign_bits(m_tmp);
return try_set(m_tmp);
}
void sls_valuation::repair_sign_bits(bvect& dst) const {
void bv_valuation::repair_sign_bits(bvect& dst) const {
if (m_signed_prefix == 0)
return;
bool sign = m_signed_prefix == bw ? dst.get(bw - 1) : dst.get(bw - m_signed_prefix - 1);
@ -474,7 +502,7 @@ namespace bv {
// 0 = (new_bits ^ bits) & fixedf
// also check that new_bits are in range
//
bool sls_valuation::can_set(bvect const& new_bits) const {
bool bv_valuation::can_set(bvect const& new_bits) const {
SASSERT(!has_overflow(new_bits));
for (unsigned i = 0; i < nw; ++i)
if (0 != ((new_bits[i] ^ m_bits[i]) & fixed[i]))
@ -482,28 +510,28 @@ namespace bv {
return in_range(new_bits);
}
unsigned sls_valuation::to_nat(unsigned max_n) const {
unsigned bv_valuation::to_nat(unsigned max_n) const {
bvect const& d = m_bits;
SASSERT(!has_overflow(d));
return d.to_nat(max_n);
}
void sls_valuation::shift_right(bvect& out, unsigned shift) const {
void bv_valuation::shift_right(bvect& out, unsigned shift) const {
SASSERT(shift < bw);
for (unsigned i = 0; i < bw; ++i)
out.set(i, i + shift < bw ? m_bits.get(i + shift) : false);
out.set(i, i + shift < bw ? out.get(i + shift) : false);
SASSERT(well_formed());
}
void sls_valuation::add_range(rational l, rational h) {
void bv_valuation::add_range(rational l, rational h) {
l = mod(l, rational::power_of_two(bw));
h = mod(h, rational::power_of_two(bw));
if (h == l)
return;
// verbose_stream() << *this << " " << l << " " << h << " --> ";
//verbose_stream() << *this << " lo " << l << " hi " << h << " --> ";
if (m_lo == m_hi) {
set_value(m_lo, l);
@ -555,7 +583,7 @@ namespace bv {
// update bits based on ranges
//
unsigned sls_valuation::diff_index(bvect const& a) const {
unsigned bv_valuation::diff_index(bvect const& a) const {
unsigned index = 0;
for (unsigned i = nw; i-- > 0; ) {
auto diff = fixed[i] & (m_bits[i] ^ a[i]);
@ -565,55 +593,87 @@ namespace bv {
return index;
}
void sls_valuation::inf_feasible(bvect& a) const {
// The least a' >= a, such that the fixed bits in bits agree with a'.
// 0 if there is no such a'.
void bv_valuation::inf_feasible(bvect& a) const {
unsigned lo_index = diff_index(a);
if (lo_index != 0) {
lo_index--;
SASSERT(a.get(lo_index) != m_bits.get(lo_index));
SASSERT(fixed.get(lo_index));
for (unsigned i = 0; i <= lo_index; ++i) {
if (!fixed.get(i))
a.set(i, false);
else if (fixed.get(i))
a.set(i, m_bits.get(i));
}
if (!a.get(lo_index)) {
for (unsigned i = lo_index + 1; i < bw; ++i)
if (!fixed.get(i) && !a.get(i)) {
a.set(i, true);
break;
}
}
}
}
void sls_valuation::sup_feasible(bvect& a) const {
unsigned hi_index = diff_index(a);
if (hi_index != 0) {
hi_index--;
SASSERT(a.get(hi_index) != m_bits.get(hi_index));
SASSERT(fixed.get(hi_index));
for (unsigned i = 0; i <= hi_index; ++i) {
if (!fixed.get(i))
a.set(i, true);
else if (fixed.get(i))
a.set(i, m_bits.get(i));
}
if (a.get(hi_index)) {
for (unsigned i = hi_index + 1; i < bw; ++i)
if (!fixed.get(i) && a.get(i)) {
a.set(i, false);
break;
}
}
}
}
void sls_valuation::tighten_range() {
if (m_lo == m_hi)
if (lo_index == 0)
return;
--lo_index;
// decrement a[lo_index:0] maximally
SASSERT(a.get(lo_index) != m_bits.get(lo_index));
SASSERT(fixed.get(lo_index));
for (unsigned i = 0; i <= lo_index; ++i) {
if (!fixed.get(i))
a.set(i, false);
else if (fixed.get(i))
a.set(i, m_bits.get(i));
}
// the previous value of a[lo_index] was 0.
// a[lo_index:0] was incremented, so no need to adjust bits a[:lo_index+1]
if (a.get(lo_index))
return;
// find the minimal increment within a[:lo_index+1]
for (unsigned i = lo_index + 1; i < bw; ++i) {
if (!fixed.get(i) && !a.get(i)) {
a.set(i, true);
return;
}
}
// there is no feasiable value a' >= a, so find the least
// feasiable value a' >= 0.
for (unsigned i = 0; i < bw; ++i)
if (!fixed.get(i))
a.set(i, false);
}
// The greatest a' <= a, such that the fixed bits in bits agree with a'.
// the greatest a' <= -1 if there is no such a'.
void bv_valuation::sup_feasible(bvect& a) const {
unsigned hi_index = diff_index(a);
if (hi_index == 0)
return;
--hi_index;
SASSERT(a.get(hi_index) != m_bits.get(hi_index));
SASSERT(fixed.get(hi_index));
// increment a[hi_index:0] maximally
for (unsigned i = 0; i <= hi_index; ++i) {
if (!fixed.get(i))
a.set(i, true);
else if (fixed.get(i))
a.set(i, m_bits.get(i));
}
// If a[hi_index:0] was decremented, then no need to adjust bits a[:hi_index+1]
if (!a.get(hi_index))
return;
// find the minimal decrement within a[:hi_index+1]
for (unsigned i = hi_index + 1; i < bw; ++i) {
if (!fixed.get(i) && a.get(i)) {
a.set(i, false);
return;
}
}
// a[hi_index:0] was incremented, but a[:hi_index+1] cannot be decremented.
// maximize a[:hi_index+1] to model wrap around behavior.
for (unsigned i = hi_index + 1; i < bw; ++i)
if (!fixed.get(i))
a.set(i, true);
}
void bv_valuation::tighten_range() {
// verbose_stream() << "tighten " << m_lo << " " << m_hi << " " << m_bits << "\n";
if (m_lo == m_hi)
return;
inf_feasible(m_lo);
@ -625,59 +685,8 @@ namespace bv {
add1(hi1);
hi1.copy_to(nw, m_hi);
/*
unsigned lo_index = 0, hi_index = 0;
for (unsigned i = nw; i-- > 0; ) {
auto lo_diff = (fixed[i] & (m_bits[i] ^ m_lo[i]));
if (lo_diff != 0 && lo_index == 0)
lo_index = 1 + i * 8 * sizeof(digit_t) + log2(lo_diff);
auto hi_diff = (fixed[i] & (m_bits[i] ^ hi1[i]));
if (hi_diff != 0 && hi_index == 0)
hi_index = 1 + i * 8 * sizeof(digit_t) + log2(hi_diff);
}
if (lo_index != 0) {
lo_index--;
SASSERT(m_lo.get(lo_index) != m_bits.get(lo_index));
SASSERT(fixed.get(lo_index));
for (unsigned i = 0; i <= lo_index; ++i) {
if (!fixed.get(i))
m_lo.set(i, false);
else if (fixed.get(i))
m_lo.set(i, m_bits.get(i));
}
if (!m_bits.get(lo_index)) {
for (unsigned i = lo_index + 1; i < bw; ++i)
if (!fixed.get(i) && !m_lo.get(i)) {
m_lo.set(i, true);
break;
}
}
}
if (hi_index != 0) {
hi_index--;
SASSERT(hi1.get(hi_index) != m_bits.get(hi_index));
SASSERT(fixed.get(hi_index));
for (unsigned i = 0; i <= hi_index; ++i) {
if (!fixed.get(i))
hi1.set(i, true);
else if (fixed.get(i))
hi1.set(i, m_bits.get(i));
}
if (m_bits.get(hi_index)) {
for (unsigned i = hi_index + 1; i < bw; ++i)
if (!fixed.get(i) && hi1.get(i)) {
hi1.set(i, false);
break;
}
}
add1(hi1);
hi1.copy_to(nw, m_hi);
}
*/
if (has_range() && !in_range(m_bits))
m_bits = m_lo;
m_lo.copy_to(nw, m_bits);
if (mod(lo() + 1, rational::power_of_two(bw)) == hi())
for (unsigned i = 0; i < nw; ++i)
@ -687,16 +696,17 @@ namespace bv {
if (hi() < rational::power_of_two(i))
fixed.set(i, true);
// verbose_stream() << "post tighten " << m_lo << " " << m_hi << " " << m_bits << "\n";
SASSERT(well_formed());
}
void sls_valuation::set_sub(bvect& out, bvect const& a, bvect const& b) const {
void bv_valuation::set_sub(bvect& out, bvect const& a, bvect const& b) const {
digit_t c;
mpn_manager().sub(a.data(), nw, b.data(), nw, out.data(), &c);
clear_overflow_bits(out);
}
bool sls_valuation::set_add(bvect& out, bvect const& a, bvect const& b) const {
bool bv_valuation::set_add(bvect& out, bvect const& a, bvect const& b) const {
digit_t c;
mpn_manager().add(a.data(), nw, b.data(), nw, out.data(), nw + 1, &c);
bool ovfl = out[nw] != 0 || has_overflow(out);
@ -704,7 +714,9 @@ namespace bv {
return ovfl;
}
bool sls_valuation::set_mul(bvect& out, bvect const& a, bvect const& b, bool check_overflow) const {
bool bv_valuation::set_mul(bvect& out, bvect const& a, bvect const& b, bool check_overflow) const {
out.reserve(2 * nw);
SASSERT(out.size() >= 2 * nw);
mpn_manager().mul(a.data(), nw, b.data(), nw, out.data());
bool ovfl = false;
if (check_overflow) {
@ -716,7 +728,7 @@ namespace bv {
return ovfl;
}
bool sls_valuation::is_power_of2(bvect const& src) const {
bool bv_valuation::is_power_of2(bvect const& src) const {
unsigned c = 0;
for (unsigned i = 0; i < nw; ++i)
c += get_num_1bits(src[i]);

View file

@ -3,7 +3,7 @@ Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_valuation.h
sls_bv_valuation.h
Abstract:
@ -20,12 +20,10 @@ Author:
#include "util/params.h"
#include "util/scoped_ptr_vector.h"
#include "util/uint_set.h"
#include "ast/ast.h"
#include "ast/sls/sls_stats.h"
#include "ast/sls/sls_powers.h"
#include "ast/bv_decl_plugin.h"
#include "util/mpz.h"
#include "util/rational.h"
namespace bv {
namespace sls {
class bvect : public svector<digit_t> {
public:
@ -106,7 +104,7 @@ namespace bv {
inline bool operator!=(bvect const& a, bvect const& b) { return !(a == b); }
std::ostream& operator<<(std::ostream& out, bvect const& v);
class sls_valuation {
class bv_valuation {
protected:
bvect m_bits;
bvect m_lo, m_hi; // range assignment to bit-vector, as wrap-around interval
@ -124,8 +122,8 @@ namespace bv {
bvect fixed; // bit assignment and don't care bit
bvect eval; // current evaluation
sls_valuation(unsigned bw);
bv_valuation(unsigned bw);
void set_bw(unsigned bw);
void set_signed(unsigned prefix) { m_signed_prefix = prefix; }
@ -134,7 +132,9 @@ namespace bv {
digit_t bits(unsigned i) const { return m_bits[i]; }
bvect const& bits() const { return m_bits; }
bvect const& tmp_bits(bool use_current) const { return use_current ? m_bits : m_tmp; }
bool commit_eval();
bool is_fixed() const { for (unsigned i = bw; i-- > 0; ) if (!fixed.get(i)) return false; return true; }
bool get_bit(unsigned i) const { return m_bits.get(i); }
bool try_set_bit(unsigned i, bool b) {
@ -166,6 +166,9 @@ namespace bv {
bool has_range() const { return m_lo != m_hi; }
void tighten_range();
void save_value() { m_bits.copy_to(nw, m_tmp); }
void restore_value() { m_tmp.copy_to(nw, m_bits); }
void clear_overflow_bits(bvect& bits) const {
SASSERT(nw > 0);
bits[nw - 1] &= mask;
@ -175,7 +178,7 @@ namespace bv {
bool in_range(bvect const& bits) const;
bool can_set(bvect const& bits) const;
bool eq(sls_valuation const& other) const { return eq(other.m_bits); }
bool eq(bv_valuation const& other) const { return eq(other.m_bits); }
bool eq(bvect const& other) const { return other == m_bits; }
bool is_zero() const { return is_zero(m_bits); }
@ -342,6 +345,6 @@ namespace bv {
};
inline std::ostream& operator<<(std::ostream& out, sls_valuation const& v) { return v.display(out); }
inline std::ostream& operator<<(std::ostream& out, bv_valuation const& v) { return v.display(out); }
}

654
src/ast/sls/sls_context.cpp Normal file
View file

@ -0,0 +1,654 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
smt_sls.cpp
Abstract:
A Stochastic Local Search (SLS) Context.
Author:
Nikolaj Bjorner (nbjorner) 2024-06-24
--*/
#include "ast/sls/sls_context.h"
#include "ast/sls/sls_euf_plugin.h"
#include "ast/sls/sls_arith_plugin.h"
#include "ast/sls/sls_array_plugin.h"
#include "ast/sls/sls_bv_plugin.h"
#include "ast/sls/sls_basic_plugin.h"
#include "ast/sls/sls_datatype_plugin.h"
#include "ast/ast_ll_pp.h"
#include "ast/ast_pp.h"
#include "smt/params/smt_params_helper.hpp"
namespace sls {
plugin::plugin(context& c):
ctx(c),
m(c.get_manager()) {
}
context::context(ast_manager& m, sat_solver_context& s) :
m(m), s(s), m_atoms(m), m_allterms(m),
m_gd(*this),
m_ld(*this),
m_repair_down(m.get_num_asts(), m_gd),
m_repair_up(m.get_num_asts(), m_ld),
m_constraint_trail(m),
m_todo(m) {
}
void context::updt_params(params_ref const& p) {
smt_params_helper smtp(p);
m_rand.set_seed(smtp.random_seed());
m_params.append(p);
}
void context::register_plugin(plugin* p) {
m_plugins.reserve(p->fid() + 1);
m_plugins.set(p->fid(), p);
}
void context::ensure_plugin(family_id fid) {
if (m_plugins.get(fid, nullptr))
return;
else if (fid == arith_family_id)
register_plugin(alloc(arith_plugin, *this));
else if (fid == user_sort_family_id)
register_plugin(alloc(euf_plugin, *this));
else if (fid == basic_family_id)
register_plugin(alloc(basic_plugin, *this));
else if (fid == bv_util(m).get_family_id())
register_plugin(alloc(bv_plugin, *this));
else if (fid == array_util(m).get_family_id())
register_plugin(alloc(array_plugin, *this));
else if (fid == datatype_util(m).get_family_id())
register_plugin(alloc(datatype_plugin, *this));
else if (fid == null_family_id)
;
else
verbose_stream() << "did not find plugin for " << fid << "\n";
}
scoped_ptr<euf::egraph>& context::egraph() {
return euf().egraph();
}
euf_plugin& context::euf() {
auto fid = user_sort_family_id;
auto p = m_plugins.get(fid, nullptr);
if (!p) {
p = alloc(euf_plugin, *this);
register_plugin(p);
}
return *dynamic_cast<euf_plugin*>(p);
}
void context::ensure_plugin(expr* e) {
auto fid = get_fid(e);
ensure_plugin(fid);
fid = e->get_sort()->get_family_id();
ensure_plugin(fid);
}
void context::register_atom(sat::bool_var v, expr* e) {
m_atoms.setx(v, e);
m_atom2bool_var.setx(e->get_id(), v, sat::null_bool_var);
}
void context::on_restart() {
for (auto p : m_plugins)
if (p)
p->on_restart();
}
lbool context::check() {
//
// initialize data-structures if not done before.
// identify minimal feasible assignment to literals.
// sub-expressions within assignment are relevant.
// Use timestamps to make it incremental.
//
init();
while (unsat().empty() && m.inc()) {
propagate_boolean_assignment();
// verbose_stream() << "propagate " << unsat().size() << " " << m_new_constraint << "\n";
if (m_new_constraint || !unsat().empty())
return l_undef;
if (all_of(m_plugins, [&](auto* p) { return !p || p->is_sat(); })) {
values2model();
return l_true;
}
}
return l_undef;
}
void context::values2model() {
model_ref mdl = alloc(model, m);
expr_ref_vector args(m);
for (expr* e : subterms())
if (is_uninterp_const(e))
mdl->register_decl(to_app(e)->get_decl(), get_value(e));
for (expr* e : subterms()) {
if (!is_app(e))
continue;
auto f = to_app(e)->get_decl();
if (!include_func_interp(f))
continue;
auto v = get_value(e);
auto fi = mdl->get_func_interp(f);
if (!fi) {
fi = alloc(func_interp, m, f->get_arity());
mdl->register_decl(f, fi);
}
args.reset();
for (expr* arg : *to_app(e)) {
args.push_back(get_value(arg));
SASSERT(args.back());
}
SASSERT(f->get_arity() == args.size());
if (!fi->get_entry(args.data()))
fi->insert_new_entry(args.data(), v);
}
s.on_model(mdl);
// verbose_stream() << *mdl << "\n";
TRACE("sls", display(tout));
}
void context::propagate_boolean_assignment() {
reinit_relevant();
for (auto p : m_plugins)
if (p)
p->start_propagation();
for (sat::literal lit : root_literals())
propagate_literal(lit);
if (m_new_constraint)
return;
while (!m_new_constraint && m.inc() && (!m_repair_up.empty() || !m_repair_down.empty())) {
while (!m_repair_down.empty() && !m_new_constraint && m.inc()) {
auto id = m_repair_down.erase_min();
expr* e = term(id);
TRACE("sls", tout << "repair down " << mk_bounded_pp(e, m) << "\n");
if (is_app(e)) {
auto p = m_plugins.get(get_fid(e), nullptr);
++m_stats.m_num_repair_down;
if (p && !p->repair_down(to_app(e)) && !m_repair_up.contains(e->get_id())) {
IF_VERBOSE(3, verbose_stream() << "revert repair: " << mk_bounded_pp(e, m) << "\n");
m_repair_up.insert(e->get_id());
}
}
}
while (!m_repair_up.empty() && !m_new_constraint && m.inc()) {
auto id = m_repair_up.erase_min();
expr* e = term(id);
++m_stats.m_num_repair_up;
TRACE("sls", tout << "repair up " << mk_bounded_pp(e, m) << "\n");
if (is_app(e)) {
auto p = m_plugins.get(get_fid(e), nullptr);
if (p)
p->repair_up(to_app(e));
}
}
}
repair_literals();
// propagate "final checks"
bool propagated = true;
while (propagated && !m_new_constraint) {
propagated = false;
for (auto p : m_plugins)
propagated |= p && !m_new_constraint && p->propagate();
}
}
void context::repair_literals() {
for (sat::bool_var v = 0; v < s.num_vars() && !m_new_constraint; ++v) {
auto a = atom(v);
if (!a)
continue;
sat::literal lit(v, !is_true(v));
auto p = m_plugins.get(get_fid(a), nullptr);
if (p)
p->repair_literal(lit);
}
}
family_id context::get_fid(expr* e) const {
if (!is_app(e))
return user_sort_family_id;
family_id fid = to_app(e)->get_family_id();
if (m.is_eq(e))
fid = to_app(e)->get_arg(0)->get_sort()->get_family_id();
if (m.is_distinct(e))
fid = to_app(e)->get_arg(0)->get_sort()->get_family_id();
if ((fid == null_family_id && to_app(e)->get_num_args() > 0) || fid == model_value_family_id)
fid = user_sort_family_id;
return fid;
}
void context::propagate_literal(sat::literal lit) {
if (!is_true(lit))
return;
auto a = atom(lit.var());
if (!a)
return;
family_id fid = get_fid(a);
auto p = m_plugins.get(fid, nullptr);
if (p)
p->propagate_literal(lit);
if (!is_true(lit)) {
m_new_constraint = true;
}
}
bool context::is_true(expr* e) {
SASSERT(m.is_bool(e));
auto v = m_atom2bool_var.get(e->get_id(), sat::null_bool_var);
if (v != sat::null_bool_var)
return m.is_true(m_plugins[basic_family_id]->get_value(e));
else
return is_true(v);
}
bool context::is_fixed(expr* e) {
// is this a Boolean literal that is a unit?
return false;
}
expr_ref context::get_value(expr* e) {
sort* s = e->get_sort();
auto fid = s->get_family_id();
auto p = m_plugins.get(fid, nullptr);
if (p)
return p->get_value(e);
verbose_stream() << fid << " " << m.get_family_name(fid) << " " << mk_pp(e, m) << "\n";
UNREACHABLE();
return expr_ref(e, m);
}
bool context::set_value(expr * e, expr * v) {
return any_of(m_plugins, [&](auto p) { return p && p->set_value(e, v); });
}
bool context::is_relevant(expr* e) {
unsigned id = e->get_id();
if (m_relevant.contains(id))
return true;
if (m_visited.contains(id))
return false;
m_visited.insert(id);
if (m_parents.size() <= id)
verbose_stream() << "not in map " << mk_bounded_pp(e, m) << "\n";
for (auto p : m_parents[id]) {
if (is_relevant(p)) {
m_relevant.insert(id);
return true;
}
}
return false;
}
void context::add_constraint(expr* e) {
if (m_constraint_ids.contains(e->get_id()))
return;
m_constraint_ids.insert(e->get_id());
m_constraint_trail.push_back(e);
add_clause(e);
m_new_constraint = true;
++m_stats.m_num_constraints;
}
void context::add_clause(expr* f) {
expr_ref _e(f, m);
expr* g, * h, * k;
sat::literal_vector clause;
if (m.is_true(f))
return;
if (m.is_not(f, g) && m.is_not(g, g)) {
add_clause(g);
return;
}
bool sign = m.is_not(f, f);
if (!sign && m.is_or(f)) {
clause.reset();
for (auto arg : *to_app(f))
clause.push_back(mk_literal(arg));
s.add_clause(clause.size(), clause.data());
}
else if (!sign && m.is_and(f)) {
for (auto arg : *to_app(f))
add_clause(arg);
}
else if (sign && m.is_or(f)) {
for (auto arg : *to_app(f)) {
expr_ref fml(m.mk_not(arg), m);
add_clause(fml);
}
}
else if (!sign && m.is_implies(f, g, h)) {
clause.reset();
clause.push_back(~mk_literal(g));
clause.push_back(mk_literal(h));
s.add_clause(clause.size(), clause.data());
}
else if (sign && m.is_implies(f, g, h)) {
expr_ref fml(m.mk_not(h), m);
add_clause(fml);
add_clause(g);
}
else if (sign && m.is_and(f)) {
clause.reset();
for (auto arg : *to_app(f))
clause.push_back(~mk_literal(arg));
s.add_clause(clause.size(), clause.data());
}
else if (m.is_iff(f, g, h)) {
auto lit1 = mk_literal(g);
auto lit2 = mk_literal(h);
sat::literal cls1[2] = { sign ? lit1 : ~lit1, lit2 };
sat::literal cls2[2] = { sign ? ~lit1 : lit1, ~lit2 };
s.add_clause(2, cls1);
s.add_clause(2, cls2);
}
else if (m.is_ite(f, g, h, k)) {
auto lit1 = mk_literal(g);
auto lit2 = mk_literal(h);
auto lit3 = mk_literal(k);
// (g -> h) & (~g -> k)
// (g & h) | (~g & k)
// negated: (g -> ~h) & (g -> ~k)
sat::literal cls1[2] = { ~lit1, sign ? ~lit2 : lit2 };
sat::literal cls2[2] = { lit1, sign ? ~lit3 : lit3 };
s.add_clause(2, cls1);
s.add_clause(2, cls2);
}
else {
sat::literal lit = mk_literal(f);
if (sign)
lit.neg();
s.add_clause(1, &lit);
}
}
void context::add_clause(sat::literal_vector const& lits) {
s.add_clause(lits.size(), lits.data());
m_new_constraint = true;
++m_stats.m_num_constraints;
}
sat::literal context::mk_literal() {
sat::bool_var v = s.add_var();
return sat::literal(v, false);
}
sat::literal context::mk_literal(expr* e) {
expr_ref _e(e, m);
sat::literal lit;
bool neg = false;
expr* a, * b, * c;
while (m.is_not(e, e))
neg = !neg;
auto v = m_atom2bool_var.get(e->get_id(), sat::null_bool_var);
if (v != sat::null_bool_var)
return sat::literal(v, neg);
sat::literal_vector clause;
lit = mk_literal();
register_atom(lit.var(), e);
if (m.is_true(e)) {
clause.push_back(lit);
s.add_clause(clause.size(), clause.data());
}
else if (m.is_false(e)) {
clause.push_back(~lit);
s.add_clause(clause.size(), clause.data());
}
else if (m.is_and(e)) {
for (expr* arg : *to_app(e)) {
auto lit2 = mk_literal(arg);
clause.push_back(~lit2);
sat::literal lits[2] = { ~lit, lit2 };
s.add_clause(2, lits);
}
clause.push_back(lit);
s.add_clause(clause.size(), clause.data());
}
else if (m.is_or(e)) {
for (expr* arg : *to_app(e)) {
auto lit2 = mk_literal(arg);
clause.push_back(lit2);
sat::literal lits[2] = { lit, ~lit2 };
s.add_clause(2, lits);
}
clause.push_back(~lit);
s.add_clause(clause.size(), clause.data());
}
else if (m.is_iff(e, a, b) || m.is_xor(e, a, b)) {
auto lit1 = mk_literal(a);
auto lit2 = mk_literal(b);
if (m.is_xor(e))
lit2.neg();
sat::literal cls1[3] = { ~lit, ~lit1, lit2 };
sat::literal cls2[3] = { ~lit, lit1, ~lit2 };
sat::literal cls3[3] = { lit, lit1, lit2 };
sat::literal cls4[3] = { lit, ~lit1, ~lit2 };
s.add_clause(3, cls1);
s.add_clause(3, cls2);
s.add_clause(3, cls3);
s.add_clause(3, cls4);
}
else if (m.is_ite(e, a, b, c)) {
auto lit1 = mk_literal(a);
auto lit2 = mk_literal(b);
auto lit3 = mk_literal(c);
sat::literal cls1[3] = { ~lit, ~lit1, lit2 };
sat::literal cls2[3] = { ~lit, lit1, lit3 };
sat::literal cls3[3] = { lit, ~lit1, ~lit2 };
sat::literal cls4[3] = { lit, lit1, ~lit3 };
s.add_clause(3, cls1);
s.add_clause(3, cls2);
s.add_clause(3, cls3);
s.add_clause(3, cls4);
}
else
register_terms(e);
return neg ? ~lit : lit;
}
void context::init() {
m_new_constraint = false;
if (m_initialized)
return;
m_initialized = true;
m_unit_literals.reset();
m_unit_indices.reset();
for (auto const& clause : s.clauses())
if (clause.m_clause.size() == 1)
m_unit_literals.push_back(clause.m_clause[0]);
for (sat::literal lit : m_unit_literals)
m_unit_indices.insert(lit.index());
IF_VERBOSE(3, verbose_stream() << "UNITS " << m_unit_literals << "\n");
for (unsigned i = 0; i < m_atoms.size(); ++i)
if (m_atoms.get(i))
register_terms(m_atoms.get(i));
for (auto p : m_plugins)
if (p)
p->initialize();
}
void context::register_terms(expr* e) {
auto is_visited = [&](expr* e) {
return nullptr != m_allterms.get(e->get_id(), nullptr);
};
auto visit = [&](expr* e) {
m_allterms.setx(e->get_id(), e);
ensure_plugin(e);
register_term(e);
};
if (is_visited(e))
return;
m_subterms.reset();
m_todo.push_back(e);
if (m_todo.size() > 1)
return;
while (!m_todo.empty()) {
expr* e = m_todo.back();
if (is_visited(e))
m_todo.pop_back();
else if (is_app(e)) {
if (all_of(*to_app(e), [&](expr* arg) { return is_visited(arg); })) {
expr_ref _e(e, m);
m_todo.pop_back();
m_parents.reserve(to_app(e)->get_id() + 1);
for (expr* arg : *to_app(e)) {
m_parents.reserve(arg->get_id() + 1);
m_parents[arg->get_id()].push_back(e);
}
if (m.is_bool(e))
mk_literal(e);
visit(e);
}
else {
for (expr* arg : *to_app(e))
m_todo.push_back(arg);
}
}
else {
expr_ref _e(e, m);
m_todo.pop_back();
visit(e);
}
}
}
void context::new_value_eh(expr* e) {
DEBUG_CODE(
if (m.is_bool(e)) {
auto v = m_atom2bool_var.get(e->get_id(), sat::null_bool_var);
if (v != sat::null_bool_var) {
SASSERT(m.is_true(get_value(e)) == is_true(v));
}
}
);
m_repair_down.reserve(e->get_id() + 1);
m_repair_up.reserve(e->get_id() + 1);
if (!term(e->get_id()))
verbose_stream() << "no term " << mk_bounded_pp(e, m) << "\n";
SASSERT(e == term(e->get_id()));
if (!m_repair_down.contains(e->get_id()))
m_repair_down.insert(e->get_id());
for (auto p : parents(e)) {
auto pid = p->get_id();
m_repair_up.reserve(pid + 1);
m_repair_down.reserve(pid + 1);
if (!m_repair_up.contains(pid))
m_repair_up.insert(pid);
}
}
void context::register_term(expr* e) {
for (auto p : m_plugins)
if (p)
p->register_term(e);
}
ptr_vector<expr> const& context::subterms() {
if (!m_subterms.empty())
return m_subterms;
for (auto e : m_allterms)
if (e)
m_subterms.push_back(e);
std::stable_sort(m_subterms.begin(), m_subterms.end(),
[](expr* a, expr* b) { return get_depth(a) < get_depth(b); });
return m_subterms;
}
void context::reinit_relevant() {
m_relevant.reset();
m_visited.reset();
m_root_literals.reset();
for (auto const& clause : s.clauses()) {
bool has_relevant = false;
unsigned n = 0;
sat::literal selected_lit = sat::null_literal;
for (auto lit : clause) {
auto atm = m_atoms.get(lit.var(), nullptr);
if (!atm)
continue;
auto a = atm->get_id();
if (!is_true(lit))
continue;
if (m_relevant.contains(a)) {
has_relevant = true;
break;
}
if (m_rand() % ++n == 0)
selected_lit = lit;
}
if (!has_relevant && selected_lit != sat::null_literal) {
m_relevant.insert(m_atoms[selected_lit.var()]->get_id());
m_root_literals.push_back(selected_lit);
}
}
shuffle(m_root_literals.size(), m_root_literals.data(), m_rand);
}
std::ostream& context::display(std::ostream& out) const {
for (auto id : m_repair_down)
out << "d " << mk_bounded_pp(term(id), m) << "\n";
for (auto id : m_repair_up)
out << "u " << mk_bounded_pp(term(id), m) << "\n";
for (unsigned v = 0; v < m_atoms.size(); ++v) {
auto e = m_atoms[v];
if (e)
out << v << ": " << mk_bounded_pp(e, m) << " := " << (is_true(v)?"T":"F") << "\n";
}
for (auto p : m_plugins)
if (p)
p->display(out);
return out;
}
void context::collect_statistics(statistics& st) const {
for (auto p : m_plugins)
if (p)
p->collect_statistics(st);
st.update("sls-repair-down", m_stats.m_num_repair_down);
st.update("sls-repair-up", m_stats.m_num_repair_up);
st.update("sls-constraints", m_stats.m_num_constraints);
}
void context::reset_statistics() {
for (auto p : m_plugins)
if (p)
p->reset_statistics();
m_stats.reset();
}
}

212
src/ast/sls/sls_context.h Normal file
View file

@ -0,0 +1,212 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_context.h
Abstract:
A Stochastic Local Search (SLS) Context.
Author:
Nikolaj Bjorner (nbjorner) 2024-06-24
--*/
#pragma once
#include "util/sat_literal.h"
#include "util/sat_sls.h"
#include "util/statistics.h"
#include "ast/ast.h"
#include "ast/euf/euf_egraph.h"
#include "model/model.h"
#include "util/scoped_ptr_vector.h"
#include "util/obj_hashtable.h"
#include "util/heap.h"
namespace sls {
class context;
class euf_plugin;
class plugin {
protected:
context& ctx;
ast_manager& m;
family_id m_fid;
public:
plugin(context& c);
virtual ~plugin() {}
virtual family_id fid() { return m_fid; }
virtual void register_term(expr* e) = 0;
virtual expr_ref get_value(expr* e) = 0;
virtual void initialize() = 0;
virtual void start_propagation() {};
virtual bool propagate() = 0;
virtual void propagate_literal(sat::literal lit) = 0;
virtual void repair_literal(sat::literal lit) = 0;
virtual bool repair_down(app* e) = 0;
virtual void repair_up(app* e) = 0;
virtual bool is_sat() = 0;
virtual void on_rescale() {};
virtual void on_restart() {};
virtual std::ostream& display(std::ostream& out) const = 0;
virtual bool set_value(expr* e, expr* v) = 0;
virtual void collect_statistics(statistics& st) const = 0;
virtual void reset_statistics() = 0;
virtual bool include_func_interp(func_decl* f) const { return false; }
};
using clause = ptr_iterator<sat::literal>;
class sat_solver_context {
public:
virtual ~sat_solver_context() {}
virtual vector<sat::clause_info> const& clauses() const = 0;
virtual sat::clause_info const& get_clause(unsigned idx) const = 0;
virtual ptr_iterator<unsigned> get_use_list(sat::literal lit) = 0;
virtual void flip(sat::bool_var v) = 0;
virtual double reward(sat::bool_var v) = 0;
virtual double get_weigth(unsigned clause_idx) = 0;
virtual bool is_true(sat::literal lit) = 0;
virtual unsigned num_vars() const = 0;
virtual indexed_uint_set const& unsat() const = 0;
virtual void on_model(model_ref& mdl) = 0;
virtual sat::bool_var add_var() = 0;
virtual void add_clause(unsigned n, sat::literal const* lits) = 0;
virtual void force_restart() = 0;
virtual std::ostream& display(std::ostream& out) = 0;
};
class context {
struct greater_depth {
context& c;
greater_depth(context& c) : c(c) {}
bool operator()(unsigned x, unsigned y) const {
return get_depth(c.term(x)) > get_depth(c.term(y));
}
};
struct less_depth {
context& c;
less_depth(context& c) : c(c) {}
bool operator()(unsigned x, unsigned y) const {
return get_depth(c.term(x)) < get_depth(c.term(y));
}
};
struct stats {
unsigned m_num_repair_down = 0;
unsigned m_num_repair_up = 0;
unsigned m_num_constraints = 0;
void reset() { memset(this, 0, sizeof(*this)); }
};
ast_manager& m;
sat_solver_context& s;
scoped_ptr_vector<plugin> m_plugins;
indexed_uint_set m_relevant, m_visited;
expr_ref_vector m_atoms;
unsigned_vector m_atom2bool_var;
params_ref m_params;
vector<ptr_vector<expr>> m_parents;
sat::literal_vector m_root_literals, m_unit_literals;
indexed_uint_set m_unit_indices;
random_gen m_rand;
bool m_initialized = false;
bool m_new_constraint = false;
bool m_dirty = false;
expr_ref_vector m_allterms;
ptr_vector<expr> m_subterms;
greater_depth m_gd;
less_depth m_ld;
heap<greater_depth> m_repair_down;
heap<less_depth> m_repair_up;
uint_set m_constraint_ids;
expr_ref_vector m_constraint_trail;
stats m_stats;
void register_plugin(plugin* p);
void init();
expr_ref_vector m_todo;
void register_terms(expr* e);
void register_term(expr* e);
void propagate_boolean_assignment();
void propagate_literal(sat::literal lit);
void repair_literals();
void values2model();
void ensure_plugin(expr* e);
void ensure_plugin(family_id fid);
family_id get_fid(expr* e) const;
sat::literal mk_literal();
public:
context(ast_manager& m, sat_solver_context& s);
// Between SAT/SMT solver and context.
void register_atom(sat::bool_var v, expr* e);
lbool check();
void on_restart();
void updt_params(params_ref const& p);
params_ref const& get_params() const { return m_params; }
// expose sat_solver to plugins
vector<sat::clause_info> const& clauses() const { return s.clauses(); }
sat::clause_info const& get_clause(unsigned idx) const { return s.get_clause(idx); }
ptr_iterator<unsigned> get_use_list(sat::literal lit) { return s.get_use_list(lit); }
double get_weight(unsigned clause_idx) { return s.get_weigth(clause_idx); }
unsigned num_bool_vars() const { return s.num_vars(); }
bool is_true(sat::literal lit) { return s.is_true(lit); }
bool is_true(sat::bool_var v) const { return s.is_true(sat::literal(v, false)); }
expr* atom(sat::bool_var v) { return m_atoms.get(v, nullptr); }
expr* term(unsigned id) const { return m_allterms.get(id); }
sat::bool_var atom2bool_var(expr* e) const { return m_atom2bool_var.get(e->get_id(), sat::null_bool_var); }
sat::literal mk_literal(expr* e);
void add_clause(expr* f);
void add_clause(sat::literal_vector const& lits);
void flip(sat::bool_var v) { s.flip(v); }
double reward(sat::bool_var v) { return s.reward(v); }
indexed_uint_set const& unsat() const { return s.unsat(); }
unsigned rand() { return m_rand(); }
unsigned rand(unsigned n) { return m_rand(n); }
sat::literal_vector const& root_literals() const { return m_root_literals; }
sat::literal_vector const& unit_literals() const { return m_unit_literals; }
bool is_unit(sat::literal lit) const { return m_unit_indices.contains(lit.index()); }
void reinit_relevant();
void force_restart() { s.force_restart(); }
bool include_func_interp(func_decl* f) const { return any_of(m_plugins, [&](plugin* p) { return p && p->include_func_interp(f); }); }
ptr_vector<expr> const& parents(expr* e) {
m_parents.reserve(e->get_id() + 1);
return m_parents[e->get_id()];
}
// Between plugin solvers
expr_ref get_value(expr* e);
bool set_value(expr* e, expr* v);
void new_value_eh(expr* e);
bool is_true(expr* e);
bool is_fixed(expr* e);
bool is_relevant(expr* e);
void add_constraint(expr* e);
ptr_vector<expr> const& subterms();
ast_manager& get_manager() { return m; }
std::ostream& display(std::ostream& out) const;
std::ostream& display_all(std::ostream& out) const { return s.display(out); }
scoped_ptr<euf::egraph>& egraph();
euf_plugin& euf();
void collect_statistics(statistics& st) const;
void reset_statistics();
};
}

View file

@ -0,0 +1,956 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_datatype_plugin.cpp
Abstract:
Algebraic Datatypes for SLS
Author:
Nikolaj Bjorner (nbjorner) 2024-10-14
Notes:
Eager reduction to EUF:
is-c(c(t)) for each c(t) in T
acc_i(c(t_i)) = t_i for each c(..t_i..) in T
is-c(t) => t = c(...acc_j(t)..) for each acc_j(t) in T
sum_i is-c_i(t) = 1
is-c(t) <=> c = t for each 0-ary constructor c
is-c(t) <=> t = c(acc_1(t)..acc_n(t))
s = acc(...(acc(t)) => s != t if t is recursive
or_i t = t_i if t is a finite sort with terms t_i
s := acc(t) => s < t in P
a := s = acc(t), a is a unit => s < t in P
a := s = acc(t), a in Atoms => (a => s < t) in P
s << t if there is a path P with conditions L.
L => s != t
This disregards if acc is applied to non-matching constructor.
In this case we rely on that the interpretation of acc can be
forced.
If this is incorrect, include is-c(t) assumptions in path axioms.
Is P sufficient? Should we just consider all possible paths of depth at most k to be safe?
Example:
C(acc(t)) == C(s)
triggers equation acc(t) = s, but the equation is implicit, so acc(t) and s are not directly
connected.
Even, the axioms extracted from P don't consider transitivity of =.
So the can-be-equal alias approximation is too strong.
We therefore add an occurs check during propagation and lazily add missed axioms.
Model-repair based:
1. Initialize uninterpreted datatype nodes to hold arbitrary values.
2. Initialize datatype nodes by induced evaluation.
3. Atomic constraints are of the form for datatype terms
x = y, x = t, x != y, x != t; s = t, s != t
violated x = y: x <- eval(y), y <- eval(x) or x, y <- fresh
violated x = t: x <- eval(t), repair t using the shape of x
violated x != y: x <- fresh, y <- fresh
violated x != t: x <- fresh, subterm y of t: y <- fresh
acc(x) = t: eval(x) = c(u, v) acc(c(u,v)) = u -> repair(u = t)
acc(x) = t: eval(x) does not match acc -> acc(x)
has a fixed interpretation, so repair over t instead, or update interpretation of x
uses:
model::get_fresh_value(s)
model::get_some_value(s)
--*/
#include "ast/sls/sls_datatype_plugin.h"
#include "ast/sls/sls_euf_plugin.h"
#include "ast/ast_pp.h"
#include "params/sls_params.hpp"
namespace sls {
datatype_plugin::datatype_plugin(context& c):
plugin(c),
euf(c.euf()),
g(c.egraph()),
dt(m),
m_axioms(m),
m_values(m),
m_eval(m) {
m_fid = dt.get_family_id();
}
datatype_plugin::~datatype_plugin() {}
void datatype_plugin::collect_path_axioms() {
expr* t = nullptr, *z = nullptr;
for (auto s : ctx.subterms()) {
if (dt.is_accessor(s, t) && dt.is_recursive(t) && dt.is_recursive(s))
add_edge(s, t, m.mk_app(dt.get_constructor_is(dt.get_accessor_constructor(to_app(s)->get_decl())), t));
if (dt.is_constructor(s) && dt.is_recursive(s)) {
for (auto arg : *to_app(s))
add_edge(arg, s, nullptr);
}
}
expr* x = nullptr, *y = nullptr;
for (sat::bool_var v = 0; v < ctx.num_bool_vars(); ++v) {
expr* e = ctx.atom(v);
if (!e)
continue;
if (!m.is_eq(e, x, y))
continue;
if (!dt.is_recursive(x))
continue;
sat::literal lp(v, false), ln(v, true);
if (dt.is_accessor(x, z) && dt.is_recursive(z)) {
if (ctx.is_unit(lp))
add_edge(y, z, nullptr);
else if (ctx.is_unit(ln))
;
else
add_edge(y, z, e);
}
if (dt.is_accessor(y, z) && dt.is_recursive(z)) {
if (ctx.is_unit(lp))
add_edge(x, z, m.mk_app(dt.get_constructor_is(dt.get_accessor_constructor(to_app(y)->get_decl())), z));
else if (ctx.is_unit(ln))
;
else
add_edge(x, z, e);
}
}
add_path_axioms();
}
void datatype_plugin::add_edge(expr* child, expr* parent, expr* cond) {
m_parents.insert_if_not_there(child, vector<parent_t>()).push_back({parent, expr_ref(cond, m)});
TRACE("dt", tout << mk_bounded_pp(child, m) << " <- " << mk_bounded_pp(parent, m) << " " << mk_bounded_pp(cond, m) << "\n");
}
void datatype_plugin::add_path_axioms() {
ptr_vector<expr> path;
sat::literal_vector lits;
for (auto [child, parents] : m_parents) {
path.reset();
lits.reset();
path.push_back(child);
add_path_axioms(path, lits, parents);
}
}
void datatype_plugin::add_path_axioms(ptr_vector<expr>& children, sat::literal_vector& lits, vector<parent_t> const& parents) {
for (auto const& [parent, cond] : parents) {
if (cond)
lits.push_back(~ctx.mk_literal(cond));
if (children.contains(parent)) {
// only assert loop clauses for proper loops
if (parent == children[0])
ctx.add_clause(lits);
if (cond)
lits.pop_back();
continue;
}
if (children[0]->get_sort() == parent->get_sort()) {
lits.push_back(~ctx.mk_literal(m.mk_eq(children[0], parent)));
TRACE("dt", for (auto lit : lits) tout << (lit.sign() ? "~": "") << mk_pp(ctx.atom(lit.var()), m) << "\n";);
ctx.add_clause(lits);
lits.pop_back();
}
auto child = children.back();
if (m_parents.contains(child)) {
children.push_back(parent);
auto& parents2 = m_parents[child];
add_path_axioms(children, lits, parents2);
children.pop_back();
}
if (cond)
lits.pop_back();
}
}
void datatype_plugin::add_axioms() {
expr_ref_vector axioms(m);
for (auto t : ctx.subterms()) {
auto s = t->get_sort();
if (dt.is_datatype(s))
m_dts.insert_if_not_there(s, ptr_vector<expr>()).push_back(t);
if (!is_app(t))
continue;
auto ta = to_app(t);
auto f = ta->get_decl();
if (dt.is_constructor(t)) {
auto r = dt.get_constructor_is(f);
m_axioms.push_back(m.mk_app(r, t));
auto& acc = *dt.get_constructor_accessors(f);
for (unsigned i = 0; i < ta->get_num_args(); ++i) {
auto ti = ta->get_arg(i);
m_axioms.push_back(m.mk_eq(ti, m.mk_app(acc[i], t)));
}
auto& cns = *dt.get_datatype_constructors(s);
for (auto c : cns) {
if (c != f) {
auto r2 = dt.get_constructor_is(c);
m_axioms.push_back(m.mk_not(m.mk_app(r2, t)));
}
}
continue;
}
if (dt.is_recognizer0(f)) {
auto u = ta->get_arg(0);
auto c = dt.get_recognizer_constructor(f);
m_axioms.push_back(m.mk_iff(t, m.mk_app(dt.get_constructor_is(c), u)));
}
if (dt.is_update_field(t)) {
NOT_IMPLEMENTED_YET();
}
if (dt.is_datatype(s)) {
auto& cns = *dt.get_datatype_constructors(s);
expr_ref_vector ors(m);
for (auto c : cns) {
auto r = dt.get_constructor_is(c);
ors.push_back(m.mk_app(r, t));
}
m_axioms.push_back(m.mk_or(ors));
#if 0
// expanded lazily
// EUF already handles enumeration datatype case.
for (unsigned i = 0; i < cns.size(); ++i) {
auto r1 = dt.get_constructor_is(cns[i]);
for (unsigned j = i + 1; j < cns.size(); ++j) {
auto r2 = dt.get_constructor_is(cns[j]);
m_axioms.push_back(m.mk_or(m.mk_not(m.mk_app(r1, t)), m.mk_not(m.mk_app(r2, t))));
}
}
#endif
for (auto c : cns) {
auto r = dt.get_constructor_is(c);
auto& acc = *dt.get_constructor_accessors(c);
expr_ref_vector args(m);
for (auto a : acc)
args.push_back(m.mk_app(a, t));
m_axioms.push_back(m.mk_iff(m.mk_app(r, t), m.mk_eq(t, m.mk_app(c, args))));
}
}
}
//collect_path_axioms();
TRACE("dt", for (auto a : m_axioms) tout << mk_pp(a, m) << "\n";);
for (auto a : m_axioms)
ctx.add_constraint(a);
}
void datatype_plugin::initialize() {
sls_params sp(ctx.get_params());
m_axiomatic_mode = sp.dt_axiomatic();
if (m_axiomatic_mode)
add_axioms();
}
expr_ref datatype_plugin::get_value(expr* e) {
if (!dt.is_datatype(e))
return expr_ref(m);
if (m_axiomatic_mode) {
init_values();
return expr_ref(m_values.get(g->find(e)->get_root_id()), m);
}
return expr_ref(m_eval.get(e->get_id()), m);
}
void datatype_plugin::init_values() {
if (!m_values.empty())
return;
TRACE("dt", g->display(tout));
m_model = alloc(model, m);
// retrieve e-graph from sls_euf_solver: add bridge in sls_context to share e-graph
SASSERT(g);
// build top_sort<euf::enode> similar to dt_solver.cpp
top_sort<euf::enode> deps;
for (auto* n : g->nodes())
if (n->is_root())
add_dep(n, deps);
auto trace_assignment = [&](std::ostream& out, euf::enode* n) {
for (auto sib : euf::enode_class(n))
out << g->bpp(sib) << " ";
out << " <- " << mk_bounded_pp(m_values.get(n->get_id()), m) << "\n";
};
deps.topological_sort();
expr_ref_vector args(m);
euf::enode_vector leaves, worklist;
obj_map<euf::enode, euf::enode_vector> leaf2root;
// walk topological sort in order of leaves to roots, attaching values to nodes.
for (euf::enode* n : deps.top_sorted()) {
SASSERT(n->is_root());
unsigned id = n->get_id();
if (m_values.get(id, nullptr))
continue;
expr* e = n->get_expr();
m_values.reserve(id + 1);
if (!dt.is_datatype(e))
continue;
euf::enode* con = get_constructor(n);
if (!con) {
leaves.push_back(n);
continue;
}
auto f = con->get_decl();
args.reset();
bool has_null = false;
for (auto arg : euf::enode_args(con)) {
if (dt.is_datatype(arg->get_sort())) {
auto val_arg = m_values.get(arg->get_root_id());
if (!val_arg)
has_null = true;
leaf2root.insert_if_not_there(arg->get_root(), euf::enode_vector()).push_back(n);
args.push_back(val_arg);
}
else
args.push_back(ctx.get_value(arg->get_expr()));
}
if (!has_null) {
m_values.setx(id, m.mk_app(f, args));
m_model->register_value(m_values.get(id));
TRACE("dt", tout << "Set interpretation "; trace_assignment(tout, n););
}
}
TRACE("dt",
for (euf::enode* n : deps.top_sorted()) {
tout << g->bpp(n) << ": ";
tout << g->bpp(get_constructor(n)) << " :: ";
auto s = deps.get_dep(n);
if (s) {
tout << " -> ";
for (auto t : *s)
tout << g->bpp(t) << " ";
}
tout << "\n";
}
);
auto process_workitem = [&](euf::enode* n) {
if (!leaf2root.contains(n))
return true;
bool all_processed = true;
for (auto p : leaf2root[n]) {
if (m_values.get(p->get_id(), nullptr))
continue;
auto con = get_constructor(p);
SASSERT(con);
auto f = con->get_decl();
args.reset();
bool has_missing = false;
for (auto arg : euf::enode_args(con)) {
if (dt.is_datatype(arg->get_sort())) {
auto arg_val = m_values.get(arg->get_root_id());
if (!arg_val)
has_missing = true;
args.push_back(arg_val);
}
else
args.push_back(ctx.get_value(arg->get_expr()));
}
if (has_missing) {
all_processed = false;
continue;
}
worklist.push_back(p);
SASSERT(all_of(args, [&](expr* e) { return e != nullptr; }));
m_values.setx(p->get_id(), m.mk_app(f, args));
TRACE("dt", tout << "Patched interpretation "; trace_assignment(tout, p););
m_model->register_value(m_values.get(p->get_id()));
}
return all_processed;
};
auto process_worklist = [&](euf::enode_vector& worklist) {
unsigned j = 0, sz = worklist.size();
for (unsigned i = 0; i < worklist.size(); ++i)
if (!process_workitem(worklist[i]))
worklist[j++] = worklist[i];
worklist.shrink(j);
return j < sz;
};
// attach fresh values to each leaf, walk up parents to assign them values.
while (!leaves.empty()) {
auto n = leaves.back();
leaves.pop_back();
SASSERT(!get_constructor(n));
auto v = m_model->get_fresh_value(n->get_sort());
if (!v)
v = m_model->get_some_value(n->get_sort());
SASSERT(v);
unsigned id = n->get_id();
m_values.setx(id, v);
TRACE("dt", tout << "Fresh interpretation "; trace_assignment(tout, n););
worklist.reset();
worklist.push_back(n);
while (process_worklist(worklist))
;
}
}
void datatype_plugin::add_dep(euf::enode* n, top_sort<euf::enode>& dep) {
if (!dt.is_datatype(n->get_expr()))
return;
euf::enode* con = get_constructor(n);
TRACE("dt", tout << g->bpp(n) << " con: " << g->bpp(con) << "\n";);
if (!con)
dep.insert(n, nullptr);
else if (con->num_args() == 0)
dep.insert(n, nullptr);
else
for (euf::enode* arg : euf::enode_args(con))
dep.add(n, arg->get_root());
}
void datatype_plugin::start_propagation() {
m_values.reset();
m_model = nullptr;
}
euf::enode* datatype_plugin::get_constructor(euf::enode* n) const {
for (auto sib : euf::enode_class(n))
if (dt.is_constructor(sib->get_expr()))
return sib;
return nullptr;
}
bool datatype_plugin::propagate() {
enum color_t { white, grey, black };
svector<color_t> color;
ptr_vector<euf::enode> stack;
obj_map<sort, ptr_vector<expr>> sorts;
auto set_conflict = [&](euf::enode* n) {
expr_ref_vector diseqs(m);
while (true) {
auto n2 = stack.back();
auto con2 = get_constructor(n2);
if (n2 != con2)
diseqs.push_back(m.mk_not(m.mk_eq(n2->get_expr(), con2->get_expr())));
if (n2->get_root() == n->get_root()) {
if (n != n2)
diseqs.push_back(m.mk_not(m.mk_eq(n->get_expr(), n2->get_expr())));
break;
}
stack.pop_back();
}
IF_VERBOSE(1, verbose_stream() << "cycle\n"; for (auto e : diseqs) verbose_stream() << mk_pp(e, m) << "\n";);
ctx.add_constraint(m.mk_or(diseqs));
++m_stats.m_num_occurs;
};
for (auto n : g->nodes()) {
if (!n->is_root())
continue;
euf::enode* con = nullptr;
for (auto sib : euf::enode_class(n)) {
if (dt.is_constructor(sib->get_expr())) {
if (!con)
con = sib;
if (con && con->get_decl() != sib->get_decl()) {
ctx.add_constraint(m.mk_not(m.mk_eq(con->get_expr(), sib->get_expr())));
++m_stats.m_num_occurs;
}
}
}
}
for (auto n : g->nodes()) {
if (!n->is_root())
continue;
expr* e = n->get_expr();
if (!dt.is_datatype(e))
continue;
if (!ctx.is_relevant(e))
continue;
sort* s = e->get_sort();
sorts.insert_if_not_there(s, ptr_vector<expr>()).push_back(e);
auto c = color.get(e->get_id(), white);
SASSERT(c != grey);
if (c == black)
continue;
// dfs traversal of enodes, starting with n,
// with outgoing edges the arguments of con, where con
// is a node in the same congruence class as n that is a constructor.
// For every cycle accumulate a conflict.
stack.push_back(n);
while (!stack.empty()) {
n = stack.back();
unsigned id = n->get_root_id();
c = color.get(id, white);
euf::enode* con;
switch (c) {
case black:
stack.pop_back();
break;
case grey:
case white:
color.setx(id, grey, white);
con = get_constructor(n);
if (!con)
goto done_with_node;
for (auto child : euf::enode_args(con)) {
auto c2 = color.get(child->get_root_id(), white);
switch (c2) {
case black:
break;
case grey:
set_conflict(child);
return true;
case white:
stack.push_back(child);
goto node_pushed;
}
}
done_with_node:
color[id] = black;
stack.pop_back();
node_pushed:
break;
}
}
}
for (auto const& [s, elems] : sorts) {
auto sz = s->get_num_elements();
if (!sz.is_finite() || sz.size() >= elems.size())
continue;
ctx.add_constraint(m.mk_not(m.mk_distinct((unsigned)sz.size() + 1, elems.data())));
}
return false;
}
bool datatype_plugin::include_func_interp(func_decl* f) const {
if (!dt.is_accessor(f))
return false;
func_decl* con_decl = dt.get_accessor_constructor(f);
for (euf::enode* app : g->enodes_of(f)) {
euf::enode* con = get_constructor(app->get_arg(0));
if (con && con->get_decl() != con_decl)
return true;
}
return false;
}
std::ostream& datatype_plugin::display(std::ostream& out) const {
for (auto a : m_axioms)
out << mk_bounded_pp(a, m, 3) << "\n";
return out;
}
void datatype_plugin::propagate_literal(sat::literal lit) {
if (m_axiomatic_mode)
euf.propagate_literal(lit);
else
propagate_literal_model_building(lit);
}
void datatype_plugin::propagate_literal_model_building(sat::literal lit) {
if (!ctx.is_true(lit))
return;
auto a = ctx.atom(lit.var());
if (!a || !is_app(a))
return;
repair_down(to_app(a));
}
bool datatype_plugin::is_sat() { return true; }
void datatype_plugin::register_term(expr* e) {
expr* t = nullptr;
if (dt.is_accessor(e, t)) {
auto f = to_app(e)->get_decl();
m_occurs.insert_if_not_there(f, expr_set()).insert(e);
m_eval_accessor.insert_if_not_there(f, obj_map<expr, expr*>());
}
}
bool datatype_plugin::repair_down(app* e) {
expr* t, * s;
auto v0 = eval0(e);
auto v1 = eval1(e);
if (v0 == v1)
return true;
IF_VERBOSE(2, verbose_stream() << "dt-repair-down " << mk_bounded_pp(e, m) << " " << v0 << " <- " << v1 << "\n");
if (dt.is_constructor(e))
repair_down_constructor(e, v0, v1);
else if (dt.is_accessor(e, t))
repair_down_accessor(e, t, v0);
else if (dt.is_recognizer(e, t))
repair_down_recognizer(e, t);
else if (m.is_eq(e, s, t))
repair_down_eq(e, s, t);
else if (m.is_distinct(e))
repair_down_distinct(e);
else {
UNREACHABLE();
}
return false;
}
//
// C(t) <- C(s) then repair t <- s
// C(t) <- D(s) then fail the repair.
//
void datatype_plugin::repair_down_constructor(app* e, expr* v0, expr* v1) {
SASSERT(dt.is_constructor(v0));
SASSERT(dt.is_constructor(v1));
SASSERT(e->get_decl() == to_app(v1)->get_decl());
if (e->get_decl() == to_app(v0)->get_decl()) {
for (unsigned i = 0; i < e->get_num_args(); ++i) {
auto w0 = to_app(v0)->get_arg(i);
auto w1 = to_app(v1)->get_arg(i);
if (w0 == w1)
continue;
expr* arg = e->get_arg(i);
set_eval0(arg, w0);
ctx.new_value_eh(arg);
}
}
}
//
// A_D(t) <- s, val(t) = D(..s'..) then update val(t) to agree with s
// A_D(t) <- s, val(t) = C(..) then set t to D(...s...)
// , eval(val(A_D(t))) = s' then update eval(val(A_D,(t))) to s'
void datatype_plugin::repair_down_accessor(app* e, expr* t, expr* v0) {
auto f = e->get_decl();
auto c = dt.get_accessor_constructor(f);
auto val_t = eval0(t);
SASSERT(dt.is_constructor(val_t));
expr_ref_vector args(m);
auto const& accs = *dt.get_constructor_accessors(c);
unsigned i;
for (i = 0; i < accs.size(); ++i) {
if (accs[i] == f)
break;
}
SASSERT(i < accs.size());
if (to_app(val_t)->get_decl() == c) {
if (to_app(val_t)->get_arg(i) == v0)
return;
args.append(accs.size(), to_app(val_t)->get_args());
args[i] = v0;
expr* new_val_t = m.mk_app(c, args);
set_eval0(t, new_val_t);
ctx.new_value_eh(t);
return;
}
if (ctx.rand(5) != 0) {
update_eval_accessor(e, val_t, v0);
return;
}
for (unsigned j = 0; j < accs.size(); ++j) {
if (i == j)
args[i] = v0;
else
args[j] = m_model->get_some_value(accs[j]->get_range());
}
expr* new_val_t = m.mk_app(c, args);
set_eval0(t, new_val_t);
ctx.new_value_eh(t);
}
void datatype_plugin::repair_down_recognizer(app* e, expr* t) {
auto bv = ctx.atom2bool_var(e);
auto is_true = ctx.is_true(bv);
auto c = dt.get_recognizer_constructor(e->get_decl());
auto val_t = eval0(t);
auto const& cons = *dt.get_datatype_constructors(t->get_sort());
auto set_to_instance = [&](func_decl* c) {
auto const& accs = *dt.get_constructor_accessors(c);
expr_ref_vector args(m);
for (auto a : accs)
args.push_back(m_model->get_some_value(a->get_range()));
set_eval0(t, m.mk_app(c, args));
ctx.new_value_eh(t);
};
auto different_constructor = [&](func_decl* c) {
unsigned i = 0;
func_decl* c_new = nullptr;
for (auto c2 : cons)
if (c2 != c && ctx.rand(++i) == 0)
c_new = c2;
return c_new;
};
SASSERT(dt.is_constructor(val_t));
if (c == to_app(val_t)->get_decl() && is_true)
return;
if (c != to_app(val_t)->get_decl() && !is_true)
return;
if (ctx.rand(10) == 0)
ctx.flip(bv);
else if (is_true)
set_to_instance(c);
else if (cons.size() == 1)
ctx.flip(bv);
else
set_to_instance(different_constructor(c));
}
void datatype_plugin::repair_down_eq(app* e, expr* s, expr* t) {
auto bv = ctx.atom2bool_var(e);
auto is_true = ctx.is_true(bv);
auto vs = eval0(s);
auto vt = eval0(t);
if (is_true && vs == vt)
return;
if (!is_true && vs != vt)
return;
if (is_true) {
auto coin = ctx.rand(5);
if (coin <= 1) {
set_eval0(s, vt);
ctx.new_value_eh(s);
return;
}
if (coin <= 3) {
set_eval0(t, vs);
ctx.new_value_eh(t);
}
if (true) {
auto new_v = m_model->get_some_value(s->get_sort());
set_eval0(s, new_v);
set_eval0(t, new_v);
ctx.new_value_eh(s);
ctx.new_value_eh(t);
return;
}
}
auto coin = ctx.rand(10);
if (coin <= 4) {
auto new_v = m_model->get_some_value(s->get_sort());
set_eval0(s, new_v);
ctx.new_value_eh(s);
return;
}
if (coin <= 9) {
auto new_v = m_model->get_some_value(s->get_sort());
set_eval0(t, new_v);
ctx.new_value_eh(t);
return;
}
}
void datatype_plugin::repair_down_distinct(app* e) {
auto bv = ctx.atom2bool_var(e);
auto is_true = ctx.is_true(bv);
unsigned sz = e->get_num_args();
for (unsigned i = 0; i < sz; ++i) {
auto val1 = eval0(e->get_arg(i));
for (unsigned j = i + 1; j < sz; ++j) {
auto val2 = eval0(e->get_arg(j));
if (val1 != val2)
continue;
if (!is_true)
return;
if (ctx.rand(2) == 0)
std::swap(i, j);
auto new_v = m_model->get_some_value(e->get_arg(i)->get_sort());
set_eval0(e->get_arg(i), new_v);
ctx.new_value_eh(e->get_arg(i));
return;
}
}
if (is_true)
return;
if (sz == 1) {
ctx.flip(bv);
return;
}
unsigned i = ctx.rand(sz);
unsigned j = ctx.rand(sz-1);
if (j == i)
++j;
if (ctx.rand(2) == 0)
std::swap(i, j);
set_eval0(e->get_arg(i), eval0(e->get_arg(j)));
}
void datatype_plugin::repair_up(app* e) {
IF_VERBOSE(2, verbose_stream() << "dt-repair-up " << mk_bounded_pp(e, m) << "\n");
expr* t;
auto v0 = eval0(e);
auto v1 = eval1(e);
if (v0 == v1)
return;
if (dt.is_constructor(e))
set_eval0(e, v1);
else if (m.is_bool(e))
ctx.flip(ctx.atom2bool_var(e));
else if (dt.is_accessor(e, t))
repair_up_accessor(e, t, v1);
else {
UNREACHABLE();
}
}
void datatype_plugin::repair_up_accessor(app* e, expr* t, expr* v1) {
auto v_t = eval0(t);
auto f = e->get_decl();
SASSERT(dt.is_constructor(v_t));
auto c = dt.get_accessor_constructor(f);
if (to_app(v_t)->get_decl() != c)
update_eval_accessor(e, v_t, v1);
set_eval0(e, v1);
}
expr_ref datatype_plugin::eval1(expr* e) {
expr* s = nullptr, * t = nullptr;
if (m.is_eq(e, s, t))
return expr_ref(m.mk_bool_val(eval0rec(s) == eval0rec(t)), m);
if (m.is_distinct(e)) {
expr_ref_vector args(m);
for (auto arg : *to_app(e))
args.push_back(eval0(arg));
bool d = true;
for (unsigned i = 0; i < args.size(); ++i)
for (unsigned j = i + 1; i < args.size(); ++j)
d &= args.get(i) != args.get(j);
return expr_ref(m.mk_bool_val(d), m);
}
if (dt.is_accessor(e, t)) {
auto f = to_app(e)->get_decl();
auto v = eval0rec(t);
return eval_accessor(f, v);
}
if (dt.is_constructor(e)) {
expr_ref_vector args(m);
for (auto arg : *to_app(e))
args.push_back(eval0rec(arg));
return expr_ref(m.mk_app(to_app(e)->get_decl(), args), m);
}
if (dt.is_recognizer(e, t)) {
auto v = eval0rec(t);
SASSERT(dt.is_constructor(v));
auto c = dt.get_recognizer_constructor(to_app(e)->get_decl());
return expr_ref(m.mk_bool_val(c == to_app(v)->get_decl()), m);
}
return eval0(e);
}
expr_ref datatype_plugin::eval0rec(expr* e) {
auto v = m_eval.get(e->get_id(), nullptr);
if (v)
return expr_ref(v, m);
if (!is_app(e) || to_app(e)->get_family_id() != m_fid)
return ctx.get_value(e);
auto w = eval1(e);
m_eval.set(e->get_id(), w);
return w;
}
expr_ref datatype_plugin::eval_accessor(func_decl* f, expr* t) {
auto& t2val = m_eval_accessor[f];
if (!t2val.contains(t)) {
auto val = m_model->get_some_value(f->get_range());
m.inc_ref(t);
m.inc_ref(val);
}
return expr_ref(t2val[t], m);
}
void datatype_plugin::update_eval_accessor(app* e, expr* t, expr* value) {
func_decl* f = e->get_decl();
auto& t2val = m_eval_accessor[f];
expr* old_value = nullptr;
t2val.find(t, old_value);
if (old_value == value)
;
else if (old_value) {
t2val[t] = value;
m.inc_ref(value);
m.dec_ref(old_value);
}
else {
m.inc_ref(t);
m.inc_ref(value);
t2val.insert(t, value);
}
for (expr* b : m_occurs[f]) {
if (b == e)
continue;
expr* a;
VERIFY(dt.is_accessor(b, a));
auto v_a = eval0(a);
if (v_a.get() == t) {
set_eval0(b, value);
ctx.new_value_eh(b);
}
}
}
void datatype_plugin::del_eval_accessor() {
ptr_vector<expr> kv;
for (auto& [f, t2val] : m_eval_accessor)
for (auto& [k, val] : t2val)
kv.push_back(k), kv.push_back(val);
for (auto k : kv)
m.dec_ref(k);
}
expr_ref datatype_plugin::eval0(expr* n) {
if (!dt.is_datatype(n->get_sort()))
return ctx.get_value(n);
auto v = m_eval.get(n->get_id(), nullptr);
if (v)
return expr_ref(v, m);
set_eval0(n, m_model->get_some_value(n->get_sort()));
return expr_ref(m_eval.get(n->get_id()), m);
}
void datatype_plugin::set_eval0(expr* e, expr* value) {
if (dt.is_datatype(e->get_sort()))
m_eval[e->get_id()] = value;
else
ctx.set_value(e, value);
}
expr_ref datatype_plugin::eval0(euf::enode* n) {
return eval0(n->get_root()->get_expr());
}
void datatype_plugin::collect_statistics(statistics& st) const {
st.update("sls-dt-axioms", m_axioms.size());
st.update("sls-dt-occurs-conflicts", m_stats.m_num_occurs);
}
void datatype_plugin::reset_statistics() {}
}

View file

@ -0,0 +1,107 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_datatype_plugin.h
Abstract:
Algebraic Datatypes for SLS
Author:
Nikolaj Bjorner (nbjorner) 2024-10-14
--*/
#pragma once
#include "ast/sls/sls_context.h"
#include "ast/datatype_decl_plugin.h"
#include "util/top_sort.h"
namespace sls {
class euf_plugin;
class datatype_plugin : public plugin {
struct stats {
unsigned m_num_occurs = 0;
void reset() { memset(this, 0, sizeof(*this)); }
};
struct parent_t {
expr* parent;
expr_ref condition;
};
euf_plugin& euf;
scoped_ptr<euf::egraph>& g;
obj_map<sort, ptr_vector<expr>> m_dts;
obj_map<expr, vector<parent_t>> m_parents;
bool m_axiomatic_mode = true;
mutable datatype_util dt;
expr_ref_vector m_axioms, m_values, m_eval;
model_ref m_model;
stats m_stats;
void collect_path_axioms();
void add_edge(expr* child, expr* parent, expr* cond);
void add_path_axioms();
void add_path_axioms(ptr_vector<expr>& children, sat::literal_vector& lits, vector<parent_t> const& parents);
void add_axioms();
void init_values();
void add_dep(euf::enode* n, top_sort<euf::enode>& dep);
euf::enode* get_constructor(euf::enode* n) const;
// f -> v_t -> val
// e = A(t)
// val(t) <- val
//
typedef obj_hashtable<expr> expr_set;
obj_map<func_decl, obj_map<expr, expr*>> m_eval_accessor;
obj_map<func_decl, expr_set> m_occurs;
expr_ref eval1(expr* e);
expr_ref eval0(euf::enode* n);
expr_ref eval0(expr* n);
expr_ref eval0rec(expr* n);
expr_ref eval_accessor(func_decl* f, expr* t);
void update_eval_accessor(app* e, expr* t, expr* value);
void del_eval_accessor();
void set_eval0(expr* e, expr* val);
void repair_down_constructor(app* e, expr* v0, expr* v1);
void repair_down_accessor(app* e, expr* t, expr* v1);
void repair_down_recognizer(app* e, expr* t);
void repair_down_eq(app* e, expr* s, expr* t);
void repair_down_distinct(app* e);
void repair_up_accessor(app* e, expr* t, expr* v0);
void propagate_literal_model_building(sat::literal lit);
public:
datatype_plugin(context& c);
~datatype_plugin() override;
family_id fid() override { return m_fid; }
expr_ref get_value(expr* e) override;
void initialize() override;
void start_propagation() override;
void propagate_literal(sat::literal lit) override;
bool propagate() override;
bool is_sat() override;
void register_term(expr* e) override;
bool set_value(expr* e, expr* v) override { return false; }
void repair_literal(sat::literal lit) override {}
bool include_func_interp(func_decl* f) const override;
bool repair_down(app* e) override;
void repair_up(app* e) override;
std::ostream& display(std::ostream& out) const override;
void collect_statistics(statistics& st) const override;
void reset_statistics() override;
};
}

View file

@ -0,0 +1,489 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_euf_plugin.cpp
Abstract:
Congruence Closure for SLS
Author:
Nikolaj Bjorner (nbjorner) 2024-06-24
Todo:
- try incremental CC with backtracking for changing assignments
- try determining plateau moves.
- try generally a model rotation move.
--*/
#include "ast/sls/sls_euf_plugin.h"
#include "ast/ast_ll_pp.h"
#include "ast/ast_pp.h"
#include "params/sls_params.hpp"
namespace sls {
euf_plugin::euf_plugin(context& c):
plugin(c),
m_values(8U, value_hash(*this), value_eq(*this)) {
m_fid = user_sort_family_id;
}
euf_plugin::~euf_plugin() {}
void euf_plugin::initialize() {
sls_params sp(ctx.get_params());
m_incremental_mode = sp.euf_incremental();
m_incremental = 1 == m_incremental_mode;
IF_VERBOSE(2, verbose_stream() << "sls.euf: incremental " << m_incremental_mode << "\n");
}
void euf_plugin::start_propagation() {
if (m_incremental_mode == 2)
m_incremental = !m_incremental;
m_g = alloc(euf::egraph, m);
std::function<void(std::ostream&, void*)> dj = [&](std::ostream& out, void* j) {
out << "lit " << to_literal(reinterpret_cast<size_t*>(j));
};
m_g->set_display_justification(dj);
init_egraph(*m_g, !m_incremental);
}
void euf_plugin::register_term(expr* e) {
if (!is_app(e))
return;
if (!is_uninterp(e))
return;
app* a = to_app(e);
if (a->get_num_args() == 0)
return;
auto f = a->get_decl();
if (!m_app.contains(f))
m_app.insert(f, ptr_vector<app>());
m_app[f].push_back(a);
}
unsigned euf_plugin::value_hash::operator()(app* t) const {
unsigned r = 0;
for (auto arg : *t)
r *= 3, r += cc.ctx.get_value(arg)->hash();
return r;
}
bool euf_plugin::value_eq::operator()(app* a, app* b) const {
SASSERT(a->get_num_args() == b->get_num_args());
for (unsigned i = a->get_num_args(); i-- > 0; )
if (cc.ctx.get_value(a->get_arg(i)) != cc.ctx.get_value(b->get_arg(i)))
return false;
return true;
}
void euf_plugin::propagate_literal_incremental(sat::literal lit) {
m_replay_stack.push_back(lit);
replay();
}
sat::literal euf_plugin::resolve_conflict() {
auto& g = *m_g;
SASSERT(g.inconsistent());
++m_stats.m_num_conflicts;
unsigned n = 0;
sat::literal_vector lits;
sat::literal flit = sat::null_literal;
ptr_vector<size_t> explain;
g.begin_explain();
g.explain<size_t>(explain, nullptr);
g.end_explain();
double reward = -1;
TRACE("enf",
for (auto p : explain) {
sat::literal l = to_literal(p);
tout << l << " " << mk_pp(ctx.atom(l.var()), m) << " " << ctx.is_unit(l) << "\n";
});
for (auto p : explain) {
sat::literal l = to_literal(p);
CTRACE("euf", !ctx.is_true(l), tout << "not true " << l << "\n"; ctx.display(tout););
SASSERT(ctx.is_true(l));
if (ctx.is_unit(l))
continue;
if (!lits.contains(~l))
lits.push_back(~l);
if (ctx.reward(l.var()) > reward)
n = 0, reward = ctx.reward(l.var());
if (ctx.rand(++n) == 0)
flit = l;
}
// flip the last literal on the replay stack
IF_VERBOSE(10, verbose_stream() << "sls.euf - flip " << flit << "\n");
ctx.add_clause(lits);
return flit;
}
void euf_plugin::resolve() {
auto& g = *m_g;
if (!g.inconsistent())
return;
auto flit = resolve_conflict();
sat::literal slit;
if (flit == sat::null_literal)
return;
do {
slit = m_stack.back();
g.pop(1);
m_replay_stack.push_back(slit);
m_stack.pop_back();
}
while (slit != flit);
ctx.flip(flit.var());
m_replay_stack.back().neg();
}
void euf_plugin::replay() {
while (!m_replay_stack.empty()) {
auto l = m_replay_stack.back();
m_replay_stack.pop_back();
propagate_literal_incremental_step(l);
if (m_g->inconsistent())
resolve();
}
}
void euf_plugin::propagate_literal_incremental_step(sat::literal lit) {
SASSERT(ctx.is_true(lit));
auto e = ctx.atom(lit.var());
expr* x, * y;
auto& g = *m_g;
if (!e)
return;
TRACE("euf", tout << "propagate " << lit << "\n");
m_stack.push_back(lit);
g.push();
if (m.is_eq(e, x, y)) {
if (lit.sign())
g.new_diseq(g.find(e), to_ptr(lit));
else
g.merge(g.find(x), g.find(y), to_ptr(lit));
g.merge(g.find(e), g.find(m.mk_bool_val(!lit.sign())), to_ptr(lit));
}
else if (!lit.sign() && m.is_distinct(e)) {
auto n = to_app(e)->get_num_args();
for (unsigned i = 0; i < n; ++i) {
expr* a = to_app(e)->get_arg(i);
for (unsigned j = i + 1; j < n; ++j) {
auto b = to_app(e)->get_arg(j);
expr_ref eq(m.mk_eq(a, b), m);
auto c = g.find(eq);
if (!c) {
euf::enode* args[2] = { g.find(a), g.find(b) };
c = g.mk(eq, 0, 2, args);
}
g.new_diseq(c, to_ptr(lit));
g.merge(c, g.find(m.mk_false()), to_ptr(lit));
}
}
}
// else if (m.is_bool(e) && is_app(e) && to_app(e)->get_family_id() == basic_family_id)
// ;
else {
auto a = g.find(e);
auto b = g.find(m.mk_bool_val(!lit.sign()));
g.merge(a, b, to_ptr(lit));
}
g.propagate();
}
void euf_plugin::propagate_literal(sat::literal lit) {
if (m_incremental)
propagate_literal_incremental(lit);
else
propagate_literal_non_incremental(lit);
}
void euf_plugin::propagate_literal_non_incremental(sat::literal lit) {
SASSERT(ctx.is_true(lit));
auto e = ctx.atom(lit.var());
expr* x, * y;
if (!e)
return;
auto block = [&](euf::enode* a, euf::enode* b) {
TRACE("euf", tout << "block " << m_g->bpp(a) << " != " << m_g->bpp(b) << "\n");
if (a->get_root() != b->get_root())
return;
ptr_vector<size_t> explain;
m_g->explain_eq<size_t>(explain, nullptr, a, b);
m_g->end_explain();
unsigned n = 1;
sat::literal_vector lits;
sat::literal flit = sat::null_literal;
if (!ctx.is_unit(lit)) {
flit = lit;
lits.push_back(~lit);
}
for (auto p : explain) {
sat::literal l = to_literal(p);
if (!ctx.is_true(l))
return;
if (ctx.is_unit(l))
continue;
lits.push_back(~l);
if (ctx.rand(++n) == 0)
flit = l;
}
ctx.add_clause(lits);
++m_stats.m_num_conflicts;
if (flit != sat::null_literal)
ctx.flip(flit.var());
};
if (lit.sign() && m.is_eq(e, x, y))
block(m_g->find(x), m_g->find(y));
else if (!lit.sign() && m.is_distinct(e)) {
auto n = to_app(e)->get_num_args();
for (unsigned i = 0; i < n; ++i) {
auto a = m_g->find(to_app(e)->get_arg(i));
for (unsigned j = i + 1; j < n; ++j) {
auto b = m_g->find(to_app(e)->get_arg(j));
block(a, b);
}
}
}
else if (lit.sign()) {
auto a = m_g->find(e);
auto b = m_g->find(m.mk_true());
block(a, b);
}
}
void euf_plugin::init_egraph(euf::egraph& g, bool merge_eqs) {
ptr_vector<euf::enode> args;
m_stack.reset();
for (auto t : ctx.subterms()) {
args.reset();
if (is_app(t))
for (auto* arg : *to_app(t))
args.push_back(g.find(arg));
g.mk(t, 0, args.size(), args.data());
}
if (!g.find(m.mk_true()))
g.mk(m.mk_true(), 0, 0, nullptr);
if (!g.find(m.mk_false()))
g.mk(m.mk_false(), 0, 0, nullptr);
// merge all equalities
// check for conflict with disequalities during propagation
if (merge_eqs) {
TRACE("euf", tout << "root literals " << ctx.root_literals() << "\n");
for (auto lit : ctx.root_literals()) {
if (!ctx.is_true(lit))
lit.neg();
auto e = ctx.atom(lit.var());
expr* x, * y;
if (e && m.is_eq(e, x, y) && !lit.sign())
g.merge(g.find(x), g.find(y), to_ptr(lit));
else if (!lit.sign())
g.merge(g.find(e), g.find(m.mk_true()), to_ptr(lit));
}
g.propagate();
if (g.inconsistent())
resolve_conflict();
}
typedef obj_map<sort, unsigned> map1;
typedef obj_map<euf::enode, expr*> map2;
m_num_elems = alloc(map1);
m_root2value = alloc(map2);
m_pinned = alloc(expr_ref_vector, m);
for (auto n : g.nodes()) {
if (n->is_root() && is_user_sort(n->get_sort())) {
// verbose_stream() << "init root " << g.pp(n) << "\n";
unsigned num = 0;
m_num_elems->find(n->get_sort(), num);
expr* v = m.mk_model_value(num, n->get_sort());
m_pinned->push_back(v);
m_root2value->insert(n, v);
m_num_elems->insert(n->get_sort(), num + 1);
}
}
}
expr_ref euf_plugin::get_value(expr* e) {
if (m.is_model_value(e))
return expr_ref(e, m);
if (!m_g) {
m_g = alloc(euf::egraph, m);
init_egraph(*m_g, true);
}
auto n = m_g->find(e)->get_root();
VERIFY(m_root2value->find(n, e));
return expr_ref(e, m);
}
bool euf_plugin::include_func_interp(func_decl* f) const {
return is_uninterp(f) && f->get_arity() > 0;
}
bool euf_plugin::is_sat() {
for (auto& [f, ts] : m_app) {
if (ts.size() <= 1)
continue;
m_values.reset();
for (auto* t : ts) {
app* u;
if (!ctx.is_relevant(t))
continue;
if (m_values.find(t, u)) {
if (ctx.get_value(t) != ctx.get_value(u))
return false;
}
else
m_values.insert(t);
}
}
// validate_model();
return true;
}
void euf_plugin::validate_model() {
auto& g = *m_g;
for (auto lit : ctx.root_literals()) {
euf::enode* a, * b;
if (!ctx.is_true(lit))
continue;
auto e = ctx.atom(lit.var());
if (!e)
continue;
if (!ctx.is_relevant(e))
continue;
if (m.is_distinct(e))
continue;
if (m.is_eq(e)) {
a = g.find(to_app(e)->get_arg(0));
b = g.find(to_app(e)->get_arg(1));
}
if (lit.sign() && m.is_eq(e)) {
if (a->get_root() == b->get_root()) {
IF_VERBOSE(0, verbose_stream() << "not disequal " << lit << " " << mk_pp(e, m) << "\n");
ctx.display(verbose_stream());
UNREACHABLE();
}
}
else if (!lit.sign() && m.is_eq(e)) {
if (a->get_root() != b->get_root()) {
IF_VERBOSE(0, verbose_stream() << "not equal " << lit << " " << mk_pp(e, m) << "\n");
//UNREACHABLE();
}
}
else if (to_app(e)->get_family_id() != basic_family_id && lit.sign() && g.find(e)->get_root() != g.find(m.mk_false())->get_root()) {
IF_VERBOSE(0, verbose_stream() << "not alse " << lit << " " << mk_pp(e, m) << "\n");
//UNREACHABLE();
}
else if (to_app(e)->get_family_id() != basic_family_id && !lit.sign() && g.find(e)->get_root() != g.find(m.mk_true())->get_root()) {
IF_VERBOSE(0, verbose_stream() << "not true " << lit << " " << mk_pp(e, m) << "\n");
//UNREACHABLE();
}
}
}
bool euf_plugin::propagate() {
bool new_constraint = false;
for (auto & [f, ts] : m_app) {
if (ts.size() <= 1)
continue;
m_values.reset();
for (auto * t : ts) {
app* u;
if (!ctx.is_relevant(t))
continue;
if (m_values.find(t, u)) {
if (ctx.get_value(t) == ctx.get_value(u))
continue;
expr_ref_vector ors(m);
for (unsigned i = t->get_num_args(); i-- > 0; )
ors.push_back(m.mk_not(m.mk_eq(t->get_arg(i), u->get_arg(i))));
ors.push_back(m.mk_eq(t, u));
#if 0
verbose_stream() << "conflict: " << mk_bounded_pp(t, m) << " != " << mk_bounded_pp(u, m) << "\n";
verbose_stream() << "value " << ctx.get_value(t) << " != " << ctx.get_value(u) << "\n";
for (unsigned i = t->get_num_args(); i-- > 0; )
verbose_stream() << ctx.get_value(t->get_arg(i)) << " == " << ctx.get_value(u->get_arg(i)) << "\n";
#endif
expr_ref fml(m.mk_or(ors), m);
ctx.add_constraint(fml);
new_constraint = true;
}
else
m_values.insert(t);
}
}
for (auto lit : ctx.root_literals()) {
if (!ctx.is_true(lit))
continue;
auto e = ctx.atom(lit.var());
if (lit.sign() && e && m.is_distinct(e)) {
auto n = to_app(e)->get_num_args();
expr_ref_vector eqs(m);
for (unsigned i = 0; i < n; ++i) {
auto a = m_g->find(to_app(e)->get_arg(i));
for (unsigned j = i + 1; j < n; ++j) {
auto b = m_g->find(to_app(e)->get_arg(j));
if (a->get_root() == b->get_root())
goto done_distinct;
eqs.push_back(m.mk_eq(a->get_expr(), b->get_expr()));
}
}
// distinct(a, b, c) or a = b or a = c or b = c
eqs.push_back(e);
ctx.add_constraint(m.mk_or(eqs));
new_constraint = true;
done_distinct:
;
}
}
return new_constraint;
}
std::ostream& euf_plugin::display(std::ostream& out) const {
if (m_g)
m_g->display(out);
for (auto& [f, ts] : m_app) {
for (auto* t : ts)
out << mk_bounded_pp(t, m) << "\n";
out << "\n";
}
return out;
}
void euf_plugin::collect_statistics(statistics& st) const {
st.update("sls-euf-conflict", m_stats.m_num_conflicts);
}
void euf_plugin::reset_statistics() {
m_stats.reset();
}
}

View file

@ -0,0 +1,96 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_euf_plugin.h
Abstract:
Congruence Closure for SLS
Author:
Nikolaj Bjorner (nbjorner) 2024-06-24
--*/
#pragma once
#include "util/hashtable.h"
#include "ast/sls/sls_context.h"
#include "ast/euf/euf_egraph.h"
namespace sls {
class euf_plugin : public plugin {
struct stats {
unsigned m_num_conflicts = 0;
void reset() { memset(this, 0, sizeof(*this)); }
};
obj_map<func_decl, ptr_vector<app>> m_app;
struct value_hash {
euf_plugin& cc;
value_hash(euf_plugin& cc) : cc(cc) {}
unsigned operator()(app* t) const;
};
struct value_eq {
euf_plugin& cc;
value_eq(euf_plugin& cc) : cc(cc) {}
bool operator()(app* a, app* b) const;
};
hashtable<app*, value_hash, value_eq> m_values;
bool m_incremental = false;
unsigned m_incremental_mode = 0;
stats m_stats;
scoped_ptr<euf::egraph> m_g;
scoped_ptr<obj_map<sort, unsigned>> m_num_elems;
scoped_ptr<obj_map<euf::enode, expr*>> m_root2value;
scoped_ptr<expr_ref_vector> m_pinned;
void init_egraph(euf::egraph& g, bool merge_eqs);
sat::literal_vector m_stack, m_replay_stack;
void propagate_literal_incremental(sat::literal lit);
void propagate_literal_incremental_step(sat::literal lit);
void resolve();
sat::literal resolve_conflict();
void replay();
void propagate_literal_non_incremental(sat::literal lit);
bool is_user_sort(sort* s) { return s->get_family_id() == user_sort_family_id; }
size_t* to_ptr(sat::literal l) { return reinterpret_cast<size_t*>((size_t)(l.index() << 4)); };
sat::literal to_literal(size_t* p) { return sat::to_literal(static_cast<unsigned>(reinterpret_cast<size_t>(p) >> 4)); };
void validate_model();
public:
euf_plugin(context& c);
~euf_plugin() override;
expr_ref get_value(expr* e) override;
void initialize() override;
void start_propagation() override;
void propagate_literal(sat::literal lit) override;
bool propagate() override;
bool is_sat() override;
void register_term(expr* e) override;
std::ostream& display(std::ostream& out) const override;
bool set_value(expr* e, expr* v) override { return false; }
bool include_func_interp(func_decl* f) const override;
void repair_up(app* e) override {}
bool repair_down(app* e) override { return false; }
void repair_literal(sat::literal lit) override {}
void collect_statistics(statistics& st) const override;
void reset_statistics() override;
scoped_ptr<euf::egraph>& egraph() { return m_g; }
};
}

View file

@ -0,0 +1,315 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_smt_plugin.cpp
Abstract:
A Stochastic Local Search (SLS) Plugin.
Author:
Nikolaj Bjorner (nbjorner) 2024-07-10
--*/
#include "ast/sls/sls_smt_plugin.h"
#include "ast/for_each_expr.h"
#include "ast/bv_decl_plugin.h"
namespace sls {
smt_plugin::smt_plugin(smt_context& ctx) :
ctx(ctx),
m(ctx.get_manager()),
m_sls(),
m_sync(),
m_smt2sync_tr(m, m_sync),
m_smt2sls_tr(m, m_sls),
m_sync_uninterp(m_sync),
m_sls_uninterp(m_sls),
m_sync_values(m_sync),
m_context(m_sls, *this)
{
}
smt_plugin::~smt_plugin() {
SASSERT(!m_ddfw);
}
void smt_plugin::check(expr_ref_vector const& fmls, vector <sat::literal_vector> const& clauses) {
SASSERT(!m_ddfw);
// set up state for local search theory_sls here
m_result = l_undef;
m_completed = false;
m_units.reset();
m_has_units = false;
m_sls_model = nullptr;
m_ddfw = alloc(sat::ddfw);
m_ddfw->set_plugin(this);
m_ddfw->updt_params(ctx.get_params());
for (auto const& clause : clauses) {
m_ddfw->add(clause.size(), clause.data());
for (auto lit : clause)
add_shared_var(lit.var(), lit.var());
}
for (auto v : m_shared_bool_vars) {
expr* e = ctx.bool_var2expr(v);
if (!e)
continue;
m_context.register_atom(v, m_smt2sls_tr(e));
for (auto t : subterms::all(expr_ref(e, m)))
add_shared_term(t);
}
for (auto fml : fmls)
m_context.add_constraint(m_smt2sls_tr(fml));
for (unsigned v = 0; v < ctx.get_num_bool_vars(); ++v) {
expr* e = ctx.bool_var2expr(v);
if (!e)
continue;
expr_ref sls_e(m_sls);
sls_e = m_smt2sls_tr(e);
auto w = m_context.atom2bool_var(sls_e);
if (w == sat::null_bool_var)
continue;
add_shared_var(v, w);
for (auto t : subterms::all(expr_ref(e, m)))
add_shared_term(t);
}
m_thread = std::thread([this]() { run(); });
}
void smt_plugin::run() {
if (!m_ddfw)
return;
m_result = m_ddfw->check(0, nullptr);
m_ddfw->collect_statistics(m_st);
IF_VERBOSE(1, verbose_stream() << "sls-result " << m_result << "\n");
m_completed = true;
}
void smt_plugin::finalize(model_ref& mdl, ::statistics& st) {
auto* d = m_ddfw;
if (!d)
return;
bool canceled = !m_completed;
IF_VERBOSE(3, verbose_stream() << "finalize\n");
if (!m_completed)
d->rlimit().cancel();
if (m_thread.joinable())
m_thread.join();
SASSERT(m_completed);
st.copy(m_st);
mdl = nullptr;
if (m_result == l_true && m_sls_model) {
ast_translation tr(m_sls, m);
mdl = m_sls_model->translate(tr);
TRACE("sls", tout << "model: " << *m_sls_model << "\n";);
if (!canceled)
ctx.set_finished();
}
m_ddfw = nullptr;
// m_ddfw owns the pointer to smt_plugin and destructs it.
dealloc(d);
}
std::ostream& smt_plugin::display(std::ostream& out) {
m_ddfw->display(out);
m_context.display(out);
return out;
}
bool smt_plugin::is_shared(sat::literal lit) {
auto w = m_smt_bool_var2sls_bool_var.get(lit.var(), sat::null_bool_var);
if (w != sat::null_bool_var)
return true;
auto e = ctx.bool_var2expr(lit.var());
expr* t = nullptr;
if (!e)
return false;
bv_util bv(m);
if (bv.is_bit2bool(e, t) && m_shared_terms.contains(t->get_id())) {
verbose_stream() << "shared bit2bool " << mk_bounded_pp(e, ctx.get_manager()) << "\n";
return true;
}
// if arith.is_le(e, s, t) && t is a numeral, s is shared-term....
return false;
}
void smt_plugin::add_shared_var(sat::bool_var v, sat::bool_var w) {
m_smt_bool_var2sls_bool_var.setx(v, w, sat::null_bool_var);
m_sls_bool_var2smt_bool_var.setx(w, v, sat::null_bool_var);
m_sls_phase.reserve(v + 1);
m_sat_phase.reserve(v + 1);
m_rewards.reserve(v + 1);
m_shared_bool_vars.insert(v);
}
void smt_plugin::add_unit(sat::literal lit) {
if (!is_shared(lit))
return;
std::lock_guard<std::mutex> lock(m_mutex);
m_units.push_back(lit);
m_has_units = true;
}
void smt_plugin::import_phase_from_smt() {
if (m_has_new_sat_phase)
return;
m_has_new_sat_phase = true;
IF_VERBOSE(3, verbose_stream() << "new SMT -> SLS phase\n");
ctx.set_has_new_best_phase(false);
std::lock_guard<std::mutex> lock(m_mutex);
for (auto v : m_shared_bool_vars)
m_sat_phase[v] = ctx.get_best_phase(v);
}
bool smt_plugin::export_to_sls() {
bool updated = false;
if (export_units_to_sls())
updated = true;
if (export_phase_to_sls())
updated = true;
return updated;
}
bool smt_plugin::export_phase_to_sls() {
if (!m_has_new_sat_phase)
return false;
std::lock_guard<std::mutex> lock(m_mutex);
IF_VERBOSE(3, verbose_stream() << "SMT -> SLS phase\n");
for (auto v : m_shared_bool_vars) {
auto w = m_smt_bool_var2sls_bool_var[v];
if (m_sat_phase[v] != is_true(sat::literal(w, false)))
flip(w);
m_ddfw->bias(w) = m_sat_phase[v] ? 1 : -1;
}
m_has_new_sat_phase = false;
return true;
}
bool smt_plugin::export_units_to_sls() {
if (!m_has_units)
return false;
std::lock_guard<std::mutex> lock(m_mutex);
IF_VERBOSE(2, verbose_stream() << "SMT -> SLS units " << m_units << "\n");
for (auto lit : m_units) {
auto v = lit.var();
if (m_shared_bool_vars.contains(v)) {
auto w = m_smt_bool_var2sls_bool_var[v];
sat::literal sls_lit(w, lit.sign());
IF_VERBOSE(10, verbose_stream() << "unit " << sls_lit << "\n");
m_ddfw->add(1, &sls_lit);
}
else {
IF_VERBOSE(0, verbose_stream() << "value restriction " << lit << " "
<< mk_bounded_pp(ctx.bool_var2expr(lit.var()), m) << "\n");
}
}
m_has_units = false;
m_units.reset();
return true;
}
void smt_plugin::export_from_sls() {
if (unsat().size() > m_min_unsat_size)
return;
m_min_unsat_size = unsat().size();
std::lock_guard<std::mutex> lock(m_mutex);
for (auto v : m_shared_bool_vars) {
auto w = m_smt_bool_var2sls_bool_var[v];
m_rewards[v] = m_ddfw->get_reward_avg(w);
//verbose_stream() << v << " " << w << "\n";
VERIFY(m_ddfw->get_model().size() > w);
VERIFY(m_sls_phase.size() > v);
m_sls_phase[v] = l_true == m_ddfw->get_model()[w];
m_has_new_sls_phase = true;
}
// export_values_from_sls();
}
void smt_plugin::export_values_from_sls() {
IF_VERBOSE(3, verbose_stream() << "import values from sls\n");
std::lock_guard<std::mutex> lock(m_mutex);
for (auto const& [t, t_sync] : m_sls2sync_uninterp) {
expr_ref val_t = m_context.get_value(t_sync);
m_sync_values.set(t_sync->get_id(), val_t.get());
}
m_has_new_sls_values = true;
}
void smt_plugin::import_from_sls() {
export_activity_to_smt();
export_values_to_smt();
export_phase_to_smt();
}
void smt_plugin::export_activity_to_smt() {
}
void smt_plugin::export_values_to_smt() {
if (!m_has_new_sls_values)
return;
IF_VERBOSE(3, verbose_stream() << "SLS -> SMT values\n");
std::lock_guard<std::mutex> lock(m_mutex);
ast_translation tr(m_sync, m);
for (auto const& [t, t_sync] : m_smt2sync_uninterp) {
expr* sync_val = m_sync_values.get(t_sync->get_id(), nullptr);
if (!sync_val)
continue;
expr_ref val(tr(sync_val), m);
ctx.initialize_value(t, val);
}
m_has_new_sls_values = false;
}
void smt_plugin::export_phase_to_smt() {
if (!m_has_new_sls_phase)
return;
std::lock_guard<std::mutex> lock(m_mutex);
IF_VERBOSE(3, verbose_stream() << "SLS -> SMT phase\n");
for (auto v : m_shared_bool_vars) {
auto w = m_smt_bool_var2sls_bool_var[v];
ctx.force_phase(sat::literal(w, m_sls_phase[v]));
}
m_has_new_sls_phase = false;
}
void smt_plugin::add_shared_term(expr* t) {
m_shared_terms.insert(t->get_id());
if (is_uninterp(t))
add_uninterp(t);
}
void smt_plugin::add_uninterp(expr* smt_t) {
auto sync_t = m_smt2sync_tr(smt_t);
auto sls_t = m_smt2sls_tr(smt_t);
m_sync_uninterp.push_back(sync_t);
m_sls_uninterp.push_back(sls_t);
m_smt2sync_uninterp.insert(smt_t, sync_t);
m_sls2sync_uninterp.insert(sls_t, sync_t);
}
void smt_plugin::on_save_model() {
TRACE("sls", display(tout));
while (unsat().empty()) {
m_context.check();
if (!m_new_clause_added)
break;
m_ddfw->reinit();
m_new_clause_added = false;
}
// export_from_sls();
}
}

View file

@ -0,0 +1,158 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_smt_plugin.h
Abstract:
A Stochastic Local Search (SLS) Plugin.
Author:
Nikolaj Bjorner (nbjorner) 2024-07-10
--*/
#pragma once
#include "ast/sls/sls_context.h"
#include "ast/sls/sat_ddfw.h"
#include "util/statistics.h"
#include <thread>
#include <mutex>
namespace sls {
class smt_context {
public:
virtual ~smt_context() {}
virtual ast_manager& get_manager() = 0;
virtual params_ref get_params() = 0;
virtual void initialize_value(expr* t, expr* v) = 0;
virtual void force_phase(sat::literal lit) = 0;
virtual void set_has_new_best_phase(bool b) = 0;
virtual bool get_best_phase(sat::bool_var v) = 0;
virtual expr* bool_var2expr(sat::bool_var v) = 0;
virtual void set_finished() = 0;
virtual unsigned get_num_bool_vars() const = 0;
};
//
// m is accessed by the main thread
// m_sls is accessed by the sls thread
// m_sync is accessed by both
//
class smt_plugin : public sat::local_search_plugin, public sat_solver_context {
smt_context& ctx;
ast_manager& m;
ast_manager m_sls;
ast_manager m_sync;
ast_translation m_smt2sync_tr, m_smt2sls_tr;
expr_ref_vector m_sync_uninterp, m_sls_uninterp;
expr_ref_vector m_sync_values;
sat::ddfw* m_ddfw = nullptr;
sls::context m_context;
std::atomic<lbool> m_result;
std::atomic<bool> m_completed, m_has_units;
std::thread m_thread;
std::mutex m_mutex;
sat::literal_vector m_units;
model_ref m_sls_model;
::statistics m_st;
bool m_new_clause_added = false;
unsigned m_min_unsat_size = UINT_MAX;
obj_map<expr, expr*> m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp
obj_map<expr, expr*> m_smt2sync_uninterp; // hashtable from external uninterp to sync uninterp
std::atomic<bool> m_has_new_sls_values = false;
uint_set m_shared_bool_vars, m_shared_terms;
svector<bool> m_sat_phase;
std::atomic<bool> m_has_new_sat_phase = false;
std::atomic<bool> m_has_new_sls_phase = false;
svector<bool> m_sls_phase;
svector<double> m_rewards;
svector<sat::bool_var> m_smt_bool_var2sls_bool_var, m_sls_bool_var2smt_bool_var;
bool is_shared(sat::literal lit);
void run();
void add_shared_term(expr* t);
void add_uninterp(expr* smt_t);
void add_shared_var(sat::bool_var v, sat::bool_var w);
void import_phase_from_smt();
void import_values_from_sls();
void export_values_from_sls();
void import_activity_from_sls();
bool export_phase_to_sls();
bool export_units_to_sls();
void export_values_to_smt();
void export_activity_to_smt();
void export_phase_to_smt();
void export_from_sls();
friend class sat::ddfw;
~smt_plugin();
public:
smt_plugin(smt_context& ctx);
// interface to calling solver:
void check(expr_ref_vector const& fmls, vector <sat::literal_vector> const& clauses);
void finalize(model_ref& md, ::statistics& st);
void updt_params(params_ref& p) {}
std::ostream& display(std::ostream& out) override;
bool export_to_sls();
void import_from_sls();
bool completed() { return m_completed; }
void add_unit(sat::literal lit);
// local_search_plugin:
void on_restart() override {
if (export_to_sls())
m_ddfw->reinit();
}
void on_save_model() override;
void on_model(model_ref& mdl) override {
IF_VERBOSE(3, verbose_stream() << "on-model " << "\n");
m_sls_model = mdl;
}
void init_search() override {}
void finish_search() override {}
void on_rescale() override {}
// sat_solver_context:
vector<sat::clause_info> const& clauses() const override { return m_ddfw->clauses(); }
sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); }
ptr_iterator<unsigned> get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); }
void flip(sat::bool_var v) override {
m_ddfw->flip(v);
}
double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); }
double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; }
bool is_true(sat::literal lit) override {
return m_ddfw->get_value(lit.var()) != lit.sign();
}
unsigned num_vars() const override { return m_ddfw->num_vars(); }
indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); }
sat::bool_var add_var() override {
return m_ddfw->add_var();
}
void add_clause(unsigned n, sat::literal const* lits) override {
m_ddfw->add(n, lits);
m_new_clause_added = true;
}
void force_restart() override { m_ddfw->force_restart(); }
};
}

View file

@ -0,0 +1,171 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_smt_solver.cpp
Abstract:
A Stochastic Local Search (SLS) Solver.
Author:
Nikolaj Bjorner (nbjorner) 2024-07-10
--*/
#include "ast/sls/sls_context.h"
#include "ast/sls/sat_ddfw.h"
#include "ast/sls/sls_smt_solver.h"
#include "ast/ast_ll_pp.h"
namespace sls {
class smt_solver::solver_ctx : public sat::local_search_plugin, public sls::sat_solver_context {
ast_manager& m;
sat::ddfw& m_ddfw;
context m_context;
bool m_dirty = false;
bool m_new_constraint = false;
model_ref m_model;
obj_map<expr, sat::literal> m_expr2lit;
public:
solver_ctx(ast_manager& m, sat::ddfw& d) :
m(m), m_ddfw(d), m_context(m, *this) {
m_ddfw.set_plugin(this);
m.limit().push_child(&m_ddfw.rlimit());
}
~solver_ctx() override {
m.limit().pop_child(&m_ddfw.rlimit());
}
void init_search() override {}
void finish_search() override {}
void on_rescale() override {}
void on_restart() override {
m_context.on_restart();
}
bool m_on_save_model = false;
void on_save_model() override {
if (m_on_save_model)
return;
flet<bool> _on_save_model(m_on_save_model, true);
CTRACE("sls", unsat().empty(), display(tout));
while (unsat().empty()) {
m_context.check();
if (!m_new_constraint)
break;
TRACE("sls", display(tout));
//m_ddfw.simplify();
m_ddfw.reinit();
m_new_constraint = false;
}
}
void on_model(model_ref& mdl) override {
IF_VERBOSE(1, verbose_stream() << "on-model " << "\n");
m_model = mdl;
}
void register_atom(sat::bool_var v, expr* e) {
m_context.register_atom(v, e);
}
std::ostream& display(std::ostream& out) override {
m_ddfw.display(out);
m_context.display(out);
return out;
}
vector<sat::clause_info> const& clauses() const override { return m_ddfw.clauses(); }
sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw.get_clause_info(idx); }
ptr_iterator<unsigned> get_use_list(sat::literal lit) override { return m_ddfw.use_list(lit); }
void flip(sat::bool_var v) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; m_ddfw.flip(v); }
double reward(sat::bool_var v) override { return m_ddfw.get_reward(v); }
double get_weigth(unsigned clause_idx) override { return m_ddfw.get_clause_info(clause_idx).m_weight; }
bool is_true(sat::literal lit) override { return m_ddfw.get_value(lit.var()) != lit.sign(); }
unsigned num_vars() const override { return m_ddfw.num_vars(); }
indexed_uint_set const& unsat() const override { return m_ddfw.unsat_set(); }
sat::bool_var add_var() override { m_dirty = true; return m_ddfw.add_var(); }
void add_clause(expr* f) { m_context.add_clause(f); }
void force_restart() override { m_ddfw.force_restart(); }
void add_clause(unsigned n, sat::literal const* lits) override {
m_ddfw.add(n, lits);
m_new_constraint = true;
}
sat::literal mk_literal() {
sat::bool_var v = add_var();
return sat::literal(v, false);
}
model_ref get_model() { return m_model; }
void collect_statistics(statistics& st) {
m_ddfw.collect_statistics(st);
m_context.collect_statistics(st);
}
void reset_statistics() {
m_ddfw.reset_statistics();
m_context.reset_statistics();
}
void updt_params(params_ref const& p) {
m_ddfw.updt_params(p);
m_context.updt_params(p);
}
};
smt_solver::smt_solver(ast_manager& m, params_ref const& p):
m(m),
m_solver_ctx(alloc(solver_ctx, m, m_ddfw)),
m_assertions(m) {
m_solver_ctx->updt_params(p);
}
smt_solver::~smt_solver() {
}
void smt_solver::assert_expr(expr* e) {
if (m.is_and(e)) {
for (expr* arg : *to_app(e))
assert_expr(arg);
}
else
m_assertions.push_back(e);
}
lbool smt_solver::check() {
for (auto f : m_assertions)
m_solver_ctx->add_clause(f);
IF_VERBOSE(10, m_solver_ctx->display(verbose_stream()));
return m_ddfw.check(0, nullptr);
}
model_ref smt_solver::get_model() {
return m_solver_ctx->get_model();
}
std::ostream& smt_solver::display(std::ostream& out) {
return m_solver_ctx->display(out);
}
void smt_solver::collect_statistics(statistics& st) {
m_solver_ctx->collect_statistics(st);
}
void smt_solver::reset_statistics() {
m_solver_ctx->reset_statistics();
}
}

View file

@ -0,0 +1,44 @@
/*++
Copyright (c) 2024 Microsoft Corporation
Module Name:
sls_smt_solver.h
Abstract:
A Stochastic Local Search (SLS) Solver.
Author:
Nikolaj Bjorner (nbjorner) 2024-07-10
--*/
#pragma once
#include "ast/sls/sls_context.h"
#include "ast/sls/sat_ddfw.h"
namespace sls {
class smt_solver {
ast_manager& m;
class solver_ctx;
sat::ddfw m_ddfw;
solver_ctx* m_solver_ctx = nullptr;
expr_ref_vector m_assertions;
statistics m_st;
public:
smt_solver(ast_manager& m, params_ref const& p);
~smt_solver();
void assert_expr(expr* e);
lbool check();
model_ref get_model();
void updt_params(params_ref& p) {}
void collect_statistics(statistics& st);
std::ostream& display(std::ostream& out);
void reset_statistics();
};
}