diff --git a/src/interp/iz3mgr.h b/src/interp/iz3mgr.h index 218651520..dc5af9e26 100644 --- a/src/interp/iz3mgr.h +++ b/src/interp/iz3mgr.h @@ -325,6 +325,13 @@ class iz3mgr { return rational(1); } + ast get_linear_var(const ast& t){ + rational res; + if(op(t) == Times && is_numeral(arg(t,0),res)) + return arg(t,1); + return t; + } + int get_quantifier_num_bound(const ast &t) { return to_quantifier(t.raw())->get_num_decls(); } diff --git a/src/interp/iz3proof_itp.cpp b/src/interp/iz3proof_itp.cpp index f460dfaff..41237200c 100644 --- a/src/interp/iz3proof_itp.cpp +++ b/src/interp/iz3proof_itp.cpp @@ -89,12 +89,35 @@ class iz3proof_itp_impl : public iz3proof_itp { and q of ~Q and returns a proof of false. */ symb modpon; + /* This oprerator represents a concatenation of rewrites. The term + a=b;c=d represents an A rewrite from a to b, followed by a B + rewrite fron b to c, followed by an A rewrite from c to d. + */ + symb concat; + /* This represents a lack of a proof */ ast no_proof; // This is used to represent an infinitessimal value ast epsilon; + // Represents the top position of a term + ast top_pos; + + // add_pos(i,pos) represents position pos if the ith argument + symb add_pos; + + // rewrite proof rules + + /* rewrite_A(pos,cond,x=y) derives A |- cond => t[x]_p = t[y]_p + where t is an arbitrary term */ + symb rewrite_A; + + /* rewrite_B(pos,cond,x=y) derives B |- cond => t[x]_p = t[y]_p, + where t is an arbitrary term */ + symb rewrite_B; + + ast get_placeholder(ast t){ hash_map::iterator it = placeholders.find(t); if(it != placeholders.end()) @@ -143,6 +166,14 @@ class iz3proof_itp_impl : public iz3proof_itp { return pv->ranges_intersect(r,rng) && !pv->range_contained(r,rng); } + bool term_in_vocab(LitType ty, const ast &lit){ + prover::range r = pv->ast_scope(lit); + if(ty == LitA){ + return pv->ranges_intersect(r,rng); + } + return !pv->range_contained(r,rng); + } + /** Make a resolution node with given pivot literal and premises. The conclusion of premise1 should contain the negation of the pivot literal, while the conclusion of premise2 should contain the @@ -537,16 +568,38 @@ class iz3proof_itp_impl : public iz3proof_itp { return rotate_sum_rec(pl,pf,cond,ineq); } + bool is_rewrite_chain(const ast &chain){ + return sym(chain) == concat; + } + + ast ineq_from_chain(const ast &chain, ast &cond){ + if(is_rewrite_chain(chain)){ + ast last = chain_last(chain); + ast rest = chain_rest(chain); + if(is_true(rest) && is_rewrite_side(LitA,last) + && is_true(rewrite_lhs(last))){ + cond = my_and(cond,rewrite_cond(last)); + return rewrite_rhs(last); + } + if(is_rewrite_side(LitB,last) && is_true(rewrite_cond(last))) + return ineq_from_chain(rest,cond); + } + return chain; + } + void sum_cond_ineq(ast &ineq, ast &cond, const ast &coeff2, const ast &ineq2){ opr o = op(ineq2); if(o == Implies){ sum_cond_ineq(ineq,cond,coeff2,arg(ineq2,1)); cond = my_and(cond,arg(ineq2,0)); } - else if(is_ineq(ineq2)) - linear_comb(ineq,coeff2,ineq2); - else - throw cannot_simplify(); + else { + ast the_ineq = ineq_from_chain(ineq2,cond); + if(is_ineq(the_ineq)) + linear_comb(ineq,coeff2,the_ineq); + else + throw cannot_simplify(); + } } bool is_ineq(const ast &ineq){ @@ -556,7 +609,12 @@ class iz3proof_itp_impl : public iz3proof_itp { } // divide both sides of inequality by a non-negative integer divisor - ast idiv_ineq(const ast &ineq, const ast &divisor){ + ast idiv_ineq(const ast &ineq1, const ast &divisor){ + if(divisor == make_int(rational(1))) + return ineq1; + ast ineq = ineq1; + if(op(ineq) == Lt) + ineq = simplify_ineq(make(Leq,arg(ineq,0),make(Sub,arg(ineq,1),make_int("1")))); return make(op(ineq),mk_idiv(arg(ineq,0),divisor),mk_idiv(arg(ineq,1),divisor)); } @@ -580,21 +638,31 @@ class iz3proof_itp_impl : public iz3proof_itp { ast equality = arg(neg_equality,0); ast x = arg(equality,0); ast y = arg(equality,1); - ast xleqy = round_ineq(arg(pf,1)); - ast yleqx = round_ineq(arg(pf,2)); - ast itpeq; - if(get_term_type(x) == LitA) - itpeq = make(Equal,x,z3_simplify(make(Plus,x,get_ineq_rhs(xleqy)))); - else if(get_term_type(y) == LitA) - itpeq = make(Equal,z3_simplify(make(Plus,y,get_ineq_rhs(yleqx))),y); - else - throw cannot_simplify(); - ast cond = mk_true(); - ast ineq = make(Leq,make_int("0"),make_int("0")); - sum_cond_ineq(ineq,cond,make_int("-1"),xleqy); - sum_cond_ineq(ineq,cond,make_int("-1"),yleqx); - cond = z3_simplify(my_and(cond,ineq)); - return my_implies(cond,itpeq); + ast cond1 = mk_true(); + ast xleqy = round_ineq(ineq_from_chain(arg(pf,1),cond1)); + ast yleqx = round_ineq(ineq_from_chain(arg(pf,2),cond1)); + ast ineq1 = make(Leq,make_int("0"),make_int("0")); + sum_cond_ineq(ineq1,cond1,make_int("-1"),xleqy); + sum_cond_ineq(ineq1,cond1,make_int("-1"),yleqx); + cond1 = my_and(cond1,z3_simplify(ineq1)); + ast cond2 = mk_true(); + ast ineq2 = make(Leq,make_int("0"),make_int("0")); + sum_cond_ineq(ineq2,cond2,make_int("1"),xleqy); + sum_cond_ineq(ineq2,cond2,make_int("1"),yleqx); + cond2 = z3_simplify(ineq2); + if(get_term_type(x) == LitA){ + ast iter = z3_simplify(make(Plus,x,get_ineq_rhs(xleqy))); + ast rewrite1 = make_rewrite(LitA,top_pos,cond1,make(Equal,x,iter)); + ast rewrite2 = make_rewrite(LitB,top_pos,cond2,make(Equal,iter,y)); + return chain_cons(chain_cons(mk_true(),rewrite1),rewrite2); + } + if(get_term_type(y) == LitA){ + ast iter = z3_simplify(make(Plus,y,get_ineq_rhs(yleqx))); + ast rewrite2 = make_rewrite(LitA,top_pos,cond1,make(Equal,iter,y)); + ast rewrite1 = make_rewrite(LitB,top_pos,cond2,make(Equal,x,iter)); + return chain_cons(chain_cons(mk_true(),rewrite1),rewrite2); + } + throw cannot_simplify(); } throw cannot_simplify(); } @@ -612,10 +680,13 @@ class iz3proof_itp_impl : public iz3proof_itp { if(pl == arg(pf,1)){ ast cond = mk_true(); ast equa = sep_cond(arg(pf,0),cond); - if(op(equa) == Equal) - return my_implies(cond,z3_simplify(make(Leq,make_int("0"),make(Sub,arg(equa,1),arg(equa,0))))); - if(op(equa) == True) - return my_implies(cond,z3_simplify(make(Leq,make_int("0"),make_int("0")))); + if(is_equivrel_chain(equa)){ + ast ineqs= chain_ineqs(LitA,equa); + cond = my_and(cond,chain_conditions(LitA,equa)); + ast Bconds = chain_conditions(LitB,equa); + if(is_true(Bconds) && op(ineqs) != And) + return my_implies(cond,ineqs); + } } throw cannot_simplify(); } @@ -626,7 +697,9 @@ class iz3proof_itp_impl : public iz3proof_itp { args[0] = arg(pf,0); args[1] = arg(pf,1); args[2] = mk_true(); - return simplify_modpon(args); + ast cond = mk_true(); + ast chain = simplify_modpon_fwd(args, cond); + return my_implies(cond,chain); } throw cannot_simplify(); } @@ -649,6 +722,7 @@ class iz3proof_itp_impl : public iz3proof_itp { int ipos = pos.get_unsigned(); ast cond = mk_true(); ast equa = sep_cond(arg(pf,0),cond); +#if 0 if(op(equa) == Equal){ ast pe = mk_not(neg_equality); ast lhs = subst_in_arg_pos(ipos,arg(equa,0),arg(pe,0)); @@ -656,6 +730,9 @@ class iz3proof_itp_impl : public iz3proof_itp { ast res = make(Equal,lhs,rhs); return my_implies(cond,res); } +#endif + ast res = chain_pos_add(ipos,equa); + return my_implies(cond,res); } } throw cannot_simplify(); @@ -666,11 +743,12 @@ class iz3proof_itp_impl : public iz3proof_itp { return mk_true(); ast cond = mk_true(); ast equa = sep_cond(args[0],cond); - if(is_equivrel(equa)) - return my_implies(cond,make(op(equa),arg(equa,1),arg(equa,0))); + if(is_equivrel_chain(equa)) + return my_implies(cond,reverse_chain(equa)); throw cannot_simplify(); } +#if 0 ast simplify_modpon(const std::vector &args){ if(op(args[1]) == True){ ast cond = mk_true(); @@ -720,6 +798,62 @@ class iz3proof_itp_impl : public iz3proof_itp { } throw cannot_simplify(); } +#else + ast simplify_modpon_fwd(const std::vector &args, ast &cond){ + ast P = sep_cond(args[0],cond); + ast PeqQ = sep_cond(args[1],cond); + ast chain; + if(is_equivrel_chain(P)){ + ast split[2]; + split_chain(PeqQ,split); + chain = reverse_chain(split[0]); + chain = concat_rewrite_chain(chain,P); + chain = concat_rewrite_chain(chain,split[1]); + } + else // if not an equavalence, must be of form T <-> pred + chain = concat_rewrite_chain(P,PeqQ); + return chain; + } + +#if 0 + ast simplify_modpon(const std::vector &args){ + ast cond = mk_true(); + ast chain = simplify_modpon_fwd(args,cond); + ast Q2 = sep_cond(args[2],cond); + ast interp; + if(is_equivrel_chain(Q2)){ + chain = concat_rewrite_chain(chain,chain_pos_add(0,chain_pos_add(0,Q2))); + interp = my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),chain_conditions(LitB,chain)); + } + else if(is_rewrite_side(LitB,chain_last(Q2))) + interp = my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),chain_conditions(LitB,chain)); + else + interp = my_and(chain_conditions(LitB,chain),my_implies(chain_conditions(LitA,chain),mk_not(chain_formulas(LitB,chain)))); + return my_implies(cond,interp); + } +#endif + + /* Given a chain rewrite chain deriving not P and a rewrite chain deriving P, return an interpolant. */ + ast contra_chain(const ast &neg_chain, const ast &pos_chain){ + // equality is a special case. we use the derivation of x=y to rewrite not(x=y) to not(y=y) + if(is_equivrel_chain(pos_chain)){ + ast chain = concat_rewrite_chain(neg_chain,chain_pos_add(0,chain_pos_add(0,pos_chain))); + return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),chain_conditions(LitB,chain)); + } + // otherwise, we reverse the derivation of t = P and use it to rewrite not(P) to not(t) + ast chain = concat_rewrite_chain(neg_chain,chain_pos_add(0,reverse_chain(pos_chain))); + return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),chain_conditions(LitB,chain)); + } + + ast simplify_modpon(const std::vector &args){ + ast cond = mk_true(); + ast chain = simplify_modpon_fwd(args,cond); + ast Q2 = sep_cond(args[2],cond); + ast interp = is_negation_chain(chain) ? contra_chain(chain,Q2) : contra_chain(Q2,chain); + return my_implies(cond,interp); + } + +#endif bool is_equivrel(const ast &p){ opr o = op(p); @@ -801,6 +935,121 @@ class iz3proof_itp_impl : public iz3proof_itp { return res; } + bool diff_rec(LitType t, const ast &p, const ast &q, ast &pd, ast &qd){ + if(p == q) + return false; + if(term_in_vocab(t,p) && term_in_vocab(t,q)){ + pd = p; + qd = q; + return true; + } + else { + if(sym(p) != sym(q)) return false; + int nargs = num_args(p); + if(num_args(q) != nargs) return false; + for(int i = 0; i < nargs; i++) + if(diff_rec(t,arg(p,i),arg(q,i),pd,qd)) + return true; + return false; + } + } + + void diff(LitType t, const ast &p, const ast &q, ast &pd, ast &qd){ + if(!diff_rec(t,p,q,pd,qd)) + throw cannot_simplify(); + } + + bool apply_diff_rec(LitType t, const ast &inp, const ast &p, const ast &q, ast &out){ + if(p == q) + return false; + if(term_in_vocab(t,p) && term_in_vocab(t,q)){ + if(inp != p) + throw cannot_simplify(); + out = q; + return true; + } + else { + int nargs = num_args(p); + if(sym(p) != sym(q)) throw cannot_simplify(); + if(num_args(q) != nargs) throw cannot_simplify(); + if(sym(p) != sym(inp)) throw cannot_simplify(); + if(num_args(inp) != nargs) throw cannot_simplify(); + for(int i = 0; i < nargs; i++) + if(apply_diff_rec(t,arg(inp,i),arg(p,i),arg(q,i),out)) + return true; + return false; + } + } + + ast apply_diff(LitType t, const ast &inp, const ast &p, const ast &q){ + ast out; + if(!apply_diff_rec(t,inp,p,q,out)) + throw cannot_simplify(); + return out; + } + + bool merge_A_rewrites(const ast &A1, const ast &A2, ast &merged) { + if(arg(A1,1) == arg(A2,0)){ + merged = make(op(A1),arg(A1,0),arg(A2,1)); + return true; + } + ast diff1l, diff1r, diff2l, diff2r,diffBl,diffBr; + diff(LitA,arg(A1,0),arg(A1,1),diff1l,diff1r); + diff(LitA,arg(A2,0),arg(A2,1),diff2l,diff2r); + diff(LitB,arg(A1,1),arg(A2,0),diffBl,diffBr); + if(!term_common(diff2l) && !term_common(diffBr)){ + ast A1r = apply_diff(LitB,arg(A2,1),arg(A2,0),arg(A1,1)); + merged = make(op(A1),arg(A1,0),A1r); + return true; + } + if(!term_common(diff1r) && !term_common(diffBl)){ + ast A2l = apply_diff(LitB,arg(A1,0),arg(A1,1),arg(A2,0)); + merged = make(op(A1),A2l,arg(A2,1)); + return true; + } + return false; + } + + void collect_A_rewrites(const ast &t, std::vector &res){ + if(is_true(t)) + return; + if(sym(t) == concat){ + res.push_back(arg(t,0)); + collect_A_rewrites(arg(t,0),res); + return; + } + res.push_back(t); + } + + ast concat_A_rewrites(const std::vector &rew){ + if(rew.size() == 0) + return mk_true(); + ast res = rew[0]; + for(unsigned i = 1; i < rew.size(); i++) + res = make(concat,res,rew[i]); + return res; + } + + ast merge_concat_rewrites(const ast &A1, const ast &A2){ + std::vector rew; + collect_A_rewrites(A1,rew); + int first = rew.size(), last = first; // range that might need merging + collect_A_rewrites(A2,rew); + while(first > 0 && first < (int)rew.size() && first <= last){ + ast merged; + if(merge_A_rewrites(rew[first-1],rew[first],merged)){ + rew[first] = merged; + first--; + rew.erase(rew.begin()+first); + last--; + if(first >= last) last = first+1; + } + else + first++; + } + return concat_A_rewrites(rew); + } + ast sep_cond(const ast &t, ast &cond){ if(op(t) == Implies){ cond = my_and(cond,arg(t,0)); @@ -809,6 +1058,326 @@ class iz3proof_itp_impl : public iz3proof_itp { return t; } + + /* operations on term positions */ + + /** Finds the difference between two positions. If p1 < p2 (p1 is a + position below p2), returns -1 and sets diff to p2-p1 (the psath + from position p2 to position p1). If p2 < p1 (p2 is a position + below p1), returns 1 and sets diff to p1-p2 (the psath from + position p1 to position p2). If equal, return 0 and set diff to + top_pos. Else (if p1 and p2 are independent) returns 2 and + leaves diff unchanged. */ + + int pos_diff(const ast &p1, const ast &p2, ast &diff){ + if(p1 == top_pos && p2 != top_pos){ + diff = p2; + return 1; + } + if(p2 == top_pos && p1 != top_pos){ + diff = p1; + return -1; + } + if(p1 == top_pos && p2 == top_pos){ + diff = p1; + return 0; + } + if(arg(p1,0) == arg(p2,0)) // same argument position, recur + return pos_diff(arg(p1,1),arg(p2,1),diff); + return 2; + } + + /* return the position of pos in the argth argument */ + ast pos_add(int arg, const ast &pos){ + return make(add_pos,make_int(rational(arg)),pos); + } + + /* return the argument number of position, if not top */ + int pos_arg(const ast &pos){ + rational r; + if(is_numeral(arg(pos,0),r)) + return r.get_unsigned(); + throw "bad position!"; + } + + /* substitute y into position pos in x */ + ast subst_in_pos(const ast &x, const ast &pos, const ast &y){ + if(pos == top_pos) + return y; + int p = pos_arg(pos); + int nargs = num_args(x); + if(p >= 0 && p < nargs){ + std::vector args(nargs); + for(int i = 0; i < nargs; i++) + args[i] = i == p ? subst_in_pos(arg(x,i),arg(pos,1),y) : arg(x,i); + return clone(x,args); + } + throw "bad term position!"; + } + + + /* operations on rewrites */ + ast make_rewrite(LitType t, const ast &pos, const ast &cond, const ast &equality){ + return make(t == LitA ? rewrite_A : rewrite_B, pos, cond, equality); + } + + ast rewrite_pos(const ast &rew){ + return arg(rew,0); + } + + ast rewrite_cond(const ast &rew){ + return arg(rew,1); + } + + ast rewrite_equ(const ast &rew){ + return arg(rew,2); + } + + ast rewrite_lhs(const ast &rew){ + return arg(arg(rew,2),0); + } + + ast rewrite_rhs(const ast &rew){ + return arg(arg(rew,2),1); + } + + /* operations on rewrite chains */ + + ast chain_cons(const ast &chain, const ast &elem){ + return make(concat,chain,elem); + } + + ast chain_rest(const ast &chain){ + return arg(chain,0); + } + + ast chain_last(const ast &chain){ + return arg(chain,1); + } + + ast rewrite_update_rhs(const ast &rew, const ast &pos, const ast &new_rhs, const ast &new_cond){ + ast foo = subst_in_pos(rewrite_rhs(rew),pos,new_rhs); + ast equality = arg(rew,2); + return make(sym(rew),rewrite_pos(rew),my_and(rewrite_cond(rew),new_cond),make(op(equality),arg(equality,0),foo)); + } + + ast rewrite_update_lhs(const ast &rew, const ast &pos, const ast &new_lhs, const ast &new_cond){ + ast foo = subst_in_pos(rewrite_lhs(rew),pos,new_lhs); + ast equality = arg(rew,2); + return make(sym(rew),rewrite_pos(rew),my_and(rewrite_cond(rew),new_cond),make(op(equality),foo,arg(equality,1))); + } + + bool is_common_rewrite(const ast &rew){ + return term_common(arg(rew,2)); + } + + bool is_right_mover(const ast &rew){ + return term_common(rewrite_lhs(rew)) && !term_common(rewrite_rhs(rew)); + } + + bool is_left_mover(const ast &rew){ + return term_common(rewrite_rhs(rew)) && !term_common(rewrite_lhs(rew)); + } + + bool same_side(const ast &rew1, const ast &rew2){ + return sym(rew1) == sym(rew2); + } + + bool is_rewrite_side(LitType t, const ast &rew){ + if(t == LitA) + return sym(rew) == rewrite_A; + return sym(rew) == rewrite_B; + } + + ast rewrite_to_formula(const ast &rew){ + return my_implies(arg(rew,1),arg(rew,2)); + } + + // make rewrite rew conditon on rewrite cond + ast rewrite_conditional(const ast &cond, const ast &rew){ + ast cf = rewrite_to_formula(cond); + return make(sym(rew),arg(rew,0),my_and(arg(rew,1),cf),arg(rew,2)); + } + + ast reverse_rewrite(const ast &rew){ + ast equ = arg(rew,2); + return make(sym(rew),arg(rew,0),arg(rew,1),make(op(equ),arg(equ,1),arg(equ,0))); + } + + ast rewrite_pos_add(int apos, const ast &rew){ + return make(sym(rew),pos_add(apos,arg(rew,0)),arg(rew,1),arg(rew,2)); + } + + ast rewrite_up(const ast &rew){ + return make(sym(rew),arg(arg(rew,0),1),arg(rew,1),arg(rew,2)); + } + + /** Adds a rewrite to a chain of rewrites, keeping the chain in + normal form. An empty chain is represented by true.*/ + + ast add_rewrite_to_chain(const ast &chain, const ast &rewrite){ + if(is_true(chain)) + return chain_cons(chain,rewrite); + ast last = chain_last(chain); + ast rest = chain_rest(chain); + if(same_side(last,rewrite)){ + ast p1 = rewrite_pos(last); + ast p2 = rewrite_pos(rewrite); + ast diff; + switch(pos_diff(p1,p2,diff)){ + case 1: { + ast absorb = rewrite_update_rhs(last,diff,rewrite_rhs(rewrite),rewrite_cond(rewrite)); + return add_rewrite_to_chain(rest,absorb); + } + case 0: + case -1: { + ast absorb = rewrite_update_lhs(rewrite,diff,rewrite_lhs(last),rewrite_cond(last)); + return add_rewrite_to_chain(rest,absorb); + } + default: {// independent case + bool rm = is_right_mover(last); + bool lm = is_left_mover(rewrite); + if((lm && !rm) || (rm && !lm)) + return chain_swap(rest,last,rewrite); + } + } + } + else { + if(is_left_mover(rewrite)){ + if(is_common_rewrite(last)) + return add_rewrite_to_chain(chain_cons(rest,flip_rewrite(last)),rewrite); + if(!is_left_mover(last)) + return chain_swap(rest,last,rewrite); + } + if(is_right_mover(last)){ + if(is_common_rewrite(rewrite)) + return add_rewrite_to_chain(chain,flip_rewrite(rewrite)); + if(!is_right_mover(rewrite)) + return chain_swap(rest,last,rewrite); + } + } + return chain_cons(chain,rewrite); + } + + ast chain_swap(const ast &rest, const ast &last, const ast &rewrite){ + return chain_cons(add_rewrite_to_chain(rest,rewrite),last); + } + + ast flip_rewrite(const ast &rew){ + symb flip_sym = (sym(rew) == rewrite_A) ? rewrite_B : rewrite_A; + ast cf = rewrite_to_formula(rew); + return make(flip_sym,arg(rew,0),my_implies(arg(rew,1),cf),arg(rew,2)); + } + + /** concatenates two rewrite chains, keeping result in normal form. */ + + ast concat_rewrite_chain(const ast &chain1, const ast &chain2){ + if(is_true(chain2)) return chain1; + if(is_true(chain1)) return chain2; + ast foo = concat_rewrite_chain(chain1,chain_rest(chain2)); + return add_rewrite_to_chain(foo,chain_last(chain2)); + } + + /** reverse a chain of rewrites */ + + ast reverse_chain_rec(const ast &chain, const ast &prefix){ + if(is_true(chain)) + return prefix; + ast last = reverse_rewrite(chain_last(chain)); + ast rest = chain_rest(chain); + return reverse_chain_rec(rest,chain_cons(prefix,last)); + } + + ast reverse_chain(const ast &chain){ + return reverse_chain_rec(chain,mk_true()); + } + + bool is_equivrel_chain(const ast &chain){ + if(is_true(chain)) + return true; + ast last = chain_last(chain); + ast rest = chain_rest(chain); + if(is_true(rest)) + return !is_true(rewrite_lhs(last)); + return is_equivrel_chain(rest); + } + + bool is_negation_chain(const ast &chain){ + if(is_true(chain)) + return false; + ast last = chain_last(chain); + ast rest = chain_rest(chain); + if(is_true(rest)) + return op(rewrite_rhs(last)) == Not; + return is_negation_chain(rest); + } + + /** Split a chain of rewrites two chains, operating on positions 0 and 1. + Fail if any rewrite in the chain operates on top position. */ + void split_chain_rec(const ast &chain, ast *res){ + if(is_true(chain)) + return; + ast last = chain_last(chain); + ast rest = chain_rest(chain); + split_chain_rec(rest,res); + ast pos = rewrite_pos(last); + if(pos == top_pos) + throw "bad rewrite chain"; + int arg = pos_arg(pos); + if(arg<0 || arg > 1) + throw "bad position!"; + res[arg] = chain_cons(res[arg],rewrite_up(last)); + } + + void split_chain(const ast &chain, ast *res){ + res[0] = res[1] = mk_true(); + split_chain_rec(chain,res); + } + + ast chain_conditions(LitType t, const ast &chain){ + if(is_true(chain)) + return mk_true(); + ast last = chain_last(chain); + ast rest = chain_rest(chain); + ast cond = chain_conditions(t,rest); + if(is_rewrite_side(t,last)) + cond = my_and(cond,rewrite_cond(last)); + return cond; + } + + ast chain_formulas(LitType t, const ast &chain){ + if(is_true(chain)) + return mk_true(); + ast last = chain_last(chain); + ast rest = chain_rest(chain); + ast cond = chain_formulas(t,rest); + if(is_rewrite_side(t,last)) + cond = my_and(cond,rewrite_equ(last)); + return cond; + } + + ast chain_ineqs(LitType t, const ast &chain){ + if(is_true(chain)) + return mk_true(); + ast last = chain_last(chain); + ast rest = chain_rest(chain); + ast cond = chain_ineqs(t,rest); + if(is_rewrite_side(t,last)){ + ast equa = rewrite_equ(last); + cond = my_and(cond,z3_simplify(make(Leq,make_int("0"),make(Sub,arg(equa,1),arg(equa,0))))); + } + return cond; + } + + ast chain_pos_add(int arg, const ast &chain){ + if(is_true(chain)) + return mk_true(); + ast last = rewrite_pos_add(arg,chain_last(chain)); + ast rest = chain_pos_add(arg,chain_rest(chain)); + return chain_cons(rest,last); + } + + /** Make an assumption node. The given clause is assumed in the given frame. */ virtual node make_assumption(int frame, const std::vector &assumption){ if(pv->in_range(frame,rng)){ @@ -822,7 +1391,12 @@ class iz3proof_itp_impl : public iz3proof_itp { else { return mk_true(); } -} + } + + ast make_local_rewrite(LitType t, const ast &p){ + ast rew = is_equivrel(p) ? p : make(Iff,mk_true(),p); + return chain_cons(mk_true(),make_rewrite(t, top_pos, mk_true(), rew)); + } ast triv_interp(const symb &rule, const std::vector &premises){ std::vector ps; ps.resize(premises.size()); @@ -830,12 +1404,11 @@ class iz3proof_itp_impl : public iz3proof_itp { int mask = 0; for(unsigned i = 0; i < ps.size(); i++){ ast p = premises[i]; - switch(get_term_type(p)){ + LitType t = get_term_type(p); + switch(t){ case LitA: - ps[i] = p; - break; case LitB: - ps[i] = mk_true(); + ps[i] = make_local_rewrite(t,p); break; default: ps[i] = get_placeholder(p); // can only prove consequent! @@ -1159,11 +1732,11 @@ class iz3proof_itp_impl : public iz3proof_itp { virtual ast make_mixed_congruence(const ast &x, const ast &y, const ast &p, const ast &con, const ast &prem1){ ast foo = p; std::vector conjs; - switch(get_term_type(foo)){ + LitType t = get_term_type(foo); + switch(t){ case LitA: - break; case LitB: - foo = mk_true(); + foo = make_local_rewrite(t,foo); break; case LitMixed: conjs.push_back(foo); @@ -1561,6 +2134,7 @@ public: type booldom[1] = {bool_type()}; type boolbooldom[2] = {bool_type(),bool_type()}; type boolboolbooldom[3] = {bool_type(),bool_type(),bool_type()}; + type intbooldom[2] = {int_type(),bool_type()}; contra = function("@contra",2,boolbooldom,bool_type()); sum = function("@sum",3,boolintbooldom,bool_type()); rotate_sum = function("@rotsum",2,boolbooldom,bool_type()); @@ -1572,6 +2146,11 @@ public: epsilon = make_var("@eps",int_type()); modpon = function("@mp",3,boolboolbooldom,bool_type()); no_proof = make_var("@nop",bool_type()); + concat = function("@concat",2,boolbooldom,bool_type()); + top_pos = make_var("@top_pos",bool_type()); + add_pos = function("@add_pos",2,intbooldom,bool_type()); + rewrite_A = function("@rewrite_A",3,boolboolbooldom,bool_type()); + rewrite_B = function("@rewrite_B",3,boolboolbooldom,bool_type()); } }; diff --git a/src/interp/iz3translate.cpp b/src/interp/iz3translate.cpp index 042dbe192..a41fdcba9 100755 --- a/src/interp/iz3translate.cpp +++ b/src/interp/iz3translate.cpp @@ -829,6 +829,42 @@ public: return make_int(d); } + ast get_bounded_variable(const ast &ineq, bool &lb){ + ast nineq = normalize_inequality(ineq); + ast lhs = arg(nineq,0); + switch(op(lhs)){ + case Uninterpreted: + lb = false; + return lhs; + case Times: + if(arg(lhs,0) == make_int(rational(1))) + lb = false; + else if(arg(lhs,0) == make_int(rational(-1))) + lb = true; + else + throw unsupported(); + return arg(lhs,1); + default: + throw unsupported(); + } + } + + rational get_term_coefficient(const ast &t1, const ast &v){ + ast t = arg(normalize_inequality(t1),0); + if(op(t) == Plus){ + int nargs = num_args(t); + for(int i = 0; i < nargs; i++){ + if(get_linear_var(arg(t,i)) == v) + return get_coeff(arg(t,i)); + } + } + else + if(get_linear_var(t) == v) + return get_coeff(t); + return rational(0); + } + + Iproof::node GCDtoDivRule(const ast &proof, bool pol, std::vector &coeffs, std::vector &prems, ast &cut_con){ // gather the summands of the desired polarity std::vector my_prems; @@ -843,6 +879,24 @@ public: } } ast my_con = sum_inequalities(my_coeffs,my_prem_cons); + + // handle generalized GCD test. sadly, we dont' get the coefficients... + if(coeffs[0].is_zero()){ + bool lb; + int xtra_prem = 0; + ast bv = get_bounded_variable(conc(prem(proof,0)),lb); + rational bv_coeff = get_term_coefficient(my_con,bv); + if(bv_coeff.is_pos() != lb) + xtra_prem = 1; + if(bv_coeff.is_neg()) + bv_coeff = -bv_coeff; + + my_prems.push_back(prems[xtra_prem]); + my_coeffs.push_back(make_int(bv_coeff)); + my_prem_cons.push_back(conc(prem(proof,xtra_prem))); + my_con = sum_inequalities(my_coeffs,my_prem_cons); + } + my_con = normalize_inequality(my_con); Iproof::node hyp = iproof->make_hypothesis(mk_not(my_con)); my_prems.push_back(hyp); @@ -961,6 +1015,12 @@ public: else lits.push_back(from_ast(con)); + // special case + if(dk == PR_MODUS_PONENS && pr(prem(proof,0)) == PR_QUANT_INST && pr(prem(proof,1)) == PR_REWRITE ) { + res = iproof->make_axiom(lits); + return res; + } + // translate all the premises std::vector args(nprems); for(unsigned i = 0; i < nprems; i++) @@ -1057,6 +1117,10 @@ public: res = iproof->make_hypothesis(conc(proof)); break; } + case PR_QUANT_INST: { + res = iproof->make_axiom(lits); + break; + } default: assert(0 && "translate_main: unsupported proof rule"); throw unsupported(); @@ -1066,6 +1130,11 @@ public: return res; } + void clear_translation(){ + translation.first.clear(); + translation.second.clear(); + } + // We actually compute the interpolant here and then produce a proof consisting of just a lemma iz3proof::node translate(ast proof, iz3proof &dst){ @@ -1075,6 +1144,7 @@ public: ast itp = translate_main(proof); itps.push_back(itp); delete iproof; + clear_translation(); } // Very simple proof -- lemma of the empty clause with computed interpolation iz3proof::node Ipf = dst.make_lemma(std::vector(),itps); // builds result in dst