diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index 1bfaf2b17..0f5ccdcc3 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -91,6 +91,7 @@ namespace sat { virtual void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r, bool probing) = 0; virtual bool is_extended_binary(ext_justification_idx idx, literal_vector & r) { return false; } virtual void asserted(literal l) {}; + virtual void set_eliminated(bool_var v) {}; virtual check_result check() = 0; virtual lbool resolve_conflict() { return l_undef; } // stores result in sat::solver::m_lemma virtual void push() = 0; diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 301140eb1..6d7dda150 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -328,14 +328,19 @@ namespace sat { } void solver::set_eliminated(bool_var v, bool f) { - if (m_eliminated[v] && !f) + if (m_eliminated[v] == f) + return; + if (!f) reset_var(v, m_external[v], m_decision[v]); + else if (f && m_ext) + m_ext->set_eliminated(v); m_eliminated[v] = f; } clause* solver::mk_clause(unsigned num_lits, literal * lits, sat::status st) { m_model_is_current = false; + for (unsigned i = 0; i < num_lits; i++) VERIFY(!was_eliminated(lits[i])); diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index 9a9f895c9..c2e5c26ec 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -128,6 +128,7 @@ namespace euf { s().set_external(v); s().set_eliminated(v, false); + if (lit.sign()) { v = si.add_bool_var(e); s().set_external(v); @@ -265,17 +266,18 @@ namespace euf { sat::status st = sat::status::th(m_is_redundant, m.get_basic_family_id()); expr* c = nullptr, * th = nullptr, * el = nullptr; if (!m.is_bool(e) && m.is_ite(e, c, th, el)) { - app* a = to_app(e); - expr_ref eq_th = mk_eq(a, th); + expr_ref eq_th = mk_eq(e, th); sat::literal lit_th = mk_literal(eq_th); if (th == el) { s().add_clause(1, &lit_th, st); } else { sat::bool_var v = si.to_bool_var(c); + VERIFY(v != sat::null_bool_var); + VERIFY(s().is_external(v)); SASSERT(v != sat::null_bool_var); - - expr_ref eq_el = mk_eq(a, el); + VERIFY(!s().was_eliminated(v)); + expr_ref eq_el = mk_eq(e, el); sat::literal lit_el = mk_literal(eq_el); literal lits1[2] = { literal(v, true), lit_th }; diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index 9a3eedcb8..4a2d9f40d 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -279,6 +279,11 @@ namespace euf { } } + void solver::set_eliminated(bool_var v) { + si.uncache(literal(v, false)); + si.uncache(literal(v, true)); + } + void solver::asserted(literal l) { expr* e = m_bool_var2expr.get(l.var(), nullptr); if (!e) { diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 112a8a35d..5e5525be9 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -293,6 +293,7 @@ namespace euf { void get_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing); void add_antecedent(enode* a, enode* b); void add_diseq_antecedent(enode* a, enode* b); + void set_eliminated(bool_var v) override; void asserted(literal l) override; sat::check_result check() override; void push() override; diff --git a/src/sat/smt/sat_smt.h b/src/sat/smt/sat_smt.h index 862236649..0ba632f63 100644 --- a/src/sat/smt/sat_smt.h +++ b/src/sat/smt/sat_smt.h @@ -36,6 +36,7 @@ namespace sat { virtual bool_var to_bool_var(expr* e) = 0; virtual bool_var add_bool_var(expr* e) = 0; virtual void cache(app* t, literal l) = 0; + virtual void uncache(literal l) = 0; virtual void push() = 0; virtual void pop(unsigned n) = 0; virtual void set_expr2var_replay(obj_map* r) = 0; diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index b626b3460..d774e30a1 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -60,7 +60,10 @@ struct goal2sat::imp : public sat::sat_internalizer { pb_util pb; svector m_frame_stack; svector m_result_stack; - obj_map m_cache; + obj_map m_app2lit; + u_map m_lit2app; + unsigned_vector m_cache_lim; + ptr_vector m_cache_trail; obj_hashtable m_interface_vars; sat::solver_core & m_solver; atom2bool_var & m_map; @@ -205,12 +208,11 @@ struct goal2sat::imp : public sat::sat_internalizer { sat::bool_var v = m_map.to_bool_var(e); if (v != sat::null_bool_var) return v; - if (is_app(e) && m_cache.find(to_app(e), l) && !l.sign()) + if (is_app(e) && m_app2lit.find(to_app(e), l) && !l.sign()) return l.var(); return sat::null_bool_var; } - void set_expr2var_replay(obj_map* r) override { m_expr2var_replay = r; } @@ -236,8 +238,10 @@ struct goal2sat::imp : public sat::sat_internalizer { unsigned m_num_scopes{ 0 }; void force_push() { - for (; m_num_scopes > 0; --m_num_scopes) + for (; m_num_scopes > 0; --m_num_scopes) { m_map.push(); + m_cache_lim.push_back(m_cache_trail.size()); + } } void push() override { @@ -251,12 +255,37 @@ struct goal2sat::imp : public sat::sat_internalizer { } n -= m_num_scopes; m_num_scopes = 0; - m_cache.reset(); m_map.pop(n); + unsigned k = m_cache_lim[m_cache_lim.size() - n]; + for (; k-- > m_cache_trail.size(); ) { + app* t = m_cache_trail[k]; + sat::literal lit; + if (m_app2lit.find(t, lit)) { + m_app2lit.remove(t); + m_lit2app.remove(lit.index()); + } + } + m_cache_trail.shrink(k); + m_cache_lim.shrink(m_cache_lim.size() - n); } + // remove non-external literals from cache. + void uncache(sat::literal lit) override { + app* t = nullptr; + if (m_lit2app.find(lit.index(), t)) { + m_lit2app.remove(lit.index()); + m_app2lit.remove(t); + } + } + + void cache(app* t, sat::literal l) override { - m_cache.insert(t, l); + SASSERT(m_num_scopes == 0); + SASSERT(!m_app2lit.contains(t)); + SASSERT(!m_lit2app.contains(l.index())); + m_app2lit.insert(t, l); + m_lit2app.insert(l.index(), t); + m_cache_trail.push_back(t); } void convert_atom(expr * t, bool root, bool sign) { @@ -317,7 +346,7 @@ struct goal2sat::imp : public sat::sat_internalizer { bool process_cached(app* t, bool root, bool sign) { sat::literal l = sat::null_literal; - if (!m_cache.find(t, l)) + if (!m_app2lit.find(t, l)) return false; if (sign) l.neg(); @@ -396,7 +425,7 @@ struct goal2sat::imp : public sat::sat_internalizer { SASSERT(num <= m_result_stack.size()); sat::bool_var k = add_var(false, t); sat::literal l(k, false); - m_cache.insert(t, l); + cache(t, l); sat::literal * lits = m_result_stack.end() - num; for (unsigned i = 0; i < num; i++) mk_clause(~lits[i], l); @@ -445,7 +474,7 @@ struct goal2sat::imp : public sat::sat_internalizer { SASSERT(num <= m_result_stack.size()); sat::bool_var k = add_var(false, t); sat::literal l(k, false); - m_cache.insert(t, l); + cache(t, l); sat::literal * lits = m_result_stack.end() - num; // l => /\ lits @@ -497,7 +526,7 @@ struct goal2sat::imp : public sat::sat_internalizer { else { sat::bool_var k = add_var(false, n); sat::literal l(k, false); - m_cache.insert(n, l); + cache(n, l); mk_clause(~l, ~c, t); mk_clause(~l, c, e); mk_clause(l, ~c, ~t); @@ -534,7 +563,7 @@ struct goal2sat::imp : public sat::sat_internalizer { else { sat::bool_var k = add_var(false, t); sat::literal l(k, false); - m_cache.insert(t, l); + cache(t, l); // l <=> (l1 => l2) mk_clause(~l, ~l1, l2); mk_clause(l1, l); @@ -571,7 +600,7 @@ struct goal2sat::imp : public sat::sat_internalizer { mk_clause(l, l1, l2); mk_clause(l, ~l1, ~l2); if (aig()) aig()->add_iff(l, l1, l2); - m_cache.insert(t, m.is_xor(t) ? ~l : l); + cache(t, m.is_xor(t) ? ~l : l); if (sign) l.neg(); m_result_stack.push_back(l); @@ -787,8 +816,10 @@ struct goal2sat::imp : public sat::sat_internalizer { SASSERT(m_result_stack.size() == sz + 1); sat::literal result = m_result_stack.back(); m_result_stack.pop_back(); - if (!result.sign() && m_map.to_bool_var(n) == sat::null_bool_var) + if (!result.sign() && m_map.to_bool_var(n) == sat::null_bool_var) { m_map.insert(n, result.var()); + m_solver.set_external(result.var()); + } return result; } @@ -853,7 +884,7 @@ struct goal2sat::imp : public sat::sat_internalizer { scoped_reset(imp& i) :i(i) {} ~scoped_reset() { i.m_interface_vars.reset(); - i.m_cache.reset(); + i.m_app2lit.reset(); } }; scoped_reset _reset(*this);