diff --git a/src/smt/theory_str.cpp b/src/smt/theory_str.cpp index f44cb8322..939a63160 100644 --- a/src/smt/theory_str.cpp +++ b/src/smt/theory_str.cpp @@ -54,12 +54,15 @@ theory_str::theory_str(ast_manager & m): tmpValTestVarCount(0), avoidLoopCut(true), loopDetected(false), - contains_map(m) + contains_map(m), + m_find(*this), + m_trail_stack(*this) { initialize_charset(); } theory_str::~theory_str() { + m_trail_stack.reset(); } void theory_str::initialize_charset() { @@ -284,7 +287,7 @@ theory_var theory_str::mk_var(enode* n) { } else { theory_var v = theory::mk_var(n); - // m_find.mk_var(); + m_find.mk_var(); get_context().attach_th_var(n, this, v); get_context().mark_as_relevant(n); return v; @@ -1586,6 +1589,8 @@ void theory_str::attach_new_th_var(enode * n) { void theory_str::reset_eh() { TRACE("t_str", tout << "resetting" << std::endl;); + m_trail_stack.reset(); + m_basicstr_axiom_todo.reset(); m_str_eq_todo.reset(); m_concat_axiom_todo.reset(); @@ -1673,13 +1678,40 @@ bool theory_str::new_eq_check(expr * lhs, expr * rhs) { return true; } +// support for user_smt_theory-style EQC handling + +app * theory_str::get_ast(theory_var i) { + return get_enode(i)->get_owner(); +} + +theory_var theory_str::get_var(expr * n) const { + if (!is_app(n)) { + return null_theory_var; + } + context & ctx = get_context(); + if (ctx.e_internalized(to_app(n))) { + enode * e = ctx.get_enode(to_app(n)); + return e->get_th_var(get_id()); + } + return null_theory_var; +} + +// simulate Z3_theory_get_eqc_next() +expr * theory_str::get_eqc_next(expr * n) { + theory_var v = get_var(n); + if (v != null_theory_var) { + theory_var r = m_find.next(v); + return get_ast(r); + } + return n; +} + void theory_str::group_terms_by_eqc(expr * n, std::set & concats, std::set & vars, std::set & consts) { context & ctx = get_context(); - enode * nNode = ctx.get_enode(n); - enode * eqcNode = nNode; + expr * eqcNode = n; do { - app * ast = eqcNode->get_owner(); - if (is_concat(eqcNode)) { + app * ast = to_app(eqcNode); + if (is_concat(ast)) { expr * simConcat = simplify_concat(ast); if (simConcat != ast) { if (is_concat(to_app(simConcat))) { @@ -1694,13 +1726,13 @@ void theory_str::group_terms_by_eqc(expr * n, std::set & concats, std::se } else { concats.insert(simConcat); } - } else if (is_string(eqcNode)) { + } else if (is_string(ast)) { consts.insert(ast); } else { vars.insert(ast); } - eqcNode = eqcNode->get_next(); - } while (eqcNode != nNode); + eqcNode = get_eqc_next(eqcNode); + } while (eqcNode != n); } void theory_str::get_nodes_in_concat(expr * node, ptr_vector & nodeList) { @@ -3975,6 +4007,22 @@ expr * theory_str::get_eqc_value(expr * n, bool & hasEqcValue) { return n; } +// Simulate the behaviour of get_eqc_value() from Z3str2. +// We only check m_find for a string constant. + +expr * theory_str::z3str2_get_eqc_value(expr * n , bool & hasEqcValue) { + expr * curr = n; + do { + if (m_strutil.is_string(curr)) { + hasEqcValue = true; + return curr; + } + curr = get_eqc_next(curr); + } while (curr != n); + hasEqcValue = false; + return n; +} + // from Z3: theory_seq.cpp static theory_mi_arith* get_th_arith(context& ctx, theory_id afid, expr* e) { @@ -6110,106 +6158,107 @@ void theory_str::handle_equality(expr * lhs, expr * rhs) { instantiate_str_eq_length_axiom(ctx.get_enode(lhs), ctx.get_enode(rhs)); // group terms by equivalence class (groupNodeInEqc()) - // Previously we did the check between LHS and RHS equivalence classes. - // However these have since been merged. - // We start by asserting that the EQCs, in fact, really are merged. - if (!in_same_eqc(lhs, rhs)) { - TRACE("t_str", tout << "BUG: lhs and rhs not in same eqc in new_eq_eh(), loss of invariant!" << std::endl;); - UNREACHABLE(); - } - std::set eqc_concat; - std::set eqc_var; - std::set eqc_const; - group_terms_by_eqc(lhs, eqc_concat, eqc_var, eqc_const); + std::set eqc_concat_lhs; + std::set eqc_var_lhs; + std::set eqc_const_lhs; + group_terms_by_eqc(lhs, eqc_concat_lhs, eqc_var_lhs, eqc_const_lhs); + + std::set eqc_concat_rhs; + std::set eqc_var_rhs; + std::set eqc_const_rhs; + group_terms_by_eqc(rhs, eqc_concat_rhs, eqc_var_rhs, eqc_const_rhs); TRACE("t_str_detail", - tout << "eqc:" << std::endl; + tout << "lhs eqc:" << std::endl; tout << "Concats:" << std::endl; - for (std::set::iterator it = eqc_concat.begin(); it != eqc_concat.end(); ++it) { + for (std::set::iterator it = eqc_concat_lhs.begin(); it != eqc_concat_lhs.end(); ++it) { expr * ex = *it; tout << mk_ismt2_pp(ex, get_manager()) << std::endl; } tout << "Variables:" << std::endl; - for (std::set::iterator it = eqc_var.begin(); it != eqc_var.end(); ++it) { + for (std::set::iterator it = eqc_var_lhs.begin(); it != eqc_var_lhs.end(); ++it) { expr * ex = *it; tout << mk_ismt2_pp(ex, get_manager()) << std::endl; } tout << "Constants:" << std::endl; - for (std::set::iterator it = eqc_const.begin(); it != eqc_const.end(); ++it) { + for (std::set::iterator it = eqc_const_lhs.begin(); it != eqc_const_lhs.end(); ++it) { + expr * ex = *it; + tout << mk_ismt2_pp(ex, get_manager()) << std::endl; + } + + tout << "rhs eqc:" << std::endl; + tout << "Concats:" << std::endl; + for (std::set::iterator it = eqc_concat_rhs.begin(); it != eqc_concat_rhs.end(); ++it) { + expr * ex = *it; + tout << mk_ismt2_pp(ex, get_manager()) << std::endl; + } + tout << "Variables:" << std::endl; + for (std::set::iterator it = eqc_var_rhs.begin(); it != eqc_var_rhs.end(); ++it) { + expr * ex = *it; + tout << mk_ismt2_pp(ex, get_manager()) << std::endl; + } + tout << "Constants:" << std::endl; + for (std::set::iterator it = eqc_const_rhs.begin(); it != eqc_const_rhs.end(); ++it) { expr * ex = *it; tout << mk_ismt2_pp(ex, get_manager()) << std::endl; } ); // step 1: Concat == Concat - - // enhancement from Z3str2: all-pairs match over LHS and RHS wrt. other concats - if (eqc_concat.size() != 0) { - std::set::iterator itor1, itor2; - for (itor1 = eqc_concat.begin(); itor1 != eqc_concat.end(); ++itor1) { - for (itor2 = itor1; itor2 != eqc_concat.end(); ++itor2) { - if (itor1 == itor2) { - continue; - } - expr * e1 = *itor1; - expr * e2 = *itor2; - TRACE("t_str_detail", tout << "simplify concat-concat pair " << mk_pp(e1, m) << " and " << mk_pp(e2, m) << std::endl;); - simplify_concat_equality(e1, e2); + int hasCommon = 0; + if (eqc_concat_lhs.size() != 0 && eqc_concat_rhs.size() != 0) { + std::set::iterator itor1 = eqc_concat_lhs.begin(); + std::set::iterator itor2 = eqc_concat_rhs.begin(); + for (; itor1 != eqc_concat_lhs.end(); itor1++) { + if (eqc_concat_rhs.find(*itor1) != eqc_concat_rhs.end()) { + hasCommon = 1; + break; } } + for (; itor2 != eqc_concat_rhs.end(); itor2++) { + if (eqc_concat_lhs.find(*itor2) != eqc_concat_lhs.end()) { + hasCommon = 1; + break; + } + } + if (hasCommon == 0) { + simplify_concat_equality(*(eqc_concat_lhs.begin()), *(eqc_concat_rhs.begin())); + } } // step 2: Concat == Constant - // same enhancement as above wrt. Z3str2's behaviour - if (eqc_const.size() != 0) { - expr * conStr = *(eqc_const.begin()); - std::set::iterator itor2; - for (itor2 = eqc_concat.begin(); itor2 != eqc_concat.end(); ++itor2) { + + if (eqc_const_lhs.size() != 0) { + expr * conStr = *(eqc_const_lhs.begin()); + std::set::iterator itor2 = eqc_const_rhs.begin(); + for (; itor2 != eqc_const_rhs.end(); itor2++) { solve_concat_eq_str(*itor2, conStr); } + } else if (eqc_const_rhs.size() != 0) { + expr* conStr = *(eqc_const_rhs.begin()); + std::set::iterator itor1 = eqc_const_lhs.begin(); + for (; itor1 != eqc_const_lhs.end(); itor1++) { + solve_concat_eq_str(*itor1, conStr); + } } // simplify parents wrt. the equivalence class of both sides - // TODO this is slightly broken, re-enable it once some semantics have been fixed - // Briefly, Z3str2 expects that as this function is entered, - // lhs and rhs are NOT in the same equivalence class yet. - // However, newer versions of Z3 appear to behave differently, - // putting lhs and rhs into the same equivalence class - // *before* this function is called. - // Instead we do something possibly more aggressive here. - /* - bool lhs_has_eqc_value = false; - bool rhs_has_eqc_value = false; - expr * lhs_value = get_eqc_value(lhs, lhs_has_eqc_value); - expr * rhs_value = get_eqc_value(rhs, rhs_has_eqc_value); - if (lhs_has_eqc_value && !rhs_has_eqc_value) { - simplify_parent(rhs, lhs_value); + bool nn1HasEqcValue = false; + bool nn2HasEqcValue = false; + // we want the Z3str2 eqc check here... + expr * nn1_value = z3str2_get_eqc_value(lhs, nn1HasEqcValue); + expr * nn2_value = z3str2_get_eqc_value(rhs, nn2HasEqcValue); + if (nn1HasEqcValue && !nn2HasEqcValue) { + simplify_parent(rhs, nn1_value); } - if (!lhs_has_eqc_value && rhs_has_eqc_value) { - simplify_parent(lhs, rhs_value); - } - */ - bool lhs_has_eqc_value = false; - bool rhs_has_eqc_value = false; - expr * lhs_value = get_eqc_value(lhs, lhs_has_eqc_value); - expr * rhs_value = get_eqc_value(rhs, rhs_has_eqc_value); - - // TODO this depends on the old, possibly broken, semantics of is_string(). - // we explicitly want to test whether lhs/rhs is actually a string constant. - bool lhs_is_string_constant = m_strutil.is_string(lhs); - bool rhs_is_string_constant = m_strutil.is_string(rhs); - - - if (lhs_has_eqc_value && !rhs_is_string_constant) { - simplify_parent(rhs, lhs_value); - } - if (rhs_has_eqc_value && !lhs_is_string_constant) { - simplify_parent(lhs, rhs_value); + if (!nn1HasEqcValue && nn2HasEqcValue) { + simplify_parent(lhs, nn2_value); } // regex unroll + // TODO NEXT check EQC semantics here too expr * nn1EqConst = NULL; std::set nn1EqUnrollFuncs; @@ -6229,6 +6278,7 @@ void theory_str::handle_equality(expr * lhs, expr * rhs) { process_unroll_eq_const_str(*itor2, nn1EqConst); } } + } void theory_str::set_up_axioms(expr * ex) { @@ -6407,7 +6457,15 @@ void theory_str::new_eq_eh(theory_var x, theory_var y) { //TRACE("t_str_detail", tout << "new eq: v#" << x << " = v#" << y << std::endl;); TRACE("t_str", tout << "new eq: " << mk_ismt2_pp(get_enode(x)->get_owner(), get_manager()) << " = " << mk_ismt2_pp(get_enode(y)->get_owner(), get_manager()) << std::endl;); + /* + if (m_find.find(x) == m_find.find(y)) { + return; + } + */ handle_equality(get_enode(x)->get_owner(), get_enode(y)->get_owner()); + + // replicate Z3str2 behaviour: merge eqc **AFTER** handle_equality + m_find.merge(x, y); } void theory_str::new_diseq_eh(theory_var x, theory_var y) { @@ -6427,6 +6485,8 @@ void theory_str::assign_eh(bool_var v, bool is_true) { void theory_str::push_scope_eh() { theory::push_scope_eh(); + m_trail_stack.push_scope(); + sLevel += 1; TRACE("t_str", tout << "push to " << sLevel << std::endl;); TRACE("t_str_dump_assign_on_scope_change", dump_assignments();); @@ -6549,6 +6609,7 @@ void theory_str::pop_scope_eh(unsigned num_scopes) { m_basicstr_axiom_todo.reset(); m_basicstr_axiom_todo = new_m_basicstr; + m_trail_stack.pop_scope(num_scopes); theory::pop_scope_eh(num_scopes); check_variable_scope(); diff --git a/src/smt/theory_str.h b/src/smt/theory_str.h index 9f7d51a8f..58b104209 100644 --- a/src/smt/theory_str.h +++ b/src/smt/theory_str.h @@ -27,6 +27,7 @@ Revision History: #include #include #include"str_rewriter.h" +#include"union_find.h" namespace smt { @@ -81,6 +82,10 @@ namespace smt { level = -100; } }; + + typedef trail_stack th_trail_stack; + typedef union_find th_union_find; + protected: // Some options that control how the solver operates. @@ -252,6 +257,12 @@ namespace smt { obj_pair_map concat_astNode_map; + th_union_find m_find; + th_trail_stack m_trail_stack; + theory_var get_var(expr * n) const; + expr * get_eqc_next(expr * n); + app * get_ast(theory_var i); + protected: void assert_axiom(expr * e); void assert_implication(expr * premise, expr * conclusion); @@ -347,6 +358,7 @@ namespace smt { app * mk_value_helper(app * n); expr * get_eqc_value(expr * n, bool & hasEqcValue); + expr * z3str2_get_eqc_value(expr * n , bool & hasEqcValue); bool in_same_eqc(expr * n1, expr * n2); expr * collect_eq_nodes(expr * n, expr_ref_vector & eqcSet); @@ -479,6 +491,11 @@ namespace smt { virtual void display(std::ostream & out) const; bool overlapping_variables_detected() const { return loopDetected; } + + th_trail_stack& get_trail_stack() { return m_trail_stack; } + void merge_eh(theory_var, theory_var, theory_var v1, theory_var v2) {} + void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) { } + void unmerge_eh(theory_var v1, theory_var v2) {} protected: virtual bool internalize_atom(app * atom, bool gate_ctx); virtual bool internalize_term(app * term);