diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 6d210c92d..b27d8166e 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -959,6 +959,7 @@ void basic_decl_plugin::get_op_names(svector & op_names, symbol co op_names.push_back(builtin_name("=", OP_EQ)); op_names.push_back(builtin_name("distinct", OP_DISTINCT)); op_names.push_back(builtin_name("ite", OP_ITE)); + op_names.push_back(builtin_name("if", OP_ITE)); op_names.push_back(builtin_name("and", OP_AND)); op_names.push_back(builtin_name("or", OP_OR)); op_names.push_back(builtin_name("xor", OP_XOR)); @@ -969,7 +970,6 @@ void basic_decl_plugin::get_op_names(svector & op_names, symbol co op_names.push_back(builtin_name("implies", OP_IMPLIES)); op_names.push_back(builtin_name("iff", OP_EQ)); op_names.push_back(builtin_name("if_then_else", OP_ITE)); - op_names.push_back(builtin_name("if", OP_ITE)); op_names.push_back(builtin_name("&&", OP_AND)); op_names.push_back(builtin_name("||", OP_OR)); op_names.push_back(builtin_name("equals", OP_EQ)); diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index 709886e4d..7b1411edb 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -803,6 +803,8 @@ void cmd_context::insert(symbol const & s, func_decl * f) { #endif func_decls & fs = m_func_decls.insert_if_not_there(s, func_decls()); if (!fs.insert(m(), f)) { + if (m_allow_duplicate_declarations) + return; std::string msg = "invalid declaration, "; msg += f->get_arity() == 0 ? "constant" : "function"; msg += " '"; diff --git a/src/cmd_context/cmd_context.h b/src/cmd_context/cmd_context.h index b7beb2625..6b21d3836 100644 --- a/src/cmd_context/cmd_context.h +++ b/src/cmd_context/cmd_context.h @@ -192,6 +192,7 @@ protected: bool m_numeral_as_real; bool m_ignore_check; // used by the API to disable check-sat() commands when parsing SMT 2.0 files. bool m_exit_on_error; + bool m_allow_duplicate_declarations { false }; static std::ostringstream g_error_stream; @@ -346,6 +347,7 @@ public: void set_produce_unsat_cores(bool flag); void set_produce_proofs(bool flag); void set_produce_unsat_assumptions(bool flag) { m_produce_unsat_assumptions = flag; } + void set_allow_duplicate_declarations() { m_allow_duplicate_declarations = true; } bool produce_assignments() const { return m_produce_assignments; } bool produce_unsat_assumptions() const { return m_produce_unsat_assumptions; } void set_produce_assignments(bool flag) { m_produce_assignments = flag; } diff --git a/src/sat/dimacs.cpp b/src/sat/dimacs.cpp index eb276b9e0..155d3a77b 100644 --- a/src/sat/dimacs.cpp +++ b/src/sat/dimacs.cpp @@ -159,6 +159,10 @@ namespace dimacs { return out << pp << " " << r.m_lits << " 0\n"; case drat_record::tag_t::is_node: return out << "e " << r.m_node_id << " " << r.m_name << " " << r.m_args << "0\n"; + case drat_record::tag_t::is_sort: + return out << "s " << r.m_node_id << " " << r.m_name << " " << r.m_args << "0\n"; + case drat_record::tag_t::is_decl: + return out << "f " << r.m_node_id << " " << r.m_name << " " << r.m_args << "0\n"; case drat_record::tag_t::is_bool_def: return out << "b " << r.m_node_id << " " << r.m_args << "0\n"; } @@ -208,6 +212,24 @@ namespace dimacs { bool drat_parser::next() { int n, b, e, theory_id; + auto parse_ast = [&](drat_record::tag_t tag) { + ++in; + skip_whitespace(in); + n = parse_int(in, err); + skip_whitespace(in); + m_record.m_name = parse_sexpr(); + m_record.m_tag = tag; + m_record.m_node_id = n; + m_record.m_args.reset(); + while (true) { + n = parse_int(in, err); + if (n == 0) + break; + if (n < 0) + throw lex_error(); + m_record.m_args.push_back(n); + } + }; try { loop: skip_whitespace(in); @@ -235,22 +257,13 @@ namespace dimacs { m_record.m_status = sat::status::th(false, theory_id); break; case 'e': - ++in; - skip_whitespace(in); - n = parse_int(in, err); - skip_whitespace(in); - m_record.m_name = parse_sexpr(); - m_record.m_tag = drat_record::tag_t::is_node; - m_record.m_node_id = n; - m_record.m_args.reset(); - while (true) { - n = parse_int(in, err); - if (n == 0) - break; - if (n < 0) - throw lex_error(); - m_record.m_args.push_back(n); - } + parse_ast(drat_record::tag_t::is_node); + break; + case 'f': + parse_ast(drat_record::tag_t::is_decl); + break; + case 's': + parse_ast(drat_record::tag_t::is_sort); break; case 'b': ++in; diff --git a/src/sat/dimacs.h b/src/sat/dimacs.h index 681f65b1f..690f5b0fc 100644 --- a/src/sat/dimacs.h +++ b/src/sat/dimacs.h @@ -53,7 +53,7 @@ namespace dimacs { }; struct drat_record { - enum class tag_t { is_clause, is_node, is_bool_def }; + enum class tag_t { is_clause, is_node, is_decl, is_sort, is_bool_def }; tag_t m_tag{ tag_t::is_clause }; // a clause populates m_lits and m_status // a node populates m_node_id, m_name, m_args diff --git a/src/sat/sat_drat.cpp b/src/sat/sat_drat.cpp index ede975b50..23ff01f15 100644 --- a/src/sat/sat_drat.cpp +++ b/src/sat/sat_drat.cpp @@ -86,6 +86,9 @@ namespace sat { return; if (m_activity && ((m_stats.m_num_add % 1000) == 0)) dump_activity(); + + SASSERT(!(n == 2 && c[0] == literal(3802, true) && c[1] == literal(3808, false))); + VERIFY(!(n == 2 && c[0] == literal(3802, true) && c[1] == literal(3808, false))); char buffer[10000]; char digits[20]; // enough for storing unsigned @@ -262,9 +265,9 @@ namespace sat { (*m_out) << "b " << v << " " << n << " 0\n"; } - void drat::def_begin(unsigned n, std::string const& name) { + void drat::def_begin(char id, unsigned n, std::string const& name) { if (m_out) - (*m_out) << "e " << n << " " << name; + (*m_out) << id << " " << n << " " << name; } void drat::def_add_arg(unsigned arg) { diff --git a/src/sat/sat_drat.h b/src/sat/sat_drat.h index d7d6477f2..f39d18017 100644 --- a/src/sat/sat_drat.h +++ b/src/sat/sat_drat.h @@ -130,7 +130,7 @@ namespace sat { void bool_def(bool_var v, unsigned n); // declare AST node n with 'name' and arguments arg - void def_begin(unsigned n, std::string const& name); + void def_begin(char id, unsigned n, std::string const& name); void def_add_arg(unsigned arg); void def_end(); diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 747e7752c..86536ac47 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -13,6 +13,7 @@ z3_add_component(sat_smt ba_xor.cpp bv_ackerman.cpp bv_internalize.cpp + bv_invariant.cpp bv_solver.cpp euf_ackerman.cpp euf_internalize.cpp diff --git a/src/sat/smt/bv_ackerman.cpp b/src/sat/smt/bv_ackerman.cpp index 669ebf0b8..62251d7a0 100644 --- a/src/sat/smt/bv_ackerman.cpp +++ b/src/sat/smt/bv_ackerman.cpp @@ -91,7 +91,7 @@ namespace bv { } if (glue < max_glue) - v.m_glue = glue <= sz ? 0 : glue; + v.m_glue = 2*glue <= sz ? 0 : glue; } void ackerman::remove(vv* p) { @@ -146,7 +146,7 @@ namespace bv { sort* s2 = s.m.get_sort(s.var2expr(v2)); if (s1 != s2 || !s.bv.is_bv_sort(s1)) return; - IF_VERBOSE(0, verbose_stream() << "assert ackerman " << v1 << " " << v2 << "\n"); + // IF_VERBOSE(0, verbose_stream() << "assert ackerman " << v1 << " " << v2 << "\n"); s.assert_ackerman(v1, v2); } diff --git a/src/sat/smt/bv_internalize.cpp b/src/sat/smt/bv_internalize.cpp index 4f6bbea45..2686a65f7 100644 --- a/src/sat/smt/bv_internalize.cpp +++ b/src/sat/smt/bv_internalize.cpp @@ -74,7 +74,10 @@ namespace bv { SASSERT(m.is_bool(e)); if (!visit_rec(m, e, sign, root, redundant)) return sat::null_literal; - return expr2literal(e); + sat::literal lit = expr2literal(e); + if (sign) + lit.neg(); + return lit; } void solver::internalize(expr* e, bool redundant) { @@ -179,8 +182,9 @@ namespace bv { m_bits[v].reset(); for (unsigned i = 0; i < bv_size; i++) { expr_ref b2b(bv.mk_bit2bool(e, i), m); + m_bits[v].push_back(sat::null_literal); sat::literal lit = ctx.internalize(b2b, false, false, m_is_redundant); - m_bits[v].push_back(lit); + SASSERT(m_bits[v].back() == lit); } } @@ -234,7 +238,15 @@ namespace bv { set_bit_eh(v, l, idx); } + solver::bit_atom* solver::mk_bit_atom(sat::bool_var bv) { + bit_atom* b = new (get_region()) bit_atom(); + insert_bv2a(bv, b); + ctx.push(mk_atom_trail(bv, *this)); + return b; + } + void solver::set_bit_eh(theory_var v, literal l, unsigned idx) { + SASSERT(m_bits[v][idx] == l); if (s().value(l) != l_undef && s().lvl(l) == 0) register_true_false_bit(v, idx); else { @@ -248,9 +260,7 @@ namespace bv { b->m_occs = new (get_region()) var_pos_occ(v, idx, b->m_occs); } else { - bit_atom* b = new (get_region()) bit_atom(); - insert_bv2a(l.var(), b); - ctx.push(mk_atom_trail(l.var(), *this)); + bit_atom* b = mk_bit_atom(l.var()); SASSERT(!b->m_occs); b->m_occs = new (get_region()) var_pos_occ(v, idx); } @@ -460,7 +470,7 @@ namespace bv { new_bits.reset(); fn(arg_bits.size(), arg_bits.c_ptr(), bits.c_ptr(), new_bits); bits.swap(new_bits); - } + } init_bits(n, bits); TRACE("bv_verbose", tout << arg_bits << " " << bits << " " << new_bits << "\n";); } @@ -505,17 +515,18 @@ namespace bv { init_bits(n, bits); } - void solver::internalize_extract(app* n) { - SASSERT(n->get_num_args() == 1); - euf::enode* e = expr2enode(n); - theory_var v = e->get_th_var(get_id()); - theory_var arg = get_arg_var(e, 0); - unsigned start = n->get_decl()->get_parameter(1).get_int(); - unsigned end = n->get_decl()->get_parameter(0).get_int(); - SASSERT(start <= end && end < get_bv_size(v)); + void solver::internalize_extract(app* e) { + expr* arg_e = nullptr; + unsigned lo = 0, hi = 0; + VERIFY(bv.is_extract(e, lo, hi, arg_e)); + euf::enode* n = expr2enode(e); + theory_var v = n->get_th_var(get_id()); + theory_var arg_v = get_arg_var(n, 0); + SASSERT(hi - lo + 1 == get_bv_size(v)); + SASSERT(lo <= hi && hi < get_bv_size(arg_v)); m_bits[v].reset(); - for (unsigned i = start; i <= end; ++i) - add_bit(v, m_bits[arg][i]); + for (unsigned i = lo; i <= hi; ++i) + add_bit(v, m_bits[arg_v][i]); find_wpos(v); } @@ -528,17 +539,22 @@ namespace bv { mk_var(argn); } theory_var v_arg = argn->get_th_var(get_id()); - sat::literal lit = expr2literal(n); - sat::bool_var b = lit.var(); - bit_atom* a = new (get_region()) bit_atom(); SASSERT(idx < get_bv_size(v_arg)); - a->m_occs = new (get_region()) var_pos_occ(v_arg, idx); - insert_bv2a(b, a); - ctx.push(mk_atom_trail(b, *this)); - if (idx < m_bits[v_arg].size() && lit != m_bits[v_arg][idx]) { - add_clause(m_bits[v_arg][idx], ~lit); - add_clause(~m_bits[v_arg][idx], lit); + sat::literal lit = expr2literal(n); + sat::literal lit0 = m_bits[v_arg][idx]; + if (lit0 == sat::null_literal) { + m_bits[v_arg][idx] = lit; + bit_atom* a = new (get_region()) bit_atom(); + a->m_occs = new (get_region()) var_pos_occ(v_arg, idx); + insert_bv2a(lit.var(), a); + ctx.push(mk_atom_trail(lit.var(), *this)); } + else if (lit != lit0) { + add_clause(lit0, ~lit); + add_clause(~lit0, lit); + } + + // validate_atoms(); // axiomatize bit2bool on constants. rational val; unsigned sz; diff --git a/src/sat/smt/bv_solver.cpp b/src/sat/smt/bv_solver.cpp index 1c241e1c9..cd35b5b20 100644 --- a/src/sat/smt/bv_solver.cpp +++ b/src/sat/smt/bv_solver.cpp @@ -36,6 +36,20 @@ namespace bv { } }; + class solver::bit_occs_trail : public trail { + solver& s; + bit_atom& a; + var_pos_occ* m_occs; + + public: + bit_occs_trail(solver& s, bit_atom& a):s(s), a(a), m_occs(a.m_occs) {} + + virtual void undo(euf::solver& euf) { + std::cout << "add back occurrences " << & a << "\n"; + a.m_occs = m_occs; + } + }; + solver::solver(euf::solver& ctx, theory_id id) : euf::th_euf_solver(ctx, id), bv(m), @@ -254,7 +268,7 @@ namespace bv { void solver::asserted(literal l) { atom* a = get_bv2a(l.var()); TRACE("bv", tout << "asserted: " << l << "\n";); - if (a->is_bit()) + if (a && a->is_bit()) for (auto vp : a->to_bit()) m_prop_queue.push_back(vp); } @@ -327,16 +341,25 @@ namespace bv { bool solver::set_root(literal l, literal r) { atom* a = get_bv2a(l.var()); + atom* b = get_bv2a(r.var()); if (!a || !a->is_bit()) return true; - for (auto vp : a->to_bit()) { + if (b && !b->is_bit()) + return false; + for (auto vp : a->to_bit()) { sat::literal l2 = m_bits[vp.first][vp.second]; - sat::literal r2 = (l.sign() == l2.sign()) ? r : ~r; + if (l2.var() == r.var()) + continue; SASSERT(l2.var() == l.var()); - ctx.push(bit_trail(*this, vp)); + VERIFY(l2.var() == l.var()); + sat::literal r2 = (l.sign() == l2.sign()) ? r : ~r; + ctx.push(vector2_value_trail(m_bits, vp.first, vp.second)); m_bits[vp.first][vp.second] = r2; set_bit_eh(vp.first, r2, vp.second); } + ctx.push(bit_occs_trail(*this, a->to_bit())); + a->to_bit().m_occs = nullptr; + // validate_atoms(); return true; } @@ -443,6 +466,7 @@ namespace bv { def_atom* new_a = new (result->get_region()) def_atom(a->to_def().m_var, a->to_def().m_def); m_bool_var2atom.setx(i, new_a, nullptr); } + validate_atoms(); } return result; } @@ -456,6 +480,10 @@ namespace bv { void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) { SASSERT(bv.is_bv(n->get_expr())); + if (bv.is_numeral(n->get_expr())) { + values[n->get_root_id()] = n->get_expr(); + return; + } theory_var v = n->get_th_var(get_id()); rational val; unsigned i = 0; @@ -619,52 +647,4 @@ namespace bv { return m_power2[i]; } - /** - \brief Check whether m_zero_one_bits is an accurate summary of the bits in the - equivalence class rooted by v. - \remark The method does nothing if v is not the root of the equivalence class. - */ - bool solver::check_zero_one_bits(theory_var v) { - if (s().inconsistent()) - return true; // property is only valid if the context is not in a conflict. - if (!is_root(v) || !is_bv(v)) - return true; - bool_vector bits[2]; - unsigned num_bits = 0; - unsigned bv_sz = get_bv_size(v); - bits[0].resize(bv_sz, false); - bits[1].resize(bv_sz, false); - - theory_var curr = v; - do { - literal_vector const& lits = m_bits[curr]; - for (unsigned i = 0; i < lits.size(); i++) { - literal l = lits[i]; - if (s().value(l) != l_undef) { - unsigned is_true = s().value(l) == l_true; - if (bits[!is_true][i]) { - // expect a conflict later on. - return true; - } - if (!bits[is_true][i]) { - bits[is_true][i] = true; - num_bits++; - } - } - } - curr = m_find.next(curr); - } while (curr != v); - - zero_one_bits const& _bits = m_zero_one_bits[v]; - SASSERT(_bits.size() == num_bits); - bool_vector already_found; - already_found.resize(bv_sz, false); - for (auto& zo : _bits) { - SASSERT(find(zo.m_owner) == v); - SASSERT(bits[zo.m_is_true][zo.m_idx]); - SASSERT(!already_found[zo.m_idx]); - already_found[zo.m_idx] = true; - } - return true; - } } diff --git a/src/sat/smt/bv_solver.h b/src/sat/smt/bv_solver.h index 45f90d109..ea9f364c6 100644 --- a/src/sat/smt/bv_solver.h +++ b/src/sat/smt/bv_solver.h @@ -145,7 +145,9 @@ namespace bv { class bit_trail; class add_var_pos_trail; class mk_atom_trail; + class bit_occs_trail; typedef ptr_vector bool_var2atom; + typedef vector bits_vector; bv_util bv; arith_util m_autil; @@ -153,7 +155,7 @@ namespace bv { ackerman m_ackerman; bit_blaster m_bb; bv_union_find m_find; - vector m_bits; // per var, the bits of a given variable. + bits_vector m_bits; // per var, the bits of a given variable. unsigned_vector m_wpos; // per var, watch position for fixed variable detection. vector m_zero_one_bits; // per var, see comment in the struct zero_one_bit bool_var2atom m_bool_var2atom; @@ -188,6 +190,7 @@ namespace bv { sat::status status() const { return sat::status::th(m_is_redundant, get_id()); } void register_true_false_bit(theory_var v, unsigned i); void add_bit(theory_var v, sat::literal lit); + bit_atom* mk_bit_atom(sat::bool_var bv); void set_bit_eh(theory_var v, literal l, unsigned idx); void init_bits(expr* e, expr_ref_vector const & bits); void mk_bits(theory_var v); @@ -228,6 +231,7 @@ namespace bv { // invariants bool check_zero_one_bits(theory_var v); + void validate_atoms() const; public: solver(euf::solver& ctx, theory_id id); diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index fd6954d7e..d45525fe7 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -15,7 +15,6 @@ Author: --*/ -#include "ast/ast_pp.h" #include "ast/pb_decl_plugin.h" #include "sat/smt/euf_solver.h" @@ -98,9 +97,9 @@ namespace euf { } sat::literal solver::attach_lit(literal lit, expr* e) { - sat::bool_var v = lit.var(); + sat::bool_var v = lit.var(); s().set_external(v); - s().set_eliminated(v, false); + s().set_eliminated(v, false); if (lit.sign()) { v = si.add_bool_var(e); @@ -112,7 +111,8 @@ namespace euf { lit = lit2; } m_var2expr.reserve(v + 1, nullptr); - SASSERT(m_var2expr[v] == nullptr); + if (m_var2expr[v]) + return lit; m_var2expr[v] = e; m_var_trail.push_back(v); if (!m_egraph.find(e)) { diff --git a/src/sat/smt/euf_proof.cpp b/src/sat/smt/euf_proof.cpp index ab95c99dc..cbb49b847 100644 --- a/src/sat/smt/euf_proof.cpp +++ b/src/sat/smt/euf_proof.cpp @@ -15,7 +15,6 @@ Author: --*/ -#include "ast/ast_ll_pp.h" #include "sat/smt/euf_solver.h" namespace euf { @@ -33,12 +32,13 @@ namespace euf { return; if (is_app(e)) { app* a = to_app(e); + drat_log_decl(a->get_decl()); if (a->get_num_parameters() == 0) - get_drat().def_begin(e->get_id(), a->get_decl()->get_name().str()); + get_drat().def_begin('e', e->get_id(), a->get_decl()->get_name().str()); else { std::stringstream strm; strm << mk_ismt2_func(a->get_decl(), m); - get_drat().def_begin(e->get_id(), strm.str()); + get_drat().def_begin('e', e->get_id(), strm.str()); } for (expr* arg : *a) get_drat().def_add_arg(arg->get_id()); @@ -49,6 +49,20 @@ namespace euf { } } + void solver::drat_log_decl(func_decl* f) { + if (f->get_family_id() != null_family_id) + return; + if (m_drat_asts.contains(f)) + return; + m_drat_asts.insert(f); + push(insert_obj_trail(m_drat_asts, f)); + std::ostringstream strm; + smt2_pp_environment_dbg env(m); + ast_smt2_pp(strm, f, env); + get_drat().def_begin('f', f->get_decl_id(), strm.str()); + get_drat().def_end(); + } + /** * \brief logs antecedents to a proof trail. * diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 101af94ba..a64b4b652 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -147,6 +147,8 @@ namespace euf { // proofs void log_antecedents(std::ostream& out, literal l, literal_vector const& r); void log_antecedents(literal l, literal_vector const& r); + void drat_log_decl(func_decl* f); + obj_hashtable m_drat_asts; bool m_drat_initialized{ false }; void init_drat(); diff --git a/src/sat/smt/sat_smt.h b/src/sat/smt/sat_smt.h index fe1d32df6..58e33c06d 100644 --- a/src/sat/smt/sat_smt.h +++ b/src/sat/smt/sat_smt.h @@ -17,6 +17,7 @@ Author: #pragma once #include "ast/ast.h" #include "ast/ast_pp.h" +#include "ast/ast_ll_pp.h" #include "sat/sat_solver.h" namespace sat { diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index e31d9f6a9..953400063 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -381,11 +381,10 @@ struct goal2sat::imp : public sat::sat_internalizer { sat::bool_var k = add_var(false, t); sat::literal l(k, false); m_cache.insert(t, l); - sat::literal * lits = m_result_stack.end() - num; - - for (unsigned i = 0; i < num; i++) { + sat::literal * lits = m_result_stack.end() - num; + for (unsigned i = 0; i < num; i++) mk_clause(~lits[i], l); - } + m_result_stack.push_back(~l); lits = m_result_stack.end() - num - 1; if (m_aig) { @@ -395,9 +394,9 @@ struct goal2sat::imp : public sat::sat_internalizer { // remark: mk_clause may perform destructive updated to lits. // I have to execute it after the binary mk_clause above. mk_clause(num+1, lits); - if (m_aig) { + if (m_aig) m_aig->add_or(l, num, aig_lits.c_ptr()); - } + m_result_stack.shrink(old_sz); if (sign) l.neg(); diff --git a/src/shell/drat_frontend.cpp b/src/shell/drat_frontend.cpp index 3ebf592c1..0897166b2 100644 --- a/src/shell/drat_frontend.cpp +++ b/src/shell/drat_frontend.cpp @@ -5,6 +5,7 @@ Copyright (c) 2020 Microsoft Corporation #include #include +#include "ast/bv_decl_plugin.h" #include "util/memory_manager.h" #include "util/statistics.h" #include "sat/dimacs.h" @@ -25,6 +26,7 @@ class smt_checker { params_ref m_params; scoped_ptr m_lemma_solver, m_input_solver; sat::literal_vector m_units; + bool m_check_inputs { false }; expr* fresh(expr* e) { unsigned i = e->get_id(); @@ -140,7 +142,7 @@ public: for (sat::literal lit : lits) while (lit.var() >= m_drat.get_solver().num_vars()) m_drat.get_solver().mk_var(true); - if (st.is_input()) + if (st.is_input() && m_check_inputs) check_assertion_redundant(lits); else if (!st.is_sat() && !st.is_deleted()) check_clause(lits); @@ -160,6 +162,69 @@ public: s->assert_expr(e); s->display(std::cout); } + + symbol name; + unsigned_vector params; + ptr_vector sorts; + + void parse_sexpr(sexpr_ref const& sexpr, cmd_context& ctx, expr_ref_vector const& args, expr_ref& result) { + params.reset(); + sorts.reset(); + for (expr* arg : args) + sorts.push_back(m.get_sort(arg)); + sort_ref rng(m); + switch (sexpr->get_kind()) { + case sexpr::kind_t::COMPOSITE: { + unsigned sz = sexpr->get_num_children(); + if (sz == 0) + goto bail; + if (sexpr->get_child(0)->get_symbol() == symbol("_")) { + name = sexpr->get_child(1)->get_symbol(); + if (name == "bv" && sz == 4) { + bv_util bvu(m); + auto val = sexpr->get_child(2)->get_numeral(); + auto n = sexpr->get_child(3)->get_numeral().get_unsigned(); + result = bvu.mk_numeral(val, n); + return; + } + for (unsigned i = 2; i < sz; ++i) { + auto* child = sexpr->get_child(i); + if (child->is_numeral() && child->get_numeral().is_unsigned()) + params.push_back(child->get_numeral().get_unsigned()); + else + goto bail; + } + break; + } + goto bail; + } + case sexpr::kind_t::SYMBOL: + name = sexpr->get_symbol(); + break; + case sexpr::kind_t::BV_NUMERAL: { + std::cout << "bv numeral\n"; + goto bail; + unsigned sz = sexpr->get_bv_size(); + rational r = sexpr->get_numeral(); + break; + } + case sexpr::kind_t::STRING: + case sexpr::kind_t::KEYWORD: + case sexpr::kind_t::NUMERAL: + default: + goto bail; + } + func_decl* f = ctx.find_func_decl(name, params.size(), params.c_ptr(), args.size(), sorts.c_ptr(), rng.get()); + if (!f) + goto bail; + result = ctx.m().mk_app(f, args); + return; + bail: + std::cout << "Could not parse expression\n"; + sexpr->display(std::cout); + std::cout << "\n"; + exit(0); + } }; static void verify_smt(char const* drat_file, char const* smt_file) { @@ -194,14 +259,14 @@ static void verify_smt(char const* drat_file, char const* smt_file) { expr_ref_vector bool_var2expr(m); expr_ref_vector exprs(m), args(m), inputs(m); - func_decl* f = nullptr; - ptr_vector sorts; + sort_ref_vector sargs(m), sorts(m); + func_decl_ref_vector decls(m); smt_checker checker(drat_checker, bool_var2expr); for (expr* a : ctx.assertions()) checker.add_assertion(a); - + for (auto const& r : drat) { std::cout << dimacs::drat_pp(r, write_theory); std::cout.flush(); @@ -211,25 +276,39 @@ static void verify_smt(char const* drat_file, char const* smt_file) { if (drat_checker.inconsistent()) std::cout << "inconsistent\n"; break; - case dimacs::drat_record::tag_t::is_node: + case dimacs::drat_record::tag_t::is_node: { + expr_ref e(m); args.reset(); - sorts.reset(); - for (auto n : r.m_args) { + for (auto n : r.m_args) args.push_back(exprs.get(n)); - sorts.push_back(ctx.m().get_sort(args.back())); - } - if (r.m_name[0] == '(') { - std::cout << "parsing sexprs is TBD\n"; - exit(0); - } - f = ctx.find_func_decl(symbol(r.m_name.c_str()), 0, nullptr, args.size(), sorts.c_ptr(), nullptr); - if (!f) { - std::cout << "could not find function\n"; - exit(0); - } + std::istringstream strm(r.m_name); + auto sexpr = parse_sexpr(ctx, strm, p, drat_file); + checker.parse_sexpr(sexpr, ctx, args, e); exprs.reserve(r.m_node_id+1); - exprs.set(r.m_node_id, ctx.m().mk_app(f, args.size(), args.c_ptr())); + exprs.set(r.m_node_id, e); break; + } + case dimacs::drat_record::tag_t::is_decl: { + std::istringstream strm(r.m_name); + ctx.set_allow_duplicate_declarations(); + parse_smt2_commands(ctx, strm); + break; + } + case dimacs::drat_record::tag_t::is_sort: { + sort_ref srt(m); + symbol name = symbol(r.m_name.c_str()); + sargs.reset(); + for (auto n : r.m_args) + sargs.push_back(sorts.get(n)); + psort_decl* pd = ctx.find_psort_decl(name); + if (pd) + srt = pd->instantiate(ctx.pm(), sargs.size(), sargs.c_ptr()); + else + srt = m.mk_uninterpreted_sort(name); + sorts.reserve(r.m_node_id+1); + sorts.set(r.m_node_id, srt); + break; + } case dimacs::drat_record::tag_t::is_bool_def: bool_var2expr.reserve(r.m_node_id+1); bool_var2expr.set(r.m_node_id, exprs.get(r.m_args[0])); diff --git a/src/util/trail.h b/src/util/trail.h index 40b50a59a..afe46f702 100644 --- a/src/util/trail.h +++ b/src/util/trail.h @@ -126,6 +126,28 @@ public: } }; +template +class vector2_value_trail : public trail { + V & m_vector; + unsigned m_i; + unsigned m_j; + T m_old_value; +public: + vector2_value_trail(V& v, unsigned i, unsigned j): + m_vector(v), + m_i(i), + m_j(j), + m_old_value(v[i][j]) { + } + + ~vector2_value_trail() override { + } + + void undo(Ctx & ctx) override { + m_vector[m_i][m_j] = m_old_value; + } +}; + template class insert_obj_map : public trail {