/*++ Copyright (c) 2011 Microsoft Corporation Module Name: iz3proof.cpp Abstract: This class defines a simple interpolating proof system. Author: Ken McMillan (kenmcmil) Revision History: --*/ #include "iz3proof_itp.h" #ifndef WIN32 using namespace stl_ext; #endif // #define INVARIANT_CHECKING class iz3proof_itp_impl : public iz3proof_itp { prover *pv; prover::range rng; bool weak; enum LitType {LitA,LitB,LitMixed}; hash_map placeholders; // These symbols represent deduction rules /* This symbol represents a proof by contradiction. That is, contra(p,l1 /\ ... /\ lk) takes a proof p of l1,...,lk |- false and returns a proof of |- ~l1,...,~l2 */ symb contra; /* The summation rule. The term sum(p,c,i) takes a proof p of an inequality i', an integer coefficient c and an inequality i, and yieds a proof of i' + ci. */ symb sum; /* Proof rotation. The proof term rotate(q,p) takes a proof p of: Gamma, q |- false and yields a proof of: Gamma |- ~q */ symb rotate_sum; /* Inequalities to equality. leq2eq(p, q, r) takes a proof p of ~x=y, a proof q of x <= y and a proof r of y <= x and yields a proof of false. */ symb leq2eq; /* Equality to inequality. eq2leq(p, q) takes a proof p of x=y, and a proof q ~(x <= y) and and yields a proof of false. */ symb eq2leq; /* Proof term cong(p,q) takes a proof p of x=y and a proof q of t != t and returns a proof of false. */ symb cong; /* Excluded middle. exmid(phi,p,q) takes a proof p of phi and a proof q of ~\phi and returns a proof of false. */ symb exmid; /* Symmetry. symm(p) takes a proof p of x=y and produces a proof of y=x. */ symb symm; /* Modus ponens. modpon(p,e,q) takes proofs p of P, e of P=Q and q of ~Q and returns a proof of false. */ symb modpon; /* This oprerator represents a concatenation of rewrites. The term a=b;c=d represents an A rewrite from a to b, followed by a B rewrite fron b to c, followed by an A rewrite from c to d. */ symb concat; /* This represents a lack of a proof */ ast no_proof; // This is used to represent an infinitessimal value ast epsilon; // Represents the top position of a term ast top_pos; // add_pos(i,pos) represents position pos if the ith argument symb add_pos; // rewrite proof rules /* rewrite_A(pos,cond,x=y) derives A |- cond => t[x]_p = t[y]_p where t is an arbitrary term */ symb rewrite_A; /* rewrite_B(pos,cond,x=y) derives B |- cond => t[x]_p = t[y]_p, where t is an arbitrary term */ symb rewrite_B; ast get_placeholder(ast t){ hash_map::iterator it = placeholders.find(t); if(it != placeholders.end()) return it->second; ast &res = placeholders[t]; res = mk_fresh_constant("@p",get_type(t)); #if 0 std::cout << "placeholder "; print_expr(std::cout,res); std::cout << " = "; print_expr(std::cout,t); std::cout << std::endl; #endif return res; } ast make_contra_node(const ast &pf, const std::vector &lits, int pfok = -1){ if(lits.size() == 0) return pf; std::vector reslits; reslits.push_back(make(contra,pf,mk_false())); for(unsigned i = 0; i < lits.size(); i++){ ast bar; if(pfok & (1 << i)) bar = make(rotate_sum,lits[i],pf); else bar = no_proof; ast foo = make(contra,bar,lits[i]); reslits.push_back(foo); } return make(And,reslits); } LitType get_term_type(const ast &lit){ prover::range r = pv->ast_scope(lit); if(pv->range_is_empty(r)) return LitMixed; if(weak) { if(pv->range_min(r) == SHRT_MIN) return pv->range_contained(r,rng) ? LitA : LitB; else return pv->ranges_intersect(r,rng) ? LitA : LitB; } else return pv->range_contained(r,rng) ? LitA : LitB; } bool term_common(const ast &t){ prover::range r = pv->ast_scope(t); return pv->ranges_intersect(r,rng) && !pv->range_contained(r,rng); } bool term_in_vocab(LitType ty, const ast &lit){ prover::range r = pv->ast_scope(lit); if(ty == LitA){ return pv->ranges_intersect(r,rng); } return !pv->range_contained(r,rng); } /** Make a resolution node with given pivot literal and premises. The conclusion of premise1 should contain the negation of the pivot literal, while the conclusion of premise2 should contain the pivot literal. */ node make_resolution(ast pivot, const std::vector &conc, node premise1, node premise2) { LitType lt = get_term_type(pivot); if(lt == LitA) return my_or(premise1,premise2); if(lt == LitB) return my_and(premise1,premise2); /* the mixed case is a bit complicated */ static int non_local_count = 0; ast res = resolve_arith(pivot,conc,premise1,premise2); #ifdef INVARIANT_CHECKING check_contra(conc,res); #endif non_local_count++; return res; } /* Handles the case of resolution on a mixed arith atom. */ ast resolve_arith(const ast &pivot, const std::vector &conc, node premise1, node premise2){ ast atom = get_lit_atom(pivot); hash_map memo; ast neg_pivot_lit = mk_not(atom); if(op(pivot) != Not) std::swap(premise1,premise2); return resolve_arith_rec1(memo, neg_pivot_lit, premise1, premise2); } ast apply_coeff(const ast &coeff, const ast &t){ #if 0 rational r; if(!is_integer(coeff,r)) throw "ack!"; ast n = make_int(r.numerator()); ast res = make(Times,n,t); if(!r.is_int()) { ast d = make_int(r.numerator()); res = mk_idiv(res,d); } return res; #endif return make(Times,coeff,t); } ast sum_ineq(const ast &coeff1, const ast &ineq1, const ast &coeff2, const ast &ineq2){ opr sum_op = Leq; if(op(ineq1) == Lt || op(ineq2) == Lt) sum_op = Lt; ast sum_sides[2]; for(int i = 0; i < 2; i++){ sum_sides[i] = make(Plus,apply_coeff(coeff1,arg(ineq1,i)),apply_coeff(coeff2,arg(ineq2,i))); sum_sides[i] = z3_simplify(sum_sides[i]); } return make(sum_op,sum_sides[0],sum_sides[1]); } void collect_contra_resolvents(int from, const ast &pivot1, const ast &pivot, const ast &conj, std::vector &res){ int nargs = num_args(conj); for(int i = from; i < nargs; i++){ ast f = arg(conj,i); if(!(f == pivot)){ ast ph = get_placeholder(mk_not(arg(pivot1,1))); ast pf = arg(pivot1,0); ast thing = pf == no_proof ? no_proof : subst_term_and_simp(ph,pf,arg(f,0)); ast newf = make(contra,thing,arg(f,1)); res.push_back(newf); } } } bool is_negative_equality(const ast &e){ if(op(e) == Not){ opr o = op(arg(e,0)); return o == Equal || o == Iff; } return false; } int count_negative_equalities(const std::vector &resolvent){ int res = 0; for(unsigned i = 0; i < resolvent.size(); i++) if(is_negative_equality(arg(resolvent[i],1))) res++; return res; } ast resolve_contra_nf(const ast &pivot1, const ast &conj1, const ast &pivot2, const ast &conj2){ std::vector resolvent; collect_contra_resolvents(0,pivot1,pivot2,conj2,resolvent); collect_contra_resolvents(1,pivot2,pivot1,conj1,resolvent); if(count_negative_equalities(resolvent) > 1) throw proof_error(); if(resolvent.size() == 1) // we have proved a contradiction return simplify(arg(resolvent[0],0)); // this is the proof -- get interpolant return make(And,resolvent); } ast resolve_contra(const ast &pivot1, const ast &conj1, const ast &pivot2, const ast &conj2){ if(arg(pivot1,0) != no_proof) return resolve_contra_nf(pivot1, conj1, pivot2, conj2); if(arg(pivot2,0) != no_proof) return resolve_contra_nf(pivot2, conj2, pivot1, conj1); return resolve_with_quantifier(pivot1, conj1, pivot2, conj2); } bool is_contra_itp(const ast &pivot1, ast itp2, ast &pivot2){ if(op(itp2) == And){ int nargs = num_args(itp2); for(int i = 1; i < nargs; i++){ ast foo = arg(itp2,i); if(op(foo) == Uninterpreted && sym(foo) == contra){ if(arg(foo,1) == pivot1){ pivot2 = foo; return true; } } else break; } } return false; } ast resolve_arith_rec2(hash_map &memo, const ast &pivot1, const ast &conj1, const ast &itp2){ ast &res = memo[itp2]; if(!res.null()) return res; ast pivot2; if(is_contra_itp(mk_not(arg(pivot1,1)),itp2,pivot2)) res = resolve_contra(pivot1,conj1,pivot2,itp2); else { switch(op(itp2)){ case Or: case And: case Implies: { unsigned nargs = num_args(itp2); std::vector args; args.resize(nargs); for(unsigned i = 0; i < nargs; i++) args[i] = resolve_arith_rec2(memo, pivot1, conj1, arg(itp2,i)); ast foo = itp2; // get rid of const res = clone(foo,args); break; } default: res = itp2; } } return res; } ast resolve_arith_rec1(hash_map &memo, const ast &neg_pivot_lit, const ast &itp1, const ast &itp2){ ast &res = memo[itp1]; if(!res.null()) return res; ast pivot1; if(is_contra_itp(neg_pivot_lit,itp1,pivot1)){ hash_map memo2; res = resolve_arith_rec2(memo2,pivot1,itp1,itp2); } else { switch(op(itp1)){ case Or: case And: case Implies: { unsigned nargs = num_args(itp1); std::vector args; args.resize(nargs); for(unsigned i = 0; i < nargs; i++) args[i] = resolve_arith_rec1(memo, neg_pivot_lit, arg(itp1,i), itp2); ast foo = itp1; // get rid of const res = clone(foo,args); break; } default: res = itp1; } } return res; } void check_contra(hash_set &memo, hash_set &neg_lits, const ast &foo){ if(memo.find(foo) != memo.end()) return; memo.insert(foo); if(op(foo) == Uninterpreted && sym(foo) == contra){ ast neg_lit = arg(foo,1); if(!is_false(neg_lit) && neg_lits.find(neg_lit) == neg_lits.end()) throw "lost a literal"; return; } else { switch(op(foo)){ case Or: case And: case Implies: { unsigned nargs = num_args(foo); std::vector args; args.resize(nargs); for(unsigned i = 0; i < nargs; i++) check_contra(memo, neg_lits, arg(foo,i)); break; } default: break; } } } void check_contra(const std::vector &neg_lits, const ast &foo){ hash_set memo; hash_set neg_lits_set; for(unsigned i = 0; i < neg_lits.size(); i++) if(get_term_type(neg_lits[i]) == LitMixed) neg_lits_set.insert(mk_not(neg_lits[i])); check_contra(memo,neg_lits_set,foo); } hash_map subst_memo; // memo of subst function ast subst_term_and_simp(const ast &var, const ast &t, const ast &e){ subst_memo.clear(); return subst_term_and_simp_rec(var,t,e); } ast subst_term_and_simp_rec(const ast &var, const ast &t, const ast &e){ if(e == var) return t; std::pair foo(e,ast()); std::pair::iterator,bool> bar = subst_memo.insert(foo); ast &res = bar.first->second; if(bar.second){ if(sym(e) == rotate_sum && var == get_placeholder(arg(e,0))){ res = e; return res; } int nargs = num_args(e); std::vector args(nargs); for(int i = 0; i < nargs; i++) args[i] = subst_term_and_simp_rec(var,t,arg(e,i)); opr f = op(e); if(f == Equal && args[0] == args[1]) res = mk_true(); else if(f == And) res = my_and(args); else if(f == Or) res = my_or(args); else if(f == Idiv) res = mk_idiv(args[0],args[1]); else res = clone(e,args); } return res; } /* This is where the real work happens. Here, we simplify the proof obtained by cut elimination, obtaining an interpolant. */ struct cannot_simplify {}; hash_map simplify_memo; ast simplify(const ast &t){ return simplify_rec(t); } ast simplify_rec(const ast &e){ std::pair foo(e,ast()); std::pair::iterator,bool> bar = simplify_memo.insert(foo); ast &res = bar.first->second; if(bar.second){ int nargs = num_args(e); std::vector args(nargs); bool placeholder_arg = false; for(int i = 0; i < nargs; i++){ args[i] = simplify_rec(arg(e,i)); placeholder_arg |= is_placeholder(args[i]); } try { opr f = op(e); if(f == Equal && args[0] == args[1]) res = mk_true(); else if(f == And) res = my_and(args); else if(f == Or) res = my_or(args); else if(f == Idiv) res = mk_idiv(args[0],args[1]); else if(f == Uninterpreted && !placeholder_arg){ symb g = sym(e); if(g == rotate_sum) res = simplify_rotate(args); else if(g == symm) res = simplify_symm(args); else if(g == modpon) res = simplify_modpon(args); else if(g == sum) res = simplify_sum(args); #if 0 else if(g == cong) res = simplify_cong(args); else if(g == modpon) res = simplify_modpon(args); else if(g == leq2eq) res = simplify_leq2eq(args); else if(g == eq2leq) res = simplify_eq2leq(args); #endif else res = clone(e,args); } else res = clone(e,args); } catch (const cannot_simplify &){ res = clone(e,args); } } return res; } ast simplify_rotate(const std::vector &args){ const ast &pf = args[1]; ast pl = get_placeholder(args[0]); if(op(pf) == Uninterpreted){ symb g = sym(pf); if(g == sum) return simplify_rotate_sum(pl,pf); if(g == leq2eq) return simplify_rotate_leq2eq(pl,args[0],pf); if(g == eq2leq) return simplify_rotate_eq2leq(pl,args[0],pf); if(g == cong) return simplify_rotate_cong(pl,args[0],pf); if(g == modpon) return simplify_rotate_modpon(pl,args[0],pf); // if(g == symm) return simplify_rotate_symm(pl,args[0],pf); } if(op(pf) == Leq) throw "foo!"; throw cannot_simplify(); } ast simplify_sum(std::vector &args){ ast cond = mk_true(); ast ineq = args[0]; if(!is_ineq(ineq)) throw cannot_simplify(); sum_cond_ineq(ineq,cond,args[1],args[2]); return my_implies(cond,ineq); } ast simplify_rotate_sum(const ast &pl, const ast &pf){ ast cond = mk_true(); ast ineq = make(Leq,make_int("0"),make_int("0")); return rotate_sum_rec(pl,pf,cond,ineq); } bool is_rewrite_chain(const ast &chain){ return sym(chain) == concat; } ast ineq_from_chain(const ast &chain, ast &cond){ if(is_rewrite_chain(chain)){ ast last = chain_last(chain); ast rest = chain_rest(chain); if(is_true(rest) && is_rewrite_side(LitA,last) && is_true(rewrite_lhs(last))){ cond = my_and(cond,rewrite_cond(last)); return rewrite_rhs(last); } if(is_rewrite_side(LitB,last) && is_true(rewrite_cond(last))) return ineq_from_chain(rest,cond); } return chain; } void sum_cond_ineq(ast &ineq, ast &cond, const ast &coeff2, const ast &ineq2){ opr o = op(ineq2); if(o == Implies){ sum_cond_ineq(ineq,cond,coeff2,arg(ineq2,1)); cond = my_and(cond,arg(ineq2,0)); } else { ast the_ineq = ineq_from_chain(ineq2,cond); if(is_ineq(the_ineq)) linear_comb(ineq,coeff2,the_ineq); else throw cannot_simplify(); } } bool is_ineq(const ast &ineq){ opr o = op(ineq); if(o == Not) o = op(arg(ineq,0)); return o == Leq || o == Lt || o == Geq || o == Gt; } // divide both sides of inequality by a non-negative integer divisor ast idiv_ineq(const ast &ineq1, const ast &divisor){ if(divisor == make_int(rational(1))) return ineq1; ast ineq = ineq1; if(op(ineq) == Lt) ineq = simplify_ineq(make(Leq,arg(ineq,0),make(Sub,arg(ineq,1),make_int("1")))); return make(op(ineq),mk_idiv(arg(ineq,0),divisor),mk_idiv(arg(ineq,1),divisor)); } ast rotate_sum_rec(const ast &pl, const ast &pf, ast &cond, ast &ineq){ if(pf == pl) return my_implies(cond,simplify_ineq(ineq)); if(op(pf) == Uninterpreted && sym(pf) == sum){ if(arg(pf,2) == pl){ sum_cond_ineq(ineq,cond,make_int("1"),arg(pf,0)); ineq = idiv_ineq(ineq,arg(pf,1)); return my_implies(cond,ineq); } sum_cond_ineq(ineq,cond,arg(pf,1),arg(pf,2)); return rotate_sum_rec(pl,arg(pf,0),cond,ineq); } throw cannot_simplify(); } ast simplify_rotate_leq2eq(const ast &pl, const ast &neg_equality, const ast &pf){ if(pl == arg(pf,0)){ ast equality = arg(neg_equality,0); ast x = arg(equality,0); ast y = arg(equality,1); ast cond1 = mk_true(); ast xleqy = round_ineq(ineq_from_chain(arg(pf,1),cond1)); ast yleqx = round_ineq(ineq_from_chain(arg(pf,2),cond1)); ast ineq1 = make(Leq,make_int("0"),make_int("0")); sum_cond_ineq(ineq1,cond1,make_int("-1"),xleqy); sum_cond_ineq(ineq1,cond1,make_int("-1"),yleqx); cond1 = my_and(cond1,z3_simplify(ineq1)); ast cond2 = mk_true(); ast ineq2 = make(Leq,make_int("0"),make_int("0")); sum_cond_ineq(ineq2,cond2,make_int("1"),xleqy); sum_cond_ineq(ineq2,cond2,make_int("1"),yleqx); cond2 = z3_simplify(ineq2); if(get_term_type(x) == LitA){ ast iter = z3_simplify(make(Plus,x,get_ineq_rhs(xleqy))); ast rewrite1 = make_rewrite(LitA,top_pos,cond1,make(Equal,x,iter)); ast rewrite2 = make_rewrite(LitB,top_pos,cond2,make(Equal,iter,y)); return chain_cons(chain_cons(mk_true(),rewrite1),rewrite2); } if(get_term_type(y) == LitA){ ast iter = z3_simplify(make(Plus,y,get_ineq_rhs(yleqx))); ast rewrite2 = make_rewrite(LitA,top_pos,cond1,make(Equal,iter,y)); ast rewrite1 = make_rewrite(LitB,top_pos,cond2,make(Equal,x,iter)); return chain_cons(chain_cons(mk_true(),rewrite1),rewrite2); } throw cannot_simplify(); } throw cannot_simplify(); } ast round_ineq(const ast &ineq){ if(!is_ineq(ineq)) throw cannot_simplify(); ast res = simplify_ineq(ineq); if(op(res) == Lt) res = make(Leq,arg(res,0),make(Sub,arg(res,1),make_int("1"))); return res; } ast simplify_rotate_eq2leq(const ast &pl, const ast &neg_equality, const ast &pf){ if(pl == arg(pf,1)){ ast cond = mk_true(); ast equa = sep_cond(arg(pf,0),cond); if(is_equivrel_chain(equa)){ ast lhs,rhs; eq_from_ineq(arg(neg_equality,0),lhs,rhs); // get inequality we need to prove ast ineqs= chain_ineqs(op(arg(neg_equality,0)),LitA,equa,lhs,rhs); // chain must be from lhs to rhs cond = my_and(cond,chain_conditions(LitA,equa)); ast Bconds = chain_conditions(LitB,equa); if(is_true(Bconds) && op(ineqs) != And) return my_implies(cond,ineqs); } } throw cannot_simplify(); } void reverse_modpon(std::vector &args){ std::vector sargs(1); sargs[0] = args[1]; args[1] = simplify_symm(sargs); if(is_equivrel_chain(args[2])) args[1] = down_chain(args[1]); std::swap(args[0],args[2]); } ast simplify_rotate_modpon(const ast &pl, const ast &neg_equality, const ast &pf){ std::vector args; args.resize(3); args[0] = arg(pf,0); args[1] = arg(pf,1); args[2] = arg(pf,2); if(pl == args[0]) reverse_modpon(args); if(pl == args[2]){ ast cond = mk_true(); ast chain = simplify_modpon_fwd(args, cond); return my_implies(cond,chain); } throw cannot_simplify(); } ast get_ineq_rhs(const ast &ineq2){ opr o = op(ineq2); if(o == Implies) return get_ineq_rhs(arg(ineq2,1)); else if(o == Leq || o == Lt) return arg(ineq2,1); throw cannot_simplify(); } ast simplify_rotate_cong(const ast &pl, const ast &neg_equality, const ast &pf){ if(pl == arg(pf,2)){ if(op(arg(pf,0)) == True) return mk_true(); rational pos; if(is_numeral(arg(pf,1),pos)){ int ipos = pos.get_unsigned(); ast cond = mk_true(); ast equa = sep_cond(arg(pf,0),cond); #if 0 if(op(equa) == Equal){ ast pe = mk_not(neg_equality); ast lhs = subst_in_arg_pos(ipos,arg(equa,0),arg(pe,0)); ast rhs = subst_in_arg_pos(ipos,arg(equa,1),arg(pe,1)); ast res = make(Equal,lhs,rhs); return my_implies(cond,res); } #endif ast res = chain_pos_add(ipos,equa); return my_implies(cond,res); } } throw cannot_simplify(); } ast simplify_symm(const std::vector &args){ if(op(args[0]) == True) return mk_true(); ast cond = mk_true(); ast equa = sep_cond(args[0],cond); if(is_equivrel_chain(equa)) return my_implies(cond,reverse_chain(equa)); throw cannot_simplify(); } ast simplify_modpon_fwd(const std::vector &args, ast &cond){ ast P = sep_cond(args[0],cond); ast PeqQ = sep_cond(args[1],cond); ast chain; if(is_equivrel_chain(P)){ try { ast split[2]; split_chain(PeqQ,split); chain = reverse_chain(split[0]); chain = concat_rewrite_chain(chain,P); chain = concat_rewrite_chain(chain,split[1]); } catch(const cannot_split &){ static int this_count = 0; this_count++; ast tail, pref = get_head_chain(PeqQ,tail,false); // pref is x=y, tail is x=y -> x'=y' ast split[2]; split_chain(tail,split); // rewrites from x to x' and y to y' ast head = chain_last(pref); ast prem = make_rewrite(rewrite_side(head),top_pos,rewrite_cond(head),make(Iff,mk_true(),mk_not(rewrite_lhs(head)))); ast back_chain = chain_cons(mk_true(),prem); back_chain = concat_rewrite_chain(back_chain,chain_pos_add(0,reverse_chain(chain_rest(pref)))); ast cond = contra_chain(back_chain,P); if(is_rewrite_side(LitA,head)) cond = mk_not(cond); ast fwd_rewrite = make_rewrite(rewrite_side(head),top_pos,cond,rewrite_rhs(head)); P = chain_cons(mk_true(),fwd_rewrite); chain = reverse_chain(split[0]); chain = concat_rewrite_chain(chain,P); chain = concat_rewrite_chain(chain,split[1]); } } else // if not an equivalence, must be of form T <-> pred chain = concat_rewrite_chain(P,PeqQ); return chain; } /* Given a chain rewrite chain deriving not P and a rewrite chain deriving P, return an interpolant. */ ast contra_chain(const ast &neg_chain, const ast &pos_chain){ // equality is a special case. we use the derivation of x=y to rewrite not(x=y) to not(y=y) if(is_equivrel_chain(pos_chain)){ ast tail, pref = get_head_chain(neg_chain,tail); // pref is not(x=y), tail is not(x,y) -> not(x',y') ast split[2]; split_chain(down_chain(tail),split); // rewrites from x to x' and y to y' ast chain = split[0]; chain = concat_rewrite_chain(chain,pos_chain); // rewrites from x to y' chain = concat_rewrite_chain(chain,reverse_chain(split[1])); // rewrites from x to y chain = concat_rewrite_chain(pref,chain_pos_add(0,chain_pos_add(0,chain))); // rewrites t -> not(y=y) ast head = chain_last(pref); if(is_rewrite_side(LitB,head)){ ast condition = chain_conditions(LitB,chain); return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),condition); } else { ast condition = chain_conditions(LitA,chain); return my_and(chain_conditions(LitB,chain),my_implies(condition,mk_not(chain_formulas(LitB,chain)))); } // ast chain = concat_rewrite_chain(neg_chain,chain_pos_add(0,chain_pos_add(0,pos_chain))); // return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),chain_conditions(LitB,chain)); } // otherwise, we reverse the derivation of t = P and use it to rewrite not(P) to not(t) ast chain = concat_rewrite_chain(neg_chain,chain_pos_add(0,reverse_chain(pos_chain))); return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),chain_conditions(LitB,chain)); } ast simplify_modpon(const std::vector &args){ ast cond = mk_true(); ast chain = simplify_modpon_fwd(args,cond); ast Q2 = sep_cond(args[2],cond); ast interp = is_negation_chain(chain) ? contra_chain(chain,Q2) : contra_chain(Q2,chain); return my_implies(cond,interp); } bool is_equivrel(const ast &p){ opr o = op(p); return o == Equal || o == Iff; } struct rewrites_failed{}; /* Suppose p in Lang(B) and A |- p -> q and B |- q -> r. Return a z in Lang(B) such that B |- p -> z and A |- z -> q. Collect any side conditions in "rules". */ ast commute_rewrites(const ast &p, const ast &q, const ast &r, ast &rules){ if(q == r) return p; if(p == q) return r; else { ast rew = make(Equal,q,r); if(get_term_type(rew) == LitB){ apply_common_rewrites(p,p,q,rules); // A rewrites must be over comon vocab return r; } } if(sym(p) != sym(q) || sym(q) != sym(r)) throw rewrites_failed(); int nargs = num_args(p); if(nargs != num_args(q) || nargs != num_args(r)) throw rewrites_failed(); std::vector args; args.resize(nargs); for(int i = 0; i < nargs; i++) args[i] = commute_rewrites(arg(p,i),arg(q,i),arg(r,i),rules); return clone(p,args); } ast apply_common_rewrites(const ast &p, const ast &q, const ast &r, ast &rules){ if(q == r) return p; ast rew = make(Equal,q,r); if(term_common(rew)){ if(p != q) throw rewrites_failed(); rules = my_and(rules,rew); return r; } if(sym(p) != sym(q) || sym(q) != sym(r)) return p; int nargs = num_args(p); if(nargs != num_args(q) || nargs != num_args(r)) return p; std::vector args; args.resize(nargs); for(int i = 0; i < nargs; i++) args[i] = apply_common_rewrites(arg(p,i),arg(q,i),arg(r,i),rules); return clone(p,args); } ast apply_all_rewrites(const ast &p, const ast &q, const ast &r){ if(q == r) return p; if(p == q) return r; if(sym(p) != sym(q) || sym(q) != sym(r)) throw rewrites_failed(); int nargs = num_args(p); if(nargs != num_args(q) || nargs != num_args(r)) throw rewrites_failed(); std::vector args; args.resize(nargs); for(int i = 0; i < nargs; i++) args[i] = apply_all_rewrites(arg(p,i),arg(q,i),arg(r,i)); return clone(p,args); } ast delta(const ast &x, const ast &y){ if(op(x) != op(y) || (op(x) == Uninterpreted && sym(x) != sym(y)) || num_args(x) != num_args(y)) return make(Equal,x,y); ast res = mk_true(); int nargs = num_args(x); for(int i = 0; i < nargs; i++) res = my_and(res,delta(arg(x,i),arg(y,i))); return res; } bool diff_rec(LitType t, const ast &p, const ast &q, ast &pd, ast &qd){ if(p == q) return false; if(term_in_vocab(t,p) && term_in_vocab(t,q)){ pd = p; qd = q; return true; } else { if(sym(p) != sym(q)) return false; int nargs = num_args(p); if(num_args(q) != nargs) return false; for(int i = 0; i < nargs; i++) if(diff_rec(t,arg(p,i),arg(q,i),pd,qd)) return true; return false; } } void diff(LitType t, const ast &p, const ast &q, ast &pd, ast &qd){ if(!diff_rec(t,p,q,pd,qd)) throw cannot_simplify(); } bool apply_diff_rec(LitType t, const ast &inp, const ast &p, const ast &q, ast &out){ if(p == q) return false; if(term_in_vocab(t,p) && term_in_vocab(t,q)){ if(inp != p) throw cannot_simplify(); out = q; return true; } else { int nargs = num_args(p); if(sym(p) != sym(q)) throw cannot_simplify(); if(num_args(q) != nargs) throw cannot_simplify(); if(sym(p) != sym(inp)) throw cannot_simplify(); if(num_args(inp) != nargs) throw cannot_simplify(); for(int i = 0; i < nargs; i++) if(apply_diff_rec(t,arg(inp,i),arg(p,i),arg(q,i),out)) return true; return false; } } ast apply_diff(LitType t, const ast &inp, const ast &p, const ast &q){ ast out; if(!apply_diff_rec(t,inp,p,q,out)) throw cannot_simplify(); return out; } bool merge_A_rewrites(const ast &A1, const ast &A2, ast &merged) { if(arg(A1,1) == arg(A2,0)){ merged = make(op(A1),arg(A1,0),arg(A2,1)); return true; } ast diff1l, diff1r, diff2l, diff2r,diffBl,diffBr; diff(LitA,arg(A1,0),arg(A1,1),diff1l,diff1r); diff(LitA,arg(A2,0),arg(A2,1),diff2l,diff2r); diff(LitB,arg(A1,1),arg(A2,0),diffBl,diffBr); if(!term_common(diff2l) && !term_common(diffBr)){ ast A1r = apply_diff(LitB,arg(A2,1),arg(A2,0),arg(A1,1)); merged = make(op(A1),arg(A1,0),A1r); return true; } if(!term_common(diff1r) && !term_common(diffBl)){ ast A2l = apply_diff(LitB,arg(A1,0),arg(A1,1),arg(A2,0)); merged = make(op(A1),A2l,arg(A2,1)); return true; } return false; } void collect_A_rewrites(const ast &t, std::vector &res){ if(is_true(t)) return; if(sym(t) == concat){ res.push_back(arg(t,0)); collect_A_rewrites(arg(t,0),res); return; } res.push_back(t); } ast concat_A_rewrites(const std::vector &rew){ if(rew.size() == 0) return mk_true(); ast res = rew[0]; for(unsigned i = 1; i < rew.size(); i++) res = make(concat,res,rew[i]); return res; } ast merge_concat_rewrites(const ast &A1, const ast &A2){ std::vector rew; collect_A_rewrites(A1,rew); int first = rew.size(), last = first; // range that might need merging collect_A_rewrites(A2,rew); while(first > 0 && first < (int)rew.size() && first <= last){ ast merged; if(merge_A_rewrites(rew[first-1],rew[first],merged)){ rew[first] = merged; first--; rew.erase(rew.begin()+first); last--; if(first >= last) last = first+1; } else first++; } return concat_A_rewrites(rew); } ast sep_cond(const ast &t, ast &cond){ if(op(t) == Implies){ cond = my_and(cond,arg(t,0)); return arg(t,1); } return t; } /* operations on term positions */ /** Finds the difference between two positions. If p1 < p2 (p1 is a position below p2), returns -1 and sets diff to p2-p1 (the psath from position p2 to position p1). If p2 < p1 (p2 is a position below p1), returns 1 and sets diff to p1-p2 (the psath from position p1 to position p2). If equal, return 0 and set diff to top_pos. Else (if p1 and p2 are independent) returns 2 and leaves diff unchanged. */ int pos_diff(const ast &p1, const ast &p2, ast &diff){ if(p1 == top_pos && p2 != top_pos){ diff = p2; return 1; } if(p2 == top_pos && p1 != top_pos){ diff = p1; return -1; } if(p1 == top_pos && p2 == top_pos){ diff = p1; return 0; } if(arg(p1,0) == arg(p2,0)) // same argument position, recur return pos_diff(arg(p1,1),arg(p2,1),diff); return 2; } /* return the position of pos in the argth argument */ ast pos_add(int arg, const ast &pos){ return make(add_pos,make_int(rational(arg)),pos); } /* return the argument number of position, if not top */ int pos_arg(const ast &pos){ rational r; if(is_numeral(arg(pos,0),r)) return r.get_unsigned(); throw "bad position!"; } /* substitute y into position pos in x */ ast subst_in_pos(const ast &x, const ast &pos, const ast &y){ if(pos == top_pos) return y; int p = pos_arg(pos); int nargs = num_args(x); if(p >= 0 && p < nargs){ std::vector args(nargs); for(int i = 0; i < nargs; i++) args[i] = i == p ? subst_in_pos(arg(x,i),arg(pos,1),y) : arg(x,i); return clone(x,args); } throw "bad term position!"; } ast diff_chain(LitType t, const ast &pos, const ast &x, const ast &y, const ast &prefix){ int nargs = num_args(x); if(x == y) return prefix; if(sym(x) == sym(y) && nargs == num_args(y)){ ast res = prefix; for(int i = 0; i < nargs; i++) res = diff_chain(t,pos_add(i,pos),arg(x,i),arg(y,i),res); return res; } return chain_cons(prefix,make_rewrite(t,pos,mk_true(),make_equiv_rel(x,y))); } /* operations on rewrites */ ast make_rewrite(LitType t, const ast &pos, const ast &cond, const ast &equality){ #if 0 if(pos == top_pos && op(equality) == Iff && !is_true(arg(equality,0))) throw "bad rewrite"; #endif return make(t == LitA ? rewrite_A : rewrite_B, pos, cond, equality); } ast rewrite_pos(const ast &rew){ return arg(rew,0); } ast rewrite_cond(const ast &rew){ return arg(rew,1); } ast rewrite_equ(const ast &rew){ return arg(rew,2); } ast rewrite_lhs(const ast &rew){ return arg(arg(rew,2),0); } ast rewrite_rhs(const ast &rew){ return arg(arg(rew,2),1); } /* operations on rewrite chains */ ast chain_cons(const ast &chain, const ast &elem){ return make(concat,chain,elem); } ast chain_rest(const ast &chain){ return arg(chain,0); } ast chain_last(const ast &chain){ return arg(chain,1); } ast rewrite_update_rhs(const ast &rew, const ast &pos, const ast &new_rhs, const ast &new_cond){ ast foo = subst_in_pos(rewrite_rhs(rew),pos,new_rhs); ast equality = arg(rew,2); return make(sym(rew),rewrite_pos(rew),my_and(rewrite_cond(rew),new_cond),make(op(equality),arg(equality,0),foo)); } ast rewrite_update_lhs(const ast &rew, const ast &pos, const ast &new_lhs, const ast &new_cond){ ast foo = subst_in_pos(rewrite_lhs(rew),pos,new_lhs); ast equality = arg(rew,2); return make(sym(rew),rewrite_pos(rew),my_and(rewrite_cond(rew),new_cond),make(op(equality),foo,arg(equality,1))); } bool is_common_rewrite(const ast &rew){ return term_common(arg(rew,2)); } bool is_right_mover(const ast &rew){ return term_common(rewrite_lhs(rew)) && !term_common(rewrite_rhs(rew)); } bool is_left_mover(const ast &rew){ return term_common(rewrite_rhs(rew)) && !term_common(rewrite_lhs(rew)); } bool same_side(const ast &rew1, const ast &rew2){ return sym(rew1) == sym(rew2); } bool is_rewrite_side(LitType t, const ast &rew){ if(t == LitA) return sym(rew) == rewrite_A; return sym(rew) == rewrite_B; } LitType rewrite_side(const ast &rew){ return (sym(rew) == rewrite_A) ? LitA : LitB; } ast rewrite_to_formula(const ast &rew){ return my_implies(arg(rew,1),arg(rew,2)); } // make rewrite rew conditon on rewrite cond ast rewrite_conditional(const ast &cond, const ast &rew){ ast cf = rewrite_to_formula(cond); return make(sym(rew),arg(rew,0),my_and(arg(rew,1),cf),arg(rew,2)); } ast reverse_rewrite(const ast &rew){ ast equ = arg(rew,2); return make(sym(rew),arg(rew,0),arg(rew,1),make(op(equ),arg(equ,1),arg(equ,0))); } ast rewrite_pos_add(int apos, const ast &rew){ return make(sym(rew),pos_add(apos,arg(rew,0)),arg(rew,1),arg(rew,2)); } ast rewrite_up(const ast &rew){ return make(sym(rew),arg(arg(rew,0),1),arg(rew,1),arg(rew,2)); } /** Adds a rewrite to a chain of rewrites, keeping the chain in normal form. An empty chain is represented by true.*/ ast add_rewrite_to_chain(const ast &chain, const ast &rewrite){ if(is_true(chain)) return chain_cons(chain,rewrite); ast last = chain_last(chain); ast rest = chain_rest(chain); if(same_side(last,rewrite)){ ast p1 = rewrite_pos(last); ast p2 = rewrite_pos(rewrite); ast diff; switch(pos_diff(p1,p2,diff)){ case 1: { ast absorb = rewrite_update_rhs(last,diff,rewrite_rhs(rewrite),rewrite_cond(rewrite)); return add_rewrite_to_chain(rest,absorb); } case 0: case -1: { ast absorb = rewrite_update_lhs(rewrite,diff,rewrite_lhs(last),rewrite_cond(last)); return add_rewrite_to_chain(rest,absorb); } default: {// independent case bool rm = is_right_mover(last); bool lm = is_left_mover(rewrite); if((lm && !rm) || (rm && !lm)) return chain_swap(rest,last,rewrite); } } } else { if(is_left_mover(rewrite)){ if(is_common_rewrite(last)) return add_rewrite_to_chain(chain_cons(rest,flip_rewrite(last)),rewrite); if(!is_left_mover(last)) return chain_swap(rest,last,rewrite); } if(is_right_mover(last)){ if(is_common_rewrite(rewrite)) return add_rewrite_to_chain(chain,flip_rewrite(rewrite)); if(!is_right_mover(rewrite)) return chain_swap(rest,last,rewrite); } } return chain_cons(chain,rewrite); } ast chain_swap(const ast &rest, const ast &last, const ast &rewrite){ return chain_cons(add_rewrite_to_chain(rest,rewrite),last); } ast flip_rewrite(const ast &rew){ symb flip_sym = (sym(rew) == rewrite_A) ? rewrite_B : rewrite_A; ast cf = rewrite_to_formula(rew); return make(flip_sym,arg(rew,0),my_implies(arg(rew,1),cf),arg(rew,2)); } /** concatenates two rewrite chains, keeping result in normal form. */ ast concat_rewrite_chain(const ast &chain1, const ast &chain2){ if(is_true(chain2)) return chain1; if(is_true(chain1)) return chain2; ast foo = concat_rewrite_chain(chain1,chain_rest(chain2)); return add_rewrite_to_chain(foo,chain_last(chain2)); } /** reverse a chain of rewrites */ ast reverse_chain_rec(const ast &chain, const ast &prefix){ if(is_true(chain)) return prefix; ast last = reverse_rewrite(chain_last(chain)); ast rest = chain_rest(chain); return reverse_chain_rec(rest,chain_cons(prefix,last)); } ast reverse_chain(const ast &chain){ return reverse_chain_rec(chain,mk_true()); } bool is_equivrel_chain(const ast &chain){ if(is_true(chain)) return true; ast last = chain_last(chain); ast rest = chain_rest(chain); if(is_true(rest)) return !is_true(rewrite_lhs(last)); return is_equivrel_chain(rest); } bool is_negation_chain(const ast &chain){ if(is_true(chain)) return false; ast last = chain_last(chain); ast rest = chain_rest(chain); if(is_true(rest)) return op(rewrite_rhs(last)) == Not; return is_negation_chain(rest); } // split a rewrite chain into head and tail at last top-level rewrite ast get_head_chain(const ast &chain, ast &tail, bool is_not = true){ ast last = chain_last(chain); ast rest = chain_rest(chain); ast pos = rewrite_pos(last); if(pos == top_pos || (is_not && arg(pos,1) == top_pos)){ tail = mk_true(); return chain; } if(is_true(rest)) throw "bad rewrite chain"; ast head = get_head_chain(rest,tail,is_not); tail = chain_cons(tail,last); return head; } struct cannot_split {}; /** Split a chain of rewrites two chains, operating on positions 0 and 1. Fail if any rewrite in the chain operates on top position. */ void split_chain_rec(const ast &chain, ast *res){ if(is_true(chain)) return; ast last = chain_last(chain); ast rest = chain_rest(chain); split_chain_rec(rest,res); ast pos = rewrite_pos(last); if(pos == top_pos){ if(rewrite_lhs(last) == rewrite_rhs(last)) return; // skip if it's a noop throw cannot_split(); } int arg = pos_arg(pos); if(arg<0 || arg > 1) throw cannot_split(); res[arg] = chain_cons(res[arg],rewrite_up(last)); } void split_chain(const ast &chain, ast *res){ res[0] = res[1] = mk_true(); split_chain_rec(chain,res); } ast down_chain(const ast &chain){ ast split[2]; split_chain(chain,split); return split[0]; } ast chain_conditions(LitType t, const ast &chain){ if(is_true(chain)) return mk_true(); ast last = chain_last(chain); ast rest = chain_rest(chain); ast cond = chain_conditions(t,rest); if(is_rewrite_side(t,last)) cond = my_and(cond,rewrite_cond(last)); return cond; } ast chain_formulas(LitType t, const ast &chain){ if(is_true(chain)) return mk_true(); ast last = chain_last(chain); ast rest = chain_rest(chain); ast cond = chain_formulas(t,rest); if(is_rewrite_side(t,last)) cond = my_and(cond,rewrite_equ(last)); return cond; } ast chain_ineqs(opr comp_op, LitType t, const ast &chain, const ast &lhs, const ast &rhs){ if(is_true(chain)){ if(lhs != rhs) throw "bad ineq inference"; return make(Leq,make_int(rational(0)),make_int(rational(0))); } ast last = chain_last(chain); ast rest = chain_rest(chain); ast mid = subst_in_pos(rhs,rewrite_pos(last),rewrite_lhs(last)); ast cond = chain_ineqs(comp_op,t,rest,lhs,mid); if(is_rewrite_side(t,last)){ ast diff; if(comp_op == Leq) diff = make(Sub,rhs,mid); else diff = make(Sub,mid,rhs); ast foo = z3_simplify(make(Leq,make_int("0"),diff)); if(is_true(cond)) cond = foo; else { linear_comb(cond,make_int(rational(1)),foo); cond = simplify_ineq(cond); } } return cond; } ast ineq_to_lhs(const ast &ineq){ ast s = make(Leq,make_int(rational(0)),make_int(rational(0))); linear_comb(s,make_int(rational(1)),ineq); return simplify_ineq(s); } void eq_from_ineq(const ast &ineq, ast &lhs, ast &rhs){ // ast s = ineq_to_lhs(ineq); // ast srhs = arg(s,1); ast srhs = arg(ineq,0); if(op(srhs) == Plus && num_args(srhs) == 2){ lhs = arg(srhs,0); rhs = arg(srhs,1); // if(op(lhs) == Times) // std::swap(lhs,rhs); if(op(rhs) == Times){ rhs = arg(rhs,1); // if(op(ineq) == Leq) // std::swap(lhs,rhs); return; } } throw "bad ineq"; } ast chain_pos_add(int arg, const ast &chain){ if(is_true(chain)) return mk_true(); ast last = rewrite_pos_add(arg,chain_last(chain)); ast rest = chain_pos_add(arg,chain_rest(chain)); return chain_cons(rest,last); } /** Make an assumption node. The given clause is assumed in the given frame. */ virtual node make_assumption(int frame, const std::vector &assumption){ if(!weak){ if(pv->in_range(frame,rng)){ std::vector itp_clause; for(unsigned i = 0; i < assumption.size(); i++) if(get_term_type(assumption[i]) != LitA) itp_clause.push_back(assumption[i]); ast res = my_or(itp_clause); return res; } else { return mk_true(); } } else { if(pv->in_range(frame,rng)){ return mk_false(); } else { std::vector itp_clause; for(unsigned i = 0; i < assumption.size(); i++) if(get_term_type(assumption[i]) != LitB) itp_clause.push_back(assumption[i]); ast res = my_or(itp_clause); return mk_not(res); } } } ast make_local_rewrite(LitType t, const ast &p){ ast rew = is_equivrel(p) ? p : make(Iff,mk_true(),p); #if 0 if(op(rew) == Iff && !is_true(arg(rew,0))) return diff_chain(t,top_pos,arg(rew,0),arg(rew,1), mk_true()); #endif return chain_cons(mk_true(),make_rewrite(t, top_pos, mk_true(), rew)); } ast triv_interp(const symb &rule, const std::vector &premises, int mask_in){ std::vector ps; ps.resize(premises.size()); std::vector conjs; int mask = 0; for(unsigned i = 0; i < ps.size(); i++){ ast p = premises[i]; LitType t = get_term_type(p); switch(t){ case LitA: case LitB: ps[i] = make_local_rewrite(t,p); break; default: ps[i] = get_placeholder(p); // can only prove consequent! if(mask_in & (1 << i)) mask |= (1 << conjs.size()); conjs.push_back(p); } } ast ref = make(rule,ps); ast res = make_contra_node(ref,conjs,mask); return res; } ast triv_interp(const symb &rule, const ast &p0, const ast &p1, const ast &p2, int mask){ std::vector ps; ps.resize(3); ps[0] = p0; ps[1] = p1; ps[2] = p2; return triv_interp(rule,ps,mask); } /** Make a modus-ponens node. This takes derivations of |- x and |- x = y and produces |- y */ virtual node make_mp(const ast &p_eq_q, const ast &prem1, const ast &prem2){ /* Interpolate the axiom p, p=q -> q */ ast p = arg(p_eq_q,0); ast q = arg(p_eq_q,1); ast itp; if(get_term_type(p_eq_q) == LitMixed){ int mask = 1 << 2; if(op(p) == Not && is_equivrel(arg(p,0))) mask |= 1; // we may need to run this rule backward if first premise is negative equality itp = triv_interp(modpon,p,p_eq_q,mk_not(q),mask); } else { if(get_term_type(p) == LitA){ if(get_term_type(q) == LitA) itp = mk_false(); else { if(get_term_type(p_eq_q) == LitA) itp = q; else throw proof_error(); } } else { if(get_term_type(q) == LitA){ if(get_term_type(make(Equal,p,q)) == LitA) itp = mk_not(p); else throw proof_error(); } else itp = mk_true(); } } /* Resolve it with the premises */ std::vector conc; conc.push_back(q); conc.push_back(mk_not(p_eq_q)); itp = make_resolution(p,conc,itp,prem1); conc.pop_back(); itp = make_resolution(p_eq_q,conc,itp,prem2); return itp; } /** Make an axiom node. The conclusion must be an instance of an axiom. */ virtual node make_axiom(const std::vector &conclusion){ prover::range frng = pv->range_full(); int nargs = conclusion.size(); std::vector largs(nargs); std::vector eqs; std::vector pfs; for(int i = 0; i < nargs; i++){ ast argpf; ast lit = conclusion[i]; largs[i] = localize_term(lit,frng,argpf); frng = pv->range_glb(frng,pv->ast_scope(largs[i])); if(largs[i] != lit){ eqs.push_back(make_equiv(largs[i],lit)); pfs.push_back(argpf); } } int frame = pv->range_max(frng); ast itp = make_assumption(frame,largs); for(unsigned i = 0; i < eqs.size(); i++) itp = make_mp(eqs[i],itp,pfs[i]); return itp; } /** Make a Contra node. This rule takes a derivation of the form Gamma |- False and produces |- \/~Gamma. */ virtual node make_contra(node prem, const std::vector &conclusion){ return prem; } /** Make hypothesis. Creates a node of the form P |- P. */ virtual node make_hypothesis(const ast &P){ if(is_not(P)) return make_hypothesis(arg(P,0)); switch(get_term_type(P)){ case LitA: return mk_false(); case LitB: return mk_true(); default: // mixed hypothesis switch(op(P)){ case Geq: case Leq: case Gt: case Lt: { ast zleqz = make(Leq,make_int("0"),make_int("0")); ast fark1 = make(sum,zleqz,make_int("1"),get_placeholder(P)); ast fark2 = make(sum,fark1,make_int("1"),get_placeholder(mk_not(P))); ast res = make(And,make(contra,fark2,mk_false()), make(contra,get_placeholder(mk_not(P)),P), make(contra,get_placeholder(P),mk_not(P))); return res; } default: { ast em = make(exmid,P,get_placeholder(P),get_placeholder(mk_not(P))); ast res = make(And,make(contra,em,mk_false()), make(contra,get_placeholder(mk_not(P)),P), make(contra,get_placeholder(P),mk_not(P))); return res; } } } } /** Make a Reflexivity node. This rule produces |- x = x */ virtual node make_reflexivity(ast con){ throw proof_error(); } /** Make a Symmetry node. This takes a derivation of |- x = y and produces | y = x. Ditto for ~(x=y) */ virtual node make_symmetry(ast con, const ast &premcon, node prem){ #if 0 ast x = arg(con,0); ast y = arg(con,1); ast p = make(op(con),y,x); #endif if(get_term_type(con) != LitMixed) return prem; // symmetry shmymmetry... ast em = make(exmid,con,make(symm,get_placeholder(premcon)),get_placeholder(mk_not(con))); ast itp = make(And,make(contra,em,mk_false()), make(contra,make(symm,get_placeholder(mk_not(con))),premcon), make(contra,make(symm,get_placeholder(premcon)),mk_not(con))); std::vector conc; conc.push_back(con); itp = make_resolution(premcon,conc,itp,prem); return itp; } ast make_equiv_rel(const ast &x, const ast &y){ if(is_bool_type(get_type(x))) return make(Iff,x,y); return make(Equal,x,y); } /** Make a transitivity node. This takes derivations of |- x = y and |- y = z produces | x = z */ virtual node make_transitivity(const ast &x, const ast &y, const ast &z, node prem1, node prem2){ /* Interpolate the axiom x=y,y=z,-> x=z */ ast p = make_equiv_rel(x,y); ast q = make_equiv_rel(y,z); ast r = make_equiv_rel(x,z); ast equiv = make(Iff,p,r); ast itp; itp = make_congruence(q,equiv,prem2); itp = make_mp(equiv,prem1,itp); return itp; } /** Make a congruence node. This takes derivations of |- x_i = y_i and produces |- f(x_1,...,x_n) = f(y_1,...,y_n) */ virtual node make_congruence(const ast &p, const ast &con, const ast &prem1){ ast x = arg(p,0), y = arg(p,1); ast itp; LitType con_t = get_term_type(con); if(get_term_type(p) == LitA){ if(con_t == LitA) itp = mk_false(); else if(con_t == LitB) itp = p; else itp = make_mixed_congruence(x, y, p, con, prem1); } else { if(con_t == LitA) itp = mk_not(p); else{ if(con_t == LitB) itp = mk_true(); else itp = make_mixed_congruence(x, y, p, con, prem1); } } std::vector conc; conc.push_back(con); itp = make_resolution(p,conc,itp,prem1); return itp; } int find_congruence_position(const ast &p, const ast &con){ // find the argument position of x and y const ast &x = arg(p,0); const ast &y = arg(p,1); int nargs = num_args(arg(con,0)); for(int i = 0; i < nargs; i++) if(x == arg(arg(con,0),i) && y == arg(arg(con,1),i)) return i; throw proof_error(); } /** Make a congruence node. This takes derivations of |- x_i1 = y_i1, |- x_i2 = y_i2,... and produces |- f(...x_i1...x_i2...) = f(...y_i1...y_i2...) */ node make_congruence(const std::vector &p, const ast &con, const std::vector &prems){ if(p.size() == 0) throw proof_error(); if(p.size() == 1) return make_congruence(p[0],con,prems[0]); ast thing = con; ast res = mk_true(); for(unsigned i = 0; i < p.size(); i++){ int pos = find_congruence_position(p[i],thing); ast next = subst_in_arg_pos(pos,arg(p[i],1),arg(thing,0)); ast goal = make(op(thing),arg(thing,0),next); ast equa = make_congruence(p[i],goal,prems[i]); if(i == 0) res = equa; else { ast trace = make(op(con),arg(con,0),arg(thing,0)); ast equiv = make(Iff,trace,make(op(trace),arg(trace,0),next)); ast foo = make_congruence(goal,equiv,equa); res = make_mp(equiv,res,foo); } thing = make(op(thing),next,arg(thing,1)); } return res; } /* Interpolate a mixed congruence axiom. */ virtual ast make_mixed_congruence(const ast &x, const ast &y, const ast &p, const ast &con, const ast &prem1){ ast foo = p; std::vector conjs; LitType t = get_term_type(foo); switch(t){ case LitA: case LitB: foo = make_local_rewrite(t,foo); break; case LitMixed: conjs.push_back(foo); foo = get_placeholder(foo); } // find the argument position of x and y int pos = -1; int nargs = num_args(arg(con,0)); for(int i = 0; i < nargs; i++) if(x == arg(arg(con,0),i) && y == arg(arg(con,1),i)) pos = i; if(pos == -1) throw proof_error(); ast bar = make(cong,foo,make_int(rational(pos)),get_placeholder(mk_not(con))); conjs.push_back(mk_not(con)); return make_contra_node(bar,conjs); } ast subst_in_arg_pos(int pos, ast term, ast app){ std::vector args; get_args(app,args); args[pos] = term; return clone(app,args); } /** Make a farkas proof node. */ virtual node make_farkas(ast con, const std::vector &prems, const std::vector &prem_cons, const std::vector &coeffs){ /* Compute the interpolant for the clause */ ast zero = make_int("0"); std::vector conjs; ast thing = make(Leq,zero,zero); for(unsigned i = 0; i < prem_cons.size(); i++){ const ast &lit = prem_cons[i]; if(get_term_type(lit) == LitA) linear_comb(thing,coeffs[i],lit); } thing = simplify_ineq(thing); for(unsigned i = 0; i < prem_cons.size(); i++){ const ast &lit = prem_cons[i]; if(get_term_type(lit) == LitMixed){ thing = make(sum,thing,coeffs[i],get_placeholder(lit)); conjs.push_back(lit); } } thing = make_contra_node(thing,conjs); /* Resolve it with the premises */ std::vector conc; conc.resize(prem_cons.size()); for(unsigned i = 0; i < prem_cons.size(); i++) conc[prem_cons.size()-i-1] = prem_cons[i]; for(unsigned i = 0; i < prem_cons.size(); i++){ thing = make_resolution(prem_cons[i],conc,thing,prems[i]); conc.pop_back(); } return thing; } /** Set P to P + cQ, where P and Q are linear inequalities. Assumes P is 0 <= y or 0 < y. */ void linear_comb(ast &P, const ast &c, const ast &Q){ ast Qrhs; bool strict = op(P) == Lt; if(is_not(Q)){ ast nQ = arg(Q,0); switch(op(nQ)){ case Gt: Qrhs = make(Sub,arg(nQ,1),arg(nQ,0)); break; case Lt: Qrhs = make(Sub,arg(nQ,0),arg(nQ,1)); break; case Geq: Qrhs = make(Sub,arg(nQ,1),arg(nQ,0)); strict = true; break; case Leq: Qrhs = make(Sub,arg(nQ,0),arg(nQ,1)); strict = true; break; default: throw proof_error(); } } else { switch(op(Q)){ case Leq: Qrhs = make(Sub,arg(Q,1),arg(Q,0)); break; case Geq: Qrhs = make(Sub,arg(Q,0),arg(Q,1)); break; case Lt: Qrhs = make(Sub,arg(Q,1),arg(Q,0)); strict = true; break; case Gt: Qrhs = make(Sub,arg(Q,0),arg(Q,1)); strict = true; break; default: throw proof_error(); } } Qrhs = make(Times,c,Qrhs); if(strict) P = make(Lt,arg(P,0),make(Plus,arg(P,1),Qrhs)); else P = make(Leq,arg(P,0),make(Plus,arg(P,1),Qrhs)); } /* Make an axiom instance of the form |- x<=y, y<= x -> x =y */ virtual node make_leq2eq(ast x, ast y, const ast &xleqy, const ast &yleqx){ ast con = make(Equal,x,y); ast itp; switch(get_term_type(con)){ case LitA: itp = mk_false(); break; case LitB: itp = mk_true(); break; default: { // mixed equality std::vector conjs; conjs.resize(3); conjs[0] = mk_not(con); conjs[1] = xleqy; conjs[2] = yleqx; itp = make_contra_node(make(leq2eq, get_placeholder(mk_not(con)), get_placeholder(xleqy), get_placeholder(yleqx)), conjs,1); } } return itp; } /* Make an axiom instance of the form |- x = y -> x <= y */ virtual node make_eq2leq(ast x, ast y, const ast &xleqy){ ast itp; switch(get_term_type(xleqy)){ case LitA: itp = mk_false(); break; case LitB: itp = mk_true(); break; default: { // mixed equality std::vector conjs; conjs.resize(2); conjs[0] = make(Equal,x,y); conjs[1] = mk_not(xleqy); itp = make(eq2leq,get_placeholder(conjs[0]),get_placeholder(conjs[1])); itp = make_contra_node(itp,conjs,2); } } return itp; } /* Make an inference of the form t <= c |- t/d <= floor(c/d) where t is an affine term divisble by d and c is an integer constant */ virtual node make_cut_rule(const ast &tleqc, const ast &d, const ast &con, const ast &prem){ ast itp = mk_false(); switch(get_term_type(con)){ case LitA: itp = mk_false(); break; case LitB: itp = mk_true(); break; default: { std::vector conjs; conjs.resize(2); conjs[0] = tleqc; conjs[1] = mk_not(con); itp = make(sum,get_placeholder(conjs[0]),d,get_placeholder(conjs[1])); itp = make_contra_node(itp,conjs); } } std::vector conc; conc.push_back(con); itp = make_resolution(tleqc,conc,itp,prem); return itp; } // create a fresh variable for localization ast fresh_localization_var(const ast &term, int frame){ std::ostringstream s; s << "%" << (localization_vars.size()); ast var = make_var(s.str().c_str(),get_type(term)); pv->sym_range(sym(var)) = pv->range_full(); // make this variable global localization_vars.push_back(LocVar(var,term,frame)); return var; } struct LocVar { // localization vars ast var; // a fresh variable ast term; // term it represents int frame; // frame in which it's defined LocVar(ast v, ast t, int f){var=v;term=t;frame=f;} }; std::vector localization_vars; // localization vars in order of creation hash_map localization_map; // maps terms to their localization vars hash_map localization_pf_map; // maps terms to proofs of their localizations /* "localize" a term e to a given frame range, creating new symbols to represent non-local subterms. This returns the localized version e_l, as well as a proof thet e_l = l. */ ast make_refl(const ast &e){ return mk_true(); // TODO: is this right? } ast make_equiv(const ast &x, const ast &y){ if(get_type(x) == bool_type()) return make(Iff,x,y); else return make(Equal,x,y); } ast localize_term(ast e, const prover::range &rng, ast &pf){ ast orig_e = e; pf = make_refl(e); // proof that e = e prover::range erng = pv->ast_scope(e); if(!(erng.lo > erng.hi) && pv->ranges_intersect(pv->ast_scope(e),rng)){ return e; // this term occurs in range, so it's O.K. } hash_map::iterator it = localization_map.find(e); if(it != localization_map.end()){ pf = localization_pf_map[e]; return it->second; } // if is is non-local, we must first localize the arguments to // the range of its function symbol int nargs = num_args(e); if(nargs > 0 /* && (!is_local(e) || flo <= hi || fhi >= lo) */){ prover::range frng = rng; if(op(e) == Uninterpreted){ symb f = sym(e); prover::range srng = pv->sym_range(f); if(pv->ranges_intersect(srng,rng)) // localize to desired range if possible frng = pv->range_glb(srng,rng); } std::vector largs(nargs); std::vector eqs; std::vector pfs; for(int i = 0; i < nargs; i++){ ast argpf; largs[i] = localize_term(arg(e,i),frng,argpf); frng = pv->range_glb(frng,pv->ast_scope(largs[i])); if(largs[i] != arg(e,i)){ eqs.push_back(make_equiv(largs[i],arg(e,i))); pfs.push_back(argpf); } } e = clone(e,largs); if(pfs.size()) pf = make_congruence(eqs,make_equiv(e,orig_e),pfs); // assert(is_local(e)); } if(pv->ranges_intersect(pv->ast_scope(e),rng)) return e; // this term occurs in range, so it's O.K. // choose a frame for the constraint that is close to range int frame = pv->range_near(pv->ast_scope(e),rng); ast new_var = fresh_localization_var(e,frame); localization_map[e] = new_var; std::vector foo; foo.push_back(make_equiv(new_var,e)); ast bar = make_assumption(frame,foo); pf = make_transitivity(new_var,e,orig_e,bar,pf); localization_pf_map[orig_e] = pf; return new_var; } ast add_quants(ast e){ for(int i = localization_vars.size() - 1; i >= 0; i--){ LocVar &lv = localization_vars[i]; opr quantifier = (pv->in_range(lv.frame,rng)) ? Exists : Forall; e = apply_quant(quantifier,lv.var,e); } return e; } node make_resolution(ast pivot, node premise1, node premise2) { std::vector lits; return make_resolution(pivot,lits,premise1,premise2); } /* Return an interpolant from a proof of false */ ast interpolate(const node &pf){ // proof of false must be a formula, with quantified symbols return add_quants(z3_simplify(pf)); } ast resolve_with_quantifier(const ast &pivot1, const ast &conj1, const ast &pivot2, const ast &conj2){ if(is_not(arg(pivot1,1))) return resolve_with_quantifier(pivot2,conj2,pivot1,conj1); ast eqpf; ast P = arg(pivot1,1); ast Ploc = localize_term(P, rng, eqpf); ast pPloc = make_hypothesis(Ploc); ast pP = make_mp(make(Iff,Ploc,P),pPloc,eqpf); ast rP = make_resolution(P,conj1,pP); ast nP = mk_not(P); ast nPloc = mk_not(Ploc); ast neqpf = make_congruence(make(Iff,Ploc,P),make(Iff,nPloc,nP),eqpf); ast npPloc = make_hypothesis(nPloc); ast npP = make_mp(make(Iff,nPloc,nP),npPloc,neqpf); ast nrP = make_resolution(nP,conj2,npP); ast res = make_resolution(Ploc,rP,nrP); return res; } ast get_contra_coeff(const ast &f){ ast c = arg(f,0); // if(!is_not(arg(f,1))) // c = make(Uminus,c); return c; } ast my_or(const ast &a, const ast &b){ return mk_or(a,b); } ast my_and(const ast &a, const ast &b){ return mk_and(a,b); } ast my_implies(const ast &a, const ast &b){ return mk_implies(a,b); } ast my_or(const std::vector &a){ return mk_or(a); } ast my_and(const std::vector &a){ return mk_and(a); } ast get_lit_atom(const ast &l){ if(op(l) == Not) return arg(l,0); return l; } bool is_placeholder(const ast &e){ if(op(e) == Uninterpreted){ std::string name = string_of_symbol(sym(e)); if(name.size() > 2 && name[0] == '@' && name[1] == 'p') return true; } return false; } public: iz3proof_itp_impl(prover *p, const prover::range &r, bool w) : iz3proof_itp(*p) { pv = p; rng = r; weak = false ; //w; type boolintbooldom[3] = {bool_type(),int_type(),bool_type()}; type booldom[1] = {bool_type()}; type boolbooldom[2] = {bool_type(),bool_type()}; type boolboolbooldom[3] = {bool_type(),bool_type(),bool_type()}; type intbooldom[2] = {int_type(),bool_type()}; contra = function("@contra",2,boolbooldom,bool_type()); m().inc_ref(contra); sum = function("@sum",3,boolintbooldom,bool_type()); m().inc_ref(sum); rotate_sum = function("@rotsum",2,boolbooldom,bool_type()); m().inc_ref(rotate_sum); leq2eq = function("@leq2eq",3,boolboolbooldom,bool_type()); m().inc_ref(leq2eq); eq2leq = function("@eq2leq",2,boolbooldom,bool_type()); m().inc_ref(eq2leq); cong = function("@cong",3,boolintbooldom,bool_type()); m().inc_ref(cong); exmid = function("@exmid",3,boolboolbooldom,bool_type()); m().inc_ref(exmid); symm = function("@symm",1,booldom,bool_type()); m().inc_ref(symm); epsilon = make_var("@eps",int_type()); modpon = function("@mp",3,boolboolbooldom,bool_type()); m().inc_ref(modpon); no_proof = make_var("@nop",bool_type()); concat = function("@concat",2,boolbooldom,bool_type()); m().inc_ref(concat); top_pos = make_var("@top_pos",bool_type()); add_pos = function("@add_pos",2,intbooldom,bool_type()); m().inc_ref(add_pos); rewrite_A = function("@rewrite_A",3,boolboolbooldom,bool_type()); m().inc_ref(rewrite_A); rewrite_B = function("@rewrite_B",3,boolboolbooldom,bool_type()); m().inc_ref(rewrite_B); } ~iz3proof_itp_impl(){ m().dec_ref(contra); m().dec_ref(sum); m().dec_ref(rotate_sum); m().dec_ref(leq2eq); m().dec_ref(eq2leq); m().dec_ref(cong); m().dec_ref(exmid); m().dec_ref(symm); m().dec_ref(modpon); m().dec_ref(concat); m().dec_ref(add_pos); m().dec_ref(rewrite_A); m().dec_ref(rewrite_B); } }; iz3proof_itp *iz3proof_itp::create(prover *p, const prover::range &r, bool w){ return new iz3proof_itp_impl(p,r,w); }