3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-23 00:55:31 +00:00

local changes

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2017-09-25 16:37:15 -07:00
parent 82922d92f7
commit ced2029ae9
12 changed files with 69 additions and 64 deletions

View file

@ -43,7 +43,6 @@ struct pb2bv_rewriter::imp {
struct card2bv_rewriter {
typedef expr* literal;
typedef ptr_vector<expr> literal_vector;
sorting_network_config m_cfg;
psort_nw<card2bv_rewriter> m_sort;
ast_manager& m;
imp& m_imp;
@ -572,7 +571,7 @@ struct pb2bv_rewriter::imp {
public:
card2bv_rewriter(imp& i, ast_manager& m):
m_sort(*this, m_cfg),
m_sort(*this),
m(m),
m_imp(i),
au(m),
@ -786,7 +785,7 @@ struct pb2bv_rewriter::imp {
void pb_totalizer(bool f) {
m_pb_totalizer = f;
}
void set_at_most1(sorting_network_encoding enc) { m_cfg.m_encoding = enc; }
void set_at_most1(sorting_network_encoding enc) { m_sort.cfg().m_encoding = enc; }
};
@ -852,7 +851,7 @@ struct pb2bv_rewriter::imp {
sorting_network_encoding atmost1_encoding() const {
symbol enc = m_params.get_sym("atmost1_encoding", enc);
symbol enc = m_params.get_sym("atmost1_encoding", symbol());
if (enc == symbol()) {
enc = gparams::get_module("sat").get_sym("atmost1_encoding", symbol());
}

View file

@ -32,13 +32,12 @@ namespace opt {
public:
typedef expr* literal;
typedef ptr_vector<expr> literal_vector;
sorting_network_config m_cfg;
psort_nw<sortmax> m_sort;
expr_ref_vector m_trail;
func_decl_ref_vector m_fresh;
ref<filter_model_converter> m_filter;
sortmax(maxsat_context& c, weights_t& ws, expr_ref_vector const& soft):
maxsmt_solver_base(c, ws, soft), m_sort(*this, m_cfg), m_trail(m), m_fresh(m) {}
maxsmt_solver_base(c, ws, soft), m_sort(*this), m_trail(m), m_fresh(m) {}
virtual ~sortmax() {}

View file

@ -1619,11 +1619,14 @@ namespace sat {
double ba_solver::get_reward(card const& c, literal_occs_fun& literal_occs) const {
unsigned k = c.k(), slack = 0;
double to_add = 0;
bool do_add = get_config().m_lookahead_reward == heule_schur_reward;
double to_add = do_add ? 0: 1;
for (literal l : c) {
switch (value(l)) {
case l_true: --k; if (k == 0) return 0; break;
case l_undef: to_add += literal_occs(l); ++slack; break;
case l_true: --k; if (k == 0) return 0;
case l_undef:
if (do_add) to_add += literal_occs(l);
++slack; break;
case l_false: break;
}
}
@ -1633,14 +1636,19 @@ namespace sat {
double ba_solver::get_reward(pb const& c, literal_occs_fun& occs) const {
unsigned k = c.k(), slack = 0;
double to_add = 0;
bool do_add = get_config().m_lookahead_reward == heule_schur_reward;
double to_add = do_add ? 0 : 1;
double undefs = 0;
for (wliteral wl : c) {
literal l = wl.second;
unsigned w = wl.first;
switch (value(l)) {
case l_true: if (k <= w) return 0; k -= w; break;
case l_undef: to_add += occs(l); ++undefs; slack += w; break; // TBD multiplier factor on this
case l_true: if (k <= w) return 0;
case l_undef:
if (do_add) to_add += occs(l);
++undefs;
slack += w;
break; // TBD multiplier factor on this
case l_false: break;
}
}

View file

@ -87,7 +87,22 @@ namespace sat {
m_lookahead_simplify = p.lookahead_simplify();
m_lookahead_cube = p.lookahead_cube();
m_lookahead_search = p.lookahead_search();
m_lookahead_reward = p.lookahead_reward();
if (p.lookahead_reward() == symbol("hs")) {
m_lookahead_reward = heule_schur_reward;
}
else if (p.lookahead_reward() == symbol("heuleu")) {
m_lookahead_reward = heule_unit_reward;
}
else if (p.lookahead_reward() == symbol("ternary")) {
m_lookahead_reward = ternary_reward;
}
else if (p.lookahead_reward() == symbol("unit")) {
m_lookahead_reward = unit_literal_reward;
}
else {
throw sat_param_exception("invalid reward type supplied: accepted heuristics are 'ternary', 'heuleu', 'unit' or 'heule_schur'");
}
m_lookahead_cube_fraction = p.lookahead_cube_fraction();
m_lookahead_cube_cutoff = p.lookahead_cube_cutoff();

View file

@ -57,6 +57,13 @@ namespace sat {
PB_TOTALIZER
};
enum reward_t {
ternary_reward,
unit_literal_reward,
heule_schur_reward,
heule_unit_reward
};
struct config {
unsigned long long m_max_memory;
phase_selection m_phase;
@ -78,7 +85,7 @@ namespace sat {
bool m_lookahead_cube;
unsigned m_lookahead_cube_cutoff;
double m_lookahead_cube_fraction;
symbol m_lookahead_reward;
reward_t m_lookahead_reward;
unsigned m_simplify_mult1;
double m_simplify_mult2;

View file

@ -1299,7 +1299,7 @@ namespace sat {
case watched::EXT_CONSTRAINT: {
SASSERT(m_s.m_ext);
bool keep = m_s.m_ext->propagate(l, it->get_ext_constraint_idx());
if (m_search_mode == lookahead_mode::lookahead1) {
if (m_search_mode == lookahead_mode::lookahead1 && !m_inconsistent) {
lookahead_literal_occs_fun literal_occs_fn(*this);
m_lookahead_reward += m_s.m_ext->get_reward(l, it->get_ext_constraint_idx(), literal_occs_fn);
}
@ -2065,21 +2065,7 @@ namespace sat {
}
void lookahead::init_config() {
if (m_s.m_config.m_lookahead_reward == symbol("hs")) {
m_config.m_reward_type = heule_schur_reward;
}
else if (m_s.m_config.m_lookahead_reward == symbol("heuleu")) {
m_config.m_reward_type = heule_unit_reward;
}
else if (m_s.m_config.m_lookahead_reward == symbol("ternary")) {
m_config.m_reward_type = ternary_reward;
}
else if (m_s.m_config.m_lookahead_reward == symbol("unit")) {
m_config.m_reward_type = unit_literal_reward;
}
else {
warning_msg("Reward type not recognized");
}
m_config.m_reward_type = m_s.m_config.m_lookahead_reward;
m_config.m_cube_cutoff = m_s.m_config.m_lookahead_cube_cutoff;
m_config.m_cube_fraction = m_s.m_config.m_lookahead_cube_fraction;
}

View file

@ -66,13 +66,6 @@ namespace sat {
friend class ccc;
friend class ba_solver;
enum reward_t {
ternary_reward,
unit_literal_reward,
heule_schur_reward,
heule_unit_reward
};
struct config {
double m_dl_success;
double m_alpha;

View file

@ -1060,10 +1060,12 @@ namespace sat {
it.next();
}
ext_constraint_list const& ext_list = s.m_ext_use_list.get(~l);
for (ext_constraint_idx idx : ext_list) {
if (!s.s.m_ext->is_blocked(l, idx)) {
return false;
if (s.s.m_ext) {
ext_constraint_list const& ext_list = s.m_ext_use_list.get(~l);
for (ext_constraint_idx idx : ext_list) {
if (!s.s.m_ext->is_blocked(l, idx)) {
return false;
}
}
}
return true;

View file

@ -1441,8 +1441,7 @@ namespace smt {
theory_pb_params p;
theory_pb th(ctx.get_manager(), p);
psort_expr ps(ctx, th);
sorting_network_config cfg;
psort_nw<psort_expr> sort(ps, cfg);
psort_nw<psort_expr> sort(ps);
return sort.ge(false, k, n, xs);
}
@ -1578,8 +1577,7 @@ namespace smt {
psort_expr ps(ctx, *this);
sorting_network_config cfg;
psort_nw<psort_expr> sortnw(ps, cfg);
psort_nw<psort_expr> sortnw(ps);
sortnw.m_stats.reset();
if (ctx.get_assignment(thl) == l_true &&

View file

@ -85,6 +85,7 @@ public:
}
virtual lbool check_sat_core(unsigned num_assumptions, expr * const * assumptions) {
m_solver->updt_params(m_params);
return m_solver->check_sat(num_assumptions, assumptions);
}

View file

@ -191,9 +191,9 @@ static void test_eq1(unsigned n, sorting_network_encoding enc) {
}
smt_params fp;
smt::kernel solver(m, fp);
sorting_network_config cfg;
cfg.m_encoding = enc;
psort_nw<ast_ext2> sn(ext, cfg);
psort_nw<ast_ext2> sn(ext);
sn.cfg().m_encoding = enc;
expr_ref result1(m), result2(m);
// equality:
@ -237,9 +237,8 @@ static void test_sorting_eq(unsigned n, unsigned k, sorting_network_encoding enc
}
smt_params fp;
smt::kernel solver(m, fp);
sorting_network_config cfg;
cfg.m_encoding = enc;
psort_nw<ast_ext2> sn(ext, cfg);
psort_nw<ast_ext2> sn(ext);
sn.cfg().m_encoding = enc;
expr_ref result(m);
// equality:
@ -288,9 +287,8 @@ static void test_sorting_le(unsigned n, unsigned k, sorting_network_encoding enc
}
smt_params fp;
smt::kernel solver(m, fp);
sorting_network_config cfg;
cfg.m_encoding = enc;
psort_nw<ast_ext2> sn(ext, cfg);
psort_nw<ast_ext2> sn(ext);
sn.cfg().m_encoding = enc;
expr_ref result(m);
// B <= k
std::cout << "le " << k << "\n";
@ -337,9 +335,8 @@ void test_sorting_ge(unsigned n, unsigned k, sorting_network_encoding enc) {
}
smt_params fp;
smt::kernel solver(m, fp);
sorting_network_config cfg;
cfg.m_encoding = enc;
psort_nw<ast_ext2> sn(ext, cfg);
psort_nw<ast_ext2> sn(ext);
sn.cfg().m_encoding = enc;
expr_ref result(m);
// k <= B
std::cout << "ge " << k << "\n";
@ -402,9 +399,8 @@ void test_at_most_1(unsigned n, bool full, sorting_network_encoding enc) {
}
ast_ext2 ext(m);
sorting_network_config cfg;
cfg.m_encoding = enc;
psort_nw<ast_ext2> sn(ext, cfg);
psort_nw<ast_ext2> sn(ext);
sn.cfg().m_encoding = enc;
expr_ref result1(m), result2(m);
result1 = sn.le(full, 1, in.size(), in.c_ptr());
result2 = naive_at_most1(in);
@ -481,9 +477,8 @@ static void test_at_most1(sorting_network_encoding enc) {
in[4] = in[3].get();
ast_ext2 ext(m);
sorting_network_config cfg;
cfg.m_encoding = enc;
psort_nw<ast_ext2> sn(ext, cfg);
psort_nw<ast_ext2> sn(ext);
sn.cfg().m_encoding = enc;
expr_ref result(m);
result = sn.le(true, 1, in.size(), in.c_ptr());
std::cout << result << "\n";

View file

@ -208,7 +208,9 @@ Notes:
}
};
psort_nw(psort_expr& c, sorting_network_config const& cfg): ctx(c), m_cfg(cfg) {}
psort_nw(psort_expr& c): ctx(c) {}
sorting_network_config& cfg() { return m_cfg; }
literal ge(bool full, unsigned k, unsigned n, literal const* xs) {
if (k > n) {