diff --git a/src/api/api_solver.cpp b/src/api/api_solver.cpp index 6dc41efb2..c6eae7488 100644 --- a/src/api/api_solver.cpp +++ b/src/api/api_solver.cpp @@ -477,4 +477,60 @@ extern "C" { Z3_CATCH_RETURN(Z3_L_UNDEF); } + Z3_ast Z3_API Z3_solver_lookahead(Z3_context c, + Z3_solver s, + Z3_ast_vector candidates) { + Z3_TRY; + LOG_Z3_solver_lookahead(c, s, candidates); + ast_manager& m = mk_c(c)->m(); + expr_ref_vector _candidates(m); + ast_ref_vector const& __candidates = to_ast_vector_ref(candidates); + for (auto & e : __candidates) { + if (!is_expr(e)) { + SET_ERROR_CODE(Z3_INVALID_USAGE); + return 0; + } + _candidates.push_back(to_expr(e)); + } + + expr_ref result(m); + unsigned timeout = to_solver(s)->m_params.get_uint("timeout", mk_c(c)->get_timeout()); + unsigned rlimit = to_solver(s)->m_params.get_uint("rlimit", mk_c(c)->get_rlimit()); + bool use_ctrl_c = to_solver(s)->m_params.get_bool("ctrl_c", false); + cancel_eh eh(mk_c(c)->m().limit()); + api::context::set_interruptable si(*(mk_c(c)), eh); + { + scoped_ctrl_c ctrlc(eh, false, use_ctrl_c); + scoped_timer timer(timeout, &eh); + scoped_rlimit _rlimit(mk_c(c)->m().limit(), rlimit); + try { + result = to_solver_ref(s)->lookahead(_candidates); + } + catch (z3_exception & ex) { + mk_c(c)->handle_exception(ex); + return 0; + } + } + mk_c(c)->save_ast_trail(result); + RETURN_Z3(of_ast(result)); + Z3_CATCH_RETURN(0); + } + + Z3_ast_vector Z3_API Z3_solver_get_lemmas(Z3_context c, Z3_solver s) { + Z3_TRY; + LOG_Z3_solver_get_lemmas(c, s); + RESET_ERROR_CODE(); + ast_manager& m = mk_c(c)->m(); + init_solver(c, s); + Z3_ast_vector_ref * v = alloc(Z3_ast_vector_ref, *mk_c(c), mk_c(c)->m()); + mk_c(c)->save_object(v); + expr_ref_vector lemmas(m); + to_solver_ref(s)->get_lemmas(lemmas); + for (expr* e : lemmas) { + v->m_ast_vector.push_back(e); + } + RETURN_Z3(of_ast_vector(v)); + Z3_CATCH_RETURN(0); + } + }; diff --git a/src/api/dotnet/Solver.cs b/src/api/dotnet/Solver.cs index dff2677df..6dd19cddf 100644 --- a/src/api/dotnet/Solver.cs +++ b/src/api/dotnet/Solver.cs @@ -252,6 +252,29 @@ namespace Microsoft.Z3 return lboolToStatus(r); } + /// + /// Select a lookahead literal from the set of supplied candidates. + /// + public BoolExpr Lookahead(IEnumerable candidates) + { + ASTVector cands = new ASTVector(Context); + foreach (var c in candidates) cands.Push(c); + return (BoolExpr)Expr.Create(Context, Native.Z3_solver_lookahead(Context.nCtx, NativeObject, cands.NativeObject)); + } + + /// + /// Retrieve set of lemmas that have been inferred by solver. + /// + public BoolExpr[] Lemmas + { + get + { + var r = Native.Z3_solver_get_lemmas(Context.nCtx, NativeObject); + var v = new ASTVector(Context, r); + return v.ToBoolExprArray(); + } + } + /// /// The model of the last Check. /// diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 84a80ddf7..7e93e58ae 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -6216,6 +6216,21 @@ class Solver(Z3PPObject): consequences = [ consequences[i] for i in range(sz) ] return CheckSatResult(r), consequences + def lemmas(self): + """Extract auxiliary lemmas produced by solver""" + return AstVector(Z3_solver_get_lemmas(self.ctx.ref(), self.solver), self.ctx) + + def lookahead(self, candidates = None): + """Get lookahead literal""" + if candidates is None: + candidates = AstVector(None, self.ctx) + elif not isinstance(candidates, AstVector): + _cs = AstVector(None, self.ctx) + for c in candidates: + _asms.push(c) + candidates = _cs + return _to_expr_ref(Z3_solver_lookahead(self.ctx.ref(), self.solver, candidates), self.ctx) + def proof(self): """Return a proof for the last `check()`. Proof construction must be enabled.""" return _to_expr_ref(Z3_solver_get_proof(self.ctx.ref(), self.solver), self.ctx) diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 45065f856..e3bd942bd 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -6023,6 +6023,29 @@ extern "C" { Z3_ast_vector assumptions, Z3_ast_vector variables, Z3_ast_vector consequences); + + /** + \brief select a literal from the list of candidate propositional variables to split on. + If the candidate list is empty, then the solver chooses a formula based on its internal state. + + def_API('Z3_solver_lookahead', AST, (_in(CONTEXT), _in(SOLVER), _in(AST_VECTOR))) + */ + + Z3_ast Z3_API Z3_solver_lookahead(Z3_context c, Z3_solver s, Z3_ast_vector candidates); + + + /** + \brief retrieve lemmas from solver state. Lemmas are auxiliary unit literals, + binary clauses and other learned clauses that are below a minimal glue level. + Lemmas that have been retrieved in a previous call may be suppressed from subsequent + calls. + + def_API('Z3_solver_get_lemmas', AST_VECTOR, (_in(CONTEXT), _in(SOLVER))) + */ + + Z3_ast_vector Z3_API Z3_solver_get_lemmas(Z3_context c, Z3_solver s); + + /** \brief Retrieve the model for the last #Z3_solver_check or #Z3_solver_check_assumptions @@ -6031,6 +6054,7 @@ extern "C" { def_API('Z3_solver_get_model', MODEL, (_in(CONTEXT), _in(SOLVER))) */ + Z3_model Z3_API Z3_solver_get_model(Z3_context c, Z3_solver s); /** diff --git a/src/sat/sat_lookahead.h b/src/sat/sat_lookahead.h index 7f01197f2..8f7a7675f 100644 --- a/src/sat/sat_lookahead.h +++ b/src/sat/sat_lookahead.h @@ -384,6 +384,7 @@ namespace sat { candidate(bool_var v, float r): m_var(v), m_rating(r) {} }; svector m_candidates; + uint_set m_select_lookahead_vars; float get_rating(bool_var v) const { return m_rating[v]; } float get_rating(literal l) const { return get_rating(l.var()); } @@ -468,7 +469,11 @@ namespace sat { for (bool_var const* it = m_freevars.begin(), * end = m_freevars.end(); it != end; ++it) { SASSERT(is_undef(*it)); bool_var x = *it; - if (newbies || active_prefix(x)) { + if (!m_select_lookahead_vars.empty() && m_select_lookahead_vars.contains(x)) { + m_candidates.push_back(candidate(x, m_rating[x])); + sum += m_rating[x]; + } + else if (newbies || active_prefix(x)) { m_candidates.push_back(candidate(x, m_rating[x])); sum += m_rating[x]; } @@ -1853,6 +1858,31 @@ namespace sat { return search(); } + literal select_lookahead(bool_var_vector const& vars) { + m_search_mode = lookahead_mode::searching; + scoped_level _sl(*this, c_fixed_truth); + init(); + if (inconsistent()) return null_literal; + inc_istamp(); + for (auto v : vars) { + m_select_lookahead_vars.insert(v); + } + literal l = choose(); + m_select_lookahead_vars.reset(); + if (inconsistent()) return null_literal; + + // assign unit literals that were found during search for lookahead. + unsigned num_assigned = 0; + for (literal lit : m_trail) { + if (!m_s.was_eliminated(lit.var()) && m_s.value(lit) != l_true) { + m_s.assign(lit, justification()); + ++num_assigned; + } + } + IF_VERBOSE(1, verbose_stream() << "(sat-lookahead :units " << num_assigned << ")\n";); + return l; + } + /** \brief simplify set of clauses by extracting units from a lookahead at base level. */ diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 8095ee12d..fd26bdd70 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -809,6 +809,16 @@ namespace sat { return r; } + literal solver::select_lookahead(bool_var_vector const& vars) { + lookahead lh(*this); + literal result = lh.select_lookahead(vars); + if (result == null_literal) { + set_conflict(justification()); + } + // extract unit literals from lh + return result; + } + // ----------------------- // // Search diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index c512597b0..54ea360e4 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -352,6 +352,8 @@ namespace sat { model_converter const & get_model_converter() const { return m_mc; } void set_model(model const& mdl); + literal select_lookahead(bool_var_vector const& vars); + protected: unsigned m_conflicts; unsigned m_restarts; diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index 5016d20cd..225fd8d63 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -255,6 +255,28 @@ public: return 0; } + virtual expr_ref lookahead(expr_ref_vector const& candidates) { + sat::bool_var_vector vars; + u_map var2candidate; + for (auto c : candidates) { + // TBD: check membership + sat::bool_var v = m_map.to_bool_var(c); + SASSERT(v != sat::null_bool_var); + vars.push_back(v); + var2candidate.insert(v, c); + } + sat::literal l = m_solver.select_lookahead(vars); + if (l == sat::null_literal) { + return expr_ref(m.mk_true(), m); + } + expr* e; + if (!var2candidate.find(l.var(), e)) { + // TBD: if candidate set is empty, then do something else. + e = m.mk_true(); + } + return expr_ref(l.sign() ? m.mk_not(e) : e, m); + } + virtual lbool get_consequences_core(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector& conseq) { init_preprocess(); TRACE("sat", tout << assumptions << "\n" << vars << "\n";); diff --git a/src/solver/solver.cpp b/src/solver/solver.cpp index 9163cfeda..59b950972 100644 --- a/src/solver/solver.cpp +++ b/src/solver/solver.cpp @@ -162,3 +162,9 @@ bool solver::is_literal(ast_manager& m, expr* e) { return is_uninterp_const(e) || (m.is_not(e, e) && is_uninterp_const(e)); } +expr_ref solver::lookahead(expr_ref_vector const& candidates) { + ast_manager& m = candidates.get_manager(); + return expr_ref(m.mk_true(), m); +} + + diff --git a/src/solver/solver.h b/src/solver/solver.h index 6b9d38f29..51bff08ad 100644 --- a/src/solver/solver.h +++ b/src/solver/solver.h @@ -172,6 +172,17 @@ public: */ virtual lbool preferred_sat(expr_ref_vector const& asms, vector& cores); + /** + \brief extract a lookahead candidates for branching. + */ + + virtual expr_ref lookahead(expr_ref_vector const& candidates); + + /** + \brief extract learned lemmas. + */ + virtual void get_lemmas(expr_ref_vector& lemmas) {} + /** \brief Display the content of this solver. */ diff --git a/src/tactic/portfolio/bounded_int2bv_solver.cpp b/src/tactic/portfolio/bounded_int2bv_solver.cpp index 83693abba..cee8688e5 100644 --- a/src/tactic/portfolio/bounded_int2bv_solver.cpp +++ b/src/tactic/portfolio/bounded_int2bv_solver.cpp @@ -149,6 +149,7 @@ public: virtual void set_reason_unknown(char const* msg) { m_solver->set_reason_unknown(msg); } virtual void get_labels(svector & r) { m_solver->get_labels(r); } virtual ast_manager& get_manager() const { return m; } + virtual expr_ref lookahead(expr_ref_vector& candidates) { flush_assertions(); return m_solver->lookahead(candidates); } virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) { return m_solver->find_mutexes(vars, mutexes); } virtual lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) { flush_assertions(); diff --git a/src/tactic/portfolio/enum2bv_solver.cpp b/src/tactic/portfolio/enum2bv_solver.cpp index 35601f374..ef7ee6cd8 100644 --- a/src/tactic/portfolio/enum2bv_solver.cpp +++ b/src/tactic/portfolio/enum2bv_solver.cpp @@ -97,6 +97,7 @@ public: virtual void get_labels(svector & r) { m_solver->get_labels(r); } virtual ast_manager& get_manager() const { return m; } virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) { return m_solver->find_mutexes(vars, mutexes); } + virtual expr_ref lookahead(expr_ref_vector& candidates) { return m_solver->lookahead(candidates); } virtual lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) { datatype_util dt(m); diff --git a/src/tactic/portfolio/pb2bv_solver.cpp b/src/tactic/portfolio/pb2bv_solver.cpp index c8aa82e97..673db03c0 100644 --- a/src/tactic/portfolio/pb2bv_solver.cpp +++ b/src/tactic/portfolio/pb2bv_solver.cpp @@ -93,6 +93,7 @@ public: virtual void set_reason_unknown(char const* msg) { m_solver->set_reason_unknown(msg); } virtual void get_labels(svector & r) { m_solver->get_labels(r); } virtual ast_manager& get_manager() const { return m; } + virtual expr_ref lookahead(expr_ref_vector& candidates) { flush_assertions(); return m_solver->lookahead(candidates); } virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) { return m_solver->find_mutexes(vars, mutexes); } virtual lbool get_consequences_core(expr_ref_vector const& asms, expr_ref_vector const& vars, expr_ref_vector& consequences) { flush_assertions();