From 1bd73d4635cfe6ffd5b8d1c4ee952ea959b30b27 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 15 Aug 2023 11:12:55 -0700 Subject: [PATCH] initial stab at mbi Signed-off-by: Nikolaj Bjorner --- src/qe/qe_mbi.cpp | 407 ++++++++++++++++++----------------- src/qe/qe_mbi.h | 36 ++-- src/sat/smt/synth_solver.cpp | 74 +++++-- src/sat/smt/synth_solver.h | 2 +- 4 files changed, 292 insertions(+), 227 deletions(-) diff --git a/src/qe/qe_mbi.cpp b/src/qe/qe_mbi.cpp index 48a7928be..b1dc58999 100644 --- a/src/qe/qe_mbi.cpp +++ b/src/qe/qe_mbi.cpp @@ -164,80 +164,25 @@ namespace qe { } // ------------------------------- - // uflia_mbi + // uflia_project - struct uflia_mbi::is_atom_proc { - ast_manager& m; - expr_ref_vector& m_atoms; - obj_hashtable& m_atom_set; + /** + * \brief Order arithmetical variables: + * sort arithmetical terms, such that deepest terms are first. + */ + void uflia_project::order_avars(app_ref_vector& avars) { - is_atom_proc(expr_ref_vector& atoms, obj_hashtable& atom_set): - m(atoms.m()), m_atoms(atoms), m_atom_set(atom_set) {} - - void operator()(app* a) { - if (m_atom_set.contains(a)) { - // continue - } - else if (m.is_eq(a) && !m.is_iff(a)) { - m_atoms.push_back(a); - m_atom_set.insert(a); - } - else if (m.is_bool(a) && a->get_family_id() != m.get_basic_family_id()) { - m_atoms.push_back(a); - m_atom_set.insert(a); - } - } - void operator()(expr*) {} - }; - - uflia_mbi::uflia_mbi(solver* s, solver* sNot): - mbi_plugin(s->get_manager()), - m_atoms(m), - m_fmls(m), - m_solver(s), - m_dual_solver(sNot) { - params_ref p; - p.set_bool("core.minimize", true); - m_solver->updt_params(p); - m_dual_solver->updt_params(p); - m_solver->get_assertions(m_fmls); - collect_atoms(m_fmls); + // sort avars based on depth + std::function compare_depth = + [](app* x, app* y) { + return + (x->get_depth() > y->get_depth()) || + (x->get_depth() == y->get_depth() && x->get_id() > y->get_id()); + }; + std::sort(avars.data(), avars.data() + avars.size(), compare_depth); + TRACE("qe", tout << "avars:" << avars << "\n";); } - void uflia_mbi::collect_atoms(expr_ref_vector const& fmls) { - expr_fast_mark1 marks; - is_atom_proc proc(m_atoms, m_atom_set); - for (expr* e : fmls) { - quick_for_each_expr(proc, marks, e); - } - } - - bool uflia_mbi::get_literals(model_ref& mdl, expr_ref_vector& lits) { - lits.reset(); - IF_VERBOSE(10, verbose_stream() << "atoms: " << m_atoms << "\n"); - for (expr* e : m_atoms) { - if (mdl->is_true(e)) - lits.push_back(e); - else if (mdl->is_false(e)) - lits.push_back(m.mk_not(e)); - } - TRACE("qe", tout << "atoms from model: " << lits << "\n";); - solver_ref dual = m_dual_solver->translate(m, m_dual_solver->get_params()); - dual->assert_expr(mk_not(mk_and(m_fmls))); - lbool r = dual->check_sat(lits); - TRACE("qe", dual->display(tout << "dual result " << r << "\n");); - if (l_false == r) { - // use the dual solver to find a 'small' implicant - lits.reset(); - dual->get_unsat_core(lits); - return true; - } - else { - return false; - } - } - - /** * \brief A subterm is an arithmetic variable if: * 1. it is not shared. @@ -246,7 +191,7 @@ namespace qe { * * The result is ordered using deepest term first. */ - app_ref_vector uflia_mbi::get_arith_vars(expr_ref_vector const& lits) { + app_ref_vector uflia_project::get_arith_vars(expr_ref_vector const& lits) { app_ref_vector avars(m); bool_vector seen; arith_util a(m); @@ -273,7 +218,7 @@ namespace qe { For these cases we apply model refinement to the literals: non-shared sub-expressions are replaced by model values. */ - void uflia_mbi::fix_non_shared(model& mdl, expr_ref_vector& lits) { + void uflia_project::fix_non_shared(model& mdl, expr_ref_vector& lits) { th_rewriter rewrite(m); expr_ref_vector trail(m); obj_map cache; @@ -325,7 +270,7 @@ namespace qe { lits[i] = cache[lits.get(i)]; } - vector uflia_mbi::arith_project(model_ref& mdl, app_ref_vector& avars, expr_ref_vector& lits) { + vector uflia_project::arith_project(model_ref& mdl, app_ref_vector& avars, expr_ref_vector& lits) { mbp::arith_project_plugin ap(m); ap.set_check_purified(false); vector defs; @@ -336,6 +281,196 @@ namespace qe { return defs; } + void uflia_project::split_arith(expr_ref_vector const& lits, + expr_ref_vector& alits, + expr_ref_vector& uflits) { + arith_util a(m); + for (expr* lit : lits) { + expr* atom = lit, *x = nullptr, *y = nullptr; + m.is_not(lit, atom); + if (m.is_eq(atom, x, y)) { + if (a.is_int_real(x)) { + alits.push_back(lit); + } + uflits.push_back(lit); + } + else if (a.is_arith_expr(atom)) { + alits.push_back(lit); + } + else { + uflits.push_back(lit); + } + } + TRACE("qe", + tout << "alits: " << alits << "\n"; + tout << "uflits: " << uflits << "\n";); + } + + /** + \brief add difference certificates to formula. + */ + void uflia_project::add_dcert(model_ref& mdl, expr_ref_vector& lits) { + mbp::term_graph tg(m); + add_arith_dcert(*mdl.get(), lits); + func_decl_ref_vector shared(m_shared_trail); + tg.set_vars(shared, false); + lits.append(tg.dcert(*mdl.get(), lits)); + TRACE("qe", tout << "project: " << lits << "\n";); + } + + /** + Add disequalities between functions that appear in arithmetic context. + */ + void uflia_project::add_arith_dcert(model& mdl, expr_ref_vector& lits) { + obj_map> apps; + arith_util a(m); + for (expr* e : subterms::ground(lits)) { + if (a.is_int_real(e) && is_uninterp(e) && to_app(e)->get_num_args() > 0) { + func_decl* f = to_app(e)->get_decl(); + apps.insert_if_not_there(f, ptr_vector()).push_back(to_app(e)); + } + } + for (auto const& kv : apps) { + ptr_vector const& es = kv.m_value; + expr_ref_vector values(m); + for (expr* e : kv.m_value) values.push_back(mdl(e)); + for (unsigned i = 0; i < es.size(); ++i) { + expr* v1 = values.get(i); + for (unsigned j = i + 1; j < es.size(); ++j) { + expr* v2 = values.get(j); + if (v1 != v2) { + add_arith_dcert(mdl, lits, es[i], es[j]); + } + } + } + } + } + + void uflia_project::add_arith_dcert(model& mdl, expr_ref_vector& lits, app* a, app* b) { + arith_util arith(m); + SASSERT(a->get_decl() == b->get_decl()); + for (unsigned i = a->get_num_args(); i-- > 0; ) { + expr* arg1 = a->get_arg(i), *arg2 = b->get_arg(i); + if (arith.is_int_real(arg1) && mdl(arg1) != mdl(arg2)) { + lits.push_back(m.mk_not(m.mk_eq(arg1, arg2))); + return; + } + } + } + + /** + * \brief project private symbols. + */ + void uflia_project::project_euf(model_ref& mdl, expr_ref_vector& lits) { + mbp::term_graph tg(m); + func_decl_ref_vector shared(m_shared_trail); + tg.set_vars(shared, false); + tg.add_lits(lits); + lits.reset(); + lits.append(tg.project(*mdl.get())); + TRACE("qe", tout << "project: " << lits << "\n";); + } + + vector uflia_project::project_solve(model_ref& mdl, expr_ref_vector& lits) { + TRACE("qe", tout << "project literals: " << lits << "\n" << *mdl << "\n"); + + add_dcert(mdl, lits); + expr_ref_vector alits(m), uflits(m); + split_arith(lits, alits, uflits); + auto avars = get_arith_vars(lits); + vector defs = arith_project(mdl, avars, alits); + for (auto const& d : defs) uflits.push_back(m.mk_eq(d.var, d.term)); + TRACE("qe", tout << "uflits: " << uflits << "\n";); + project_euf(mdl, uflits); + lits.reset(); + lits.append(alits); + lits.append(uflits); + IF_VERBOSE(10, verbose_stream() << "projection : " << lits << "\n"); + TRACE("qe", + tout << "projection: " << lits << "\n"; + tout << "avars: " << avars << "\n"; + tout << "alits: " << lits << "\n"; + tout << "uflits: " << uflits << "\n";); + return defs; + } + + + + // ------------------------------- + // uflia_mbi + + struct uflia_mbi::is_atom_proc { + ast_manager& m; + expr_ref_vector& m_atoms; + obj_hashtable& m_atom_set; + + is_atom_proc(expr_ref_vector& atoms, obj_hashtable& atom_set): + m(atoms.m()), m_atoms(atoms), m_atom_set(atom_set) {} + + void operator()(app* a) { + if (m_atom_set.contains(a)) { + // continue + } + else if (m.is_eq(a) && !m.is_iff(a)) { + m_atoms.push_back(a); + m_atom_set.insert(a); + } + else if (m.is_bool(a) && a->get_family_id() != m.get_basic_family_id()) { + m_atoms.push_back(a); + m_atom_set.insert(a); + } + } + void operator()(expr*) {} + }; + + uflia_mbi::uflia_mbi(solver* s, solver* sNot): + uflia_project(s->get_manager()), + m_atoms(m), + m_fmls(m), + m_solver(s), + m_dual_solver(sNot) { + params_ref p; + p.set_bool("core.minimize", true); + m_solver->updt_params(p); + m_dual_solver->updt_params(p); + m_solver->get_assertions(m_fmls); + collect_atoms(m_fmls); + } + + void uflia_mbi::collect_atoms(expr_ref_vector const& fmls) { + expr_fast_mark1 marks; + is_atom_proc proc(m_atoms, m_atom_set); + for (expr* e : fmls) { + quick_for_each_expr(proc, marks, e); + } + } + + bool uflia_mbi::get_literals(model_ref& mdl, expr_ref_vector& lits) { + lits.reset(); + IF_VERBOSE(10, verbose_stream() << "atoms: " << m_atoms << "\n"); + for (expr* e : m_atoms) { + if (mdl->is_true(e)) + lits.push_back(e); + else if (mdl->is_false(e)) + lits.push_back(m.mk_not(e)); + } + TRACE("qe", tout << "atoms from model: " << lits << "\n";); + solver_ref dual = m_dual_solver->translate(m, m_dual_solver->get_params()); + dual->assert_expr(mk_not(mk_and(m_fmls))); + lbool r = dual->check_sat(lits); + TRACE("qe", dual->display(tout << "dual result " << r << "\n");); + if (l_false == r) { + // use the dual solver to find a 'small' implicant + lits.reset(); + dual->get_unsat_core(lits); + return true; + } + else { + return false; + } + } + + mbi_result uflia_mbi::operator()(expr_ref_vector& lits, model_ref& mdl) { lbool r = m_solver->check_sat(lits); @@ -367,137 +502,13 @@ namespace qe { \brief main projection routine */ void uflia_mbi::project(model_ref& mdl, expr_ref_vector& lits) { - TRACE("qe", - tout << "project literals: " << lits << "\n" << *mdl << "\n"; - tout << m_solver->get_assertions() << "\n";); - - add_dcert(mdl, lits); - expr_ref_vector alits(m), uflits(m); - split_arith(lits, alits, uflits); - auto avars = get_arith_vars(lits); - vector defs = arith_project(mdl, avars, alits); - for (auto const& d : defs) uflits.push_back(m.mk_eq(d.var, d.term)); - TRACE("qe", tout << "uflits: " << uflits << "\n";); - project_euf(mdl, uflits); - lits.reset(); - lits.append(alits); - lits.append(uflits); - IF_VERBOSE(10, verbose_stream() << "projection : " << lits << "\n"); - TRACE("qe", - tout << "projection: " << lits << "\n"; - tout << "avars: " << avars << "\n"; - tout << "alits: " << lits << "\n"; - tout << "uflits: " << uflits << "\n";); - } - - void uflia_mbi::split_arith(expr_ref_vector const& lits, - expr_ref_vector& alits, - expr_ref_vector& uflits) { - arith_util a(m); - for (expr* lit : lits) { - expr* atom = lit, *x = nullptr, *y = nullptr; - m.is_not(lit, atom); - if (m.is_eq(atom, x, y)) { - if (a.is_int_real(x)) { - alits.push_back(lit); - } - uflits.push_back(lit); - } - else if (a.is_arith_expr(atom)) { - alits.push_back(lit); - } - else { - uflits.push_back(lit); - } - } - TRACE("qe", - tout << "alits: " << alits << "\n"; - tout << "uflits: " << uflits << "\n";); + project_solve(mdl, lits); } - /** - \brief add difference certificates to formula. - */ - void uflia_mbi::add_dcert(model_ref& mdl, expr_ref_vector& lits) { - mbp::term_graph tg(m); - add_arith_dcert(*mdl.get(), lits); - func_decl_ref_vector shared(m_shared_trail); - tg.set_vars(shared, false); - lits.append(tg.dcert(*mdl.get(), lits)); - TRACE("qe", tout << "project: " << lits << "\n";); - } - /** - Add disequalities between functions that appear in arithmetic context. - */ - void uflia_mbi::add_arith_dcert(model& mdl, expr_ref_vector& lits) { - obj_map> apps; - arith_util a(m); - for (expr* e : subterms::ground(lits)) { - if (a.is_int_real(e) && is_uninterp(e) && to_app(e)->get_num_args() > 0) { - func_decl* f = to_app(e)->get_decl(); - apps.insert_if_not_there(f, ptr_vector()).push_back(to_app(e)); - } - } - for (auto const& kv : apps) { - ptr_vector const& es = kv.m_value; - expr_ref_vector values(m); - for (expr* e : kv.m_value) values.push_back(mdl(e)); - for (unsigned i = 0; i < es.size(); ++i) { - expr* v1 = values.get(i); - for (unsigned j = i + 1; j < es.size(); ++j) { - expr* v2 = values.get(j); - if (v1 != v2) { - add_arith_dcert(mdl, lits, es[i], es[j]); - } - } - } - } - } - void uflia_mbi::add_arith_dcert(model& mdl, expr_ref_vector& lits, app* a, app* b) { - arith_util arith(m); - SASSERT(a->get_decl() == b->get_decl()); - for (unsigned i = a->get_num_args(); i-- > 0; ) { - expr* arg1 = a->get_arg(i), *arg2 = b->get_arg(i); - if (arith.is_int_real(arg1) && mdl(arg1) != mdl(arg2)) { - lits.push_back(m.mk_not(m.mk_eq(arg1, arg2))); - return; - } - } - } - - /** - * \brief project private symbols. - */ - void uflia_mbi::project_euf(model_ref& mdl, expr_ref_vector& lits) { - mbp::term_graph tg(m); - func_decl_ref_vector shared(m_shared_trail); - tg.set_vars(shared, false); - tg.add_lits(lits); - lits.reset(); - lits.append(tg.project(*mdl.get())); - TRACE("qe", tout << "project: " << lits << "\n";); - } - - /** - * \brief Order arithmetical variables: - * sort arithmetical terms, such that deepest terms are first. - */ - void uflia_mbi::order_avars(app_ref_vector& avars) { - - // sort avars based on depth - std::function compare_depth = - [](app* x, app* y) { - return - (x->get_depth() > y->get_depth()) || - (x->get_depth() == y->get_depth() && x->get_id() > y->get_id()); - }; - std::sort(avars.data(), avars.data() + avars.size(), compare_depth); - TRACE("qe", tout << "avars:" << avars << "\n";); - } void uflia_mbi::block(expr_ref_vector const& lits) { expr_ref clause(mk_not(mk_and(lits)), m); diff --git a/src/qe/qe_mbi.h b/src/qe/qe_mbi.h index 93f7df88d..9d28d4e38 100644 --- a/src/qe/qe_mbi.h +++ b/src/qe/qe_mbi.h @@ -115,7 +115,29 @@ namespace qe { void block(expr_ref_vector const& lits) override; }; - class uflia_mbi : public mbi_plugin { + class uflia_project : public mbi_plugin { + protected: + void order_avars(app_ref_vector& avars); + app_ref_vector get_arith_vars(expr_ref_vector const& lits); + void fix_non_shared(model& mdl, expr_ref_vector& lits); + vector<::mbp::def> arith_project(model_ref& mdl, app_ref_vector& avars, expr_ref_vector& lits); + void add_dcert(model_ref& mdl, expr_ref_vector& lits); + void add_arith_dcert(model& mdl, expr_ref_vector& lits); + void add_arith_dcert(model& mdl, expr_ref_vector& lits, app* a, app* b); + void project_euf(model_ref& mdl, expr_ref_vector& lits); + void split_arith(expr_ref_vector const& lits, + expr_ref_vector& alits, + expr_ref_vector& uflits); + public: + uflia_project(ast_manager& m): mbi_plugin(m) {} + + vector<::mbp::def> project_solve(model_ref& mdl, expr_ref_vector& lits); + void block(expr_ref_vector const& lits) override {} + mbi_result operator()(expr_ref_vector& lits, model_ref& mdl) override { return mbi_result::mbi_undef; } + + }; + + class uflia_mbi : public uflia_project { expr_ref_vector m_atoms; obj_hashtable m_atom_set; expr_ref_vector m_fmls; @@ -125,18 +147,8 @@ namespace qe { bool get_literals(model_ref& mdl, expr_ref_vector& lits); void collect_atoms(expr_ref_vector const& fmls); - void order_avars(app_ref_vector& avars); - void add_dcert(model_ref& mdl, expr_ref_vector& lits); - void add_arith_dcert(model& mdl, expr_ref_vector& lits); - void add_arith_dcert(model& mdl, expr_ref_vector& lits, app* a, app* b); - app_ref_vector get_arith_vars(expr_ref_vector const& lits); - vector<::mbp::def> arith_project(model_ref& mdl, app_ref_vector& avars, expr_ref_vector& lits); - void project_euf(model_ref& mdl, expr_ref_vector& lits); - void split_arith(expr_ref_vector const& lits, - expr_ref_vector& alits, - expr_ref_vector& uflits); - void fix_non_shared(model& mdl, expr_ref_vector& lits); + public: uflia_mbi(solver* s, solver* emptySolver); mbi_result operator()(expr_ref_vector& lits, model_ref& mdl) override; diff --git a/src/sat/smt/synth_solver.cpp b/src/sat/smt/synth_solver.cpp index d9b6f2fd7..b29832d48 100644 --- a/src/sat/smt/synth_solver.cpp +++ b/src/sat/smt/synth_solver.cpp @@ -19,9 +19,7 @@ Author: #include "ast/rewriter/th_rewriter.h" #include "sat/smt/synth_solver.h" #include "sat/smt/euf_solver.h" -#include "qe/mbp/mbp_term_graph.h" -#include "qe/mbp/mbp_arith.h" -#include "qe/mbp/mbp_arrays.h" +#include "qe/qe_mbi.h" namespace synth { @@ -36,11 +34,12 @@ namespace synth { solver::~solver() {} + + bool solver::is_output(expr * e) const { + return any_of(m_synth, [&](synth_objective const& a) { return a.output() == e; }); + } + bool solver::contains_uncomputable(expr* e) { - - auto is_output = [&](expr* e) { - return any_of(m_synth, [&](synth_objective const& a) { return a.output() == e; }); - }; return any_of(subterms::all(expr_ref(e, m)), [&](expr* a) { return (is_app(a) && m_uncomputable.contains(to_app(a)->get_decl())) || is_output(a); }); } @@ -353,12 +352,11 @@ namespace synth { compute_rep(); for (synth_objective const& e : m_synth) { - auto lit = synthesize(e); - if (lit == sat::null_literal) + expr_ref sol = compute_solution(e); + if (!sol) return false; - clause.push_back(~lit); + IF_VERBOSE(0, verbose_stream() << sol << "\n"); } - add_clause(clause); expr_ref cond = compute_condition(); add_unit(~mk_literal(cond)); IF_VERBOSE(0, verbose_stream() << "if " << cond << "\n"); @@ -412,13 +410,57 @@ namespace synth { arith_util a(m); if (!a.is_int_real(obj.output())) return false; + model_ref mdl = alloc(model, m); + ctx.update_model(mdl, false); verbose_stream() << "int-real-objective\n"; + verbose_stream() << *mdl << "\n"; + + expr_ref_vector lits(m), core(m); + for (unsigned i = 0; i < s().trail_size(); ++i) { + sat::literal l = s().trail_literal(i); + if (!ctx.is_relevant(l)) + continue; + expr_ref e = literal2expr(l); + if (e) + lits.push_back(e); + } + verbose_stream() << lits << "\n"; + + sat::no_drat_params no_drat_params; + ref<::solver> solver = mk_smt2_solver(m, no_drat_params, symbol::null); + solver->assert_expr(m.mk_not(m.mk_and(m_spec))); + lbool r = solver->check_sat(lits); + if (r != l_false) + return false; + solver->get_unsat_core(core); + verbose_stream() << "core " << core << "\n"; + + qe::uflia_project proj(m); + auto& egraph = ctx.get_egraph(); + func_decl_ref_vector shared(m); + ast_mark visited; + for (auto* n : egraph.nodes()) + if (is_app(n->get_expr()) && !is_output(n->get_expr()) && !m_uncomputable.contains(n->get_decl()) && !visited.is_marked(n->get_decl())) { + visited.mark(n->get_decl(), true); + shared.push_back(n->get_decl()); + } + verbose_stream() << "shared " << shared << "\n"; + proj.set_shared(shared); + auto defs = proj.project_solve(mdl, core); + + for (auto const& d : defs) { + verbose_stream() << d.var << " := " << d.term << "\n"; + if (d.var == obj.output()) { + obj.set_solution(d.term); + ctx.push(synth_objective::unset_solution(obj)); + return true; + } + } #if 0 - // 1 retrieve a model - // 1.5 - difference cert? - // 1.6 - split arith? - // 2 retrieve literal dependencies - // 3 split_arith, arith_vars, rpoejct, project_euf, + // - retrieve literal dependencies + // - difference cert? + // - split arith? + // - split_arith, arith_vars, rpoejct, project_euf, // produce projection add_dcert(mdl, lits); diff --git a/src/sat/smt/synth_solver.h b/src/sat/smt/synth_solver.h index d0f3e66d7..acfc5ad12 100644 --- a/src/sat/smt/synth_solver.h +++ b/src/sat/smt/synth_solver.h @@ -62,7 +62,7 @@ namespace synth { }; - + bool is_output(expr* e) const; sat::literal synthesize(synth_objective const& synth_objective); void add_uncomputable(app* e); void add_synth_objective(synth_objective const & e);