3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-06 17:44:08 +00:00
z3/src/interp/iz3proof_itp.cpp
2014-04-03 13:20:08 -07:00

3007 lines
89 KiB
C++
Executable file

/*++
Copyright (c) 2011 Microsoft Corporation
Module Name:
iz3proof.cpp
Abstract:
This class defines a simple interpolating proof system.
Author:
Ken McMillan (kenmcmil)
Revision History:
--*/
#ifdef _WINDOWS
#pragma warning(disable:4996)
#pragma warning(disable:4800)
#pragma warning(disable:4267)
#pragma warning(disable:4101)
#endif
#include "iz3proof_itp.h"
using namespace stl_ext;
// #define INVARIANT_CHECKING
class iz3proof_itp_impl : public iz3proof_itp {
prover *pv;
prover::range rng;
bool weak;
enum LitType {LitA,LitB,LitMixed};
hash_map<ast,ast> placeholders;
// These symbols represent deduction rules
/* This symbol represents a proof by contradiction. That is,
contra(p,l1 /\ ... /\ lk) takes a proof p of
l1,...,lk |- false
and returns a proof of
|- ~l1,...,~l2
*/
symb contra;
/* The summation rule. The term sum(p,c,i) takes a proof p of an
inequality i', an integer coefficient c and an inequality i, and
yieds a proof of i' + ci. */
symb sum;
/* Proof rotation. The proof term rotate(q,p) takes a
proof p of:
Gamma, q |- false
and yields a proof of:
Gamma |- ~q
*/
symb rotate_sum;
/* Inequalities to equality. leq2eq(p, q, r) takes a proof
p of ~x=y, a proof q of x <= y and a proof r of y <= x
and yields a proof of false. */
symb leq2eq;
/* Equality to inequality. eq2leq(p, q) takes a proof p of x=y, and
a proof q ~(x <= y) and and yields a proof of false. */
symb eq2leq;
/* Proof term cong(p,q) takes a proof p of x=y and a proof
q of t != t<y/x> and returns a proof of false. */
symb cong;
/* Excluded middle. exmid(phi,p,q) takes a proof p of phi and a
proof q of ~\phi and returns a proof of false. */
symb exmid;
/* Symmetry. symm(p) takes a proof p of x=y and produces
a proof of y=x. */
symb symm;
/* Modus ponens. modpon(p,e,q) takes proofs p of P, e of P=Q
and q of ~Q and returns a proof of false. */
symb modpon;
/* This oprerator represents a concatenation of rewrites. The term
a=b;c=d represents an A rewrite from a to b, followed by a B
rewrite fron b to c, followed by an A rewrite from c to d.
*/
symb concat;
/* This represents a lack of a proof */
ast no_proof;
// This is used to represent an infinitessimal value
ast epsilon;
// Represents the top position of a term
ast top_pos;
// add_pos(i,pos) represents position pos if the ith argument
symb add_pos;
// rewrite proof rules
/* rewrite_A(pos,cond,x=y) derives A |- cond => t[x]_p = t[y]_p
where t is an arbitrary term */
symb rewrite_A;
/* rewrite_B(pos,cond,x=y) derives B |- cond => t[x]_p = t[y]_p,
where t is an arbitrary term */
symb rewrite_B;
/* a normalization step is of the form (lhs=rhs) : proof, where "proof"
is a proof of lhs=rhs and lhs is a mixed term. If rhs is a mixed term
then it must have a greater index than lhs. */
symb normal_step;
/* A chain of normalization steps is either "true" (the null chain)
or normal_chain(<step> <tail>), where step is a normalization step
and tail is a normalization chain. The lhs of <step> must have
a less term index than any lhs in the chain. Moreover, the rhs of
<step> may not occur as the lhs of step in <tail>. If we wish to
add lhs=rhs to the beginning of <tail> and rhs=rhs' occurs in <tail>
we must apply transitivity, transforming <step> to lhs=rhs'. */
symb normal_chain;
/* If p is a proof of Q and c is a normalization chain, then normal(p,c)
is a proof of Q(c) (that is, Q with all substitutions in c performed). */
symb normal;
/** Stand-ins for quantifiers */
symb sforall, sexists;
ast get_placeholder(ast t){
hash_map<ast,ast>::iterator it = placeholders.find(t);
if(it != placeholders.end())
return it->second;
ast &res = placeholders[t];
res = mk_fresh_constant("@p",get_type(t));
#if 0
std::cout << "placeholder ";
print_expr(std::cout,res);
std::cout << " = ";
print_expr(std::cout,t);
std::cout << std::endl;
#endif
return res;
}
ast make_contra_node(const ast &pf, const std::vector<ast> &lits, int pfok = -1){
if(lits.size() == 0)
return pf;
std::vector<ast> reslits;
reslits.push_back(make(contra,pf,mk_false()));
for(unsigned i = 0; i < lits.size(); i++){
ast bar;
if(pfok & (1 << i)) bar = make(rotate_sum,lits[i],pf);
else bar = no_proof;
ast foo = make(contra,bar,lits[i]);
reslits.push_back(foo);
}
return make(And,reslits);
}
LitType get_term_type(const ast &lit){
prover::range r = pv->ast_scope(lit);
if(pv->range_is_empty(r))
return LitMixed;
if(weak) {
if(pv->range_min(r) == SHRT_MIN)
return pv->range_contained(r,rng) ? LitA : LitB;
else
return pv->ranges_intersect(r,rng) ? LitA : LitB;
}
else
return pv->range_contained(r,rng) ? LitA : LitB;
}
bool term_common(const ast &t){
prover::range r = pv->ast_scope(t);
return pv->ranges_intersect(r,rng) && !pv->range_contained(r,rng);
}
bool term_in_vocab(LitType ty, const ast &lit){
prover::range r = pv->ast_scope(lit);
if(ty == LitA){
return pv->ranges_intersect(r,rng);
}
return !pv->range_contained(r,rng);
}
/** Make a resolution node with given pivot literal and premises.
The conclusion of premise1 should contain the negation of the
pivot literal, while the conclusion of premise2 should contain the
pivot literal.
*/
node make_resolution(ast pivot, const std::vector<ast> &conc, node premise1, node premise2) {
LitType lt = get_term_type(pivot);
if(lt == LitA)
return my_or(premise1,premise2);
if(lt == LitB)
return my_and(premise1,premise2);
/* the mixed case is a bit complicated */
static int non_local_count = 0;
ast res = resolve_arith(pivot,conc,premise1,premise2);
#ifdef INVARIANT_CHECKING
check_contra(conc,res);
#endif
non_local_count++;
return res;
}
/* Handles the case of resolution on a mixed arith atom. */
ast resolve_arith(const ast &pivot, const std::vector<ast> &conc, node premise1, node premise2){
ast atom = get_lit_atom(pivot);
hash_map<ast,ast> memo;
ast neg_pivot_lit = mk_not(atom);
if(op(pivot) != Not)
std::swap(premise1,premise2);
if(op(pivot) == Equal && op(arg(pivot,0)) == Select && op(arg(pivot,1)) == Select){
neg_pivot_lit = mk_not(neg_pivot_lit);
std::swap(premise1,premise2);
}
return resolve_arith_rec1(memo, neg_pivot_lit, premise1, premise2);
}
ast apply_coeff(const ast &coeff, const ast &t){
#if 0
rational r;
if(!is_integer(coeff,r))
throw "ack!";
ast n = make_int(r.numerator());
ast res = make(Times,n,t);
if(!r.is_int()) {
ast d = make_int(r.numerator());
res = mk_idiv(res,d);
}
return res;
#endif
return make(Times,coeff,t);
}
ast sum_ineq(const ast &coeff1, const ast &ineq1, const ast &coeff2, const ast &ineq2){
opr sum_op = Leq;
if(op(ineq1) == Lt || op(ineq2) == Lt)
sum_op = Lt;
ast sum_sides[2];
for(int i = 0; i < 2; i++){
sum_sides[i] = make(Plus,apply_coeff(coeff1,arg(ineq1,i)),apply_coeff(coeff2,arg(ineq2,i)));
sum_sides[i] = z3_simplify(sum_sides[i]);
}
return make(sum_op,sum_sides[0],sum_sides[1]);
}
void collect_contra_resolvents(int from, const ast &pivot1, const ast &pivot, const ast &conj, std::vector<ast> &res){
int nargs = num_args(conj);
for(int i = from; i < nargs; i++){
ast f = arg(conj,i);
if(!(f == pivot)){
ast ph = get_placeholder(mk_not(arg(pivot1,1)));
ast pf = arg(pivot1,0);
ast thing = pf == no_proof ? no_proof : subst_term_and_simp(ph,pf,arg(f,0));
ast newf = make(contra,thing,arg(f,1));
res.push_back(newf);
}
}
}
bool is_negative_equality(const ast &e){
if(op(e) == Not){
opr o = op(arg(e,0));
return o == Equal || o == Iff;
}
return false;
}
int count_negative_equalities(const std::vector<ast> &resolvent){
int res = 0;
for(unsigned i = 0; i < resolvent.size(); i++)
if(is_negative_equality(arg(resolvent[i],1)))
res++;
return res;
}
ast resolve_contra_nf(const ast &pivot1, const ast &conj1,
const ast &pivot2, const ast &conj2){
std::vector<ast> resolvent;
collect_contra_resolvents(0,pivot1,pivot2,conj2,resolvent);
collect_contra_resolvents(1,pivot2,pivot1,conj1,resolvent);
if(count_negative_equalities(resolvent) > 1)
throw proof_error();
if(resolvent.size() == 1) // we have proved a contradiction
return simplify(arg(resolvent[0],0)); // this is the proof -- get interpolant
return make(And,resolvent);
}
ast resolve_contra(const ast &pivot1, const ast &conj1,
const ast &pivot2, const ast &conj2){
if(arg(pivot1,0) != no_proof)
return resolve_contra_nf(pivot1, conj1, pivot2, conj2);
if(arg(pivot2,0) != no_proof)
return resolve_contra_nf(pivot2, conj2, pivot1, conj1);
return resolve_with_quantifier(pivot1, conj1, pivot2, conj2);
}
bool is_contra_itp(const ast &pivot1, ast itp2, ast &pivot2){
if(op(itp2) == And){
int nargs = num_args(itp2);
for(int i = 1; i < nargs; i++){
ast foo = arg(itp2,i);
if(op(foo) == Uninterpreted && sym(foo) == contra){
if(arg(foo,1) == pivot1){
pivot2 = foo;
return true;
}
}
else break;
}
}
return false;
}
ast resolve_arith_rec2(hash_map<ast,ast> &memo, const ast &pivot1, const ast &conj1, const ast &itp2){
ast &res = memo[itp2];
if(!res.null())
return res;
ast pivot2;
if(is_contra_itp(mk_not(arg(pivot1,1)),itp2,pivot2))
res = resolve_contra(pivot1,conj1,pivot2,itp2);
else {
switch(op(itp2)){
case Or:
case And:
case Implies: {
unsigned nargs = num_args(itp2);
std::vector<ast> args; args.resize(nargs);
for(unsigned i = 0; i < nargs; i++)
args[i] = resolve_arith_rec2(memo, pivot1, conj1, arg(itp2,i));
ast foo = itp2; // get rid of const
res = clone(foo,args);
break;
}
default:
{
opr o = op(itp2);
if(o == Uninterpreted){
symb s = sym(itp2);
if(s == sforall || s == sexists)
res = make(s,arg(itp2,0),resolve_arith_rec2(memo, pivot1, conj1, arg(itp2,1)));
else
res = itp2;
}
else {
res = itp2;
}
}
}
}
return res;
}
ast resolve_arith_rec1(hash_map<ast,ast> &memo, const ast &neg_pivot_lit, const ast &itp1, const ast &itp2){
ast &res = memo[itp1];
if(!res.null())
return res;
ast pivot1;
if(is_contra_itp(neg_pivot_lit,itp1,pivot1)){
hash_map<ast,ast> memo2;
res = resolve_arith_rec2(memo2,pivot1,itp1,itp2);
}
else {
switch(op(itp1)){
case Or:
case And:
case Implies: {
unsigned nargs = num_args(itp1);
std::vector<ast> args; args.resize(nargs);
for(unsigned i = 0; i < nargs; i++)
args[i] = resolve_arith_rec1(memo, neg_pivot_lit, arg(itp1,i), itp2);
ast foo = itp1; // get rid of const
res = clone(foo,args);
break;
}
default:
{
opr o = op(itp1);
if(o == Uninterpreted){
symb s = sym(itp1);
if(s == sforall || s == sexists)
res = make(s,arg(itp1,0),resolve_arith_rec1(memo, neg_pivot_lit, arg(itp1,1), itp2));
else
res = itp1;
}
else {
res = itp1;
}
}
}
}
return res;
}
void check_contra(hash_set<ast> &memo, hash_set<ast> &neg_lits, const ast &foo){
if(memo.find(foo) != memo.end())
return;
memo.insert(foo);
if(op(foo) == Uninterpreted && sym(foo) == contra){
ast neg_lit = arg(foo,1);
if(!is_false(neg_lit) && neg_lits.find(neg_lit) == neg_lits.end())
throw "lost a literal";
return;
}
else {
switch(op(foo)){
case Or:
case And:
case Implies: {
unsigned nargs = num_args(foo);
std::vector<ast> args; args.resize(nargs);
for(unsigned i = 0; i < nargs; i++)
check_contra(memo, neg_lits, arg(foo,i));
break;
}
default: break;
}
}
}
void check_contra(const std::vector<ast> &neg_lits, const ast &foo){
hash_set<ast> memo;
hash_set<ast> neg_lits_set;
for(unsigned i = 0; i < neg_lits.size(); i++)
if(get_term_type(neg_lits[i]) == LitMixed)
neg_lits_set.insert(mk_not(neg_lits[i]));
check_contra(memo,neg_lits_set,foo);
}
hash_map<ast,ast> subst_memo; // memo of subst function
ast subst_term_and_simp(const ast &var, const ast &t, const ast &e){
subst_memo.clear();
return subst_term_and_simp_rec(var,t,e);
}
ast subst_term_and_simp_rec(const ast &var, const ast &t, const ast &e){
if(e == var) return t;
std::pair<ast,ast> foo(e,ast());
std::pair<hash_map<ast,ast>::iterator,bool> bar = subst_memo.insert(foo);
ast &res = bar.first->second;
if(bar.second){
if(op(e) == Uninterpreted){
symb g = sym(e);
if(g == rotate_sum){
if(var == get_placeholder(arg(e,0))){
res = e;
}
else
res = make(rotate_sum,arg(e,0),subst_term_and_simp_rec(var,t,arg(e,1)));
return res;
}
if(g == concat){
res = e;
return res;
}
}
int nargs = num_args(e);
std::vector<ast> args(nargs);
for(int i = 0; i < nargs; i++)
args[i] = subst_term_and_simp_rec(var,t,arg(e,i));
opr f = op(e);
if(f == Equal && args[0] == args[1]) res = mk_true();
else if(f == And) res = my_and(args);
else if(f == Or) res = my_or(args);
else if(f == Idiv) res = mk_idiv(args[0],args[1]);
else res = clone(e,args);
}
return res;
}
/* This is where the real work happens. Here, we simplify the
proof obtained by cut elimination, obtaining an interpolant. */
struct cannot_simplify {};
hash_map<ast,ast> simplify_memo;
ast simplify(const ast &t){
ast res = normalize(simplify_rec(t));
#ifdef BOGUS_QUANTS
if(localization_vars.size())
res = add_quants(z3_simplify(res));
#endif
return res;
}
ast simplify_rec(const ast &e){
std::pair<ast,ast> foo(e,ast());
std::pair<hash_map<ast,ast>::iterator,bool> bar = simplify_memo.insert(foo);
ast &res = bar.first->second;
if(bar.second){
int nargs = num_args(e);
std::vector<ast> args(nargs);
bool placeholder_arg = false;
symb g = sym(e);
if(g == concat){
res = e;
return res;
}
for(int i = 0; i < nargs; i++){
if(i == 0 && g == rotate_sum)
args[i] = arg(e,i);
else
args[i] = simplify_rec(arg(e,i));
placeholder_arg |= is_placeholder(args[i]);
}
try {
opr f = op(e);
if(f == Equal && args[0] == args[1]) res = mk_true();
else if(f == And) res = my_and(args);
else if(f == Or)
res = my_or(args);
else if(f == Idiv) res = mk_idiv(args[0],args[1]);
else if(f == Uninterpreted && !placeholder_arg){
if(g == rotate_sum) res = simplify_rotate(args);
else if(g == symm) res = simplify_symm(args);
else if(g == modpon) res = simplify_modpon(args);
else if(g == sum) res = simplify_sum(args);
else if(g == exmid) res = simplify_exmid(args);
else if(g == cong) res = simplify_cong(args);
#if 0
else if(g == modpon) res = simplify_modpon(args);
else if(g == leq2eq) res = simplify_leq2eq(args);
else if(g == eq2leq) res = simplify_eq2leq(args);
#endif
else res = clone(e,args);
}
else res = clone(e,args);
}
catch (const cannot_simplify &){
res = clone(e,args);
}
}
return res;
}
ast simplify_rotate(const std::vector<ast> &args){
const ast &pf = args[1];
ast pl = get_placeholder(args[0]);
if(op(pf) == Uninterpreted){
symb g = sym(pf);
if(g == sum) return simplify_rotate_sum(pl,pf);
if(g == leq2eq) return simplify_rotate_leq2eq(pl,args[0],pf);
if(g == eq2leq) return simplify_rotate_eq2leq(pl,args[0],pf);
if(g == cong) return simplify_rotate_cong(pl,args[0],pf);
if(g == modpon) return simplify_rotate_modpon(pl,args[0],pf);
// if(g == symm) return simplify_rotate_symm(pl,args[0],pf);
}
if(op(pf) == Leq)
throw "foo!";
throw cannot_simplify();
}
bool is_normal_ineq(const ast &ineq){
if(sym(ineq) == normal)
return is_ineq(arg(ineq,0));
return is_ineq(ineq);
}
ast destruct_cond_ineq(const ast &ineq, ast &Aproves, ast &Bproves){
ast res = ineq;
opr o = op(res);
if(o == And){
Aproves = my_and(Aproves,arg(res,0));
res = arg(res,1);
o = op(res);
}
if(o == Implies){
Bproves = my_and(Bproves,arg(res,0));
res = arg(res,1);
}
return res;
}
ast simplify_sum(std::vector<ast> &args){
ast Aproves = mk_true(), Bproves = mk_true();
ast ineq = destruct_cond_ineq(args[0],Aproves,Bproves);
if(!is_normal_ineq(ineq)) throw cannot_simplify();
sum_cond_ineq(ineq,args[1],args[2],Aproves,Bproves);
return my_and(Aproves,my_implies(Bproves,ineq));
}
ast simplify_rotate_sum(const ast &pl, const ast &pf){
ast Aproves = mk_true(), Bproves = mk_true();
ast ineq = make(Leq,make_int("0"),make_int("0"));
ineq = rotate_sum_rec(pl,pf,Aproves,Bproves,ineq);
if(is_true(Aproves) && is_true(Bproves))
return ineq;
return my_and(Aproves,my_implies(Bproves,ineq));
}
bool is_rewrite_chain(const ast &chain){
return sym(chain) == concat;
}
#if 0
ast ineq_from_chain_simple(const ast &chain, ast &cond){
if(is_true(chain))
return chain;
ast last = chain_last(chain);
ast rest = chain_rest(chain);
if(is_true(rest) && is_rewrite_side(LitA,last)
&& is_true(rewrite_lhs(last))){
cond = my_and(cond,rewrite_cond(last));
return rewrite_rhs(last);
}
if(is_rewrite_side(LitB,last) && is_true(rewrite_cond(last)))
return ineq_from_chain_simple(rest,cond);
return chain;
}
#endif
ast ineq_from_chain(const ast &chain, ast &Aproves, ast &Bproves){
if(is_rewrite_chain(chain))
return rewrite_chain_to_normal_ineq(chain,Aproves,Bproves);
return chain;
}
void sum_cond_ineq(ast &ineq, const ast &coeff2, const ast &ineq2, ast &Aproves, ast &Bproves){
opr o = op(ineq2);
if(o == And){
sum_cond_ineq(ineq,coeff2,arg(ineq2,1),Aproves,Bproves);
Aproves = my_and(Aproves,arg(ineq2,0));
}
else if(o == Implies){
sum_cond_ineq(ineq,coeff2,arg(ineq2,1),Aproves,Bproves);
Bproves = my_and(Bproves,arg(ineq2,0));
}
else {
ast the_ineq = ineq_from_chain(ineq2,Aproves,Bproves);
if(sym(ineq) == normal || sym(the_ineq) == normal){
sum_normal_ineq(ineq,coeff2,the_ineq,Aproves,Bproves);
return;
}
if(is_ineq(the_ineq))
linear_comb(ineq,coeff2,the_ineq);
else
throw cannot_simplify();
}
}
void destruct_normal(const ast &pf, ast &p, ast &n){
if(sym(pf) == normal){
p = arg(pf,0);
n = arg(pf,1);
}
else {
p = pf;
n = mk_true();
}
}
void sum_normal_ineq(ast &ineq, const ast &coeff2, const ast &ineq2, ast &Aproves, ast &Bproves){
ast in1,in2,n1,n2;
destruct_normal(ineq,in1,n1);
destruct_normal(ineq2,in2,n2);
ast dummy1, dummy2;
sum_cond_ineq(in1,coeff2,in2,dummy1,dummy2);
n1 = merge_normal_chains(n1,n2, Aproves, Bproves);
ineq = is_true(n1) ? in1 : make_normal(in1,n1);
}
bool is_ineq(const ast &ineq){
opr o = op(ineq);
if(o == Not) o = op(arg(ineq,0));
return o == Leq || o == Lt || o == Geq || o == Gt;
}
// divide both sides of inequality by a non-negative integer divisor
ast idiv_ineq(const ast &ineq1, const ast &divisor){
if(sym(ineq1) == normal){
ast in1,n1;
destruct_normal(ineq1,in1,n1);
in1 = idiv_ineq(in1,divisor);
return make_normal(in1,n1);
}
if(divisor == make_int(rational(1)))
return ineq1;
ast ineq = ineq1;
if(op(ineq) == Lt)
ineq = simplify_ineq(make(Leq,arg(ineq,0),make(Sub,arg(ineq,1),make_int("1"))));
return make(op(ineq),mk_idiv(arg(ineq,0),divisor),mk_idiv(arg(ineq,1),divisor));
}
ast rotate_sum_rec(const ast &pl, const ast &pf, ast &Aproves, ast &Bproves, ast &ineq){
if(pf == pl){
if(sym(ineq) == normal)
return ineq;
return simplify_ineq(ineq);
}
if(op(pf) == Uninterpreted && sym(pf) == sum){
if(arg(pf,2) == pl){
sum_cond_ineq(ineq,make_int("1"),arg(pf,0),Aproves,Bproves);
ineq = idiv_ineq(ineq,arg(pf,1));
return ineq;
}
sum_cond_ineq(ineq,arg(pf,1),arg(pf,2),Aproves,Bproves);
return rotate_sum_rec(pl,arg(pf,0),Aproves,Bproves,ineq);
}
throw cannot_simplify();
}
ast simplify_rotate_leq2eq(const ast &pl, const ast &neg_equality, const ast &pf){
if(pl == arg(pf,0)){
ast equality = arg(neg_equality,0);
ast x = arg(equality,0);
ast y = arg(equality,1);
ast Aproves1 = mk_true(), Bproves1 = mk_true();
ast pf1 = destruct_cond_ineq(arg(pf,1), Aproves1, Bproves1);
ast pf2 = destruct_cond_ineq(arg(pf,2), Aproves1, Bproves1);
ast xleqy = round_ineq(ineq_from_chain(pf1,Aproves1,Bproves1));
ast yleqx = round_ineq(ineq_from_chain(pf2,Aproves1,Bproves1));
ast ineq1 = make(Leq,make_int("0"),make_int("0"));
sum_cond_ineq(ineq1,make_int("-1"),xleqy,Aproves1,Bproves1);
sum_cond_ineq(ineq1,make_int("-1"),yleqx,Aproves1,Bproves1);
ast Acond = my_implies(Aproves1,my_and(Bproves1,z3_simplify(ineq1)));
ast Aproves2 = mk_true(), Bproves2 = mk_true();
ast ineq2 = make(Leq,make_int("0"),make_int("0"));
sum_cond_ineq(ineq2,make_int("1"),xleqy,Aproves2,Bproves2);
sum_cond_ineq(ineq2,make_int("1"),yleqx,Aproves2,Bproves2);
ast Bcond = my_implies(Bproves1,my_and(Aproves1,z3_simplify(ineq2)));
// if(!is_true(Aproves1) || !is_true(Bproves1))
// std::cout << "foo!\n";;
if(get_term_type(x) == LitA){
ast iter = z3_simplify(make(Plus,x,get_ineq_rhs(xleqy)));
ast rewrite1 = make_rewrite(LitA,top_pos,Acond,make(Equal,x,iter));
ast rewrite2 = make_rewrite(LitB,top_pos,Bcond,make(Equal,iter,y));
return chain_cons(chain_cons(mk_true(),rewrite1),rewrite2);
}
if(get_term_type(y) == LitA){
ast iter = z3_simplify(make(Plus,y,get_ineq_rhs(yleqx)));
ast rewrite2 = make_rewrite(LitA,top_pos,Acond,make(Equal,iter,y));
ast rewrite1 = make_rewrite(LitB,top_pos,Bcond,make(Equal,x,iter));
return chain_cons(chain_cons(mk_true(),rewrite1),rewrite2);
}
throw cannot_simplify();
}
throw cannot_simplify();
}
ast round_ineq(const ast &ineq){
if(sym(ineq) == normal)
return make_normal(round_ineq(arg(ineq,0)),arg(ineq,1));
if(!is_ineq(ineq))
throw cannot_simplify();
ast res = simplify_ineq(ineq);
if(op(res) == Lt)
res = make(Leq,arg(res,0),make(Sub,arg(res,1),make_int("1")));
return res;
}
ast unmixed_eq2ineq(const ast &lhs, const ast &rhs, opr comp_op, const ast &equa, ast &cond){
ast ineqs= chain_ineqs(comp_op,LitA,equa,lhs,rhs); // chain must be from lhs to rhs
cond = my_and(cond,chain_conditions(LitA,equa));
ast Bconds = z3_simplify(chain_conditions(LitB,equa));
if(is_true(Bconds) && op(ineqs) != And)
return my_implies(cond,ineqs);
if(op(ineqs) != And)
return my_and(Bconds,my_implies(cond,ineqs));
throw "help!";
}
ast add_mixed_eq2ineq(const ast &lhs, const ast &rhs, const ast &equa, const ast &itp){
if(is_true(equa))
return itp;
std::vector<ast> args(3);
args[0] = itp;
args[1] = make_int("1");
ast ineq = make(Leq,make_int(rational(0)),make_int(rational(0)));
args[2] = make_normal(ineq,cons_normal(fix_normal(lhs,rhs,equa),mk_true()));
return simplify_sum(args);
}
ast simplify_rotate_eq2leq(const ast &pl, const ast &neg_equality, const ast &pf){
if(pl == arg(pf,1)){
ast cond = mk_true();
ast equa = sep_cond(arg(pf,0),cond);
if(is_equivrel_chain(equa)){
ast lhs,rhs; eq_from_ineq(arg(neg_equality,0),lhs,rhs); // get inequality we need to prove
if(!rewrites_from_to(equa,lhs,rhs)){
lhs = arg(arg(neg_equality,0),0); // the equality proved is ambiguous, sadly
rhs = arg(arg(neg_equality,0),1);
}
LitType lhst = get_term_type(lhs), rhst = get_term_type(rhs);
if(lhst != LitMixed && rhst != LitMixed)
return unmixed_eq2ineq(lhs, rhs, op(arg(neg_equality,0)), equa, cond);
else {
ast left, left_term, middle, right_term, right;
left = get_left_movers(equa,lhs,middle,left_term);
middle = get_right_movers(middle,rhs,right,right_term);
ast itp = unmixed_eq2ineq(left_term, right_term, op(arg(neg_equality,0)), middle, cond);
// itp = my_implies(cond,itp);
itp = add_mixed_eq2ineq(lhs, left_term, left, itp);
itp = add_mixed_eq2ineq(right_term, rhs, right, itp);
return itp;
}
}
}
throw "help!";
}
void reverse_modpon(std::vector<ast> &args){
std::vector<ast> sargs(1); sargs[0] = args[1];
args[1] = simplify_symm(sargs);
if(is_equivrel_chain(args[2]))
args[1] = down_chain(args[1]);
std::swap(args[0],args[2]);
}
ast simplify_rotate_modpon(const ast &pl, const ast &neg_equality, const ast &pf){
std::vector<ast> args; args.resize(3);
args[0] = arg(pf,0);
args[1] = arg(pf,1);
args[2] = arg(pf,2);
if(pl == args[0])
reverse_modpon(args);
if(pl == args[2]){
ast cond = mk_true();
ast chain = simplify_modpon_fwd(args, cond);
return my_implies(cond,chain);
}
throw cannot_simplify();
}
ast get_ineq_rhs(const ast &ineq2){
opr o = op(ineq2);
if(o == Implies)
return get_ineq_rhs(arg(ineq2,1));
else if(o == Leq || o == Lt)
return arg(ineq2,1);
throw cannot_simplify();
}
ast simplify_rotate_cong(const ast &pl, const ast &neg_equality, const ast &pf){
if(pl == arg(pf,2)){
if(op(arg(pf,0)) == True)
return mk_true();
rational pos;
if(is_numeral(arg(pf,1),pos)){
int ipos = pos.get_unsigned();
ast cond = mk_true();
ast equa = sep_cond(arg(pf,0),cond);
#if 0
if(op(equa) == Equal){
ast pe = mk_not(neg_equality);
ast lhs = subst_in_arg_pos(ipos,arg(equa,0),arg(pe,0));
ast rhs = subst_in_arg_pos(ipos,arg(equa,1),arg(pe,1));
ast res = make(Equal,lhs,rhs);
return my_implies(cond,res);
}
#endif
ast res = chain_pos_add(ipos,equa);
return my_implies(cond,res);
}
}
throw cannot_simplify();
}
ast simplify_symm(const std::vector<ast> &args){
if(op(args[0]) == True)
return mk_true();
ast cond = mk_true();
ast equa = sep_cond(args[0],cond);
if(is_equivrel_chain(equa))
return my_implies(cond,reverse_chain(equa));
if(is_negation_chain(equa))
return commute_negation_chain(equa);
throw cannot_simplify();
}
ast simplify_modpon_fwd(const std::vector<ast> &args, ast &cond){
ast P = sep_cond(args[0],cond);
ast PeqQ = sep_cond(args[1],cond);
ast chain;
if(is_equivrel_chain(P)){
try {
ast split[2];
split_chain(PeqQ,split);
chain = reverse_chain(split[0]);
chain = concat_rewrite_chain(chain,P);
chain = concat_rewrite_chain(chain,split[1]);
}
catch(const cannot_split &){
static int this_count = 0;
this_count++;
ast tail, pref = get_head_chain(PeqQ,tail,false); // pref is x=y, tail is x=y -> x'=y'
ast split[2]; split_chain(tail,split); // rewrites from x to x' and y to y'
ast head = chain_last(pref);
ast prem = make_rewrite(rewrite_side(head),top_pos,rewrite_cond(head),make(Iff,mk_true(),mk_not(rewrite_lhs(head))));
ast back_chain = chain_cons(mk_true(),prem);
back_chain = concat_rewrite_chain(back_chain,chain_pos_add(0,reverse_chain(chain_rest(pref))));
ast cond = contra_chain(back_chain,P);
if(is_rewrite_side(LitA,head))
cond = mk_not(cond);
ast fwd_rewrite = make_rewrite(rewrite_side(head),top_pos,cond,rewrite_rhs(head));
P = chain_cons(mk_true(),fwd_rewrite);
chain = reverse_chain(split[0]);
chain = concat_rewrite_chain(chain,P);
chain = concat_rewrite_chain(chain,split[1]);
}
}
else { // if not an equivalence, must be of form T <-> pred
chain = concat_rewrite_chain(P,PeqQ);
}
return chain;
}
struct subterm_normals_failed {};
void get_subterm_normals(const ast &ineq1, const ast &ineq2, const ast &chain, ast &normals,
const ast &pos, hash_set<ast> &memo, ast &Aproves, ast &Bproves){
opr o1 = op(ineq1);
opr o2 = op(ineq2);
if(o1 == Not || o1 == Leq || o1 == Lt || o1 == Geq || o1 == Gt || o1 == Plus || o1 == Times){
int n = num_args(ineq1);
if(o2 != o1 || num_args(ineq2) != n)
throw "bad inequality rewriting";
for(int i = 0; i < n; i++){
ast new_pos = add_pos_to_end(pos,i);
get_subterm_normals(arg(ineq1,i), arg(ineq2,i), chain, normals, new_pos, memo, Aproves, Bproves);
}
}
else if(get_term_type(ineq2) == LitMixed){
if(memo.find(ineq2) == memo.end()){
memo.insert(ineq2);
ast sub_chain = extract_rewrites(chain,pos);
if(is_true(sub_chain))
throw "bad inequality rewriting";
ast new_normal = make_normal_step(ineq2,ineq1,reverse_chain(sub_chain));
normals = merge_normal_chains(normals,cons_normal(new_normal,mk_true()), Aproves, Bproves);
}
}
else if(!(ineq1 == ineq2))
throw subterm_normals_failed();
}
ast rewrites_to_normals(const ast &ineq1, const ast &chain, ast &normals, ast &Aproves, ast &Bproves, ast &Aineqs){
if(is_true(chain))
return ineq1;
ast last = chain_last(chain);
ast rest = chain_rest(chain);
ast new_ineq1 = rewrites_to_normals(ineq1, rest, normals, Aproves, Bproves, Aineqs);
ast p1 = rewrite_pos(last);
ast term1;
ast coeff = arith_rewrite_coeff(new_ineq1,p1,term1);
ast res = subst_in_pos(new_ineq1,rewrite_pos(last),rewrite_rhs(last));
ast rpos;
pos_diff(p1,rewrite_pos(last),rpos);
ast term2 = subst_in_pos(term1,rpos,rewrite_rhs(last));
if(get_term_type(term1) != LitMixed && get_term_type(term2) != LitMixed){
if(is_rewrite_side(LitA,last))
linear_comb(Aineqs,coeff,make(Leq,make_int(rational(0)),make(Sub,term2,term1)));
}
else {
ast pf = extract_rewrites(make(concat,mk_true(),rest),p1);
ast new_normal = fix_normal(term1,term2,pf);
normals = merge_normal_chains(normals,cons_normal(new_normal,mk_true()), Aproves, Bproves);
}
return res;
}
ast arith_rewrite_coeff(const ast &ineq, ast &p1, ast &term){
ast coeff = make_int(rational(1));
if(p1 == top_pos){
term = ineq;
return coeff;
}
int argpos = pos_arg(p1);
opr o = op(ineq);
switch(o){
case Leq:
case Lt:
coeff = argpos ? make_int(rational(1)) : make_int(rational(-1));
break;
case Geq:
case Gt:
coeff = argpos ? make_int(rational(-1)) : make_int(rational(1));
break;
case Not:
case Plus:
break;
case Times:
coeff = arg(ineq,0);
break;
default:
p1 = top_pos;
term = ineq;
return coeff;
}
p1 = arg(p1,1);
ast res = arith_rewrite_coeff(arg(ineq,argpos),p1,term);
p1 = pos_add(argpos,p1);
return coeff == make_int(rational(1)) ? res : make(Times,coeff,res);
}
ast rewrite_chain_to_normal_ineq(const ast &chain, ast &Aproves, ast &Bproves){
ast tail, pref = get_head_chain(chain,tail,false); // pref is x=y, tail is x=y -> x'=y'
ast head = chain_last(pref);
ast ineq1 = rewrite_rhs(head);
ast ineq2 = apply_rewrite_chain(ineq1,tail);
ast nc = mk_true();
hash_set<ast> memo;
ast itp = make(Leq,make_int(rational(0)),make_int(rational(0)));
ast Aproves_save = Aproves, Bproves_save = Bproves; try {
get_subterm_normals(ineq1,ineq2,tail,nc,top_pos,memo, Aproves, Bproves);
}
catch (const subterm_normals_failed &){ Aproves = Aproves_save; Bproves = Bproves_save; nc = mk_true();
rewrites_to_normals(ineq1, tail, nc, Aproves, Bproves, itp);
}
if(is_rewrite_side(LitA,head)){
linear_comb(itp,make_int("1"),ineq1); // make sure it is normal form
//itp = ineq1;
ast mc = z3_simplify(chain_side_proves(LitB,pref));
Bproves = my_and(Bproves,mc);
}
else {
ast mc = z3_simplify(chain_side_proves(LitA,pref));
Aproves = my_and(Aproves,mc);
}
if(is_true(nc))
return itp;
return make_normal(itp,nc);
}
/* Given a chain rewrite chain deriving not P and a rewrite chain deriving P, return an interpolant. */
ast contra_chain(const ast &neg_chain, const ast &pos_chain){
// equality is a special case. we use the derivation of x=y to rewrite not(x=y) to not(y=y)
if(is_equivrel_chain(pos_chain)){
ast tail, pref = get_head_chain(neg_chain,tail); // pref is not(x=y), tail is not(x,y) -> not(x',y')
ast split[2]; split_chain(down_chain(tail),split); // rewrites from x to x' and y to y'
ast chain = split[0];
chain = concat_rewrite_chain(chain,pos_chain); // rewrites from x to y'
chain = concat_rewrite_chain(chain,reverse_chain(split[1])); // rewrites from x to y
chain = concat_rewrite_chain(pref,chain_pos_add(0,chain_pos_add(0,chain))); // rewrites t -> not(y=y)
ast head = chain_last(pref);
if(is_rewrite_side(LitB,head)){
ast condition = chain_conditions(LitB,chain);
return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),condition);
}
else {
ast condition = chain_conditions(LitA,chain);
return my_and(chain_conditions(LitB,chain),my_implies(condition,mk_not(chain_formulas(LitB,chain))));
}
// ast chain = concat_rewrite_chain(neg_chain,chain_pos_add(0,chain_pos_add(0,pos_chain)));
// return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),chain_conditions(LitB,chain));
}
// otherwise, we reverse the derivation of t = P and use it to rewrite not(P) to not(t)
ast chain = concat_rewrite_chain(neg_chain,chain_pos_add(0,reverse_chain(pos_chain)));
return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),chain_conditions(LitB,chain));
}
ast simplify_modpon(const std::vector<ast> &args){
ast Aproves = mk_true(), Bproves = mk_true();
ast chain = simplify_modpon_fwd(args,Bproves);
ast Q2 = destruct_cond_ineq(args[2],Aproves,Bproves);
ast interp;
if(is_normal_ineq(Q2)){ // inequalities are special
ast nQ2 = rewrite_chain_to_normal_ineq(chain,Aproves,Bproves);
sum_cond_ineq(nQ2,make_int(rational(1)),Q2,Aproves,Bproves);
interp = normalize(nQ2);
}
else
interp = is_negation_chain(chain) ? contra_chain(chain,Q2) : contra_chain(Q2,chain);
return my_and(Aproves,my_implies(Bproves,interp));
}
ast simplify_exmid(const std::vector<ast> &args){
if(is_equivrel(args[0])){
ast Aproves = mk_true(), Bproves = mk_true();
ast chain = destruct_cond_ineq(args[1],Aproves,Bproves);
ast Q2 = destruct_cond_ineq(args[2],Aproves,Bproves);
ast interp = contra_chain(Q2,chain);
return my_and(Aproves,my_implies(Bproves,interp));
}
throw "bad exmid";
}
ast simplify_cong(const std::vector<ast> &args){
ast Aproves = mk_true(), Bproves = mk_true();
ast chain = destruct_cond_ineq(args[0],Aproves,Bproves);
rational pos;
if(is_numeral(args[1],pos)){
int ipos = pos.get_unsigned();
chain = chain_pos_add(ipos,chain);
ast Q2 = destruct_cond_ineq(args[2],Aproves,Bproves);
ast interp = contra_chain(Q2,chain);
return my_and(Aproves,my_implies(Bproves,interp));
}
throw "bad cong";
}
bool is_equivrel(const ast &p){
opr o = op(p);
return o == Equal || o == Iff;
}
struct rewrites_failed{};
/* Suppose p in Lang(B) and A |- p -> q and B |- q -> r. Return a z in Lang(B) such that
B |- p -> z and A |- z -> q. Collect any side conditions in "rules". */
ast commute_rewrites(const ast &p, const ast &q, const ast &r, ast &rules){
if(q == r)
return p;
if(p == q)
return r;
else {
ast rew = make(Equal,q,r);
if(get_term_type(rew) == LitB){
apply_common_rewrites(p,p,q,rules); // A rewrites must be over comon vocab
return r;
}
}
if(sym(p) != sym(q) || sym(q) != sym(r))
throw rewrites_failed();
int nargs = num_args(p);
if(nargs != num_args(q) || nargs != num_args(r))
throw rewrites_failed();
std::vector<ast> args; args.resize(nargs);
for(int i = 0; i < nargs; i++)
args[i] = commute_rewrites(arg(p,i),arg(q,i),arg(r,i),rules);
return clone(p,args);
}
ast apply_common_rewrites(const ast &p, const ast &q, const ast &r, ast &rules){
if(q == r)
return p;
ast rew = make(Equal,q,r);
if(term_common(rew)){
if(p != q)
throw rewrites_failed();
rules = my_and(rules,rew);
return r;
}
if(sym(p) != sym(q) || sym(q) != sym(r))
return p;
int nargs = num_args(p);
if(nargs != num_args(q) || nargs != num_args(r))
return p;
std::vector<ast> args; args.resize(nargs);
for(int i = 0; i < nargs; i++)
args[i] = apply_common_rewrites(arg(p,i),arg(q,i),arg(r,i),rules);
return clone(p,args);
}
ast apply_all_rewrites(const ast &p, const ast &q, const ast &r){
if(q == r)
return p;
if(p == q)
return r;
if(sym(p) != sym(q) || sym(q) != sym(r))
throw rewrites_failed();
int nargs = num_args(p);
if(nargs != num_args(q) || nargs != num_args(r))
throw rewrites_failed();
std::vector<ast> args; args.resize(nargs);
for(int i = 0; i < nargs; i++)
args[i] = apply_all_rewrites(arg(p,i),arg(q,i),arg(r,i));
return clone(p,args);
}
ast delta(const ast &x, const ast &y){
if(op(x) != op(y) || (op(x) == Uninterpreted && sym(x) != sym(y)) || num_args(x) != num_args(y))
return make(Equal,x,y);
ast res = mk_true();
int nargs = num_args(x);
for(int i = 0; i < nargs; i++)
res = my_and(res,delta(arg(x,i),arg(y,i)));
return res;
}
bool diff_rec(LitType t, const ast &p, const ast &q, ast &pd, ast &qd){
if(p == q)
return false;
if(term_in_vocab(t,p) && term_in_vocab(t,q)){
pd = p;
qd = q;
return true;
}
else {
if(sym(p) != sym(q)) return false;
int nargs = num_args(p);
if(num_args(q) != nargs) return false;
for(int i = 0; i < nargs; i++)
if(diff_rec(t,arg(p,i),arg(q,i),pd,qd))
return true;
return false;
}
}
void diff(LitType t, const ast &p, const ast &q, ast &pd, ast &qd){
if(!diff_rec(t,p,q,pd,qd))
throw cannot_simplify();
}
bool apply_diff_rec(LitType t, const ast &inp, const ast &p, const ast &q, ast &out){
if(p == q)
return false;
if(term_in_vocab(t,p) && term_in_vocab(t,q)){
if(inp != p)
throw cannot_simplify();
out = q;
return true;
}
else {
int nargs = num_args(p);
if(sym(p) != sym(q)) throw cannot_simplify();
if(num_args(q) != nargs) throw cannot_simplify();
if(sym(p) != sym(inp)) throw cannot_simplify();
if(num_args(inp) != nargs) throw cannot_simplify();
for(int i = 0; i < nargs; i++)
if(apply_diff_rec(t,arg(inp,i),arg(p,i),arg(q,i),out))
return true;
return false;
}
}
ast apply_diff(LitType t, const ast &inp, const ast &p, const ast &q){
ast out;
if(!apply_diff_rec(t,inp,p,q,out))
throw cannot_simplify();
return out;
}
bool merge_A_rewrites(const ast &A1, const ast &A2, ast &merged) {
if(arg(A1,1) == arg(A2,0)){
merged = make(op(A1),arg(A1,0),arg(A2,1));
return true;
}
ast diff1l, diff1r, diff2l, diff2r,diffBl,diffBr;
diff(LitA,arg(A1,0),arg(A1,1),diff1l,diff1r);
diff(LitA,arg(A2,0),arg(A2,1),diff2l,diff2r);
diff(LitB,arg(A1,1),arg(A2,0),diffBl,diffBr);
if(!term_common(diff2l) && !term_common(diffBr)){
ast A1r = apply_diff(LitB,arg(A2,1),arg(A2,0),arg(A1,1));
merged = make(op(A1),arg(A1,0),A1r);
return true;
}
if(!term_common(diff1r) && !term_common(diffBl)){
ast A2l = apply_diff(LitB,arg(A1,0),arg(A1,1),arg(A2,0));
merged = make(op(A1),A2l,arg(A2,1));
return true;
}
return false;
}
void collect_A_rewrites(const ast &t, std::vector<ast> &res){
if(is_true(t))
return;
if(sym(t) == concat){
res.push_back(arg(t,0));
collect_A_rewrites(arg(t,0),res);
return;
}
res.push_back(t);
}
ast concat_A_rewrites(const std::vector<ast> &rew){
if(rew.size() == 0)
return mk_true();
ast res = rew[0];
for(unsigned i = 1; i < rew.size(); i++)
res = make(concat,res,rew[i]);
return res;
}
ast merge_concat_rewrites(const ast &A1, const ast &A2){
std::vector<ast> rew;
collect_A_rewrites(A1,rew);
int first = rew.size(), last = first; // range that might need merging
collect_A_rewrites(A2,rew);
while(first > 0 && first < (int)rew.size() && first <= last){
ast merged;
if(merge_A_rewrites(rew[first-1],rew[first],merged)){
rew[first] = merged;
first--;
rew.erase(rew.begin()+first);
last--;
if(first >= last) last = first+1;
}
else
first++;
}
return concat_A_rewrites(rew);
}
ast sep_cond(const ast &t, ast &cond){
if(op(t) == Implies){
cond = my_and(cond,arg(t,0));
return arg(t,1);
}
return t;
}
/* operations on term positions */
/** Finds the difference between two positions. If p1 < p2 (p1 is a
position below p2), returns -1 and sets diff to p2-p1 (the psath
from position p2 to position p1). If p2 < p1 (p2 is a position
below p1), returns 1 and sets diff to p1-p2 (the psath from
position p1 to position p2). If equal, return 0 and set diff to
top_pos. Else (if p1 and p2 are independent) returns 2 and
leaves diff unchanged. */
int pos_diff(const ast &p1, const ast &p2, ast &diff){
if(p1 == top_pos && p2 != top_pos){
diff = p2;
return 1;
}
if(p2 == top_pos && p1 != top_pos){
diff = p1;
return -1;
}
if(p1 == top_pos && p2 == top_pos){
diff = p1;
return 0;
}
if(arg(p1,0) == arg(p2,0)) // same argument position, recur
return pos_diff(arg(p1,1),arg(p2,1),diff);
return 2;
}
/* return the position of pos in the argth argument */
ast pos_add(int arg, const ast &pos){
return make(add_pos,make_int(rational(arg)),pos);
}
ast add_pos_to_end(const ast &pos, int i){
if(pos == top_pos)
return pos_add(i,pos);
return make(add_pos,arg(pos,0),add_pos_to_end(arg(pos,1),i));
}
/* return the argument number of position, if not top */
int pos_arg(const ast &pos){
rational r;
if(is_numeral(arg(pos,0),r))
return r.get_unsigned();
throw "bad position!";
}
/* substitute y into position pos in x */
ast subst_in_pos(const ast &x, const ast &pos, const ast &y){
if(pos == top_pos)
return y;
int p = pos_arg(pos);
int nargs = num_args(x);
if(p >= 0 && p < nargs){
std::vector<ast> args(nargs);
for(int i = 0; i < nargs; i++)
args[i] = i == p ? subst_in_pos(arg(x,i),arg(pos,1),y) : arg(x,i);
return clone(x,args);
}
throw "bad term position!";
}
ast diff_chain(LitType t, const ast &pos, const ast &x, const ast &y, const ast &prefix){
int nargs = num_args(x);
if(x == y) return prefix;
if(sym(x) == sym(y) && nargs == num_args(y)){
ast res = prefix;
for(int i = 0; i < nargs; i++)
res = diff_chain(t,pos_add(i,pos),arg(x,i),arg(y,i),res);
return res;
}
return chain_cons(prefix,make_rewrite(t,pos,mk_true(),make_equiv_rel(x,y)));
}
/* operations on rewrites */
ast make_rewrite(LitType t, const ast &pos, const ast &cond, const ast &equality){
#if 0
if(pos == top_pos && op(equality) == Iff && !is_true(arg(equality,0)))
throw "bad rewrite";
#endif
if(!is_equivrel(equality))
throw "bad rewrite";
return make(t == LitA ? rewrite_A : rewrite_B, pos, cond, equality);
}
ast rewrite_pos(const ast &rew){
return arg(rew,0);
}
ast rewrite_cond(const ast &rew){
return arg(rew,1);
}
ast rewrite_equ(const ast &rew){
return arg(rew,2);
}
ast rewrite_lhs(const ast &rew){
return arg(arg(rew,2),0);
}
ast rewrite_rhs(const ast &rew){
return arg(arg(rew,2),1);
}
/* operations on rewrite chains */
ast chain_cons(const ast &chain, const ast &elem){
return make(concat,chain,elem);
}
ast chain_rest(const ast &chain){
return arg(chain,0);
}
ast chain_last(const ast &chain){
return arg(chain,1);
}
ast rewrite_update_rhs(const ast &rew, const ast &pos, const ast &new_rhs, const ast &new_cond){
ast foo = subst_in_pos(rewrite_rhs(rew),pos,new_rhs);
ast equality = arg(rew,2);
return make(sym(rew),rewrite_pos(rew),my_and(rewrite_cond(rew),new_cond),make(op(equality),arg(equality,0),foo));
}
ast rewrite_update_lhs(const ast &rew, const ast &pos, const ast &new_lhs, const ast &new_cond){
ast foo = subst_in_pos(rewrite_lhs(rew),pos,new_lhs);
ast equality = arg(rew,2);
return make(sym(rew),rewrite_pos(rew),my_and(rewrite_cond(rew),new_cond),make(op(equality),foo,arg(equality,1)));
}
bool is_common_rewrite(const ast &rew){
return term_common(arg(rew,2));
}
bool is_right_mover(const ast &rew){
return term_common(rewrite_lhs(rew)) && !term_common(rewrite_rhs(rew));
}
bool is_left_mover(const ast &rew){
return term_common(rewrite_rhs(rew)) && !term_common(rewrite_lhs(rew));
}
bool same_side(const ast &rew1, const ast &rew2){
return sym(rew1) == sym(rew2);
}
bool is_rewrite_side(LitType t, const ast &rew){
if(t == LitA)
return sym(rew) == rewrite_A;
return sym(rew) == rewrite_B;
}
LitType rewrite_side(const ast &rew){
return (sym(rew) == rewrite_A) ? LitA : LitB;
}
ast rewrite_to_formula(const ast &rew){
return my_implies(arg(rew,1),arg(rew,2));
}
// make rewrite rew conditon on rewrite cond
ast rewrite_conditional(const ast &cond, const ast &rew){
ast cf = rewrite_to_formula(cond);
return make(sym(rew),arg(rew,0),my_and(arg(rew,1),cf),arg(rew,2));
}
ast reverse_rewrite(const ast &rew){
ast equ = arg(rew,2);
return make(sym(rew),arg(rew,0),arg(rew,1),make(op(equ),arg(equ,1),arg(equ,0)));
}
ast rewrite_pos_add(int apos, const ast &rew){
return make(sym(rew),pos_add(apos,arg(rew,0)),arg(rew,1),arg(rew,2));
}
ast rewrite_pos_set(const ast &pos, const ast &rew){
return make(sym(rew),pos,arg(rew,1),arg(rew,2));
}
ast rewrite_up(const ast &rew){
return make(sym(rew),arg(arg(rew,0),1),arg(rew,1),arg(rew,2));
}
/** Adds a rewrite to a chain of rewrites, keeping the chain in
normal form. An empty chain is represented by true.*/
ast add_rewrite_to_chain(const ast &chain, const ast &rewrite){
if(is_true(chain))
return chain_cons(chain,rewrite);
ast last = chain_last(chain);
ast rest = chain_rest(chain);
if(same_side(last,rewrite)){
ast p1 = rewrite_pos(last);
ast p2 = rewrite_pos(rewrite);
ast diff;
switch(pos_diff(p1,p2,diff)){
case 1: {
ast absorb = rewrite_update_rhs(last,diff,rewrite_rhs(rewrite),rewrite_cond(rewrite));
return add_rewrite_to_chain(rest,absorb);
}
case 0:
case -1: {
ast absorb = rewrite_update_lhs(rewrite,diff,rewrite_lhs(last),rewrite_cond(last));
return add_rewrite_to_chain(rest,absorb);
}
default: {// independent case
bool rm = is_right_mover(last);
bool lm = is_left_mover(rewrite);
if((lm && !rm) || (rm && !lm))
return chain_swap(rest,last,rewrite);
}
}
}
else {
if(is_left_mover(rewrite)){
if(is_common_rewrite(last))
return add_rewrite_to_chain(chain_cons(rest,flip_rewrite(last)),rewrite);
if(!is_left_mover(last))
return chain_swap(rest,last,rewrite);
}
if(is_right_mover(last)){
if(is_common_rewrite(rewrite))
return add_rewrite_to_chain(chain,flip_rewrite(rewrite));
if(!is_right_mover(rewrite))
return chain_swap(rest,last,rewrite);
}
}
return chain_cons(chain,rewrite);
}
ast chain_swap(const ast &rest, const ast &last, const ast &rewrite){
return chain_cons(add_rewrite_to_chain(rest,rewrite),last);
}
ast flip_rewrite(const ast &rew){
symb flip_sym = (sym(rew) == rewrite_A) ? rewrite_B : rewrite_A;
ast cf = rewrite_to_formula(rew);
return make(flip_sym,arg(rew,0),my_implies(arg(rew,1),cf),arg(rew,2));
}
/** concatenates two rewrite chains, keeping result in normal form. */
ast concat_rewrite_chain(const ast &chain1, const ast &chain2){
if(is_true(chain2)) return chain1;
if(is_true(chain1)) return chain2;
ast foo = concat_rewrite_chain(chain1,chain_rest(chain2));
return add_rewrite_to_chain(foo,chain_last(chain2));
}
/** reverse a chain of rewrites */
ast reverse_chain_rec(const ast &chain, const ast &prefix){
if(is_true(chain))
return prefix;
ast last = reverse_rewrite(chain_last(chain));
ast rest = chain_rest(chain);
return reverse_chain_rec(rest,chain_cons(prefix,last));
}
ast reverse_chain(const ast &chain){
return reverse_chain_rec(chain,mk_true());
}
bool is_equivrel_chain(const ast &chain){
if(is_true(chain))
return true;
ast last = chain_last(chain);
ast rest = chain_rest(chain);
if(is_true(rest))
return !is_true(rewrite_lhs(last));
return is_equivrel_chain(rest);
}
bool is_negation_chain(const ast &chain){
if(is_true(chain))
return false;
ast last = chain_last(chain);
ast rest = chain_rest(chain);
if(is_true(rest))
return op(rewrite_rhs(last)) == Not;
return is_negation_chain(rest);
}
ast commute_negation_chain(const ast &chain){
if(is_true(chain))
return chain;
ast last = chain_last(chain);
ast rest = chain_rest(chain);
if(is_true(rest)){
ast old = rewrite_rhs(last);
if(!(op(old) == Not))
throw "bad negative equality chain";
ast equ = arg(old,0);
if(!is_equivrel(equ))
throw "bad negative equality chain";
last = rewrite_update_rhs(last,top_pos,make(Not,make(op(equ),arg(equ,1),arg(equ,0))),make(True));
return chain_cons(rest,last);
}
ast pos = rewrite_pos(last);
if(pos == top_pos)
throw "bad negative equality chain";
int idx = pos_arg(pos);
if(idx != 0)
throw "bad negative equality chain";
pos = arg(pos,1);
if(pos == top_pos){
ast lhs = rewrite_lhs(last);
ast rhs = rewrite_rhs(last);
if(op(lhs) != Equal || op(rhs) != Equal)
throw "bad negative equality chain";
last = make_rewrite(rewrite_side(last),rewrite_pos(last),rewrite_cond(last),
make(Iff,make(Equal,arg(lhs,1),arg(lhs,0)),make(Equal,arg(rhs,1),arg(rhs,0))));
}
else {
idx = pos_arg(pos);
if(idx == 0)
idx = 1;
else if(idx == 1)
idx = 0;
else
throw "bad negative equality chain";
pos = pos_add(0,pos_add(idx,arg(pos,1)));
last = make_rewrite(rewrite_side(last),pos,rewrite_cond(last),rewrite_equ(last));
}
return chain_cons(commute_negation_chain(rest),last);
}
// split a rewrite chain into head and tail at last top-level rewrite
ast get_head_chain(const ast &chain, ast &tail, bool is_not = true){
ast last = chain_last(chain);
ast rest = chain_rest(chain);
ast pos = rewrite_pos(last);
if(pos == top_pos || (is_not && arg(pos,1) == top_pos)){
tail = mk_true();
return chain;
}
if(is_true(rest))
throw "bad rewrite chain";
ast head = get_head_chain(rest,tail,is_not);
tail = chain_cons(tail,last);
return head;
}
bool has_mixed_summands(const ast &e){
if(op(e) == Plus){
int nargs = num_args(e);
for(int i = 0; i < nargs; i++)
if(has_mixed_summands(arg(e,i)))
return true;
return false;
}
return get_term_type(e) == LitMixed;
}
// split a rewrite chain into head and tail at last sum with no mixed sumands
ast get_right_movers(const ast &chain, const ast &rhs, ast &tail, ast &mid){
if(is_true(chain) || !has_mixed_summands(rhs)){
mid = rhs;
tail = mk_true();
return chain;
}
ast last = chain_last(chain);
ast rest = chain_rest(chain);
ast mm = subst_in_pos(rhs,rewrite_pos(last),rewrite_lhs(last));
ast res = get_right_movers(rest,mm,tail,mid);
tail = chain_cons(tail,last);
return res;
}
// split a rewrite chain into head and tail at first sum with no mixed sumands
ast get_left_movers(const ast &chain, const ast &lhs, ast &tail, ast &mid){
if(is_true(chain)){
mid = lhs;
if(!has_mixed_summands(lhs)){
tail = mk_true();
return chain;
}
return ast();
}
ast last = chain_last(chain);
ast rest = chain_rest(chain);
ast res = get_left_movers(rest,lhs,tail,mid);
if(res.null()){
mid = subst_in_pos(mid,rewrite_pos(last),rewrite_rhs(last));
if(get_term_type(mid) != LitMixed){
tail = mk_true();
return chain;
}
return ast();
}
tail = chain_cons(tail,last);
return res;
}
struct cannot_split {};
/** Split a chain of rewrites two chains, operating on positions 0 and 1.
Fail if any rewrite in the chain operates on top position. */
void split_chain_rec(const ast &chain, ast *res){
if(is_true(chain))
return;
ast last = chain_last(chain);
ast rest = chain_rest(chain);
split_chain_rec(rest,res);
ast pos = rewrite_pos(last);
if(pos == top_pos){
if(rewrite_lhs(last) == rewrite_rhs(last))
return; // skip if it's a noop
throw cannot_split();
}
int arg = pos_arg(pos);
if(arg<0 || arg > 1)
throw cannot_split();
res[arg] = chain_cons(res[arg],rewrite_up(last));
}
void split_chain(const ast &chain, ast *res){
res[0] = res[1] = mk_true();
split_chain_rec(chain,res);
}
ast extract_rewrites(const ast &chain, const ast &pos){
if(is_true(chain))
return chain;
ast last = chain_last(chain);
ast rest = chain_rest(chain);
ast new_rest = extract_rewrites(rest,pos);
ast p1 = rewrite_pos(last);
ast diff;
switch(pos_diff(p1,pos,diff)){
case -1: {
ast new_last = rewrite_pos_set(diff, last);
return chain_cons(new_rest,new_last);
}
case 1:
if(rewrite_lhs(last) != rewrite_rhs(last))
throw "bad rewrite chain";
break;
default:;
}
return new_rest;
}
ast down_chain(const ast &chain){
ast split[2];
split_chain(chain,split);
return split[0];
}
ast chain_conditions(LitType t, const ast &chain){
if(is_true(chain))
return mk_true();
ast last = chain_last(chain);
ast rest = chain_rest(chain);
ast cond = chain_conditions(t,rest);
if(is_rewrite_side(t,last))
cond = my_and(cond,rewrite_cond(last));
return cond;
}
ast chain_formulas(LitType t, const ast &chain){
if(is_true(chain))
return mk_true();
ast last = chain_last(chain);
ast rest = chain_rest(chain);
ast cond = chain_formulas(t,rest);
if(is_rewrite_side(t,last))
cond = my_and(cond,rewrite_equ(last));
return cond;
}
bool rewrites_from_to(const ast &chain, const ast &lhs, const ast &rhs){
if(is_true(chain))
return lhs == rhs;
ast last = chain_last(chain);
ast rest = chain_rest(chain);
ast mid = subst_in_pos(rhs,rewrite_pos(last),rewrite_lhs(last));
return rewrites_from_to(rest,lhs,mid);
}
struct bad_ineq_inference {};
ast chain_ineqs(opr comp_op, LitType t, const ast &chain, const ast &lhs, const ast &rhs){
if(is_true(chain)){
if(lhs != rhs)
throw bad_ineq_inference();
return make(Leq,make_int(rational(0)),make_int(rational(0)));
}
ast last = chain_last(chain);
ast rest = chain_rest(chain);
ast mid = subst_in_pos(rhs,rewrite_pos(last),rewrite_lhs(last));
ast cond = chain_ineqs(comp_op,t,rest,lhs,mid);
if(is_rewrite_side(t,last)){
ast diff;
if(comp_op == Leq) diff = make(Sub,rhs,mid);
else diff = make(Sub,mid,rhs);
ast foo = make(Leq,make_int("0"),z3_simplify(diff));
if(is_true(cond))
cond = foo;
else {
linear_comb(cond,make_int(rational(1)),foo);
cond = simplify_ineq(cond);
}
}
return cond;
}
ast ineq_to_lhs(const ast &ineq){
ast s = make(Leq,make_int(rational(0)),make_int(rational(0)));
linear_comb(s,make_int(rational(1)),ineq);
return simplify_ineq(s);
}
void eq_from_ineq(const ast &ineq, ast &lhs, ast &rhs){
// ast s = ineq_to_lhs(ineq);
// ast srhs = arg(s,1);
ast srhs = arg(ineq,0);
if(op(srhs) == Plus && num_args(srhs) == 2 && arg(ineq,1) == make_int(rational(0))){
lhs = arg(srhs,0);
rhs = arg(srhs,1);
// if(op(lhs) == Times)
// std::swap(lhs,rhs);
if(op(rhs) == Times){
rhs = arg(rhs,1);
// if(op(ineq) == Leq)
// std::swap(lhs,rhs);
return;
}
}
if(op(ineq) == Leq || op(ineq) == Geq){
lhs = srhs;
rhs = arg(ineq,1);
return;
}
throw "bad ineq";
}
ast chain_pos_add(int arg, const ast &chain){
if(is_true(chain))
return mk_true();
ast last = rewrite_pos_add(arg,chain_last(chain));
ast rest = chain_pos_add(arg,chain_rest(chain));
return chain_cons(rest,last);
}
ast apply_rewrite_chain(const ast &t, const ast &chain){
if(is_true(chain))
return t;
ast last = chain_last(chain);
ast rest = chain_rest(chain);
ast mid = apply_rewrite_chain(t,rest);
ast res = subst_in_pos(mid,rewrite_pos(last),rewrite_rhs(last));
return res;
}
ast drop_rewrites(LitType t, const ast &chain, ast &remainder){
if(!is_true(chain)){
ast last = chain_last(chain);
ast rest = chain_rest(chain);
if(is_rewrite_side(t,last)){
ast res = drop_rewrites(t,rest,remainder);
remainder = chain_cons(remainder,last);
return res;
}
}
remainder = mk_true();
return chain;
}
// Normalization chains
ast cons_normal(const ast &first, const ast &rest){
return make(normal_chain,first,rest);
}
ast normal_first(const ast &t){
return arg(t,0);
}
ast normal_rest(const ast &t){
return arg(t,1);
}
ast normal_lhs(const ast &t){
return arg(arg(t,0),0);
}
ast normal_rhs(const ast &t){
return arg(arg(t,0),1);
}
ast normal_proof(const ast &t){
return arg(t,1);
}
ast make_normal_step(const ast &lhs, const ast &rhs, const ast &proof){
return make(normal_step,make_equiv(lhs,rhs),proof);
}
ast make_normal(const ast &ineq, const ast &nrml){
if(!is_ineq(ineq))
throw "what?";
return make(normal,ineq,nrml);
}
ast fix_normal(const ast &lhs, const ast &rhs, const ast &proof){
LitType lhst = get_term_type(lhs);
LitType rhst = get_term_type(rhs);
if(lhst == LitMixed && (rhst != LitMixed || ast_id(lhs) < ast_id(rhs)))
return make_normal_step(lhs,rhs,proof);
if(rhst == LitMixed && (lhst != LitMixed || ast_id(rhs) < ast_id(lhs)))
return make_normal_step(rhs,lhs,reverse_chain(proof));
throw "help!";
}
ast chain_side_proves(LitType side, const ast &chain){
LitType other_side = side == LitA ? LitB : LitA;
return my_and(chain_conditions(other_side,chain),my_implies(chain_conditions(side,chain),chain_formulas(side,chain)));
}
// Merge two normalization chains
ast merge_normal_chains_rec(const ast &chain1, const ast &chain2, hash_map<ast,ast> &trans, ast &Aproves, ast &Bproves){
if(is_true(chain1))
return chain2;
if(is_true(chain2))
return chain1;
ast f1 = normal_first(chain1);
ast f2 = normal_first(chain2);
ast lhs1 = normal_lhs(f1);
ast lhs2 = normal_lhs(f2);
int id1 = ast_id(lhs1);
int id2 = ast_id(lhs2);
if(id1 < id2)
return cons_normal(f1,merge_normal_chains_rec(normal_rest(chain1),chain2,trans,Aproves,Bproves));
if(id2 < id1)
return cons_normal(f2,merge_normal_chains_rec(chain1,normal_rest(chain2),trans,Aproves,Bproves));
ast rhs1 = normal_rhs(f1);
ast rhs2 = normal_rhs(f2);
LitType t1 = get_term_type(rhs1);
LitType t2 = get_term_type(rhs2);
int tid1 = ast_id(rhs1);
int tid2 = ast_id(rhs2);
ast pf1 = normal_proof(f1);
ast pf2 = normal_proof(f2);
ast new_normal;
if(t1 == LitMixed && (t2 != LitMixed || tid2 > tid1)){
ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2);
new_normal = f2;
trans[rhs1] = make_normal_step(rhs1,rhs2,new_proof);
}
else if(t2 == LitMixed && (t1 != LitMixed || tid1 > tid2))
return merge_normal_chains_rec(chain2,chain1,trans,Aproves,Bproves);
else if(t1 == LitA && t2 == LitB){
ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2);
ast Bproof, Aproof = drop_rewrites(LitB,new_proof,Bproof);
ast mcA = chain_side_proves(LitB,Aproof);
Bproves = my_and(Bproves,mcA);
ast mcB = chain_side_proves(LitA,Bproof);
Aproves = my_and(Aproves,mcB);
ast rep = apply_rewrite_chain(rhs1,Aproof);
new_proof = concat_rewrite_chain(pf1,Aproof);
new_normal = make_normal_step(lhs1,rep,new_proof);
ast A_normal = make_normal_step(rhs1,rep,Aproof);
ast res = cons_normal(new_normal,merge_normal_chains_rec(normal_rest(chain1),normal_rest(chain2),trans,Aproves,Bproves));
res = merge_normal_chains_rec(res,cons_normal(A_normal,make(True)),trans,Aproves,Bproves);
return res;
}
else if(t1 == LitB && t2 == LitA)
return merge_normal_chains_rec(chain2,chain1,trans,Aproves,Bproves);
else if(t1 == LitA) {
ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2);
ast mc = chain_side_proves(LitB,new_proof);
Bproves = my_and(Bproves,mc);
new_normal = f1; // choice is arbitrary
}
else { /* t1 = t2 = LitB */
ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2);
ast mc = chain_side_proves(LitA,new_proof);
Aproves = my_and(Aproves,mc);
new_normal = f1; // choice is arbitrary
}
return cons_normal(new_normal,merge_normal_chains_rec(normal_rest(chain1),normal_rest(chain2),trans,Aproves,Bproves));
}
ast trans_normal_chain(const ast &chain, hash_map<ast,ast> &trans){
if(is_true(chain))
return chain;
ast f = normal_first(chain);
ast r = normal_rest(chain);
r = trans_normal_chain(r,trans);
ast rhs = normal_rhs(f);
hash_map<ast,ast>::iterator it = trans.find(rhs);
ast new_normal;
if(it != trans.end() && get_term_type(normal_lhs(f)) == LitMixed){
const ast &f2 = it->second;
ast pf = concat_rewrite_chain(normal_proof(f),normal_proof(f2));
new_normal = make_normal_step(normal_lhs(f),normal_rhs(f2),pf);
}
else
new_normal = f;
if(get_term_type(normal_lhs(f)) == LitMixed)
trans[normal_lhs(f)] = new_normal;
return cons_normal(new_normal,r);
}
ast merge_normal_chains(const ast &chain1, const ast &chain2, ast &Aproves, ast &Bproves){
hash_map<ast,ast> trans;
ast res = merge_normal_chains_rec(chain1,chain2,trans,Aproves,Bproves);
res = trans_normal_chain(res,trans);
return res;
}
bool destruct_cond_ineq(ast t, ast &Aproves, ast &Bproves, ast&ineq){
if(op(t) == And){
Aproves = arg(t,0);
t = arg(t,1);
}
else
Aproves = mk_true();
if(op(t) == Implies){
Bproves = arg(t,0);
t = arg(t,1);
}
else
Bproves = mk_true();
if(is_normal_ineq(t)){
ineq = t;
return true;
}
return false;
}
ast cons_cond_ineq(const ast &Aproves, const ast &Bproves, const ast &ineq){
return my_and(Aproves,my_implies(Bproves,ineq));
}
ast normalize(const ast &ct){
ast Aproves,Bproves,t;
if(!destruct_cond_ineq(ct,Aproves,Bproves,t))
return ct;
if(sym(t) != normal)
return ct;
ast chain = arg(t,1);
hash_map<ast,ast> map;
for(ast c = chain; !is_true(c); c = normal_rest(c)){
ast first = normal_first(c);
ast lhs = normal_lhs(first);
ast rhs = normal_rhs(first);
map[lhs] = rhs;
}
ast res = subst(map,arg(t,0));
return cons_cond_ineq(Aproves,Bproves,res);
}
/** Make an assumption node. The given clause is assumed in the given frame. */
virtual node make_assumption(int frame, const std::vector<ast> &assumption){
if(!weak){
if(pv->in_range(frame,rng)){
std::vector<ast> itp_clause;
for(unsigned i = 0; i < assumption.size(); i++)
if(get_term_type(assumption[i]) != LitA)
itp_clause.push_back(assumption[i]);
ast res = my_or(itp_clause);
return res;
}
else {
return mk_true();
}
}
else {
if(pv->in_range(frame,rng)){
return mk_false();
}
else {
std::vector<ast> itp_clause;
for(unsigned i = 0; i < assumption.size(); i++)
if(get_term_type(assumption[i]) != LitB)
itp_clause.push_back(assumption[i]);
ast res = my_or(itp_clause);
return mk_not(res);
}
}
}
ast make_local_rewrite(LitType t, const ast &p){
ast rew = is_equivrel(p) ? p : make(Iff,mk_true(),p);
#if 0
if(op(rew) == Iff && !is_true(arg(rew,0)))
return diff_chain(t,top_pos,arg(rew,0),arg(rew,1), mk_true());
#endif
return chain_cons(mk_true(),make_rewrite(t, top_pos, mk_true(), rew));
}
ast triv_interp(const symb &rule, const std::vector<ast> &premises, int mask_in){
std::vector<ast> ps; ps.resize(premises.size());
std::vector<ast> conjs;
int mask = 0;
for(unsigned i = 0; i < ps.size(); i++){
ast p = premises[i];
LitType t = get_term_type(p);
switch(t){
case LitA:
case LitB:
ps[i] = make_local_rewrite(t,p);
break;
default:
ps[i] = get_placeholder(p); // can only prove consequent!
if(mask_in & (1 << i))
mask |= (1 << conjs.size());
conjs.push_back(p);
}
}
ast ref = make(rule,ps);
ast res = make_contra_node(ref,conjs,mask);
return res;
}
ast triv_interp(const symb &rule, const ast &p0, const ast &p1, const ast &p2, int mask){
std::vector<ast> ps; ps.resize(3);
ps[0] = p0;
ps[1] = p1;
ps[2] = p2;
return triv_interp(rule,ps,mask);
}
/** Make a modus-ponens node. This takes derivations of |- x
and |- x = y and produces |- y */
virtual node make_mp(const ast &p_eq_q, const ast &prem1, const ast &prem2){
/* Interpolate the axiom p, p=q -> q */
ast p = arg(p_eq_q,0);
ast q = arg(p_eq_q,1);
ast itp;
if(get_term_type(p_eq_q) == LitMixed){
int mask = 1 << 2;
if(op(p) == Not && is_equivrel(arg(p,0)))
mask |= 1; // we may need to run this rule backward if first premise is negative equality
itp = triv_interp(modpon,p,p_eq_q,mk_not(q),mask);
}
else {
if(get_term_type(p) == LitA){
if(get_term_type(q) == LitA)
itp = mk_false();
else {
if(get_term_type(p_eq_q) == LitA)
itp = q;
else
throw proof_error();
}
}
else {
if(get_term_type(q) == LitA){
if(get_term_type(make(Equal,p,q)) == LitA)
itp = mk_not(p);
else
throw proof_error();
}
else
itp = mk_true();
}
}
/* Resolve it with the premises */
std::vector<ast> conc; conc.push_back(q); conc.push_back(mk_not(p_eq_q));
itp = make_resolution(p,conc,itp,prem1);
conc.pop_back();
itp = make_resolution(p_eq_q,conc,itp,prem2);
return itp;
}
ast capture_localization(ast e){
// #define CAPTURE_LOCALIZATION
#ifdef CAPTURE_LOCALIZATION
for(int i = localization_vars.size() - 1; i >= 0; i--){
LocVar &lv = localization_vars[i];
if(occurs_in(lv.var,e)){
symb q = (pv->in_range(lv.frame,rng)) ? sexists : sforall;
e = make(q,make(Equal,lv.var,lv.term),e); // use Equal because it is polymorphic
}
}
#endif
return e;
}
/** Make an axiom node. The conclusion must be an instance of an axiom. */
virtual node make_axiom(const std::vector<ast> &conclusion, prover::range frng){
int nargs = conclusion.size();
std::vector<ast> largs(nargs);
std::vector<ast> eqs;
std::vector<ast> pfs;
for(int i = 0; i < nargs; i++){
ast argpf;
ast lit = conclusion[i];
largs[i] = localize_term(lit,frng,argpf);
frng = pv->range_glb(frng,pv->ast_scope(largs[i]));
if(largs[i] != lit){
eqs.push_back(make_equiv(largs[i],lit));
pfs.push_back(argpf);
}
}
int frame = pv->range_max(frng);
ast itp = make_assumption(frame,largs);
for(unsigned i = 0; i < eqs.size(); i++)
itp = make_mp(eqs[i],itp,pfs[i]);
return capture_localization(itp);
}
virtual node make_axiom(const std::vector<ast> &conclusion){
return make_axiom(conclusion,pv->range_full());
}
/** Make a Contra node. This rule takes a derivation of the form
Gamma |- False and produces |- \/~Gamma. */
virtual node make_contra(node prem, const std::vector<ast> &conclusion){
return prem;
}
/** Make hypothesis. Creates a node of the form P |- P. */
virtual node make_hypothesis(const ast &P){
if(is_not(P))
return make_hypothesis(arg(P,0));
switch(get_term_type(P)){
case LitA:
return mk_false();
case LitB:
return mk_true();
default: // mixed hypothesis
switch(op(P)){
case Geq:
case Leq:
case Gt:
case Lt: {
ast zleqz = make(Leq,make_int("0"),make_int("0"));
ast fark1 = make(sum,zleqz,make_int("1"),get_placeholder(P));
ast fark2 = make(sum,fark1,make_int("1"),get_placeholder(mk_not(P)));
ast res = make(And,make(contra,fark2,mk_false()),
make(contra,get_placeholder(mk_not(P)),P),
make(contra,get_placeholder(P),mk_not(P)));
return res;
}
default: {
ast em = make(exmid,P,get_placeholder(P),get_placeholder(mk_not(P)));
ast res = make(And,make(contra,em,mk_false()),
make(contra,get_placeholder(mk_not(P)),P),
make(contra,get_placeholder(P),mk_not(P)));
return res;
}
}
}
}
/** Make a Reflexivity node. This rule produces |- x = x */
virtual node make_reflexivity(ast con){
if(get_term_type(con) == LitA)
return mk_false();
if(get_term_type(con) == LitB)
return mk_true();
ast itp = make(And,make(contra,no_proof,mk_false()),
make(contra,mk_true(),mk_not(con)));
return itp;
}
/** Make a Symmetry node. This takes a derivation of |- x = y and
produces | y = x. Ditto for ~(x=y) */
virtual node make_symmetry(ast con, const ast &premcon, node prem){
#if 0
ast x = arg(con,0);
ast y = arg(con,1);
ast p = make(op(con),y,x);
#endif
if(get_term_type(con) != LitMixed)
return prem; // symmetry shmymmetry...
ast em = make(exmid,con,make(symm,get_placeholder(premcon)),get_placeholder(mk_not(con)));
ast itp = make(And,make(contra,em,mk_false()),
make(contra,make(symm,get_placeholder(mk_not(con))),premcon),
make(contra,make(symm,get_placeholder(premcon)),mk_not(con)));
std::vector<ast> conc; conc.push_back(con);
itp = make_resolution(premcon,conc,itp,prem);
return itp;
}
ast make_equiv_rel(const ast &x, const ast &y){
if(is_bool_type(get_type(x)))
return make(Iff,x,y);
return make(Equal,x,y);
}
/** Make a transitivity node. This takes derivations of |- x = y
and |- y = z produces | x = z */
virtual node make_transitivity(const ast &x, const ast &y, const ast &z, node prem1, node prem2){
/* Interpolate the axiom x=y,y=z,-> x=z */
ast p = make_equiv_rel(x,y);
ast q = make_equiv_rel(y,z);
ast r = make_equiv_rel(x,z);
ast equiv = make(Iff,p,r);
ast itp;
itp = make_congruence(q,equiv,prem2);
itp = make_mp(equiv,prem1,itp);
return itp;
}
/** Make a congruence node. This takes derivations of |- x_i = y_i
and produces |- f(x_1,...,x_n) = f(y_1,...,y_n) */
virtual node make_congruence(const ast &p, const ast &con, const ast &prem1){
ast x = arg(p,0), y = arg(p,1);
ast itp;
LitType con_t = get_term_type(con);
if(get_term_type(p) == LitA){
if(con_t == LitA)
itp = mk_false();
else if(con_t == LitB)
itp = p;
else
itp = make_mixed_congruence(x, y, p, con, prem1);
}
else {
if(con_t == LitA)
itp = mk_not(p);
else{
if(con_t == LitB)
itp = mk_true();
else
itp = make_mixed_congruence(x, y, p, con, prem1);
}
}
std::vector<ast> conc; conc.push_back(con);
itp = make_resolution(p,conc,itp,prem1);
return itp;
}
int find_congruence_position(const ast &p, const ast &con){
// find the argument position of x and y
const ast &x = arg(p,0);
const ast &y = arg(p,1);
int nargs = num_args(arg(con,0));
for(int i = 0; i < nargs; i++)
if(x == arg(arg(con,0),i) && y == arg(arg(con,1),i))
return i;
throw proof_error();
}
/** Make a congruence node. This takes derivations of |- x_i1 = y_i1, |- x_i2 = y_i2,...
and produces |- f(...x_i1...x_i2...) = f(...y_i1...y_i2...) */
node make_congruence(const std::vector<ast> &p, const ast &con, const std::vector<ast> &prems){
if(p.size() == 0)
throw proof_error();
if(p.size() == 1)
return make_congruence(p[0],con,prems[0]);
ast thing = con;
ast res = mk_true();
for(unsigned i = 0; i < p.size(); i++){
int pos = find_congruence_position(p[i],thing);
ast next = subst_in_arg_pos(pos,arg(p[i],1),arg(thing,0));
ast goal = make(op(thing),arg(thing,0),next);
ast equa = make_congruence(p[i],goal,prems[i]);
if(i == 0)
res = equa;
else {
ast trace = make(op(con),arg(con,0),arg(thing,0));
ast equiv = make(Iff,trace,make(op(trace),arg(trace,0),next));
ast foo = make_congruence(goal,equiv,equa);
res = make_mp(equiv,res,foo);
}
thing = make(op(thing),next,arg(thing,1));
}
return res;
}
/* Interpolate a mixed congruence axiom. */
virtual ast make_mixed_congruence(const ast &x, const ast &y, const ast &p, const ast &con, const ast &prem1){
ast foo = p;
std::vector<ast> conjs;
LitType t = get_term_type(foo);
switch(t){
case LitA:
case LitB:
foo = make_local_rewrite(t,foo);
break;
case LitMixed:
conjs.push_back(foo);
foo = get_placeholder(foo);
}
// find the argument position of x and y
int pos = -1;
int nargs = num_args(arg(con,0));
for(int i = 0; i < nargs; i++)
if(x == arg(arg(con,0),i) && y == arg(arg(con,1),i))
pos = i;
if(pos == -1)
throw proof_error();
ast bar = make(cong,foo,make_int(rational(pos)),get_placeholder(mk_not(con)));
conjs.push_back(mk_not(con));
return make_contra_node(bar,conjs);
}
ast subst_in_arg_pos(int pos, ast term, ast app){
std::vector<ast> args;
get_args(app,args);
args[pos] = term;
return clone(app,args);
}
/** Make a farkas proof node. */
virtual node make_farkas(ast con, const std::vector<node> &prems, const std::vector<ast> &prem_cons,
const std::vector<ast> &coeffs){
/* Compute the interpolant for the clause */
ast zero = make_int("0");
std::vector<ast> conjs;
ast thing = make(Leq,zero,zero);
for(unsigned i = 0; i < prem_cons.size(); i++){
const ast &lit = prem_cons[i];
if(get_term_type(lit) == LitA)
// Farkas rule seems to assume strict integer inequalities are rounded
linear_comb(thing,coeffs[i],lit,true /*round_off*/);
}
thing = simplify_ineq(thing);
for(unsigned i = 0; i < prem_cons.size(); i++){
const ast &lit = prem_cons[i];
if(get_term_type(lit) == LitMixed){
thing = make(sum,thing,coeffs[i],get_placeholder(lit));
conjs.push_back(lit);
}
}
thing = make_contra_node(thing,conjs);
/* Resolve it with the premises */
std::vector<ast> conc; conc.resize(prem_cons.size());
for(unsigned i = 0; i < prem_cons.size(); i++)
conc[prem_cons.size()-i-1] = prem_cons[i];
for(unsigned i = 0; i < prem_cons.size(); i++){
thing = make_resolution(prem_cons[i],conc,thing,prems[i]);
conc.pop_back();
}
return thing;
}
/** Set P to P + cQ, where P and Q are linear inequalities. Assumes P is 0 <= y or 0 < y. */
void linear_comb(ast &P, const ast &c, const ast &Q, bool round_off = false){
ast Qrhs;
bool qstrict = false;
if(is_not(Q)){
ast nQ = arg(Q,0);
switch(op(nQ)){
case Gt:
Qrhs = make(Sub,arg(nQ,1),arg(nQ,0));
break;
case Lt:
Qrhs = make(Sub,arg(nQ,0),arg(nQ,1));
break;
case Geq:
Qrhs = make(Sub,arg(nQ,1),arg(nQ,0));
qstrict = true;
break;
case Leq:
Qrhs = make(Sub,arg(nQ,0),arg(nQ,1));
qstrict = true;
break;
default:
throw proof_error();
}
}
else {
switch(op(Q)){
case Leq:
Qrhs = make(Sub,arg(Q,1),arg(Q,0));
break;
case Geq:
Qrhs = make(Sub,arg(Q,0),arg(Q,1));
break;
case Lt:
Qrhs = make(Sub,arg(Q,1),arg(Q,0));
qstrict = true;
break;
case Gt:
Qrhs = make(Sub,arg(Q,0),arg(Q,1));
qstrict = true;
break;
default:
throw proof_error();
}
}
#if 0
bool pstrict = op(P) == Lt, strict = pstrict || qstrict;
if(pstrict && qstrict && round_off)
Qrhs = make(Sub,Qrhs,make_int(rational(1)));
#else
bool pstrict = op(P) == Lt;
if(qstrict && round_off && (pstrict || !(c == make_int(rational(1))))){
Qrhs = make(Sub,Qrhs,make_int(rational(1)));
qstrict = false;
}
Qrhs = make(Times,c,Qrhs);
bool strict = pstrict || qstrict;
#endif
if(strict)
P = make(Lt,arg(P,0),make(Plus,arg(P,1),Qrhs));
else
P = make(Leq,arg(P,0),make(Plus,arg(P,1),Qrhs));
}
/* Make an axiom instance of the form |- x<=y, y<= x -> x =y */
virtual node make_leq2eq(ast x, ast y, const ast &xleqy, const ast &yleqx){
ast con = make(Equal,x,y);
ast itp;
switch(get_term_type(con)){
case LitA:
itp = mk_false();
break;
case LitB:
itp = mk_true();
break;
default: { // mixed equality
if(get_term_type(x) == LitMixed || get_term_type(y) == LitMixed){
// std::cerr << "WARNING: mixed term in leq2eq\n";
std::vector<ast> lits;
lits.push_back(con);
lits.push_back(make(Not,xleqy));
lits.push_back(make(Not,yleqx));
return make_axiom(lits);
}
std::vector<ast> conjs; conjs.resize(3);
conjs[0] = mk_not(con);
conjs[1] = xleqy;
conjs[2] = yleqx;
itp = make_contra_node(make(leq2eq,
get_placeholder(mk_not(con)),
get_placeholder(xleqy),
get_placeholder(yleqx)),
conjs,1);
}
}
return itp;
}
/* Make an axiom instance of the form |- x = y -> x <= y */
virtual node make_eq2leq(ast x, ast y, const ast &xleqy){
ast itp;
switch(get_term_type(xleqy)){
case LitA:
itp = mk_false();
break;
case LitB:
itp = mk_true();
break;
default: { // mixed equality
std::vector<ast> conjs; conjs.resize(2);
conjs[0] = make(Equal,x,y);
conjs[1] = mk_not(xleqy);
itp = make(eq2leq,get_placeholder(conjs[0]),get_placeholder(conjs[1]));
itp = make_contra_node(itp,conjs,2);
}
}
return itp;
}
/* Make an inference of the form t <= c |- t/d <= floor(c/d) where t
is an affine term divisble by d and c is an integer constant */
virtual node make_cut_rule(const ast &tleqc, const ast &d, const ast &con, const ast &prem){
ast itp = mk_false();
switch(get_term_type(con)){
case LitA:
itp = mk_false();
break;
case LitB:
itp = mk_true();
break;
default: {
std::vector<ast> conjs; conjs.resize(2);
conjs[0] = tleqc;
conjs[1] = mk_not(con);
itp = make(sum,get_placeholder(conjs[0]),d,get_placeholder(conjs[1]));
itp = make_contra_node(itp,conjs);
}
}
std::vector<ast> conc; conc.push_back(con);
itp = make_resolution(tleqc,conc,itp,prem);
return itp;
}
// create a fresh variable for localization
ast fresh_localization_var(const ast &term, int frame){
std::ostringstream s;
s << "%" << (localization_vars.size());
ast var = make_var(s.str().c_str(),get_type(term));
pv->sym_range(sym(var)) = pv->range_full(); // make this variable global
localization_vars.push_back(LocVar(var,term,frame));
return var;
}
struct LocVar { // localization vars
ast var; // a fresh variable
ast term; // term it represents
int frame; // frame in which it's defined
LocVar(ast v, ast t, int f){var=v;term=t;frame=f;}
};
std::vector<LocVar> localization_vars; // localization vars in order of creation
hash_map<ast,ast> localization_map; // maps terms to their localization vars
hash_map<ast,ast> localization_pf_map; // maps terms to proofs of their localizations
/* "localize" a term e to a given frame range, creating new symbols to
represent non-local subterms. This returns the localized version e_l,
as well as a proof thet e_l = l.
*/
ast make_refl(const ast &e){
if(get_term_type(e) == LitA)
return mk_false();
return mk_true(); // TODO: is this right?
}
ast make_equiv(const ast &x, const ast &y){
if(get_type(x) == bool_type())
return make(Iff,x,y);
else
return make(Equal,x,y);
}
ast localize_term(ast e, const prover::range &rng, ast &pf){
ast orig_e = e;
pf = make_refl(e); // proof that e = e
prover::range erng = pv->ast_scope(e);
#if 0
if(!(erng.lo > erng.hi) && pv->ranges_intersect(pv->ast_scope(e),rng)){
return e; // this term occurs in range, so it's O.K.
}
#endif
hash_map<ast,ast>::iterator it = localization_map.find(e);
if(it != localization_map.end() && is_bool_type(get_type(e))
&& !pv->ranges_intersect(pv->ast_scope(it->second),rng))
it = localization_map.end(); // prevent quantifiers over booleans
if(it != localization_map.end()){
pf = localization_pf_map[e];
e = it->second;
}
else {
// if it is non-local, we must first localize the arguments to
// the range of its function symbol
int nargs = num_args(e);
if(nargs > 0 /* && (!is_local(e) || flo <= hi || fhi >= lo) */){
prover::range frng = rng;
opr o = op(e);
if(o == Uninterpreted){
symb f = sym(e);
prover::range srng = pv->sym_range(f);
if(pv->ranges_intersect(srng,rng)) // localize to desired range if possible
frng = pv->range_glb(srng,rng);
else
frng = srng; // this term will be localized
}
else if(o == Plus || o == Times){ // don't want bound variables inside arith ops
// std::cout << "WARNING: non-local arithmetic\n";
// frng = erng; // this term will be localized
}
else if(o == Select){ // treat the array term like a function symbol
prover::range srng = pv->ast_scope(arg(e,0));
if(!(srng.lo > srng.hi) && pv->ranges_intersect(srng,rng)) // localize to desired range if possible
frng = pv->range_glb(srng,rng);
else
frng = srng; // this term will be localized
}
std::vector<ast> largs(nargs);
std::vector<ast> eqs;
std::vector<ast> pfs;
for(int i = 0; i < nargs; i++){
ast argpf;
largs[i] = localize_term(arg(e,i),frng,argpf);
frng = pv->range_glb(frng,pv->ast_scope(largs[i]));
if(largs[i] != arg(e,i)){
eqs.push_back(make_equiv(largs[i],arg(e,i)));
pfs.push_back(argpf);
}
}
e = clone(e,largs);
if(pfs.size())
pf = make_congruence(eqs,make_equiv(e,orig_e),pfs);
// assert(is_local(e));
}
localization_pf_map[orig_e] = pf;
localization_map[orig_e] = e;
}
if(pv->ranges_intersect(pv->ast_scope(e),rng))
return e; // this term occurs in range, so it's O.K.
if(is_array_type(get_type(e)))
std::cerr << "WARNING: array quantifier\n";
// choose a frame for the constraint that is close to range
int frame = pv->range_near(pv->ast_scope(e),rng);
ast new_var = fresh_localization_var(e,frame);
localization_map[orig_e] = new_var;
std::vector<ast> foo; foo.push_back(make_equiv(new_var,e));
ast bar = make_assumption(frame,foo);
pf = make_transitivity(new_var,e,orig_e,bar,pf);
localization_pf_map[orig_e] = pf;
return new_var;
}
ast delete_quant(hash_map<ast,ast> &memo, const ast &v, const ast &e){
std::pair<ast,ast> foo(e,ast());
std::pair<hash_map<ast,ast>::iterator,bool> bar = memo.insert(foo);
ast &res = bar.first->second;
if(bar.second){
opr o = op(e);
switch(o){
case Or:
case And:
case Implies: {
unsigned nargs = num_args(e);
std::vector<ast> args; args.resize(nargs);
for(unsigned i = 0; i < nargs; i++)
args[i] = delete_quant(memo, v, arg(e,i));
res = make(o,args);
break;
}
case Uninterpreted: {
symb s = sym(e);
ast w = arg(arg(e,0),0);
if(s == sforall || s == sexists){
res = delete_quant(memo,v,arg(e,1));
if(w != v)
res = make(s,w,res);
break;
}
}
default:
res = e;
}
}
return res;
}
ast insert_quants(hash_map<ast,ast> &memo, const ast &e){
std::pair<ast,ast> foo(e,ast());
std::pair<hash_map<ast,ast>::iterator,bool> bar = memo.insert(foo);
ast &res = bar.first->second;
if(bar.second){
opr o = op(e);
switch(o){
case Or:
case And:
case Implies: {
unsigned nargs = num_args(e);
std::vector<ast> args; args.resize(nargs);
for(unsigned i = 0; i < nargs; i++)
args[i] = insert_quants(memo, arg(e,i));
res = make(o,args);
break;
}
case Uninterpreted: {
symb s = sym(e);
if(s == sforall || s == sexists){
opr q = (s == sforall) ? Forall : Exists;
ast v = arg(arg(e,0),0);
hash_map<ast,ast> dmemo;
ast body = delete_quant(dmemo,v,arg(e,1));
body = insert_quants(memo,body);
res = apply_quant(q,v,body);
break;
}
}
default:
res = e;
}
}
return res;
}
ast add_quants(ast e){
#ifdef CAPTURE_LOCALIZATION
if(!localization_vars.empty()){
hash_map<ast,ast> memo;
e = insert_quants(memo,e);
}
#else
for(int i = localization_vars.size() - 1; i >= 0; i--){
LocVar &lv = localization_vars[i];
opr quantifier = (pv->in_range(lv.frame,rng)) ? Exists : Forall;
e = apply_quant(quantifier,lv.var,e);
}
#endif
return e;
}
node make_resolution(ast pivot, node premise1, node premise2) {
std::vector<ast> lits;
return make_resolution(pivot,lits,premise1,premise2);
}
/* Return an interpolant from a proof of false */
ast interpolate(const node &pf){
// proof of false must be a formula, with quantified symbols
#ifndef BOGUS_QUANTS
return add_quants(z3_simplify(pf));
#else
return z3_simplify(pf);
#endif
}
ast resolve_with_quantifier(const ast &pivot1, const ast &conj1,
const ast &pivot2, const ast &conj2){
if(is_not(arg(pivot1,1)))
return resolve_with_quantifier(pivot2,conj2,pivot1,conj1);
ast eqpf;
ast P = arg(pivot1,1);
ast Ploc = localize_term(P, rng, eqpf);
ast pPloc = make_hypothesis(Ploc);
ast pP = make_mp(make(Iff,Ploc,P),pPloc,eqpf);
ast rP = make_resolution(P,conj1,pP);
ast nP = mk_not(P);
ast nPloc = mk_not(Ploc);
ast neqpf = make_congruence(make(Iff,Ploc,P),make(Iff,nPloc,nP),eqpf);
ast npPloc = make_hypothesis(nPloc);
ast npP = make_mp(make(Iff,nPloc,nP),npPloc,neqpf);
ast nrP = make_resolution(nP,conj2,npP);
ast res = make_resolution(Ploc,rP,nrP);
return capture_localization(res);
}
ast get_contra_coeff(const ast &f){
ast c = arg(f,0);
// if(!is_not(arg(f,1)))
// c = make(Uminus,c);
return c;
}
ast my_or(const ast &a, const ast &b){
return mk_or(a,b);
}
ast my_and(const ast &a, const ast &b){
return mk_and(a,b);
}
ast my_implies(const ast &a, const ast &b){
return mk_implies(a,b);
}
ast my_or(const std::vector<ast> &a){
return mk_or(a);
}
ast my_and(const std::vector<ast> &a){
return mk_and(a);
}
ast get_lit_atom(const ast &l){
if(op(l) == Not)
return arg(l,0);
return l;
}
bool is_placeholder(const ast &e){
if(op(e) == Uninterpreted){
std::string name = string_of_symbol(sym(e));
if(name.size() > 2 && name[0] == '@' && name[1] == 'p')
return true;
}
return false;
}
public:
iz3proof_itp_impl(prover *p, const prover::range &r, bool w)
: iz3proof_itp(*p)
{
pv = p;
rng = r;
weak = false ; //w;
type boolintbooldom[3] = {bool_type(),int_type(),bool_type()};
type booldom[1] = {bool_type()};
type boolbooldom[2] = {bool_type(),bool_type()};
type boolboolbooldom[3] = {bool_type(),bool_type(),bool_type()};
type intbooldom[2] = {int_type(),bool_type()};
contra = function("@contra",2,boolbooldom,bool_type());
m().inc_ref(contra);
sum = function("@sum",3,boolintbooldom,bool_type());
m().inc_ref(sum);
rotate_sum = function("@rotsum",2,boolbooldom,bool_type());
m().inc_ref(rotate_sum);
leq2eq = function("@leq2eq",3,boolboolbooldom,bool_type());
m().inc_ref(leq2eq);
eq2leq = function("@eq2leq",2,boolbooldom,bool_type());
m().inc_ref(eq2leq);
cong = function("@cong",3,boolintbooldom,bool_type());
m().inc_ref(cong);
exmid = function("@exmid",3,boolboolbooldom,bool_type());
m().inc_ref(exmid);
symm = function("@symm",1,booldom,bool_type());
m().inc_ref(symm);
epsilon = make_var("@eps",int_type());
modpon = function("@mp",3,boolboolbooldom,bool_type());
m().inc_ref(modpon);
no_proof = make_var("@nop",bool_type());
concat = function("@concat",2,boolbooldom,bool_type());
m().inc_ref(concat);
top_pos = make_var("@top_pos",bool_type());
add_pos = function("@add_pos",2,intbooldom,bool_type());
m().inc_ref(add_pos);
rewrite_A = function("@rewrite_A",3,boolboolbooldom,bool_type());
m().inc_ref(rewrite_A);
rewrite_B = function("@rewrite_B",3,boolboolbooldom,bool_type());
m().inc_ref(rewrite_B);
normal_step = function("@normal_step",2,boolbooldom,bool_type());
m().inc_ref(normal_step);
normal_chain = function("@normal_chain",2,boolbooldom,bool_type());
m().inc_ref(normal_chain);
normal = function("@normal",2,boolbooldom,bool_type());
m().inc_ref(normal);
sforall = function("@sforall",2,boolbooldom,bool_type());
m().inc_ref(sforall);
sexists = function("@sexists",2,boolbooldom,bool_type());
m().inc_ref(sexists);
}
~iz3proof_itp_impl(){
m().dec_ref(contra);
m().dec_ref(sum);
m().dec_ref(rotate_sum);
m().dec_ref(leq2eq);
m().dec_ref(eq2leq);
m().dec_ref(cong);
m().dec_ref(exmid);
m().dec_ref(symm);
m().dec_ref(modpon);
m().dec_ref(concat);
m().dec_ref(add_pos);
m().dec_ref(rewrite_A);
m().dec_ref(rewrite_B);
}
};
iz3proof_itp *iz3proof_itp::create(prover *p, const prover::range &r, bool w){
return new iz3proof_itp_impl(p,r,w);
}