diff --git a/src/ast/rewriter/CMakeLists.txt b/src/ast/rewriter/CMakeLists.txt index 9d80fd5ac..9f98ff0ed 100644 --- a/src/ast/rewriter/CMakeLists.txt +++ b/src/ast/rewriter/CMakeLists.txt @@ -19,6 +19,7 @@ z3_add_component(rewriter factor_equivs.cpp factor_rewriter.cpp fpa_rewriter.cpp + hoist_rewriter.cpp inj_axiom.cpp label_rewriter.cpp maximize_ac_sharing.cpp diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 5ea039763..00ce093be 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -34,7 +34,7 @@ Revision History: namespace sat { solver::solver(params_ref const & p, reslimit& l): - m_rlimit(l), + solver_core(l), m_checkpoint_enabled(true), m_config(p), m_par(nullptr), @@ -3328,7 +3328,7 @@ namespace sat { bool_var solver::max_var(bool learned, bool_var v) { m_user_bin_clauses.reset(); - collect_bin_clauses(m_user_bin_clauses, learned); + collect_bin_clauses(m_user_bin_clauses, learned, false); for (unsigned i = 0; i < m_user_bin_clauses.size(); ++i) { literal l1 = m_user_bin_clauses[i].first; literal l2 = m_user_bin_clauses[i].second; @@ -3827,7 +3827,7 @@ namespace sat { max_cliques mc; m_user_bin_clauses.reset(); m_binary_clause_graph.reset(); - collect_bin_clauses(m_user_bin_clauses, true); + collect_bin_clauses(m_user_bin_clauses, true, false); hashtable, default_eq > seen_bc; for (auto const& b : m_user_bin_clauses) { literal l1 = b.first; diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index dd81bac2e..4249da5ba 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -37,6 +37,7 @@ Revision History: #include "sat/sat_drat.h" #include "sat/sat_parallel.h" #include "sat/sat_local_search.h" +#include "sat/sat_solver_core.h" #include "util/params.h" #include "util/statistics.h" #include "util/stopwatch.h" @@ -75,11 +76,10 @@ namespace sat { void collect_statistics(statistics & st) const; }; - class solver { + class solver : public solver_core { public: struct abort_solver {}; protected: - reslimit& m_rlimit; bool m_checkpoint_enabled; config m_config; stats m_stats; @@ -197,12 +197,12 @@ namespace sat { // Misc // // ----------------------- - void updt_params(params_ref const & p); + void updt_params(params_ref const & p) override; static void collect_param_descrs(param_descrs & d); - void collect_statistics(statistics & st) const; + void collect_statistics(statistics & st) const override; void reset_statistics(); - void display_status(std::ostream & out) const; + void display_status(std::ostream & out) const override; /** \brief Copy (non learned) clauses from src to this solver. @@ -217,6 +217,9 @@ namespace sat { // Variable & Clause creation // // ----------------------- + void add_clause(unsigned num_lits, literal * lits, bool learned) override { mk_clause(num_lits, lits, learned); } + bool_var add_var(bool ext) override { return mk_var(ext, true); } + bool_var mk_var(bool ext = false, bool dvar = true); void mk_clause(literal_vector const& lits, bool learned = false) { mk_clause(lits.size(), lits.c_ptr(), learned); } void mk_clause(unsigned num_lits, literal * lits, bool learned = false); @@ -279,29 +282,28 @@ namespace sat { // // ----------------------- public: - bool inconsistent() const { return m_inconsistent; } - unsigned num_vars() const { return m_level.size(); } - unsigned num_clauses() const; + bool inconsistent() const override { return m_inconsistent; } + unsigned num_vars() const override { return m_level.size(); } + unsigned num_clauses() const override; void num_binary(unsigned& given, unsigned& learned) const; unsigned num_restarts() const { return m_restarts; } - bool is_external(bool_var v) const { return m_external[v] != 0; } - bool is_external(literal l) const { return is_external(l.var()); } - void set_external(bool_var v); - void set_non_external(bool_var v); + bool is_external(bool_var v) const override { return m_external[v] != 0; } + void set_external(bool_var v) override; + void set_non_external(bool_var v) override; bool was_eliminated(bool_var v) const { return m_eliminated[v] != 0; } - void set_eliminated(bool_var v, bool f) { m_eliminated[v] = f; } + void set_eliminated(bool_var v, bool f) override { m_eliminated[v] = f; } bool was_eliminated(literal l) const { return was_eliminated(l.var()); } unsigned scope_lvl() const { return m_scope_lvl; } unsigned search_lvl() const { return m_search_lvl; } bool at_search_lvl() const { return m_scope_lvl == m_search_lvl; } - bool at_base_lvl() const { return m_scope_lvl == 0; } + bool at_base_lvl() const override { return m_scope_lvl == 0; } lbool value(literal l) const { return static_cast(m_assignment[l.index()]); } lbool value(bool_var v) const { return static_cast(m_assignment[literal(v, false).index()]); } unsigned lvl(bool_var v) const { return m_level[v]; } unsigned lvl(literal l) const { return m_level[l.var()]; } - unsigned init_trail_size() const { return at_base_lvl() ? m_trail.size() : m_scopes[0].m_trail_lim; } + unsigned init_trail_size() const override { return at_base_lvl() ? m_trail.size() : m_scopes[0].m_trail_lim; } unsigned trail_size() const { return m_trail.size(); } - literal trail_literal(unsigned i) const { return m_trail[i]; } + literal trail_literal(unsigned i) const override { return m_trail[i]; } literal scope_literal(unsigned n) const { return m_trail[m_scopes[n].m_trail_lim]; } void assign(literal l, justification j) { TRACE("sat_assign", tout << l << " previous value: " << value(l) << "\n";); @@ -333,8 +335,8 @@ namespace sat { config const& get_config() const { return m_config; } void set_incremental(bool b) { m_config.m_incremental = b; } bool is_incremental() const { return m_config.m_incremental; } - extension* get_extension() const { return m_ext.get(); } - void set_extension(extension* e); + extension* get_extension() const override { return m_ext.get(); } + void set_extension(extension* e) override; bool set_root(literal l, literal r); void flush_roots(); typedef std::pair bin_clause; @@ -369,13 +371,13 @@ namespace sat { // // ----------------------- public: - lbool check(unsigned num_lits = 0, literal const* lits = nullptr); + lbool check(unsigned num_lits = 0, literal const* lits = nullptr) override; - model const & get_model() const { return m_model; } + model const & get_model() const override { return m_model; } bool model_is_current() const { return m_model_is_current; } - literal_vector const& get_core() const { return m_core; } + literal_vector const& get_core() const override { return m_core; } model_converter const & get_model_converter() const { return m_mc; } - void flush(model_converter& mc) { mc.flush(m_mc); } + void flush(model_converter& mc) override { mc.flush(m_mc); } void set_model(model const& mdl); char const* get_reason_unknown() const { return m_reason_unknown.c_str(); } bool check_clauses(model const& m) const; @@ -545,10 +547,10 @@ namespace sat { bool_var max_var(bool learned, bool_var v); public: - void user_push(); - void user_pop(unsigned num_scopes); - void pop_to_base_level(); - unsigned num_user_scopes() const { return m_user_scope_literals.size(); } + void user_push() override; + void user_pop(unsigned num_scopes) override; + void pop_to_base_level() override; + unsigned num_user_scopes() const override { return m_user_scope_literals.size(); } reslimit& rlimit() { return m_rlimit; } // ----------------------- // @@ -645,8 +647,9 @@ namespace sat { clause * const * begin_learned() const { return m_learned.begin(); } clause * const * end_learned() const { return m_learned.end(); } clause_vector const& learned() const { return m_learned; } - clause_vector const& clauses() const { return m_clauses; } - void collect_bin_clauses(svector & r, bool learned, bool learned_only = false) const; + clause_vector const& clauses() const override { return m_clauses; } + void collect_bin_clauses(svector & r, bool learned, bool learned_only) const override; + // ----------------------- // @@ -654,11 +657,11 @@ namespace sat { // // ----------------------- public: - bool check_invariant() const; + bool check_invariant() const override; void display(std::ostream & out) const; void display_watches(std::ostream & out) const; void display_watches(std::ostream & out, literal lit) const; - void display_dimacs(std::ostream & out) const; + void display_dimacs(std::ostream & out) const override; void display_wcnf(std::ostream & out, unsigned sz, literal const* lits, unsigned const* weights) const; void display_assignment(std::ostream & out) const; std::ostream& display_justification(std::ostream & out, justification const& j) const; diff --git a/src/sat/sat_solver_core.h b/src/sat/sat_solver_core.h new file mode 100644 index 000000000..4f763bd55 --- /dev/null +++ b/src/sat/sat_solver_core.h @@ -0,0 +1,116 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + sat_solver_core.h + +Abstract: + + SAT solver API class. + +Author: + + Nikolaj Bjorner (nbjorner) 2019-02-06 + +Revision History: + +--*/ +#ifndef SAT_SOLVER_CORE_H_ +#define SAT_SOLVER_CORE_H_ + + +#include "sat/sat_types.h" + +namespace sat { + + class solver_core { + protected: + reslimit& m_rlimit; + public: + solver_core(reslimit& l) : m_rlimit(l) {} + ~solver_core() {} + + virtual void pop_to_base_level() {} + virtual bool at_base_lvl() const { return true; } + + // retrieve model if solver return sat + virtual model const & get_model() const = 0; + + // retrieve core from assumptions + virtual literal_vector const& get_core() const = 0; + + // is the state inconsistent? + virtual bool inconsistent() const = 0; + + // number of variables and clauses + virtual unsigned num_vars() const = 0; + virtual unsigned num_clauses() const = 0; + + // check satisfiability + virtual lbool check(unsigned num_lits = 0, literal const* lits = nullptr) = 0; + + // add clauses + virtual void add_clause(unsigned n, literal* lits, bool is_redundant) = 0; + void add_clause(literal l1, literal l2, bool is_redundant) { + literal lits[2] = {l1, l2}; + add_clause(2, lits, is_redundant); + } + void add_clause(literal l1, literal l2, literal l3, bool is_redundant) { + literal lits[3] = {l1, l2, l3}; + add_clause(3, lits, is_redundant); + } + // create boolean variable, tagged as external (= true) or internal (can be eliminated). + virtual bool_var add_var(bool ext) = 0; + + // update parameters + virtual void updt_params(params_ref const& p) {} + + + virtual bool check_invariant() const { return true; } + virtual void display_status(std::ostream& out) const {} + virtual void display_dimacs(std::ostream& out) const {} + + virtual bool is_external(bool_var v) const { return true; } + bool is_external(literal l) const { return is_external(l.var()); } + virtual void set_external(bool_var v) {} + virtual void set_non_external(bool_var v) {} + virtual void set_eliminated(bool_var v, bool f) {} + + // optional support for user-scopes. Not relevant for sat_tactic integration. + // it is only relevant for incremental mode SAT, which isn't wrapped (yet) + virtual void user_push() { throw default_exception("optional API not supported"); } + virtual void user_pop(unsigned num_scopes) {}; + virtual unsigned num_user_scopes() const { return 0;} + + // hooks for extension solver. really just ba_solver atm. + virtual extension* get_extension() const { return nullptr; } + virtual void set_extension(extension* e) { if (e) throw default_exception("optional API not supported"); } + + + // The following methods are used when converting the state from the SAT solver back + // to a set of assertions. + + // retrieve model converter that handles variable elimination and other transformations + virtual void flush(model_converter& mc) {} + + // size of initial trail containing unit clauses + virtual unsigned init_trail_size() const = 0; + + // literal at trail index i + virtual literal trail_literal(unsigned i) const = 0; + + // collect n-ary clauses + virtual clause_vector const& clauses() const = 0; + + // collect binary clauses + typedef std::pair bin_clause; + virtual void collect_bin_clauses(svector & r, bool learned, bool learned_only) const = 0; + + // collect statistics from sat solver + virtual void collect_statistics(statistics & st) const {} + + }; +}; + +#endif diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 419952f99..f020fd2ab 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -57,7 +57,7 @@ struct goal2sat::imp { svector m_result_stack; obj_map m_cache; obj_hashtable m_interface_vars; - sat::solver & m_solver; + sat::solver_core & m_solver; atom2bool_var & m_map; dep2asm_map & m_dep2asm; sat::bool_var m_true; @@ -69,7 +69,7 @@ struct goal2sat::imp { bool m_xor_solver; bool m_is_lemma; - imp(ast_manager & _m, params_ref const & p, sat::solver & s, atom2bool_var & map, dep2asm_map& dep2asm, bool default_external): + imp(ast_manager & _m, params_ref const & p, sat::solver_core & s, atom2bool_var & map, dep2asm_map& dep2asm, bool default_external): m(_m), pb(m), m_ext(nullptr), @@ -97,30 +97,30 @@ struct goal2sat::imp { void mk_clause(sat::literal l) { TRACE("goal2sat", tout << "mk_clause: " << l << "\n";); - m_solver.mk_clause(1, &l); + m_solver.add_clause(1, &l, false); } void set_lemma_mode(bool f) { m_is_lemma = f; } void mk_clause(sat::literal l1, sat::literal l2) { TRACE("goal2sat", tout << "mk_clause: " << l1 << " " << l2 << "\n";); - m_solver.mk_clause(l1, l2, m_is_lemma); + m_solver.add_clause(l1, l2, m_is_lemma); } void mk_clause(sat::literal l1, sat::literal l2, sat::literal l3) { TRACE("goal2sat", tout << "mk_clause: " << l1 << " " << l2 << " " << l3 << "\n";); - m_solver.mk_clause(l1, l2, l3, m_is_lemma); + m_solver.add_clause(l1, l2, l3, m_is_lemma); } void mk_clause(unsigned num, sat::literal * lits) { TRACE("goal2sat", tout << "mk_clause: "; for (unsigned i = 0; i < num; i++) tout << lits[i] << " "; tout << "\n";); - m_solver.mk_clause(num, lits, m_is_lemma); + m_solver.add_clause(num, lits, m_is_lemma); } sat::bool_var mk_true() { if (m_true == sat::null_bool_var) { // create fake variable to represent true; - m_true = m_solver.mk_var(false); + m_true = m_solver.add_var(false); mk_clause(sat::literal(m_true, false)); // v is true } return m_true; @@ -139,7 +139,7 @@ struct goal2sat::imp { } else { bool ext = m_default_external || !is_uninterp_const(t) || m_interface_vars.contains(t); - sat::bool_var v = m_solver.mk_var(ext); + sat::bool_var v = m_solver.add_var(ext); m_map.insert(t, v); l = sat::literal(v, sign); TRACE("sat", tout << "new_var: " << v << ": " << mk_ismt2_pp(t, m) << "\n";); @@ -247,7 +247,7 @@ struct goal2sat::imp { } else { SASSERT(num <= m_result_stack.size()); - sat::bool_var k = m_solver.mk_var(); + sat::bool_var k = m_solver.add_var(false); sat::literal l(k, false); m_cache.insert(t, l); sat::literal * lits = m_result_stack.end() - num; @@ -286,7 +286,7 @@ struct goal2sat::imp { } else { SASSERT(num <= m_result_stack.size()); - sat::bool_var k = m_solver.mk_var(); + sat::bool_var k = m_solver.add_var(false); sat::literal l(k, false); m_cache.insert(t, l); // l => /\ lits @@ -329,7 +329,7 @@ struct goal2sat::imp { m_result_stack.reset(); } else { - sat::bool_var k = m_solver.mk_var(); + sat::bool_var k = m_solver.add_var(false); sat::literal l(k, false); m_cache.insert(n, l); mk_clause(~l, ~c, t); @@ -366,7 +366,7 @@ struct goal2sat::imp { m_result_stack.reset(); } else { - sat::bool_var k = m_solver.mk_var(); + sat::bool_var k = m_solver.add_var(false); sat::literal l(k, false); m_cache.insert(t, l); mk_clause(~l, l1, ~l2); @@ -390,7 +390,7 @@ struct goal2sat::imp { return; } sat::literal_vector lits; - sat::bool_var v = m_solver.mk_var(true); + sat::bool_var v = m_solver.add_var(true); lits.push_back(sat::literal(v, true)); convert_pb_args(num, lits); // ensure that = is converted to xor @@ -472,7 +472,7 @@ struct goal2sat::imp { m_ext->add_pb_ge(sat::null_bool_var, wlits, k1); } else { - sat::bool_var v = m_solver.mk_var(true); + sat::bool_var v = m_solver.add_var(true); sat::literal lit(v, sign); m_ext->add_pb_ge(v, wlits, k.get_unsigned()); TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); @@ -503,7 +503,7 @@ struct goal2sat::imp { m_ext->add_pb_ge(sat::null_bool_var, wlits, k1); } else { - sat::bool_var v = m_solver.mk_var(true); + sat::bool_var v = m_solver.add_var(true); sat::literal lit(v, sign); m_ext->add_pb_ge(v, wlits, k.get_unsigned()); TRACE("goal2sat", tout << "root: " << root << " lit: " << lit << "\n";); @@ -518,8 +518,8 @@ struct goal2sat::imp { svector wlits; convert_pb_args(t, wlits); bool base_assert = (root && !sign && m_solver.num_user_scopes() == 0); - sat::bool_var v1 = base_assert ? sat::null_bool_var : m_solver.mk_var(true); - sat::bool_var v2 = base_assert ? sat::null_bool_var : m_solver.mk_var(true); + sat::bool_var v1 = base_assert ? sat::null_bool_var : m_solver.add_var(true); + sat::bool_var v2 = base_assert ? sat::null_bool_var : m_solver.add_var(true); m_ext->add_pb_ge(v1, wlits, k.get_unsigned()); k.neg(); for (wliteral& wl : wlits) { @@ -533,7 +533,7 @@ struct goal2sat::imp { } else { sat::literal l1(v1, false), l2(v2, false); - sat::bool_var v = m_solver.mk_var(); + sat::bool_var v = m_solver.add_var(false); sat::literal l(v, false); mk_clause(~l, l1); mk_clause(~l, l2); @@ -553,7 +553,7 @@ struct goal2sat::imp { m_ext->add_at_least(sat::null_bool_var, lits, k.get_unsigned()); } else { - sat::bool_var v = m_solver.mk_var(true); + sat::bool_var v = m_solver.add_var(true); sat::literal lit(v, false); m_ext->add_at_least(v, lits, k.get_unsigned()); m_cache.insert(t, lit); @@ -575,7 +575,7 @@ struct goal2sat::imp { m_ext->add_at_least(sat::null_bool_var, lits, lits.size() - k.get_unsigned()); } else { - sat::bool_var v = m_solver.mk_var(true); + sat::bool_var v = m_solver.add_var(true); sat::literal lit(v, false); m_ext->add_at_least(v, lits, lits.size() - k.get_unsigned()); m_cache.insert(t, lit); @@ -588,8 +588,8 @@ struct goal2sat::imp { SASSERT(k.is_unsigned()); sat::literal_vector lits; convert_pb_args(t->get_num_args(), lits); - sat::bool_var v1 = (root && !sign) ? sat::null_bool_var : m_solver.mk_var(true); - sat::bool_var v2 = (root && !sign) ? sat::null_bool_var : m_solver.mk_var(true); + sat::bool_var v1 = (root && !sign) ? sat::null_bool_var : m_solver.add_var(true); + sat::bool_var v2 = (root && !sign) ? sat::null_bool_var : m_solver.add_var(true); m_ext->add_at_least(v1, lits, k.get_unsigned()); for (sat::literal& l : lits) { l.neg(); @@ -602,7 +602,7 @@ struct goal2sat::imp { } else { sat::literal l1(v1, false), l2(v2, false); - sat::bool_var v = m_solver.mk_var(); + sat::bool_var v = m_solver.add_var(false); sat::literal l(v, false); mk_clause(~l, l1); mk_clause(~l, l2); @@ -897,7 +897,7 @@ struct goal2sat::scoped_set_imp { }; -void goal2sat::operator()(goal const & g, params_ref const & p, sat::solver & t, atom2bool_var & m, dep2asm_map& dep2asm, bool default_external, bool is_lemma) { +void goal2sat::operator()(goal const & g, params_ref const & p, sat::solver_core & t, atom2bool_var & m, dep2asm_map& dep2asm, bool default_external, bool is_lemma) { imp proc(g.m(), p, t, m, dep2asm, default_external); scoped_set_imp set(this, &proc); proc.set_lemma_mode(is_lemma); @@ -916,7 +916,7 @@ void goal2sat::get_interpreted_atoms(expr_ref_vector& atoms) { sat2goal::mc::mc(ast_manager& m): m(m), m_var2expr(m) {} -void sat2goal::mc::flush_smc(sat::solver& s, atom2bool_var const& map) { +void sat2goal::mc::flush_smc(sat::solver_core& s, atom2bool_var const& map) { s.flush(m_smc); m_var2expr.resize(s.num_vars()); map.mk_var_inv(m_var2expr); @@ -1157,13 +1157,14 @@ struct sat2goal::imp { r.assert_expr(fml); } - void assert_clauses(ref& mc, sat::solver const & s, sat::clause_vector const& clauses, goal & r, bool asserted) { + void assert_clauses(ref& mc, sat::solver_core const & s, sat::clause_vector const& clauses, goal & r, bool asserted) { ptr_buffer lits; + unsigned small_lbd = 3; // s.get_config().m_gc_small_lbd; for (sat::clause* cp : clauses) { checkpoint(); lits.reset(); sat::clause const & c = *cp; - if (asserted || m_learned || c.glue() <= s.get_config().m_gc_small_lbd) { + if (asserted || m_learned || c.glue() <= small_lbd) { for (sat::literal l : c) { lits.push_back(lit2expr(mc, l)); } @@ -1172,11 +1173,11 @@ struct sat2goal::imp { } } - sat::ba_solver* get_ba_solver(sat::solver const& s) { + sat::ba_solver* get_ba_solver(sat::solver_core const& s) { return dynamic_cast(s.get_extension()); } - void operator()(sat::solver & s, atom2bool_var const & map, goal & r, ref & mc) { + void operator()(sat::solver_core & s, atom2bool_var const & map, goal & r, ref & mc) { if (s.at_base_lvl() && s.inconsistent()) { r.assert_expr(m.mk_false()); return; @@ -1196,7 +1197,7 @@ struct sat2goal::imp { // collect binary clauses svector bin_clauses; - s.collect_bin_clauses(bin_clauses, m_learned); + s.collect_bin_clauses(bin_clauses, m_learned, false); for (sat::solver::bin_clause const& bc : bin_clauses) { checkpoint(); r.assert_expr(m.mk_or(lit2expr(mc, bc.first), lit2expr(mc, bc.second))); @@ -1262,7 +1263,7 @@ struct sat2goal::scoped_set_imp { } }; -void sat2goal::operator()(sat::solver & t, atom2bool_var const & m, params_ref const & p, +void sat2goal::operator()(sat::solver_core & t, atom2bool_var const & m, params_ref const & p, goal & g, ref & mc) { imp proc(g.m(), p); scoped_set_imp set(this, &proc); diff --git a/src/sat/tactic/goal2sat.h b/src/sat/tactic/goal2sat.h index 514b65311..78884051e 100644 --- a/src/sat/tactic/goal2sat.h +++ b/src/sat/tactic/goal2sat.h @@ -62,7 +62,7 @@ public: \warning conversion throws a tactic_exception, if it is interrupted (by set_cancel), an unsupported operator is found, or memory consumption limit is reached (set with param :max-memory). */ - void operator()(goal const & g, params_ref const & p, sat::solver & t, atom2bool_var & m, dep2asm_map& dep2asm, bool default_external = false, bool is_lemma = false); + void operator()(goal const & g, params_ref const & p, sat::solver_core & t, atom2bool_var & m, dep2asm_map& dep2asm, bool default_external = false, bool is_lemma = false); void get_interpreted_atoms(expr_ref_vector& atoms); @@ -88,7 +88,7 @@ public: mc(ast_manager& m); ~mc() override {} // flush model converter from SAT solver to this structure. - void flush_smc(sat::solver& s, atom2bool_var const& map); + void flush_smc(sat::solver_core& s, atom2bool_var const& map); void operator()(model_ref& md) override; void operator()(expr_ref& fml) override; model_converter* translate(ast_translation& translator) override; @@ -113,7 +113,7 @@ public: \warning conversion throws a tactic_exception, if it is interrupted (by set_cancel), or memory consumption limit is reached (set with param :max-memory). */ - void operator()(sat::solver & t, atom2bool_var const & m, params_ref const & p, goal & s, ref & mc); + void operator()(sat::solver_core & t, atom2bool_var const & m, params_ref const & p, goal & s, ref & mc); }; diff --git a/src/sat/tactic/sat_tactic.cpp b/src/sat/tactic/sat_tactic.cpp index bd100d620..1f6fe11ad 100644 --- a/src/sat/tactic/sat_tactic.cpp +++ b/src/sat/tactic/sat_tactic.cpp @@ -29,12 +29,12 @@ class sat_tactic : public tactic { ast_manager & m; goal2sat m_goal2sat; sat2goal m_sat2goal; - sat::solver m_solver; + scoped_ptr m_solver; params_ref m_params; imp(ast_manager & _m, params_ref const & p): m(_m), - m_solver(p, m.limit()), + m_solver(alloc(sat::solver, p, m.limit())), m_params(p) { SASSERT(!m.proofs_enabled()); updt_params(p); @@ -51,7 +51,7 @@ class sat_tactic : public tactic { atom2bool_var map(m); obj_map dep2asm; sat::literal_vector assumptions; - m_goal2sat(*g, m_params, m_solver, map, dep2asm); + m_goal2sat(*g, m_params, *m_solver, map, dep2asm); TRACE("sat_solver_unknown", tout << "interpreted_atoms: " << map.interpreted_atoms() << "\n"; for (auto const& kv : map) { if (!is_uninterp_const(kv.m_key)) @@ -60,15 +60,15 @@ class sat_tactic : public tactic { g->reset(); g->m().compact_memory(); - CASSERT("sat_solver", m_solver.check_invariant()); - IF_VERBOSE(TACTIC_VERBOSITY_LVL, m_solver.display_status(verbose_stream());); - TRACE("sat_dimacs", m_solver.display_dimacs(tout);); + CASSERT("sat_solver", m_solver->check_invariant()); + IF_VERBOSE(TACTIC_VERBOSITY_LVL, m_solver->display_status(verbose_stream());); + TRACE("sat_dimacs", m_solver->display_dimacs(tout);); dep2assumptions(dep2asm, assumptions); - lbool r = m_solver.check(assumptions.size(), assumptions.c_ptr()); + lbool r = m_solver->check(assumptions.size(), assumptions.c_ptr()); if (r == l_false) { expr_dependency * lcore = nullptr; if (produce_core) { - sat::literal_vector const& ucore = m_solver.get_core(); + sat::literal_vector const& ucore = m_solver->get_core(); u_map asm2dep; mk_asm2dep(dep2asm, asm2dep); for (unsigned i = 0; i < ucore.size(); ++i) { @@ -83,7 +83,7 @@ class sat_tactic : public tactic { // register model if (produce_models) { model_ref md = alloc(model, m); - sat::model const & ll_m = m_solver.get_model(); + sat::model const & ll_m = m_solver->get_model(); TRACE("sat_tactic", for (unsigned i = 0; i < ll_m.size(); i++) tout << i << ":" << ll_m[i] << " "; tout << "\n";); for (auto const& kv : map) { expr * n = kv.m_key; @@ -109,9 +109,9 @@ class sat_tactic : public tactic { #if 0 IF_VERBOSE(TACTIC_VERBOSITY_LVL, verbose_stream() << "\"formula constrains interpreted atoms, recovering formula from sat solver...\"\n";); #endif - m_solver.pop_to_base_level(); + m_solver->pop_to_base_level(); ref mc; - m_sat2goal(m_solver, map, m_params, *(g.get()), mc); + m_sat2goal(*m_solver, map, m_params, *(g.get()), mc); g->add(mc.get()); } g->inc_depth(); @@ -134,7 +134,7 @@ class sat_tactic : public tactic { } void updt_params(params_ref const& p) { - m_solver.updt_params(p); + m_solver->updt_params(p); } }; @@ -192,10 +192,10 @@ public: scoped_set_imp set(this, &proc); try { proc(g, result); - proc.m_solver.collect_statistics(m_stats); + proc.m_solver->collect_statistics(m_stats); } catch (sat::solver_exception & ex) { - proc.m_solver.collect_statistics(m_stats); + proc.m_solver->collect_statistics(m_stats); throw tactic_exception(ex.msg()); } TRACE("sat_stats", m_stats.display_smt2(tout);); diff --git a/src/shell/dimacs_frontend.cpp b/src/shell/dimacs_frontend.cpp index 114d8daf6..4f032c01e 100644 --- a/src/shell/dimacs_frontend.cpp +++ b/src/shell/dimacs_frontend.cpp @@ -113,7 +113,7 @@ static void track_clauses(sat::solver const& src, sat::clause * const * it = src.begin_clauses(); sat::clause * const * end = src.end_clauses(); svector bin_clauses; - src.collect_bin_clauses(bin_clauses, false); + src.collect_bin_clauses(bin_clauses, false, false); tracking_clauses.reserve(2*src.num_vars() + static_cast(end - it) + bin_clauses.size()); for (sat::bool_var v = 1; v < src.num_vars(); ++v) { diff --git a/src/smt/smt_cg_table.cpp b/src/smt/smt_cg_table.cpp index ad15fd819..b85fed02d 100644 --- a/src/smt/smt_cg_table.cpp +++ b/src/smt/smt_cg_table.cpp @@ -71,10 +71,7 @@ namespace smt { void cg_table::display(std::ostream & out) const { out << "congruence table:\n"; - table::iterator it = m_table.begin(); - table::iterator end = m_table.end(); - for (; it != end; ++it) { - enode * n = *it; + for (enode * n : m_table) { out << mk_pp(n->get_owner(), m_manager) << "\n"; } } @@ -82,10 +79,7 @@ namespace smt { void cg_table::display_compact(std::ostream & out) const { if (!m_table.empty()) { out << "congruence table:\n"; - table::iterator it = m_table.begin(); - table::iterator end = m_table.end(); - for (; it != end; ++it) { - enode * n = *it; + for (enode * n : m_table) { out << "#" << n->get_owner()->get_id() << " "; } out << "\n"; @@ -94,10 +88,7 @@ namespace smt { #ifdef Z3DEBUG bool cg_table::check_invariant() const { - table::iterator it = m_table.begin(); - table::iterator end = m_table.end(); - for (; it != end; ++it) { - enode * n = *it; + for (enode * n : m_table) { CTRACE("cg_table", !contains_ptr(n), tout << "#" << n->get_owner_id() << "\n";); SASSERT(contains_ptr(n)); } @@ -136,9 +127,11 @@ namespace smt { } bool cg_table::cg_eq::operator()(enode * n1, enode * n2) const { - SASSERT(n1->get_num_args() == n2->get_num_args()); SASSERT(n1->get_decl() == n2->get_decl()); unsigned num = n1->get_num_args(); + if (num != n2->get_num_args()) { + return false; + } for (unsigned i = 0; i < num; i++) if (n1->get_arg(i)->get_root() != n2->get_arg(i)->get_root()) return false; @@ -205,10 +198,7 @@ namespace smt { } void cg_table::reset() { - ptr_vector::iterator it = m_tables.begin(); - ptr_vector::iterator end = m_tables.end(); - for (; it != end; ++it) { - void * t = *it; + for (void* t : m_tables) { switch (GET_TAG(t)) { case UNARY: dealloc(UNTAG(unary_table*, t)); @@ -225,10 +215,9 @@ namespace smt { } } m_tables.reset(); - obj_map::iterator it2 = m_func_decl2id.begin(); - obj_map::iterator end2 = m_func_decl2id.end(); - for (; it2 != end2; ++it2) - m_manager.dec_ref(it2->m_key); + for (auto const& kv : m_func_decl2id) { + m_manager.dec_ref(kv.m_key); + } m_func_decl2id.reset(); } diff --git a/src/smt/smt_cg_table.h b/src/smt/smt_cg_table.h index 64c8328d0..4085ccc5f 100644 --- a/src/smt/smt_cg_table.h +++ b/src/smt/smt_cg_table.h @@ -252,6 +252,8 @@ namespace smt { enode_bool_pair insert(enode * n) { // it doesn't make sense to insert a constant. SASSERT(n->get_num_args() > 0); + SASSERT(!m_manager.is_and(n->get_owner())); + SASSERT(!m_manager.is_or(n->get_owner())); enode * n_prime; void * t = get_table(n); switch (static_cast(GET_TAG(t))) { diff --git a/src/tactic/core/solve_eqs_tactic.cpp b/src/tactic/core/solve_eqs_tactic.cpp index 623f83db4..1fc0c1f41 100644 --- a/src/tactic/core/solve_eqs_tactic.cpp +++ b/src/tactic/core/solve_eqs_tactic.cpp @@ -22,6 +22,9 @@ Revision History: #include "ast/ast_util.h" #include "ast/ast_pp.h" #include "ast/pb_decl_plugin.h" +#include "ast/rewriter/th_rewriter.h" +#include "ast/rewriter/rewriter_def.h" +#include "ast/rewriter/hoist_rewriter.h" #include "tactic/goal_shared_occs.h" #include "tactic/tactical.h" #include "tactic/generic_model_converter.h" @@ -574,27 +577,36 @@ class solve_eqs_tactic : public tactic { } else if (m().is_or(f)) { flatten_or(f, args); - //std::cout << "hoist or " << args.size() << "\n"; for (unsigned i = 0; i < args.size(); ++i) { path.push_back(nnf_context(false, args, i)); hoist_nnf(g, args.get(i), path, idx, depth + 1); path.pop_back(); } } - else { - // std::cout << "no hoist " << mk_pp(f, m()) << "\n"; - } } - bool collect_hoist(goal const& g) { - bool change = false; + void collect_hoist(goal const& g) { unsigned size = g.size(); vector path; for (unsigned idx = 0; idx < size; idx++) { checkpoint(); hoist_nnf(g, g.form(idx), path, idx, 0); } - return change; + } + + void distribute_and_or(goal & g) { + unsigned size = g.size(); + hoist_rewriter_star rw(m()); + th_rewriter thrw(m()); + expr_ref tmp(m()), tmp2(m()); + for (unsigned idx = 0; idx < size; idx++) { + checkpoint(); + expr* f = g.form(idx); + thrw(f, tmp); + rw(tmp, tmp2); + g.update(idx, tmp2); + } + } void sort_vars() { @@ -918,6 +930,9 @@ class solve_eqs_tactic : public tactic { m_subst = alloc(expr_substitution, m(), m_produce_unsat_cores, m_produce_proofs); m_norm_subst = alloc(expr_substitution, m(), m_produce_unsat_cores, m_produce_proofs); while (true) { + if (m_context_solve) { + distribute_and_or(*(g.get())); + } collect_num_occs(*g); collect(*g); if (m_context_solve) {