3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-22 16:45:31 +00:00

move sat_params to params directory, add op_def repair options

This commit is contained in:
Nikolaj Bjorner 2024-07-06 13:26:39 -07:00
parent 3ff60a4af0
commit 833f524887
22 changed files with 298 additions and 58 deletions

View file

@ -365,6 +365,7 @@ public:
MATCH_BINARY(is_div0);
MATCH_BINARY(is_idiv0);
MATCH_BINARY(is_power);
MATCH_BINARY(is_power0);
MATCH_UNARY(is_sin);
MATCH_UNARY(is_asin);

View file

@ -7,7 +7,7 @@ Module Name:
Abstract:
Local search dispatch for NIA
Local search dispatch for arithmetic
Author:
@ -41,10 +41,6 @@ namespace sls {
check_ineqs();
}
template<typename num_t>
void arith_base<num_t>::store_best_values() {
}
// distance to true
template<typename num_t>
num_t arith_base<num_t>::dtt(bool sign, num_t const& args, ineq const& ineq) const {
@ -111,7 +107,7 @@ namespace sls {
template<typename num_t>
num_t arith_base<num_t>::divide(var_t v, num_t const& delta, num_t const& coeff) {
if (m_vars[v].m_kind == var_kind::REAL)
if (m_vars[v].m_sort == var_sort::REAL)
return delta / coeff;
return div(delta + abs(coeff) - 1, coeff);
}
@ -346,7 +342,7 @@ namespace sls {
if (value(v) != sum)
m_vars_to_update.push_back({ v, sum });
}
if (vi.m_add_idx != UINT_MAX || vi.m_mul_idx != UINT_MAX)
if (vi.m_def_idx != UINT_MAX)
// add repair actions for additions and multiplications
m_defs_to_update.push_back(v);
}
@ -364,7 +360,6 @@ namespace sls {
ineq.m_args.push_back({ c, v });
}
bool arith_base<checked_int64<true>>::is_num(expr* e, checked_int64<true>& i) {
rational r;
if (a.is_numeral(e, r)) {
@ -390,10 +385,10 @@ namespace sls {
auto v = m_expr2var.get(e->get_id(), UINT_MAX);
expr* x, * y;
num_t i;
if (v != UINT_MAX)
add_arg(term, coeff, v);
else if (is_num(e, i))
term.m_coeff += coeff * i;
if (v != UINT_MAX)
add_arg(term, coeff, v);
else if (is_num(e, i))
term.m_coeff += coeff * i;
else if (a.is_add(e)) {
for (expr* arg : *to_app(e))
add_args(term, arg, coeff);
@ -424,19 +419,80 @@ namespace sls {
num_t prod(1);
for (auto w : m)
m_vars[w].m_muls.push_back(idx), prod *= value(w);
m_vars[v].m_mul_idx = idx;
m_vars[v].m_def_idx = idx;
m_vars[v].m_op = arith_op_kind::OP_MUL;
m_vars[v].m_value = prod;
add_arg(term, c, v);
break;
}
}
}
else if (a.is_uminus(e, x))
add_args(term, x, -coeff);
else if (is_uninterp(e))
else if (a.is_uminus(e, x))
add_args(term, x, -coeff);
else if (a.is_mod(e, x, y) || a.is_mod0(e, x, y))
add_arg(term, coeff, mk_op(arith_op_kind::OP_MOD, e, x, y));
else if (a.is_idiv(e, x, y) || a.is_idiv0(e, x, y))
add_arg(term, coeff, mk_op(arith_op_kind::OP_IDIV, e, x, y));
else if (a.is_div(e, x, y) || a.is_div0(e, x, y))
add_arg(term, coeff, mk_op(arith_op_kind::OP_DIV, e, x, y));
else if (a.is_rem(e, x, y))
add_arg(term, coeff, mk_op(arith_op_kind::OP_REM, e, x, y));
else if (a.is_power(e, x, y) || a.is_power0(e, x, y))
add_arg(term, coeff, mk_op(arith_op_kind::OP_POWER, e, x, y));
else if (a.is_abs(e, x))
add_arg(term, coeff, mk_op(arith_op_kind::OP_ABS, e, x, x));
else if (a.is_to_int(e, x))
add_arg(term, coeff, mk_op(arith_op_kind::OP_TO_INT, e, x, x));
else if (a.is_to_real(e, x))
add_arg(term, coeff, mk_op(arith_op_kind::OP_TO_REAL, e, x, x));
else if (is_uninterp(e))
add_arg(term, coeff, mk_var(e));
else
else if (a.is_arith_expr(e)) {
NOT_IMPLEMENTED_YET();
}
else {
NOT_IMPLEMENTED_YET();
}
}
template<typename num_t>
typename arith_base<num_t>::var_t arith_base<num_t>::mk_op(arith_op_kind k, expr* e, expr* x, expr* y) {
auto v = mk_var(e);
auto w = mk_term(x);
auto u = mk_term(y);
unsigned idx = m_ops.size();
num_t val;
switch (k) {
case arith_op_kind::OP_MOD:
if (value(v) != 0)
val = mod(value(w), value(v));
break;
case arith_op_kind::OP_REM:
if (value(v) != 0) {
val = value(w);
val %= value(v);
}
break;
case arith_op_kind::OP_IDIV:
if (value(v) != 0)
val = div(value(w), value(v));
break;
case arith_op_kind::OP_DIV:
if (value(v) != 0)
val = value(w) / value(v);
break;
case arith_op_kind::OP_ABS:
val = abs(value(w));
break;
default:
NOT_IMPLEMENTED_YET();
break;
}
m_ops.push_back({v, k, v, w});
m_vars[v].m_def_idx = idx;
m_vars[v].m_op = k;
m_vars[v].m_value = val;
return v;
}
template<typename num_t>
@ -454,18 +510,19 @@ namespace sls {
m_adds.push_back({ t.m_args, t.m_coeff, v });
for (auto const& [c, w] : t.m_args)
m_vars[w].m_adds.push_back(idx), sum += c * value(w);
m_vars[v].m_add_idx = idx;
m_vars[v].m_def_idx = idx;
m_vars[v].m_op = arith_op_kind::OP_ADD;
m_vars[v].m_value = sum;
return v;
}
template<typename num_t>
unsigned arith_base<num_t>::mk_var(expr* e) {
unsigned v = m_expr2var.get(e->get_id(), UINT_MAX);
typename arith_base<num_t>::var_t arith_base<num_t>::mk_var(expr* e) {
var_t v = m_expr2var.get(e->get_id(), UINT_MAX);
if (v == UINT_MAX) {
v = m_vars.size();
m_expr2var.setx(e->get_id(), v, UINT_MAX);
m_vars.push_back(var_info(e, a.is_int(e) ? var_kind::INT : var_kind::REAL));
m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL));
}
return v;
}
@ -504,6 +561,14 @@ namespace sls {
add_args(ineq, y, num_t(-1));
init_ineq(bv, ineq);
}
else if (a.is_is_int(e, x))
{
NOT_IMPLEMENTED_YET();
}
#if 0
else if (a.is_idivides(e, x, y))
NOT_IMPLEMENTED_YET();
#endif
else {
SASSERT(!a.is_arith_expr(e));
}
@ -562,10 +627,42 @@ namespace sls {
auto v = m_defs_to_update.back();
m_defs_to_update.pop_back();
auto const& vi = m_vars[v];
if (vi.m_mul_idx != UINT_MAX)
repair_mul(m_muls[vi.m_mul_idx]);
if (vi.m_add_idx != UINT_MAX)
repair_add(m_adds[vi.m_add_idx]);
switch (vi.m_op) {
case arith_op_kind::LAST_ARITH_OP:
break;
case arith_op_kind::OP_ADD:
repair_add(m_adds[vi.m_def_idx]);
break;
case arith_op_kind::OP_MUL:
repair_mul(m_muls[vi.m_def_idx]);
break;
case arith_op_kind::OP_MOD:
repair_mod(m_ops[vi.m_def_idx]);
break;
case arith_op_kind::OP_REM:
repair_rem(m_ops[vi.m_def_idx]);
break;
case arith_op_kind::OP_POWER:
repair_power(m_ops[vi.m_def_idx]);
break;
case arith_op_kind::OP_IDIV:
repair_idiv(m_ops[vi.m_def_idx]);
break;
case arith_op_kind::OP_DIV:
repair_div(m_ops[vi.m_def_idx]);
break;
case arith_op_kind::OP_ABS:
repair_abs(m_ops[vi.m_def_idx]);
break;
case arith_op_kind::OP_TO_INT:
repair_to_int(m_ops[vi.m_def_idx]);
break;
case arith_op_kind::OP_TO_REAL:
repair_to_real(m_ops[vi.m_def_idx]);
break;
default:
NOT_IMPLEMENTED_YET();
}
}
}
@ -584,7 +681,7 @@ namespace sls {
else {
auto const& [c, w] = coeffs[rand() % coeffs.size()];
num_t delta = sum - val;
bool is_real = m_vars[w].m_kind == var_kind::REAL;
bool is_real = m_vars[w].m_sort == var_sort::REAL;
bool round_down = rand() % 2 == 0;
num_t new_value = value(w) + (is_real ? delta / c : round_down ? div(delta, c) : div(delta + c - 1, c));
update(w, new_value);
@ -627,7 +724,7 @@ namespace sls {
auto w = md.m_monomial[rand() % md.m_monomial.size()];
auto old_value = value(w);
num_t new_value;
if (m_vars[w].m_kind == var_kind::REAL)
if (m_vars[w].m_sort == var_sort::REAL)
new_value = old_value * val / product;
else
new_value = divide(w, old_value * val, product);
@ -650,6 +747,112 @@ namespace sls {
}
}
template<typename num_t>
void arith_base<num_t>::repair_rem(op_def const& od) {
auto val = value(od.m_var);
auto v1 = value(od.m_arg1);
auto v2 = value(od.m_arg2);
if (v2 == 0)
return;
IF_VERBOSE(0, verbose_stream() << "todo repair rem");
// bail
v1 %= v2;
update(od.m_var, v1);
}
template<typename num_t>
void arith_base<num_t>::repair_abs(op_def const& od) {
auto val = value(od.m_var);
auto v1 = value(od.m_arg1);
if (val < 0)
update(od.m_var, abs(v1));
else if (rand() % 2 == 0)
update(od.m_arg1, val);
else
update(od.m_arg1, -val);
}
template<typename num_t>
void arith_base<num_t>::repair_to_int(op_def const& od) {
NOT_IMPLEMENTED_YET();
}
template<typename num_t>
void arith_base<num_t>::repair_to_real(op_def const& od) {
if (rand() % 20 == 0)
update(od.m_var, value(od.m_arg1));
else
update(od.m_arg1, value(od.m_arg1));
}
template<typename num_t>
void arith_base<num_t>::repair_power(op_def const& od) {
auto val = value(od.m_var);
auto v1 = value(od.m_arg1);
auto v2 = value(od.m_arg2);
if (v1 == 0 && v2 == 0)
return;
IF_VERBOSE(0, verbose_stream() << "todo repair ^");
NOT_IMPLEMENTED_YET();
}
template<typename num_t>
void arith_base<num_t>::repair_mod(op_def const& od) {
auto val = value(od.m_var);
auto v1 = value(od.m_arg1);
auto v2 = value(od.m_arg2);
// repair first argument
if (val >= 0 && val < v2) {
auto v3 = mod(v1, v2);
if (v3 == val)
return;
// find r, such that mod(v1 + r, v2) = val
// v1 := v1 + val - v3 (+/- v2)
v1 += val - v3;
switch (rand() % 6) {
case 0:
v1 += v2;
break;
case 1:
v1 -= v2;
break;
default:
break;
}
update(od.m_arg1, v1);
return;
}
if (v2 == 0)
return;
// bail
update(od.m_var, mod(v1, v2));
}
template<typename num_t>
void arith_base<num_t>::repair_idiv(op_def const& od) {
auto val = value(od.m_var);
auto v1 = value(od.m_arg1);
auto v2 = value(od.m_arg2);
if (v2 == 0)
return;
IF_VERBOSE(0, verbose_stream() << "todo repair div");
// bail
update(od.m_var, div(v1, v2));
}
template<typename num_t>
void arith_base<num_t>::repair_div(op_def const& od) {
auto val = value(od.m_var);
auto v1 = value(od.m_arg1);
auto v2 = value(od.m_arg2);
if (v2 == 0)
return;
IF_VERBOSE(0, verbose_stream() << "todo repair /");
// bail
update(od.m_var, v1 / v2);
}
template<typename num_t>
double arith_base<num_t>::reward(sat::literal lit) {
if (m_dscore_mode)
@ -819,6 +1022,10 @@ namespace sls {
out << ad.m_coeff;
out << "\n";
}
for (auto od : m_ops) {
out << "v" << od.m_var << " := ";
out << "v" << od.m_arg1 << " op-" << od.m_op << " v" << od.m_arg2 << "\n";
}
return out;
}

View file

@ -3,7 +3,7 @@ Copyright (c) 2020 Microsoft Corporation
Module Name:
arith_local_search.h
sls_arith_base.h
Abstract:
@ -30,7 +30,7 @@ namespace sls {
template<typename num_t>
class arith_base : public plugin {
enum class ineq_kind { EQ, LE, LT};
enum class var_kind { INT, REAL };
enum class var_sort { INT, REAL };
typedef unsigned var_t;
typedef unsigned atom_t;
@ -86,13 +86,13 @@ namespace sls {
private:
struct var_info {
var_info(expr* e, var_kind k): m_expr(e), m_kind(k) {}
var_info(expr* e, var_sort k): m_expr(e), m_sort(k) {}
expr* m_expr;
num_t m_value{ 0 };
num_t m_best_value{ 0 };
var_kind m_kind;
unsigned m_add_idx = UINT_MAX;
unsigned m_mul_idx = UINT_MAX;
var_sort m_sort;
arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP;
unsigned m_def_idx = UINT_MAX;
vector<std::pair<num_t, sat::bool_var>> m_bool_vars;
unsigned_vector m_muls;
unsigned_vector m_adds;
@ -106,6 +106,12 @@ namespace sls {
struct add_def : public linear_term {
unsigned m_var;
};
struct op_def {
unsigned m_var;
arith_op_kind m_op;
unsigned m_arg1, m_arg2;
};
stats m_stats;
config m_config;
@ -113,16 +119,25 @@ namespace sls {
vector<var_info> m_vars;
vector<mul_def> m_muls;
vector<add_def> m_adds;
vector<op_def> m_ops;
unsigned_vector m_expr2var;
bool m_dscore_mode = false;
arith_util a;
unsigned_vector m_defs_to_update;
vector<std::pair<var_t, num_t>> m_vars_to_update;
unsigned get_num_vars() const { return m_vars.size(); }
void repair_mul(mul_def const& md);
void repair_add(add_def const& ad);
unsigned_vector m_defs_to_update;
vector<std::pair<var_t, num_t>> m_vars_to_update;
void repair_mod(op_def const& od);
void repair_idiv(op_def const& od);
void repair_div(op_def const& od);
void repair_rem(op_def const& od);
void repair_power(op_def const& od);
void repair_abs(op_def const& od);
void repair_to_int(op_def const& od);
void repair_to_real(op_def const& od);
void repair_defs_and_updates();
void repair_defs();
void repair_updates();
@ -149,12 +164,13 @@ namespace sls {
double dtt_reward(sat::literal lit);
double dscore(var_t v, num_t const& new_value) const;
void save_best_values();
void store_best_values();
unsigned mk_var(expr* e);
ineq& new_ineq(ineq_kind op, num_t const& bound);
var_t mk_var(expr* e);
var_t mk_term(expr* e);
var_t mk_op(arith_op_kind k, expr* e, expr* x, expr* y);
void add_arg(linear_term& term, num_t const& c, var_t v);
void add_args(linear_term& term, expr* e, num_t const& sign);
var_t mk_term(expr* e);
ineq& new_ineq(ineq_kind op, num_t const& bound);
void init_ineq(sat::bool_var bv, ineq& i);
num_t divide(var_t v, num_t const& delta, num_t const& coeff);

View file

@ -42,7 +42,7 @@ Notes:
#include "ast/converters/generic_model_converter.h"
#include "ackermannization/ackermannize_bv_tactic.h"
#include "sat/sat_solver/inc_sat_solver.h"
#include "sat/sat_params.hpp"
#include "params/sat_params.hpp"
#include "opt/opt_context.h"
#include "opt/opt_solver.h"
#include "opt/opt_params.hpp"

View file

@ -20,7 +20,7 @@ Author:
#include "ast/pb_decl_plugin.h"
#include "opt/maxsmt.h"
#include "opt/opt_lns.h"
#include "sat/sat_params.hpp"
#include "params/sat_params.hpp"
#include <algorithm>
namespace opt {

View file

@ -14,6 +14,7 @@ z3_add_component(params
pattern_inference_params_helper.pyg
poly_rewriter_params.pyg
rewriter_params.pyg
sat_params.pyg
seq_rewriter_params.pyg
sls_params.pyg
solver_params.pyg

View file

@ -43,7 +43,6 @@ z3_add_component(sat
params
PYG_FILES
sat_asymm_branch_params.pyg
sat_params.pyg
sat_scc_params.pyg
sat_simplifier_params.pyg
)

View file

@ -16,9 +16,9 @@ Author:
Revision History:
--*/
#include "params/sat_params.hpp"
#include "sat/sat_config.h"
#include "sat/sat_types.h"
#include "sat/sat_params.hpp"
#include "sat/sat_simplifier_params.hpp"
#include "params/solver_params.hpp"

View file

@ -28,7 +28,7 @@
#include "util/luby.h"
#include "sat/sat_ddfw.h"
#include "sat/sat_solver.h"
#include "sat/sat_params.hpp"
#include "params/sat_params.hpp"
namespace sat {

View file

@ -25,12 +25,9 @@
#include "util/params.h"
#include "util/ema.h"
#include "util/sat_sls.h"
#include "sat/sat_clause.h"
#include "util/map.h"
#include "sat/sat_types.h"
namespace arith {
class sls;
}
namespace sat {
class solver;

View file

@ -19,7 +19,7 @@ Notes:
#include "sat/sat_local_search.h"
#include "sat/sat_solver.h"
#include "sat/sat_params.hpp"
#include "params/sat_params.hpp"
#include "util/timer.h"
namespace sat {

View file

@ -39,7 +39,7 @@ Notes:
#include "model/model_v2_pp.h"
#include "model/model_evaluator.h"
#include "sat/sat_solver.h"
#include "sat/sat_params.hpp"
#include "params/sat_params.hpp"
#include "sat/smt/euf_solver.h"
#include "sat/tactic/goal2sat.h"
#include "sat/tactic/sat2goal.h"

View file

@ -33,7 +33,7 @@ Notes:
#include "model/model_evaluator.h"
#include "sat/sat_solver.h"
#include "solver/simplifier_solver.h"
#include "sat/sat_params.hpp"
#include "params/sat_params.hpp"
#include "sat/smt/euf_solver.h"
#include "sat/tactic/goal2sat.h"
#include "sat/tactic/sat2goal.h"

View file

@ -21,7 +21,7 @@ Author:
#include "ast/ast_ll_pp.h"
#include "ast/arith_decl_plugin.h"
#include "smt/smt_solver.h"
#include "sat/sat_params.hpp"
#include "params/sat_params.hpp"
#include "sat/smt/euf_proof_checker.h"
#include "sat/smt/arith_theory_checker.h"
#include "sat/smt/q_theory_checker.h"

View file

@ -44,7 +44,7 @@ Notes:
#include "sat/smt/pb_solver.h"
#include "sat/smt/euf_solver.h"
#include "sat/smt/sat_th.h"
#include "sat/sat_params.hpp"
#include "params/sat_params.hpp"
#include<sstream>
struct goal2sat::imp : public sat::sat_internalizer {

View file

@ -44,7 +44,7 @@ Notes:
#include "sat/smt/pb_solver.h"
#include "sat/smt/euf_solver.h"
#include "sat/smt/sat_th.h"
#include "sat/sat_params.hpp"
#include "params/sat_params.hpp"
#include<sstream>
sat2goal::mc::mc(ast_manager& m): m(m), m_var2expr(m) {}

View file

@ -22,7 +22,7 @@ Notes:
#include "sat/tactic/goal2sat.h"
#include "sat/tactic/sat2goal.h"
#include "sat/sat_solver.h"
#include "sat/sat_params.hpp"
#include "params/sat_params.hpp"
class sat_tactic : public tactic {

View file

@ -23,7 +23,7 @@ Revision History:
#include "util/rlimit.h"
#include "util/gparams.h"
#include "sat/dimacs.h"
#include "sat/sat_params.hpp"
#include "params/sat_params.hpp"
#include "sat/sat_solver.h"
#include "sat/tactic/goal2sat.h"
#include "sat/tactic/sat2goal.h"

View file

@ -47,7 +47,7 @@ Notes:
#include "solver/parallel_params.hpp"
#include "params/tactic_params.hpp"
#include "parsers/smt2/smt2parser.h"
#include "sat/sat_params.hpp"
#include "params/sat_params.hpp"
tactic* mk_tactic_for_logic(ast_manager& m, params_ref const& p, symbol const& logic);

View file

@ -17,7 +17,7 @@ Author:
--*/
#include "smt/tactic/smt_tactic_core.h"
#include "sat/tactic/sat_tactic.h"
#include "sat/sat_params.hpp"
#include "params/sat_params.hpp"
#include "solver/solver2tactic.h"
#include "solver/solver.h"

View file

@ -168,6 +168,11 @@ public:
return *this;
}
checked_int64& operator%=(checked_int64 const& other) {
m_value %= other.m_value;
return *this;
}
friend inline checked_int64 abs(checked_int64 const& i) {
return i.abs();
}
@ -286,3 +291,17 @@ inline checked_int64<CHECK> operator/(checked_int64<CHECK> const& a, checked_int
result /= b;
return result;
}
template<bool CHECK>
inline checked_int64<CHECK> mod(checked_int64<CHECK> const& a, checked_int64<CHECK> const& b) {
checked_int64<CHECK> result(a);
result %= b;
if (result < 0) {
if (b > 0)
result += b;
else
result -= b;
}
return result;
}