diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index c6eae7488..09e522db2 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -479,12 +479,14 @@ extern "C" { Z3_ast Z3_API Z3_solver_lookahead(Z3_context c, Z3_solver s, + Z3_ast_vector assumptions, Z3_ast_vector candidates) { Z3_TRY; - LOG_Z3_solver_lookahead(c, s, candidates); + LOG_Z3_solver_lookahead(c, s, assumptions, candidates); ast_manager& m = mk_c(c)->m(); - expr_ref_vector _candidates(m); + expr_ref_vector _candidates(m), _assumptions(m); ast_ref_vector const& __candidates = to_ast_vector_ref(candidates); + ast_ref_vector const& __assumptions = to_ast_vector_ref(assumptions); for (auto & e : __candidates) { if (!is_expr(e)) { SET_ERROR_CODE(Z3_INVALID_USAGE); @@ -492,6 +494,13 @@ extern "C" { } _candidates.push_back(to_expr(e)); } + for (auto & e : __assumptions) { + if (!is_expr(e)) { + SET_ERROR_CODE(Z3_INVALID_USAGE); + return 0; + } + _assumptions.push_back(to_expr(e)); + } expr_ref result(m); unsigned timeout = to_solver(s)->m_params.get_uint("timeout", mk_c(c)->get_timeout()); @@ -504,7 +513,7 @@ extern "C" { scoped_timer timer(timeout, &eh); scoped_rlimit _rlimit(mk_c(c)->m().limit(), rlimit); try { - result = to_solver_ref(s)->lookahead(_candidates); + result = to_solver_ref(s)->lookahead(_assumptions, _candidates); } catch (z3_exception & ex) { mk_c(c)->handle_exception(ex); diff --git a/src/api/dotnet/Solver.cs b/src/api/dotnet/Solver.cs index 6dd19cddf..078f5bc7a 100644 --- a/src/api/dotnet/Solver.cs +++ b/src/api/dotnet/Solver.cs @@ -255,11 +255,13 @@ namespace Microsoft.Z3 /// /// Select a lookahead literal from the set of supplied candidates. /// - public BoolExpr Lookahead(IEnumerable candidates) + public BoolExpr Lookahead(IEnumerable assumptions, IEnumerable candidates) { ASTVector cands = new ASTVector(Context); foreach (var c in candidates) cands.Push(c); - return (BoolExpr)Expr.Create(Context, Native.Z3_solver_lookahead(Context.nCtx, NativeObject, cands.NativeObject)); + ASTVector assums = new ASTVector(Context); + foreach (var c in assumptions) assums.Push(c); + return (BoolExpr)Expr.Create(Context, Native.Z3_solver_lookahead(Context.nCtx, NativeObject, assums.NativeObject, cands.NativeObject)); } /// diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 4ae7c5c22..317b32923 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -6028,10 +6028,10 @@ extern "C" { \brief select a literal from the list of candidate propositional variables to split on. If the candidate list is empty, then the solver chooses a formula based on its internal state. - def_API('Z3_solver_lookahead', AST, (_in(CONTEXT), _in(SOLVER), _in(AST_VECTOR))) + def_API('Z3_solver_lookahead', AST, (_in(CONTEXT), _in(SOLVER), _in(AST_VECTOR), _in(AST_VECTOR))) */ - Z3_ast Z3_API Z3_solver_lookahead(Z3_context c, Z3_solver s, Z3_ast_vector candidates); + Z3_ast Z3_API Z3_solver_lookahead(Z3_context c, Z3_solver s, Z3_ast_vector assumptions, Z3_ast_vector candidates); /** diff --git a/src/opt/opt_solver.h b/src/opt/opt_solver.h index 24d0408fc..cef270abc 100644 --- a/src/opt/opt_solver.h +++ b/src/opt/opt_solver.h @@ -107,7 +107,7 @@ namespace opt { virtual ast_manager& get_manager() const { return m; } virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes); virtual lbool preferred_sat(expr_ref_vector const& asms, vector& cores); - virtual expr_ref lookahead(expr_ref_vector const& candidates) { return expr_ref(m.mk_true(), m); } + virtual expr_ref lookahead(expr_ref_vector const& assumptions, expr_ref_vector const& candidates) { return expr_ref(m.mk_true(), m); } void set_logic(symbol const& logic); smt::theory_var add_objective(app* term); diff --git a/src/sat/sat_lookahead.cpp b/src/sat/sat_lookahead.cpp index 97b4ef073..eeeff2c11 100644 --- a/src/sat/sat_lookahead.cpp +++ b/src/sat/sat_lookahead.cpp @@ -30,6 +30,18 @@ namespace sat { if (p.m_s.m_ext) p.m_s.m_ext->set_lookahead(0); } + lookahead::scoped_assumptions::scoped_assumptions(lookahead& p, literal_vector const& lits): p(p), lits(lits) { + for (auto l : lits) { + p.push(l, p.c_fixed_truth); + } + } + lookahead::scoped_assumptions::~scoped_assumptions() { + for (auto l : lits) { + p.pop(); + } + } + + void lookahead::flip_prefix() { if (m_trail_lim.size() < 64) { uint64 mask = (1ull << m_trail_lim.size()); @@ -1612,7 +1624,7 @@ namespace sat { } - literal lookahead::select_lookahead(bool_var_vector const& vars) { + literal lookahead::select_lookahead(literal_vector const& assumptions, bool_var_vector const& vars) { IF_VERBOSE(1, verbose_stream() << "(sat-select " << vars.size() << ")\n";); scoped_ext _sext(*this); m_search_mode = lookahead_mode::searching; @@ -1623,19 +1635,25 @@ namespace sat { for (auto v : vars) { m_select_lookahead_vars.insert(v); } + + scoped_assumptions _sa(*this, assumptions); literal l = choose(); m_select_lookahead_vars.reset(); - if (inconsistent()) return null_literal; + if (inconsistent()) l = null_literal; +#if 0 // assign unit literals that were found during search for lookahead. - unsigned num_assigned = 0; - for (literal lit : m_trail) { - if (!m_s.was_eliminated(lit.var()) && m_s.value(lit) != l_true) { - m_s.assign(lit, justification()); - ++num_assigned; + if (assumptions.empty()) { + unsigned num_assigned = 0; + for (literal lit : m_trail) { + if (!m_s.was_eliminated(lit.var()) && m_s.value(lit) != l_true) { + m_s.assign(lit, justification()); + ++num_assigned; + } } + IF_VERBOSE(1, verbose_stream() << "(sat-lookahead :units " << num_assigned << ")\n";); } - IF_VERBOSE(1, verbose_stream() << "(sat-lookahead :units " << num_assigned << ")\n";); +#endif return l; } diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h index 028e94786..2b8ddc594 100644 --- a/src/sat/sat_lookahead.h +++ b/src/sat/sat_lookahead.h @@ -199,6 +199,14 @@ namespace sat { ~scoped_ext(); }; + class scoped_assumptions { + lookahead& p; + literal_vector lits; + public: + scoped_assumptions(lookahead& p, literal_vector const& lits); + ~scoped_assumptions(); + }; + // ------------------------------------- // prefix updates. I use low order bits. @@ -447,7 +455,7 @@ namespace sat { return search(); } - literal select_lookahead(bool_var_vector const& vars); + literal select_lookahead(literal_vector const& assumptions, bool_var_vector const& vars); /** \brief simplify set of clauses by extracting units from a lookahead at base level. */ diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 1f8f3aa19..77a816660 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -811,14 +811,9 @@ namespace sat { return r; } - literal solver::select_lookahead(bool_var_vector const& vars) { + literal solver::select_lookahead(literal_vector const& assumptions, bool_var_vector const& vars) { lookahead lh(*this); - literal result = lh.select_lookahead(vars); - if (result == null_literal) { - set_conflict(justification()); - } - // extract unit literals from lh - return result; + return lh.select_lookahead(assumptions, vars); } // ----------------------- @@ -851,8 +846,8 @@ namespace sat { } #endif try { - if (inconsistent()) return l_false; init_search(); + if (inconsistent()) return l_false; propagate(false); if (inconsistent()) return l_false; init_assumptions(num_lits, lits); diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 348e1cd22..b6dffe511 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -355,7 +355,7 @@ namespace sat { void set_model(model const& mdl); char const* get_reason_unknown() const { return m_reason_unknown.c_str(); } - literal select_lookahead(bool_var_vector const& vars); + literal select_lookahead(literal_vector const& assumptions, bool_var_vector const& vars); protected: unsigned m_conflicts; diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index 3fa9f4c81..9d9c24b8f 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -148,6 +148,8 @@ public: virtual lbool check_sat(unsigned sz, expr * const * assumptions) { m_solver.pop_to_base_level(); + m_core.reset(); + if (m_solver.inconsistent()) return l_false; expr_ref_vector _assumptions(m); obj_map asm2fml; for (unsigned i = 0; i < sz; ++i) { @@ -280,9 +282,10 @@ public: return 0; } - virtual expr_ref lookahead(expr_ref_vector const& candidates) { + virtual expr_ref lookahead(expr_ref_vector const& assumptions, expr_ref_vector const& candidates) { IF_VERBOSE(1, verbose_stream() << "(sat-lookahead " << candidates.size() << ")\n";); sat::bool_var_vector vars; + sat::literal_vector lits; expr_ref_vector lit2expr(m); lit2expr.resize(m_solver.num_vars() * 2); m_map.mk_inv(lit2expr); @@ -292,11 +295,29 @@ public: vars.push_back(v); } } + for (auto c : assumptions) { + SASSERT(is_literal(c)); + sat::bool_var v = sat::null_bool_var; + bool sign = false; + expr* e = c; + while (m.is_not(e, e)) { + sign = !sign; + } + if (is_uninterp_const(e)) { + v = m_map.to_bool_var(e); + } + if (v != sat::null_bool_var) { + lits.push_back(sat::literal(v, sign)); + } + else { + IF_VERBOSE(0, verbose_stream() << "WARNING: could not handle " << mk_pp(c, m) << "\n";); + } + } IF_VERBOSE(1, verbose_stream() << "vars: " << vars.size() << "\n";); if (vars.empty()) { return expr_ref(m.mk_true(), m); } - sat::literal l = m_solver.select_lookahead(vars); + sat::literal l = m_solver.select_lookahead(lits, vars); if (m_solver.inconsistent()) { IF_VERBOSE(1, verbose_stream() << "(sat-lookahead inconsistent)\n";); return expr_ref(m.mk_false(), m); @@ -715,7 +736,7 @@ private: if (asm2fml.contains(e)) { e = asm2fml.find(e); } - m_core.push_back(e); + m_core.push_back(e); } } diff --git a/src/smt/smt_solver.cpp b/src/smt/smt_solver.cpp index 2ca5a11c3..ef81fbfd2 100644 --- a/src/smt/smt_solver.cpp +++ b/src/smt/smt_solver.cpp @@ -219,7 +219,7 @@ namespace smt { return m_context.get_formulas()[idx]; } - virtual expr_ref lookahead(expr_ref_vector const& candidates) { + virtual expr_ref lookahead(expr_ref_vector const& assumptions, expr_ref_vector const& candidates) { ast_manager& m = get_manager(); return expr_ref(m.mk_true(), m); } diff --git a/src/solver/combined_solver.cpp b/src/solver/combined_solver.cpp index 5d6cd232f..b23aabc2d 100644 --- a/src/solver/combined_solver.cpp +++ b/src/solver/combined_solver.cpp @@ -274,8 +274,8 @@ public: return m_solver1->get_num_assumptions() + m_solver2->get_num_assumptions(); } - virtual expr_ref lookahead(expr_ref_vector const& candidates) { - return m_solver1->lookahead(candidates); + virtual expr_ref lookahead(expr_ref_vector const& assumptions, expr_ref_vector const& candidates) { + return m_solver1->lookahead(assumptions, candidates); } virtual expr * get_assumption(unsigned idx) const { diff --git a/src/solver/solver.h b/src/solver/solver.h index 56890f7c0..5346cf4a4 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -176,7 +176,7 @@ public: \brief extract a lookahead candidates for branching. */ - virtual expr_ref lookahead(expr_ref_vector const& candidates) = 0; + virtual expr_ref lookahead(expr_ref_vector const& assumptions, expr_ref_vector const& candidates) = 0; /** \brief extract learned lemmas. diff --git a/src/solver/tactic2solver.cpp b/src/solver/tactic2solver.cpp index f9d5a4b0f..d3d8e59ce 100644 --- a/src/solver/tactic2solver.cpp +++ b/src/solver/tactic2solver.cpp @@ -76,7 +76,7 @@ public: virtual ast_manager& get_manager() const; - virtual expr_ref lookahead(expr_ref_vector const& candidates) { + virtual expr_ref lookahead(expr_ref_vector const& assumptions, expr_ref_vector const& candidates) { ast_manager& m = get_manager(); std::cout << "tactic2solver\n"; return expr_ref(m.mk_true(), m); diff --git a/src/tactic/portfolio/bounded_int2bv_solver.cpp b/src/tactic/portfolio/bounded_int2bv_solver.cpp index aeaa88bc2..746856543 100644 --- a/src/tactic/portfolio/bounded_int2bv_solver.cpp +++ b/src/tactic/portfolio/bounded_int2bv_solver.cpp @@ -149,7 +149,7 @@ public: virtual void set_reason_unknown(char const* msg) { m_solver->set_reason_unknown(msg); } virtual void get_labels(svector & r) { m_solver->get_labels(r); } virtual ast_manager& get_manager() const { return m; } - virtual expr_ref lookahead(expr_ref_vector const& candidates) { flush_assertions(); return m_solver->lookahead(candidates); } + virtual expr_ref lookahead(expr_ref_vector const& assumptions, expr_ref_vector const& candidates) { flush_assertions(); return m_solver->lookahead(assumptions, candidates); } virtual void get_lemmas(expr_ref_vector & lemmas) { flush_assertions(); m_solver->get_lemmas(lemmas); } virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) { return m_solver->find_mutexes(vars, mutexes); } virtual lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) { diff --git a/src/tactic/portfolio/enum2bv_solver.cpp b/src/tactic/portfolio/enum2bv_solver.cpp index d50dee57c..67be432c2 100644 --- a/src/tactic/portfolio/enum2bv_solver.cpp +++ b/src/tactic/portfolio/enum2bv_solver.cpp @@ -97,7 +97,7 @@ public: virtual void get_labels(svector & r) { m_solver->get_labels(r); } virtual ast_manager& get_manager() const { return m; } virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) { return m_solver->find_mutexes(vars, mutexes); } - virtual expr_ref lookahead(expr_ref_vector const& candidates) { return m_solver->lookahead(candidates); } + virtual expr_ref lookahead(expr_ref_vector const& assumptions, expr_ref_vector const& candidates) { return m_solver->lookahead(assumptions, candidates); } virtual void get_lemmas(expr_ref_vector & lemmas) { m_solver->get_lemmas(lemmas); } virtual lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) { diff --git a/src/tactic/portfolio/pb2bv_solver.cpp b/src/tactic/portfolio/pb2bv_solver.cpp index 46a5912b0..27e2a5850 100644 --- a/src/tactic/portfolio/pb2bv_solver.cpp +++ b/src/tactic/portfolio/pb2bv_solver.cpp @@ -93,7 +93,7 @@ public: virtual void set_reason_unknown(char const* msg) { m_solver->set_reason_unknown(msg); } virtual void get_labels(svector & r) { m_solver->get_labels(r); } virtual ast_manager& get_manager() const { return m; } - virtual expr_ref lookahead(expr_ref_vector const& candidates) { flush_assertions(); return m_solver->lookahead(candidates); } + virtual expr_ref lookahead(expr_ref_vector const& assumptions, expr_ref_vector const& candidates) { flush_assertions(); return m_solver->lookahead(assumptions, candidates); } virtual void get_lemmas(expr_ref_vector & lemmas) { flush_assertions(); m_solver->get_lemmas(lemmas); } virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) { return m_solver->find_mutexes(vars, mutexes); } virtual lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) {