diff --git a/src/duality/duality.h b/src/duality/duality.h index 166e8ef0d..158ad84d9 100644 --- a/src/duality/duality.h +++ b/src/duality/duality.h @@ -277,6 +277,7 @@ namespace Duality { public: std::list edges; std::list nodes; + std::list constraints; }; @@ -286,6 +287,8 @@ namespace Duality { literals dualLabels; std::list stack; std::vector axioms; // only saved here for printing purposes + solver aux_solver; + public: @@ -296,7 +299,7 @@ namespace Duality { inherit the axioms. */ - RPFP(LogicSolver *_ls) : Z3User(*(_ls->ctx), *(_ls->slvr)), dualModel(*(_ls->ctx)) + RPFP(LogicSolver *_ls) : Z3User(*(_ls->ctx), *(_ls->slvr)), dualModel(*(_ls->ctx)), aux_solver(*(_ls->ctx)) { ls = _ls; nodeCount = 0; @@ -351,10 +354,10 @@ namespace Duality { bool SubsetEq(const Transformer &other){ Term t = owner->SubstParams(other.IndParams,IndParams,other.Formula); expr test = Formula && !t; - owner->slvr.push(); - owner->slvr.add(test); - check_result res = owner->slvr.check(); - owner->slvr.pop(1); + owner->aux_solver.push(); + owner->aux_solver.add(test); + check_result res = owner->aux_solver.check(); + owner->aux_solver.pop(1); return res == unsat; } @@ -444,6 +447,19 @@ namespace Duality { return n; } + /** Delete a node. You can only do this if not connected to any edges.*/ + void DeleteNode(Node *node){ + if(node->Outgoing || !node->Incoming.empty()) + throw "cannot delete RPFP node"; + for(std::vector::iterator it = nodes.end(), en = nodes.begin(); it != en;){ + if(*(--it) == node){ + nodes.erase(it); + break; + } + } + delete node; + } + /** This class represents a hyper-edge in the RPFP graph */ class Edge @@ -460,6 +476,7 @@ namespace Duality { hash_map varMap; Edge *map; Term labeled; + std::vector constraints; Edge(Node *_Parent, const Transformer &_F, const std::vector &_Children, RPFP *_owner, int _number) : F(_F), Parent(_Parent), Children(_Children), dual(expr(_owner->ctx)) { @@ -480,6 +497,29 @@ namespace Duality { return e; } + + /** Delete a hyper-edge and unlink it from any nodes. */ + void DeleteEdge(Edge *edge){ + if(edge->Parent) + edge->Parent->Outgoing = 0; + for(unsigned int i = 0; i < edge->Children.size(); i++){ + std::vector &ic = edge->Children[i]->Incoming; + for(std::vector::iterator it = ic.begin(), en = ic.end(); it != en; ++it){ + if(*it == edge){ + ic.erase(it); + break; + } + } + } + for(std::vector::iterator it = edges.end(), en = edges.begin(); it != en;){ + if(*(--it) == edge){ + edges.erase(it); + break; + } + } + delete edge; + } + /** Create an edge that lower-bounds its parent. */ Edge *CreateLowerBoundEdge(Node *_Parent) { @@ -494,13 +534,25 @@ namespace Duality { void AssertEdge(Edge *e, int persist = 0, bool with_children = false, bool underapprox = false); - + /* Constrain an edge by the annotation of one of its children. */ + + void ConstrainParent(Edge *parent, Node *child); + /** For incremental solving, asserts the negation of the upper bound associated * with a node. * */ void AssertNode(Node *n); + /** Assert a constraint on an edge in the SMT context. + */ + void ConstrainEdge(Edge *e, const Term &t); + + /** Fix the truth values of atomic propositions in the given + edge to their values in the current assignment. */ + void FixCurrentState(Edge *root); + + /** Declare a constant in the background theory. */ void DeclareConstant(const FuncDecl &f); @@ -592,6 +644,9 @@ namespace Duality { Term ComputeUnderapprox(Node *root, int persist); + /** Try to strengthen the annotation of a node by removing disjuncts. */ + void Generalize(Node *node); + /** Push a scope. Assertions made after Push can be undone by Pop. */ void Push(); @@ -803,7 +858,15 @@ namespace Duality { Term SubstBound(hash_map &subst, const Term &t); + void ConstrainEdgeLocalized(Edge *e, const Term &t); + void GreedyReduce(solver &s, std::vector &conjuncts); + + void NegateLits(std::vector &lits); + + expr SimplifyOr(std::vector &lits); + + void SetAnnotation(Node *root, const expr &t); }; /** RPFP solver base class. */ diff --git a/src/duality/duality_rpfp.cpp b/src/duality/duality_rpfp.cpp index fe5ad8672..17e3de6bb 100644 --- a/src/duality/duality_rpfp.cpp +++ b/src/duality/duality_rpfp.cpp @@ -283,7 +283,10 @@ namespace Duality { children[i] = ToTermTree(e->Children[i]); // Term top = ReducedDualEdge(e); Term top = e->dual.null() ? ctx.bool_val(true) : e->dual; - return new TermTree(top, children); + TermTree *res = new TermTree(top, children); + for(unsigned i = 0; i < e->constraints.size(); i++) + res->addTerm(e->constraints[i]); + return res; } TermTree *RPFP::GetGoalTree(Node *root){ @@ -375,6 +378,19 @@ namespace Duality { x = x && y; } + void RPFP::SetAnnotation(Node *root, const expr &t){ + hash_map memo; + Term b; + std::vector v; + RedVars(root, b, v); + memo[b] = ctx.bool_val(true); + for (unsigned i = 0; i < v.size(); i++) + memo[v[i]] = root->Annotation.IndParams[i]; + Term annot = SubstRec(memo, t); + // Strengthen(ref root.Annotation.Formula, annot); + root->Annotation.Formula = annot; + } + void RPFP::DecodeTree(Node *root, TermTree *interp, int persist) { std::vector &ic = interp->getChildren(); @@ -384,16 +400,7 @@ namespace Duality { for (unsigned i = 0; i < nc.size(); i++) DecodeTree(nc[i], ic[i], persist); } - hash_map memo; - Term b; - std::vector v; - RedVars(root, b, v); - memo[b] = ctx.bool_val(true); - for (unsigned i = 0; i < v.size(); i++) - memo[v[i]] = root->Annotation.IndParams[i]; - Term annot = SubstRec(memo, interp->getTerm()); - // Strengthen(ref root.Annotation.Formula, annot); - root->Annotation.Formula = annot; + SetAnnotation(root,interp->getTerm()); #if 0 if(persist != 0) Z3_persist_ast(ctx,root->Annotation.Formula,persist); @@ -511,6 +518,10 @@ namespace Duality { timer_stop("solver add"); } + void RPFP::ConstrainParent(Edge *parent, Node *child){ + ConstrainEdgeLocalized(parent,GetAnnotation(child)); + } + /** For incremental solving, asserts the negation of the upper bound associated * with a node. @@ -526,6 +537,24 @@ namespace Duality { } } + /** Assert a constraint on an edge in the SMT context. + */ + + void RPFP::ConstrainEdge(Edge *e, const Term &t) + { + Term tl = Localize(e, t); + ConstrainEdgeLocalized(e,tl); + } + + void RPFP::ConstrainEdgeLocalized(Edge *e, const Term &tl) + { + e->constraints.push_back(tl); + stack.back().constraints.push_back(e); + slvr.add(tl); + } + + + /** Declare a constant in the background theory. */ void RPFP::DeclareConstant(const FuncDecl &f){ @@ -1064,7 +1093,7 @@ namespace Duality { } } /* Unreachable! */ - throw "error in RPFP::ImplicantRed"; + std::cerr << "error in RPFP::ImplicantRed"; goto done; } else if(k == Not) { @@ -1671,6 +1700,17 @@ namespace Duality { return eu; } + void RPFP::FixCurrentState(Edge *edge){ + hash_set dont_cares; + resolve_ite_memo.clear(); + timer_start("UnderapproxFormula"); + Term dual = edge->dual.null() ? ctx.bool_val(true) : edge->dual; + Term eu = UnderapproxFormula(dual,dont_cares); + timer_stop("UnderapproxFormula"); + ConstrainEdgeLocalized(edge,eu); + } + + RPFP::Term RPFP::ModelValueAsConstraint(const Term &t){ if(t.is_array()){ @@ -1714,6 +1754,69 @@ namespace Duality { res = CreateRelation(p->Annotation.IndParams,funder); } + void RPFP::GreedyReduce(solver &s, std::vector &conjuncts){ + // verify + s.push(); + expr conj = ctx.make(And,conjuncts); + s.add(conj); + check_result res = s.check(); + s.pop(1); + if(res != unsat) + throw "should be unsat"; + + for(unsigned i = 0; i < conjuncts.size(); ){ + std::swap(conjuncts[i],conjuncts.back()); + expr save = conjuncts.back(); + conjuncts.pop_back(); + s.push(); + expr conj = ctx.make(And,conjuncts); + s.add(conj); + check_result res = s.check(); + s.pop(1); + if(res != unsat){ + conjuncts.push_back(save); + std::swap(conjuncts[i],conjuncts.back()); + i++; + } + } + } + + void RPFP::NegateLits(std::vector &lits){ + for(unsigned i = 0; i < lits.size(); i++){ + expr &f = lits[i]; + if(f.is_app() && f.decl().get_decl_kind() == Not) + f = f.arg(0); + else + f = !f; + } + } + + expr RPFP::SimplifyOr(std::vector &lits){ + if(lits.size() == 0) + return ctx.bool_val(false); + if(lits.size() == 1) + return lits[0]; + return ctx.make(Or,lits); + } + + void RPFP::Generalize(Node *node){ + std::vector conjuncts; + expr fmla = GetAnnotation(node); + CollectConjuncts(fmla,conjuncts,false); + // try to remove conjuncts one at a tme + aux_solver.push(); + Edge *edge = node->Outgoing; + if(!edge->dual.null()) + aux_solver.add(edge->dual); + for(unsigned i = 0; i < edge->constraints.size(); i++){ + expr tl = edge->constraints[i]; + aux_solver.add(tl); + } + GreedyReduce(aux_solver,conjuncts); + aux_solver.pop(1); + NegateLits(conjuncts); + SetAnnotation(node,SimplifyOr(conjuncts)); + } /** Push a scope. Assertions made after Push can be undone by Pop. */ @@ -1735,6 +1838,8 @@ namespace Duality { (*it)->dual = expr(ctx,NULL); for(std::list::iterator it = back.nodes.begin(), en = back.nodes.end(); it != en; ++it) (*it)->dual = expr(ctx,NULL); + for(std::list::iterator it = back.constraints.begin(), en = back.constraints.end(); it != en; ++it) + (*it)->constraints.pop_back(); stack.pop_back(); } } diff --git a/src/duality/duality_solver.cpp b/src/duality/duality_solver.cpp index 0043998d2..afa1c7683 100644 --- a/src/duality/duality_solver.cpp +++ b/src/duality/duality_solver.cpp @@ -1270,18 +1270,24 @@ namespace Duality { } } + bool UpdateNodeToNode(Node *node, Node *top){ + if(!node->Annotation.SubsetEq(top->Annotation)){ + reporter->Update(node,top->Annotation); + indset->Update(node,top->Annotation); + updated_nodes.insert(node->map); + node->Annotation.IntersectWith(top->Annotation); + return true; + } + return false; + } + /** Update the unwinding solution, using an interpolant for the derivation tree. */ void UpdateWithInterpolant(Node *node, RPFP *tree, Node *top){ if(top->Outgoing) for(unsigned i = 0; i < top->Outgoing->Children.size(); i++) UpdateWithInterpolant(node->Outgoing->Children[i],tree,top->Outgoing->Children[i]); - if(!node->Annotation.SubsetEq(top->Annotation)){ - reporter->Update(node,top->Annotation); - indset->Update(node,top->Annotation); - updated_nodes.insert(node->map); - node->Annotation.IntersectWith(top->Annotation); - } + UpdateNodeToNode(node, top); heuristic->Update(node); } @@ -1305,7 +1311,8 @@ namespace Duality { if(node->Bound.IsFull()) return true; reporter->Bound(node); int start_decs = rpfp->CumulativeDecisions(); - DerivationTree dt(this,unwinding,reporter,heuristic,FullExpand); + DerivationTree *dtp = new DerivationTreeSlow(this,unwinding,reporter,heuristic,FullExpand); + DerivationTree &dt = *dtp; bool res = dt.Derive(unwinding,node,UseUnderapprox); int end_decs = rpfp->CumulativeDecisions(); // std::cout << "decisions: " << (end_decs - start_decs) << std::endl; @@ -1321,6 +1328,7 @@ namespace Duality { UpdateWithInterpolant(node,dt.tree,dt.top); delete dt.tree; } + delete dtp; return !res; } @@ -1491,7 +1499,7 @@ namespace Duality { return res != unsat; } - bool Build(){ + virtual bool Build(){ #ifdef EFFORT_BOUNDED_STRAT start_decs = tree->CumulativeDecisions(); #endif @@ -1545,7 +1553,7 @@ namespace Duality { } } - void ExpandNode(RPFP::Node *p){ + virtual void ExpandNode(RPFP::Node *p){ // tree->RemoveEdge(p->Outgoing); Edge *edge = duality->GetNodeOutgoing(p->map,last_decs); std::vector &cs = edge->Children; @@ -1573,6 +1581,7 @@ namespace Duality { } #else #if 0 + void ExpansionChoices(std::set &best){ std::vector unused_set, used_set; std::set choices; @@ -1668,7 +1677,7 @@ namespace Duality { #endif #endif - bool ExpandSomeNodes(bool high_priority = false){ + bool ExpandSomeNodes(bool high_priority = false, int max = INT_MAX){ #ifdef EFFORT_BOUNDED_STRAT last_decs = tree->CumulativeDecisions() - start_decs; #endif @@ -1679,17 +1688,194 @@ namespace Duality { timer_stop("ExpansionChoices"); std::list leaves_copy = leaves; // copy so can modify orig leaves.clear(); + int count = 0; for(std::list::iterator it = leaves_copy.begin(), en = leaves_copy.end(); it != en; ++it){ - if(choices.find(*it) != choices.end()) + if(choices.find(*it) != choices.end() && count < max){ + count++; ExpandNode(*it); + } else leaves.push_back(*it); } timer_stop("ExpandSomeNodes"); return !choices.empty(); } + void RemoveExpansion(RPFP::Node *p){ + Edge *edge = p->Outgoing; + Node *parent = edge->Parent; + std::vector cs = edge->Children; + tree->DeleteEdge(edge); + for(unsigned i = 0; i < cs.size(); i++) + tree->DeleteNode(cs[i]); + leaves.push_back(parent); + } }; + class DerivationTreeSlow : public DerivationTree { + public: + + struct stack_entry { + unsigned level; // SMT solver stack level + std::vector expansions; + }; + + std::vector stack; + + hash_map updates; + + DerivationTreeSlow(Duality *_duality, RPFP *rpfp, Reporter *_reporter, Heuristic *_heuristic, bool _full_expand) + : DerivationTree(_duality, rpfp, _reporter, _heuristic, _full_expand) { + stack.push_back(stack_entry()); + } + + virtual bool Build(){ + + stack.back().level = tree->slvr.get_scope_level(); + + while (true) + { + lbool res; + + unsigned slvr_level = tree->slvr.get_scope_level(); + if(slvr_level != stack.back().level) + throw "stacks out of sync!"; + + res = tree->Solve(top, 1); // incremental solve, keep interpolants for one pop + + if (res == l_false) { + if (stack.empty()) // should never happen + return false; + + std::vector &expansions = stack.back().expansions; + int update_count = 0; + for(unsigned i = 0; i < expansions.size(); i++){ + tree->Generalize(expansions[i]); + if(RecordUpdate(expansions[i])) + update_count++; + } + if(update_count == 0) + std::cout << "backtracked without learning\n"; + tree->Pop(1); + hash_set leaves_to_remove; + for(unsigned i = 0; i < expansions.size(); i++){ + Node *node = expansions[i]; + // if(node != top) + // tree->ConstrainParent(node->Incoming[0],node); + std::vector &cs = node->Outgoing->Children; + for(unsigned i = 0; i < cs.size(); i++){ + leaves_to_remove.insert(cs[i]); + UnmapNode(cs[i]); + if(std::find(updated_nodes.begin(),updated_nodes.end(),cs[i]) != updated_nodes.end()) + throw "help!"; + } + RemoveExpansion(node); + } + RemoveLeaves(leaves_to_remove); + stack.pop_back(); + HandleUpdatedNodes(); + if(stack.size() == 1) + return false; + } + else { + tree->Push(); + std::vector &expansions = stack.back().expansions; + for(unsigned i = 0; i < expansions.size(); i++){ + tree->FixCurrentState(expansions[i]->Outgoing); + } + if(tree->slvr.check() == unsat) + throw "help!"; + stack.push_back(stack_entry()); + stack.back().level = tree->slvr.get_scope_level(); + if(ExpandSomeNodes(false,1)){ + continue; + } + while(stack.size() > 1){ + tree->Pop(1); + stack.pop_back(); + } + return true; + } + } + } + + void RemoveLeaves(hash_set &leaves_to_remove){ + std::list leaves_copy; + leaves_copy.swap(leaves); + for(std::list::iterator it = leaves_copy.begin(), en = leaves_copy.end(); it != en; ++it){ + if(leaves_to_remove.find(*it) == leaves_to_remove.end()) + leaves.push_back(*it); + } + } + + hash_map > node_map; + std::list updated_nodes; + + virtual void ExpandNode(RPFP::Node *p){ + stack.back().expansions.push_back(p); + DerivationTree::ExpandNode(p); + std::vector &new_nodes = p->Outgoing->Children; + for(unsigned i = 0; i < new_nodes.size(); i++){ + Node *n = new_nodes[i]; + node_map[n->map].push_back(n); + } + } + + bool RecordUpdate(Node *node){ + bool res = duality->UpdateNodeToNode(node->map,node); + if(res){ + std::vector to_update = node_map[node->map]; + for(unsigned i = 0; i < to_update.size(); i++){ + Node *node2 = to_update[i]; + // maintain invariant that no nodes on updated list are created at current stack level + if(node2 == node || !(node->Incoming.size() > 0 && AtCurrentStackLevel(node2->Incoming[0]->Parent))){ + updated_nodes.push_back(node2); + if(node2 != node) + node2->Annotation = node->Annotation; + } + } + } + return res; + } + + void HandleUpdatedNodes(){ + for(std::list::iterator it = updated_nodes.begin(), en = updated_nodes.end(); it != en;){ + Node *node = *it; + node->Annotation = node->map->Annotation; + if(node->Incoming.size() > 0) + tree->ConstrainParent(node->Incoming[0],node); + if(AtCurrentStackLevel(node->Incoming[0]->Parent)){ + std::list::iterator victim = it; + ++it; + updated_nodes.erase(victim); + } + else + ++it; + } + } + + bool AtCurrentStackLevel(Node *node){ + std::vector vec = stack.back().expansions; + for(unsigned i = 0; i < vec.size(); i++) + if(vec[i] == node) + return true; + return false; + } + + void UnmapNode(Node *node){ + std::vector &vec = node_map[node->map]; + for(unsigned i = 0; i < vec.size(); i++){ + if(vec[i] == node){ + std::swap(vec[i],vec.back()); + vec.pop_back(); + return; + } + } + throw "can't unmap node"; + } + + }; + + class Covering { struct cover_info { diff --git a/src/duality/duality_wrapper.cpp b/src/duality/duality_wrapper.cpp index fef70e031..dd64052a0 100644 --- a/src/duality/duality_wrapper.cpp +++ b/src/duality/duality_wrapper.cpp @@ -425,15 +425,18 @@ expr context::make_quant(decl_kind op, const std::vector &_sorts, const st static int linearize_assumptions(int num, TermTree *assumptions, - std::vector &linear_assumptions, + std::vector > &linear_assumptions, std::vector &parents){ for(unsigned i = 0; i < assumptions->getChildren().size(); i++) num = linearize_assumptions(num, assumptions->getChildren()[i], linear_assumptions, parents); - linear_assumptions[num] = assumptions->getTerm(); + // linear_assumptions[num].push_back(assumptions->getTerm()); for(unsigned i = 0; i < assumptions->getChildren().size(); i++) parents[assumptions->getChildren()[i]->getNumber()] = num; parents[num] = SHRT_MAX; // in case we have no parent - linear_assumptions[num] = assumptions->getTerm(); + linear_assumptions[num].push_back(assumptions->getTerm()); + std::vector &ts = assumptions->getTerms(); + for(unsigned i = 0; i < ts.size(); i++) + linear_assumptions[num].push_back(ts[i]); return num + 1; } @@ -462,14 +465,15 @@ expr context::make_quant(decl_kind op, const std::vector &_sorts, const st { int size = assumptions->number(0); - std::vector linear_assumptions(size); + std::vector > linear_assumptions(size); std::vector parents(size); linearize_assumptions(0,assumptions,linear_assumptions,parents); ptr_vector< ::ast> _interpolants(size-1); - ptr_vector< ::ast>_assumptions(size); + vector >_assumptions(size); for(int i = 0; i < size; i++) - _assumptions[i] = linear_assumptions[i]; + for(unsigned j = 0; j < linear_assumptions[i].size(); j++) + _assumptions[i].push_back(linear_assumptions[i][j]); ::vector _parents; _parents.resize(parents.size()); for(unsigned i = 0; i < parents.size(); i++) _parents[i] = parents[i]; @@ -481,7 +485,8 @@ expr context::make_quant(decl_kind op, const std::vector &_sorts, const st if(!incremental){ for(unsigned i = 0; i < linear_assumptions.size(); i++) - add(linear_assumptions[i]); + for(unsigned j = 0; j < linear_assumptions[i].size(); j++) + add(linear_assumptions[i][j]); } check_result res = check(); diff --git a/src/duality/duality_wrapper.h b/src/duality/duality_wrapper.h index 21ed45479..291ddfbcf 100755 --- a/src/duality/duality_wrapper.h +++ b/src/duality/duality_wrapper.h @@ -867,6 +867,9 @@ namespace Duality { if(m_solver) m_solver->cancel(); } + + unsigned get_scope_level(){return m_solver->get_scope_level();} + }; #if 0 @@ -1199,6 +1202,8 @@ namespace Duality { inline expr getTerm(){return term;} + inline std::vector &getTerms(){return terms;} + inline std::vector &getChildren(){ return children; } @@ -1215,6 +1220,8 @@ namespace Duality { } inline void setTerm(expr t){term = t;} + + inline void addTerm(expr t){terms.push_back(t);} inline void setChildren(const std::vector & _children){ children = _children; @@ -1231,6 +1238,7 @@ namespace Duality { private: expr term; + std::vector terms; std::vector children; int num; }; diff --git a/src/interp/iz3interp.cpp b/src/interp/iz3interp.cpp index 92afc5723..7ef7e6aa2 100755 --- a/src/interp/iz3interp.cpp +++ b/src/interp/iz3interp.cpp @@ -75,15 +75,16 @@ struct frame_reducer : public iz3mgr { } } - void get_frames(const std::vector &z3_preds, + void get_frames(const std::vector >&z3_preds, const std::vector &orig_parents, - std::vector &assertions, + std::vector >&assertions, std::vector &parents, z3pf proof){ frames = z3_preds.size(); orig_parents_copy = orig_parents; for(unsigned i = 0; i < z3_preds.size(); i++) - frame_map[z3_preds[i]] = i; + for(unsigned j = 0; j < z3_preds[i].size(); j++) + frame_map[z3_preds[i][j]] = i; used_frames.resize(frames); hash_set memo; get_proof_assumptions_rec(proof,memo,used_frames); @@ -202,7 +203,7 @@ public: } void proof_to_interpolant(z3pf proof, - const std::vector &cnsts, + const std::vector > &cnsts, const std::vector &parents, std::vector &interps, const std::vector &theory, @@ -216,7 +217,7 @@ public: // get rid of frames not used in proof - std::vector cnsts_vec; + std::vector > cnsts_vec; std::vector parents_vec; frame_reducer fr(*this); fr.get_frames(cnsts,parents,cnsts_vec,parents_vec,proof); @@ -235,10 +236,7 @@ public: #define BINARY_INTERPOLATION #ifndef BINARY_INTERPOLATION // create a translator - std::vector > cnsts_vec_vec(cnsts_vec.size()); - for(unsigned i = 0; i < cnsts_vec.size(); i++) - cnsts_vec_vec[i].push_back(cnsts_vec[i]); - iz3translation *tr = iz3translation::create(*this,sp,cnsts_vec_vec,parents_vec,theory); + iz3translation *tr = iz3translation::create(*this,sp,cnsts_vec,parents_vec,theory); tr_killer.set(tr); // set the translation options, if needed @@ -273,7 +271,8 @@ public: std::vector > cnsts_vec_vec(2); for(unsigned j = 0; j < cnsts_vec.size(); j++){ bool is_A = the_base.in_range(j,rng); - cnsts_vec_vec[is_A ? 0 : 1].push_back(cnsts_vec[j]); + for(unsigned k = 0; k < cnsts_vec[j].size(); k++) + cnsts_vec_vec[is_A ? 0 : 1].push_back(cnsts_vec[j][k]); } killme tr_killer_i; @@ -308,6 +307,19 @@ public: } + void proof_to_interpolant(z3pf proof, + std::vector &cnsts, + const std::vector &parents, + std::vector &interps, + const std::vector &theory, + interpolation_options_struct *options = 0 + ){ + std::vector > cnsts_vec(cnsts.size()); + for(unsigned i = 0; i < cnsts.size(); i++) + cnsts_vec[i].push_back(cnsts[i]); + proof_to_interpolant(proof,cnsts_vec,parents,interps,theory,options); + } + // same as above, but represents the tree using an ast void proof_to_interpolant(const z3pf &proof, @@ -322,7 +334,6 @@ public: to_parents_vec_representation(_cnsts, tree, cnsts, parents, theory, pos_map); - //use the parents vector representation to compute interpolant proof_to_interpolant(proof,cnsts,parents,interps,theory,options); @@ -397,6 +408,35 @@ void iz3interpolate(ast_manager &_m_manager, interps[i] = itp.uncook(_interps[i]); } +void iz3interpolate(ast_manager &_m_manager, + ast *proof, + const ::vector > &cnsts, + const ::vector &parents, + ptr_vector &interps, + const ptr_vector &theory, + interpolation_options_struct * options) +{ + iz3interp itp(_m_manager); + if(options) + options->apply(itp); + std::vector > _cnsts(cnsts.size()); + std::vector _parents(parents.size()); + std::vector _interps; + std::vector _theory(theory.size()); + for(unsigned i = 0; i < cnsts.size(); i++) + for(unsigned j = 0; j < cnsts[i].size(); j++) + _cnsts[i].push_back(itp.cook(cnsts[i][j])); + for(unsigned i = 0; i < parents.size(); i++) + _parents[i] = parents[i]; + for(unsigned i = 0; i < theory.size(); i++) + _theory[i] = itp.cook(theory[i]); + iz3mgr::ast _proof = itp.cook(proof); + itp.proof_to_interpolant(_proof,_cnsts,_parents,_interps,_theory,options); + interps.resize(_interps.size()); + for(unsigned i = 0; i < interps.size(); i++) + interps[i] = itp.uncook(_interps[i]); +} + void iz3interpolate(ast_manager &_m_manager, ast *proof, const ptr_vector &cnsts, diff --git a/src/interp/iz3interp.h b/src/interp/iz3interp.h index 62f967c02..52aa716c3 100644 --- a/src/interp/iz3interp.h +++ b/src/interp/iz3interp.h @@ -56,6 +56,16 @@ void iz3interpolate(ast_manager &_m_manager, const ptr_vector &theory, interpolation_options_struct * options = 0); +/* Same as above, but each constraint is a vector of formulas. */ + +void iz3interpolate(ast_manager &_m_manager, + ast *proof, + const vector > &cnsts, + const ::vector &parents, + ptr_vector &interps, + const ptr_vector &theory, + interpolation_options_struct * options = 0); + /* Compute an interpolant from a proof. This version uses the ast representation, for compatibility with the new API. */ diff --git a/src/interp/iz3mgr.cpp b/src/interp/iz3mgr.cpp index faa4a636d..24df25f4e 100644 --- a/src/interp/iz3mgr.cpp +++ b/src/interp/iz3mgr.cpp @@ -815,6 +815,22 @@ iz3mgr::ast iz3mgr::subst(ast var, ast t, ast e){ return subst(memo,var,t,e); } +iz3mgr::ast iz3mgr::subst(stl_ext::hash_map &subst_memo,ast e){ + std::pair foo(e,ast()); + std::pair::iterator,bool> bar = subst_memo.insert(foo); + ast &res = bar.first->second; + if(bar.second){ + int nargs = num_args(e); + std::vector args(nargs); + for(int i = 0; i < nargs; i++) + args[i] = subst(subst_memo,arg(e,i)); + opr f = op(e); + if(f == Equal && args[0] == args[1]) res = mk_true(); + else res = clone(e,args); + } + return res; +} + // apply a quantifier to a formula, with some optimizations // 1) bound variable does not occur -> no quantifier // 2) bound variable must be equal to some term -> substitute diff --git a/src/interp/iz3mgr.h b/src/interp/iz3mgr.h index f6c0bdf87..760feb000 100644 --- a/src/interp/iz3mgr.h +++ b/src/interp/iz3mgr.h @@ -631,6 +631,9 @@ class iz3mgr { ast subst(ast var, ast t, ast e); + // apply a substitution defined by a map + ast subst(stl_ext::hash_map &map, ast e); + // apply a quantifier to a formula, with some optimizations // 1) bound variable does not occur -> no quantifier // 2) bound variable must be equal to some term -> substitute diff --git a/src/interp/iz3proof_itp.cpp b/src/interp/iz3proof_itp.cpp index a2d05f7f9..60bdcde9a 100644 --- a/src/interp/iz3proof_itp.cpp +++ b/src/interp/iz3proof_itp.cpp @@ -118,6 +118,28 @@ class iz3proof_itp_impl : public iz3proof_itp { where t is an arbitrary term */ symb rewrite_B; + /* a normalization step is of the form (lhs=rhs) : proof, where "proof" + is a proof of lhs=rhs and lhs is a mixed term. If rhs is a mixed term + then it must have a greater index than lhs. */ + symb normal_step; + + /* A chain of normalization steps is either "true" (the null chain) + or normal_chain( ), where step is a normalization step + and tail is a normalization chain. The lhs of must have + a less term index than any lhs in the chain. Moreover, the rhs of + may not occur as the lhs of step in . If we wish to + add lhs=rhs to the beginning of and rhs=rhs' occurs in + we must apply transitivity, transforming to lhs=rhs'. */ + + symb normal_chain; + + /* If p is a proof of Q and c is a normalization chain, then normal(p,c) + is a proof of Q(c) (that is, Q with all substitutions in c performed). */ + + symb normal; + + + ast get_placeholder(ast t){ hash_map::iterator it = placeholders.find(t); @@ -521,10 +543,16 @@ class iz3proof_itp_impl : public iz3proof_itp { throw cannot_simplify(); } + bool is_normal_ineq(const ast &ineq){ + if(sym(ineq) == normal) + return is_ineq(arg(ineq,0)); + return is_ineq(ineq); + } + ast simplify_sum(std::vector &args){ ast cond = mk_true(); ast ineq = args[0]; - if(!is_ineq(ineq)) throw cannot_simplify(); + if(!is_normal_ineq(ineq)) throw cannot_simplify(); sum_cond_ineq(ineq,cond,args[1],args[2]); return my_implies(cond,ineq); } @@ -540,6 +568,8 @@ class iz3proof_itp_impl : public iz3proof_itp { } ast ineq_from_chain(const ast &chain, ast &cond){ + if(sym(chain) == normal) + throw "normalized inequalities not supported here"; if(is_rewrite_chain(chain)){ ast last = chain_last(chain); ast rest = chain_rest(chain); @@ -561,6 +591,13 @@ class iz3proof_itp_impl : public iz3proof_itp { cond = my_and(cond,arg(ineq2,0)); } else { + if(sym(ineq) == normal || sym(ineq2) == normal){ + ast Aproves = mk_true(); + sum_normal_ineq(ineq,coeff2,ineq2,Aproves,cond); + if(!is_true(Aproves)) + throw "Aproves not handled in sum_cond_ineq"; + return; + } ast the_ineq = ineq_from_chain(ineq2,cond); if(is_ineq(the_ineq)) linear_comb(ineq,coeff2,the_ineq); @@ -569,6 +606,27 @@ class iz3proof_itp_impl : public iz3proof_itp { } } + void destruct_normal(const ast &pf, ast &p, ast &n){ + if(sym(pf) == normal){ + p = arg(pf,0); + n = arg(pf,1); + } + else { + p = pf; + n = mk_true(); + } + } + + void sum_normal_ineq(ast &ineq, const ast &coeff2, const ast &ineq2, ast &Aproves, ast &Bproves){ + ast in1,in2,n1,n2; + destruct_normal(ineq,in1,n1); + destruct_normal(ineq2,in2,n2); + ast dummy; + sum_cond_ineq(in1,dummy,coeff2,in2); + n1 = merge_normal_chains(n1,n2, Aproves, Bproves); + ineq = make(normal,in1,n1); + } + bool is_ineq(const ast &ineq){ opr o = op(ineq); if(o == Not) o = op(arg(ineq,0)); @@ -577,6 +635,12 @@ class iz3proof_itp_impl : public iz3proof_itp { // divide both sides of inequality by a non-negative integer divisor ast idiv_ineq(const ast &ineq1, const ast &divisor){ + if(sym(ineq1) == normal){ + ast in1,n1; + destruct_normal(ineq1,in1,n1); + in1 = idiv_ineq(in1,divisor); + return make(normal,in1,n1); + } if(divisor == make_int(rational(1))) return ineq1; ast ineq = ineq1; @@ -649,11 +713,18 @@ class iz3proof_itp_impl : public iz3proof_itp { ast equa = sep_cond(arg(pf,0),cond); if(is_equivrel_chain(equa)){ ast lhs,rhs; eq_from_ineq(arg(neg_equality,0),lhs,rhs); // get inequality we need to prove - ast ineqs= chain_ineqs(op(arg(neg_equality,0)),LitA,equa,lhs,rhs); // chain must be from lhs to rhs - cond = my_and(cond,chain_conditions(LitA,equa)); - ast Bconds = chain_conditions(LitB,equa); - if(is_true(Bconds) && op(ineqs) != And) - return my_implies(cond,ineqs); + LitType lhst = get_term_type(lhs), rhst = get_term_type(rhs); + if(lhst != LitMixed && rhst != LitMixed){ + ast ineqs= chain_ineqs(op(arg(neg_equality,0)),LitA,equa,lhs,rhs); // chain must be from lhs to rhs + cond = my_and(cond,chain_conditions(LitA,equa)); + ast Bconds = z3_simplify(chain_conditions(LitB,equa)); + if(is_true(Bconds) && op(ineqs) != And) + return my_implies(cond,ineqs); + } + else { + ast itp = make(Leq,make_int(rational(0)),make_int(rational(0))); + return make(normal,itp,cons_normal(fix_normal(lhs,rhs,equa),mk_true())); + } } } throw cannot_simplify(); @@ -757,11 +828,57 @@ class iz3proof_itp_impl : public iz3proof_itp { chain = concat_rewrite_chain(chain,split[1]); } } - else // if not an equivalence, must be of form T <-> pred + else { // if not an equivalence, must be of form T <-> pred chain = concat_rewrite_chain(P,PeqQ); + } return chain; } + void get_subterm_normals(const ast &ineq1, const ast &ineq2, const ast &chain, ast &normals, + const ast &pos, hash_set &memo, ast &Aproves, ast &Bproves){ + opr o1 = op(ineq1); + opr o2 = op(ineq2); + if(o1 == Not || o1 == Leq || o1 == Lt || o1 == Geq || o1 == Gt || o1 == Plus || o1 == Times){ + int n = num_args(ineq1); + if(o2 != o1 || num_args(ineq2) != n) + throw "bad inequality rewriting"; + for(int i = 0; i < n; i++){ + ast new_pos = add_pos_to_end(pos,i); + get_subterm_normals(arg(ineq1,i), arg(ineq2,i), chain, normals, new_pos, memo, Aproves, Bproves); + } + } + else if(get_term_type(ineq2) == LitMixed && memo.find(ineq2) == memo.end()){ + memo.insert(ineq2); + ast sub_chain = extract_rewrites(chain,pos); + if(is_true(sub_chain)) + throw "bad inequality rewriting"; + ast new_normal = make_normal(ineq2,ineq1,reverse_chain(sub_chain)); + normals = merge_normal_chains(normals,cons_normal(new_normal,mk_true()), Aproves, Bproves); + } + } + + ast rewrite_chain_to_normal_ineq(const ast &chain, ast &Aproves, ast &Bproves){ + ast tail, pref = get_head_chain(chain,tail,false); // pref is x=y, tail is x=y -> x'=y' + ast head = chain_last(pref); + ast ineq1 = rewrite_rhs(head); + ast ineq2 = apply_rewrite_chain(ineq1,tail); + ast nc = mk_true(); + hash_set memo; + get_subterm_normals(ineq1,ineq2,tail,nc,top_pos,memo, Aproves, Bproves); + ast itp; + if(is_rewrite_side(LitA,head)){ + itp = ineq1; + ast mc = z3_simplify(chain_side_proves(LitB,pref)); + Bproves = my_and(Bproves,mc); + } + else { + itp = make(Leq,make_int(rational(0)),make_int(rational(0))); + ast mc = z3_simplify(chain_side_proves(LitA,pref)); + Aproves = my_and(Aproves,mc); + } + return make(normal,itp,nc); + } + /* Given a chain rewrite chain deriving not P and a rewrite chain deriving P, return an interpolant. */ ast contra_chain(const ast &neg_chain, const ast &pos_chain){ // equality is a special case. we use the derivation of x=y to rewrite not(x=y) to not(y=y) @@ -790,11 +907,18 @@ class iz3proof_itp_impl : public iz3proof_itp { } ast simplify_modpon(const std::vector &args){ - ast cond = mk_true(); - ast chain = simplify_modpon_fwd(args,cond); - ast Q2 = sep_cond(args[2],cond); - ast interp = is_negation_chain(chain) ? contra_chain(chain,Q2) : contra_chain(Q2,chain); - return my_implies(cond,interp); + ast Aproves = mk_true(), Bproves = mk_true(); + ast chain = simplify_modpon_fwd(args,Bproves); + ast Q2 = sep_cond(args[2],Bproves); + ast interp; + if(is_normal_ineq(Q2)){ // inequalities are special + ast nQ2 = rewrite_chain_to_normal_ineq(chain,Aproves,Bproves); + sum_cond_ineq(nQ2,Bproves,make_int(rational(1)),Q2); + interp = normalize(nQ2); + } + else + interp = is_negation_chain(chain) ? contra_chain(chain,Q2) : contra_chain(Q2,chain); + return my_and(Aproves,my_implies(Bproves,interp)); } @@ -1035,6 +1159,12 @@ class iz3proof_itp_impl : public iz3proof_itp { return make(add_pos,make_int(rational(arg)),pos); } + ast add_pos_to_end(const ast &pos, int i){ + if(pos == top_pos) + return pos_add(i,pos); + return make(add_pos,arg(pos,0),add_pos_to_end(arg(pos,1),i)); + } + /* return the argument number of position, if not top */ int pos_arg(const ast &pos){ rational r; @@ -1170,6 +1300,10 @@ class iz3proof_itp_impl : public iz3proof_itp { return make(sym(rew),pos_add(apos,arg(rew,0)),arg(rew,1),arg(rew,2)); } + ast rewrite_pos_set(const ast &pos, const ast &rew){ + return make(sym(rew),pos,arg(rew,1),arg(rew,2)); + } + ast rewrite_up(const ast &rew){ return make(sym(rew),arg(arg(rew,0),1),arg(rew,1),arg(rew,2)); } @@ -1317,6 +1451,28 @@ class iz3proof_itp_impl : public iz3proof_itp { split_chain_rec(chain,res); } + ast extract_rewrites(const ast &chain, const ast &pos){ + if(is_true(chain)) + return chain; + ast last = chain_last(chain); + ast rest = chain_rest(chain); + ast new_rest = extract_rewrites(rest,pos); + ast p1 = rewrite_pos(last); + ast diff; + switch(pos_diff(p1,pos,diff)){ + case -1: { + ast new_last = rewrite_pos_set(diff, last); + return chain_cons(new_rest,new_last); + } + case 1: + if(rewrite_lhs(last) != rewrite_rhs(last)) + throw "bad rewrite chain"; + break; + default:; + } + return new_rest; + } + ast down_chain(const ast &chain){ ast split[2]; split_chain(chain,split); @@ -1381,7 +1537,7 @@ class iz3proof_itp_impl : public iz3proof_itp { // ast s = ineq_to_lhs(ineq); // ast srhs = arg(s,1); ast srhs = arg(ineq,0); - if(op(srhs) == Plus && num_args(srhs) == 2){ + if(op(srhs) == Plus && num_args(srhs) == 2 && arg(ineq,1) == make_int(rational(0))){ lhs = arg(srhs,0); rhs = arg(srhs,1); // if(op(lhs) == Times) @@ -1393,6 +1549,11 @@ class iz3proof_itp_impl : public iz3proof_itp { return; } } + if(op(ineq) == Leq){ + lhs = srhs; + rhs = arg(ineq,1); + return; + } throw "bad ineq"; } @@ -1404,7 +1565,171 @@ class iz3proof_itp_impl : public iz3proof_itp { return chain_cons(rest,last); } + ast apply_rewrite_chain(const ast &t, const ast &chain){ + if(is_true(chain)) + return t; + ast last = chain_last(chain); + ast rest = chain_rest(chain); + ast mid = apply_rewrite_chain(t,rest); + ast res = subst_in_pos(mid,rewrite_pos(last),rewrite_rhs(last)); + return res; + } + ast drop_rewrites(LitType t, const ast &chain, ast &remainder){ + if(!is_true(chain)){ + ast last = chain_last(chain); + ast rest = chain_rest(chain); + if(is_rewrite_side(t,last)){ + ast res = drop_rewrites(t,rest,remainder); + remainder = chain_cons(remainder,last); + return res; + } + } + remainder = mk_true(); + return chain; + } + + // Normalization chains + + ast cons_normal(const ast &first, const ast &rest){ + return make(normal_chain,first,rest); + } + + ast normal_first(const ast &t){ + return arg(t,0); + } + + ast normal_rest(const ast &t){ + return arg(t,1); + } + + ast normal_lhs(const ast &t){ + return arg(arg(t,0),1); + } + + ast normal_rhs(const ast &t){ + return arg(arg(t,0),1); + } + + ast normal_proof(const ast &t){ + return arg(t,1); + } + + ast make_normal(const ast &lhs, const ast &rhs, const ast &proof){ + return make(normal_step,make_equiv(lhs,rhs),proof); + } + + ast fix_normal(const ast &lhs, const ast &rhs, const ast &proof){ + LitType rhst = get_term_type(rhs); + if(rhst != LitMixed || ast_id(lhs) < ast_id(rhs)) + return make_normal(lhs,rhs,proof); + else + return make_normal(rhs,lhs,reverse_chain(proof)); + } + + ast chain_side_proves(LitType side, const ast &chain){ + LitType other_side = side == LitA ? LitB : LitA; + return my_and(chain_conditions(other_side,chain),my_implies(chain_conditions(side,chain),chain_formulas(side,chain))); + } + + // Merge two normalization chains + ast merge_normal_chains_rec(const ast &chain1, const ast &chain2, hash_map &trans, ast &Aproves, ast &Bproves){ + if(is_true(chain1)) + return chain2; + if(is_true(chain2)) + return chain1; + ast f1 = normal_first(chain1); + ast f2 = normal_first(chain2); + ast lhs1 = normal_lhs(f1); + ast lhs2 = normal_lhs(f2); + int id1 = ast_id(lhs1); + int id2 = ast_id(lhs2); + if(id1 < id2) return cons_normal(f1,merge_normal_chains_rec(normal_rest(chain1),chain2,trans,Aproves,Bproves)); + if(id2 < id1) return cons_normal(f2,merge_normal_chains_rec(chain1,normal_rest(chain2),trans,Aproves,Bproves)); + ast rhs1 = normal_rhs(f1); + ast rhs2 = normal_rhs(f2); + LitType t1 = get_term_type(rhs1); + LitType t2 = get_term_type(rhs2); + int tid1 = ast_id(rhs1); + int tid2 = ast_id(rhs2); + ast pf1 = normal_proof(f1); + ast pf2 = normal_proof(f2); + ast new_normal; + if(t1 == LitMixed && (t2 != LitMixed || tid2 > tid1)){ + ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2); + new_normal = f2; + trans[rhs1] = make_normal(rhs1,rhs2,new_proof); + } + else if(t2 == LitMixed && (t1 != LitMixed || tid1 > tid2)) + return merge_normal_chains_rec(chain2,chain1,trans,Aproves,Bproves); + else if(t1 == LitA && t2 == LitB){ + ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2); + ast Bproof, Aproof = drop_rewrites(LitB,new_proof,Bproof); + ast mcA = chain_side_proves(LitB,Aproof); + Bproves = my_and(Bproves,mcA); + ast mcB = chain_side_proves(LitA,Bproof); + Aproves = my_and(Aproves,mcB); + ast rep = apply_rewrite_chain(rhs1,Aproof); + new_proof = concat_rewrite_chain(pf1,Aproof); + new_normal = make_normal(rhs1,rep,new_proof); + } + else if(t1 == LitA && t2 == LitB) + return merge_normal_chains_rec(chain2,chain1,trans,Aproves,Bproves); + else if(t1 == LitA) { + ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2); + ast mc = chain_side_proves(LitB,new_proof); + Bproves = my_and(Bproves,mc); + new_normal = f1; // choice is arbitrary + } + else { /* t1 = t2 = LitB */ + ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2); + ast mc = chain_side_proves(LitA,new_proof); + Aproves = my_and(Aproves,mc); + new_normal = f1; // choice is arbitrary + } + return cons_normal(new_normal,merge_normal_chains_rec(normal_rest(chain1),normal_rest(chain2),trans,Aproves,Bproves)); + } + + ast trans_normal_chain(const ast &chain, hash_map &trans){ + if(is_true(chain)) + return chain; + ast f = normal_first(chain); + ast r = normal_rest(chain); + ast rhs = normal_rhs(f); + hash_map::iterator it = trans.find(rhs); + ast new_normal; + if(it != trans.end()){ + const ast &f2 = it->second; + ast pf = concat_rewrite_chain(normal_proof(f),normal_proof(f2)); + new_normal = make_normal(normal_lhs(f),normal_rhs(f2),pf); + } + else + new_normal = f; + return cons_normal(new_normal,trans_normal_chain(r,trans)); + } + + ast merge_normal_chains(const ast &chain1, const ast &chain2, ast &Aproves, ast &Bproves){ + hash_map trans; + ast res = merge_normal_chains_rec(chain1,chain2,trans,Aproves,Bproves); + res = trans_normal_chain(res,trans); + return res; + } + + ast normalize(const ast &t){ + if(sym(t) != normal) + return t; + ast chain = arg(t,1); + hash_map map; + for(ast c = chain; !is_true(c); c = normal_rest(c)){ + ast first = normal_first(c); + ast lhs = normal_lhs(first); + ast rhs = normal_rhs(first); + map[lhs] = rhs; + } + ast res = subst(map,arg(t,0)); + return res; + } + /** Make an assumption node. The given clause is assumed in the given frame. */ virtual node make_assumption(int frame, const std::vector &assumption){ if(!weak){ @@ -1939,6 +2264,8 @@ class iz3proof_itp_impl : public iz3proof_itp { */ ast make_refl(const ast &e){ + if(get_term_type(e) == LitA) + return mk_false(); return mk_true(); // TODO: is this right? } @@ -2141,6 +2468,12 @@ public: m().inc_ref(rewrite_A); rewrite_B = function("@rewrite_B",3,boolboolbooldom,bool_type()); m().inc_ref(rewrite_B); + normal_step = function("@normal_step",2,boolbooldom,bool_type()); + m().inc_ref(normal_step); + normal_chain = function("@normal_chain",2,boolbooldom,bool_type()); + m().inc_ref(normal_chain); + normal = function("@normal",2,boolbooldom,bool_type()); + m().inc_ref(normal); } ~iz3proof_itp_impl(){