From d28ef1d47185d6d67a95e1cb3b6251ed9e72c6da Mon Sep 17 00:00:00 2001 From: Murphy Berzish Date: Mon, 15 Aug 2016 17:38:24 -0400 Subject: [PATCH] add theory_str::check_contain_by_eq_nodes --- src/smt/theory_str.cpp | 382 ++++++++++++++++++++++++++++++++++++++-- src/smt/theory_str.h | 27 ++- src/util/obj_pair_set.h | 5 + 3 files changed, 394 insertions(+), 20 deletions(-) diff --git a/src/smt/theory_str.cpp b/src/smt/theory_str.cpp index 28b310196..93173402c 100644 --- a/src/smt/theory_str.cpp +++ b/src/smt/theory_str.cpp @@ -152,6 +152,11 @@ void theory_str::assert_axiom(expr * e) { //TRACE("t_str_detail", tout << "done asserting " << mk_ismt2_pp(e, get_manager()) << std::endl;); } +expr * theory_str::rewrite_implication(expr * premise, expr * conclusion) { + ast_manager & m = get_manager(); + return m.mk_or(m.mk_not(premise), conclusion); +} + void theory_str::assert_implication(expr * premise, expr * conclusion) { ast_manager & m = get_manager(); TRACE("t_str_detail", tout << "asserting implication " << mk_ismt2_pp(premise, m) << " -> " << mk_ismt2_pp(conclusion, m) << std::endl;); @@ -1119,20 +1124,28 @@ void theory_str::instantiate_axiom_Contains(enode * e) { context & ctx = get_context(); ast_manager & m = get_manager(); - app * expr = e->get_owner(); - if (axiomatized_terms.contains(expr)) { - TRACE("t_str_detail", tout << "already set up Contains axiom for " << mk_pp(expr, m) << std::endl;); + app * ex = e->get_owner(); + if (axiomatized_terms.contains(ex)) { + TRACE("t_str_detail", tout << "already set up Contains axiom for " << mk_pp(ex, m) << std::endl;); return; } - axiomatized_terms.insert(expr); - contains_map.push_back(expr); // replaces registerContain() in Z3str2 + axiomatized_terms.insert(ex); + { // register Contains() + expr * str = ex->get_arg(0); + expr * substr = ex->get_arg(1); + contains_map.push_back(ex); + std::pair key = std::pair(str, substr); + contain_pair_bool_map.insert(str, substr, ex); + contain_pair_idx_map[str].insert(key); + contain_pair_idx_map[substr].insert(key); + } - TRACE("t_str_detail", tout << "instantiate Contains axiom for " << mk_pp(expr, m) << std::endl;); + TRACE("t_str_detail", tout << "instantiate Contains axiom for " << mk_pp(ex, m) << std::endl;); expr_ref ts0(mk_str_var("ts0"), m); expr_ref ts1(mk_str_var("ts1"), m); - expr_ref breakdownAssert(ctx.mk_eq_atom(expr, ctx.mk_eq_atom(expr->get_arg(0), mk_concat(ts0, mk_concat(expr->get_arg(1), ts1)))), m); + expr_ref breakdownAssert(ctx.mk_eq_atom(ex, ctx.mk_eq_atom(ex->get_arg(0), mk_concat(ts0, mk_concat(ex->get_arg(1), ts1)))), m); SASSERT(breakdownAssert); assert_axiom(breakdownAssert); } @@ -4168,12 +4181,7 @@ void theory_str::check_contain_by_eqc_val(expr * varNode, expr * constNode) { expr_ref_vector litems(m); - // Modification from Z3str: - // since we don't track containPairIdxMap any more, - // we check each element of contains_map to see whether - // either of its arguments are equal to varNode. - // This could possibly be made faster if we had a map class that - // let us use an expr_ref as a key. + // TODO refactor to use the new contain_pair_idx_map expr_ref_vector::iterator itor1 = contains_map.begin(); for (; itor1 != contains_map.end(); ++itor1) { @@ -4304,8 +4312,7 @@ void theory_str::check_contain_by_substr(expr * varNode, expr_ref_vector & willE ast_manager & m = get_manager(); expr_ref_vector litems(m); - // same deal as before, we do not track containPairIdxMap - // and so we check elements of contains_map instead + // TODO refactor to use the new contain_pair_idx_map expr_ref_vector::iterator itor1 = contains_map.begin(); for (; itor1 != contains_map.end(); ++itor1) { @@ -4368,8 +4375,347 @@ void theory_str::check_contain_by_substr(expr * varNode, expr_ref_vector & willE } } +bool theory_str::in_contain_idx_map(expr * n) { + return contain_pair_idx_map.contains(n); +} + void theory_str::check_contain_by_eq_nodes(expr * n1, expr * n2) { - NOT_IMPLEMENTED_YET(); // TODO NEXT + context & ctx = get_context(); + ast_manager & m = get_manager(); + + if (in_contain_idx_map(n1) && in_contain_idx_map(n2)) { + obj_pair_set::iterator keysItor1 = contain_pair_idx_map[n1].begin(); + obj_pair_set::iterator keysItor2; + + for (; keysItor1 != contain_pair_idx_map[n1].end(); keysItor1++) { + // keysItor1 is on set {<.., n1>, ..., , ...} + std::pair key1 = *keysItor1; + if (key1.first == n1 && key1.second == n2) { + expr_ref implyL(m); + expr_ref implyR(contain_pair_bool_map[key1], m); + if (n1 != n2) { + implyL = ctx.mk_eq_atom(n1, n2); + assert_implication(implyL, implyR); + } else { + assert_axiom(implyR); + } + } + + for (keysItor2 = contain_pair_idx_map[n2].begin(); + keysItor2 != contain_pair_idx_map[n2].end(); keysItor2++) { + // keysItor2 is on set {<.., n2>, ..., , ...} + std::pair key2 = *keysItor2; + // skip if the pair is eq + if (key1 == key2) { + continue; + } + + // *************************** + // Case 1: Contains(m, ...) /\ Contains(n, ) /\ m = n + // *************************** + if (key1.first == n1 && key2.first == n2) { + expr * subAst1 = key1.second; + expr * subAst2 = key2.second; + bool subAst1HasValue = false; + bool subAst2HasValue = false; + expr * subValue1 = get_eqc_value(subAst1, subAst1HasValue); + expr * subValue2 = get_eqc_value(subAst2, subAst2HasValue); + + TRACE("t_str_detail", + tout << "(Contains " << mk_pp(n1, m) << " " << mk_pp(subAst1, m) << ")" << std::endl; + tout << "(Contains " << mk_pp(n2, m) << " " << mk_pp(subAst2, m) << ")" << std::endl; + if (subAst1 != subValue1) { + tout << mk_pp(subAst1, m) << " = " << mk_pp(subValue1, m) << std::endl; + } + if (subAst2 != subValue2) { + tout << mk_pp(subAst2, m) << " = " << mk_pp(subValue2, m) << std::endl; + } + ); + + if (subAst1HasValue && subAst2HasValue) { + expr_ref_vector litems1(m); + if (n1 != n2) { + litems1.push_back(ctx.mk_eq_atom(n1, n2)); + } + if (subValue1 != subAst1) { + litems1.push_back(ctx.mk_eq_atom(subAst1, subValue1)); + } + if (subValue2 != subAst2) { + litems1.push_back(ctx.mk_eq_atom(subAst2, subValue2)); + } + + std::string subConst1 = m_strutil.get_string_constant_value(subValue1); + std::string subConst2 = m_strutil.get_string_constant_value(subValue2); + expr_ref implyR(m); + if (subConst1 == subConst2) { + // key1.first = key2.first /\ key1.second = key2.second + // ==> (containPairBoolMap[key1] = containPairBoolMap[key2]) + implyR = ctx.mk_eq_atom(contain_pair_bool_map[key1], contain_pair_bool_map[key2]); + } else if (subConst1.find(subConst2) != std::string::npos) { + // key1.first = key2.first /\ Contains(key1.second, key2.second) + // ==> (containPairBoolMap[key1] --> containPairBoolMap[key2]) + implyR = rewrite_implication(contain_pair_bool_map[key1], contain_pair_bool_map[key2]); + } else if (subConst2.find(subConst1) != std::string::npos) { + // key1.first = key2.first /\ Contains(key2.second, key1.second) + // ==> (containPairBoolMap[key2] --> containPairBoolMap[key1]) + implyR = rewrite_implication(contain_pair_bool_map[key2], contain_pair_bool_map[key1]); + } + + if (implyR) { + if (litems1.empty()) { + assert_axiom(implyR); + } else { + assert_implication(mk_and(litems1), implyR); + } + } + } else { + expr_ref_vector subAst1Eqc(m); + expr_ref_vector subAst2Eqc(m); + collect_eq_nodes(subAst1, subAst1Eqc); + collect_eq_nodes(subAst2, subAst2Eqc); + + if (subAst1Eqc.contains(subAst2)) { + // ----------------------------------------------------------- + // * key1.first = key2.first /\ key1.second = key2.second + // --> containPairBoolMap[key1] = containPairBoolMap[key2] + // ----------------------------------------------------------- + expr_ref_vector litems2(m); + if (n1 != n2) { + litems2.push_back(ctx.mk_eq_atom(n1, n2)); + } + if (subAst1 != subAst2) { + litems2.push_back(ctx.mk_eq_atom(subAst1, subAst2)); + } + expr_ref implyR(ctx.mk_eq_atom(contain_pair_bool_map[key1], contain_pair_bool_map[key2]), m); + if (litems2.empty()) { + assert_axiom(implyR); + } else { + assert_implication(mk_and(litems2), implyR); + } + } else { + // ----------------------------------------------------------- + // * key1.first = key2.first + // check eqc(key1.second) and eqc(key2.second) + // ----------------------------------------------------------- + expr_ref_vector::iterator eqItorSub1 = subAst1Eqc.begin(); + for (; eqItorSub1 != subAst1Eqc.end(); eqItorSub1++) { + expr_ref_vector::iterator eqItorSub2 = subAst2Eqc.begin(); + for (; eqItorSub2 != subAst2Eqc.end(); eqItorSub2++) { + // ------------ + // key1.first = key2.first /\ containPairBoolMap[] + // ==> (containPairBoolMap[key1] --> containPairBoolMap[key2]) + // ------------ + { + expr_ref_vector litems3(m); + if (n1 != n2) { + litems3.push_back(ctx.mk_eq_atom(n1, n2)); + } + expr * eqSubVar1 = *eqItorSub1; + if (eqSubVar1 != subAst1) { + litems3.push_back(ctx.mk_eq_atom(subAst1, eqSubVar1)); + } + expr * eqSubVar2 = *eqItorSub2; + if (eqSubVar2 != subAst2) { + litems3.push_back(ctx.mk_eq_atom(subAst2, eqSubVar2)); + } + std::pair tryKey1 = std::make_pair(eqSubVar1, eqSubVar2); + if (contain_pair_bool_map.contains(tryKey1)) { + TRACE("t_str_detail", tout << "(Contains " << mk_pp(eqSubVar1, m) << " " << mk_pp(eqSubVar2, m) << ")" << std::endl;); + litems3.push_back(contain_pair_bool_map[tryKey1]); + expr_ref implR(rewrite_implication(contain_pair_bool_map[key1], contain_pair_bool_map[key2]), m); + assert_implication(mk_and(litems3), implR); + } + } + // ------------ + // key1.first = key2.first /\ containPairBoolMap[] + // ==> (containPairBoolMap[key2] --> containPairBoolMap[key1]) + // ------------ + { + expr_ref_vector litems4(m); + if (n1 != n2) { + litems4.push_back(ctx.mk_eq_atom(n1, n2)); + } + expr * eqSubVar1 = *eqItorSub1; + if (eqSubVar1 != subAst1) { + litems4.push_back(ctx.mk_eq_atom(subAst1, eqSubVar1)); + } + expr * eqSubVar2 = *eqItorSub2; + if (eqSubVar2 != subAst2) { + litems4.push_back(ctx.mk_eq_atom(subAst2, eqSubVar2)); + } + std::pair tryKey2 = std::make_pair(eqSubVar2, eqSubVar1); + if (contain_pair_bool_map.contains(tryKey2)) { + TRACE("t_str_detail", tout << "(Contains " << mk_pp(eqSubVar2, m) << " " << mk_pp(eqSubVar1, m) << ")" << std::endl;); + litems4.push_back(contain_pair_bool_map[tryKey2]); + expr_ref implR(rewrite_implication(contain_pair_bool_map[key2], contain_pair_bool_map[key1]), m); + assert_implication(mk_and(litems4), implR); + } + } + } + } + } + } + } + // *************************** + // Case 2: Contains(..., m) /\ Contains(... , n) /\ m = n + // *************************** + else if (key1.second == n1 && key2.second == n2) { + expr * str1 = key1.first; + expr * str2 = key2.first; + bool str1HasValue = false; + bool str2HasValue = false; + expr * strVal1 = get_eqc_value(str1, str1HasValue); + expr * strVal2 = get_eqc_value(str2, str2HasValue); + + TRACE("t_str_detail", + tout << "(Contains " << mk_pp(str1, m) << " " << mk_pp(n1, m) << ")" << std::endl; + tout << "(Contains " << mk_pp(str2, m) << " " << mk_pp(n2, m) << ")" << std::endl; + if (str1 != strVal1) { + tout << mk_pp(str1, m) << " = " << mk_pp(strVal1, m) << std::endl; + } + if (str2 != strVal2) { + tout << mk_pp(str2, m) << " = " << mk_pp(strVal2, m) << std::endl; + } + ); + + if (str1HasValue && str2HasValue) { + expr_ref_vector litems1(m); + if (n1 != n2) { + litems1.push_back(ctx.mk_eq_atom(n1, n2)); + } + if (strVal1 != str1) { + litems1.push_back(ctx.mk_eq_atom(str1, strVal1)); + } + if (strVal2 != str2) { + litems1.push_back(ctx.mk_eq_atom(str2, strVal2)); + } + + std::string const1 = m_strutil.get_string_constant_value(strVal1); + std::string const2 = m_strutil.get_string_constant_value(strVal2); + expr_ref implyR(m); + + if (const1 == const2) { + // key1.second = key2.second /\ key1.first = key2.first + // ==> (containPairBoolMap[key1] = containPairBoolMap[key2]) + implyR = ctx.mk_eq_atom(contain_pair_bool_map[key1], contain_pair_bool_map[key2]); + } else if (const1.find(const2) != std::string::npos) { + // key1.second = key2.second /\ Contains(key1.first, key2.first) + // ==> (containPairBoolMap[key2] --> containPairBoolMap[key1]) + implyR = rewrite_implication(contain_pair_bool_map[key2], contain_pair_bool_map[key1]); + } else if (const2.find(const1) != std::string::npos) { + // key1.first = key2.first /\ Contains(key2.first, key1.first) + // ==> (containPairBoolMap[key1] --> containPairBoolMap[key2]) + implyR = rewrite_implication(contain_pair_bool_map[key1], contain_pair_bool_map[key2]); + } + + if (implyR) { + if (litems1.size() == 0) { + assert_axiom(implyR); + } else { + assert_implication(mk_and(litems1), implyR); + } + } + } + + else { + expr_ref_vector str1Eqc(m); + expr_ref_vector str2Eqc(m); + collect_eq_nodes(str1, str1Eqc); + collect_eq_nodes(str2, str2Eqc); + + if (str1Eqc.contains(str2)) { + // ----------------------------------------------------------- + // * key1.first = key2.first /\ key1.second = key2.second + // --> containPairBoolMap[key1] = containPairBoolMap[key2] + // ----------------------------------------------------------- + expr_ref_vector litems2(m); + if (n1 != n2) { + litems2.push_back(ctx.mk_eq_atom(n1, n2)); + } + if (str1 != str2) { + litems2.push_back(ctx.mk_eq_atom(str1, str2)); + } + expr_ref implyR(ctx.mk_eq_atom(contain_pair_bool_map[key1], contain_pair_bool_map[key2]), m); + if (litems2.empty()) { + assert_axiom(implyR); + } else { + assert_implication(mk_and(litems2), implyR); + } + } else { + // ----------------------------------------------------------- + // * key1.second = key2.second + // check eqc(key1.first) and eqc(key2.first) + // ----------------------------------------------------------- + expr_ref_vector::iterator eqItorStr1 = str1Eqc.begin(); + for (; eqItorStr1 != str1Eqc.end(); eqItorStr1++) { + expr_ref_vector::iterator eqItorStr2 = str2Eqc.begin(); + for (; eqItorStr2 != str2Eqc.end(); eqItorStr2++) { + { + expr_ref_vector litems3(m); + if (n1 != n2) { + litems3.push_back(ctx.mk_eq_atom(n1, n2)); + } + expr * eqStrVar1 = *eqItorStr1; + if (eqStrVar1 != str1) { + litems3.push_back(ctx.mk_eq_atom(str1, eqStrVar1)); + } + expr * eqStrVar2 = *eqItorStr2; + if (eqStrVar2 != str2) { + litems3.push_back(ctx.mk_eq_atom(str2, eqStrVar2)); + } + std::pair tryKey1 = std::make_pair(eqStrVar1, eqStrVar2); + if (contain_pair_bool_map.contains(tryKey1)) { + TRACE("t_str_detail", tout << "(Contains " << mk_pp(eqStrVar1, m) << " " << mk_pp(eqStrVar2, m) << ")" << std::endl;); + litems3.push_back(contain_pair_bool_map[tryKey1]); + + // ------------ + // key1.second = key2.second /\ containPairBoolMap[] + // ==> (containPairBoolMap[key2] --> containPairBoolMap[key1]) + // ------------ + expr_ref implR(rewrite_implication(contain_pair_bool_map[key2], contain_pair_bool_map[key1]), m); + assert_implication(mk_and(litems3), implR); + } + } + + { + expr_ref_vector litems4(m); + if (n1 != n2) { + litems4.push_back(ctx.mk_eq_atom(n1, n2)); + } + expr * eqStrVar1 = *eqItorStr1; + if (eqStrVar1 != str1) { + litems4.push_back(ctx.mk_eq_atom(str1, eqStrVar1)); + } + expr *eqStrVar2 = *eqItorStr2; + if (eqStrVar2 != str2) { + litems4.push_back(ctx.mk_eq_atom(str2, eqStrVar2)); + } + std::pair tryKey2 = std::make_pair(eqStrVar2, eqStrVar1); + + if (contain_pair_bool_map.contains(tryKey2)) { + TRACE("t_str_detail", tout << "(Contains " << mk_pp(eqStrVar2, m) << " " << mk_pp(eqStrVar1, m) << ")" << std::endl;); + litems4.push_back(contain_pair_bool_map[tryKey2]); + // ------------ + // key1.first = key2.first /\ containPairBoolMap[] + // ==> (containPairBoolMap[key1] --> containPairBoolMap[key2]) + // ------------ + expr_ref implR(ctx.mk_eq_atom(contain_pair_bool_map[key1], contain_pair_bool_map[key2]), m); + assert_implication(mk_and(litems4), implR); + } + } + } + } + } + } + + } + } + + if (n1 == n2) { + break; + } + } + } // (in_contain_idx_map(n1) && in_contain_idx_map(n2)) } void theory_str::check_contain_in_new_eq(expr * n1, expr * n2) { @@ -4444,9 +4790,7 @@ void theory_str::check_contain_in_new_eq(expr * n1, expr * n2) { expr_ref_vector::iterator varItor2 = varItor1; for (; varItor2 != willEqClass.end(); ++varItor2) { expr * varAst2 = *varItor2; - // for testing purposes - TRACE("t_str", tout << "WARNING: some Contains checks disabled!" << std::endl;); - // check_contain_by_eq_nodes(varAst1, varAst2); + check_contain_by_eq_nodes(varAst1, varAst2); } } } diff --git a/src/smt/theory_str.h b/src/smt/theory_str.h index 476519e5c..e5fd25894 100644 --- a/src/smt/theory_str.h +++ b/src/smt/theory_str.h @@ -49,6 +49,26 @@ namespace smt { virtual void register_value(expr * n) { /* Ignore */ } }; + // rather than modify obj_pair_map I inherit from it and add my own helper methods + class theory_str_contain_pair_bool_map_t : public obj_pair_map { + public: + expr * operator[](std::pair key) const { + expr * value; + bool found = this->find(key.first, key.second, value); + if (found) { + return value; + } else { + TRACE("t_str", tout << "WARNING: lookup miss in contain_pair_bool_map!" << std::endl;); + return NULL; + } + } + + bool contains(std::pair key) const { + expr * unused; + return this->find(key.first, key.second, unused); + } + }; + class theory_str : public theory { struct T_cut { @@ -191,7 +211,10 @@ namespace smt { std::map unroll_var_map; std::map, expr*> concat_eq_unroll_ast_map; - expr_ref_vector contains_map; // was containPairBoolMap in Z3str2 + expr_ref_vector contains_map; + + theory_str_contain_pair_bool_map_t contain_pair_bool_map; + obj_map > contain_pair_idx_map; char * char_set; std::map charSetLookupTable; @@ -200,6 +223,7 @@ namespace smt { protected: void assert_axiom(expr * e); void assert_implication(expr * premise, expr * conclusion); + expr * rewrite_implication(expr * premise, expr * conclusion); app * mk_strlen(expr * e); expr * mk_concat(expr * n1, expr * n2); @@ -313,6 +337,7 @@ namespace smt { void check_contain_by_eqc_val(expr * varNode, expr * constNode); void check_contain_by_substr(expr * varNode, expr_ref_vector & willEqClass); void check_contain_by_eq_nodes(expr * n1, expr * n2); + bool in_contain_idx_map(expr * n); void get_nodes_in_concat(expr * node, ptr_vector & nodeList); expr * simplify_concat(expr * node); diff --git a/src/util/obj_pair_set.h b/src/util/obj_pair_set.h index 29139a51d..c4212977c 100644 --- a/src/util/obj_pair_set.h +++ b/src/util/obj_pair_set.h @@ -46,6 +46,11 @@ public: bool contains(obj_pair const & p) const { return m_set.contains(p); } void reset() { m_set.reset(); } bool empty() const { return m_set.empty(); } + + typedef typename chashtable::iterator iterator; + + iterator begin() { return m_set.begin(); } + iterator end() { return m_set.end(); } }; #endif