/*++ Copyright (c) 2014 Microsoft Corporation Module Name: fpa_solver.cpp Abstract: Floating-Point Theory Plugin Author: Christoph (cwinter) 2014-04-23 Revision History: Ported from theory_fpa by nbjorner in 2020. --*/ #include "sat/smt/fpa_solver.h" #include "ast/fpa/bv2fpa_converter.h" namespace fpa { solver::solver(euf::solver& ctx) : euf::th_euf_solver(ctx, symbol("fpa"), ctx.get_manager().mk_family_id("fpa")), m_th_rw(ctx.get_manager()), m_converter(ctx.get_manager(), m_th_rw), m_rw(ctx.get_manager(), m_converter, params_ref()), m_fpa_util(m_converter.fu()), m_bv_util(m_converter.bu()), m_arith_util(m_converter.au()) { params_ref p; p.set_bool("arith_lhs", true); m_th_rw.updt_params(p); } solver::~solver() { dec_ref_map_key_values(m, m_conversions); SASSERT(m_conversions.empty()); } expr_ref solver::convert(expr* e) { expr_ref res(m); expr* ccnv; TRACE("t_fpa", tout << "converting " << mk_ismt2_pp(e, m) << "\n";); if (m_conversions.find(e, ccnv)) { res = ccnv; TRACE("t_fpa_detail", tout << "cached:" << "\n"; tout << mk_ismt2_pp(e, m) << "\n" << " -> " << "\n" << mk_ismt2_pp(res, m) << "\n";); } else { res = m_rw.convert(m_th_rw, e); TRACE("t_fpa_detail", tout << "converted; caching:" << "\n"; tout << mk_ismt2_pp(e, m) << "\n" << " -> " << "\n" << mk_ismt2_pp(res, m) << "\n";); m_conversions.insert(e, res); m.inc_ref(e); m.inc_ref(res); ctx.push(insert_ref2_map(m, m_conversions, e, res.get())); } return res; } sat::literal_vector solver::mk_side_conditions() { sat::literal_vector conds; expr_ref t(m); for (expr* arg : m_converter.m_extra_assertions) { ctx.get_rewriter()(arg, t); m_th_rw(t); conds.push_back(mk_literal(t)); } m_converter.m_extra_assertions.reset(); return conds; } sat::check_result solver::check() { SASSERT(m_converter.m_extra_assertions.empty()); if (unit_propagate()) return sat::check_result::CR_CONTINUE; SASSERT(m_nodes.size() <= m_nodes_qhead); return sat::check_result::CR_DONE; } void solver::attach_new_th_var(enode* n) { theory_var v = mk_var(n); ctx.attach_th_var(n, this, v); TRACE("t_fpa", tout << "new theory var: " << mk_ismt2_pp(n->get_expr(), m) << " := " << v << "\n";); } sat::literal solver::internalize(expr* e, bool sign, bool root) { SASSERT(m.is_bool(e)); if (!visit_rec(m, e, sign, root)) return sat::null_literal; sat::literal lit = expr2literal(e); if (sign) lit.neg(); return lit; } void solver::internalize(expr* e) { visit_rec(m, e, false, false); } bool solver::visited(expr* e) { euf::enode* n = expr2enode(e); return n && n->is_attached_to(get_id()); } bool solver::visit(expr* e) { if (visited(e)) return true; if (!is_app(e) || to_app(e)->get_family_id() != get_id()) { ctx.internalize(e); return true; } m_stack.push_back(sat::eframe(e)); return false; } bool solver::post_visit(expr* e, bool sign, bool root) { euf::enode* n = expr2enode(e); SASSERT(!n || !n->is_attached_to(get_id())); if (!n) n = mk_enode(e, false); SASSERT(!n->is_attached_to(get_id())); attach_new_th_var(n); TRACE("fp", tout << "post: " << mk_bounded_pp(e, m) << "\n";); m_nodes.push_back(std::tuple(n, sign, root)); ctx.push(push_back_trail(m_nodes)); return true; } void solver::apply_sort_cnstr(enode* n, sort* s) { TRACE("t_fpa", tout << "apply sort cnstr for: " << mk_ismt2_pp(n->get_expr(), m) << "\n";); SASSERT(s->get_family_id() == get_id()); SASSERT(m_fpa_util.is_float(s) || m_fpa_util.is_rm(s)); SASSERT(m_fpa_util.is_float(n->get_expr()) || m_fpa_util.is_rm(n->get_expr())); SASSERT(n->get_decl()->get_range() == s); if (is_attached_to_var(n)) return; if (m.is_ite(n->get_expr())) return; attach_new_th_var(n); expr* owner = n->get_expr(); if (m_fpa_util.is_rm(s) && !m_fpa_util.is_bv2rm(owner)) { // For every RM term, we need to make sure that it's // associated bit-vector is within the valid range. expr_ref valid(m), limit(m); limit = m_bv_util.mk_numeral(4, 3); valid = m_bv_util.mk_ule(m_converter.wrap(owner), limit); add_unit(mk_literal(valid)); } activate(owner); } bool solver::unit_propagate() { if (m_nodes.size() <= m_nodes_qhead) return false; ctx.push(value_trail(m_nodes_qhead)); for (; m_nodes_qhead < m_nodes.size(); ++m_nodes_qhead) unit_propagate(m_nodes[m_nodes_qhead]); return true; } void solver::unit_propagate(std::tuple const& t) { auto [n, sign, root] = t; expr* e = n->get_expr(); app* a = to_app(e); if (m.is_bool(e)) { sat::literal atom(ctx.get_si().add_bool_var(e), false); atom = ctx.attach_lit(atom, e); sat::literal bv_atom = mk_literal(m_rw.convert_atom(m_th_rw, e)); sat::literal_vector conds = mk_side_conditions(); conds.push_back(bv_atom); add_equiv_and(atom, conds); if (root) { if (sign) atom.neg(); add_unit(atom); } } else { switch (a->get_decl_kind()) { case OP_FPA_TO_FP: case OP_FPA_TO_UBV: case OP_FPA_TO_SBV: case OP_FPA_TO_REAL: case OP_FPA_TO_IEEE_BV: { expr_ref conv = convert(e); add_unit(eq_internalize(e, conv)); add_units(mk_side_conditions()); break; } default: /* ignore */ break; } } activate(e); } void solver::activate(expr* n) { TRACE("t_fpa", tout << "relevant_eh for: " << mk_ismt2_pp(n, m) << "\n";); mpf_manager& mpfm = m_fpa_util.fm(); if (m.is_ite(n)) { // skip } else if (m_fpa_util.is_float(n) || m_fpa_util.is_rm(n)) { expr* a = nullptr, * b = nullptr, * c = nullptr; if (!m_fpa_util.is_fp(n)) { app_ref wrapped = m_converter.wrap(n); mpf_rounding_mode rm; scoped_mpf val(mpfm); if (m_fpa_util.is_rm_numeral(n, rm)) { expr_ref rm_num(m); rm_num = m_bv_util.mk_numeral(rm, 3); add_unit(eq_internalize(wrapped, rm_num)); } else if (m_fpa_util.is_numeral(n, val)) { expr_ref bv_val_e(convert(n), m); VERIFY(m_fpa_util.is_fp(bv_val_e, a, b, c)); expr* args[] = { a, b, c }; expr_ref cc_args(m_bv_util.mk_concat(3, args), m); // Require // wrap(n) = bvK // fp(extract(wrap(n)) = n add_unit(eq_internalize(wrapped, cc_args)); add_unit(eq_internalize(bv_val_e, n)); add_units(mk_side_conditions()); } else add_unit(eq_internalize(m_converter.unwrap(wrapped, n->get_sort()), n)); } } else if (is_app(n) && to_app(n)->get_family_id() == get_id()) { // These are the conversion functions fp.to_* */ SASSERT(!m_fpa_util.is_float(n) && !m_fpa_util.is_rm(n)); } else { /* Theory variables can be merged when (= bv-term (bvwrap fp-term)) */ SASSERT(m_bv_util.is_bv(n)); } } void solver::ensure_equality_relation(theory_var x, theory_var y) { fpa_util& fu = m_fpa_util; enode* e_x = var2enode(x); enode* e_y = var2enode(y); expr* xe = e_x->get_expr(); expr* ye = e_y->get_expr(); if (fu.is_bvwrap(xe) || fu.is_bvwrap(ye)) return; TRACE("t_fpa", tout << "new eq: " << x << " = " << y << "\n"; tout << mk_ismt2_pp(xe, m) << "\n" << " = " << "\n" << mk_ismt2_pp(ye, m) << "\n";); expr_ref xc = convert(xe); expr_ref yc = convert(ye); TRACE("t_fpa_detail", tout << "xc = " << mk_ismt2_pp(xc, m) << "\n" << "yc = " << mk_ismt2_pp(yc, m) << "\n";); expr_ref c(m); if ((fu.is_float(xe) && fu.is_float(ye)) || (fu.is_rm(xe) && fu.is_rm(ye))) m_converter.mk_eq(xc, yc, c); else c = m.mk_eq(xc, yc); m_th_rw(c); sat::literal eq1 = eq_internalize(xe, ye); sat::literal eq2 = mk_literal(c); add_equiv(eq1, eq2); add_units(mk_side_conditions()); } void solver::new_eq_eh(euf::th_eq const& eq) { ensure_equality_relation(eq.v1(), eq.v2()); } void solver::new_diseq_eh(euf::th_eq const& eq) { ensure_equality_relation(eq.v1(), eq.v2()); } void solver::asserted(sat::literal l) { expr* e = ctx.bool_var2expr(l.var()); TRACE("t_fpa", tout << "assign_eh for: " << l << "\n" << mk_ismt2_pp(e, m) << "\n";); sat::literal c = mk_literal(convert(e)); sat::literal_vector conds = mk_side_conditions(); conds.push_back(c); if (l.sign()) { for (sat::literal sc : conds) add_clause(l, sc); } else { for (auto& sc : conds) sc.neg(); conds.push_back(l); add_clause(conds); } } void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { expr* e = n->get_expr(); app_ref wrapped(m); expr_ref value(m); auto is_wrapped = [&]() { if (!wrapped) wrapped = m_converter.wrap(e); return expr2enode(wrapped) != nullptr; }; if (m_fpa_util.is_rm_numeral(e) || m_fpa_util.is_numeral(e)) value = e; else if (m_fpa_util.is_fp(e)) { SASSERT(n->num_args() == 3); expr* a = values.get(n->get_arg(0)->get_root_id()); expr* b = values.get(n->get_arg(1)->get_root_id()); expr* c = values.get(n->get_arg(2)->get_root_id()); value = m_converter.bv2fpa_value(e->get_sort(), a, b, c); } else if (m_fpa_util.is_bv2rm(e)) { SASSERT(n->num_args() == 1); value = m_converter.bv2rm_value(values.get(n->get_arg(0)->get_root_id())); } else if (m_fpa_util.is_rm(e) && is_wrapped()) value = m_converter.bv2rm_value(values.get(expr2enode(wrapped)->get_root_id())); else if (m_fpa_util.is_rm(e)) value = m_fpa_util.mk_round_toward_zero(); else if (m_fpa_util.is_float(e) && is_wrapped()) { expr* a = values.get(expr2enode(wrapped)->get_root_id()); value = m_converter.bv2fpa_value(e->get_sort(), a); } else { SASSERT(m_fpa_util.is_float(e)); unsigned ebits = m_fpa_util.get_ebits(e->get_sort()); unsigned sbits = m_fpa_util.get_sbits(e->get_sort()); value = m_fpa_util.mk_pzero(ebits, sbits); } values.set(n->get_root_id(), value); TRACE("t_fpa", tout << ctx.bpp(n) << " := " << value << "\n";); } bool solver::add_dep(euf::enode* n, top_sort& dep) { expr* e = n->get_expr(); if (m_fpa_util.is_fp(e)) { SASSERT(n->num_args() == 3); for (enode* arg : euf::enode_args(n)) dep.add(n, arg); return true; } else if (m_fpa_util.is_bv2rm(e)) { SASSERT(n->num_args() == 1); dep.add(n, n->get_arg(0)); return true; } else if (m_fpa_util.is_rm(e) || m_fpa_util.is_float(e)) { euf::enode* wrapped = expr2enode(m_converter.wrap(e)); if (wrapped) dep.add(n, wrapped); return nullptr != wrapped; } else return false; } std::ostream& solver::display(std::ostream& out) const { bool first = true; for (enode* n : ctx.get_egraph().nodes()) { theory_var v = n->get_th_var(m_fpa_util.get_family_id()); if (v != -1) { if (first) out << "fpa theory variables:" << "\n"; out << v << " -> " << mk_ismt2_pp(n->get_expr(), m) << "\n"; first = false; } } // if there are no fpa theory variables, was fp ever used? if (first) return out; out << "bv theory variables:" << "\n"; for (enode* n : ctx.get_egraph().nodes()) { theory_var v = n->get_th_var(m_bv_util.get_family_id()); if (v != -1) out << v << " -> " << mk_ismt2_pp(n->get_expr(), m) << "\n"; } out << "arith theory variables:" << "\n"; for (enode* n : ctx.get_egraph().nodes()) { theory_var v = n->get_th_var(m_arith_util.get_family_id()); if (v != -1) out << v << " -> " << mk_ismt2_pp(n->get_expr(), m) << "\n"; } out << "equivalence classes:\n"; for (enode* n : ctx.get_egraph().nodes()) { expr* e = n->get_expr(); out << n->get_root_id() << " --> " << mk_ismt2_pp(e, m) << "\n"; } return out; } void solver::finalize_model(model& mdl) { model new_model(m); bv2fpa_converter bv2fp(m, m_converter); obj_hashtable seen; bv2fp.convert_min_max_specials(&mdl, &new_model, seen); bv2fp.convert_uf2bvuf(&mdl, &new_model, seen); for (func_decl* f : seen) mdl.unregister_decl(f); for (unsigned i = 0; i < new_model.get_num_constants(); i++) { func_decl* f = new_model.get_constant(i); mdl.register_decl(f, new_model.get_const_interp(f)); } for (unsigned i = 0; i < new_model.get_num_functions(); i++) { func_decl* f = new_model.get_function(i); func_interp* fi = new_model.get_func_interp(f)->copy(); mdl.register_decl(f, fi); } } };