diff --git a/src/nlsat/nlsat_evaluator.cpp b/src/nlsat/nlsat_evaluator.cpp index 7f353c7db..ddce9fc09 100644 --- a/src/nlsat/nlsat_evaluator.cpp +++ b/src/nlsat/nlsat_evaluator.cpp @@ -488,7 +488,7 @@ namespace nlsat { return sign; } - interval_set_ref infeasible_intervals(ineq_atom * a, bool neg) { + interval_set_ref infeasible_intervals(ineq_atom * a, bool neg, clause const* cls) { sign_table & table = m_sign_table_tmp; table.reset(); unsigned num_ps = a->size(); @@ -543,7 +543,7 @@ namespace nlsat { curr_root_id = table.get_root_id(c-1); } set = m_ism.mk(prev_open, prev_inf, table.get_root(prev_root_id), - curr_open, false, table.get_root(curr_root_id), jst); + curr_open, false, table.get_root(curr_root_id), jst, cls); result = m_ism.mk_union(result, set); prev_sat = true; } @@ -554,7 +554,7 @@ namespace nlsat { if (c == 0) { if (num_cells == 1) { // (-oo, oo) - result = m_ism.mk(true, true, dummy, true, true, dummy, jst); + result = m_ism.mk(true, true, dummy, true, true, dummy, jst, cls); } else { // save -oo as beginning of infeasible interval @@ -583,7 +583,7 @@ namespace nlsat { if (c == num_cells - 1) { // last cell add interval with (prev, oo) set = m_ism.mk(prev_open, prev_inf, table.get_root(prev_root_id), - true, true, dummy, jst); + true, true, dummy, jst, cls); result = m_ism.mk_union(result, set); } } @@ -592,7 +592,7 @@ namespace nlsat { return result; } - interval_set_ref infeasible_intervals(root_atom * a, bool neg) { + interval_set_ref infeasible_intervals(root_atom * a, bool neg, clause const* cls) { atom::kind k = a->get_kind(); unsigned i = a->i(); SASSERT(i > 0); @@ -613,7 +613,7 @@ namespace nlsat { result = m_ism.mk_empty(); } else { - result = m_ism.mk(true, true, dummy, true, true, dummy, jst); // (-oo, oo) + result = m_ism.mk(true, true, dummy, true, true, dummy, jst, cls); // (-oo, oo) } } else { @@ -621,38 +621,38 @@ namespace nlsat { switch (k) { case atom::ROOT_EQ: if (neg) { - result = m_ism.mk(false, false, r_i, false, false, r_i, jst); // [r_i, r_i] + result = m_ism.mk(false, false, r_i, false, false, r_i, jst, cls); // [r_i, r_i] } else { interval_set_ref s1(m_ism), s2(m_ism); - s1 = m_ism.mk(true, true, dummy, true, false, r_i, jst); // (-oo, r_i) - s2 = m_ism.mk(true, false, r_i, true, true, dummy, jst); // (r_i, oo) + s1 = m_ism.mk(true, true, dummy, true, false, r_i, jst, cls); // (-oo, r_i) + s2 = m_ism.mk(true, false, r_i, true, true, dummy, jst, cls); // (r_i, oo) result = m_ism.mk_union(s1, s2); } break; case atom::ROOT_LT: if (neg) - result = m_ism.mk(true, true, dummy, true, false, r_i, jst); // (-oo, r_i) + result = m_ism.mk(true, true, dummy, true, false, r_i, jst, cls); // (-oo, r_i) else - result = m_ism.mk(false, false, r_i, true, true, dummy, jst); // [r_i, oo) + result = m_ism.mk(false, false, r_i, true, true, dummy, jst, cls); // [r_i, oo) break; case atom::ROOT_GT: if (neg) - result = m_ism.mk(true, false, r_i, true, true, dummy, jst); // (r_i, oo) + result = m_ism.mk(true, false, r_i, true, true, dummy, jst, cls); // (r_i, oo) else - result = m_ism.mk(true, true, dummy, false, false, r_i, jst); // (-oo, r_i] + result = m_ism.mk(true, true, dummy, false, false, r_i, jst, cls); // (-oo, r_i] break; case atom::ROOT_LE: if (neg) - result = m_ism.mk(true, true, dummy, false, false, r_i, jst); // (-oo, r_i] + result = m_ism.mk(true, true, dummy, false, false, r_i, jst, cls); // (-oo, r_i] else - result = m_ism.mk(true, false, r_i, true, true, dummy, jst); // (r_i, oo) + result = m_ism.mk(true, false, r_i, true, true, dummy, jst, cls); // (r_i, oo) break; case atom::ROOT_GE: if (neg) - result = m_ism.mk(false, false, r_i, true, true, dummy, jst); // [r_i, oo) + result = m_ism.mk(false, false, r_i, true, true, dummy, jst, cls); // [r_i, oo) else - result = m_ism.mk(true, true, dummy, true, false, r_i, jst); // (-oo, r_i) + result = m_ism.mk(true, true, dummy, true, false, r_i, jst, cls); // (-oo, r_i) break; default: UNREACHABLE(); @@ -663,8 +663,8 @@ namespace nlsat { return result; } - interval_set_ref infeasible_intervals(atom * a, bool neg) { - return a->is_ineq_atom() ? infeasible_intervals(to_ineq_atom(a), neg) : infeasible_intervals(to_root_atom(a), neg); + interval_set_ref infeasible_intervals(atom * a, bool neg, clause const* cls) { + return a->is_ineq_atom() ? infeasible_intervals(to_ineq_atom(a), neg, cls) : infeasible_intervals(to_root_atom(a), neg, cls); } }; @@ -684,8 +684,8 @@ namespace nlsat { return m_imp->eval(a, neg); } - interval_set_ref evaluator::infeasible_intervals(atom * a, bool neg) { - return m_imp->infeasible_intervals(a, neg); + interval_set_ref evaluator::infeasible_intervals(atom * a, bool neg, clause const* cls) { + return m_imp->infeasible_intervals(a, neg, cls); } void evaluator::push() { diff --git a/src/nlsat/nlsat_evaluator.h b/src/nlsat/nlsat_evaluator.h index e43eec80a..7e6be0697 100644 --- a/src/nlsat/nlsat_evaluator.h +++ b/src/nlsat/nlsat_evaluator.h @@ -52,7 +52,7 @@ namespace nlsat { Let x be a->max_var(). Then, the resultant set specifies which values of x falsify the given literal. */ - interval_set_ref infeasible_intervals(atom * a, bool neg); + interval_set_ref infeasible_intervals(atom * a, bool neg, clause const* cls); void push(); void pop(unsigned num_scopes); diff --git a/src/nlsat/nlsat_explain.cpp b/src/nlsat/nlsat_explain.cpp index a93935fb6..7df849b81 100644 --- a/src/nlsat/nlsat_explain.cpp +++ b/src/nlsat/nlsat_explain.cpp @@ -1382,14 +1382,14 @@ namespace nlsat { literal l = core[i]; atom * a = m_atoms[l.var()]; SASSERT(a != 0); - interval_set_ref inf = m_evaluator.infeasible_intervals(a, l.sign()); + interval_set_ref inf = m_evaluator.infeasible_intervals(a, l.sign(), nullptr); r = ism.mk_union(inf, r); if (ism.is_full(r)) { // Done return false; } } - TRACE("nlsat_mininize", tout << "interval set after adding partial core:\n" << r << "\n";); + TRACE("nlsat_minimize", tout << "interval set after adding partial core:\n" << r << "\n";); if (todo.size() == 1) { // Done core.push_back(todo[0]); @@ -1401,7 +1401,7 @@ namespace nlsat { literal l = todo[i]; atom * a = m_atoms[l.var()]; SASSERT(a != 0); - interval_set_ref inf = m_evaluator.infeasible_intervals(a, l.sign()); + interval_set_ref inf = m_evaluator.infeasible_intervals(a, l.sign(), nullptr); r = ism.mk_union(inf, r); if (ism.is_full(r)) { // literal l must be in the core @@ -1425,15 +1425,15 @@ namespace nlsat { todo.reset(); core.reset(); todo.append(num, ls); while (true) { - TRACE("nlsat_mininize", tout << "core minimization:\n"; display(tout, todo); tout << "\nCORE:\n"; display(tout, core);); + TRACE("nlsat_minimize", tout << "core minimization:\n"; display(tout, todo); tout << "\nCORE:\n"; display(tout, core);); if (!minimize_core(todo, core)) break; std::reverse(todo.begin(), todo.end()); - TRACE("nlsat_mininize", tout << "core minimization:\n"; display(tout, todo); tout << "\nCORE:\n"; display(tout, core);); + TRACE("nlsat_minimize", tout << "core minimization:\n"; display(tout, todo); tout << "\nCORE:\n"; display(tout, core);); if (!minimize_core(todo, core)) break; } - TRACE("nlsat_mininize", tout << "core:\n"; display(tout, core);); + TRACE("nlsat_minimize", tout << "core:\n"; display(tout, core);); r.append(core.size(), core.c_ptr()); } diff --git a/src/nlsat/nlsat_interval_set.cpp b/src/nlsat/nlsat_interval_set.cpp index b089204f8..c678f2b23 100644 --- a/src/nlsat/nlsat_interval_set.cpp +++ b/src/nlsat/nlsat_interval_set.cpp @@ -28,6 +28,7 @@ namespace nlsat { unsigned m_lower_inf:1; unsigned m_upper_inf:1; literal m_justification; + clause const* m_clause; anum m_lower; anum m_upper; }; @@ -147,7 +148,7 @@ namespace nlsat { interval_set * interval_set_manager::mk(bool lower_open, bool lower_inf, anum const & lower, bool upper_open, bool upper_inf, anum const & upper, - literal justification) { + literal justification, clause const* cls) { void * mem = m_allocator.allocate(interval_set::get_obj_size(1)); interval_set * new_set = new (mem) interval_set(); new_set->m_num_intervals = 1; @@ -159,6 +160,7 @@ namespace nlsat { i->m_upper_open = upper_open; i->m_upper_inf = upper_inf; i->m_justification = justification; + i->m_clause = cls; if (!lower_inf) m_am.set(i->m_lower, lower); if (!upper_inf) @@ -644,8 +646,9 @@ namespace nlsat { return true; } - void interval_set_manager::get_justifications(interval_set const * s, literal_vector & js) { + void interval_set_manager::get_justifications(interval_set const * s, literal_vector & js, ptr_vector& clauses) { js.reset(); + clauses.reset(); unsigned num = num_intervals(s); for (unsigned i = 0; i < num; i++) { literal l = s->m_intervals[i].m_justification; @@ -654,6 +657,9 @@ namespace nlsat { continue; m_already_visited.setx(lidx, true, false); js.push_back(l); + if (s->m_intervals[i].m_clause) { + clauses.push_back(const_cast(s->m_intervals[i].m_clause)); + } } for (unsigned i = 0; i < num; i++) { literal l = s->m_intervals[i].m_justification; diff --git a/src/nlsat/nlsat_interval_set.h b/src/nlsat/nlsat_interval_set.h index 9cc57faba..1091a1ffd 100644 --- a/src/nlsat/nlsat_interval_set.h +++ b/src/nlsat/nlsat_interval_set.h @@ -47,7 +47,7 @@ namespace nlsat { */ interval_set * mk(bool lower_open, bool lower_inf, anum const & lower, bool upper_open, bool upper_inf, anum const & upper, - literal justification); + literal justification, clause const* cls); /** \brief Return the union of two sets. @@ -91,7 +91,7 @@ namespace nlsat { /** \brief Return a set of literals that justify s. */ - void get_justifications(interval_set const * s, literal_vector & js); + void get_justifications(interval_set const * s, literal_vector & js, ptr_vector& clauses ); std::ostream& display(std::ostream & out, interval_set const * s) const; diff --git a/src/nlsat/nlsat_justification.h b/src/nlsat/nlsat_justification.h index 64e0d3d70..db0abbb0e 100644 --- a/src/nlsat/nlsat_justification.h +++ b/src/nlsat/nlsat_justification.h @@ -38,16 +38,24 @@ namespace nlsat { class lazy_justification { unsigned m_num_literals; - literal m_literals[0]; + unsigned m_num_clauses; + char m_data[0]; + nlsat::clause* const* clauses() const { return (nlsat::clause *const*)(m_data); } public: - static unsigned get_obj_size(unsigned num) { return sizeof(lazy_justification) + sizeof(literal)*num; } - lazy_justification(unsigned num, literal const * lits): - m_num_literals(num) { - memcpy(m_literals, lits, sizeof(literal)*num); + static unsigned get_obj_size(unsigned nl, unsigned nc) { return sizeof(lazy_justification) + sizeof(literal)*nl + sizeof(nlsat::clause*)*nc; } + lazy_justification(unsigned nl, literal const * lits, unsigned nc, nlsat::clause * const* clss): + m_num_literals(nl), + m_num_clauses(nc) { + memcpy(m_data + 0, clss, sizeof(nlsat::clause const*)*nc); + memcpy(m_data + sizeof(nlsat::clause*)*nc, lits, sizeof(literal)*nl); } - unsigned size() const { return m_num_literals; } - literal operator[](unsigned i) const { SASSERT(i < size()); return m_literals[i]; } - literal const * lits() const { return m_literals; } + unsigned num_lits() const { return m_num_literals; } + literal lit(unsigned i) const { SASSERT(i < num_lits()); return lits()[i]; } + literal const * lits() const { return (literal const*)(m_data + m_num_clauses*sizeof(nlsat::clause*)); } + + unsigned num_clauses() const { return m_num_clauses; } + nlsat::clause const& clause(unsigned i) const { SASSERT(i < num_clauses()); return *(clauses()[i]); } + }; class justification { @@ -83,15 +91,15 @@ namespace nlsat { const justification decided_justification(true); inline justification mk_clause_jst(clause const * c) { return justification(const_cast(c)); } - inline justification mk_lazy_jst(small_object_allocator & a, unsigned num, literal const * lits) { - void * mem = a.allocate(lazy_justification::get_obj_size(num)); - return justification(new (mem) lazy_justification(num, lits)); + inline justification mk_lazy_jst(small_object_allocator & a, unsigned nl, literal const * lits, unsigned nc, clause *const* clauses) { + void * mem = a.allocate(lazy_justification::get_obj_size(nl, nc)); + return justification(new (mem) lazy_justification(nl, lits, nc, clauses)); } inline void del_jst(small_object_allocator & a, justification jst) { if (jst.is_lazy()) { lazy_justification * ptr = jst.get_lazy(); - unsigned obj_sz = lazy_justification::get_obj_size(ptr->size()); + unsigned obj_sz = lazy_justification::get_obj_size(ptr->num_lits(), ptr->num_clauses()); a.deallocate(obj_sz, ptr); } } diff --git a/src/nlsat/nlsat_solver.cpp b/src/nlsat/nlsat_solver.cpp index ba14be6aa..1b9c0eb39 100644 --- a/src/nlsat/nlsat_solver.cpp +++ b/src/nlsat/nlsat_solver.cpp @@ -450,7 +450,7 @@ namespace nlsat { m_is_int. push_back(is_int); m_watches. push_back(clause_vector()); m_infeasible.push_back(0); - m_var2eq. push_back(0); + m_var2eq. push_back(nullptr); m_perm. push_back(x); m_inv_perm. push_back(x); SASSERT(m_is_int.size() == m_watches.size()); @@ -948,7 +948,7 @@ namespace nlsat { m_levels[b] = m_scope_lvl; m_justifications[b] = j; save_assign_trail(b); - updt_eq(b); + updt_eq(b, j); TRACE("nlsat_assign", tout << "b" << b << " -> " << m_bvalues[b] << "\n";); } @@ -1047,11 +1047,12 @@ namespace nlsat { \brief assign l to true, because l + (justification of) s is infeasible in RCF in the current interpretation. */ literal_vector core; + ptr_vector clauses; void R_propagate(literal l, interval_set const * s, bool include_l = true) { - m_ism.get_justifications(s, core); + m_ism.get_justifications(s, core, clauses); if (include_l) core.push_back(~l); - assign(l, mk_lazy_jst(m_allocator, core.size(), core.c_ptr())); + assign(l, mk_lazy_jst(m_allocator, core.size(), core.c_ptr(), clauses.size(), clauses.c_ptr())); SASSERT(value(l) == l_true); } @@ -1074,7 +1075,7 @@ namespace nlsat { /** \brief Update m_var2eq mapping. */ - void updt_eq(bool_var b) { + void updt_eq(bool_var b, justification j) { if (!m_simplify_cores) return; if (m_bvalues[b] != l_true) @@ -1082,6 +1083,16 @@ namespace nlsat { atom * a = m_atoms[b]; if (a == nullptr || a->get_kind() != atom::EQ || to_ineq_atom(a)->size() > 1 || to_ineq_atom(a)->is_even(0)) return; + switch (j.get_kind()) { + case justification::CLAUSE: + if (j.get_clause()->assumptions() != nullptr) return; + break; + case justification::LAZY: + if (j.get_lazy()->num_clauses() > 0) return; + break; + default: + break; + } var x = m_xk; SASSERT(a->max_var() == x); SASSERT(x != null_var); @@ -1121,7 +1132,7 @@ namespace nlsat { atom * a = m_atoms[b]; SASSERT(a != nullptr); interval_set_ref curr_set(m_ism); - curr_set = m_evaluator.infeasible_intervals(a, l.sign()); + curr_set = m_evaluator.infeasible_intervals(a, l.sign(), &cls); TRACE("nlsat_inf_set", tout << "infeasible set for literal: "; display(tout, l); tout << "\n"; m_ism.display(tout, curr_set); tout << "\n"; display(tout, cls) << "\n";); if (m_ism.is_empty(curr_set)) { @@ -1553,7 +1564,7 @@ namespace nlsat { void resolve_lazy_justification(bool_var b, lazy_justification const & jst) { TRACE("nlsat_resolve", tout << "resolving lazy_justification for b: " << b << "\n";); - unsigned sz = jst.size(); + unsigned sz = jst.num_lits(); // Dump lemma as Mathematica formula that must be true, // if the current interpretation (really) makes the core in jst infeasible. @@ -1561,15 +1572,15 @@ namespace nlsat { tout << "assignment lemma\n"; literal_vector core; for (unsigned i = 0; i < sz; i++) { - core.push_back(~jst[i]); + core.push_back(~jst.lit(i)); } display_mathematica_lemma(tout, core.size(), core.c_ptr(), true);); m_lazy_clause.reset(); - m_explain(jst.size(), jst.lits(), m_lazy_clause); + m_explain(jst.num_lits(), jst.lits(), m_lazy_clause); for (unsigned i = 0; i < sz; i++) - m_lazy_clause.push_back(~jst[i]); - + m_lazy_clause.push_back(~jst.lit(i)); + // lazy clause is a valid clause TRACE("nlsat_mathematica", display_mathematica_lemma(tout, m_lazy_clause.size(), m_lazy_clause.c_ptr());); TRACE("nlsat_proof_sk", tout << "theory lemma\n"; display_abst(tout, m_lazy_clause.size(), m_lazy_clause.c_ptr()); tout << "\n";); @@ -1593,6 +1604,12 @@ namespace nlsat { } }); resolve_clause(b, m_lazy_clause.size(), m_lazy_clause.c_ptr()); + + for (unsigned i = 0; i < jst.num_clauses(); ++i) { + clause const& c = jst.clause(i); + TRACE("nlsat", display(tout << "adding clause assumptions ", c) << "\n";); + m_lemma_assumptions = m_asm.mk_join(static_cast<_assumption_set>(c.assumptions()), m_lemma_assumptions); + } } /**