From e8826bb20f90c21e3cb4f196d4fa67b00d005691 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 30 Aug 2020 09:49:12 -0700 Subject: [PATCH] fix #4651 Signed-off-by: Nikolaj Bjorner --- src/sat/smt/ba_internalize.cpp | 12 +++--- src/sat/smt/ba_solver.cpp | 6 +-- src/sat/smt/ba_solver.h | 6 +-- src/sat/smt/euf_solver.cpp | 75 +++++++++++++++++++++++----------- src/sat/smt/euf_solver.h | 41 +++++++++++-------- src/sat/smt/sat_smt.h | 2 - src/sat/smt/sat_th.h | 7 +--- src/sat/tactic/goal2sat.cpp | 11 ++--- src/smt/theory_seq.cpp | 7 +++- 9 files changed, 98 insertions(+), 69 deletions(-) diff --git a/src/sat/smt/ba_internalize.cpp b/src/sat/smt/ba_internalize.cpp index 30684ce0d..352692ad1 100644 --- a/src/sat/smt/ba_internalize.cpp +++ b/src/sat/smt/ba_internalize.cpp @@ -192,9 +192,9 @@ namespace sat { literal l1(v1, false), l2(v2, false); bool_var v = s().add_var(false); literal l(v, false); - si.mk_clause(~l, l1); - si.mk_clause(~l, l2); - si.mk_clause(~l1, ~l2, l); + s().mk_clause(~l, l1); + s().mk_clause(~l, l2); + s().mk_clause(~l1, ~l2, l); si.cache(t, l); if (sign) l.neg(); return l; @@ -267,9 +267,9 @@ namespace sat { literal l1(v1, false), l2(v2, false); bool_var v = s().add_var(false); literal l(v, false); - si.mk_clause(~l, l1); - si.mk_clause(~l, l2); - si.mk_clause(~l1, ~l2, l); + s().mk_clause(~l, l1); + s().mk_clause(~l, l2); + s().mk_clause(~l1, ~l2, l); si.cache(t, l); if (sign) l.neg(); return l; diff --git a/src/sat/smt/ba_solver.cpp b/src/sat/smt/ba_solver.cpp index bcc674d37..077d59a07 100644 --- a/src/sat/smt/ba_solver.cpp +++ b/src/sat/smt/ba_solver.cpp @@ -3,7 +3,7 @@ Copyright (c) 2017 Microsoft Corporation Module Name: - ba_core.cpp + ba_solver.cpp Abstract: @@ -13,8 +13,6 @@ Author: Nikolaj Bjorner (nbjorner) 2017-01-30 -Revision History: - --*/ #include @@ -1845,8 +1843,6 @@ namespace sat { add_pb_ge(lit, wlits, k, false); } - - /* \brief return true to keep watching literal. */ diff --git a/src/sat/smt/ba_solver.h b/src/sat/smt/ba_solver.h index a3a028831..fbd63af51 100644 --- a/src/sat/smt/ba_solver.h +++ b/src/sat/smt/ba_solver.h @@ -567,9 +567,9 @@ namespace sat { ~ba_solver() override; void set_solver(solver* s) override { m_solver = s; } void set_lookahead(lookahead* l) override { m_lookahead = l; } - void add_at_least(bool_var v, literal_vector const& lits, unsigned k); - void add_pb_ge(bool_var v, svector const& wlits, unsigned k); - void add_xr(literal_vector const& lits); + void add_at_least(bool_var v, literal_vector const& lits, unsigned k); + void add_pb_ge(bool_var v, svector const& wlits, unsigned k); + void add_xr(literal_vector const& lits); bool propagate(literal l, ext_constraint_idx idx) override; lbool resolve_conflict() override; diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index e555990ad..42a69af0e 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -42,11 +42,12 @@ namespace euf { sat::th_solver* solver::get_solver(expr* e) { if (is_app(e)) - return fid2solver(to_app(e)->get_family_id()); + return get_solver(to_app(e)->get_decl()); return nullptr; } - sat::th_solver* solver::fid2solver(family_id fid) { + sat::th_solver* solver::get_solver(func_decl* f) { + family_id fid = f->get_family_id(); if (fid == null_family_id) return nullptr; auto* ext = m_id2solver.get(fid, nullptr); @@ -54,14 +55,17 @@ namespace euf { return ext; pb_util pb(m); if (pb.get_family_id() == fid) { - sat::ba_solver* ba = alloc(sat::ba_solver, m, si); - ba->set_solver(m_solver); - add_solver(pb.get_family_id(), ba); - ba->push_scopes(s().num_scopes()); - return ba; + ext = alloc(sat::ba_solver, m, si); } - - return nullptr; + if (ext) { + ext->set_solver(m_solver); + ext->push_scopes(s().num_scopes()); + add_solver(fid, ext); + } + else { + unhandled_function(f); + } + return ext; } void solver::add_solver(family_id fid, sat::th_solver* th) { @@ -69,6 +73,11 @@ namespace euf { m_id2solver.setx(fid, th, nullptr); } + void solver::unhandled_function(func_decl* f) { + IF_VERBOSE(0, verbose_stream() << mk_pp(f, m) << " not handled\n"); + // TBD: set some state with the unhandled function. + } + bool solver::propagate(literal l, ext_constraint_idx idx) { auto* ext = sat::constraint_base::to_extension(idx); SASSERT(ext != this); @@ -87,25 +96,26 @@ namespace euf { m_explain.reset(); euf::enode* n = nullptr; bool sign = false; - if (j.id() != 0) { - auto p = m_var2node[l.var()]; - n = p.first; - SASSERT(n); - sign = l.sign() != p.second; - } + enode_bool_pair p; // init_ackerman(); - switch (j.id()) { - case 0: + switch (j.kind()) { + case constraint::conflict: SASSERT(m_egraph.inconsistent()); m_egraph.explain(m_explain); break; - case 1: + case constraint::eq: + n = m_var2node[l.var()].first; + SASSERT(n); SASSERT(m_egraph.is_equality(n)); m_egraph.explain_eq(m_explain, n->get_arg(0), n->get_arg(1), n->commutative()); break; - case 2: + case constraint::lit: + p = m_var2node[l.var()]; + n = p.first; + sign = l.sign() != p.second; + SASSERT(n); SASSERT(m.is_bool(n->get_owner())); m_egraph.explain_eq(m_explain, n, (sign ? mk_false() : mk_true()), false); break; @@ -168,10 +178,10 @@ namespace euf { } } - constraint& solver::mk_constraint(constraint*& c, unsigned id) { + constraint& solver::mk_constraint(constraint*& c, constraint::kind_t k) { if (!c) { void* mem = memory::allocate(sat::constraint_base::obj_size(sizeof(constraint))); - c = new (sat::constraint_base::ptr2mem(mem)) constraint(id); + c = new (sat::constraint_base::ptr2mem(mem)) constraint(k); sat::constraint_base::initialize(mem, this); } return *c; @@ -321,7 +331,7 @@ namespace euf { bool solver::is_blocked(literal l, ext_constraint_idx idx) { auto* ext = sat::constraint_base::to_extension(idx); if (ext != this) - return is_blocked(l, idx); + return ext->is_blocked(l, idx); return false; } @@ -345,6 +355,24 @@ namespace euf { return w; } + double solver::get_reward(literal l, ext_constraint_idx idx, sat::literal_occs_fun& occs) const { + double r = 0; + for (auto* e : m_solvers) { + r = e->get_reward(l, idx, occs); + if (r != 0) + return r; + } + return r; + } + + bool solver::is_extended_binary(ext_justification_idx idx, literal_vector& r) { + for (auto* e : m_solvers) { + if (e->is_extended_binary(idx, r)) + return true; + } + return false; + } + void solver::init_ackerman() { if (m_ackerman) return; @@ -365,7 +393,7 @@ namespace euf { auto* ext = get_solver(e); if (ext) return ext->internalize(e, sign, root); - std::cout << mk_pp(e, m) << "\n"; + IF_VERBOSE(0, verbose_stream() << "internalize: " << mk_pp(e, m) << "\n"); SASSERT(!si.is_bool_op(e)); sat::scoped_stack _sc(m_stack); unsigned sz = m_stack.size(); @@ -429,7 +457,6 @@ namespace euf { expr* e = n->get_owner(); if (m.is_bool(e)) { sat::bool_var v = si.add_bool_var(e); - std::cout << "attach " << v << "\n"; attach_bool_var(v, false, n); } } diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 26c41abeb..0c2545356 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -34,12 +34,13 @@ namespace euf { typedef sat::bool_var bool_var; class constraint { - unsigned m_id; public: - constraint(unsigned id) : - m_id(id) - {} - unsigned id() const { return m_id; } + enum kind_t { conflict, eq, lit}; + private: + kind_t m_kind; + public: + constraint(kind_t k) : m_kind(k) {} + unsigned kind() const { return m_kind; } static constraint* from_idx(size_t z) { return reinterpret_cast(z); } size_t to_index() const { return sat::constraint_base::mem2base(this); } }; @@ -61,10 +62,10 @@ namespace euf { stats m_stats; sat::solver* m_solver { nullptr }; sat::lookahead* m_lookahead { nullptr }; - ast_manager* m_to_m { nullptr }; - atom2bool_var* m_to_expr2var { nullptr }; - sat::sat_internalizer* m_to_si{ nullptr }; - scoped_ptr m_ackerman; + ast_manager* m_to_m; + atom2bool_var* m_to_expr2var; + sat::sat_internalizer* m_to_si; + scoped_ptr m_ackerman; svector m_var2node; ptr_vector m_explain; @@ -91,11 +92,11 @@ namespace euf { euf::enode* mk_false(); // extensions - sat::th_solver* get_solver(func_decl* f) { return fid2solver(f->get_family_id()); } + sat::th_solver* get_solver(func_decl* f); sat::th_solver* get_solver(expr* e); sat::th_solver* get_solver(sat::bool_var v); - sat::th_solver* fid2solver(family_id fid); void add_solver(family_id fid, sat::th_solver* th); + void unhandled_function(func_decl* f); void init_ackerman(); // model building @@ -109,10 +110,10 @@ namespace euf { void propagate(); void get_antecedents(literal l, constraint& j, literal_vector& r); - constraint& mk_constraint(constraint*& c, unsigned id); - constraint& conflict_constraint() { return mk_constraint(m_conflict, 0); } - constraint& eq_constraint() { return mk_constraint(m_eq, 1); } - constraint& lit_constraint() { return mk_constraint(m_lit, 2); } + constraint& mk_constraint(constraint*& c, constraint::kind_t k); + constraint& conflict_constraint() { return mk_constraint(m_conflict, constraint::conflict); } + constraint& eq_constraint() { return mk_constraint(m_eq, constraint::eq); } + constraint& lit_constraint() { return mk_constraint(m_lit, constraint::lit); } public: solver(ast_manager& m, atom2bool_var& expr2var, sat::sat_internalizer& si, params_ref const& p = params_ref()): @@ -146,11 +147,15 @@ namespace euf { s.m_to_expr2var = &a2b; s.m_to_si = &si; } - ~scoped_set_translate() { s.m_to_m = &s.m; s.m_to_expr2var = &s.m_expr2var; s.m_to_si = &s.si; } + ~scoped_set_translate() { + s.m_to_m = &s.m; + s.m_to_expr2var = &s.m_expr2var; + s.m_to_si = &s.si; + } }; - double get_reward(literal l, ext_constraint_idx idx, sat::literal_occs_fun& occs) const override { return 0; } - bool is_extended_binary(ext_justification_idx idx, literal_vector & r) override { return false; } + double get_reward(literal l, ext_constraint_idx idx, sat::literal_occs_fun& occs) const override; + bool is_extended_binary(ext_justification_idx idx, literal_vector& r) override; bool propagate(literal l, ext_constraint_idx idx) override; void get_antecedents(literal l, ext_justification_idx idx, literal_vector & r) override; void asserted(literal l) override; diff --git a/src/sat/smt/sat_smt.h b/src/sat/smt/sat_smt.h index ec718bf93..37e447ab4 100644 --- a/src/sat/smt/sat_smt.h +++ b/src/sat/smt/sat_smt.h @@ -40,8 +40,6 @@ namespace sat { virtual bool is_bool_op(expr* e) const = 0; virtual literal internalize(expr* e) = 0; virtual bool_var add_bool_var(expr* e) = 0; - virtual void mk_clause(literal a, literal b) = 0; - virtual void mk_clause(literal l1, literal l2, literal l3, bool is_lemma = false) = 0; virtual void cache(app* t, literal l) = 0; }; diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index 220111db5..44bd6b1bb 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -18,7 +18,7 @@ Author: #include "util/top_sort.h" #include "sat/smt/sat_smt.h" -#include "ast/euf/euf_egraph.h" +#include "ast/euf/euf_enode.h" namespace sat { @@ -27,8 +27,6 @@ namespace sat { virtual ~th_internalizer() {} virtual literal internalize(expr* e, bool sign, bool root) = 0; - - }; class th_decompile { @@ -64,8 +62,7 @@ namespace sat { public: virtual ~th_solver() {} - virtual th_solver* fresh(solver* s, ast_manager& m, sat_internalizer& si) = 0; - + virtual th_solver* fresh(solver* s, ast_manager& m, sat_internalizer& si) = 0; }; diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index b9effe74e..54f51695e 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -106,14 +106,14 @@ struct goal2sat::imp : public sat::sat_internalizer { m_solver.add_clause(1, &l, false); } - void mk_clause(sat::literal l1, sat::literal l2) override { + void mk_clause(sat::literal l1, sat::literal l2) { TRACE("goal2sat", tout << "mk_clause: " << l1 << " " << l2 << "\n";); m_solver.add_clause(l1, l2, false); } - void mk_clause(sat::literal l1, sat::literal l2, sat::literal l3, bool is_lemma = false) override { + void mk_clause(sat::literal l1, sat::literal l2, sat::literal l3) { TRACE("goal2sat", tout << "mk_clause: " << l1 << " " << l2 << " " << l3 << "\n";); - m_solver.add_clause(l1, l2, l3, is_lemma); + m_solver.add_clause(l1, l2, l3, false); } void mk_clause(unsigned num, sat::literal * lits) { @@ -385,8 +385,8 @@ struct goal2sat::imp : public sat::sat_internalizer { mk_clause(l, ~c, ~t); mk_clause(l, c, ~e); if (m_ite_extra) { - mk_clause(~t, ~e, l, false); - mk_clause(t, e, ~l, false); + mk_clause(~t, ~e, l); + mk_clause(t, e, ~l); } if (m_aig) m_aig->add_ite(l, c, t, e); if (sign) @@ -801,6 +801,7 @@ void goal2sat::operator()(goal const & g, params_ref const & p, sat::solver_core dealloc(m_imp); m_imp = nullptr; } + } void goal2sat::get_interpreted_atoms(expr_ref_vector& atoms) { diff --git a/src/smt/theory_seq.cpp b/src/smt/theory_seq.cpp index 921e5b45f..bc5a36bc9 100644 --- a/src/smt/theory_seq.cpp +++ b/src/smt/theory_seq.cpp @@ -1724,9 +1724,11 @@ std::ostream& theory_seq::display_deps(std::ostream& out, literal_vector const& smt2_pp_environment_dbg env(m); params_ref p; for (auto const& eq : eqs) { + if (eq.first->get_root() != eq.second->get_root()) + out << "invalid: "; out << " (= " << mk_bounded_pp(eq.first->get_owner(), m, 2) << "\n " << mk_bounded_pp(eq.second->get_owner(), m, 2) - << ")\n"; + << ")\n"; } for (literal l : lits) { display_lit(out, l) << "\n"; @@ -2908,6 +2910,7 @@ bool theory_seq::propagate_eq(dependency* deps, literal_vector const& _lits, exp } void theory_seq::assign_eh(bool_var v, bool is_true) { + force_push(); expr* e = ctx.bool_var2expr(v); expr* e1 = nullptr, *e2 = nullptr; expr_ref f(m); @@ -3023,6 +3026,7 @@ void theory_seq::assign_eh(bool_var v, bool is_true) { } void theory_seq::new_eq_eh(theory_var v1, theory_var v2) { + force_push(); enode* n1 = get_enode(v1); enode* n2 = get_enode(v2); expr* o1 = n1->get_owner(); @@ -3066,6 +3070,7 @@ void theory_seq::new_eq_eh(dependency* deps, enode* n1, enode* n2) { } void theory_seq::new_diseq_eh(theory_var v1, theory_var v2) { + force_push(); enode* n1 = get_enode(v1); enode* n2 = get_enode(v2); expr_ref e1(n1->get_owner(), m);