diff --git a/src/interp/iz3proof_itp.cpp b/src/interp/iz3proof_itp.cpp index 9d241398c..1eb4fcc84 100755 --- a/src/interp/iz3proof_itp.cpp +++ b/src/interp/iz3proof_itp.cpp @@ -729,29 +729,31 @@ class iz3proof_itp_impl : public iz3proof_itp { ast x = arg(equality,0); ast y = arg(equality,1); ast Aproves1 = mk_true(), Bproves1 = mk_true(); - ast xleqy = round_ineq(ineq_from_chain(arg(pf,1),Aproves1,Bproves1)); - ast yleqx = round_ineq(ineq_from_chain(arg(pf,2),Aproves1,Bproves1)); + ast pf1 = destruct_cond_ineq(arg(pf,1), Aproves1, Bproves1); + ast pf2 = destruct_cond_ineq(arg(pf,2), Aproves1, Bproves1); + ast xleqy = round_ineq(ineq_from_chain(pf1,Aproves1,Bproves1)); + ast yleqx = round_ineq(ineq_from_chain(pf2,Aproves1,Bproves1)); ast ineq1 = make(Leq,make_int("0"),make_int("0")); sum_cond_ineq(ineq1,make_int("-1"),xleqy,Aproves1,Bproves1); sum_cond_ineq(ineq1,make_int("-1"),yleqx,Aproves1,Bproves1); - Bproves1 = my_and(Bproves1,z3_simplify(ineq1)); + ast Acond = my_implies(Aproves1,my_and(Bproves1,z3_simplify(ineq1))); ast Aproves2 = mk_true(), Bproves2 = mk_true(); ast ineq2 = make(Leq,make_int("0"),make_int("0")); sum_cond_ineq(ineq2,make_int("1"),xleqy,Aproves2,Bproves2); sum_cond_ineq(ineq2,make_int("1"),yleqx,Aproves2,Bproves2); - Bproves2 = z3_simplify(ineq2); - if(!is_true(Aproves1) || !is_true(Aproves2)) - throw "help!"; + ast Bcond = my_implies(Bproves1,my_and(Aproves1,z3_simplify(ineq2))); + // if(!is_true(Aproves1) || !is_true(Bproves1)) + // std::cout << "foo!\n";; if(get_term_type(x) == LitA){ ast iter = z3_simplify(make(Plus,x,get_ineq_rhs(xleqy))); - ast rewrite1 = make_rewrite(LitA,top_pos,Bproves1,make(Equal,x,iter)); - ast rewrite2 = make_rewrite(LitB,top_pos,Bproves2,make(Equal,iter,y)); + ast rewrite1 = make_rewrite(LitA,top_pos,Acond,make(Equal,x,iter)); + ast rewrite2 = make_rewrite(LitB,top_pos,Bcond,make(Equal,iter,y)); return chain_cons(chain_cons(mk_true(),rewrite1),rewrite2); } if(get_term_type(y) == LitA){ ast iter = z3_simplify(make(Plus,y,get_ineq_rhs(yleqx))); - ast rewrite2 = make_rewrite(LitA,top_pos,Bproves1,make(Equal,iter,y)); - ast rewrite1 = make_rewrite(LitB,top_pos,Bproves2,make(Equal,x,iter)); + ast rewrite2 = make_rewrite(LitA,top_pos,Acond,make(Equal,iter,y)); + ast rewrite1 = make_rewrite(LitB,top_pos,Bcond,make(Equal,x,iter)); return chain_cons(chain_cons(mk_true(),rewrite1),rewrite2); } throw cannot_simplify(); @@ -922,6 +924,8 @@ class iz3proof_itp_impl : public iz3proof_itp { return chain; } + struct subterm_normals_failed {}; + void get_subterm_normals(const ast &ineq1, const ast &ineq2, const ast &chain, ast &normals, const ast &pos, hash_set &memo, ast &Aproves, ast &Bproves){ opr o1 = op(ineq1); @@ -935,14 +939,77 @@ class iz3proof_itp_impl : public iz3proof_itp { get_subterm_normals(arg(ineq1,i), arg(ineq2,i), chain, normals, new_pos, memo, Aproves, Bproves); } } - else if(get_term_type(ineq2) == LitMixed && memo.find(ineq2) == memo.end()){ - memo.insert(ineq2); - ast sub_chain = extract_rewrites(chain,pos); - if(is_true(sub_chain)) - throw "bad inequality rewriting"; - ast new_normal = make_normal_step(ineq2,ineq1,reverse_chain(sub_chain)); - normals = merge_normal_chains(normals,cons_normal(new_normal,mk_true()), Aproves, Bproves); + else if(get_term_type(ineq2) == LitMixed){ + if(memo.find(ineq2) == memo.end()){ + memo.insert(ineq2); + ast sub_chain = extract_rewrites(chain,pos); + if(is_true(sub_chain)) + throw "bad inequality rewriting"; + ast new_normal = make_normal_step(ineq2,ineq1,reverse_chain(sub_chain)); + normals = merge_normal_chains(normals,cons_normal(new_normal,mk_true()), Aproves, Bproves); + } } + else if(!(ineq1 == ineq2)) + throw subterm_normals_failed(); + } + + ast rewrites_to_normals(const ast &ineq1, const ast &chain, ast &normals, ast &Aproves, ast &Bproves, ast &Aineqs){ + if(is_true(chain)) + return ineq1; + ast last = chain_last(chain); + ast rest = chain_rest(chain); + ast new_ineq1 = rewrites_to_normals(ineq1, rest, normals, Aproves, Bproves, Aineqs); + ast p1 = rewrite_pos(last); + ast term1; + ast coeff = arith_rewrite_coeff(new_ineq1,p1,term1); + ast res = subst_in_pos(new_ineq1,rewrite_pos(last),rewrite_rhs(last)); + ast rpos; + pos_diff(p1,rewrite_pos(last),rpos); + ast term2 = subst_in_pos(term1,rpos,rewrite_rhs(last)); + if(get_term_type(term1) != LitMixed && get_term_type(term2) != LitMixed){ + if(is_rewrite_side(LitA,last)) + linear_comb(Aineqs,coeff,make(Leq,make_int(rational(0)),make(Sub,term2,term1))); + } + else { + ast pf = extract_rewrites(make(concat,mk_true(),rest),p1); + ast new_normal = fix_normal(term1,term2,pf); + normals = merge_normal_chains(normals,cons_normal(new_normal,mk_true()), Aproves, Bproves); + } + return res; + } + + ast arith_rewrite_coeff(const ast &ineq, ast &p1, ast &term){ + ast coeff = make_int(rational(1)); + if(p1 == top_pos){ + term = ineq; + return coeff; + } + int argpos = pos_arg(p1); + opr o = op(ineq); + switch(o){ + case Leq: + case Lt: + coeff = argpos ? make_int(rational(1)) : make_int(rational(-1)); + break; + case Geq: + case Gt: + coeff = argpos ? make_int(rational(-1)) : make_int(rational(1)); + break; + case Not: + case Plus: + break; + case Times: + coeff = arg(ineq,0); + break; + default: + p1 = top_pos; + term = ineq; + return coeff; + } + p1 = arg(p1,1); + ast res = arith_rewrite_coeff(arg(ineq,argpos),p1,term); + p1 = pos_add(argpos,p1); + return coeff == make_int(rational(1)) ? res : make(Times,coeff,res); } ast rewrite_chain_to_normal_ineq(const ast &chain, ast &Aproves, ast &Bproves){ @@ -952,17 +1019,20 @@ class iz3proof_itp_impl : public iz3proof_itp { ast ineq2 = apply_rewrite_chain(ineq1,tail); ast nc = mk_true(); hash_set memo; - get_subterm_normals(ineq1,ineq2,tail,nc,top_pos,memo, Aproves, Bproves); - ast itp; + ast itp = make(Leq,make_int(rational(0)),make_int(rational(0))); + ast Aproves_save = Aproves, Bproves_save = Bproves; try { + get_subterm_normals(ineq1,ineq2,tail,nc,top_pos,memo, Aproves, Bproves); + } + catch (const subterm_normals_failed &){ Aproves = Aproves_save; Bproves = Bproves_save; nc = mk_true(); + rewrites_to_normals(ineq1, tail, nc, Aproves, Bproves, itp); + } if(is_rewrite_side(LitA,head)){ - itp = make(Leq,make_int("0"),make_int("0")); linear_comb(itp,make_int("1"),ineq1); // make sure it is normal form //itp = ineq1; ast mc = z3_simplify(chain_side_proves(LitB,pref)); Bproves = my_and(Bproves,mc); } else { - itp = make(Leq,make_int(rational(0)),make_int(rational(0))); ast mc = z3_simplify(chain_side_proves(LitA,pref)); Aproves = my_and(Aproves,mc); }