From aec59e4ff77b5389e377a5df2336491cd99ed84e Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 16 Oct 2016 15:43:28 -0400 Subject: [PATCH] add consequence finding to inc-sat-solver Signed-off-by: Nikolaj Bjorner --- contrib/cmake/src/test/CMakeLists.txt | 1 + src/ast/pb_decl_plugin.cpp | 50 +++++---- src/ast/pb_decl_plugin.h | 1 + src/ast/rewriter/pb_rewriter.cpp | 16 +++ src/sat/sat_solver.cpp | 46 ++++---- src/sat/sat_solver.h | 3 + src/sat/sat_solver/inc_sat_solver.cpp | 144 +++++++++++++++++++++++++- src/sat/sat_types.h | 12 +++ src/test/get_consequences.cpp | 50 +++++++++ src/test/main.cpp | 1 + 10 files changed, 283 insertions(+), 41 deletions(-) create mode 100644 src/test/get_consequences.cpp diff --git a/contrib/cmake/src/test/CMakeLists.txt b/contrib/cmake/src/test/CMakeLists.txt index 427cedcdb..acaf186ba 100644 --- a/contrib/cmake/src/test/CMakeLists.txt +++ b/contrib/cmake/src/test/CMakeLists.txt @@ -42,6 +42,7 @@ add_executable(test-z3 factor_rewriter.cpp fixed_bit_vector.cpp for_each_file.cpp + get_consequences.cpp get_implied_equalities.cpp "${CMAKE_CURRENT_BINARY_DIR}/gparams_register_modules.cpp" hashtable.cpp diff --git a/src/ast/pb_decl_plugin.cpp b/src/ast/pb_decl_plugin.cpp index e87bc15ac..18a652859 100644 --- a/src/ast/pb_decl_plugin.cpp +++ b/src/ast/pb_decl_plugin.cpp @@ -101,35 +101,47 @@ void pb_decl_plugin::get_op_names(svector & op_names, symbol const } void pb_util::normalize(unsigned num_args, rational const* coeffs, rational const& k) { - rational d(1); - for (unsigned i = 0; i < num_args; ++i) { - d = lcm(d, denominator(coeffs[i])); - } m_coeffs.reset(); - for (unsigned i = 0; i < num_args; ++i) { - m_coeffs.push_back(d*coeffs[i]); + bool all_ones = true; + for (unsigned i = 0; i < num_args && all_ones; ++i) { + all_ones = denominator(coeffs[i]).is_one(); + } + if (all_ones) { + for (unsigned i = 0; i < num_args; ++i) { + m_coeffs.push_back(coeffs[i]); + } + m_k = k; + } + else { + rational d(1); + for (unsigned i = 0; i < num_args; ++i) { + d = lcm(d, denominator(coeffs[i])); + } + for (unsigned i = 0; i < num_args; ++i) { + m_coeffs.push_back(d*coeffs[i]); + } + m_k = d*k; } - m_k = d*k; } app * pb_util::mk_le(unsigned num_args, rational const * coeffs, expr * const * args, rational const& k) { normalize(num_args, coeffs, k); - vector params; - params.push_back(parameter(floor(m_k))); + m_params.reset(); + m_params.push_back(parameter(floor(m_k))); for (unsigned i = 0; i < num_args; ++i) { - params.push_back(parameter(m_coeffs[i])); + m_params.push_back(parameter(m_coeffs[i])); } - return m.mk_app(m_fid, OP_PB_LE, params.size(), params.c_ptr(), num_args, args, m.mk_bool_sort()); + return m.mk_app(m_fid, OP_PB_LE, m_params.size(), m_params.c_ptr(), num_args, args, m.mk_bool_sort()); } app * pb_util::mk_ge(unsigned num_args, rational const * coeffs, expr * const * args, rational const& k) { normalize(num_args, coeffs, k); - vector params; - params.push_back(parameter(ceil(m_k))); + m_params.reset(); + m_params.push_back(parameter(ceil(m_k))); for (unsigned i = 0; i < num_args; ++i) { - params.push_back(parameter(m_coeffs[i])); + m_params.push_back(parameter(m_coeffs[i])); } - return m.mk_app(m_fid, OP_PB_GE, params.size(), params.c_ptr(), num_args, args, m.mk_bool_sort()); + return m.mk_app(m_fid, OP_PB_GE, m_params.size(), m_params.c_ptr(), num_args, args, m.mk_bool_sort()); } app * pb_util::mk_eq(unsigned num_args, rational const * coeffs, expr * const * args, rational const& k) { @@ -137,12 +149,12 @@ app * pb_util::mk_eq(unsigned num_args, rational const * coeffs, expr * const * if (!m_k.is_int()) { return m.mk_false(); } - vector params; - params.push_back(parameter(m_k)); + m_params.reset(); + m_params.push_back(parameter(m_k)); for (unsigned i = 0; i < num_args; ++i) { - params.push_back(parameter(m_coeffs[i])); + m_params.push_back(parameter(m_coeffs[i])); } - return m.mk_app(m_fid, OP_PB_EQ, params.size(), params.c_ptr(), num_args, args, m.mk_bool_sort()); + return m.mk_app(m_fid, OP_PB_EQ, m_params.size(), m_params.c_ptr(), num_args, args, m.mk_bool_sort()); } // ax + by < k diff --git a/src/ast/pb_decl_plugin.h b/src/ast/pb_decl_plugin.h index e1b16f0c9..2ed14a4ce 100644 --- a/src/ast/pb_decl_plugin.h +++ b/src/ast/pb_decl_plugin.h @@ -80,6 +80,7 @@ class pb_util { ast_manager & m; family_id m_fid; vector m_coeffs; + vector m_params; rational m_k; void normalize(unsigned num_args, rational const* coeffs, rational const& k); public: diff --git a/src/ast/rewriter/pb_rewriter.cpp b/src/ast/rewriter/pb_rewriter.cpp index eb85f8ec9..74062fbfa 100644 --- a/src/ast/rewriter/pb_rewriter.cpp +++ b/src/ast/rewriter/pb_rewriter.cpp @@ -277,6 +277,22 @@ br_status pb_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * cons tout << tmp << "\n"; tout << result << "\n"; ); + +#if 0 + static unsigned num_changes = 0; + static unsigned num_calls = 0; + static unsigned inc = 1; + { + expr_ref tmp(m); + tmp = m.mk_app(f, num_args, args); + num_calls++; + if (tmp != result) ++num_changes; + if (num_calls > inc) { + std::cout << num_calls << " " << num_changes << "\n"; + inc *= 2; + } + } +#endif TRACE("pb_validate", validate_rewrite(f, num_args, args, result);); diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index f2279b3c4..915080c4a 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -3075,45 +3075,51 @@ namespace sat { m_binary_clause_graph[l1.index()].push_back(l2); m_binary_clause_graph[l2.index()].push_back(l1); } - while (!ps.empty()) { + bool non_empty = true; + m_seen[0].reset(); + while (non_empty) { literal_vector mutex; - literal_set other(ps); - while (!other.empty()) { - literal_set conseq; - literal p = other.pop(); + bool turn = false; + m_reachable[turn] = ps; + while (!m_reachable[turn].empty()) { + literal p = m_reachable[turn].pop(); + if (m_seen[0].contains(p)) { + continue; + } + m_reachable[turn].remove(p); + m_seen[0].insert(p); mutex.push_back(p); - if (other.empty()) { + if (m_reachable[turn].empty()) { break; } - get_reachable(p, other, conseq); - other = conseq; + m_reachable[!turn].reset(); + get_reachable(p, m_reachable[turn], m_reachable[!turn]); + turn = !turn; } if (mutex.size() > 1) { mutexes.push_back(mutex); } - for (unsigned i = 0; i < mutex.size(); ++i) { - ps.erase(mutex[i]); - } + non_empty = !mutex.empty(); } return l_true; } void solver::get_reachable(literal p, literal_set const& goal, literal_set& reachable) { - literal_set seen; - literal_vector todo; - todo.push_back(p); - while (!todo.empty()) { - p = todo.back(); - todo.pop_back(); - if (seen.contains(p)) { + m_seen[1].reset(); + m_todo.reset(); + m_todo.push_back(p); + while (!m_todo.empty()) { + p = m_todo.back(); + m_todo.pop_back(); + if (m_seen[1].contains(p)) { continue; } - seen.insert(p); + m_seen[1].insert(p); literal np = ~p; if (goal.contains(np)) { reachable.insert(np); } - todo.append(m_binary_clause_graph[np.index()]); + m_todo.append(m_binary_clause_graph[np.index()]); } } diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 785bc6856..85836b889 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -447,6 +447,9 @@ namespace sat { u_map m_antecedents; vector m_binary_clause_graph; + literal_set m_reachable[2]; + literal_set m_seen[2]; + literal_vector m_todo; void extract_assumptions(literal lit, index_set& s); diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index 6139b3e22..83ccfcac4 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -33,6 +33,7 @@ Notes: #include "filter_model_converter.h" #include "bit_blaster_model_converter.h" #include "ast_translation.h" +#include "ast_util.h" // incremental SAT solver. class inc_sat_solver : public solver { @@ -232,6 +233,41 @@ public: return 0; } + virtual lbool get_consequences_core(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector& conseq) { + TRACE("sat", tout << assumptions << "\n" << vars << "\n";); + sat::literal_vector asms; + sat::bool_var_vector bvars; + vector lconseq; + dep2asm_t dep2asm; + m_solver.pop_to_base_level(); + lbool r = internalize_formulas(); + if (r != l_true) return r; + r = internalize_assumptions(assumptions.size(), assumptions.c_ptr(), dep2asm); + if (r != l_true) return r; + r = internalize_vars(vars, bvars); + + r = m_solver.get_consequences(m_asms, bvars, lconseq); + if (r != l_true) return r; + + // build map from bound variables to + // the consequences that cover them. + u_map bool_var2conseq; + for (unsigned i = 0; i < lconseq.size(); ++i) { + TRACE("sat", tout << lconseq[i] << "\n";); + bool_var2conseq.insert(lconseq[i][0].var(), i); + } + + // extract original fixed variables + for (unsigned i = 0; i < vars.size(); ++i) { + expr_ref cons(m); + if (extract_fixed_variable(dep2asm, vars[i], bool_var2conseq, lconseq, cons)) { + conseq.push_back(cons); + } + } + + return r; + } + virtual lbool find_mutexes(expr_ref_vector const& vars, vector& mutexes) { sat::literal_vector ls; u_map lit2var; @@ -359,6 +395,106 @@ private: return res; } + lbool internalize_vars(expr_ref_vector const& vars, sat::bool_var_vector& bvars) { + for (unsigned i = 0; i < vars.size(); ++i) { + internalize_var(vars[i], bvars); + } + return l_true; + } + + bool internalize_var(expr* v, sat::bool_var_vector& bvars) { + obj_map const& const2bits = m_bb_rewriter->const2bits(); + expr* bv; + bv_util bvutil(m); + bool internalized = false; + if (is_uninterp_const(v) && m.is_bool(v)) { + sat::bool_var b = m_map.to_bool_var(v); + + if (b != sat::null_bool_var) { + bvars.push_back(b); + internalized = true; + } + } + else if (is_uninterp_const(v) && const2bits.find(to_app(v)->get_decl(), bv)) { + SASSERT(bvutil.is_bv(bv)); + app* abv = to_app(bv); + internalized = true; + unsigned sz = abv->get_num_args(); + for (unsigned j = 0; j < sz; ++j) { + SASSERT(is_uninterp_const(abv->get_arg(j))); + sat::bool_var b = m_map.to_bool_var(abv->get_arg(j)); + if (b == sat::null_bool_var) { + internalized = false; + } + else { + bvars.push_back(b); + } + } + CTRACE("sat", internalized, tout << "var: "; for (unsigned j = 0; j < sz; ++j) tout << bvars[bvars.size()-sz+j] << " "; tout << "\n";); + } + else if (is_uninterp_const(v) && bvutil.is_bv(v)) { + // variable does not occur in assertions, so is unconstrained. + } + CTRACE("sat", !internalized, tout << "unhandled variable " << mk_pp(v, m) << "\n";); + return internalized; + } + + bool extract_fixed_variable(dep2asm_t& dep2asm, expr* v, u_map const& bool_var2conseq, vector const& lconseq, expr_ref& conseq) { + u_map asm2dep; + extract_asm2dep(dep2asm, asm2dep); + + sat::bool_var_vector bvars; + if (!internalize_var(v, bvars)) { + return false; + } + sat::literal_vector value; + sat::literal_set premises; + for (unsigned i = 0; i < bvars.size(); ++i) { + unsigned index; + if (bool_var2conseq.find(bvars[i], index)) { + value.push_back(lconseq[index][0]); + for (unsigned j = 1; j < lconseq[index].size(); ++j) { + premises.insert(lconseq[index][j]); + } + } + else { + TRACE("sat", tout << "variable is not bound " << mk_pp(v, m) << "\n";); + return false; + } + } + expr_ref val(m); + expr_ref_vector conj(m); + internalize_value(value, v, val); + while (!premises.empty()) { + expr* e = 0; + VERIFY(asm2dep.find(premises.pop().index(), e)); + conj.push_back(e); + } + conseq = m.mk_implies(mk_and(conj), val); + return true; + } + + void internalize_value(sat::literal_vector const& value, expr* v, expr_ref& val) { + bv_util bvutil(m); + if (is_uninterp_const(v) && m.is_bool(v)) { + SASSERT(value.size() == 1); + val = value[0].sign() ? m.mk_not(v) : v; + } + else if (is_uninterp_const(v) && bvutil.is_bv_sort(m.get_sort(v))) { + SASSERT(value.size() == bvutil.get_bv_size(v)); + rational r(0); + for (unsigned i = 0; i < value.size(); ++i) { + if (!value[i].sign()) { + r += rational(2).expt(i); + } + } + val = m.mk_eq(v, bvutil.mk_numeral(r, value.size())); + } + else { + UNREACHABLE(); + } + } + lbool internalize_formulas() { if (m_fmls_head == m_fmls.size()) { return l_true; @@ -395,13 +531,17 @@ private: SASSERT(dep2asm.size() == m_asms.size()); } - void extract_core(dep2asm_t& dep2asm) { - u_map asm2dep; + void extract_asm2dep(dep2asm_t const& dep2asm, u_map& asm2dep) { dep2asm_t::iterator it = dep2asm.begin(), end = dep2asm.end(); for (; it != end; ++it) { expr* e = it->m_key; asm2dep.insert(it->m_value.index(), e); } + } + + void extract_core(dep2asm_t& dep2asm) { + u_map asm2dep; + extract_asm2dep(dep2asm, asm2dep); sat::literal_vector const& core = m_solver.get_core(); TRACE("sat", dep2asm_t::iterator it2 = dep2asm.begin(); diff --git a/src/sat/sat_types.h b/src/sat/sat_types.h index 697af2e2d..93109a74f 100644 --- a/src/sat/sat_types.h +++ b/src/sat/sat_types.h @@ -239,6 +239,18 @@ namespace sat { } return result; } + literal_set& operator=(literal_vector const& v) { + reset(); + for (unsigned i = 0; i < v.size(); ++i) insert(v[i]); + return *this; + } + literal_set& operator=(literal_set const& other) { + if (this != &other) { + m_set = other.m_set; + } + return *this; + } + void insert(literal l) { m_set.insert(l.index()); } void remove(literal l) { m_set.remove(l.index()); } literal pop() { return to_literal(m_set.erase()); } diff --git a/src/test/get_consequences.cpp b/src/test/get_consequences.cpp new file mode 100644 index 000000000..8bd6bccba --- /dev/null +++ b/src/test/get_consequences.cpp @@ -0,0 +1,50 @@ +/*++ +Copyright (c) 2016 Microsoft Corporation + +--*/ + +#include "inc_sat_solver.h" +#include "bv_decl_plugin.h" +#include "reg_decl_plugins.h" +#include "ast_pp.h" +//include + +static expr_ref mk_const(ast_manager& m, char const* name, sort* s) { + return expr_ref(m.mk_const(symbol(name), s), m); +} + +static expr_ref mk_bool(ast_manager& m, char const* name) { + return expr_ref(m.mk_const(symbol(name), m.mk_bool_sort()), m); +} + +static expr_ref mk_bv(ast_manager& m, char const* name, unsigned sz) { + bv_util bv(m); + return expr_ref(m.mk_const(symbol(name), bv.mk_sort(sz)), m); +} + +void tst_get_consequences() { + ast_manager m; + reg_decl_plugins(m); + bv_util bv(m); + params_ref p; + + ref solver = mk_inc_sat_solver(m, p); + expr_ref a = mk_bool(m, "a"), b = mk_bool(m, "b"), c = mk_bool(m, "c"); + expr_ref ba = mk_bv(m, "ba", 3), bb = mk_bv(m, "bb", 3), bc = mk_bv(m, "bc", 3); + + solver->assert_expr(m.mk_implies(a, b)); + solver->assert_expr(m.mk_implies(b, c)); + expr_ref_vector asms(m), vars(m), conseq(m); + asms.push_back(a); + vars.push_back(b); + vars.push_back(c); + vars.push_back(bb); + vars.push_back(bc); + solver->assert_expr(m.mk_eq(ba, bc)); + solver->assert_expr(m.mk_eq(bv.mk_numeral(2, 3), ba)); + solver->get_consequences(asms, vars, conseq); + + std::cout << conseq << "\n"; + + +} diff --git a/src/test/main.cpp b/src/test/main.cpp index 8fc0a2de6..9c6cdd668 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -228,6 +228,7 @@ int main(int argc, char ** argv) { TST(pdr); TST_ARGV(ddnf); TST(model_evaluator); + TST(get_consequences); //TST_ARGV(hs); }