3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-04 16:44:07 +00:00

use structured proof hints

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2022-05-28 09:37:41 -07:00
parent 7da9f12521
commit dd46224a1d
10 changed files with 233 additions and 63 deletions

View file

@ -1395,7 +1395,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
"version": "3.8.8"
}
},
"nbformat": 4,

View file

@ -19,8 +19,10 @@ Author:
Notes:
--*/
#include "sat_solver.h"
#include "sat_drat.h"
#include "util/rational.h"
#include "sat/sat_solver.h"
#include "sat/sat_drat.h"
namespace sat {
@ -137,13 +139,13 @@ namespace sat {
}
}
buffer[len++] = '0';
if (st.get_pragma()) {
if (st.get_hint()) {
buffer[len++] = ' ';
buffer[len++] = 'p';
buffer[len++] = ' ';
char const* ps = st.get_pragma();
while (*ps)
buffer[len++] = *ps++;
auto* ps = st.get_hint();
for (auto ch : ps->to_string())
buffer[len++] = ch;
}
buffer[len++] = '\n';
m_out->write(buffer, len);
@ -905,6 +907,106 @@ namespace sat {
if (!st.is_sat())
out << " " << p.th(st.get_th());
return out;
}
}
std::string proof_hint::to_string() const {
std::ostringstream ous;
switch (m_ty) {
case hint_type::null_h:
return std::string();
case hint_type::farkas_h:
ous << "farkas ";
break;
case hint_type::bound_h:
ous << "bound ";
break;
case hint_type::cut_h:
ous << "cut ";
break;
}
for (auto const& [q, l] : m_literals)
ous << rational(q) << " * " << l << " ";
for (auto const& [q, a, b] : m_eqs)
ous << rational(q) << " = " << a << " " << b << " ";
return ous.str();
}
proof_hint proof_hint::from_string(char const* s) {
proof_hint h;
h.m_ty = hint_type::null_h;
if (!s)
return h;
auto ws = [&]() {
while (*s == ' ' || *s == '\n' || *s == '\t')
++s;
};
auto parse_type = [&]() {
if (0 == strncmp(s, "farkas", 6)) {
h.m_ty = hint_type::farkas_h;
s += 6;
return true;
}
if (0 == strncmp(s, "bound", 5)) {
h.m_ty = hint_type::bound_h;
s += 5;
return true;
}
return false;
};
sbuffer<char> buff;
auto parse_coeff = [&]() {
buff.reset();
while (*s && *s != ' ') {
buff.push_back(*s);
++s;
}
buff.push_back(0);
return rational(buff.data());
};
auto parse_literal = [&]() {
rational r = parse_coeff();
if (!r.is_int())
return sat::null_literal;
if (r < 0)
return sat::literal((-r).get_unsigned(), true);
return sat::literal(r.get_unsigned(), false);
};
auto parse_coeff_literal = [&]() {
rational coeff = parse_coeff();
ws();
if (*s == '*') {
++s;
ws();
sat::literal lit = parse_literal();
h.m_literals.push_back(std::make_pair(coeff, lit));
return true;
}
if (*s == '=') {
++s;
ws();
unsigned a = parse_coeff().get_unsigned();
ws();
unsigned b = parse_coeff().get_unsigned();
h.m_eqs.push_back(std::make_tuple(coeff, a, b));
return true;
}
return false;
};
ws();
if (!parse_type())
return h;
ws();
while (*s) {
if (!parse_coeff_literal())
return h;
ws();
}
return h;
}
}

View file

