/*++ Copyright (c) 2011 Microsoft Corporation Module Name: iz3proof.cpp Abstract: This class defines a simple interpolating proof system. Author: Ken McMillan (kenmcmil) Revision History: --*/ #ifdef _WINDOWS #pragma warning(disable:4996) #pragma warning(disable:4800) #pragma warning(disable:4267) #pragma warning(disable:4101) #endif #include "iz3proof_itp.h" using namespace stl_ext; // #define INVARIANT_CHECKING class iz3proof_itp_impl : public iz3proof_itp { prover *pv; prover::range rng; bool weak; enum LitType {LitA,LitB,LitMixed}; hash_map placeholders; // These symbols represent deduction rules /* This symbol represents a proof by contradiction. That is, contra(p,l1 /\ ... /\ lk) takes a proof p of l1,...,lk |- false and returns a proof of |- ~l1,...,~l2 */ symb contra; /* The summation rule. The term sum(p,c,i) takes a proof p of an inequality i', an integer coefficient c and an inequality i, and yieds a proof of i' + ci. */ symb sum; /* Proof rotation. The proof term rotate(q,p) takes a proof p of: Gamma, q |- false and yields a proof of: Gamma |- ~q */ symb rotate_sum; /* Inequalities to equality. leq2eq(p, q, r) takes a proof p of ~x=y, a proof q of x <= y and a proof r of y <= x and yields a proof of false. */ symb leq2eq; /* Equality to inequality. eq2leq(p, q) takes a proof p of x=y, and a proof q ~(x <= y) and and yields a proof of false. */ symb eq2leq; /* Proof term cong(p,q) takes a proof p of x=y and a proof q of t != t and returns a proof of false. */ symb cong; /* Excluded middle. exmid(phi,p,q) takes a proof p of phi and a proof q of ~\phi and returns a proof of false. */ symb exmid; /* Symmetry. symm(p) takes a proof p of x=y and produces a proof of y=x. */ symb symm; /* Modus ponens. modpon(p,e,q) takes proofs p of P, e of P=Q and q of ~Q and returns a proof of false. */ symb modpon; /* This oprerator represents a concatenation of rewrites. The term a=b;c=d represents an A rewrite from a to b, followed by a B rewrite fron b to c, followed by an A rewrite from c to d. */ symb concat; /* This represents a lack of a proof */ ast no_proof; // This is used to represent an infinitessimal value ast epsilon; // Represents the top position of a term ast top_pos; // add_pos(i,pos) represents position pos if the ith argument symb add_pos; // rewrite proof rules /* rewrite_A(pos,cond,x=y) derives A |- cond => t[x]_p = t[y]_p where t is an arbitrary term */ symb rewrite_A; /* rewrite_B(pos,cond,x=y) derives B |- cond => t[x]_p = t[y]_p, where t is an arbitrary term */ symb rewrite_B; /* a normalization step is of the form (lhs=rhs) : proof, where "proof" is a proof of lhs=rhs and lhs is a mixed term. If rhs is a mixed term then it must have a greater index than lhs. */ symb normal_step; /* A chain of normalization steps is either "true" (the null chain) or normal_chain( ), where step is a normalization step and tail is a normalization chain. The lhs of must have a less term index than any lhs in the chain. Moreover, the rhs of may not occur as the lhs of step in . If we wish to add lhs=rhs to the beginning of and rhs=rhs' occurs in we must apply transitivity, transforming to lhs=rhs'. */ symb normal_chain; /* If p is a proof of Q and c is a normalization chain, then normal(p,c) is a proof of Q(c) (that is, Q with all substitutions in c performed). */ symb normal; /** Stand-ins for quantifiers */ symb sforall, sexists; ast get_placeholder(ast t){ hash_map::iterator it = placeholders.find(t); if(it != placeholders.end()) return it->second; ast &res = placeholders[t]; res = mk_fresh_constant("@p",get_type(t)); #if 0 std::cout << "placeholder "; print_expr(std::cout,res); std::cout << " = "; print_expr(std::cout,t); std::cout << std::endl; #endif return res; } ast make_contra_node(const ast &pf, const std::vector &lits, int pfok = -1){ if(lits.size() == 0) return pf; std::vector reslits; reslits.push_back(make(contra,pf,mk_false())); for(unsigned i = 0; i < lits.size(); i++){ ast bar; if(pfok & (1 << i)) bar = make(rotate_sum,lits[i],pf); else bar = no_proof; ast foo = make(contra,bar,lits[i]); reslits.push_back(foo); } return make(And,reslits); } LitType get_term_type(const ast &lit){ prover::range r = pv->ast_scope(lit); if(pv->range_is_empty(r)) return LitMixed; if(weak) { if(pv->range_min(r) == SHRT_MIN) return pv->range_contained(r,rng) ? LitA : LitB; else return pv->ranges_intersect(r,rng) ? LitA : LitB; } else return pv->range_contained(r,rng) ? LitA : LitB; } bool term_common(const ast &t){ prover::range r = pv->ast_scope(t); return pv->ranges_intersect(r,rng) && !pv->range_contained(r,rng); } bool term_in_vocab(LitType ty, const ast &lit){ prover::range r = pv->ast_scope(lit); if(ty == LitA){ return pv->ranges_intersect(r,rng); } return !pv->range_contained(r,rng); } /** Make a resolution node with given pivot literal and premises. The conclusion of premise1 should contain the negation of the pivot literal, while the conclusion of premise2 should contain the pivot literal. */ node make_resolution(ast pivot, const std::vector &conc, node premise1, node premise2) { LitType lt = get_term_type(pivot); if(lt == LitA) return my_or(premise1,premise2); if(lt == LitB) return my_and(premise1,premise2); /* the mixed case is a bit complicated */ static int non_local_count = 0; ast res = resolve_arith(pivot,conc,premise1,premise2); #ifdef INVARIANT_CHECKING check_contra(conc,res); #endif non_local_count++; return res; } /* Handles the case of resolution on a mixed arith atom. */ ast resolve_arith(const ast &pivot, const std::vector &conc, node premise1, node premise2){ ast atom = get_lit_atom(pivot); hash_map memo; ast neg_pivot_lit = mk_not(atom); if(op(pivot) != Not) std::swap(premise1,premise2); if(op(pivot) == Equal && op(arg(pivot,0)) == Select && op(arg(pivot,1)) == Select){ neg_pivot_lit = mk_not(neg_pivot_lit); std::swap(premise1,premise2); } return resolve_arith_rec1(memo, neg_pivot_lit, premise1, premise2); } ast apply_coeff(const ast &coeff, const ast &t){ #if 0 rational r; if(!is_integer(coeff,r)) throw "ack!"; ast n = make_int(r.numerator()); ast res = make(Times,n,t); if(!r.is_int()) { ast d = make_int(r.numerator()); res = mk_idiv(res,d); } return res; #endif return make(Times,coeff,t); } ast sum_ineq(const ast &coeff1, const ast &ineq1, const ast &coeff2, const ast &ineq2){ opr sum_op = Leq; if(op(ineq1) == Lt || op(ineq2) == Lt) sum_op = Lt; ast sum_sides[2]; for(int i = 0; i < 2; i++){ sum_sides[i] = make(Plus,apply_coeff(coeff1,arg(ineq1,i)),apply_coeff(coeff2,arg(ineq2,i))); sum_sides[i] = z3_simplify(sum_sides[i]); } return make(sum_op,sum_sides[0],sum_sides[1]); } void collect_contra_resolvents(int from, const ast &pivot1, const ast &pivot, const ast &conj, std::vector &res){ int nargs = num_args(conj); for(int i = from; i < nargs; i++){ ast f = arg(conj,i); if(!(f == pivot)){ ast ph = get_placeholder(mk_not(arg(pivot1,1))); ast pf = arg(pivot1,0); ast thing = pf == no_proof ? no_proof : subst_term_and_simp(ph,pf,arg(f,0)); ast newf = make(contra,thing,arg(f,1)); res.push_back(newf); } } } bool is_negative_equality(const ast &e){ if(op(e) == Not){ opr o = op(arg(e,0)); return o == Equal || o == Iff; } return false; } int count_negative_equalities(const std::vector &resolvent){ int res = 0; for(unsigned i = 0; i < resolvent.size(); i++) if(is_negative_equality(arg(resolvent[i],1))) res++; return res; } ast resolve_contra_nf(const ast &pivot1, const ast &conj1, const ast &pivot2, const ast &conj2){ std::vector resolvent; collect_contra_resolvents(0,pivot1,pivot2,conj2,resolvent); collect_contra_resolvents(1,pivot2,pivot1,conj1,resolvent); if(count_negative_equalities(resolvent) > 1) throw proof_error(); if(resolvent.size() == 1) // we have proved a contradiction return simplify(arg(resolvent[0],0)); // this is the proof -- get interpolant return make(And,resolvent); } ast resolve_contra(const ast &pivot1, const ast &conj1, const ast &pivot2, const ast &conj2){ if(arg(pivot1,0) != no_proof) return resolve_contra_nf(pivot1, conj1, pivot2, conj2); if(arg(pivot2,0) != no_proof) return resolve_contra_nf(pivot2, conj2, pivot1, conj1); return resolve_with_quantifier(pivot1, conj1, pivot2, conj2); } bool is_contra_itp(const ast &pivot1, ast itp2, ast &pivot2){ if(op(itp2) == And){ int nargs = num_args(itp2); for(int i = 1; i < nargs; i++){ ast foo = arg(itp2,i); if(op(foo) == Uninterpreted && sym(foo) == contra){ if(arg(foo,1) == pivot1){ pivot2 = foo; return true; } } else break; } } return false; } ast resolve_arith_rec2(hash_map &memo, const ast &pivot1, const ast &conj1, const ast &itp2){ ast &res = memo[itp2]; if(!res.null()) return res; ast pivot2; if(is_contra_itp(mk_not(arg(pivot1,1)),itp2,pivot2)) res = resolve_contra(pivot1,conj1,pivot2,itp2); else { switch(op(itp2)){ case Or: case And: case Implies: { unsigned nargs = num_args(itp2); std::vector args; args.resize(nargs); for(unsigned i = 0; i < nargs; i++) args[i] = resolve_arith_rec2(memo, pivot1, conj1, arg(itp2,i)); ast foo = itp2; // get rid of const res = clone(foo,args); break; } default: { opr o = op(itp2); if(o == Uninterpreted){ symb s = sym(itp2); if(s == sforall || s == sexists) res = make(s,arg(itp2,0),resolve_arith_rec2(memo, pivot1, conj1, arg(itp2,1))); else res = itp2; } else { res = itp2; } } } } return res; } ast resolve_arith_rec1(hash_map &memo, const ast &neg_pivot_lit, const ast &itp1, const ast &itp2){ ast &res = memo[itp1]; if(!res.null()) return res; ast pivot1; if(is_contra_itp(neg_pivot_lit,itp1,pivot1)){ hash_map memo2; res = resolve_arith_rec2(memo2,pivot1,itp1,itp2); } else { switch(op(itp1)){ case Or: case And: case Implies: { unsigned nargs = num_args(itp1); std::vector args; args.resize(nargs); for(unsigned i = 0; i < nargs; i++) args[i] = resolve_arith_rec1(memo, neg_pivot_lit, arg(itp1,i), itp2); ast foo = itp1; // get rid of const res = clone(foo,args); break; } default: { opr o = op(itp1); if(o == Uninterpreted){ symb s = sym(itp1); if(s == sforall || s == sexists) res = make(s,arg(itp1,0),resolve_arith_rec1(memo, neg_pivot_lit, arg(itp1,1), itp2)); else res = itp1; } else { res = itp1; } } } } return res; } void check_contra(hash_set &memo, hash_set &neg_lits, const ast &foo){ if(memo.find(foo) != memo.end()) return; memo.insert(foo); if(op(foo) == Uninterpreted && sym(foo) == contra){ ast neg_lit = arg(foo,1); if(!is_false(neg_lit) && neg_lits.find(neg_lit) == neg_lits.end()) throw "lost a literal"; return; } else { switch(op(foo)){ case Or: case And: case Implies: { unsigned nargs = num_args(foo); std::vector args; args.resize(nargs); for(unsigned i = 0; i < nargs; i++) check_contra(memo, neg_lits, arg(foo,i)); break; } default: break; } } } void check_contra(const std::vector &neg_lits, const ast &foo){ hash_set memo; hash_set neg_lits_set; for(unsigned i = 0; i < neg_lits.size(); i++) if(get_term_type(neg_lits[i]) == LitMixed) neg_lits_set.insert(mk_not(neg_lits[i])); check_contra(memo,neg_lits_set,foo); } hash_map subst_memo; // memo of subst function ast subst_term_and_simp(const ast &var, const ast &t, const ast &e){ subst_memo.clear(); return subst_term_and_simp_rec(var,t,e); } ast subst_term_and_simp_rec(const ast &var, const ast &t, const ast &e){ if(e == var) return t; std::pair foo(e,ast()); std::pair::iterator,bool> bar = subst_memo.insert(foo); ast &res = bar.first->second; if(bar.second){ if(op(e) == Uninterpreted){ symb g = sym(e); if(g == rotate_sum){ if(var == get_placeholder(arg(e,0))){ res = e; } else res = make(rotate_sum,arg(e,0),subst_term_and_simp_rec(var,t,arg(e,1))); return res; } if(g == concat){ res = e; return res; } } int nargs = num_args(e); std::vector args(nargs); for(int i = 0; i < nargs; i++) args[i] = subst_term_and_simp_rec(var,t,arg(e,i)); opr f = op(e); if(f == Equal && args[0] == args[1]) res = mk_true(); else if(f == And) res = my_and(args); else if(f == Or) res = my_or(args); else if(f == Idiv) res = mk_idiv(args[0],args[1]); else res = clone(e,args); } return res; } /* This is where the real work happens. Here, we simplify the proof obtained by cut elimination, obtaining an interpolant. */ struct cannot_simplify {}; hash_map simplify_memo; ast simplify(const ast &t){ ast res = normalize(simplify_rec(t)); #ifdef BOGUS_QUANTS if(localization_vars.size()) res = add_quants(z3_simplify(res)); #endif return res; } ast simplify_rec(const ast &e){ std::pair foo(e,ast()); std::pair::iterator,bool> bar = simplify_memo.insert(foo); ast &res = bar.first->second; if(bar.second){ int nargs = num_args(e); std::vector args(nargs); bool placeholder_arg = false; symb g = sym(e); if(g == concat){ res = e; return res; } for(int i = 0; i < nargs; i++){ if(i == 0 && g == rotate_sum) args[i] = arg(e,i); else args[i] = simplify_rec(arg(e,i)); placeholder_arg |= is_placeholder(args[i]); } try { opr f = op(e); if(f == Equal && args[0] == args[1]) res = mk_true(); else if(f == And) res = my_and(args); else if(f == Or) res = my_or(args); else if(f == Idiv) res = mk_idiv(args[0],args[1]); else if(f == Uninterpreted && !placeholder_arg){ if(g == rotate_sum) res = simplify_rotate(args); else if(g == symm) res = simplify_symm(args); else if(g == modpon) res = simplify_modpon(args); else if(g == sum) res = simplify_sum(args); else if(g == exmid) res = simplify_exmid(args); else if(g == cong) res = simplify_cong(args); #if 0 else if(g == modpon) res = simplify_modpon(args); else if(g == leq2eq) res = simplify_leq2eq(args); else if(g == eq2leq) res = simplify_eq2leq(args); #endif else res = clone(e,args); } else res = clone(e,args); } catch (const cannot_simplify &){ res = clone(e,args); } } return res; } ast simplify_rotate(const std::vector &args){ const ast &pf = args[1]; ast pl = get_placeholder(args[0]); if(op(pf) == Uninterpreted){ symb g = sym(pf); if(g == sum) return simplify_rotate_sum(pl,pf); if(g == leq2eq) return simplify_rotate_leq2eq(pl,args[0],pf); if(g == eq2leq) return simplify_rotate_eq2leq(pl,args[0],pf); if(g == cong) return simplify_rotate_cong(pl,args[0],pf); if(g == modpon) return simplify_rotate_modpon(pl,args[0],pf); // if(g == symm) return simplify_rotate_symm(pl,args[0],pf); } if(op(pf) == Leq) throw "foo!"; throw cannot_simplify(); } bool is_normal_ineq(const ast &ineq){ if(sym(ineq) == normal) return is_ineq(arg(ineq,0)); return is_ineq(ineq); } ast destruct_cond_ineq(const ast &ineq, ast &Aproves, ast &Bproves){ ast res = ineq; opr o = op(res); if(o == And){ Aproves = my_and(Aproves,arg(res,0)); res = arg(res,1); o = op(res); } if(o == Implies){ Bproves = my_and(Bproves,arg(res,0)); res = arg(res,1); } return res; } ast simplify_sum(std::vector &args){ ast Aproves = mk_true(), Bproves = mk_true(); ast ineq = destruct_cond_ineq(args[0],Aproves,Bproves); if(!is_normal_ineq(ineq)) throw cannot_simplify(); sum_cond_ineq(ineq,args[1],args[2],Aproves,Bproves); return my_and(Aproves,my_implies(Bproves,ineq)); } ast simplify_rotate_sum(const ast &pl, const ast &pf){ ast Aproves = mk_true(), Bproves = mk_true(); ast ineq = make(Leq,make_int("0"),make_int("0")); ineq = rotate_sum_rec(pl,pf,Aproves,Bproves,ineq); if(is_true(Aproves) && is_true(Bproves)) return ineq; return my_and(Aproves,my_implies(Bproves,ineq)); } bool is_rewrite_chain(const ast &chain){ return sym(chain) == concat; } #if 0 ast ineq_from_chain_simple(const ast &chain, ast &cond){ if(is_true(chain)) return chain; ast last = chain_last(chain); ast rest = chain_rest(chain); if(is_true(rest) && is_rewrite_side(LitA,last) && is_true(rewrite_lhs(last))){ cond = my_and(cond,rewrite_cond(last)); return rewrite_rhs(last); } if(is_rewrite_side(LitB,last) && is_true(rewrite_cond(last))) return ineq_from_chain_simple(rest,cond); return chain; } #endif ast ineq_from_chain(const ast &chain, ast &Aproves, ast &Bproves){ if(is_rewrite_chain(chain)) return rewrite_chain_to_normal_ineq(chain,Aproves,Bproves); return chain; } void sum_cond_ineq(ast &ineq, const ast &coeff2, const ast &ineq2, ast &Aproves, ast &Bproves){ opr o = op(ineq2); if(o == And){ sum_cond_ineq(ineq,coeff2,arg(ineq2,1),Aproves,Bproves); Aproves = my_and(Aproves,arg(ineq2,0)); } else if(o == Implies){ sum_cond_ineq(ineq,coeff2,arg(ineq2,1),Aproves,Bproves); Bproves = my_and(Bproves,arg(ineq2,0)); } else { ast the_ineq = ineq_from_chain(ineq2,Aproves,Bproves); if(sym(ineq) == normal || sym(the_ineq) == normal){ sum_normal_ineq(ineq,coeff2,the_ineq,Aproves,Bproves); return; } if(is_ineq(the_ineq)) linear_comb(ineq,coeff2,the_ineq); else throw cannot_simplify(); } } void destruct_normal(const ast &pf, ast &p, ast &n){ if(sym(pf) == normal){ p = arg(pf,0); n = arg(pf,1); } else { p = pf; n = mk_true(); } } void sum_normal_ineq(ast &ineq, const ast &coeff2, const ast &ineq2, ast &Aproves, ast &Bproves){ ast in1,in2,n1,n2; destruct_normal(ineq,in1,n1); destruct_normal(ineq2,in2,n2); ast dummy1, dummy2; sum_cond_ineq(in1,coeff2,in2,dummy1,dummy2); n1 = merge_normal_chains(n1,n2, Aproves, Bproves); ineq = is_true(n1) ? in1 : make_normal(in1,n1); } bool is_ineq(const ast &ineq){ opr o = op(ineq); if(o == Not) o = op(arg(ineq,0)); return o == Leq || o == Lt || o == Geq || o == Gt; } // divide both sides of inequality by a non-negative integer divisor ast idiv_ineq(const ast &ineq1, const ast &divisor){ if(sym(ineq1) == normal){ ast in1,n1; destruct_normal(ineq1,in1,n1); in1 = idiv_ineq(in1,divisor); return make_normal(in1,n1); } if(divisor == make_int(rational(1))) return ineq1; ast ineq = ineq1; if(op(ineq) == Lt) ineq = simplify_ineq(make(Leq,arg(ineq,0),make(Sub,arg(ineq,1),make_int("1")))); return make(op(ineq),mk_idiv(arg(ineq,0),divisor),mk_idiv(arg(ineq,1),divisor)); } ast rotate_sum_rec(const ast &pl, const ast &pf, ast &Aproves, ast &Bproves, ast &ineq){ if(pf == pl){ if(sym(ineq) == normal) return ineq; return simplify_ineq(ineq); } if(op(pf) == Uninterpreted && sym(pf) == sum){ if(arg(pf,2) == pl){ sum_cond_ineq(ineq,make_int("1"),arg(pf,0),Aproves,Bproves); ineq = idiv_ineq(ineq,arg(pf,1)); return ineq; } sum_cond_ineq(ineq,arg(pf,1),arg(pf,2),Aproves,Bproves); return rotate_sum_rec(pl,arg(pf,0),Aproves,Bproves,ineq); } throw cannot_simplify(); } ast simplify_rotate_leq2eq(const ast &pl, const ast &neg_equality, const ast &pf){ if(pl == arg(pf,0)){ ast equality = arg(neg_equality,0); ast x = arg(equality,0); ast y = arg(equality,1); ast Aproves1 = mk_true(), Bproves1 = mk_true(); ast pf1 = destruct_cond_ineq(arg(pf,1), Aproves1, Bproves1); ast pf2 = destruct_cond_ineq(arg(pf,2), Aproves1, Bproves1); ast xleqy = round_ineq(ineq_from_chain(pf1,Aproves1,Bproves1)); ast yleqx = round_ineq(ineq_from_chain(pf2,Aproves1,Bproves1)); ast ineq1 = make(Leq,make_int("0"),make_int("0")); sum_cond_ineq(ineq1,make_int("-1"),xleqy,Aproves1,Bproves1); sum_cond_ineq(ineq1,make_int("-1"),yleqx,Aproves1,Bproves1); ast Acond = my_implies(Aproves1,my_and(Bproves1,z3_simplify(ineq1))); ast Aproves2 = mk_true(), Bproves2 = mk_true(); ast ineq2 = make(Leq,make_int("0"),make_int("0")); sum_cond_ineq(ineq2,make_int("1"),xleqy,Aproves2,Bproves2); sum_cond_ineq(ineq2,make_int("1"),yleqx,Aproves2,Bproves2); ast Bcond = my_implies(Bproves1,my_and(Aproves1,z3_simplify(ineq2))); // if(!is_true(Aproves1) || !is_true(Bproves1)) // std::cout << "foo!\n";; if(get_term_type(x) == LitA){ ast iter = z3_simplify(make(Plus,x,get_ineq_rhs(xleqy))); ast rewrite1 = make_rewrite(LitA,top_pos,Acond,make(Equal,x,iter)); ast rewrite2 = make_rewrite(LitB,top_pos,Bcond,make(Equal,iter,y)); return chain_cons(chain_cons(mk_true(),rewrite1),rewrite2); } if(get_term_type(y) == LitA){ ast iter = z3_simplify(make(Plus,y,get_ineq_rhs(yleqx))); ast rewrite2 = make_rewrite(LitA,top_pos,Acond,make(Equal,iter,y)); ast rewrite1 = make_rewrite(LitB,top_pos,Bcond,make(Equal,x,iter)); return chain_cons(chain_cons(mk_true(),rewrite1),rewrite2); } throw cannot_simplify(); } throw cannot_simplify(); } ast round_ineq(const ast &ineq){ if(sym(ineq) == normal) return make_normal(round_ineq(arg(ineq,0)),arg(ineq,1)); if(!is_ineq(ineq)) throw cannot_simplify(); ast res = simplify_ineq(ineq); if(op(res) == Lt) res = make(Leq,arg(res,0),make(Sub,arg(res,1),make_int("1"))); return res; } ast unmixed_eq2ineq(const ast &lhs, const ast &rhs, opr comp_op, const ast &equa, ast &cond){ ast ineqs= chain_ineqs(comp_op,LitA,equa,lhs,rhs); // chain must be from lhs to rhs cond = my_and(cond,chain_conditions(LitA,equa)); ast Bconds = z3_simplify(chain_conditions(LitB,equa)); if(is_true(Bconds) && op(ineqs) != And) return my_implies(cond,ineqs); if(op(ineqs) != And) return my_and(Bconds,my_implies(cond,ineqs)); throw "help!"; } ast add_mixed_eq2ineq(const ast &lhs, const ast &rhs, const ast &equa, const ast &itp){ if(is_true(equa)) return itp; std::vector args(3); args[0] = itp; args[1] = make_int("1"); ast ineq = make(Leq,make_int(rational(0)),make_int(rational(0))); args[2] = make_normal(ineq,cons_normal(fix_normal(lhs,rhs,equa),mk_true())); return simplify_sum(args); } ast simplify_rotate_eq2leq(const ast &pl, const ast &neg_equality, const ast &pf){ if(pl == arg(pf,1)){ ast cond = mk_true(); ast equa = sep_cond(arg(pf,0),cond); if(is_equivrel_chain(equa)){ ast lhs,rhs; eq_from_ineq(arg(neg_equality,0),lhs,rhs); // get inequality we need to prove if(!rewrites_from_to(equa,lhs,rhs)){ lhs = arg(arg(neg_equality,0),0); // the equality proved is ambiguous, sadly rhs = arg(arg(neg_equality,0),1); } LitType lhst = get_term_type(lhs), rhst = get_term_type(rhs); if(lhst != LitMixed && rhst != LitMixed) return unmixed_eq2ineq(lhs, rhs, op(arg(neg_equality,0)), equa, cond); else { ast left, left_term, middle, right_term, right; left = get_left_movers(equa,lhs,middle,left_term); middle = get_right_movers(middle,rhs,right,right_term); ast itp = unmixed_eq2ineq(left_term, right_term, op(arg(neg_equality,0)), middle, cond); // itp = my_implies(cond,itp); itp = add_mixed_eq2ineq(lhs, left_term, left, itp); itp = add_mixed_eq2ineq(right_term, rhs, right, itp); return itp; } } } throw "help!"; } void reverse_modpon(std::vector &args){ std::vector sargs(1); sargs[0] = args[1]; args[1] = simplify_symm(sargs); if(is_equivrel_chain(args[2])) args[1] = down_chain(args[1]); std::swap(args[0],args[2]); } ast simplify_rotate_modpon(const ast &pl, const ast &neg_equality, const ast &pf){ std::vector args; args.resize(3); args[0] = arg(pf,0); args[1] = arg(pf,1); args[2] = arg(pf,2); if(pl == args[0]) reverse_modpon(args); if(pl == args[2]){ ast cond = mk_true(); ast chain = simplify_modpon_fwd(args, cond); return my_implies(cond,chain); } throw cannot_simplify(); } ast get_ineq_rhs(const ast &ineq2){ opr o = op(ineq2); if(o == Implies) return get_ineq_rhs(arg(ineq2,1)); else if(o == Leq || o == Lt) return arg(ineq2,1); throw cannot_simplify(); } ast simplify_rotate_cong(const ast &pl, const ast &neg_equality, const ast &pf){ if(pl == arg(pf,2)){ if(op(arg(pf,0)) == True) return mk_true(); rational pos; if(is_numeral(arg(pf,1),pos)){ int ipos = pos.get_unsigned(); ast cond = mk_true(); ast equa = sep_cond(arg(pf,0),cond); #if 0 if(op(equa) == Equal){ ast pe = mk_not(neg_equality); ast lhs = subst_in_arg_pos(ipos,arg(equa,0),arg(pe,0)); ast rhs = subst_in_arg_pos(ipos,arg(equa,1),arg(pe,1)); ast res = make(Equal,lhs,rhs); return my_implies(cond,res); } #endif ast res = chain_pos_add(ipos,equa); return my_implies(cond,res); } } throw cannot_simplify(); } ast simplify_symm(const std::vector &args){ if(op(args[0]) == True) return mk_true(); ast cond = mk_true(); ast equa = sep_cond(args[0],cond); if(is_equivrel_chain(equa)) return my_implies(cond,reverse_chain(equa)); if(is_negation_chain(equa)) return commute_negation_chain(equa); throw cannot_simplify(); } ast simplify_modpon_fwd(const std::vector &args, ast &cond){ ast P = sep_cond(args[0],cond); ast PeqQ = sep_cond(args[1],cond); ast chain; if(is_equivrel_chain(P)){ try { ast split[2]; split_chain(PeqQ,split); chain = reverse_chain(split[0]); chain = concat_rewrite_chain(chain,P); chain = concat_rewrite_chain(chain,split[1]); } catch(const cannot_split &){ static int this_count = 0; this_count++; ast tail, pref = get_head_chain(PeqQ,tail,false); // pref is x=y, tail is x=y -> x'=y' ast split[2]; split_chain(tail,split); // rewrites from x to x' and y to y' ast head = chain_last(pref); ast prem = make_rewrite(rewrite_side(head),top_pos,rewrite_cond(head),make(Iff,mk_true(),mk_not(rewrite_lhs(head)))); ast back_chain = chain_cons(mk_true(),prem); back_chain = concat_rewrite_chain(back_chain,chain_pos_add(0,reverse_chain(chain_rest(pref)))); ast cond = contra_chain(back_chain,P); if(is_rewrite_side(LitA,head)) cond = mk_not(cond); ast fwd_rewrite = make_rewrite(rewrite_side(head),top_pos,cond,rewrite_rhs(head)); P = chain_cons(mk_true(),fwd_rewrite); chain = reverse_chain(split[0]); chain = concat_rewrite_chain(chain,P); chain = concat_rewrite_chain(chain,split[1]); } } else { // if not an equivalence, must be of form T <-> pred chain = concat_rewrite_chain(P,PeqQ); } return chain; } struct subterm_normals_failed {}; void get_subterm_normals(const ast &ineq1, const ast &ineq2, const ast &chain, ast &normals, const ast &pos, hash_set &memo, ast &Aproves, ast &Bproves){ opr o1 = op(ineq1); opr o2 = op(ineq2); if(o1 == Not || o1 == Leq || o1 == Lt || o1 == Geq || o1 == Gt || o1 == Plus || o1 == Times){ int n = num_args(ineq1); if(o2 != o1 || num_args(ineq2) != n) throw "bad inequality rewriting"; for(int i = 0; i < n; i++){ ast new_pos = add_pos_to_end(pos,i); get_subterm_normals(arg(ineq1,i), arg(ineq2,i), chain, normals, new_pos, memo, Aproves, Bproves); } } else if(get_term_type(ineq2) == LitMixed){ if(memo.find(ineq2) == memo.end()){ memo.insert(ineq2); ast sub_chain = extract_rewrites(chain,pos); if(is_true(sub_chain)) throw "bad inequality rewriting"; ast new_normal = make_normal_step(ineq2,ineq1,reverse_chain(sub_chain)); normals = merge_normal_chains(normals,cons_normal(new_normal,mk_true()), Aproves, Bproves); } } else if(!(ineq1 == ineq2)) throw subterm_normals_failed(); } ast rewrites_to_normals(const ast &ineq1, const ast &chain, ast &normals, ast &Aproves, ast &Bproves, ast &Aineqs){ if(is_true(chain)) return ineq1; ast last = chain_last(chain); ast rest = chain_rest(chain); ast new_ineq1 = rewrites_to_normals(ineq1, rest, normals, Aproves, Bproves, Aineqs); ast p1 = rewrite_pos(last); ast term1; ast coeff = arith_rewrite_coeff(new_ineq1,p1,term1); ast res = subst_in_pos(new_ineq1,rewrite_pos(last),rewrite_rhs(last)); ast rpos; pos_diff(p1,rewrite_pos(last),rpos); ast term2 = subst_in_pos(term1,rpos,rewrite_rhs(last)); if(get_term_type(term1) != LitMixed && get_term_type(term2) != LitMixed){ if(is_rewrite_side(LitA,last)) linear_comb(Aineqs,coeff,make(Leq,make_int(rational(0)),make(Sub,term2,term1))); } else { ast pf = extract_rewrites(make(concat,mk_true(),rest),p1); ast new_normal = fix_normal(term1,term2,pf); normals = merge_normal_chains(normals,cons_normal(new_normal,mk_true()), Aproves, Bproves); } return res; } ast arith_rewrite_coeff(const ast &ineq, ast &p1, ast &term){ ast coeff = make_int(rational(1)); if(p1 == top_pos){ term = ineq; return coeff; } int argpos = pos_arg(p1); opr o = op(ineq); switch(o){ case Leq: case Lt: coeff = argpos ? make_int(rational(1)) : make_int(rational(-1)); break; case Geq: case Gt: coeff = argpos ? make_int(rational(-1)) : make_int(rational(1)); break; case Not: case Plus: break; case Times: coeff = arg(ineq,0); break; default: p1 = top_pos; term = ineq; return coeff; } p1 = arg(p1,1); ast res = arith_rewrite_coeff(arg(ineq,argpos),p1,term); p1 = pos_add(argpos,p1); return coeff == make_int(rational(1)) ? res : make(Times,coeff,res); } ast rewrite_chain_to_normal_ineq(const ast &chain, ast &Aproves, ast &Bproves){ ast tail, pref = get_head_chain(chain,tail,false); // pref is x=y, tail is x=y -> x'=y' ast head = chain_last(pref); ast ineq1 = rewrite_rhs(head); ast ineq2 = apply_rewrite_chain(ineq1,tail); ast nc = mk_true(); hash_set memo; ast itp = make(Leq,make_int(rational(0)),make_int(rational(0))); ast Aproves_save = Aproves, Bproves_save = Bproves; try { get_subterm_normals(ineq1,ineq2,tail,nc,top_pos,memo, Aproves, Bproves); } catch (const subterm_normals_failed &){ Aproves = Aproves_save; Bproves = Bproves_save; nc = mk_true(); rewrites_to_normals(ineq1, tail, nc, Aproves, Bproves, itp); } if(is_rewrite_side(LitA,head)){ linear_comb(itp,make_int("1"),ineq1); // make sure it is normal form //itp = ineq1; ast mc = z3_simplify(chain_side_proves(LitB,pref)); Bproves = my_and(Bproves,mc); } else { ast mc = z3_simplify(chain_side_proves(LitA,pref)); Aproves = my_and(Aproves,mc); } if(is_true(nc)) return itp; return make_normal(itp,nc); } /* Given a chain rewrite chain deriving not P and a rewrite chain deriving P, return an interpolant. */ ast contra_chain(const ast &neg_chain, const ast &pos_chain){ // equality is a special case. we use the derivation of x=y to rewrite not(x=y) to not(y=y) if(is_equivrel_chain(pos_chain)){ ast tail, pref = get_head_chain(neg_chain,tail); // pref is not(x=y), tail is not(x,y) -> not(x',y') ast split[2]; split_chain(down_chain(tail),split); // rewrites from x to x' and y to y' ast chain = split[0]; chain = concat_rewrite_chain(chain,pos_chain); // rewrites from x to y' chain = concat_rewrite_chain(chain,reverse_chain(split[1])); // rewrites from x to y chain = concat_rewrite_chain(pref,chain_pos_add(0,chain_pos_add(0,chain))); // rewrites t -> not(y=y) ast head = chain_last(pref); if(is_rewrite_side(LitB,head)){ ast condition = chain_conditions(LitB,chain); return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),condition); } else { ast condition = chain_conditions(LitA,chain); return my_and(chain_conditions(LitB,chain),my_implies(condition,mk_not(chain_formulas(LitB,chain)))); } // ast chain = concat_rewrite_chain(neg_chain,chain_pos_add(0,chain_pos_add(0,pos_chain))); // return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),chain_conditions(LitB,chain)); } // otherwise, we reverse the derivation of t = P and use it to rewrite not(P) to not(t) ast chain = concat_rewrite_chain(neg_chain,chain_pos_add(0,reverse_chain(pos_chain))); return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),chain_conditions(LitB,chain)); } ast simplify_modpon(const std::vector &args){ ast Aproves = mk_true(), Bproves = mk_true(); ast chain = simplify_modpon_fwd(args,Bproves); ast Q2 = destruct_cond_ineq(args[2],Aproves,Bproves); ast interp; if(is_normal_ineq(Q2)){ // inequalities are special ast nQ2 = rewrite_chain_to_normal_ineq(chain,Aproves,Bproves); sum_cond_ineq(nQ2,make_int(rational(1)),Q2,Aproves,Bproves); interp = normalize(nQ2); } else interp = is_negation_chain(chain) ? contra_chain(chain,Q2) : contra_chain(Q2,chain); return my_and(Aproves,my_implies(Bproves,interp)); } ast simplify_exmid(const std::vector &args){ if(is_equivrel(args[0])){ ast Aproves = mk_true(), Bproves = mk_true(); ast chain = destruct_cond_ineq(args[1],Aproves,Bproves); ast Q2 = destruct_cond_ineq(args[2],Aproves,Bproves); ast interp = contra_chain(Q2,chain); return my_and(Aproves,my_implies(Bproves,interp)); } throw "bad exmid"; } ast simplify_cong(const std::vector &args){ ast Aproves = mk_true(), Bproves = mk_true(); ast chain = destruct_cond_ineq(args[0],Aproves,Bproves); rational pos; if(is_numeral(args[1],pos)){ int ipos = pos.get_unsigned(); chain = chain_pos_add(ipos,chain); ast Q2 = destruct_cond_ineq(args[2],Aproves,Bproves); ast interp = contra_chain(Q2,chain); return my_and(Aproves,my_implies(Bproves,interp)); } throw "bad cong"; } bool is_equivrel(const ast &p){ opr o = op(p); return o == Equal || o == Iff; } struct rewrites_failed{}; /* Suppose p in Lang(B) and A |- p -> q and B |- q -> r. Return a z in Lang(B) such that B |- p -> z and A |- z -> q. Collect any side conditions in "rules". */ ast commute_rewrites(const ast &p, const ast &q, const ast &r, ast &rules){ if(q == r) return p; if(p == q) return r; else { ast rew = make(Equal,q,r); if(get_term_type(rew) == LitB){ apply_common_rewrites(p,p,q,rules); // A rewrites must be over comon vocab return r; } } if(sym(p) != sym(q) || sym(q) != sym(r)) throw rewrites_failed(); int nargs = num_args(p); if(nargs != num_args(q) || nargs != num_args(r)) throw rewrites_failed(); std::vector args; args.resize(nargs); for(int i = 0; i < nargs; i++) args[i] = commute_rewrites(arg(p,i),arg(q,i),arg(r,i),rules); return clone(p,args); } ast apply_common_rewrites(const ast &p, const ast &q, const ast &r, ast &rules){ if(q == r) return p; ast rew = make(Equal,q,r); if(term_common(rew)){ if(p != q) throw rewrites_failed(); rules = my_and(rules,rew); return r; } if(sym(p) != sym(q) || sym(q) != sym(r)) return p; int nargs = num_args(p); if(nargs != num_args(q) || nargs != num_args(r)) return p; std::vector args; args.resize(nargs); for(int i = 0; i < nargs; i++) args[i] = apply_common_rewrites(arg(p,i),arg(q,i),arg(r,i),rules); return clone(p,args); } ast apply_all_rewrites(const ast &p, const ast &q, const ast &r){ if(q == r) return p; if(p == q) return r; if(sym(p) != sym(q) || sym(q) != sym(r)) throw rewrites_failed(); int nargs = num_args(p); if(nargs != num_args(q) || nargs != num_args(r)) throw rewrites_failed(); std::vector args; args.resize(nargs); for(int i = 0; i < nargs; i++) args[i] = apply_all_rewrites(arg(p,i),arg(q,i),arg(r,i)); return clone(p,args); } ast delta(const ast &x, const ast &y){ if(op(x) != op(y) || (op(x) == Uninterpreted && sym(x) != sym(y)) || num_args(x) != num_args(y)) return make(Equal,x,y); ast res = mk_true(); int nargs = num_args(x); for(int i = 0; i < nargs; i++) res = my_and(res,delta(arg(x,i),arg(y,i))); return res; } bool diff_rec(LitType t, const ast &p, const ast &q, ast &pd, ast &qd){ if(p == q) return false; if(term_in_vocab(t,p) && term_in_vocab(t,q)){ pd = p; qd = q; return true; } else { if(sym(p) != sym(q)) return false; int nargs = num_args(p); if(num_args(q) != nargs) return false; for(int i = 0; i < nargs; i++) if(diff_rec(t,arg(p,i),arg(q,i),pd,qd)) return true; return false; } } void diff(LitType t, const ast &p, const ast &q, ast &pd, ast &qd){ if(!diff_rec(t,p,q,pd,qd)) throw cannot_simplify(); } bool apply_diff_rec(LitType t, const ast &inp, const ast &p, const ast &q, ast &out){ if(p == q) return false; if(term_in_vocab(t,p) && term_in_vocab(t,q)){ if(inp != p) throw cannot_simplify(); out = q; return true; } else { int nargs = num_args(p); if(sym(p) != sym(q)) throw cannot_simplify(); if(num_args(q) != nargs) throw cannot_simplify(); if(sym(p) != sym(inp)) throw cannot_simplify(); if(num_args(inp) != nargs) throw cannot_simplify(); for(int i = 0; i < nargs; i++) if(apply_diff_rec(t,arg(inp,i),arg(p,i),arg(q,i),out)) return true; return false; } } ast apply_diff(LitType t, const ast &inp, const ast &p, const ast &q){ ast out; if(!apply_diff_rec(t,inp,p,q,out)) throw cannot_simplify(); return out; } bool merge_A_rewrites(const ast &A1, const ast &A2, ast &merged) { if(arg(A1,1) == arg(A2,0)){ merged = make(op(A1),arg(A1,0),arg(A2,1)); return true; } ast diff1l, diff1r, diff2l, diff2r,diffBl,diffBr; diff(LitA,arg(A1,0),arg(A1,1),diff1l,diff1r); diff(LitA,arg(A2,0),arg(A2,1),diff2l,diff2r); diff(LitB,arg(A1,1),arg(A2,0),diffBl,diffBr); if(!term_common(diff2l) && !term_common(diffBr)){ ast A1r = apply_diff(LitB,arg(A2,1),arg(A2,0),arg(A1,1)); merged = make(op(A1),arg(A1,0),A1r); return true; } if(!term_common(diff1r) && !term_common(diffBl)){ ast A2l = apply_diff(LitB,arg(A1,0),arg(A1,1),arg(A2,0)); merged = make(op(A1),A2l,arg(A2,1)); return true; } return false; } void collect_A_rewrites(const ast &t, std::vector &res){ if(is_true(t)) return; if(sym(t) == concat){ res.push_back(arg(t,0)); collect_A_rewrites(arg(t,0),res); return; } res.push_back(t); } ast concat_A_rewrites(const std::vector &rew){ if(rew.size() == 0) return mk_true(); ast res = rew[0]; for(unsigned i = 1; i < rew.size(); i++) res = make(concat,res,rew[i]); return res; } ast merge_concat_rewrites(const ast &A1, const ast &A2){ std::vector rew; collect_A_rewrites(A1,rew); int first = rew.size(), last = first; // range that might need merging collect_A_rewrites(A2,rew); while(first > 0 && first < (int)rew.size() && first <= last){ ast merged; if(merge_A_rewrites(rew[first-1],rew[first],merged)){ rew[first] = merged; first--; rew.erase(rew.begin()+first); last--; if(first >= last) last = first+1; } else first++; } return concat_A_rewrites(rew); } ast sep_cond(const ast &t, ast &cond){ if(op(t) == Implies){ cond = my_and(cond,arg(t,0)); return arg(t,1); } return t; } /* operations on term positions */ /** Finds the difference between two positions. If p1 < p2 (p1 is a position below p2), returns -1 and sets diff to p2-p1 (the psath from position p2 to position p1). If p2 < p1 (p2 is a position below p1), returns 1 and sets diff to p1-p2 (the psath from position p1 to position p2). If equal, return 0 and set diff to top_pos. Else (if p1 and p2 are independent) returns 2 and leaves diff unchanged. */ int pos_diff(const ast &p1, const ast &p2, ast &diff){ if(p1 == top_pos && p2 != top_pos){ diff = p2; return 1; } if(p2 == top_pos && p1 != top_pos){ diff = p1; return -1; } if(p1 == top_pos && p2 == top_pos){ diff = p1; return 0; } if(arg(p1,0) == arg(p2,0)) // same argument position, recur return pos_diff(arg(p1,1),arg(p2,1),diff); return 2; } /* return the position of pos in the argth argument */ ast pos_add(int arg, const ast &pos){ return make(add_pos,make_int(rational(arg)),pos); } ast add_pos_to_end(const ast &pos, int i){ if(pos == top_pos) return pos_add(i,pos); return make(add_pos,arg(pos,0),add_pos_to_end(arg(pos,1),i)); } /* return the argument number of position, if not top */ int pos_arg(const ast &pos){ rational r; if(is_numeral(arg(pos,0),r)) return r.get_unsigned(); throw "bad position!"; } /* substitute y into position pos in x */ ast subst_in_pos(const ast &x, const ast &pos, const ast &y){ if(pos == top_pos) return y; int p = pos_arg(pos); int nargs = num_args(x); if(p >= 0 && p < nargs){ std::vector args(nargs); for(int i = 0; i < nargs; i++) args[i] = i == p ? subst_in_pos(arg(x,i),arg(pos,1),y) : arg(x,i); return clone(x,args); } throw "bad term position!"; } ast diff_chain(LitType t, const ast &pos, const ast &x, const ast &y, const ast &prefix){ int nargs = num_args(x); if(x == y) return prefix; if(sym(x) == sym(y) && nargs == num_args(y)){ ast res = prefix; for(int i = 0; i < nargs; i++) res = diff_chain(t,pos_add(i,pos),arg(x,i),arg(y,i),res); return res; } return chain_cons(prefix,make_rewrite(t,pos,mk_true(),make_equiv_rel(x,y))); } /* operations on rewrites */ ast make_rewrite(LitType t, const ast &pos, const ast &cond, const ast &equality){ #if 0 if(pos == top_pos && op(equality) == Iff && !is_true(arg(equality,0))) throw "bad rewrite"; #endif if(!is_equivrel(equality)) throw "bad rewrite"; return make(t == LitA ? rewrite_A : rewrite_B, pos, cond, equality); } ast rewrite_pos(const ast &rew){ return arg(rew,0); } ast rewrite_cond(const ast &rew){ return arg(rew,1); } ast rewrite_equ(const ast &rew){ return arg(rew,2); } ast rewrite_lhs(const ast &rew){ return arg(arg(rew,2),0); } ast rewrite_rhs(const ast &rew){ return arg(arg(rew,2),1); } /* operations on rewrite chains */ ast chain_cons(const ast &chain, const ast &elem){ return make(concat,chain,elem); } ast chain_rest(const ast &chain){ return arg(chain,0); } ast chain_last(const ast &chain){ return arg(chain,1); } ast rewrite_update_rhs(const ast &rew, const ast &pos, const ast &new_rhs, const ast &new_cond){ ast foo = subst_in_pos(rewrite_rhs(rew),pos,new_rhs); ast equality = arg(rew,2); return make(sym(rew),rewrite_pos(rew),my_and(rewrite_cond(rew),new_cond),make(op(equality),arg(equality,0),foo)); } ast rewrite_update_lhs(const ast &rew, const ast &pos, const ast &new_lhs, const ast &new_cond){ ast foo = subst_in_pos(rewrite_lhs(rew),pos,new_lhs); ast equality = arg(rew,2); return make(sym(rew),rewrite_pos(rew),my_and(rewrite_cond(rew),new_cond),make(op(equality),foo,arg(equality,1))); } bool is_common_rewrite(const ast &rew){ return term_common(arg(rew,2)); } bool is_right_mover(const ast &rew){ return term_common(rewrite_lhs(rew)) && !term_common(rewrite_rhs(rew)); } bool is_left_mover(const ast &rew){ return term_common(rewrite_rhs(rew)) && !term_common(rewrite_lhs(rew)); } bool same_side(const ast &rew1, const ast &rew2){ return sym(rew1) == sym(rew2); } bool is_rewrite_side(LitType t, const ast &rew){ if(t == LitA) return sym(rew) == rewrite_A; return sym(rew) == rewrite_B; } LitType rewrite_side(const ast &rew){ return (sym(rew) == rewrite_A) ? LitA : LitB; } ast rewrite_to_formula(const ast &rew){ return my_implies(arg(rew,1),arg(rew,2)); } // make rewrite rew conditon on rewrite cond ast rewrite_conditional(const ast &cond, const ast &rew){ ast cf = rewrite_to_formula(cond); return make(sym(rew),arg(rew,0),my_and(arg(rew,1),cf),arg(rew,2)); } ast reverse_rewrite(const ast &rew){ ast equ = arg(rew,2); return make(sym(rew),arg(rew,0),arg(rew,1),make(op(equ),arg(equ,1),arg(equ,0))); } ast rewrite_pos_add(int apos, const ast &rew){ return make(sym(rew),pos_add(apos,arg(rew,0)),arg(rew,1),arg(rew,2)); } ast rewrite_pos_set(const ast &pos, const ast &rew){ return make(sym(rew),pos,arg(rew,1),arg(rew,2)); } ast rewrite_up(const ast &rew){ return make(sym(rew),arg(arg(rew,0),1),arg(rew,1),arg(rew,2)); } /** Adds a rewrite to a chain of rewrites, keeping the chain in normal form. An empty chain is represented by true.*/ ast add_rewrite_to_chain(const ast &chain, const ast &rewrite){ if(is_true(chain)) return chain_cons(chain,rewrite); ast last = chain_last(chain); ast rest = chain_rest(chain); if(same_side(last,rewrite)){ ast p1 = rewrite_pos(last); ast p2 = rewrite_pos(rewrite); ast diff; switch(pos_diff(p1,p2,diff)){ case 1: { ast absorb = rewrite_update_rhs(last,diff,rewrite_rhs(rewrite),rewrite_cond(rewrite)); return add_rewrite_to_chain(rest,absorb); } case 0: case -1: { ast absorb = rewrite_update_lhs(rewrite,diff,rewrite_lhs(last),rewrite_cond(last)); return add_rewrite_to_chain(rest,absorb); } default: {// independent case bool rm = is_right_mover(last); bool lm = is_left_mover(rewrite); if((lm && !rm) || (rm && !lm)) return chain_swap(rest,last,rewrite); } } } else { if(is_left_mover(rewrite)){ if(is_common_rewrite(last)) return add_rewrite_to_chain(chain_cons(rest,flip_rewrite(last)),rewrite); if(!is_left_mover(last)) return chain_swap(rest,last,rewrite); } if(is_right_mover(last)){ if(is_common_rewrite(rewrite)) return add_rewrite_to_chain(chain,flip_rewrite(rewrite)); if(!is_right_mover(rewrite)) return chain_swap(rest,last,rewrite); } } return chain_cons(chain,rewrite); } ast chain_swap(const ast &rest, const ast &last, const ast &rewrite){ return chain_cons(add_rewrite_to_chain(rest,rewrite),last); } ast flip_rewrite(const ast &rew){ symb flip_sym = (sym(rew) == rewrite_A) ? rewrite_B : rewrite_A; ast cf = rewrite_to_formula(rew); return make(flip_sym,arg(rew,0),my_implies(arg(rew,1),cf),arg(rew,2)); } /** concatenates two rewrite chains, keeping result in normal form. */ ast concat_rewrite_chain(const ast &chain1, const ast &chain2){ if(is_true(chain2)) return chain1; if(is_true(chain1)) return chain2; ast foo = concat_rewrite_chain(chain1,chain_rest(chain2)); return add_rewrite_to_chain(foo,chain_last(chain2)); } /** reverse a chain of rewrites */ ast reverse_chain_rec(const ast &chain, const ast &prefix){ if(is_true(chain)) return prefix; ast last = reverse_rewrite(chain_last(chain)); ast rest = chain_rest(chain); return reverse_chain_rec(rest,chain_cons(prefix,last)); } ast reverse_chain(const ast &chain){ return reverse_chain_rec(chain,mk_true()); } bool is_equivrel_chain(const ast &chain){ if(is_true(chain)) return true; ast last = chain_last(chain); ast rest = chain_rest(chain); if(is_true(rest)) return !is_true(rewrite_lhs(last)); return is_equivrel_chain(rest); } bool is_negation_chain(const ast &chain){ if(is_true(chain)) return false; ast last = chain_last(chain); ast rest = chain_rest(chain); if(is_true(rest)) return op(rewrite_rhs(last)) == Not; return is_negation_chain(rest); } ast commute_negation_chain(const ast &chain){ if(is_true(chain)) return chain; ast last = chain_last(chain); ast rest = chain_rest(chain); if(is_true(rest)){ ast old = rewrite_rhs(last); if(!(op(old) == Not)) throw "bad negative equality chain"; ast equ = arg(old,0); if(!is_equivrel(equ)) throw "bad negative equality chain"; last = rewrite_update_rhs(last,top_pos,make(Not,make(op(equ),arg(equ,1),arg(equ,0))),make(True)); return chain_cons(rest,last); } ast pos = rewrite_pos(last); if(pos == top_pos) throw "bad negative equality chain"; int idx = pos_arg(pos); if(idx != 0) throw "bad negative equality chain"; pos = arg(pos,1); if(pos == top_pos){ ast lhs = rewrite_lhs(last); ast rhs = rewrite_rhs(last); if(op(lhs) != Equal || op(rhs) != Equal) throw "bad negative equality chain"; last = make_rewrite(rewrite_side(last),rewrite_pos(last),rewrite_cond(last), make(Iff,make(Equal,arg(lhs,1),arg(lhs,0)),make(Equal,arg(rhs,1),arg(rhs,0)))); } else { idx = pos_arg(pos); if(idx == 0) idx = 1; else if(idx == 1) idx = 0; else throw "bad negative equality chain"; pos = pos_add(0,pos_add(idx,arg(pos,1))); last = make_rewrite(rewrite_side(last),pos,rewrite_cond(last),rewrite_equ(last)); } return chain_cons(commute_negation_chain(rest),last); } // split a rewrite chain into head and tail at last top-level rewrite ast get_head_chain(const ast &chain, ast &tail, bool is_not = true){ ast last = chain_last(chain); ast rest = chain_rest(chain); ast pos = rewrite_pos(last); if(pos == top_pos || (is_not && arg(pos,1) == top_pos)){ tail = mk_true(); return chain; } if(is_true(rest)) throw "bad rewrite chain"; ast head = get_head_chain(rest,tail,is_not); tail = chain_cons(tail,last); return head; } bool has_mixed_summands(const ast &e){ if(op(e) == Plus){ int nargs = num_args(e); for(int i = 0; i < nargs; i++) if(has_mixed_summands(arg(e,i))) return true; return false; } return get_term_type(e) == LitMixed; } // split a rewrite chain into head and tail at last sum with no mixed sumands ast get_right_movers(const ast &chain, const ast &rhs, ast &tail, ast &mid){ if(is_true(chain) || !has_mixed_summands(rhs)){ mid = rhs; tail = mk_true(); return chain; } ast last = chain_last(chain); ast rest = chain_rest(chain); ast mm = subst_in_pos(rhs,rewrite_pos(last),rewrite_lhs(last)); ast res = get_right_movers(rest,mm,tail,mid); tail = chain_cons(tail,last); return res; } // split a rewrite chain into head and tail at first sum with no mixed sumands ast get_left_movers(const ast &chain, const ast &lhs, ast &tail, ast &mid){ if(is_true(chain)){ mid = lhs; if(!has_mixed_summands(lhs)){ tail = mk_true(); return chain; } return ast(); } ast last = chain_last(chain); ast rest = chain_rest(chain); ast res = get_left_movers(rest,lhs,tail,mid); if(res.null()){ mid = subst_in_pos(mid,rewrite_pos(last),rewrite_rhs(last)); if(get_term_type(mid) != LitMixed){ tail = mk_true(); return chain; } return ast(); } tail = chain_cons(tail,last); return res; } struct cannot_split {}; /** Split a chain of rewrites two chains, operating on positions 0 and 1. Fail if any rewrite in the chain operates on top position. */ void split_chain_rec(const ast &chain, ast *res){ if(is_true(chain)) return; ast last = chain_last(chain); ast rest = chain_rest(chain); split_chain_rec(rest,res); ast pos = rewrite_pos(last); if(pos == top_pos){ if(rewrite_lhs(last) == rewrite_rhs(last)) return; // skip if it's a noop throw cannot_split(); } int arg = pos_arg(pos); if(arg<0 || arg > 1) throw cannot_split(); res[arg] = chain_cons(res[arg],rewrite_up(last)); } void split_chain(const ast &chain, ast *res){ res[0] = res[1] = mk_true(); split_chain_rec(chain,res); } ast extract_rewrites(const ast &chain, const ast &pos){ if(is_true(chain)) return chain; ast last = chain_last(chain); ast rest = chain_rest(chain); ast new_rest = extract_rewrites(rest,pos); ast p1 = rewrite_pos(last); ast diff; switch(pos_diff(p1,pos,diff)){ case -1: { ast new_last = rewrite_pos_set(diff, last); return chain_cons(new_rest,new_last); } case 1: if(rewrite_lhs(last) != rewrite_rhs(last)) throw "bad rewrite chain"; break; default:; } return new_rest; } ast down_chain(const ast &chain){ ast split[2]; split_chain(chain,split); return split[0]; } ast chain_conditions(LitType t, const ast &chain){ if(is_true(chain)) return mk_true(); ast last = chain_last(chain); ast rest = chain_rest(chain); ast cond = chain_conditions(t,rest); if(is_rewrite_side(t,last)) cond = my_and(cond,rewrite_cond(last)); return cond; } ast chain_formulas(LitType t, const ast &chain){ if(is_true(chain)) return mk_true(); ast last = chain_last(chain); ast rest = chain_rest(chain); ast cond = chain_formulas(t,rest); if(is_rewrite_side(t,last)) cond = my_and(cond,rewrite_equ(last)); return cond; } bool rewrites_from_to(const ast &chain, const ast &lhs, const ast &rhs){ if(is_true(chain)) return lhs == rhs; ast last = chain_last(chain); ast rest = chain_rest(chain); ast mid = subst_in_pos(rhs,rewrite_pos(last),rewrite_lhs(last)); return rewrites_from_to(rest,lhs,mid); } struct bad_ineq_inference {}; ast chain_ineqs(opr comp_op, LitType t, const ast &chain, const ast &lhs, const ast &rhs){ if(is_true(chain)){ if(lhs != rhs) throw bad_ineq_inference(); return make(Leq,make_int(rational(0)),make_int(rational(0))); } ast last = chain_last(chain); ast rest = chain_rest(chain); ast mid = subst_in_pos(rhs,rewrite_pos(last),rewrite_lhs(last)); ast cond = chain_ineqs(comp_op,t,rest,lhs,mid); if(is_rewrite_side(t,last)){ ast diff; if(comp_op == Leq) diff = make(Sub,rhs,mid); else diff = make(Sub,mid,rhs); ast foo = make(Leq,make_int("0"),z3_simplify(diff)); if(is_true(cond)) cond = foo; else { linear_comb(cond,make_int(rational(1)),foo); cond = simplify_ineq(cond); } } return cond; } ast ineq_to_lhs(const ast &ineq){ ast s = make(Leq,make_int(rational(0)),make_int(rational(0))); linear_comb(s,make_int(rational(1)),ineq); return simplify_ineq(s); } void eq_from_ineq(const ast &ineq, ast &lhs, ast &rhs){ // ast s = ineq_to_lhs(ineq); // ast srhs = arg(s,1); ast srhs = arg(ineq,0); if(op(srhs) == Plus && num_args(srhs) == 2 && arg(ineq,1) == make_int(rational(0))){ lhs = arg(srhs,0); rhs = arg(srhs,1); // if(op(lhs) == Times) // std::swap(lhs,rhs); if(op(rhs) == Times){ rhs = arg(rhs,1); // if(op(ineq) == Leq) // std::swap(lhs,rhs); return; } } if(op(ineq) == Leq || op(ineq) == Geq){ lhs = srhs; rhs = arg(ineq,1); return; } throw "bad ineq"; } ast chain_pos_add(int arg, const ast &chain){ if(is_true(chain)) return mk_true(); ast last = rewrite_pos_add(arg,chain_last(chain)); ast rest = chain_pos_add(arg,chain_rest(chain)); return chain_cons(rest,last); } ast apply_rewrite_chain(const ast &t, const ast &chain){ if(is_true(chain)) return t; ast last = chain_last(chain); ast rest = chain_rest(chain); ast mid = apply_rewrite_chain(t,rest); ast res = subst_in_pos(mid,rewrite_pos(last),rewrite_rhs(last)); return res; } ast drop_rewrites(LitType t, const ast &chain, ast &remainder){ if(!is_true(chain)){ ast last = chain_last(chain); ast rest = chain_rest(chain); if(is_rewrite_side(t,last)){ ast res = drop_rewrites(t,rest,remainder); remainder = chain_cons(remainder,last); return res; } } remainder = mk_true(); return chain; } // Normalization chains ast cons_normal(const ast &first, const ast &rest){ return make(normal_chain,first,rest); } ast normal_first(const ast &t){ return arg(t,0); } ast normal_rest(const ast &t){ return arg(t,1); } ast normal_lhs(const ast &t){ return arg(arg(t,0),0); } ast normal_rhs(const ast &t){ return arg(arg(t,0),1); } ast normal_proof(const ast &t){ return arg(t,1); } ast make_normal_step(const ast &lhs, const ast &rhs, const ast &proof){ return make(normal_step,make_equiv(lhs,rhs),proof); } ast make_normal(const ast &ineq, const ast &nrml){ if(!is_ineq(ineq)) throw "what?"; return make(normal,ineq,nrml); } ast fix_normal(const ast &lhs, const ast &rhs, const ast &proof){ LitType lhst = get_term_type(lhs); LitType rhst = get_term_type(rhs); if(lhst == LitMixed && (rhst != LitMixed || ast_id(lhs) < ast_id(rhs))) return make_normal_step(lhs,rhs,proof); if(rhst == LitMixed && (lhst != LitMixed || ast_id(rhs) < ast_id(lhs))) return make_normal_step(rhs,lhs,reverse_chain(proof)); throw "help!"; } ast chain_side_proves(LitType side, const ast &chain){ LitType other_side = side == LitA ? LitB : LitA; return my_and(chain_conditions(other_side,chain),my_implies(chain_conditions(side,chain),chain_formulas(side,chain))); } // Merge two normalization chains ast merge_normal_chains_rec(const ast &chain1, const ast &chain2, hash_map &trans, ast &Aproves, ast &Bproves){ if(is_true(chain1)) return chain2; if(is_true(chain2)) return chain1; ast f1 = normal_first(chain1); ast f2 = normal_first(chain2); ast lhs1 = normal_lhs(f1); ast lhs2 = normal_lhs(f2); int id1 = ast_id(lhs1); int id2 = ast_id(lhs2); if(id1 < id2) return cons_normal(f1,merge_normal_chains_rec(normal_rest(chain1),chain2,trans,Aproves,Bproves)); if(id2 < id1) return cons_normal(f2,merge_normal_chains_rec(chain1,normal_rest(chain2),trans,Aproves,Bproves)); ast rhs1 = normal_rhs(f1); ast rhs2 = normal_rhs(f2); LitType t1 = get_term_type(rhs1); LitType t2 = get_term_type(rhs2); int tid1 = ast_id(rhs1); int tid2 = ast_id(rhs2); ast pf1 = normal_proof(f1); ast pf2 = normal_proof(f2); ast new_normal; if(t1 == LitMixed && (t2 != LitMixed || tid2 > tid1)){ ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2); new_normal = f2; trans[rhs1] = make_normal_step(rhs1,rhs2,new_proof); } else if(t2 == LitMixed && (t1 != LitMixed || tid1 > tid2)) return merge_normal_chains_rec(chain2,chain1,trans,Aproves,Bproves); else if(t1 == LitA && t2 == LitB){ ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2); ast Bproof, Aproof = drop_rewrites(LitB,new_proof,Bproof); ast mcA = chain_side_proves(LitB,Aproof); Bproves = my_and(Bproves,mcA); ast mcB = chain_side_proves(LitA,Bproof); Aproves = my_and(Aproves,mcB); ast rep = apply_rewrite_chain(rhs1,Aproof); new_proof = concat_rewrite_chain(pf1,Aproof); new_normal = make_normal_step(lhs1,rep,new_proof); ast A_normal = make_normal_step(rhs1,rep,Aproof); ast res = cons_normal(new_normal,merge_normal_chains_rec(normal_rest(chain1),normal_rest(chain2),trans,Aproves,Bproves)); res = merge_normal_chains_rec(res,cons_normal(A_normal,make(True)),trans,Aproves,Bproves); return res; } else if(t1 == LitB && t2 == LitA) return merge_normal_chains_rec(chain2,chain1,trans,Aproves,Bproves); else if(t1 == LitA) { ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2); ast mc = chain_side_proves(LitB,new_proof); Bproves = my_and(Bproves,mc); new_normal = f1; // choice is arbitrary } else { /* t1 = t2 = LitB */ ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2); ast mc = chain_side_proves(LitA,new_proof); Aproves = my_and(Aproves,mc); new_normal = f1; // choice is arbitrary } return cons_normal(new_normal,merge_normal_chains_rec(normal_rest(chain1),normal_rest(chain2),trans,Aproves,Bproves)); } ast trans_normal_chain(const ast &chain, hash_map &trans){ if(is_true(chain)) return chain; ast f = normal_first(chain); ast r = normal_rest(chain); r = trans_normal_chain(r,trans); ast rhs = normal_rhs(f); hash_map::iterator it = trans.find(rhs); ast new_normal; if(it != trans.end() && get_term_type(normal_lhs(f)) == LitMixed){ const ast &f2 = it->second; ast pf = concat_rewrite_chain(normal_proof(f),normal_proof(f2)); new_normal = make_normal_step(normal_lhs(f),normal_rhs(f2),pf); } else new_normal = f; if(get_term_type(normal_lhs(f)) == LitMixed) trans[normal_lhs(f)] = new_normal; return cons_normal(new_normal,r); } ast merge_normal_chains(const ast &chain1, const ast &chain2, ast &Aproves, ast &Bproves){ hash_map trans; ast res = merge_normal_chains_rec(chain1,chain2,trans,Aproves,Bproves); res = trans_normal_chain(res,trans); return res; } bool destruct_cond_ineq(ast t, ast &Aproves, ast &Bproves, ast&ineq){ if(op(t) == And){ Aproves = arg(t,0); t = arg(t,1); } else Aproves = mk_true(); if(op(t) == Implies){ Bproves = arg(t,0); t = arg(t,1); } else Bproves = mk_true(); if(is_normal_ineq(t)){ ineq = t; return true; } return false; } ast cons_cond_ineq(const ast &Aproves, const ast &Bproves, const ast &ineq){ return my_and(Aproves,my_implies(Bproves,ineq)); } ast normalize(const ast &ct){ ast Aproves,Bproves,t; if(!destruct_cond_ineq(ct,Aproves,Bproves,t)) return ct; if(sym(t) != normal) return ct; ast chain = arg(t,1); hash_map map; for(ast c = chain; !is_true(c); c = normal_rest(c)){ ast first = normal_first(c); ast lhs = normal_lhs(first); ast rhs = normal_rhs(first); map[lhs] = rhs; } ast res = subst(map,arg(t,0)); return cons_cond_ineq(Aproves,Bproves,res); } /** Make an assumption node. The given clause is assumed in the given frame. */ virtual node make_assumption(int frame, const std::vector &assumption){ if(!weak){ if(pv->in_range(frame,rng)){ std::vector itp_clause; for(unsigned i = 0; i < assumption.size(); i++) if(get_term_type(assumption[i]) != LitA) itp_clause.push_back(assumption[i]); ast res = my_or(itp_clause); return res; } else { return mk_true(); } } else { if(pv->in_range(frame,rng)){ return mk_false(); } else { std::vector itp_clause; for(unsigned i = 0; i < assumption.size(); i++) if(get_term_type(assumption[i]) != LitB) itp_clause.push_back(assumption[i]); ast res = my_or(itp_clause); return mk_not(res); } } } ast make_local_rewrite(LitType t, const ast &p){ ast rew = is_equivrel(p) ? p : make(Iff,mk_true(),p); #if 0 if(op(rew) == Iff && !is_true(arg(rew,0))) return diff_chain(t,top_pos,arg(rew,0),arg(rew,1), mk_true()); #endif return chain_cons(mk_true(),make_rewrite(t, top_pos, mk_true(), rew)); } ast triv_interp(const symb &rule, const std::vector &premises, int mask_in){ std::vector ps; ps.resize(premises.size()); std::vector conjs; int mask = 0; for(unsigned i = 0; i < ps.size(); i++){ ast p = premises[i]; LitType t = get_term_type(p); switch(t){ case LitA: case LitB: ps[i] = make_local_rewrite(t,p); break; default: ps[i] = get_placeholder(p); // can only prove consequent! if(mask_in & (1 << i)) mask |= (1 << conjs.size()); conjs.push_back(p); } } ast ref = make(rule,ps); ast res = make_contra_node(ref,conjs,mask); return res; } ast triv_interp(const symb &rule, const ast &p0, const ast &p1, const ast &p2, int mask){ std::vector ps; ps.resize(3); ps[0] = p0; ps[1] = p1; ps[2] = p2; return triv_interp(rule,ps,mask); } /** Make a modus-ponens node. This takes derivations of |- x and |- x = y and produces |- y */ virtual node make_mp(const ast &p_eq_q, const ast &prem1, const ast &prem2){ /* Interpolate the axiom p, p=q -> q */ ast p = arg(p_eq_q,0); ast q = arg(p_eq_q,1); ast itp; if(get_term_type(p_eq_q) == LitMixed){ int mask = 1 << 2; if(op(p) == Not && is_equivrel(arg(p,0))) mask |= 1; // we may need to run this rule backward if first premise is negative equality itp = triv_interp(modpon,p,p_eq_q,mk_not(q),mask); } else { if(get_term_type(p) == LitA){ if(get_term_type(q) == LitA) itp = mk_false(); else { if(get_term_type(p_eq_q) == LitA) itp = q; else throw proof_error(); } } else { if(get_term_type(q) == LitA){ if(get_term_type(make(Equal,p,q)) == LitA) itp = mk_not(p); else throw proof_error(); } else itp = mk_true(); } } /* Resolve it with the premises */ std::vector conc; conc.push_back(q); conc.push_back(mk_not(p_eq_q)); itp = make_resolution(p,conc,itp,prem1); conc.pop_back(); itp = make_resolution(p_eq_q,conc,itp,prem2); return itp; } ast capture_localization(ast e){ // #define CAPTURE_LOCALIZATION #ifdef CAPTURE_LOCALIZATION for(int i = localization_vars.size() - 1; i >= 0; i--){ LocVar &lv = localization_vars[i]; if(occurs_in(lv.var,e)){ symb q = (pv->in_range(lv.frame,rng)) ? sexists : sforall; e = make(q,make(Equal,lv.var,lv.term),e); // use Equal because it is polymorphic } } #endif return e; } /** Make an axiom node. The conclusion must be an instance of an axiom. */ virtual node make_axiom(const std::vector &conclusion, prover::range frng){ int nargs = conclusion.size(); std::vector largs(nargs); std::vector eqs; std::vector pfs; for(int i = 0; i < nargs; i++){ ast argpf; ast lit = conclusion[i]; largs[i] = localize_term(lit,frng,argpf); frng = pv->range_glb(frng,pv->ast_scope(largs[i])); if(largs[i] != lit){ eqs.push_back(make_equiv(largs[i],lit)); pfs.push_back(argpf); } } int frame = pv->range_max(frng); ast itp = make_assumption(frame,largs); for(unsigned i = 0; i < eqs.size(); i++) itp = make_mp(eqs[i],itp,pfs[i]); return capture_localization(itp); } virtual node make_axiom(const std::vector &conclusion){ return make_axiom(conclusion,pv->range_full()); } /** Make a Contra node. This rule takes a derivation of the form Gamma |- False and produces |- \/~Gamma. */ virtual node make_contra(node prem, const std::vector &conclusion){ return prem; } /** Make hypothesis. Creates a node of the form P |- P. */ virtual node make_hypothesis(const ast &P){ if(is_not(P)) return make_hypothesis(arg(P,0)); switch(get_term_type(P)){ case LitA: return mk_false(); case LitB: return mk_true(); default: // mixed hypothesis switch(op(P)){ case Geq: case Leq: case Gt: case Lt: { ast zleqz = make(Leq,make_int("0"),make_int("0")); ast fark1 = make(sum,zleqz,make_int("1"),get_placeholder(P)); ast fark2 = make(sum,fark1,make_int("1"),get_placeholder(mk_not(P))); ast res = make(And,make(contra,fark2,mk_false()), make(contra,get_placeholder(mk_not(P)),P), make(contra,get_placeholder(P),mk_not(P))); return res; } default: { ast em = make(exmid,P,get_placeholder(P),get_placeholder(mk_not(P))); ast res = make(And,make(contra,em,mk_false()), make(contra,get_placeholder(mk_not(P)),P), make(contra,get_placeholder(P),mk_not(P))); return res; } } } } /** Make a Reflexivity node. This rule produces |- x = x */ virtual node make_reflexivity(ast con){ if(get_term_type(con) == LitA) return mk_false(); if(get_term_type(con) == LitB) return mk_true(); ast itp = make(And,make(contra,no_proof,mk_false()), make(contra,mk_true(),mk_not(con))); return itp; } /** Make a Symmetry node. This takes a derivation of |- x = y and produces | y = x. Ditto for ~(x=y) */ virtual node make_symmetry(ast con, const ast &premcon, node prem){ #if 0 ast x = arg(con,0); ast y = arg(con,1); ast p = make(op(con),y,x); #endif if(get_term_type(con) != LitMixed) return prem; // symmetry shmymmetry... ast em = make(exmid,con,make(symm,get_placeholder(premcon)),get_placeholder(mk_not(con))); ast itp = make(And,make(contra,em,mk_false()), make(contra,make(symm,get_placeholder(mk_not(con))),premcon), make(contra,make(symm,get_placeholder(premcon)),mk_not(con))); std::vector conc; conc.push_back(con); itp = make_resolution(premcon,conc,itp,prem); return itp; } ast make_equiv_rel(const ast &x, const ast &y){ if(is_bool_type(get_type(x))) return make(Iff,x,y); return make(Equal,x,y); } /** Make a transitivity node. This takes derivations of |- x = y and |- y = z produces | x = z */ virtual node make_transitivity(const ast &x, const ast &y, const ast &z, node prem1, node prem2){ /* Interpolate the axiom x=y,y=z,-> x=z */ ast p = make_equiv_rel(x,y); ast q = make_equiv_rel(y,z); ast r = make_equiv_rel(x,z); ast equiv = make(Iff,p,r); ast itp; itp = make_congruence(q,equiv,prem2); itp = make_mp(equiv,prem1,itp); return itp; } /** Make a congruence node. This takes derivations of |- x_i = y_i and produces |- f(x_1,...,x_n) = f(y_1,...,y_n) */ virtual node make_congruence(const ast &p, const ast &con, const ast &prem1){ ast x = arg(p,0), y = arg(p,1); ast itp; LitType con_t = get_term_type(con); if(get_term_type(p) == LitA){ if(con_t == LitA) itp = mk_false(); else if(con_t == LitB) itp = p; else itp = make_mixed_congruence(x, y, p, con, prem1); } else { if(con_t == LitA) itp = mk_not(p); else{ if(con_t == LitB) itp = mk_true(); else itp = make_mixed_congruence(x, y, p, con, prem1); } } std::vector conc; conc.push_back(con); itp = make_resolution(p,conc,itp,prem1); return itp; } int find_congruence_position(const ast &p, const ast &con){ // find the argument position of x and y const ast &x = arg(p,0); const ast &y = arg(p,1); int nargs = num_args(arg(con,0)); for(int i = 0; i < nargs; i++) if(x == arg(arg(con,0),i) && y == arg(arg(con,1),i)) return i; throw proof_error(); } /** Make a congruence node. This takes derivations of |- x_i1 = y_i1, |- x_i2 = y_i2,... and produces |- f(...x_i1...x_i2...) = f(...y_i1...y_i2...) */ node make_congruence(const std::vector &p, const ast &con, const std::vector &prems){ if(p.size() == 0) throw proof_error(); if(p.size() == 1) return make_congruence(p[0],con,prems[0]); ast thing = con; ast res = mk_true(); for(unsigned i = 0; i < p.size(); i++){ int pos = find_congruence_position(p[i],thing); ast next = subst_in_arg_pos(pos,arg(p[i],1),arg(thing,0)); ast goal = make(op(thing),arg(thing,0),next); ast equa = make_congruence(p[i],goal,prems[i]); if(i == 0) res = equa; else { ast trace = make(op(con),arg(con,0),arg(thing,0)); ast equiv = make(Iff,trace,make(op(trace),arg(trace,0),next)); ast foo = make_congruence(goal,equiv,equa); res = make_mp(equiv,res,foo); } thing = make(op(thing),next,arg(thing,1)); } return res; } /* Interpolate a mixed congruence axiom. */ virtual ast make_mixed_congruence(const ast &x, const ast &y, const ast &p, const ast &con, const ast &prem1){ ast foo = p; std::vector conjs; LitType t = get_term_type(foo); switch(t){ case LitA: case LitB: foo = make_local_rewrite(t,foo); break; case LitMixed: conjs.push_back(foo); foo = get_placeholder(foo); } // find the argument position of x and y int pos = -1; int nargs = num_args(arg(con,0)); for(int i = 0; i < nargs; i++) if(x == arg(arg(con,0),i) && y == arg(arg(con,1),i)) pos = i; if(pos == -1) throw proof_error(); ast bar = make(cong,foo,make_int(rational(pos)),get_placeholder(mk_not(con))); conjs.push_back(mk_not(con)); return make_contra_node(bar,conjs); } ast subst_in_arg_pos(int pos, ast term, ast app){ std::vector args; get_args(app,args); args[pos] = term; return clone(app,args); } /** Make a farkas proof node. */ virtual node make_farkas(ast con, const std::vector &prems, const std::vector &prem_cons, const std::vector &coeffs){ /* Compute the interpolant for the clause */ ast zero = make_int("0"); std::vector conjs; ast thing = make(Leq,zero,zero); for(unsigned i = 0; i < prem_cons.size(); i++){ const ast &lit = prem_cons[i]; if(get_term_type(lit) == LitA) // Farkas rule seems to assume strict integer inequalities are rounded linear_comb(thing,coeffs[i],lit,true /*round_off*/); } thing = simplify_ineq(thing); for(unsigned i = 0; i < prem_cons.size(); i++){ const ast &lit = prem_cons[i]; if(get_term_type(lit) == LitMixed){ thing = make(sum,thing,coeffs[i],get_placeholder(lit)); conjs.push_back(lit); } } thing = make_contra_node(thing,conjs); /* Resolve it with the premises */ std::vector conc; conc.resize(prem_cons.size()); for(unsigned i = 0; i < prem_cons.size(); i++) conc[prem_cons.size()-i-1] = prem_cons[i]; for(unsigned i = 0; i < prem_cons.size(); i++){ thing = make_resolution(prem_cons[i],conc,thing,prems[i]); conc.pop_back(); } return thing; } /** Set P to P + cQ, where P and Q are linear inequalities. Assumes P is 0 <= y or 0 < y. */ void linear_comb(ast &P, const ast &c, const ast &Q, bool round_off = false){ ast Qrhs; bool qstrict = false; if(is_not(Q)){ ast nQ = arg(Q,0); switch(op(nQ)){ case Gt: Qrhs = make(Sub,arg(nQ,1),arg(nQ,0)); break; case Lt: Qrhs = make(Sub,arg(nQ,0),arg(nQ,1)); break; case Geq: Qrhs = make(Sub,arg(nQ,1),arg(nQ,0)); qstrict = true; break; case Leq: Qrhs = make(Sub,arg(nQ,0),arg(nQ,1)); qstrict = true; break; default: throw proof_error(); } } else { switch(op(Q)){ case Leq: Qrhs = make(Sub,arg(Q,1),arg(Q,0)); break; case Geq: Qrhs = make(Sub,arg(Q,0),arg(Q,1)); break; case Lt: Qrhs = make(Sub,arg(Q,1),arg(Q,0)); qstrict = true; break; case Gt: Qrhs = make(Sub,arg(Q,0),arg(Q,1)); qstrict = true; break; default: throw proof_error(); } } #if 0 bool pstrict = op(P) == Lt, strict = pstrict || qstrict; if(pstrict && qstrict && round_off) Qrhs = make(Sub,Qrhs,make_int(rational(1))); #else bool pstrict = op(P) == Lt; if(qstrict && round_off && (pstrict || !(c == make_int(rational(1))))){ Qrhs = make(Sub,Qrhs,make_int(rational(1))); qstrict = false; } Qrhs = make(Times,c,Qrhs); bool strict = pstrict || qstrict; #endif if(strict) P = make(Lt,arg(P,0),make(Plus,arg(P,1),Qrhs)); else P = make(Leq,arg(P,0),make(Plus,arg(P,1),Qrhs)); } /* Make an axiom instance of the form |- x<=y, y<= x -> x =y */ virtual node make_leq2eq(ast x, ast y, const ast &xleqy, const ast &yleqx){ ast con = make(Equal,x,y); ast itp; switch(get_term_type(con)){ case LitA: itp = mk_false(); break; case LitB: itp = mk_true(); break; default: { // mixed equality if(get_term_type(x) == LitMixed || get_term_type(y) == LitMixed){ // std::cerr << "WARNING: mixed term in leq2eq\n"; std::vector lits; lits.push_back(con); lits.push_back(make(Not,xleqy)); lits.push_back(make(Not,yleqx)); return make_axiom(lits); } std::vector conjs; conjs.resize(3); conjs[0] = mk_not(con); conjs[1] = xleqy; conjs[2] = yleqx; itp = make_contra_node(make(leq2eq, get_placeholder(mk_not(con)), get_placeholder(xleqy), get_placeholder(yleqx)), conjs,1); } } return itp; } /* Make an axiom instance of the form |- x = y -> x <= y */ virtual node make_eq2leq(ast x, ast y, const ast &xleqy){ ast itp; switch(get_term_type(xleqy)){ case LitA: itp = mk_false(); break; case LitB: itp = mk_true(); break; default: { // mixed equality std::vector conjs; conjs.resize(2); conjs[0] = make(Equal,x,y); conjs[1] = mk_not(xleqy); itp = make(eq2leq,get_placeholder(conjs[0]),get_placeholder(conjs[1])); itp = make_contra_node(itp,conjs,2); } } return itp; } /* Make an inference of the form t <= c |- t/d <= floor(c/d) where t is an affine term divisble by d and c is an integer constant */ virtual node make_cut_rule(const ast &tleqc, const ast &d, const ast &con, const ast &prem){ ast itp = mk_false(); switch(get_term_type(con)){ case LitA: itp = mk_false(); break; case LitB: itp = mk_true(); break; default: { std::vector conjs; conjs.resize(2); conjs[0] = tleqc; conjs[1] = mk_not(con); itp = make(sum,get_placeholder(conjs[0]),d,get_placeholder(conjs[1])); itp = make_contra_node(itp,conjs); } } std::vector conc; conc.push_back(con); itp = make_resolution(tleqc,conc,itp,prem); return itp; } // create a fresh variable for localization ast fresh_localization_var(const ast &term, int frame){ std::ostringstream s; s << "%" << (localization_vars.size()); ast var = make_var(s.str().c_str(),get_type(term)); pv->sym_range(sym(var)) = pv->range_full(); // make this variable global localization_vars.push_back(LocVar(var,term,frame)); return var; } struct LocVar { // localization vars ast var; // a fresh variable ast term; // term it represents int frame; // frame in which it's defined LocVar(ast v, ast t, int f){var=v;term=t;frame=f;} }; std::vector localization_vars; // localization vars in order of creation hash_map localization_map; // maps terms to their localization vars hash_map localization_pf_map; // maps terms to proofs of their localizations /* "localize" a term e to a given frame range, creating new symbols to represent non-local subterms. This returns the localized version e_l, as well as a proof thet e_l = l. */ ast make_refl(const ast &e){ if(get_term_type(e) == LitA) return mk_false(); return mk_true(); // TODO: is this right? } ast make_equiv(const ast &x, const ast &y){ if(get_type(x) == bool_type()) return make(Iff,x,y); else return make(Equal,x,y); } ast localize_term(ast e, const prover::range &rng, ast &pf){ ast orig_e = e; pf = make_refl(e); // proof that e = e prover::range erng = pv->ast_scope(e); #if 0 if(!(erng.lo > erng.hi) && pv->ranges_intersect(pv->ast_scope(e),rng)){ return e; // this term occurs in range, so it's O.K. } #endif hash_map::iterator it = localization_map.find(e); if(it != localization_map.end() && is_bool_type(get_type(e)) && !pv->ranges_intersect(pv->ast_scope(it->second),rng)) it = localization_map.end(); // prevent quantifiers over booleans if(it != localization_map.end()){ pf = localization_pf_map[e]; e = it->second; } else { // if it is non-local, we must first localize the arguments to // the range of its function symbol int nargs = num_args(e); if(nargs > 0 /* && (!is_local(e) || flo <= hi || fhi >= lo) */){ prover::range frng = rng; opr o = op(e); if(o == Uninterpreted){ symb f = sym(e); prover::range srng = pv->sym_range(f); if(pv->ranges_intersect(srng,rng)) // localize to desired range if possible frng = pv->range_glb(srng,rng); else frng = srng; // this term will be localized } else if(o == Plus || o == Times){ // don't want bound variables inside arith ops // std::cout << "WARNING: non-local arithmetic\n"; // frng = erng; // this term will be localized } else if(o == Select){ // treat the array term like a function symbol prover::range srng = pv->ast_scope(arg(e,0)); if(!(srng.lo > srng.hi) && pv->ranges_intersect(srng,rng)) // localize to desired range if possible frng = pv->range_glb(srng,rng); else frng = srng; // this term will be localized } std::vector largs(nargs); std::vector eqs; std::vector pfs; for(int i = 0; i < nargs; i++){ ast argpf; largs[i] = localize_term(arg(e,i),frng,argpf); frng = pv->range_glb(frng,pv->ast_scope(largs[i])); if(largs[i] != arg(e,i)){ eqs.push_back(make_equiv(largs[i],arg(e,i))); pfs.push_back(argpf); } } e = clone(e,largs); if(pfs.size()) pf = make_congruence(eqs,make_equiv(e,orig_e),pfs); // assert(is_local(e)); } localization_pf_map[orig_e] = pf; localization_map[orig_e] = e; } if(pv->ranges_intersect(pv->ast_scope(e),rng)) return e; // this term occurs in range, so it's O.K. if(is_array_type(get_type(e))) std::cerr << "WARNING: array quantifier\n"; // choose a frame for the constraint that is close to range int frame = pv->range_near(pv->ast_scope(e),rng); ast new_var = fresh_localization_var(e,frame); localization_map[orig_e] = new_var; std::vector foo; foo.push_back(make_equiv(new_var,e)); ast bar = make_assumption(frame,foo); pf = make_transitivity(new_var,e,orig_e,bar,pf); localization_pf_map[orig_e] = pf; return new_var; } ast delete_quant(hash_map &memo, const ast &v, const ast &e){ std::pair foo(e,ast()); std::pair::iterator,bool> bar = memo.insert(foo); ast &res = bar.first->second; if(bar.second){ opr o = op(e); switch(o){ case Or: case And: case Implies: { unsigned nargs = num_args(e); std::vector args; args.resize(nargs); for(unsigned i = 0; i < nargs; i++) args[i] = delete_quant(memo, v, arg(e,i)); res = make(o,args); break; } case Uninterpreted: { symb s = sym(e); ast w = arg(arg(e,0),0); if(s == sforall || s == sexists){ res = delete_quant(memo,v,arg(e,1)); if(w != v) res = make(s,w,res); break; } } default: res = e; } } return res; } ast insert_quants(hash_map &memo, const ast &e){ std::pair foo(e,ast()); std::pair::iterator,bool> bar = memo.insert(foo); ast &res = bar.first->second; if(bar.second){ opr o = op(e); switch(o){ case Or: case And: case Implies: { unsigned nargs = num_args(e); std::vector args; args.resize(nargs); for(unsigned i = 0; i < nargs; i++) args[i] = insert_quants(memo, arg(e,i)); res = make(o,args); break; } case Uninterpreted: { symb s = sym(e); if(s == sforall || s == sexists){ opr q = (s == sforall) ? Forall : Exists; ast v = arg(arg(e,0),0); hash_map dmemo; ast body = delete_quant(dmemo,v,arg(e,1)); body = insert_quants(memo,body); res = apply_quant(q,v,body); break; } } default: res = e; } } return res; } ast add_quants(ast e){ #ifdef CAPTURE_LOCALIZATION if(!localization_vars.empty()){ hash_map memo; e = insert_quants(memo,e); } #else for(int i = localization_vars.size() - 1; i >= 0; i--){ LocVar &lv = localization_vars[i]; opr quantifier = (pv->in_range(lv.frame,rng)) ? Exists : Forall; e = apply_quant(quantifier,lv.var,e); } #endif return e; } node make_resolution(ast pivot, node premise1, node premise2) { std::vector lits; return make_resolution(pivot,lits,premise1,premise2); } /* Return an interpolant from a proof of false */ ast interpolate(const node &pf){ // proof of false must be a formula, with quantified symbols #ifndef BOGUS_QUANTS return add_quants(z3_simplify(pf)); #else return z3_simplify(pf); #endif } ast resolve_with_quantifier(const ast &pivot1, const ast &conj1, const ast &pivot2, const ast &conj2){ if(is_not(arg(pivot1,1))) return resolve_with_quantifier(pivot2,conj2,pivot1,conj1); ast eqpf; ast P = arg(pivot1,1); ast Ploc = localize_term(P, rng, eqpf); ast pPloc = make_hypothesis(Ploc); ast pP = make_mp(make(Iff,Ploc,P),pPloc,eqpf); ast rP = make_resolution(P,conj1,pP); ast nP = mk_not(P); ast nPloc = mk_not(Ploc); ast neqpf = make_congruence(make(Iff,Ploc,P),make(Iff,nPloc,nP),eqpf); ast npPloc = make_hypothesis(nPloc); ast npP = make_mp(make(Iff,nPloc,nP),npPloc,neqpf); ast nrP = make_resolution(nP,conj2,npP); ast res = make_resolution(Ploc,rP,nrP); return capture_localization(res); } ast get_contra_coeff(const ast &f){ ast c = arg(f,0); // if(!is_not(arg(f,1))) // c = make(Uminus,c); return c; } ast my_or(const ast &a, const ast &b){ return mk_or(a,b); } ast my_and(const ast &a, const ast &b){ return mk_and(a,b); } ast my_implies(const ast &a, const ast &b){ return mk_implies(a,b); } ast my_or(const std::vector &a){ return mk_or(a); } ast my_and(const std::vector &a){ return mk_and(a); } ast get_lit_atom(const ast &l){ if(op(l) == Not) return arg(l,0); return l; } bool is_placeholder(const ast &e){ if(op(e) == Uninterpreted){ std::string name = string_of_symbol(sym(e)); if(name.size() > 2 && name[0] == '@' && name[1] == 'p') return true; } return false; } public: iz3proof_itp_impl(prover *p, const prover::range &r, bool w) : iz3proof_itp(*p) { pv = p; rng = r; weak = false ; //w; type boolintbooldom[3] = {bool_type(),int_type(),bool_type()}; type booldom[1] = {bool_type()}; type boolbooldom[2] = {bool_type(),bool_type()}; type boolboolbooldom[3] = {bool_type(),bool_type(),bool_type()}; type intbooldom[2] = {int_type(),bool_type()}; contra = function("@contra",2,boolbooldom,bool_type()); m().inc_ref(contra); sum = function("@sum",3,boolintbooldom,bool_type()); m().inc_ref(sum); rotate_sum = function("@rotsum",2,boolbooldom,bool_type()); m().inc_ref(rotate_sum); leq2eq = function("@leq2eq",3,boolboolbooldom,bool_type()); m().inc_ref(leq2eq); eq2leq = function("@eq2leq",2,boolbooldom,bool_type()); m().inc_ref(eq2leq); cong = function("@cong",3,boolintbooldom,bool_type()); m().inc_ref(cong); exmid = function("@exmid",3,boolboolbooldom,bool_type()); m().inc_ref(exmid); symm = function("@symm",1,booldom,bool_type()); m().inc_ref(symm); epsilon = make_var("@eps",int_type()); modpon = function("@mp",3,boolboolbooldom,bool_type()); m().inc_ref(modpon); no_proof = make_var("@nop",bool_type()); concat = function("@concat",2,boolbooldom,bool_type()); m().inc_ref(concat); top_pos = make_var("@top_pos",bool_type()); add_pos = function("@add_pos",2,intbooldom,bool_type()); m().inc_ref(add_pos); rewrite_A = function("@rewrite_A",3,boolboolbooldom,bool_type()); m().inc_ref(rewrite_A); rewrite_B = function("@rewrite_B",3,boolboolbooldom,bool_type()); m().inc_ref(rewrite_B); normal_step = function("@normal_step",2,boolbooldom,bool_type()); m().inc_ref(normal_step); normal_chain = function("@normal_chain",2,boolbooldom,bool_type()); m().inc_ref(normal_chain); normal = function("@normal",2,boolbooldom,bool_type()); m().inc_ref(normal); sforall = function("@sforall",2,boolbooldom,bool_type()); m().inc_ref(sforall); sexists = function("@sexists",2,boolbooldom,bool_type()); m().inc_ref(sexists); } ~iz3proof_itp_impl(){ m().dec_ref(contra); m().dec_ref(sum); m().dec_ref(rotate_sum); m().dec_ref(leq2eq); m().dec_ref(eq2leq); m().dec_ref(cong); m().dec_ref(exmid); m().dec_ref(symm); m().dec_ref(modpon); m().dec_ref(concat); m().dec_ref(add_pos); m().dec_ref(rewrite_A); m().dec_ref(rewrite_B); } }; iz3proof_itp *iz3proof_itp::create(prover *p, const prover::range &r, bool w){ return new iz3proof_itp_impl(p,r,w); }