3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-24 09:35:32 +00:00

updates to solver interface and adding some saturation rules

This commit is contained in:
Nikolaj Bjorner 2023-12-17 18:16:47 -08:00
parent 172d0ea685
commit 21791f12bf
9 changed files with 178 additions and 206 deletions

View file

@ -304,23 +304,23 @@ namespace polysat {
}
}
dependency core::get_dependency(constraint_id idx) const {
auto [sc, d, value] = m_constraint_index[idx.id];
SASSERT(value != l_undef);
return value == l_false ? ~d : d;
}
dependency_vector core::get_dependencies(constraint_id_vector const& cc) {
dependency_vector result;
for (auto idx : cc) {
auto [sc, d, value] = m_constraint_index[idx.id];
SASSERT(value != l_undef);
result.push_back(value == l_false ? ~d : d);
}
for (auto idx : cc)
result.push_back(get_dependency(idx));
return result;
}
dependency_vector core::get_dependencies(std::initializer_list<constraint_id> const& cc) {
dependency_vector result;
for (auto idx : cc) {
auto [sc, d, value] = m_constraint_index[idx.id];
SASSERT(value != l_undef);
result.push_back(value == l_false ? ~d : d);
}
for (auto idx : cc)
result.push_back(get_dependency(idx));
return result;
}
@ -421,8 +421,12 @@ namespace polysat {
assign_eh(idx, false, 0);
}
void core::add_clause(char const* name, core_vector const& cs, bool is_redundant) {
s.add_polysat_clause(name, cs, is_redundant);
bool core::add_clause(char const* name, core_vector const& cs, bool is_redundant) {
for (auto e : cs)
if (std::holds_alternative<signed_constraint>(e) && eval(*std::get_if<signed_constraint>(&e)) == l_true)
return false;
return s.add_polysat_clause(name, cs, is_redundant);
}
signed_constraint core::get_constraint(constraint_id idx) {
@ -434,8 +438,7 @@ namespace polysat {
}
lbool core::eval(constraint_id id) {
auto sc = get_constraint(id);
return sc.eval(m_assignment);
return get_constraint(id).eval(m_assignment);
}
}

View file

@ -54,7 +54,7 @@ namespace polysat {
constraints m_constraints;
assignment m_assignment;
unsigned m_qhead = 0, m_vqhead = 0;
svector<constraint_id> m_prop_queue;
constraint_id_vector m_prop_queue;
svector<constraint_info> m_constraint_index; // index of constraints
constraint_id_vector m_unsat_core;
@ -62,7 +62,7 @@ namespace polysat {
// attributes associated with variables
vector<pdd> m_vars; // for each variable a pdd
vector<rational> m_values; // current value of assigned variable
svector<constraint_id> m_justification; // justification for assignment
constraint_id_vector m_justification; // justification for assignment
activity m_activity; // activity of variables
var_queue<activity> m_var_queue; // priority queue of variables to assign
vector<unsigned_vector> m_watch; // watch lists for variables for constraints on m_prop_queue where they occur
@ -136,7 +136,7 @@ namespace polysat {
* In other words, the clause represents the formula /\ d_i -> \/ sc_j
* Where d_i are logical interpretations of dependencies and sc_j are signed constraints.
*/
void add_clause(char const* name, core_vector const& cs, bool is_redundant);
bool add_clause(char const* name, core_vector const& cs, bool is_redundant);
pvar add_var(unsigned sz);
pdd var(pvar p) { return m_vars[p]; }
@ -152,6 +152,8 @@ namespace polysat {
*/
signed_constraint get_constraint(constraint_id id);
constraint_id_vector const& unsat_core() const { return m_unsat_core; }
constraint_id_vector const& assigned_constraints() const { return m_prop_queue; }
dependency get_dependency(constraint_id idx) const;
lbool eval(constraint_id id);
bool propagate(signed_constraint const& sc, constraint_id_vector const& ids) { return s.propagate(sc, get_dependencies(ids)); }
bool propagate(signed_constraint const& sc, std::initializer_list<constraint_id> const& ids) { return s.propagate(sc, get_dependencies(ids)); }

View file

@ -31,6 +31,18 @@ namespace polysat {
}
dependency inequality::dep() const {
return c.get_dependency(id());
}
bool inequality::is_l_v(pdd const& v, signed_constraint const& sc) {
return sc.is_ule() && v == (sc.sign() ? sc.to_ule().rhs() : sc.to_ule().lhs());
}
bool inequality::is_g_v(pdd const& v, signed_constraint const& sc) {
return sc.is_ule() && v == (sc.sign() ? sc.to_ule().lhs() : sc.to_ule().rhs());
}
#if 0

View file

@ -34,6 +34,9 @@ namespace polysat {
dst.reset(src.manager());
dst = src;
}
// p := coeff*x*y where coeff_x = coeff*x, x a variable
bool is_coeffxY(pdd const& coeff_x, pdd const& p, pdd& y) const { throw default_exception("nyi"); }
public:
static inequality from_ule(core& c, constraint_id id);
@ -41,15 +44,18 @@ namespace polysat {
pdd const& rhs() const { return m_rhs; }
bool is_strict() const { return m_src.is_negative(); }
constraint_id id() const { return m_id; }
dependency dep() const;
signed_constraint as_signed_constraint() const { return m_src; }
operator signed_constraint() const { return m_src; }
// c := lhs ~ v
// where ~ is < or <=
bool is_l_v(pvar v) const { return rhs() == c.var(v); }
bool is_l_v(pvar v) const { return rhs() == c.var(v); }
static bool is_l_v(pdd const& v, signed_constraint const& sc);
// c := v ~ rhs
bool is_g_v(pvar v) const { return lhs() == c.var(v); }
static bool is_g_v(pdd const& v, signed_constraint const& sc);
// c := x ~ Y
bool is_x_l_Y(pvar x, pdd& y) const { return is_g_v(x) && (set(y, rhs()), true); }
@ -87,11 +93,11 @@ namespace polysat {
bool is_AxB_diseq_0(pvar x, pdd& a, pdd& b, pdd& y) const;
// c := Y*X ~ z*X
bool is_YX_l_zX(pvar z, pdd& x, pdd& y) const;
bool is_YX_l_zX(pvar z, pdd& x, pdd& y) const { return is_xY(z, rhs(), x) && is_coeffxY(x, lhs(), y); }
bool verify_YX_l_zX(pvar z, pdd const& x, pdd const& y) const;
// c := xY <= xZ
bool is_xY_l_xZ(pvar x, pdd& y, pdd& z) const;
bool is_xY_l_xZ(pvar x, pdd& y, pdd& z) const { return is_xY(x, lhs(), y) && is_xY(x, rhs(), z); }
/**
* Match xy = x * Y

View file

@ -56,6 +56,7 @@ namespace polysat {
if (c.size(v) != i.lhs().power_of_2())
return;
propagate_infer_equality(v, i);
try_ugt_x(v, i);
return;
}
@ -65,27 +66,122 @@ namespace polysat {
m_propagated = true;
}
/**
* p <= q, q <= p => p - q = 0
*/
void saturation::propagate_infer_equality(pvar x, inequality const& a_l_b) {
set_rule("[x] p <= q, q <= p => p - q = 0");
if (a_l_b.is_strict())
return;
if (a_l_b.lhs().degree(x) == 0 && a_l_b.rhs().degree(x) == 0)
return;
void saturation::add_clause(char const* name, core_vector const& cs, bool is_redundant) {
if (c.add_clause(name, cs, is_redundant))
m_propagated = true;
}
bool saturation::match_core(std::function<bool(signed_constraint const& sc)> const& p, constraint_id& id_out) {
for (auto id : c.unsat_core()) {
auto sc = c.get_constraint(id);
if (!sc.is_ule())
continue;
auto i = inequality::from_ule(c, id);
if (i.lhs() == a_l_b.rhs() && i.rhs() == a_l_b.lhs() && !i.is_strict()) {
c.propagate(c.eq(i.lhs(), i.rhs()), { id, a_l_b.id() });
return;
if (p(sc)) {
id_out = id;
return true;
}
}
}
return false;
}
bool saturation::match_constraints(std::function<bool(signed_constraint const& sc)> const& p, constraint_id& id_out) {
for (auto id : c.assigned_constraints()) {
auto sc = c.get_constraint(id);
if (p(sc)) {
id_out = id;
return true;
}
}
return false;
}
signed_constraint saturation::ineq(bool is_strict, pdd const& x, pdd const& y) {
return is_strict ? c.cs().ult(x, y) : c.cs().ule(x, y);
}
/**
* p <= q, q <= p => p = q
*/
void saturation::propagate_infer_equality(pvar x, inequality const& i) {
set_rule("[x] p <= q, q <= p => p - q = 0");
if (i.is_strict())
return;
if (i.lhs().degree(x) == 0 && i.rhs().degree(x) == 0)
return;
constraint_id id;
if (!match_core([&](auto const& sc) { return sc.is_ule() && !sc.sign() && sc.to_ule().lhs() == i.rhs() && sc.to_ule().rhs() == i.lhs(); }, id))
return;
c.propagate(c.eq(i.lhs(), i.rhs()), { id, i.id() });
}
/**
* Implement the inferences
* [x] yx < zx ==> Ω*(x,y) \/ y < z
* [x] yx <= zx ==> Ω*(x,y) \/ y <= z \/ x = 0
*/
void saturation::try_ugt_x(pvar v, inequality const& i) {
pdd x = c.var(v);
pdd y = x;
pdd z = x;
auto& C = c.cs();
if (!i.is_xY_l_xZ(v, y, z))
return;
auto ovfl = C.umul_ovfl(x, y);
if (i.is_strict())
add_clause("[x] yx < zx ==> Ω*(x,y) \\/ y < z", { i.dep(), ovfl, C.ult(y, z)}, false);
else
add_clause("[x] yx <= zx ==> Ω*(x,y) \\/ y <= z \\/ x = 0", { i.dep(), ovfl, C.eq(x), C.ule(y, z) }, false);
}
/**
* [y] z' <= y /\ yx <= zx ==> Ω*(x,y) \/ z'x <= zx
* [y] z' <= y /\ yx < zx ==> Ω*(x,y) \/ z'x < zx
* [y] z' < y /\ yx <= zx ==> Ω*(x,y) \/ z'x <= zx
* [y] z' < y /\ yx < zx ==> Ω*(x,y) \/ z'x < zx
* [y] z' < y /\ yx < zx ==> Ω*(x,y) \/ z'x + 1 < zx (TODO?)
* [y] z' < y /\ yx < zx ==> Ω*(x,y) \/ (z' + 1)x < zx (TODO?)
*/
void saturation::try_ugt_y(pvar v, inequality const& i) {
auto y = c.var(v);
pdd x = y;
pdd z = y;
auto& C = c.cs();
constraint_id id;
if (!i.is_Xy_l_XZ(v, x, z))
return;
if (!match_constraints([&](auto const& sc) { return inequality::is_l_v(y, sc); }, id))
return;
auto j = inequality::from_ule(c, id);
pdd const& z_prime = i.lhs();
bool is_strict = i.is_strict() || j.is_strict();
add_clause("[y] z' <= y & yx <= zx", { i, j, C.umul_ovfl(x, y), ineq(is_strict, z_prime * x, z * x) }, false);
}
/**
* [z] z <= y' /\ yx <= zx ==> Ω*(x,y') \/ yx <= y'x
* [z] z <= y' /\ yx < zx ==> Ω*(x,y') \/ yx < y'x
* [z] z < y' /\ yx <= zx ==> Ω*(x,y') \/ yx <= y'x
* [z] z < y' /\ yx < zx ==> Ω*(x,y') \/ yx < y'x
* [z] z < y' /\ yx < zx ==> Ω*(x,y') \/ yx+1 < y'x (TODO?)
* [z] z < y' /\ yx < zx ==> Ω*(x,y') \/ (y+1)x < y'x (TODO?)
*/
void saturation::try_ugt_z(pvar v, inequality const& i) {
auto z = c.var(v);
pdd y = z;
pdd x = z;
constraint_id id;
if (!i.is_YX_l_zX(v, x, y))
return;
if (!match_constraints([&](auto const& sc) { return inequality::is_g_v(z, sc); }, id))
return;
auto j = inequality::from_ule(c, id);
auto y_prime = j.rhs();
bool is_strict = i.is_strict() || j.is_strict();
add_clause("[z] z <= y' && yx <= zx", { i, j, c.umul_ovfl(x, y_prime), ineq(is_strict, y * x, y_prime * x) }, false);
}
/**
* Determine whether values of x * y is non-overflowing.
*/
@ -95,6 +191,10 @@ namespace polysat {
return c.try_eval(x, x_val) && c.try_eval(y, y_val) && x_val * y_val < bound;
}
bool saturation::is_non_overflow(pdd const& x, pdd const& y, signed_constraint& sc) {
return is_non_overflow(x, y) && (sc = c.umul_ovfl(x, y), true);
}
#if 0
@ -322,66 +422,6 @@ namespace polysat {
return false;
}
signed_constraint saturation::ineq(bool is_strict, pdd const& lhs, pdd const& rhs) {
if (is_strict)
return s.ult(lhs, rhs);
else
return s.ule(lhs, rhs);
}
bool saturation::propagate(pvar v, conflict& core, inequality const& crit, signed_constraint c) {
return propagate(v, core, crit.as_signed_constraint(), c);
}
bool saturation::propagate(pvar v, conflict& core, signed_constraint crit, signed_constraint c) {
m_lemma.insert(~crit);
return propagate(v, core, c);
}
bool saturation::propagate(pvar v, conflict& core, signed_constraint c) {
if (is_forced_true(c))
return false;
SASSERT(all_of(m_lemma, [this](sat::literal lit) { return is_forced_false(s.lit2cnstr(lit)); }));
m_lemma.insert(c);
m_lemma.set_name(m_rule);
core.add_lemma(m_lemma.build());
log_lemma(v, core);
return true;
}
bool saturation::add_conflict(pvar v, conflict& core, inequality const& crit1, signed_constraint c) {
return add_conflict(v, core, crit1, crit1, c);
}
bool saturation::add_conflict(pvar v, conflict& core, inequality const& _crit1, inequality const& _crit2, signed_constraint const c) {
auto crit1 = _crit1.as_signed_constraint();
auto crit2 = _crit2.as_signed_constraint();
m_lemma.insert(~crit1);
if (crit1 != crit2)
m_lemma.insert(~crit2);
LOG("critical " << m_rule << " " << crit1);
LOG("consequent " << c << " value: " << c.bvalue(s) << " is-false: " << c.is_currently_false(s));
SASSERT(all_of(m_lemma, [this](sat::literal lit) { return s.m_bvars.value(lit) == l_false; }));
// Ensure lemma is a conflict lemma
if (!is_forced_false(c))
return false;
// Constraint c is already on the search stack, so the lemma will not derive anything new.
if (c.bvalue(s) == l_true)
return false;
m_lemma.insert_eval(c);
m_lemma.set_name(m_rule);
core.add_lemma(m_lemma.build());
log_lemma(v, core);
return true;
}
bool saturation::is_non_overflow(pdd const& x, pdd const& y, signed_constraint& c) {
if (is_non_overflow(x, y)) {
@ -461,115 +501,12 @@ namespace polysat {
return c.bvalue(s) == l_true || c.is_currently_true(s);
}
/**
* Implement the inferences
* [x] yx < zx ==> Ω*(x,y) \/ y < z
* [x] yx <= zx ==> Ω*(x,y) \/ y <= z \/ x = 0
*/
bool saturation::try_ugt_x(pvar v, conflict& core, inequality const& xy_l_xz) {
set_rule("[x] yx <= zx");
pdd x = s.var(v);
pdd y = x;
pdd z = x;
signed_constraint non_ovfl;
if (!is_xY_l_xZ(v, xy_l_xz, y, z))
return false;
if (!xy_l_xz.is_strict() && s.is_assigned(v) && s.get_value(v).is_zero())
return false;
if (!is_non_overflow(x, y, non_ovfl))
return false;
m_lemma.reset();
m_lemma.insert_eval(~non_ovfl);
if (!xy_l_xz.is_strict())
m_lemma.insert_eval(s.eq(x));
return add_conflict(v, core, xy_l_xz, ineq(xy_l_xz.is_strict(), y, z));
}
/**
* [y] z' <= y /\ yx <= zx ==> Ω*(x,y) \/ z'x <= zx
* [y] z' <= y /\ yx < zx ==> Ω*(x,y) \/ z'x < zx
* [y] z' < y /\ yx <= zx ==> Ω*(x,y) \/ z'x <= zx
* [y] z' < y /\ yx < zx ==> Ω*(x,y) \/ z'x < zx
* [y] z' < y /\ yx < zx ==> Ω*(x,y) \/ z'x + 1 < zx (TODO?)
* [y] z' < y /\ yx < zx ==> Ω*(x,y) \/ (z' + 1)x < zx (TODO?)
*/
bool saturation::try_ugt_y(pvar v, conflict& core, inequality const& yx_l_zx) {
set_rule("[y] z' <= y & yx <= zx");
auto& m = s.var2pdd(v);
pdd x = m.zero();
pdd z = m.zero();
if (!is_Xy_l_XZ(v, yx_l_zx, x, z))
return false;
for (auto si : s.m_search) {
if (!si.is_boolean())
continue;
if (si.is_resolved())
continue;
auto d = s.lit2cnstr(si.lit());
if (!d->is_ule())
continue;
auto l_y = inequality::from_ule(d);
if (is_l_v(v, l_y) && try_ugt_y(v, core, l_y, yx_l_zx, x, z))
return true;
}
return false;
}
bool saturation::try_ugt_y(pvar v, conflict& core, inequality const& l_y, inequality const& yx_l_zx, pdd const& x, pdd const& z) {
SASSERT(is_l_v(v, l_y));
SASSERT(verify_Xy_l_XZ(v, yx_l_zx, x, z));
pdd const y = s.var(v);
signed_constraint non_ovfl;
if (!is_non_overflow(x, y, non_ovfl))
return false;
pdd const& z_prime = l_y.lhs();
m_lemma.reset();
m_lemma.insert_eval(~non_ovfl);
return add_conflict(v, core, l_y, yx_l_zx, ineq(yx_l_zx.is_strict(), z_prime * x, z * x));
}
/**
* [z] z <= y' /\ yx <= zx ==> Ω*(x,y') \/ yx <= y'x
* [z] z <= y' /\ yx < zx ==> Ω*(x,y') \/ yx < y'x
* [z] z < y' /\ yx <= zx ==> Ω*(x,y') \/ yx <= y'x
* [z] z < y' /\ yx < zx ==> Ω*(x,y') \/ yx < y'x
* [z] z < y' /\ yx < zx ==> Ω*(x,y') \/ yx+1 < y'x (TODO?)
* [z] z < y' /\ yx < zx ==> Ω*(x,y') \/ (y+1)x < y'x (TODO?)
*/
bool saturation::try_ugt_z(pvar z, conflict& core, inequality const& yx_l_zx) {
set_rule("[z] z <= y' && yx <= zx");
auto& m = s.var2pdd(z);
pdd y = m.zero();
pdd x = m.zero();
if (!is_YX_l_zX(z, yx_l_zx, x, y))
return false;
for (auto si : s.m_search) {
if (!si.is_boolean())
continue;
if (si.is_resolved())
continue;
auto d = s.lit2cnstr(si.lit());
if (!d->is_ule())
continue;
auto z_l_y = inequality::from_ule(d);
if (is_g_v(z, z_l_y) && try_ugt_z(z, core, z_l_y, yx_l_zx, x, y))
return true;
}
return false;
}
bool saturation::try_ugt_z(pvar z, conflict& core, inequality const& z_l_y, inequality const& yx_l_zx, pdd const& x, pdd const& y) {
SASSERT(is_g_v(z, z_l_y));
SASSERT(verify_YX_l_zX(z, yx_l_zx, x, y));
pdd const& y_prime = z_l_y.rhs();
signed_constraint non_ovfl;
if (!is_non_overflow(x, y_prime, non_ovfl))
return false;
m_lemma.reset();
m_lemma.insert_eval(~non_ovfl);
return add_conflict(z, core, yx_l_zx, z_l_y, ineq(yx_l_zx.is_strict(), y * x, y_prime * x));
}
/**
* [x] y <= ax /\ x <= z (non-overflow case)

View file

@ -30,15 +30,23 @@ namespace polysat {
void set_rule(char const* r) { m_rule = r; }
void propagate(signed_constraint const& sc, std::initializer_list<constraint_id> const& premises);
void add_clause(char const* name, core_vector const& cs, bool is_redundant);
bool match_core(std::function<bool(signed_constraint const& sc)> const& p, constraint_id& id);
bool match_constraints(std::function<bool(signed_constraint const& sc)> const& p, constraint_id& id);
// a * b does not overflow
bool is_non_overflow(pdd const& a, pdd const& b);
bool is_non_overflow(pdd const& x, pdd const& y, signed_constraint& c);
// p := coeff*x*y where coeff_x = coeff*x, x a variable
bool is_coeffxY(pdd const& coeff_x, pdd const& p, pdd& y);
void propagate_infer_equality(pvar x, inequality const& a_l_b);
void try_ugt_x(pvar v, inequality const& i);
void try_ugt_y(pvar v, inequality const& i);
void try_ugt_z(pvar z, inequality const& i);
signed_constraint ineq(bool is_strict, pdd const& x, pdd const& y);
#if 0
parity_tracker m_parity_tracker;
@ -47,7 +55,7 @@ namespace polysat {
bool is_non_overflow(pdd const& x, pdd const& y, signed_constraint& c);
signed_constraint ineq(bool strict, pdd const& lhs, pdd const& rhs);
void log_lemma(pvar v, conflict& core);

View file

@ -96,7 +96,7 @@ namespace polysat {
virtual void add_eq_literal(pvar v, rational const& val) = 0;
virtual void set_conflict(dependency_vector const& core) = 0;
virtual void set_lemma(core_vector const& aux_core, dependency_vector const& core) = 0;
virtual void add_polysat_clause(char const* name, core_vector cs, bool redundant) = 0;
virtual bool add_polysat_clause(char const* name, core_vector cs, bool redundant) = 0;
virtual bool propagate(signed_constraint sc, dependency_vector const& deps) = 0;
virtual void propagate(dependency const& d, bool sign, dependency_vector const& deps) = 0;
virtual trail_stack& trail() = 0;

View file

@ -278,7 +278,7 @@ namespace polysat {
return ctx.get_trail_stack();
}
void solver::add_polysat_clause(char const* name, core_vector cs, bool is_redundant) {
bool solver::add_polysat_clause(char const* name, core_vector cs, bool is_redundant) {
sat::literal_vector lits;
for (auto e : cs) {
if (std::holds_alternative<dependency>(e)) {
@ -297,7 +297,11 @@ namespace polysat {
else if (std::holds_alternative<signed_constraint>(e))
lits.push_back(ctx.mk_literal(constraint2expr(*std::get_if<signed_constraint>(&e))));
}
for (auto lit : lits)
if (s().value(lit) == l_true)
return false;
s().add_clause(lits.size(), lits.data(), sat::status::th(is_redundant, get_id(), nullptr));
return true;
}
void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) {

View file

@ -168,7 +168,7 @@ namespace polysat {
bool inconsistent() const override;
void get_bitvector_suffixes(pvar v, pvar_vector& out) override;
void get_fixed_bits(pvar v, svector<justified_fixed_bits>& fixed_bits) override;
void add_polysat_clause(char const* name, core_vector cs, bool redundant) override;
bool add_polysat_clause(char const* name, core_vector cs, bool redundant) override;
std::pair<sat::literal_vector, euf::enode_pair_vector> explain_deps(dependency_vector const& deps);