@ -28,6 +28,7 @@ Revision History:
#include "util/stopwatch.h"
#include "util/symbol.h"
#include "util/sat_literal.h"
#include "util/rational.h"
class params_ref;
class reslimit;
@ -93,29 +94,44 @@ namespace sat {
};
enum class hint_type {
null_h,
farkas_h,
bound_h,
cut_h
};
struct proof_hint {
hint_type m_ty;
vector<std::pair<rational, literal>> m_literals;
vector<std::tuple<rational, unsigned, unsigned>> m_eqs;
std::string to_string() const;
static proof_hint from_string(char const* s);
};
class status {
public:
enum class st { input, asserted, redundant, deleted };
st m_st;
int m_orig;
char const* m_pragma;
const proof_hint* m_hint;
public:
status(st s, int o, char const* ps = nullptr) : m_st(s), m_orig(o), m_pragma(ps) {};
status(status const& s) : m_st(s.m_st), m_orig(s.m_orig), m_pragma(s.m_pragma) {}
status(status&& s) noexcept { m_st = st::asserted; m_orig = -1; std::swap(m_st, s.m_st); std::swap(m_orig, s.m_orig); std::swap(m_pragma, s.m_pragma); }
status(st s, int o, proof_hint const* ps = nullptr) : m_st(s), m_orig(o), m_hint(ps) {};
status(status const& s) : m_st(s.m_st), m_orig(s.m_orig), m_hint(s.m_hint) {}
status(status&& s) noexcept { m_st = st::asserted; m_orig = -1; std::swap(m_st, s.m_st); std::swap(m_orig, s.m_orig); std::swap(m_hint, s.m_hint); }
status& operator=(status const& other) { m_st = other.m_st; m_orig = other.m_orig; return *this; }
static status redundant() { return status(status::st::redundant, -1); }
static status asserted() { return status(status::st::asserted, -1); }
static status deleted() { return status(status::st::deleted, -1); }
static status input() { return status(status::st::input, -1); }
static status th(bool redundant, int id, char const* ps = nullptr) { return status(redundant ? st::redundant : st::asserted, id, ps); }
static status th(bool redundant, int id, proof_hint const* ps = nullptr) { return status(redundant ? st::redundant : st::asserted, id, ps); }
bool is_input() const { return st::input == m_st; }
bool is_redundant() const { return st::redundant == m_st; }
bool is_asserted() const { return st::asserted == m_st; }
bool is_deleted() const { return st::deleted == m_st; }
char const* get_pragma() const { return m_pragma; }
proof_hint const* get_hint() const { return m_hint; }
bool is_sat() const { return -1 == m_orig; }
int get_th() const { return m_orig; }

View file

@ -233,46 +233,55 @@ namespace arith {
SASSERT(b1.get_var() == b2.get_var());
if (k1 == k2 && kind1 == kind2) return;
SASSERT(k1 != k2 || kind1 != kind2);
char const* bound_params = "farkas 1 1";
auto bin_clause = [&](sat::literal l1, sat::literal l2) {
sat::proof_hint* bound_params = nullptr;
if (ctx.use_drat()) {
bound_params = &m_farkas2;
m_farkas2.m_literals[0] = std::make_pair(rational(1), l1);
m_farkas2.m_literals[1] = std::make_pair(rational(1), l2);
}
add_clause(l1, l2, bound_params);
};
if (kind1 == lp_api::lower_t) {
if (kind2 == lp_api::lower_t) {
if (k2 <= k1)
add_clause(~l1, l2, bound_params);
bin_clause(~l1, l2);
else
add_clause(l1, ~l2, bound_params);
bin_clause(l1, ~l2);
}
else if (k1 <= k2)
// k1 <= k2, k1 <= x or x <= k2
add_clause(l1, l2);
bin_clause(l1, l2);
else {
// k1 > hi_inf, k1 <= x => ~(x <= hi_inf)
add_clause(~l1, ~l2, bound_params);
bin_clause(~l1, ~l2);
if (v_is_int && k1 == k2 + rational(1))
// k1 <= x or x <= k1-1
add_clause(l1, l2, bound_params);
bin_clause(l1, l2);
}
}
else if (kind2 == lp_api::lower_t) {
if (k1 >= k2)
// k1 >= lo_inf, k1 >= x or lo_inf <= x
add_clause(l1, l2, bound_params);
bin_clause(l1, l2);
else {
// k1 < k2, k2 <= x => ~(x <= k1)
add_clause(~l1, ~l2, bound_params);
bin_clause(~l1, ~l2);
if (v_is_int && k1 == k2 - rational(1))
// x <= k1 or k1+l <= x
add_clause(l1, l2, bound_params);
bin_clause(l1, l2);
}
}
else {
// kind1 == A_UPPER, kind2 == A_UPPER
if (k1 >= k2)
// k1 >= k2, x <= k2 => x <= k1
add_clause(l1, ~l2, bound_params);
bin_clause(l1, ~l2);
else
// k1 <= hi_sup , x <= k1 => x <= hi_sup
add_clause(~l1, l2, bound_params);
bin_clause(~l1, l2);
}
}

View file

@ -80,16 +80,38 @@ namespace arith {
if (m_nla) m_nla->collect_statistics(st);
}
char const* solver::bounds_pragma() {
/**
* Assumption:
* A bound literal ax <= b is explained by a set of weighted literals
* r1*(a1*x <= b1) + .... + r_k*(a_k*x <= b_k), where r_i > 0
* such that there is a r >= 1
* (r1*a1+..+r_k*a_k) = r*a, (r1*b1+..+r_k*b_k) <= r*b
*/
sat::proof_hint const* solver::explain(sat::hint_type ty) {
if (!ctx.use_drat())
return nullptr;
m_bounds_pragma.clear();
m_bounds_pragma += "bounds ";
for (sat::literal c : m_core) {
if (c.sign()) m_bounds_pragma += "-";
m_bounds_pragma += std::to_string(c.var());
m_bounds_pragma += " ";
m_bounds_pragma.m_ty = ty;
m_bounds_pragma.m_literals.reset();
m_bounds_pragma.m_eqs.reset();
for (auto ev : m_explanation) {
auto idx = ev.ci();
if (UINT_MAX == idx)
continue;
switch (m_constraint_sources[idx]) {
case inequality_source: {
literal lit = m_inequalities[idx];
m_bounds_pragma.m_literals.push_back({ev.coeff(), lit});
break;
}
case equality_source: {
auto [u, v] = m_equalities[idx];
m_bounds_pragma.m_eqs.push_back({ev.coeff(), u->get_expr_id(), v->get_expr_id()});
break;
}
default:
break;
}
}
return m_bounds_pragma.c_str();
return &m_bounds_pragma;
}
}

View file

@ -39,6 +39,8 @@ namespace arith {
lp().settings().set_random_seed(get_config().m_random_seed);
m_lia = alloc(lp::int_solver, *m_solver.get());
m_farkas2.m_ty = sat::hint_type::farkas_h;
m_farkas2.m_literals.resize(2);
}
solver::~solver() {
@ -195,7 +197,13 @@ namespace arith {
reset_evidence();
m_core.push_back(lit1);
TRACE("arith", tout << lit2 << " <- " << m_core << "\n";);
assign(lit2, m_core, m_eqs, "farkas 1 1");
sat::proof_hint* ph = nullptr;
if (ctx.use_drat()) {
ph = &m_farkas2;
m_farkas2.m_literals[0] = std::make_pair(rational(1), lit1);
m_farkas2.m_literals[1] = std::make_pair(rational(1), lit2);
}
assign(lit2, m_core, m_eqs, ph);
++m_stats.m_bounds_propagations;
}
@ -255,7 +263,7 @@ namespace arith {
TRACE("arith", for (auto lit : m_core) tout << lit << ": " << s().value(lit) << "\n";);
DEBUG_CODE(for (auto lit : m_core) { VERIFY(s().value(lit) == l_true); });
++m_stats.m_bound_propagations1;
assign(lit, m_core, m_eqs, bounds_pragma());
assign(lit, m_core, m_eqs, explain(sat::hint_type::bound_h));
}
if (should_refine_bounds() && first)
@ -370,7 +378,7 @@ namespace arith {
reset_evidence();
m_explanation.clear();
lp().explain_implied_bound(be, m_bp);
assign(bound, m_core, m_eqs, nullptr);
assign(bound, m_core, m_eqs, explain(sat::hint_type::bound_h));
}
@ -1169,7 +1177,7 @@ namespace arith {
app_ref b = mk_bound(m_lia->get_term(), m_lia->get_offset(), !m_lia->is_upper());
IF_VERBOSE(4, verbose_stream() << "cut " << b << "\n");
literal lit = expr2literal(b);
assign(lit, m_core, m_eqs, nullptr);
assign(lit, m_core, m_eqs, explain(sat::hint_type::cut_h));
lia_check = l_false;
break;
}
@ -1191,7 +1199,7 @@ namespace arith {
return lia_check;
}
void solver::assign(literal lit, literal_vector const& core, svector<enode_pair> const& eqs, char const* pma) {
void solver::assign(literal lit, literal_vector const& core, svector<enode_pair> const& eqs, sat::proof_hint const* pma) {
if (core.size() < small_lemma_size() && eqs.empty()) {
m_core2.reset();
for (auto const& c : core)
@ -1238,7 +1246,7 @@ namespace arith {
for (literal& c : m_core)
c.neg();
add_clause(m_core);
add_clause(m_core, explain(sat::hint_type::farkas_h));
}
bool solver::is_infeasible() const {

View file

@ -414,13 +414,14 @@ namespace arith {
void set_conflict();
void set_conflict_or_lemma(literal_vector const& core, bool is_conflict);
void set_evidence(lp::constraint_index idx);
void assign(literal lit, literal_vector const& core, svector<enode_pair> const& eqs, char const* pma);
void assign(literal lit, literal_vector const& core, svector<enode_pair> const& eqs, sat::proof_hint const* pma);
void false_case_of_check_nla(const nla::lemma& l);
void dbg_finalize_model(model& mdl);
std::string m_bounds_pragma;
char const* bounds_pragma();
sat::proof_hint m_bounds_pragma;
sat::proof_hint m_farkas2;
sat::proof_hint const* explain(sat::hint_type ty);
public:

View file

@ -125,7 +125,7 @@ namespace euf {
pop_core(n);
}
sat::status th_euf_solver::mk_status(char const* ps) {
sat::status th_euf_solver::mk_status(sat::proof_hint const* ps) {
return sat::status::th(m_is_redundant, get_id(), ps);
}
@ -149,7 +149,7 @@ namespace euf {
return add_clause(2, lits);
}
bool th_euf_solver::add_clause(sat::literal a, sat::literal b, char const* ps) {
bool th_euf_solver::add_clause(sat::literal a, sat::literal b, sat::proof_hint const* ps) {
sat::literal lits[2] = { a, b };
return add_clause(2, lits, ps);
}
@ -164,7 +164,7 @@ namespace euf {
return add_clause(4, lits);
}
bool th_euf_solver::add_clause(unsigned n, sat::literal* lits, char const* ps) {
bool th_euf_solver::add_clause(unsigned n, sat::literal* lits, sat::proof_hint const* ps) {
bool was_true = false;
for (unsigned i = 0; i < n; ++i)
was_true |= is_true(lits[i]);
@ -226,11 +226,11 @@ namespace euf {
return ctx.s().rand()();
}
size_t th_explain::get_obj_size(unsigned num_lits, unsigned num_eqs, char const* pma) {
return sat::constraint_base::obj_size(sizeof(th_explain) + sizeof(sat::literal) * num_lits + sizeof(enode_pair) * num_eqs + (pma?strlen(pma)+1:1));
size_t th_explain::get_obj_size(unsigned num_lits, unsigned num_eqs, sat::proof_hint const* pma) {
return sat::constraint_base::obj_size(sizeof(th_explain) + sizeof(sat::literal) * num_lits + sizeof(enode_pair) * num_eqs + (pma?pma->to_string().length()+1:1));
}
th_explain::th_explain(unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode_pair const& p, char const* pma) {
th_explain::th_explain(unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode_pair const& p, sat::proof_hint const* pma) {
m_consequent = c;
m_eq = p;
m_num_literals = n_lits;
@ -246,23 +246,26 @@ namespace euf {
m_eqs[i] = eqs[i];
base_ptr += sizeof(enode_pair) * n_eqs;
m_pragma = reinterpret_cast<char*>(base_ptr);
for (i = 0; pma && pma[i]; ++i)
m_pragma[i] = pma[i];
if (pma) {
std::string s = pma->to_string();
for (i = 0; s[i]; ++i)
m_pragma[i] = s[i];
}
m_pragma[i] = 0;
}
th_explain* th_explain::mk(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode* x, enode* y, char const* pma) {
th_explain* th_explain::mk(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode* x, enode* y, sat::proof_hint const* pma) {
region& r = th.ctx.get_region();
void* mem = r.allocate(get_obj_size(n_lits, n_eqs, pma));
sat::constraint_base::initialize(mem, &th);
return new (sat::constraint_base::ptr2mem(mem)) th_explain(n_lits, lits, n_eqs, eqs, c, enode_pair(x, y));
}
th_explain* th_explain::propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, sat::literal consequent, char const* pma) {
th_explain* th_explain::propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, sat::literal consequent, sat::proof_hint const* pma) {
return mk(th, lits.size(), lits.data(), eqs.size(), eqs.data(), consequent, nullptr, nullptr, pma);
}
th_explain* th_explain::propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y, char const* pma) {
th_explain* th_explain::propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y, sat::proof_hint const* pma) {
return mk(th, lits.size(), lits.data(), eqs.size(), eqs.data(), sat::null_literal, x, y, pma);
}

View file

@ -143,16 +143,16 @@ namespace euf {
region& get_region();
sat::status mk_status(char const* ps = nullptr);
sat::status mk_status(sat::proof_hint const* ps = nullptr);
bool add_unit(sat::literal lit);
bool add_units(sat::literal_vector const& lits);
bool add_clause(sat::literal lit) { return add_unit(lit); }
bool add_clause(sat::literal a, sat::literal b);
bool add_clause(sat::literal a, sat::literal b, char const* ps);
bool add_clause(sat::literal a, sat::literal b, sat::proof_hint const* ps);
bool add_clause(sat::literal a, sat::literal b, sat::literal c);
bool add_clause(sat::literal a, sat::literal b, sat::literal c, sat::literal d);
bool add_clause(sat::literal_vector const& lits, char const* ps = nullptr) { return add_clause(lits.size(), lits.data(), ps); }
bool add_clause(unsigned n, sat::literal* lits, char const* ps = nullptr);
bool add_clause(sat::literal_vector const& lits, sat::proof_hint const* ps = nullptr) { return add_clause(lits.size(), lits.data(), ps); }
bool add_clause(unsigned n, sat::literal* lits, sat::proof_hint const* ps = nullptr);
void add_equiv(sat::literal a, sat::literal b);
void add_equiv_and(sat::literal a, sat::literal_vector const& bs);
@ -221,9 +221,9 @@ namespace euf {
sat::literal* m_literals;
enode_pair* m_eqs;
char* m_pragma;
static size_t get_obj_size(unsigned num_lits, unsigned num_eqs, char const* pma);
th_explain(unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode_pair const& eq, char const* pma = nullptr);
static th_explain* mk(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode* x, enode* y, char const* pma = nullptr);
static size_t get_obj_size(unsigned num_lits, unsigned num_eqs, sat::proof_hint const* pma);
th_explain(unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode_pair const& eq, sat::proof_hint const* pma = nullptr);
static th_explain* mk(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode* x, enode* y, sat::proof_hint const* pma = nullptr);
public:
static th_explain* conflict(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs);
@ -234,8 +234,8 @@ namespace euf {
static th_explain* conflict(th_euf_solver& th, sat::literal lit, euf::enode* x, euf::enode* y);
static th_explain* conflict(th_euf_solver& th, euf::enode* x, euf::enode* y);
static th_explain* propagate(th_euf_solver& th, sat::literal lit, euf::enode* x, euf::enode* y);
static th_explain* propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, sat::literal consequent, char const* pma = nullptr);
static th_explain* propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y, char const* pma = nullptr);
static th_explain* propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, sat::literal consequent, sat::proof_hint const* pma = nullptr);
static th_explain* propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y, sat::proof_hint const* pma = nullptr);
sat::ext_constraint_idx to_index() const {
return sat::constraint_base::mem2base(this);
@ -270,7 +270,7 @@ namespace euf {
enode_pair eq_consequent() const { return m_eq; }
char const* get_pragma() const { return *m_pragma ? m_pragma : nullptr; }
sat::proof_hint const* get_pragma() const { return nullptr; } //*m_pragma ? m_pragma : nullptr; }
};

View file

@ -95,6 +95,7 @@ namespace sat {
inline bool operator!=(literal const & l1, literal const & l2) { return l1.m_val != l2.m_val; }
inline std::ostream & operator<<(std::ostream & out, sat::literal l) { if (l == sat::null_literal) out << "null"; else out << (l.sign() ? "-" : "") << l.var(); return out; }
typedef svector<literal> literal_vector;
@ -192,3 +193,11 @@ namespace sat {
}
};
namespace std {
inline std::string to_string(sat::literal l) {
if (l.sign()) return "-" + to_string(l.var());
return to_string(l.var());
}
};