diff --git a/examples/interp/iz3.cpp b/examples/interp/iz3.cpp new file mode 100755 index 000000000..fa7d8b896 --- /dev/null +++ b/examples/interp/iz3.cpp @@ -0,0 +1,452 @@ +#include +#include +#include +#include +#include +#include "z3.h" + + + +int usage(const char **argv){ + std::cerr << "usage: " << argv[0] << " [options] file.smt" << std::endl; + std::cerr << std::endl; + std::cerr << "options:" << std::endl; + std::cerr << " -t,--tree tree interpolation" << std::endl; + std::cerr << " -c,--check check result" << std::endl; + std::cerr << " -p,--profile profile execution" << std::endl; + std::cerr << " -w,--weak weak interpolants" << std::endl; + std::cerr << " -f,--flat ouput flat formulas" << std::endl; + std::cerr << " -o ouput to SMT-LIB file" << std::endl; + std::cerr << " -a,--anon anonymize" << std::endl; + std::cerr << " -s,--simple simple proof mode" << std::endl; + std::cerr << std::endl; + return 1; +} + +int main(int argc, const char **argv) { + + bool tree_mode = false; + bool check_mode = false; + bool profile_mode = false; + bool incremental_mode = false; + std::string output_file; + bool flat_mode = false; + bool anonymize = false; + bool write = false; + + Z3_config cfg = Z3_mk_config(); + // Z3_interpolation_options options = Z3_mk_interpolation_options(); + Z3_params options = 0; + + /* Parse the command line */ + int argn = 1; + while(argn < argc-1){ + std::string flag = argv[argn]; + if(flag[0] == '-'){ + if(flag == "-t" || flag == "--tree") + tree_mode = true; + else if(flag == "-c" || flag == "--check") + check_mode = true; + else if(flag == "-p" || flag == "--profile") + profile_mode = true; +#if 0 + else if(flag == "-w" || flag == "--weak") + Z3_set_interpolation_option(options,"weak","1"); + else if(flag == "--secondary") + Z3_set_interpolation_option(options,"secondary","1"); +#endif + else if(flag == "-i" || flag == "--incremental") + incremental_mode = true; + else if(flag == "-o"){ + argn++; + if(argn >= argc) return usage(argv); + output_file = argv[argn]; + } + else if(flag == "-f" || flag == "--flat") + flat_mode = true; + else if(flag == "-a" || flag == "--anon") + anonymize = true; + else if(flag == "-w" || flag == "--write") + write = true; + else if(flag == "-s" || flag == "--simple") + Z3_set_param_value(cfg,"PREPROCESS","false"); + else + return usage(argv); + } + argn++; + } + if(argn != argc-1) + return usage(argv); + const char *filename = argv[argn]; + + + /* Create a Z3 context to contain formulas */ + Z3_context ctx = Z3_mk_interpolation_context(cfg); + + if(write || anonymize) + Z3_set_ast_print_mode(ctx,Z3_PRINT_SMTLIB2_COMPLIANT); + else if(!flat_mode) + Z3_set_ast_print_mode(ctx,Z3_PRINT_SMTLIB_COMPLIANT); + + /* Read an interpolation problem */ + + int num; + Z3_ast *constraints; + int *parents = 0; + const char *error; + bool ok; + int num_theory; + Z3_ast *theory; + + ok = Z3_read_interpolation_problem(ctx, &num, &constraints, tree_mode ? &parents : 0, filename, &error, &num_theory, &theory); + + /* If parse failed, print the error message */ + + if(!ok){ + std::cerr << error << "\n"; + return 1; + } + + /* if we get only one formula, and it is a conjunction, split it into conjuncts. */ + if(!tree_mode && num == 1){ + Z3_app app = Z3_to_app(ctx,constraints[0]); + Z3_func_decl func = Z3_get_app_decl(ctx,app); + Z3_decl_kind dk = Z3_get_decl_kind(ctx,func); + if(dk == Z3_OP_AND){ + int nconjs = Z3_get_app_num_args(ctx,app); + if(nconjs > 1){ + std::cout << "Splitting formula into " << nconjs << " conjuncts...\n"; + num = nconjs; + constraints = new Z3_ast[num]; + for(int k = 0; k < num; k++) + constraints[k] = Z3_get_app_arg(ctx,app,k); + } + } + } + + /* Write out anonymized version. */ + + if(write || anonymize){ +#if 0 + Z3_anonymize_ast_vector(ctx,num,constraints); +#endif + std::string ofn = output_file.empty() ? "iz3out.smt2" : output_file; + Z3_write_interpolation_problem(ctx, num, constraints, parents, ofn.c_str(), num_theory, theory); + std::cout << "anonymized problem written to " << ofn << "\n"; + exit(0); + } + + /* Compute an interpolant, or get a model. */ + + Z3_ast *interpolants = (Z3_ast *)malloc((num-1) * sizeof(Z3_ast)); + Z3_model model = 0; + Z3_lbool result; + + if(!incremental_mode){ + /* In non-incremental mode, we just pass the constraints. */ + result = Z3_interpolate(ctx, num, constraints, (unsigned int *)parents, options, interpolants, &model, 0, false, num_theory, theory); + } + else { + + /* This is a somewhat useless demonstration of incremental mode. + Here, we assert the constraints in the context, then pass them to + iZ3 in an array, so iZ3 knows the sequence. Note it's safe to pass + "true", even though we haven't techically asserted if. */ + + Z3_push(ctx); + std::vector asserted(num); + + /* We start with nothing asserted. */ + for(int i = 0; i < num; i++) + asserted[i] = Z3_mk_true(ctx); + + /* Now we assert the constrints one at a time until UNSAT. */ + + for(int i = 0; i < num; i++){ + asserted[i] = constraints[i]; + Z3_assert_cnstr(ctx,constraints[i]); // assert one constraint + result = Z3_interpolate(ctx, num, &asserted[0], (unsigned int *)parents, options, interpolants, &model, 0, true, 0, 0); + if(result == Z3_L_FALSE){ + for(unsigned j = 0; j < num-1; j++) + /* Since we want the interpolant formulas to survive a "pop", we + "persist" them here. */ + Z3_persist_ast(ctx,interpolants[j],1); + break; + } + } + Z3_pop(ctx,1); + } + + switch (result) { + + /* If UNSAT, print the interpolants */ + case Z3_L_FALSE: + printf("unsat\n"); + if(output_file.empty()){ + printf("interpolant:\n"); + for(int i = 0; i < num-1; i++) + printf("%s\n", Z3_ast_to_string(ctx, interpolants[i])); + } + else { +#if 0 + Z3_write_interpolation_problem(ctx,num-1,interpolants,0,output_file.c_str()); + printf("interpolant written to %s\n",output_file.c_str()); +#endif + } +#if 1 + if(check_mode){ + std::cout << "Checking interpolant...\n"; + bool chk; + chk = Z3_check_interpolant(ctx,num,constraints,parents,interpolants,&error,num_theory,theory); + if(chk) + std::cout << "Interpolant is correct\n"; + else { + std::cout << "Interpolant is incorrect\n"; + std::cout << error; + return 1; + } + } +#endif + break; + case Z3_L_UNDEF: + printf("fail\n"); + break; + case Z3_L_TRUE: + printf("sat\n"); + printf("model:\n%s\n", Z3_model_to_string(ctx, model)); + break; + } + + if(profile_mode) + std::cout << Z3_interpolation_profile(ctx); + + /* Delete the model if there is one */ + + if (model) + Z3_del_model(ctx, model); + + /* Delete logical context. */ + + Z3_del_context(ctx); + free(interpolants); + + return 0; +} + + +#if 0 + + + +int test(){ + int i; + + /* Create a Z3 context to contain formulas */ + + Z3_config cfg = Z3_mk_config(); + Z3_context ctx = iz3_mk_context(cfg); + + int num = 2; + + Z3_ast *constraints = (Z3_ast *)malloc(num * sizeof(Z3_ast)); + +#if 1 + Z3_sort arr = Z3_mk_array_sort(ctx,Z3_mk_int_sort(ctx),Z3_mk_bool_sort(ctx)); + Z3_symbol as = Z3_mk_string_symbol(ctx, "a"); + Z3_symbol bs = Z3_mk_string_symbol(ctx, "b"); + Z3_symbol xs = Z3_mk_string_symbol(ctx, "x"); + + Z3_ast a = Z3_mk_const(ctx,as,arr); + Z3_ast b = Z3_mk_const(ctx,bs,arr); + Z3_ast x = Z3_mk_const(ctx,xs,Z3_mk_int_sort(ctx)); + + Z3_ast c1 = Z3_mk_eq(ctx,a,Z3_mk_store(ctx,b,x,Z3_mk_true(ctx))); + Z3_ast c2 = Z3_mk_not(ctx,Z3_mk_select(ctx,a,x)); +#else + Z3_symbol xs = Z3_mk_string_symbol(ctx, "x"); + Z3_ast x = Z3_mk_const(ctx,xs,Z3_mk_bool_sort(ctx)); + Z3_ast c1 = Z3_mk_eq(ctx,x,Z3_mk_true(ctx)); + Z3_ast c2 = Z3_mk_eq(ctx,x,Z3_mk_false(ctx)); + +#endif + + constraints[0] = c1; + constraints[1] = c2; + + /* print out the result for grins. */ + + // Z3_string smtout = Z3_benchmark_to_smtlib_string (ctx, "foo", "QFLIA", "sat", "", num, constraints, Z3_mk_true(ctx)); + + // Z3_string smtout = Z3_ast_to_string(ctx,constraints[0]); + // Z3_string smtout = Z3_context_to_string(ctx); + // puts(smtout); + + iz3_print(ctx,num,constraints,"iZ3temp.smt"); + + /* Make room for interpolants. */ + + Z3_ast *interpolants = (Z3_ast *)malloc((num-1) * sizeof(Z3_ast)); + + /* Make room for the model. */ + + Z3_model model = 0; + + /* Call the prover */ + + Z3_lbool result = iz3_interpolate(ctx, num, constraints, interpolants, &model); + + switch (result) { + + /* If UNSAT, print the interpolants */ + case Z3_L_FALSE: + printf("unsat, interpolants:\n"); + for(i = 0; i < num-1; i++) + printf("%s\n", Z3_ast_to_string(ctx, interpolants[i])); + break; + case Z3_L_UNDEF: + printf("fail\n"); + break; + case Z3_L_TRUE: + printf("sat\n"); + printf("model:\n%s\n", Z3_model_to_string(ctx, model)); + break; + } + + /* Delete the model if there is one */ + + if (model) + Z3_del_model(ctx, model); + + /* Delete logical context (note, we call iz3_del_context, not + Z3_del_context */ + + iz3_del_context(ctx); + + return 1; +} + +struct z3_error { + Z3_error_code c; + z3_error(Z3_error_code _c) : c(_c) {} +}; + +extern "C" { + static void throw_z3_error(Z3_error_code c){ + throw z3_error(c); + } +} + +int main(int argc, const char **argv) { + + /* Create a Z3 context to contain formulas */ + + Z3_config cfg = Z3_mk_config(); + Z3_context ctx = iz3_mk_context(cfg); + Z3_set_error_handler(ctx, throw_z3_error); + + /* Make some constraints, by parsing an smtlib formatted file given as arg 1 */ + + try { + Z3_parse_smtlib_file(ctx, argv[1], 0, 0, 0, 0, 0, 0); + } + catch(const z3_error &err){ + std::cerr << "Z3 error: " << Z3_get_error_msg(err.c) << "\n"; + std::cerr << Z3_get_smtlib_error(ctx) << "\n"; + return(1); + } + + /* Get the constraints from the parser. */ + + int num = Z3_get_smtlib_num_formulas(ctx); + + if(num == 0){ + std::cerr << "iZ3 error: File contains no formulas.\n"; + return 1; + } + + + Z3_ast *constraints = (Z3_ast *)malloc(num * sizeof(Z3_ast)); + + int i; + for (i = 0; i < num; i++) + constraints[i] = Z3_get_smtlib_formula(ctx, i); + + /* if we get only one formula, and it is a conjunction, split it into conjuncts. */ + if(num == 1){ + Z3_app app = Z3_to_app(ctx,constraints[0]); + Z3_func_decl func = Z3_get_app_decl(ctx,app); + Z3_decl_kind dk = Z3_get_decl_kind(ctx,func); + if(dk == Z3_OP_AND){ + int nconjs = Z3_get_app_num_args(ctx,app); + if(nconjs > 1){ + std::cout << "Splitting formula into " << nconjs << " conjuncts...\n"; + num = nconjs; + constraints = new Z3_ast[num]; + for(int k = 0; k < num; k++) + constraints[k] = Z3_get_app_arg(ctx,app,k); + } + } + } + + + /* print out the result for grins. */ + + // Z3_string smtout = Z3_benchmark_to_smtlib_string (ctx, "foo", "QFLIA", "sat", "", num, constraints, Z3_mk_true(ctx)); + + // Z3_string smtout = Z3_ast_to_string(ctx,constraints[0]); + // Z3_string smtout = Z3_context_to_string(ctx); + // puts(smtout); + + // iz3_print(ctx,num,constraints,"iZ3temp.smt"); + + /* Make room for interpolants. */ + + Z3_ast *interpolants = (Z3_ast *)malloc((num-1) * sizeof(Z3_ast)); + + /* Make room for the model. */ + + Z3_model model = 0; + + /* Call the prover */ + + Z3_lbool result = iz3_interpolate(ctx, num, constraints, interpolants, &model); + + switch (result) { + + /* If UNSAT, print the interpolants */ + case Z3_L_FALSE: + printf("unsat, interpolants:\n"); + for(i = 0; i < num-1; i++) + printf("%s\n", Z3_ast_to_string(ctx, interpolants[i])); + std::cout << "Checking interpolants...\n"; + const char *error; + if(iZ3_check_interpolant(ctx, num, constraints, 0, interpolants, &error)) + std::cout << "Interpolant is correct\n"; + else { + std::cout << "Interpolant is incorrect\n"; + std::cout << error << "\n"; + } + break; + case Z3_L_UNDEF: + printf("fail\n"); + break; + case Z3_L_TRUE: + printf("sat\n"); + printf("model:\n%s\n", Z3_model_to_string(ctx, model)); + break; + } + + /* Delete the model if there is one */ + + if (model) + Z3_del_model(ctx, model); + + /* Delete logical context (note, we call iz3_del_context, not + Z3_del_context */ + + iz3_del_context(ctx); + + return 0; +} + +#endif diff --git a/scripts/mk_project.py b/scripts/mk_project.py index be95032b7..6215e0fa0 100644 --- a/scripts/mk_project.py +++ b/scripts/mk_project.py @@ -34,7 +34,8 @@ def init_project_def(): add_lib('subpaving_tactic', ['core_tactics', 'subpaving'], 'math/subpaving/tactic') add_lib('aig_tactic', ['tactic'], 'tactic/aig') add_lib('solver', ['model', 'tactic']) - add_lib('cmd_context', ['solver', 'rewriter']) + add_lib('interp', ['solver']) + add_lib('cmd_context', ['solver', 'rewriter', 'interp']) add_lib('extra_cmds', ['cmd_context', 'subpaving_tactic', 'arith_tactics'], 'cmd_context/extra_cmds') add_lib('smt2parser', ['cmd_context', 'parser_util'], 'parsers/smt2') add_lib('proof_checker', ['rewriter'], 'ast/proof_checker') @@ -55,6 +56,7 @@ def init_project_def(): add_lib('smt_tactic', ['smt'], 'smt/tactic') add_lib('sls_tactic', ['tactic', 'normal_forms', 'core_tactics', 'bv_tactics'], 'tactic/sls') add_lib('qe', ['smt','sat'], 'qe') + add_lib('duality', ['smt', 'interp', 'qe']) add_lib('muz', ['smt', 'sat', 'smt2parser', 'aig_tactic', 'qe'], 'muz/base') add_lib('transforms', ['muz', 'hilbert'], 'muz/transforms') add_lib('rel', ['muz', 'transforms'], 'muz/rel') @@ -62,14 +64,15 @@ def init_project_def(): add_lib('clp', ['muz', 'transforms'], 'muz/clp') add_lib('tab', ['muz', 'transforms'], 'muz/tab') add_lib('bmc', ['muz', 'transforms'], 'muz/bmc') - add_lib('fp', ['muz', 'pdr', 'clp', 'tab', 'rel', 'bmc'], 'muz/fp') + add_lib('duality_intf', ['muz', 'transforms', 'duality'], 'muz/duality') + add_lib('fp', ['muz', 'pdr', 'clp', 'tab', 'rel', 'bmc', 'duality_intf'], 'muz/fp') add_lib('smtlogic_tactics', ['arith_tactics', 'bv_tactics', 'nlsat_tactic', 'smt_tactic', 'aig_tactic', 'fp', 'muz','qe'], 'tactic/smtlogics') add_lib('ufbv_tactic', ['normal_forms', 'core_tactics', 'macros', 'smt_tactic', 'rewriter'], 'tactic/ufbv') add_lib('portfolio', ['smtlogic_tactics', 'ufbv_tactic', 'fpa', 'aig_tactic', 'fp', 'qe','sls_tactic', 'subpaving_tactic'], 'tactic/portfolio') add_lib('smtparser', ['portfolio'], 'parsers/smt') add_lib('opt', ['smt', 'smtlogic_tactics'], 'opt') API_files = ['z3_api.h', 'z3_algebraic.h', 'z3_polynomial.h', 'z3_rcf.h'] - add_lib('api', ['portfolio', 'user_plugin', 'smtparser', 'realclosure','opt'], + add_lib('api', ['portfolio', 'user_plugin', 'smtparser', 'realclosure','interp','opt'], includes2install=['z3.h', 'z3_v1.h', 'z3_macros.h'] + API_files) add_exe('shell', ['api', 'sat', 'extra_cmds','opt'], exe_name='z3') add_exe('test', ['api', 'fuzzing'], exe_name='test-z3', install=False) @@ -84,6 +87,7 @@ def init_project_def(): set_z3py_dir('api/python') # Examples add_cpp_example('cpp_example', 'c++') + add_cpp_example('iz3', 'interp') add_cpp_example('z3_tptp', 'tptp') add_c_example('c_example', 'c') add_c_example('maxsat') diff --git a/scripts/mk_util.py b/scripts/mk_util.py index 98f213dd5..9295c949f 100644 --- a/scripts/mk_util.py +++ b/scripts/mk_util.py @@ -29,6 +29,7 @@ CXX=getenv("CXX", None) CC=getenv("CC", None) CPPFLAGS=getenv("CPPFLAGS", "") CXXFLAGS=getenv("CXXFLAGS", "") +EXAMP_DEBUG_FLAG='' LDFLAGS=getenv("LDFLAGS", "") JNI_HOME=getenv("JNI_HOME", None) @@ -70,6 +71,8 @@ VER_BUILD=None VER_REVISION=None PREFIX=os.path.split(os.path.split(os.path.split(PYTHON_PACKAGE_DIR)[0])[0])[0] GMP=False +FOCI2=False +FOCI2LIB='' VS_PAR=False VS_PAR_NUM=8 GPROF=False @@ -198,6 +201,14 @@ def test_gmp(cc): t.commit() return exec_compiler_cmd([cc, CPPFLAGS, 'tstgmp.cpp', LDFLAGS, '-lgmp']) == 0 +def test_foci2(cc,foci2lib): + if is_verbose(): + print("Testing FOCI2...") + t = TempFile('tstfoci2.cpp') + t.add('#include\nint main() { foci2 *f = foci2::create("lia"); return 0; }\n') + t.commit() + return exec_compiler_cmd([cc, CPPFLAGS, '-Isrc/interp', 'tstfoci2.cpp', LDFLAGS, foci2lib]) == 0 + def test_openmp(cc): if is_verbose(): print("Testing OpenMP...") @@ -443,6 +454,7 @@ def display_help(exit_code): if not IS_WINDOWS: print(" -g, --gmp use GMP.") print(" --gprof enable gprof") + print(" -f --foci2= use foci2 library at path") print("") print("Some influential environment variables:") if not IS_WINDOWS: @@ -458,18 +470,19 @@ def display_help(exit_code): # Parse configuration option for mk_make script def parse_options(): global VERBOSE, DEBUG_MODE, IS_WINDOWS, VS_X64, ONLY_MAKEFILES, SHOW_CPPS, VS_PROJ, TRACE, VS_PAR, VS_PAR_NUM - global DOTNET_ENABLED, JAVA_ENABLED, STATIC_LIB, PREFIX, GMP, PYTHON_PACKAGE_DIR, GPROF, GIT_HASH + global DOTNET_ENABLED, JAVA_ENABLED, STATIC_LIB, PREFIX, GMP, FOCI2, FOCI2LIB, PYTHON_PACKAGE_DIR, GPROF, GIT_HASH try: options, remainder = getopt.gnu_getopt(sys.argv[1:], - 'b:dsxhmcvtnp:gj', + 'b:df:sxhmcvtnp:gj', ['build=', 'debug', 'silent', 'x64', 'help', 'makefiles', 'showcpp', 'vsproj', - 'trace', 'nodotnet', 'staticlib', 'prefix=', 'gmp', 'java', 'parallel=', 'gprof', + 'trace', 'nodotnet', 'staticlib', 'prefix=', 'gmp', 'foci2=', 'java', 'parallel=', 'gprof', 'githash=']) except: print("ERROR: Invalid command line option") display_help(1) for opt, arg in options: + print('opt = %s, arg = %s' % (opt, arg)) if opt in ('-b', '--build'): if arg == 'src': raise MKException('The src directory should not be used to host the Makefile') @@ -507,6 +520,9 @@ def parse_options(): VS_PAR_NUM = int(arg) elif opt in ('-g', '--gmp'): GMP = True + elif opt in ('-f', '--foci2'): + FOCI2 = True + FOCI2LIB = arg elif opt in ('-j', '--java'): JAVA_ENABLED = True elif opt == '--gprof': @@ -665,6 +681,10 @@ class Component: self.src_dir = os.path.join(SRC_DIR, path) self.to_src_dir = os.path.join(REV_BUILD_DIR, self.src_dir) + def get_link_name(self): + return os.path.join(self.build_dir, self.name) + '$(LIB_EXT)' + + # Find fname in the include paths for the given component. # ownerfile is only used for creating error messages. # That is, we were looking for fname when processing ownerfile @@ -880,7 +900,7 @@ class ExeComponent(Component): out.write(obj) for dep in deps: c_dep = get_component(dep) - out.write(' %s$(LIB_EXT)' % os.path.join(c_dep.build_dir, c_dep.name)) + out.write(' ' + c_dep.get_link_name()) out.write('\n') out.write('\t$(LINK) $(LINK_OUT_FLAG)%s $(LINK_FLAGS)' % exefile) for obj in objs: @@ -888,7 +908,8 @@ class ExeComponent(Component): out.write(obj) for dep in deps: c_dep = get_component(dep) - out.write(' %s$(LIB_EXT)' % os.path.join(c_dep.build_dir, c_dep.name)) + out.write(' ' + c_dep.get_link_name()) + out.write(' ' + FOCI2LIB) out.write(' $(LINK_EXTRA_FLAGS)\n') out.write('%s: %s\n\n' % (self.name, exefile)) @@ -957,6 +978,12 @@ class DLLComponent(Component): self.install = install self.static = static + def get_link_name(self): + if self.static: + return os.path.join(self.build_dir, self.name) + '$(LIB_EXT)' + else: + return self.name + '$(SO_EXT)' + def mk_makefile(self, out): Component.mk_makefile(self, out) # generate rule for (SO_EXT) @@ -979,7 +1006,7 @@ class DLLComponent(Component): for dep in deps: if not dep in self.reexports: c_dep = get_component(dep) - out.write(' %s$(LIB_EXT)' % os.path.join(c_dep.build_dir, c_dep.name)) + out.write(' ' + c_dep.get_link_name()) out.write('\n') out.write('\t$(LINK) $(SLINK_OUT_FLAG)%s $(SLINK_FLAGS)' % dllfile) for obj in objs: @@ -988,7 +1015,8 @@ class DLLComponent(Component): for dep in deps: if not dep in self.reexports: c_dep = get_component(dep) - out.write(' %s$(LIB_EXT)' % os.path.join(c_dep.build_dir, c_dep.name)) + out.write(' ' + c_dep.get_link_name()) + out.write(' ' + FOCI2LIB) out.write(' $(SLINK_EXTRA_FLAGS)') if IS_WINDOWS: out.write(' /DEF:%s.def' % os.path.join(self.to_src_dir, self.name)) @@ -1230,7 +1258,7 @@ class CppExampleComponent(ExampleComponent): out.write(' ') out.write(os.path.join(self.to_ex_dir, cppfile)) out.write('\n') - out.write('\t%s $(OS_DEFINES) $(LINK_OUT_FLAG)%s $(LINK_FLAGS)' % (self.compiler(), exefile)) + out.write('\t%s $(OS_DEFINES) $(EXAMP_DEBUG_FLAG) $(LINK_OUT_FLAG)%s $(LINK_FLAGS)' % (self.compiler(), exefile)) # Add include dir components out.write(' -I%s' % get_component(API_COMPONENT).to_src_dir) out.write(' -I%s' % get_component(CPP_COMPONENT).to_src_dir) @@ -1450,7 +1478,7 @@ def mk_config(): print('JNI Bindings: %s' % JNI_HOME) print('Java Compiler: %s' % JAVAC) else: - global CXX, CC, GMP, CPPFLAGS, CXXFLAGS, LDFLAGS + global CXX, CC, GMP, FOCI2, CPPFLAGS, CXXFLAGS, LDFLAGS, EXAMP_DEBUG_FLAG OS_DEFINES = "" ARITH = "internal" check_ar() @@ -1468,6 +1496,14 @@ def mk_config(): SLIBEXTRAFLAGS = '%s -lgmp' % SLIBEXTRAFLAGS else: CPPFLAGS = '%s -D_MP_INTERNAL' % CPPFLAGS + if FOCI2: + if test_foci2(CXX,FOCI2LIB): + LDFLAGS = '%s %s' % (LDFLAGS,FOCI2LIB) + SLIBEXTRAFLAGS = '%s %s' % (SLIBEXTRAFLAGS,FOCI2LIB) + CPPFLAGS = '%s -D_FOCI2' % CPPFLAGS + else: + print "FAILED\n" + FOCI2 = False if GIT_HASH: CPPFLAGS = '%s -DZ3GITHASH=%s' % (CPPFLAGS, GIT_HASH) CXXFLAGS = '%s -c' % CXXFLAGS @@ -1480,6 +1516,7 @@ def mk_config(): CXXFLAGS = '%s -D_NO_OMP_' % CXXFLAGS if DEBUG_MODE: CXXFLAGS = '%s -g -Wall' % CXXFLAGS + EXAMP_DEBUG_FLAG = '-g' else: if GPROF: CXXFLAGS = '%s -O3 -D _EXTERNAL_RELEASE' % CXXFLAGS @@ -1526,6 +1563,7 @@ def mk_config(): config.write('CC=%s\n' % CC) config.write('CXX=%s\n' % CXX) config.write('CXXFLAGS=%s %s\n' % (CPPFLAGS, CXXFLAGS)) + config.write('EXAMP_DEBUG_FLAG=%s\n' % EXAMP_DEBUG_FLAG) config.write('CXX_OUT_FLAG=-o \n') config.write('OBJ_EXT=.o\n') config.write('LIB_EXT=.a\n') diff --git a/src/api/api_ast.cpp b/src/api/api_ast.cpp index f116f6868..1f25adc2c 100644 --- a/src/api/api_ast.cpp +++ b/src/api/api_ast.cpp @@ -209,6 +209,7 @@ extern "C" { MK_BINARY(Z3_mk_xor, mk_c(c)->get_basic_fid(), OP_XOR, SKIP); MK_NARY(Z3_mk_and, mk_c(c)->get_basic_fid(), OP_AND, SKIP); MK_NARY(Z3_mk_or, mk_c(c)->get_basic_fid(), OP_OR, SKIP); + MK_UNARY(Z3_mk_interp, mk_c(c)->get_basic_fid(), OP_INTERP, SKIP); Z3_ast mk_ite_core(Z3_context c, Z3_ast t1, Z3_ast t2, Z3_ast t3) { expr * result = mk_c(c)->m().mk_ite(to_expr(t1), to_expr(t2), to_expr(t3)); @@ -923,6 +924,7 @@ extern "C" { case OP_NOT: return Z3_OP_NOT; case OP_IMPLIES: return Z3_OP_IMPLIES; case OP_OEQ: return Z3_OP_OEQ; + case OP_INTERP: return Z3_OP_INTERP; case PR_UNDEF: return Z3_OP_PR_UNDEF; case PR_TRUE: return Z3_OP_PR_TRUE; diff --git a/src/api/api_interp.cpp b/src/api/api_interp.cpp new file mode 100644 index 000000000..560ee5885 --- /dev/null +++ b/src/api/api_interp.cpp @@ -0,0 +1,622 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + api_interp.cpp + +Abstract: + API for interpolation + +Author: + + Ken McMillan + +Revision History: + +--*/ +#include +#include +#include"z3.h" +#include"api_log_macros.h" +#include"api_context.h" +#include"api_tactic.h" +#include"api_solver.h" +#include"api_model.h" +#include"api_stats.h" +#include"api_ast_vector.h" +#include"tactic2solver.h" +#include"scoped_ctrl_c.h" +#include"cancel_eh.h" +#include"scoped_timer.h" +#include"smt_strategic_solver.h" +#include"smt_solver.h" +#include"smt_implied_equalities.h" +#include"iz3interp.h" +#include"iz3profiling.h" +#include"iz3hash.h" +#include"iz3pp.h" +#include"iz3checker.h" + +#ifndef WIN32 +using namespace stl_ext; +#endif + +#ifndef WIN32 +// WARNING: don't make a hash_map with this if the range type +// has a destructor: you'll get an address dependency!!! +namespace stl_ext { + template <> + class hash { + public: + size_t operator()(const Z3_ast p) const { + return (size_t) p; + } + }; +} +#endif + +typedef interpolation_options_struct *Z3_interpolation_options; + +extern "C" { + + Z3_context Z3_mk_interpolation_context(Z3_config cfg){ + if(!cfg) cfg = Z3_mk_config(); + Z3_set_param_value(cfg, "PROOF", "true"); + Z3_set_param_value(cfg, "MODEL", "true"); + // Z3_set_param_value(cfg, "PRE_SIMPLIFIER","false"); + // Z3_set_param_value(cfg, "SIMPLIFY_CLAUSES","false"); + + Z3_context ctx = Z3_mk_context(cfg); + Z3_del_config(cfg); + return ctx; + } + + void Z3_interpolate_proof(Z3_context ctx, + Z3_ast proof, + int num, + Z3_ast *cnsts, + unsigned *parents, + Z3_params options, + Z3_ast *interps, + int num_theory, + Z3_ast *theory + ){ + + if(num > 1){ // if we have interpolants to compute + + ptr_vector pre_cnsts_vec(num); // get constraints in a vector + for(int i = 0; i < num; i++){ + ast *a = to_ast(cnsts[i]); + pre_cnsts_vec[i] = a; + } + + ::vector pre_parents_vec; // get parents in a vector + if(parents){ + pre_parents_vec.resize(num); + for(int i = 0; i < num; i++) + pre_parents_vec[i] = parents[i]; + } + + ptr_vector theory_vec; // get background theory in a vector + if(theory){ + theory_vec.resize(num_theory); + for(int i = 0; i < num_theory; i++) + theory_vec[i] = to_ast(theory[i]); + } + + ptr_vector interpolants(num-1); // make space for result + + ast_manager &_m = mk_c(ctx)->m(); + iz3interpolate(_m, + to_ast(proof), + pre_cnsts_vec, + pre_parents_vec, + interpolants, + theory_vec, + 0); // ignore params for now FIXME + + // copy result back + for(unsigned i = 0; i < interpolants.size(); i++){ + mk_c(ctx)->save_ast_trail(interpolants[i]); + interps[i] = of_ast(interpolants[i]); + _m.dec_ref(interpolants[i]); + } + } + } + + Z3_lbool Z3_interpolate(Z3_context ctx, + int num, + Z3_ast *cnsts, + unsigned *parents, + Z3_params options, + Z3_ast *interps, + Z3_model *model, + Z3_literals *labels, + int incremental, + int num_theory, + Z3_ast *theory + ){ + + + profiling::timer_start("Solve"); + + if(!incremental){ + + profiling::timer_start("Z3 assert"); + + Z3_push(ctx); // so we can rewind later + + for(int i = 0; i < num; i++) + Z3_assert_cnstr(ctx,cnsts[i]); // assert all the constraints + + if(theory){ + for(int i = 0; i < num_theory; i++) + Z3_assert_cnstr(ctx,theory[i]); + } + + profiling::timer_stop("Z3 assert"); + } + + + // Get a proof of unsat + + Z3_ast proof; + Z3_lbool result; + + profiling::timer_start("Z3 solving"); + result = Z3_check_assumptions(ctx, 0, 0, model, &proof, 0, 0); + profiling::timer_stop("Z3 solving"); + + switch (result) { + case Z3_L_FALSE: + + Z3_interpolate_proof(ctx, + proof, + num, + cnsts, + parents, + options, + interps, + num_theory, + theory); + + if(!incremental) + for(int i = 0; i < num-1; i++) + Z3_persist_ast(ctx,interps[i],1); + break; + + case Z3_L_UNDEF: + if(labels) + *labels = Z3_get_relevant_labels(ctx); + break; + + case Z3_L_TRUE: + if(labels) + *labels = Z3_get_relevant_labels(ctx); + break; + } + + profiling::timer_start("Z3 pop"); + if(!incremental) + Z3_pop(ctx,1); + profiling::timer_stop("Z3 pop"); + + profiling::timer_stop("Solve"); + + return result; + + } + + static std::ostringstream itp_err; + + int Z3_check_interpolant(Z3_context ctx, + int num, + Z3_ast *cnsts, + int *parents, + Z3_ast *itp, + const char **error, + int num_theory, + Z3_ast *theory){ + + ast_manager &_m = mk_c(ctx)->m(); + itp_err.clear(); + + // need a solver -- make one here, but how? + params_ref p = params_ref::get_empty(); //FIXME + scoped_ptr sf(mk_smt_solver_factory()); + scoped_ptr sp((*(sf))(_m, p, false, true, false, symbol("AUFLIA"))); + + ptr_vector cnsts_vec(num); // get constraints in a vector + for(int i = 0; i < num; i++){ + ast *a = to_ast(cnsts[i]); + cnsts_vec[i] = a; + } + + ptr_vector itp_vec(num); // get interpolants in a vector + for(int i = 0; i < num-1; i++){ + ast *a = to_ast(itp[i]); + itp_vec[i] = a; + } + + ::vector parents_vec; // get parents in a vector + if(parents){ + parents_vec.resize(num); + for(int i = 0; i < num; i++) + parents_vec[i] = parents[i]; + } + + ptr_vector theory_vec; // get background theory in a vector + if(theory){ + theory_vec.resize(num_theory); + for(int i = 0; i < num_theory; i++) + theory_vec[i] = to_ast(theory[i]); + } + + bool res = iz3check(_m, + sp.get(), + itp_err, + cnsts_vec, + parents_vec, + itp_vec, + theory_vec); + + *error = res ? 0 : itp_err.str().c_str(); + return res; + } + + + static std::string Z3_profile_string; + + Z3_string Z3_interpolation_profile(Z3_context ctx){ + std::ostringstream f; + profiling::print(f); + Z3_profile_string = f.str(); + return Z3_profile_string.c_str(); + } + + + Z3_interpolation_options + Z3_mk_interpolation_options(){ + return (Z3_interpolation_options) new interpolation_options_struct; + } + + void + Z3_del_interpolation_options(Z3_interpolation_options opts){ + delete opts; + } + + void + Z3_set_interpolation_option(Z3_interpolation_options opts, + Z3_string name, + Z3_string value){ + opts->map[name] = value; + } + + + +}; + + +static void tokenize(const std::string &str, std::vector &tokens){ + for(unsigned i = 0; i < str.size();){ + if(str[i] == ' '){i++; continue;} + unsigned beg = i; + while(i < str.size() && str[i] != ' ')i++; + if(i > beg) + tokens.push_back(str.substr(beg,i-beg)); + } +} + +static void get_file_params(const char *filename, hash_map ¶ms){ + std::ifstream f(filename); + if(f){ + std::string first_line; + std::getline(f,first_line); + // std::cout << "first line: '" << first_line << "'" << std::endl; + if(first_line.size() >= 2 && first_line[0] == ';' && first_line[1] == '!'){ + std::vector tokens; + tokenize(first_line.substr(2,first_line.size()-2),tokens); + for(unsigned i = 0; i < tokens.size(); i++){ + std::string &tok = tokens[i]; + size_t eqpos = tok.find('='); + if(eqpos >= 0 && eqpos < tok.size()){ + std::string left = tok.substr(0,eqpos); + std::string right = tok.substr(eqpos+1,tok.size()-eqpos-1); + params[left] = right; + } + } + } + f.close(); + } +} + +extern "C" { + +#if 0 + static void iZ3_write_seq(Z3_context ctx, int num, Z3_ast *cnsts, const char *filename, int num_theory, Z3_ast *theory){ + int num_fmlas = num+num_theory; + std::vector fmlas(num_fmlas); + if(num_theory) + std::copy(theory,theory+num_theory,fmlas.begin()); + for(int i = 0; i < num_theory; i++) + fmlas[i] = Z3_mk_implies(ctx,Z3_mk_true(ctx),fmlas[i]); + std::copy(cnsts,cnsts+num,fmlas.begin()+num_theory); + Z3_string smt = Z3_benchmark_to_smtlib_string(ctx,"none","AUFLIA","unknown","",num_fmlas-1,&fmlas[0],fmlas[num_fmlas-1]); + std::ofstream f(filename); + if(num_theory) + f << ";! THEORY=" << num_theory << "\n"; + f << smt; + f.close(); + } + + void Z3_write_interpolation_problem(Z3_context ctx, int num, Z3_ast *cnsts, int *parents, const char *filename, int num_theory, Z3_ast *theory){ + if(!parents){ + iZ3_write_seq(ctx,num,cnsts,filename,num_theory,theory); + return; + } + std::vector tcnsts(num); + hash_map syms; + for(int j = 0; j < num - 1; j++){ + std::ostringstream oss; + oss << "$P" << j; + std::string name = oss.str(); + Z3_symbol s = Z3_mk_string_symbol(ctx, name.c_str()); + Z3_ast symbol = Z3_mk_const(ctx, s, Z3_mk_bool_sort(ctx)); + syms[j] = symbol; + tcnsts[j] = Z3_mk_implies(ctx,cnsts[j],symbol); + } + tcnsts[num-1] = Z3_mk_implies(ctx,cnsts[num-1],Z3_mk_false(ctx)); + for(int j = num-2; j >= 0; j--){ + int parent = parents[j]; + // assert(parent >= 0 && parent < num); + tcnsts[parent] = Z3_mk_implies(ctx,syms[j],tcnsts[parent]); + } + iZ3_write_seq(ctx,num,&tcnsts[0],filename,num_theory,theory); + } +#else + + + static Z3_ast and_vec(Z3_context ctx,svector &c){ + return (c.size() > 1) ? Z3_mk_and(ctx,c.size(),&c[0]) : c[0]; + } + + static Z3_ast parents_vector_to_tree(Z3_context ctx, int num, Z3_ast *cnsts, int *parents){ + Z3_ast res; + if(!parents){ + res = Z3_mk_interp(ctx,cnsts[0]); + for(int i = 1; i < num-1; i++){ + Z3_ast bar[2] = {res,cnsts[i]}; + res = Z3_mk_interp(ctx,Z3_mk_and(ctx,2,bar)); + } + if(num > 1){ + Z3_ast bar[2] = {res,cnsts[num-1]}; + res = Z3_mk_and(ctx,2,bar); + } + } + else { + std::vector > chs(num); + for(int i = 0; i < num-1; i++){ + svector &c = chs[i]; + c.push_back(cnsts[i]); + Z3_ast foo = Z3_mk_interp(ctx,and_vec(ctx,c)); + chs[parents[i]].push_back(foo); + } + { + svector &c = chs[num-1]; + c.push_back(cnsts[num-1]); + res = and_vec(ctx,c); + } + } + Z3_inc_ref(ctx,res); + return res; + } + + void Z3_write_interpolation_problem(Z3_context ctx, int num, Z3_ast *cnsts, int *parents, const char *filename, int num_theory, Z3_ast *theory){ + std::ofstream f(filename); + if(num > 0){ + ptr_vector cnsts_vec(num); // get constraints in a vector + for(int i = 0; i < num; i++){ + expr *a = to_expr(cnsts[i]); + cnsts_vec[i] = a; + } + Z3_ast tree = parents_vector_to_tree(ctx,num,cnsts,parents); + for(int i = 0; i < num_theory; i++){ + expr *a = to_expr(theory[i]); + cnsts_vec.push_back(a); + } + iz3pp(mk_c(ctx)->m(),cnsts_vec,to_expr(tree),f); + Z3_dec_ref(ctx,tree); + } + f.close(); + +#if 0 + + + if(!parents){ + iZ3_write_seq(ctx,num,cnsts,filename,num_theory,theory); + return; + } + std::vector tcnsts(num); + hash_map syms; + for(int j = 0; j < num - 1; j++){ + std::ostringstream oss; + oss << "$P" << j; + std::string name = oss.str(); + Z3_symbol s = Z3_mk_string_symbol(ctx, name.c_str()); + Z3_ast symbol = Z3_mk_const(ctx, s, Z3_mk_bool_sort(ctx)); + syms[j] = symbol; + tcnsts[j] = Z3_mk_implies(ctx,cnsts[j],symbol); + } + tcnsts[num-1] = Z3_mk_implies(ctx,cnsts[num-1],Z3_mk_false(ctx)); + for(int j = num-2; j >= 0; j--){ + int parent = parents[j]; + // assert(parent >= 0 && parent < num); + tcnsts[parent] = Z3_mk_implies(ctx,syms[j],tcnsts[parent]); + } + iZ3_write_seq(ctx,num,&tcnsts[0],filename,num_theory,theory); +#endif + + } + + +#endif + + static std::vector read_cnsts; + static std::vector read_parents; + static std::ostringstream read_error; + static std::string read_msg; + static std::vector read_theory; + + static bool iZ3_parse(Z3_context ctx, const char *filename, const char **error, svector &assertions){ + read_error.clear(); + try { + std::string foo(filename); + if(foo.size() >= 5 && foo.substr(foo.size()-5) == ".smt2"){ + Z3_ast ass = Z3_parse_smtlib2_file(ctx, filename, 0, 0, 0, 0, 0, 0); + Z3_app app = Z3_to_app(ctx,ass); + int nconjs = Z3_get_app_num_args(ctx,app); + assertions.resize(nconjs); + for(int k = 0; k < nconjs; k++) + assertions[k] = Z3_get_app_arg(ctx,app,k); + } + else { + Z3_parse_smtlib_file(ctx, filename, 0, 0, 0, 0, 0, 0); + int numa = Z3_get_smtlib_num_assumptions(ctx); + int numf = Z3_get_smtlib_num_formulas(ctx); + int num = numa + numf; + + assertions.resize(num); + for(int j = 0; j < num; j++){ + if(j < numa) + assertions[j] = Z3_get_smtlib_assumption(ctx,j); + else + assertions[j] = Z3_get_smtlib_formula(ctx,j-numa); + } + } + } + catch(...) { + read_error << "SMTLIB parse error: " << Z3_get_smtlib_error(ctx); + read_msg = read_error.str(); + *error = read_msg.c_str(); + return false; + } + Z3_set_error_handler(ctx, 0); + return true; + } + + + int Z3_read_interpolation_problem(Z3_context ctx, int *_num, Z3_ast **cnsts, int **parents, const char *filename, const char **error, int *ret_num_theory, Z3_ast **theory ){ + + hash_map file_params; + get_file_params(filename,file_params); + + unsigned num_theory = 0; + if(file_params.find("THEORY") != file_params.end()) + num_theory = atoi(file_params["THEORY"].c_str()); + + svector assertions; + if(!iZ3_parse(ctx,filename,error,assertions)) + return false; + + if(num_theory > assertions.size()) + num_theory = assertions.size(); + unsigned num = assertions.size() - num_theory; + + read_cnsts.resize(num); + read_parents.resize(num); + read_theory.resize(num_theory); + + for(unsigned j = 0; j < num_theory; j++) + read_theory[j] = assertions[j]; + for(unsigned j = 0; j < num; j++) + read_cnsts[j] = assertions[j+num_theory]; + + if(ret_num_theory) + *ret_num_theory = num_theory; + if(theory) + *theory = &read_theory[0]; + + if(!parents){ + *_num = num; + *cnsts = &read_cnsts[0]; + return true; + } + + for(unsigned j = 0; j < num; j++) + read_parents[j] = SHRT_MAX; + + hash_map pred_map; + + for(unsigned j = 0; j < num; j++){ + Z3_ast lhs = 0, rhs = read_cnsts[j]; + + if(Z3_get_decl_kind(ctx,Z3_get_app_decl(ctx,Z3_to_app(ctx,rhs))) == Z3_OP_IMPLIES){ + Z3_app app1 = Z3_to_app(ctx,rhs); + Z3_ast lhs1 = Z3_get_app_arg(ctx,app1,0); + Z3_ast rhs1 = Z3_get_app_arg(ctx,app1,1); + if(Z3_get_decl_kind(ctx,Z3_get_app_decl(ctx,Z3_to_app(ctx,lhs1))) == Z3_OP_AND){ + Z3_app app2 = Z3_to_app(ctx,lhs1); + int nconjs = Z3_get_app_num_args(ctx,app2); + for(int k = nconjs - 1; k >= 0; --k) + rhs1 = Z3_mk_implies(ctx,Z3_get_app_arg(ctx,app2,k),rhs1); + rhs = rhs1; + } + } + + while(1){ + Z3_app app = Z3_to_app(ctx,rhs); + Z3_func_decl func = Z3_get_app_decl(ctx,app); + Z3_decl_kind dk = Z3_get_decl_kind(ctx,func); + if(dk == Z3_OP_IMPLIES){ + if(lhs){ + Z3_ast child = lhs; + if(pred_map.find(child) == pred_map.end()){ + read_error << "formula " << j+1 << ": unknown: " << Z3_ast_to_string(ctx,child); + goto fail; + } + int child_num = pred_map[child]; + if(read_parents[child_num] != SHRT_MAX){ + read_error << "formula " << j+1 << ": multiple reference: " << Z3_ast_to_string(ctx,child); + goto fail; + } + read_parents[child_num] = j; + } + lhs = Z3_get_app_arg(ctx,app,0); + rhs = Z3_get_app_arg(ctx,app,1); + } + else { + if(!lhs){ + read_error << "formula " << j+1 << ": should be (implies {children} fmla parent)"; + goto fail; + } + read_cnsts[j] = lhs; + Z3_ast name = rhs; + if(pred_map.find(name) != pred_map.end()){ + read_error << "formula " << j+1 << ": duplicate symbol"; + goto fail; + } + pred_map[name] = j; + break; + } + } + } + + for(unsigned j = 0; j < num-1; j++) + if(read_parents[j] == SHRT_MIN){ + read_error << "formula " << j+1 << ": unreferenced"; + goto fail; + } + + *_num = num; + *cnsts = &read_cnsts[0]; + *parents = &read_parents[0]; + return true; + + fail: + read_msg = read_error.str(); + *error = read_msg.c_str(); + return false; + + } +} diff --git a/src/api/dotnet/Properties/AssemblyInfo b/src/api/dotnet/Properties/AssemblyInfo index 1ac6cb520..ee0e3f354 100644 --- a/src/api/dotnet/Properties/AssemblyInfo +++ b/src/api/dotnet/Properties/AssemblyInfo @@ -36,4 +36,3 @@ using System.Security.Permissions; // [assembly: AssemblyVersion("4.2.0.0")] [assembly: AssemblyVersion("4.3.2.0")] [assembly: AssemblyFileVersion("4.3.2.0")] - diff --git a/src/api/java/BitVecNum.java b/src/api/java/BitVecNum.java index 7615308b7..2dd0dd75a 100644 --- a/src/api/java/BitVecNum.java +++ b/src/api/java/BitVecNum.java @@ -35,7 +35,7 @@ public class BitVecNum extends BitVecExpr { Native.LongPtr res = new Native.LongPtr(); if (Native.getNumeralInt64(getContext().nCtx(), getNativeObject(), res) ^ true) - throw new Z3Exception("Numeral is not an int64"); + throw new Z3Exception("Numeral is not a long"); return res.value; } diff --git a/src/api/python/z3.py b/src/api/python/z3.py index 22ec30ad2..959566619 100644 --- a/src/api/python/z3.py +++ b/src/api/python/z3.py @@ -6424,7 +6424,7 @@ class Tactic: _z3_assert(isinstance(goal, Goal) or isinstance(goal, BoolRef), "Z3 Goal or Boolean expressions expected") goal = _to_goal(goal) if len(arguments) > 0 or len(keywords) > 0: - p = args2params(arguments, keywords, a.ctx) + p = args2params(arguments, keywords, self.ctx) return ApplyResult(Z3_tactic_apply_ex(self.ctx.ref(), self.tactic, goal.goal, p.params), self.ctx) else: return ApplyResult(Z3_tactic_apply(self.ctx.ref(), self.tactic, goal.goal), self.ctx) diff --git a/src/api/z3_api.h b/src/api/z3_api.h index 58bbe0fe9..64476c1a9 100644 --- a/src/api/z3_api.h +++ b/src/api/z3_api.h @@ -251,6 +251,8 @@ typedef enum - Z3_OP_OEQ Binary equivalence modulo namings. This binary predicate is used in proof terms. It captures equisatisfiability and equivalence modulo renamings. + - Z3_OP_INTERP Marks a sub-formula for interpolation. + - Z3_OP_ANUM Arithmetic numeral. - Z3_OP_AGNUM Arithmetic algebraic numeral. Algebraic numbers are used to represent irrational numbers in Z3. @@ -901,6 +903,7 @@ typedef enum { Z3_OP_NOT, Z3_OP_IMPLIES, Z3_OP_OEQ, + Z3_OP_INTERP, // Arithmetic Z3_OP_ANUM = 0x200, @@ -2121,6 +2124,16 @@ END_MLAPI_EXCLUDE */ Z3_ast Z3_API Z3_mk_not(__in Z3_context c, __in Z3_ast a); + /** + \brief \mlh mk_interp c a \endmlh + Create an AST node marking a formula position for interpolation. + + The node \c a must have Boolean sort. + + def_API('Z3_mk_interp', AST, (_in(CONTEXT), _in(AST))) + */ + Z3_ast Z3_API Z3_mk_interp(__in Z3_context c, __in Z3_ast a); + /** \brief \mlh mk_ite c t1 t2 t2 \endmlh Create an AST node representing an if-then-else: ite(t1, t2, @@ -7863,6 +7876,211 @@ END_MLAPI_EXCLUDE /*@}*/ + /** + @name Interpolation + */ + /*@{*/ + + /** \brief This function generates a Z3 context suitable for generation of + interpolants. Formulas can be generated as abstract syntx trees in + this context using the Z3 C API. + + Interpolants are also generated as AST's in this context. + + If cfg is non-null, it will be used as the base configuration + for the Z3 context. This makes it possible to set Z3 options + to be used during interpolation. This feature should be used + with some caution however, as it may be that certain Z3 options + are incompatible with interpolation. + + def_API('Z3_mk_interpolation_context', CONTEXT, (_in(CONFIG),)) + + */ + + Z3_context Z3_API Z3_mk_interpolation_context(__in Z3_config cfg); + + +/** Constant reprepresenting a root of a formula tree for tree interpolation */ +#define IZ3_ROOT SHRT_MAX + +/** This function uses Z3 to determine satisfiability of a set of + constraints. If UNSAT, an interpolant is returned, based on the + refutation generated by Z3. If SAT, a model is returned. + + If "parents" is non-null, computes a tree interpolant. The tree is + defined by the array "parents". This array maps each formula in + the tree to its parent, where formulas are indicated by their + integer index in "cnsts". The parent of formula n must have index + greater than n. The last formula is the root of the tree. Its + parent entry should be the constant IZ3_ROOT. + + If "parents" is null, computes a sequence interpolant. + + \param ctx The Z3 context. Must be generated by iz3_mk_context + \param num The number of constraints in the sequence + \param cnsts Array of constraints (AST's in context ctx) + \param parents The parents vector defining the tree structure + \param options Interpolation options (may be NULL) + \param interps Array to return interpolants (size at least num-1, may be NULL) + \param model Returns a Z3 model if constraints SAT (may be NULL) + \param labels Returns relevant labels if SAT (may be NULL) + \param incremental + + VERY IMPORTANT: All the Z3 formulas in cnsts must be in Z3 + context ctx. The model and interpolants returned are also + in this context. + + The return code is as in Z3_check_assumptions, that is, + + Z3_L_FALSE = constraints UNSAT (interpolants returned) + Z3_L_TRUE = constraints SAT (model returned) + Z3_L_UNDEF = Z3 produced no result, or interpolation not possible + + Currently, this function supports integer and boolean variables, + as well as arrays over these types, with linear arithmetic, + uninterpreted functions and quantifiers over integers (that is + AUFLIA). Interpolants are produced in AUFLIA. However, some + uses of array operations may cause quantifiers to appear in the + interpolants even when there are no quantifiers in the input formulas. + Although quantifiers may appear in the input formulas, Z3 may give up in + this case, returning Z3_L_UNDEF. + + If "incremental" is true, cnsts must contain exactly the set of + formulas that are currently asserted in the context. If false, + there must be no formulas currently asserted in the context. + Setting "incremental" to true makes it posisble to incrementally + add and remove constraints from the context until the context + becomes UNSAT, at which point an interpolant is computed. Caution + must be used, however. Before popping the context, if you wish to + keep the interolant formulas, you *must* preserve them by using + Z3_persist_ast. Also, if you want to simplify the interpolant + formulas using Z3_simplify, you must first pop all of the + assertions in the context (or use a different context). Otherwise, + the formulas will be simplified *relative* to these constraints, + which is almost certainly not what you want. + + + Current limitations on tree interpolants. In a tree interpolation + problem, each constant (0-ary function symbol) must occur only + along one path from root to leaf. Function symbols (of arity > 0) + are considered to have global scope (i.e., may appear in any + interpolant formula). + + + */ + + + Z3_lbool Z3_API Z3_interpolate(__in Z3_context ctx, + __in int num, + __in_ecount(num) Z3_ast *cnsts, + __in_ecount(num) unsigned *parents, + __in Z3_params options, + __out_ecount(num-1) Z3_ast *interps, + __out Z3_model *model, + __out Z3_literals *labels, + __in int incremental, + __in int num_theory, + __in_ecount(num_theory) Z3_ast *theory); + + /** Return a string summarizing cumulative time used for + interpolation. This string is purely for entertainment purposes + and has no semantics. + + \param ctx The context (currently ignored) + + def_API('Z3_interpolation_profile', STRING, (_in(CONTEXT),)) + */ + + Z3_string Z3_API Z3_interpolation_profile(__in Z3_context ctx); + + /** + \brief Read an interpolation problem from file. + + \param ctx The Z3 context. This resets the error handler of ctx. + \param filename The file name to read. + \param num Returns length of sequence. + \param cnsts Returns sequence of formulas (do not free) + \param parents Returns the parents vector (or NULL for sequence) + \param error Returns an error message in case of failure (do not free the string) + + Returns true on success. + + File formats: Currently two formats are supported, based on + SMT-LIB2. For sequence interpolants, the sequence of constraints is + represented by the sequence of "assert" commands in the file. + + For tree interpolants, one symbol of type bool is associated to + each vertex of the tree. For each vertex v there is an "assert" + of the form: + + (implies (and c1 ... cn f) v) + + where c1 .. cn are the children of v (which must precede v in the file) + and f is the formula assiciated to node v. The last formula in the + file is the root vertex, and is represented by the predicate "false". + + A solution to a tree interpolation problem can be thought of as a + valuation of the vertices that makes all the implications true + where each value is represented using the common symbols between + the formulas in the subtree and the remainder of the formulas. + + */ + + + int Z3_API Z3_read_interpolation_problem(__in Z3_context ctx, + __out int *num, + __out_ecount(*num) Z3_ast **cnsts, + __out_ecount(*num) int **parents, + __in const char *filename, + __out const char **error, + __out int *num_theory, + __out_ecount(*num_theory) Z3_ast **theory); + + + + /** Check the correctness of an interpolant. The Z3 context must + have no constraints asserted when this call is made. That means + that after interpolating, you must first fully pop the Z3 + context before calling this. See Z3_interpolate for meaning of parameters. + + \param ctx The Z3 context. Must be generated by Z3_mk_interpolation_context + \param num The number of constraints in the sequence + \param cnsts Array of constraints (AST's in context ctx) + \param parents The parents vector (or NULL for sequence) + \param interps The interpolant to check + \param error Returns an error message if interpolant incorrect (do not free the string) + + Return value is Z3_L_TRUE if interpolant is verified, Z3_L_FALSE if + incorrect, and Z3_L_UNDEF if unknown. + + */ + + int Z3_API Z3_check_interpolant(Z3_context ctx, int num, Z3_ast *cnsts, int *parents, Z3_ast *interps, const char **error, + int num_theory, Z3_ast *theory); + + /** Write an interpolation problem to file suitable for reading with + Z3_read_interpolation_problem. The output file is a sequence + of SMT-LIB2 format commands, suitable for reading with command-line Z3 + or other interpolating solvers. + + \param ctx The Z3 context. Must be generated by z3_mk_interpolation_context + \param num The number of constraints in the sequence + \param cnsts Array of constraints + \param parents The parents vector (or NULL for sequence) + \param filename The file name to write + + */ + + void Z3_API Z3_write_interpolation_problem(Z3_context ctx, + int num, + Z3_ast *cnsts, + int *parents, + const char *filename, + int num_theory, + Z3_ast *theory); + + + #endif diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 0555e5fc8..b000201b7 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -640,6 +640,7 @@ basic_decl_plugin::basic_decl_plugin(): m_iff_decl(0), m_xor_decl(0), m_not_decl(0), + m_interp_decl(0), m_implies_decl(0), m_proof_sort(0), @@ -863,6 +864,7 @@ void basic_decl_plugin::set_manager(ast_manager * m, family_id id) { m_iff_decl = mk_bool_op_decl("iff", OP_IFF, 2, false, true, false, false, true); m_xor_decl = mk_bool_op_decl("xor", OP_XOR, 2, true, true); m_not_decl = mk_bool_op_decl("not", OP_NOT, 1); + m_interp_decl = mk_bool_op_decl("interp", OP_INTERP, 1); m_implies_decl = mk_implies_decl(); m_proof_sort = m->mk_sort(symbol("Proof"), sort_info(id, PROOF_SORT)); @@ -887,6 +889,7 @@ void basic_decl_plugin::get_op_names(svector & op_names, symbol co op_names.push_back(builtin_name("or", OP_OR)); op_names.push_back(builtin_name("xor", OP_XOR)); op_names.push_back(builtin_name("not", OP_NOT)); + op_names.push_back(builtin_name("interp", OP_INTERP)); op_names.push_back(builtin_name("=>", OP_IMPLIES)); if (logic == symbol::null) { // user friendly aliases @@ -898,6 +901,7 @@ void basic_decl_plugin::get_op_names(svector & op_names, symbol co op_names.push_back(builtin_name("||", OP_OR)); op_names.push_back(builtin_name("equals", OP_EQ)); op_names.push_back(builtin_name("equiv", OP_IFF)); + op_names.push_back(builtin_name("@@", OP_INTERP)); } } @@ -918,6 +922,7 @@ void basic_decl_plugin::finalize() { DEC_REF(m_and_decl); DEC_REF(m_or_decl); DEC_REF(m_not_decl); + DEC_REF(m_interp_decl); DEC_REF(m_iff_decl); DEC_REF(m_xor_decl); DEC_REF(m_implies_decl); @@ -1016,6 +1021,7 @@ func_decl * basic_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters case OP_AND: return m_and_decl; case OP_OR: return m_or_decl; case OP_NOT: return m_not_decl; + case OP_INTERP: return m_interp_decl; case OP_IFF: return m_iff_decl; case OP_IMPLIES: return m_implies_decl; case OP_XOR: return m_xor_decl; @@ -1051,6 +1057,7 @@ func_decl * basic_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters case OP_AND: return m_and_decl; case OP_OR: return m_or_decl; case OP_NOT: return m_not_decl; + case OP_INTERP: return m_interp_decl; case OP_IFF: return m_iff_decl; case OP_IMPLIES: return m_implies_decl; case OP_XOR: return m_xor_decl; @@ -3146,4 +3153,14 @@ void scoped_mark::pop_scope(unsigned num_scopes) { } } +// Added by KLM for use in GDB +// show an expr_ref on stdout + +void prexpr(expr_ref &e){ + std::cout << mk_pp(e.get(), e.get_manager()) << std::endl; +} + +void ast_manager::show_id_gen(){ + std::cout << "id_gen: " << m_expr_id_gen.show_hash() << " " << m_decl_id_gen.show_hash() << "\n"; +} diff --git a/src/ast/ast.h b/src/ast/ast.h index ef4ab3b55..68f08e1ac 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -1006,7 +1006,7 @@ enum basic_sort_kind { }; enum basic_op_kind { - OP_TRUE, OP_FALSE, OP_EQ, OP_DISTINCT, OP_ITE, OP_AND, OP_OR, OP_IFF, OP_XOR, OP_NOT, OP_IMPLIES, OP_OEQ, LAST_BASIC_OP, + OP_TRUE, OP_FALSE, OP_EQ, OP_DISTINCT, OP_ITE, OP_AND, OP_OR, OP_IFF, OP_XOR, OP_NOT, OP_IMPLIES, OP_OEQ, OP_INTERP, LAST_BASIC_OP, PR_UNDEF, PR_TRUE, PR_ASSERTED, PR_GOAL, PR_MODUS_PONENS, PR_REFLEXIVITY, PR_SYMMETRY, PR_TRANSITIVITY, PR_TRANSITIVITY_STAR, PR_MONOTONICITY, PR_QUANT_INTRO, PR_DISTRIBUTIVITY, PR_AND_ELIM, PR_NOT_OR_ELIM, PR_REWRITE, PR_REWRITE_STAR, PR_PULL_QUANT, @@ -1028,6 +1028,7 @@ protected: func_decl * m_iff_decl; func_decl * m_xor_decl; func_decl * m_not_decl; + func_decl * m_interp_decl; func_decl * m_implies_decl; ptr_vector m_eq_decls; // cached eqs ptr_vector m_ite_decls; // cached ites @@ -1417,6 +1418,8 @@ protected: public: typedef expr_dependency_array_manager::ref expr_dependency_array; + void show_id_gen(); + protected: small_object_allocator m_alloc; family_manager m_family_manager; @@ -2000,6 +2003,7 @@ public: app * mk_distinct_expanded(unsigned num_args, expr * const * args); app * mk_true() { return m_true; } app * mk_false() { return m_false; } + app * mk_interp(expr * arg) { return mk_app(m_basic_family_id, OP_INTERP, arg); } func_decl* mk_and_decl() { sort* domain[2] = { m_bool_sort, m_bool_sort }; diff --git a/src/ast/float_decl_plugin.cpp b/src/ast/float_decl_plugin.cpp index 128dd2ce7..eb8aba9ae 100644 --- a/src/ast/float_decl_plugin.cpp +++ b/src/ast/float_decl_plugin.cpp @@ -475,6 +475,7 @@ void float_decl_plugin::get_op_names(svector & op_names, symbol co op_names.push_back(builtin_name("plusInfinity", OP_FLOAT_PLUS_INF)); op_names.push_back(builtin_name("minusInfinity", OP_FLOAT_MINUS_INF)); op_names.push_back(builtin_name("NaN", OP_FLOAT_NAN)); + op_names.push_back(builtin_name("roundNearestTiesToEven", OP_RM_NEAREST_TIES_TO_EVEN)); op_names.push_back(builtin_name("roundNearestTiesToAway", OP_RM_NEAREST_TIES_TO_AWAY)); op_names.push_back(builtin_name("roundTowardPositive", OP_RM_TOWARD_POSITIVE)); @@ -486,7 +487,7 @@ void float_decl_plugin::get_op_names(svector & op_names, symbol co op_names.push_back(builtin_name("/", OP_FLOAT_DIV)); op_names.push_back(builtin_name("*", OP_FLOAT_MUL)); - op_names.push_back(builtin_name("abs", OP_FLOAT_ABS)); + op_names.push_back(builtin_name("abs", OP_FLOAT_ABS)); op_names.push_back(builtin_name("remainder", OP_FLOAT_REM)); op_names.push_back(builtin_name("fusedMA", OP_FLOAT_FUSED_MA)); op_names.push_back(builtin_name("squareRoot", OP_FLOAT_SQRT)); @@ -515,6 +516,49 @@ void float_decl_plugin::get_op_names(svector & op_names, symbol co if (m_bv_plugin) op_names.push_back(builtin_name("asIEEEBV", OP_TO_IEEE_BV)); + + // We also support draft version 3 + op_names.push_back(builtin_name("fp", OP_TO_FLOAT)); + + op_names.push_back(builtin_name("RNE", OP_RM_NEAREST_TIES_TO_EVEN)); + op_names.push_back(builtin_name("RNA", OP_RM_NEAREST_TIES_TO_AWAY)); + op_names.push_back(builtin_name("RTP", OP_RM_TOWARD_POSITIVE)); + op_names.push_back(builtin_name("RTN", OP_RM_TOWARD_NEGATIVE)); + op_names.push_back(builtin_name("RTZ", OP_RM_TOWARD_ZERO)); + + op_names.push_back(builtin_name("fp.abs", OP_FLOAT_ABS)); + op_names.push_back(builtin_name("fp.neg", OP_FLOAT_UMINUS)); + op_names.push_back(builtin_name("fp.add", OP_FLOAT_ADD)); + op_names.push_back(builtin_name("fp.sub", OP_FLOAT_SUB)); + op_names.push_back(builtin_name("fp.mul", OP_FLOAT_MUL)); + op_names.push_back(builtin_name("fp.div", OP_FLOAT_DIV)); + op_names.push_back(builtin_name("fp.fma", OP_FLOAT_FUSED_MA)); + op_names.push_back(builtin_name("fp.sqrt", OP_FLOAT_SQRT)); + op_names.push_back(builtin_name("fp.rem", OP_FLOAT_REM)); + op_names.push_back(builtin_name("fp.eq", OP_FLOAT_EQ)); + op_names.push_back(builtin_name("fp.leq", OP_FLOAT_LE)); + op_names.push_back(builtin_name("fp.lt", OP_FLOAT_LT)); + op_names.push_back(builtin_name("fp.geq", OP_FLOAT_GE)); + op_names.push_back(builtin_name("fp.gt", OP_FLOAT_GT)); + op_names.push_back(builtin_name("fp.isNormal", OP_FLOAT_IS_NORMAL)); + op_names.push_back(builtin_name("fp.isSubnormal", OP_FLOAT_IS_SUBNORMAL)); + op_names.push_back(builtin_name("fp.isZero", OP_FLOAT_IS_ZERO)); + op_names.push_back(builtin_name("fp.isInfinite", OP_FLOAT_IS_INF)); + op_names.push_back(builtin_name("fp.isNaN", OP_FLOAT_IS_NAN)); + op_names.push_back(builtin_name("fp.min", OP_FLOAT_MIN)); + op_names.push_back(builtin_name("fp.max", OP_FLOAT_MAX)); + op_names.push_back(builtin_name("fp.convert", OP_TO_FLOAT)); + + if (m_bv_plugin) { + // op_names.push_back(builtin_name("fp.fromBv", OP_TO_FLOAT)); + // op_names.push_back(builtin_name("fp.fromUBv", OP_TO_FLOAT)); + // op_names.push_back(builtin_name("fp.fromSBv", OP_TO_FLOAT)); + // op_names.push_back(builtin_name("fp.toUBv", OP_TO_IEEE_BV)); + // op_names.push_back(builtin_name("fp.toSBv", OP_TO_IEEE_BV)); + } + + op_names.push_back(builtin_name("fp.fromReal", OP_TO_FLOAT)); + // op_names.push_back(builtin_name("fp.toReal", ?)); } void float_decl_plugin::get_sort_names(svector & sort_names, symbol const & logic) { diff --git a/src/ast/proof_checker/proof_checker.cpp b/src/ast/proof_checker/proof_checker.cpp index 85e0cc791..41c43b26c 100644 --- a/src/ast/proof_checker/proof_checker.cpp +++ b/src/ast/proof_checker/proof_checker.cpp @@ -479,7 +479,7 @@ bool proof_checker::check1_basic(proof* p, expr_ref_vector& side_conditions) { // otherwise t2 is also a quantifier. return true; } - UNREACHABLE(); + IF_VERBOSE(0, verbose_stream() << "does not match last rule: " << mk_pp(p, m) << "\n";); return false; } case PR_DER: { @@ -488,13 +488,12 @@ bool proof_checker::check1_basic(proof* p, expr_ref_vector& side_conditions) { match_fact(p, fact) && match_iff(fact.get(), t1, t2) && match_quantifier(t1, is_forall, decls1, body1) && - is_forall && - match_or(body1.get(), terms1)) { + is_forall) { // TBD: check that terms are set of equalities. // t2 is an instance of a predicate in terms1 return true; - } - UNREACHABLE(); + } + IF_VERBOSE(0, verbose_stream() << "does not match last rule: " << mk_pp(p, m) << "\n";); return false; } case PR_HYPOTHESIS: { @@ -832,7 +831,7 @@ bool proof_checker::check1_basic(proof* p, expr_ref_vector& side_conditions) { } else { IF_VERBOSE(0, verbose_stream() << "Could not establish complementarity for:\n" << - mk_pp(lit1, m) << "\n" << mk_pp(lit2, m) << "\n";); + mk_pp(lit1, m) << "\n" << mk_pp(lit2, m) << "\n" << mk_pp(p, m) << "\n";); } fmls[i] = premise1; } diff --git a/src/ast/simplifier/array_simplifier_plugin.cpp b/src/ast/simplifier/array_simplifier_plugin.cpp index bf9ca8ffe..065f8bd96 100644 --- a/src/ast/simplifier/array_simplifier_plugin.cpp +++ b/src/ast/simplifier/array_simplifier_plugin.cpp @@ -68,7 +68,7 @@ bool array_simplifier_plugin::reduce(func_decl * f, unsigned num_args, expr * co set_reduce_invoked(); if (m_presimp) return false; -#if _DEBUG +#if Z3DEBUG for (unsigned i = 0; i < num_args && i < f->get_arity(); ++i) { SASSERT(m_manager.get_sort(args[i]) == f->get_domain(i)); } diff --git a/src/cmd_context/basic_cmds.cpp b/src/cmd_context/basic_cmds.cpp index 9f18608d3..527230e16 100644 --- a/src/cmd_context/basic_cmds.cpp +++ b/src/cmd_context/basic_cmds.cpp @@ -240,6 +240,8 @@ protected: symbol m_produce_unsat_cores; symbol m_produce_models; symbol m_produce_assignments; + symbol m_produce_interpolants; + symbol m_check_interpolants; symbol m_regular_output_channel; symbol m_diagnostic_output_channel; symbol m_random_seed; @@ -253,7 +255,9 @@ protected: return s == m_print_success || s == m_print_warning || s == m_expand_definitions || s == m_interactive_mode || s == m_produce_proofs || s == m_produce_unsat_cores || - s == m_produce_models || s == m_produce_assignments || s == m_regular_output_channel || s == m_diagnostic_output_channel || + s == m_produce_models || s == m_produce_assignments || s == m_produce_interpolants || + s == m_check_interpolants || + s == m_regular_output_channel || s == m_diagnostic_output_channel || s == m_random_seed || s == m_verbosity || s == m_global_decls; } @@ -270,6 +274,8 @@ public: m_produce_unsat_cores(":produce-unsat-cores"), m_produce_models(":produce-models"), m_produce_assignments(":produce-assignments"), + m_produce_interpolants(":produce-interpolants"), + m_check_interpolants(":check-interpolants"), m_regular_output_channel(":regular-output-channel"), m_diagnostic_output_channel(":diagnostic-output-channel"), m_random_seed(":random-seed"), @@ -337,6 +343,13 @@ class set_option_cmd : public set_get_option_cmd { check_not_initialized(ctx, m_produce_proofs); ctx.set_produce_proofs(to_bool(value)); } + else if (m_option == m_produce_interpolants) { + check_not_initialized(ctx, m_produce_interpolants); + ctx.set_produce_interpolants(to_bool(value)); + } + else if (m_option == m_check_interpolants) { + ctx.set_check_interpolants(to_bool(value)); + } else if (m_option == m_produce_unsat_cores) { check_not_initialized(ctx, m_produce_unsat_cores); ctx.set_produce_unsat_cores(to_bool(value)); @@ -485,6 +498,9 @@ public: else if (opt == m_produce_proofs) { print_bool(ctx, ctx.produce_proofs()); } + else if (opt == m_produce_interpolants) { + print_bool(ctx, ctx.produce_interpolants()); + } else if (opt == m_produce_unsat_cores) { print_bool(ctx, ctx.produce_unsat_cores()); } diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index 47575be7e..079597fdc 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -39,6 +39,7 @@ Notes: #include"model_evaluator.h" #include"for_each_expr.h" #include"scoped_timer.h" +#include"interpolant_cmds.h" func_decls::func_decls(ast_manager & m, func_decl * f): m_decls(TAG(func_decl*, f, 0)) { @@ -324,6 +325,7 @@ cmd_context::cmd_context(bool main_ctx, ast_manager * m, symbol const & l): install_basic_cmds(*this); install_ext_basic_cmds(*this); install_core_tactic_cmds(*this); + install_interpolant_cmds(*this); SASSERT(m != 0 || !has_manager()); if (m) init_external_manager(); @@ -381,6 +383,19 @@ void cmd_context::set_produce_proofs(bool f) { m_params.m_proof = f; } +void cmd_context::set_produce_interpolants(bool f) { + // can only be set before initialization + // FIXME currently synonym for produce_proofs + // also sets the default solver to be simple smt + SASSERT(!has_manager()); + m_params.m_proof = f; + // set_solver_factory(mk_smt_solver_factory()); +} + +void cmd_context::set_check_interpolants(bool f) { + m_params.m_check_interpolants = f; +} + bool cmd_context::produce_models() const { return m_params.m_model; } @@ -389,6 +404,15 @@ bool cmd_context::produce_proofs() const { return m_params.m_proof; } +bool cmd_context::produce_interpolants() const { + // FIXME currently synonym for produce_proofs + return m_params.m_proof; +} + +bool cmd_context::check_interpolants() const { + return m_params.m_check_interpolants; +} + bool cmd_context::produce_unsat_cores() const { return m_params.m_unsat_core; } @@ -1458,11 +1482,27 @@ void cmd_context::validate_model() { } } +// FIXME: really interpolants_enabled ought to be a parameter to the solver factory, +// but this is a global change, so for now, we use an alternate solver factory +// for interpolation + void cmd_context::mk_solver() { bool proofs_enabled, models_enabled, unsat_core_enabled; params_ref p; m_params.get_solver_params(m(), p, proofs_enabled, models_enabled, unsat_core_enabled); - m_solver = (*m_solver_factory)(m(), p, proofs_enabled, models_enabled, unsat_core_enabled, m_logic); + if(produce_interpolants()){ + SASSERT(m_interpolating_solver_factory); + m_solver = (*m_interpolating_solver_factory)(m(), p, true /* must have proofs */, models_enabled, unsat_core_enabled, m_logic); + } + else + m_solver = (*m_solver_factory)(m(), p, proofs_enabled, models_enabled, unsat_core_enabled, m_logic); +} + + + +void cmd_context::set_interpolating_solver_factory(solver_factory * f) { + SASSERT(!has_manager()); + m_interpolating_solver_factory = f; } void cmd_context::set_solver_factory(solver_factory * f) { diff --git a/src/cmd_context/cmd_context.h b/src/cmd_context/cmd_context.h index 11828c4c6..5f24ff808 100644 --- a/src/cmd_context/cmd_context.h +++ b/src/cmd_context/cmd_context.h @@ -186,6 +186,7 @@ protected: svector m_scopes; scoped_ptr m_solver_factory; + scoped_ptr m_interpolating_solver_factory; ref m_solver; ref m_check_sat_result; @@ -252,6 +253,8 @@ public: void cancel() { set_cancel(true); } void reset_cancel() { set_cancel(false); } context_params & params() { return m_params; } + solver_factory &get_solver_factory() { return *m_solver_factory; } + solver_factory &get_interpolating_solver_factory() { return *m_interpolating_solver_factory; } void global_params_updated(); // this method should be invoked when global (and module) params are updated. bool set_logic(symbol const & s); bool has_logic() const { return m_logic != symbol::null; } @@ -274,12 +277,16 @@ public: void set_random_seed(unsigned s) { m_random_seed = s; } bool produce_models() const; bool produce_proofs() const; + bool produce_interpolants() const; + bool check_interpolants() const; bool produce_unsat_cores() const; bool well_sorted_check_enabled() const; bool validate_model_enabled() const; void set_produce_models(bool flag); void set_produce_unsat_cores(bool flag); void set_produce_proofs(bool flag); + void set_produce_interpolants(bool flag); + void set_check_interpolants(bool flag); bool produce_assignments() const { return m_produce_assignments; } void set_produce_assignments(bool flag) { m_produce_assignments = flag; } void set_status(status st) { m_status = st; } @@ -293,6 +300,7 @@ public: sexpr_manager & sm() const { if (!m_sexpr_manager) const_cast(this)->m_sexpr_manager = alloc(sexpr_manager); return *m_sexpr_manager; } void set_solver_factory(solver_factory * s); + void set_interpolating_solver_factory(solver_factory * s); void set_check_sat_result(check_sat_result * r) { m_check_sat_result = r; } check_sat_result * get_check_sat_result() const { return m_check_sat_result.get(); } check_sat_state cs_state() const; diff --git a/src/cmd_context/context_params.cpp b/src/cmd_context/context_params.cpp index 2752967dd..0ec44e0cf 100644 --- a/src/cmd_context/context_params.cpp +++ b/src/cmd_context/context_params.cpp @@ -61,6 +61,9 @@ void context_params::set(char const * param, char const * value) { else if (p == "proof") { set_bool(m_proof, param, value); } + else if (p == "check_interpolants") { + set_bool(m_check_interpolants, param, value); + } else if (p == "model") { set_bool(m_model, param, value); } @@ -96,6 +99,7 @@ void context_params::updt_params(params_ref const & p) { m_well_sorted_check = p.get_bool("type_check", p.get_bool("well_sorted_check", true)); m_auto_config = p.get_bool("auto_config", true); m_proof = p.get_bool("proof", false); + m_check_interpolants = p.get_bool("check_interpolants", false); m_model = p.get_bool("model", true); m_model_validate = p.get_bool("model_validate", false); m_trace = p.get_bool("trace", false); @@ -111,6 +115,7 @@ void context_params::collect_param_descrs(param_descrs & d) { d.insert("type_check", CPK_BOOL, "type checker (alias for well_sorted_check)", "true"); d.insert("auto_config", CPK_BOOL, "use heuristics to automatically select solver and configure it", "true"); d.insert("proof", CPK_BOOL, "proof generation, it must be enabled when the Z3 context is created", "false"); + d.insert("check_interpolants", CPK_BOOL, "check correctness of interpolants", "false"); d.insert("model", CPK_BOOL, "model generation for solvers, this parameter can be overwritten when creating a solver", "true"); d.insert("model_validate", CPK_BOOL, "validate models produced by solvers", "false"); d.insert("trace", CPK_BOOL, "trace generation for VCC", "false"); diff --git a/src/cmd_context/context_params.h b/src/cmd_context/context_params.h index 526b4f517..7271bb84f 100644 --- a/src/cmd_context/context_params.h +++ b/src/cmd_context/context_params.h @@ -29,6 +29,8 @@ class context_params { public: bool m_auto_config; bool m_proof; + bool m_interpolants; + bool m_check_interpolants; bool m_debug_ref_count; bool m_trace; std::string m_trace_file_name; diff --git a/src/cmd_context/interpolant_cmds.cpp b/src/cmd_context/interpolant_cmds.cpp new file mode 100644 index 000000000..2a7de3176 --- /dev/null +++ b/src/cmd_context/interpolant_cmds.cpp @@ -0,0 +1,272 @@ +/*++ +Copyright (c) 2013 Microsoft Corporation + +Module Name: + + interpolant_cmds.cpp + +Abstract: + Commands for interpolation. + +Author: + + Leonardo (leonardo) 2011-12-23 + +Notes: + +--*/ +#include +#include"cmd_context.h" +#include"cmd_util.h" +#include"scoped_timer.h" +#include"scoped_ctrl_c.h" +#include"cancel_eh.h" +#include"ast_pp.h" +#include"ast_smt_pp.h" +#include"ast_smt2_pp.h" +#include"parametric_cmd.h" +#include"mpq.h" +#include"expr2var.h" +#include"pp.h" +#include"pp_params.hpp" +#include"iz3interp.h" +#include"iz3checker.h" +#include"iz3profiling.h" +#include"interp_params.hpp" + +static void show_interpolant_and_maybe_check(cmd_context & ctx, + ptr_vector &cnsts, + expr *t, + ptr_vector &interps, + params_ref &m_params, + bool check) +{ + + if (m_params.get_bool("som", false)) + m_params.set_bool("flat", true); + th_rewriter s(ctx.m(), m_params); + + for(unsigned i = 0; i < interps.size(); i++){ + + expr_ref r(ctx.m()); + proof_ref pr(ctx.m()); + s(to_expr(interps[i]),r,pr); + + ctx.regular_stream() << mk_pp(r.get(), ctx.m()) << std::endl; +#if 0 + ast_smt_pp pp(ctx.m()); + pp.set_logic(ctx.get_logic().str().c_str()); + pp.display_smt2(ctx.regular_stream(), to_expr(interps[i])); + ctx.regular_stream() << std::endl; +#endif + } + + s.cleanup(); + + // verify, for the paranoid... + if(check || ctx.check_interpolants()){ + std::ostringstream err; + ast_manager &_m = ctx.m(); + + // need a solver -- make one here FIXME is this right? + bool proofs_enabled, models_enabled, unsat_core_enabled; + params_ref p; + ctx.params().get_solver_params(_m, p, proofs_enabled, models_enabled, unsat_core_enabled); + scoped_ptr sp = (ctx.get_solver_factory())(_m, p, false, true, false, ctx.get_logic()); + + if(iz3check(_m,sp.get(),err,cnsts,t,interps)) + ctx.regular_stream() << "correct\n"; + else + ctx.regular_stream() << "incorrect: " << err.str().c_str() << "\n"; + } + + for(unsigned i = 0; i < interps.size(); i++){ + ctx.m().dec_ref(interps[i]); + } + + interp_params itp_params(m_params); + if(itp_params.profile()) + profiling::print(ctx.regular_stream()); + +} + +static void check_can_interpolate(cmd_context & ctx){ + if (!ctx.produce_interpolants()) + throw cmd_exception("interpolation is not enabled, use command (set-option :produce-interpolants true)"); +} + + +static void get_interpolant_and_maybe_check(cmd_context & ctx, expr * t, params_ref &m_params, bool check) { + + check_can_interpolate(ctx); + + // get the proof, if there is one + + if (!ctx.has_manager() || + ctx.cs_state() != cmd_context::css_unsat) + throw cmd_exception("proof is not available"); + expr_ref pr(ctx.m()); + pr = ctx.get_check_sat_result()->get_proof(); + if (pr == 0) + throw cmd_exception("proof is not available"); + + // get the assertions from the context + + ptr_vector::const_iterator it = ctx.begin_assertions(); + ptr_vector::const_iterator end = ctx.end_assertions(); + ptr_vector cnsts(end - it); + for (int i = 0; it != end; ++it, ++i) + cnsts[i] = *it; + + // compute an interpolant + + ptr_vector interps; + + try { + iz3interpolate(ctx.m(),pr.get(),cnsts,t,interps,0); + } + catch (iz3_bad_tree &) { + throw cmd_exception("interpolation pattern contains non-asserted formula"); + } + catch (iz3_incompleteness &) { + throw cmd_exception("incompleteness in interpolator"); + } + + show_interpolant_and_maybe_check(ctx, cnsts, t, interps, m_params, check); +} + +static void get_interpolant(cmd_context & ctx, expr * t, params_ref &m_params) { + get_interpolant_and_maybe_check(ctx,t,m_params,false); +} + +static void get_and_check_interpolant(cmd_context & ctx, params_ref &m_params, expr * t) { + get_interpolant_and_maybe_check(ctx,t,m_params,true); +} + + +static void compute_interpolant_and_maybe_check(cmd_context & ctx, expr * t, params_ref &m_params, bool check){ + + // create a fresh solver suitable for interpolation + bool proofs_enabled, models_enabled, unsat_core_enabled; + params_ref p; + ast_manager &_m = ctx.m(); + // TODO: the following is a HACK to enable proofs in the old smt solver + // When we stop using that solver, this hack can be removed + _m.toggle_proof_mode(PGM_FINE); + ctx.params().get_solver_params(_m, p, proofs_enabled, models_enabled, unsat_core_enabled); + p.set_bool("proof", true); + scoped_ptr sp = (ctx.get_interpolating_solver_factory())(_m, p, true, models_enabled, false, ctx.get_logic()); + + ptr_vector cnsts; + ptr_vector interps; + model_ref m; + + // compute an interpolant + + lbool res; + try { + res = iz3interpolate(_m, *sp.get(), t, cnsts, interps, m, 0); + } + catch (iz3_incompleteness &) { + throw cmd_exception("incompleteness in interpolator"); + } + + switch(res){ + case l_false: + ctx.regular_stream() << "unsat\n"; + show_interpolant_and_maybe_check(ctx, cnsts, t, interps, m_params, check); + break; + + case l_true: + ctx.regular_stream() << "sat\n"; + // TODO: how to return the model to the context, if it exists? + break; + + case l_undef: + ctx.regular_stream() << "unknown\n"; + // TODO: how to return the model to the context, if it exists? + break; + } + + for(unsigned i = 0; i < cnsts.size(); i++) + ctx.m().dec_ref(cnsts[i]); + +} + +static expr *make_tree(cmd_context & ctx, const ptr_vector &exprs){ + if(exprs.size() == 0) + throw cmd_exception("not enough arguments"); + expr *foo = exprs[0]; + for(unsigned i = 1; i < exprs.size(); i++){ + foo = ctx.m().mk_and(ctx.m().mk_interp(foo),exprs[i]); + } + return foo; +} + +static void get_interpolant(cmd_context & ctx, const ptr_vector &exprs, params_ref &m_params) { + expr *foo = make_tree(ctx,exprs); + ctx.m().inc_ref(foo); + get_interpolant(ctx,foo,m_params); + ctx.m().dec_ref(foo); +} + +static void compute_interpolant(cmd_context & ctx, const ptr_vector &exprs, params_ref &m_params) { + expr *foo = make_tree(ctx, exprs); + ctx.m().inc_ref(foo); + compute_interpolant_and_maybe_check(ctx,foo,m_params,false); + ctx.m().dec_ref(foo); +} + + +// UNARY_CMD(get_interpolant_cmd, "get-interpolant", "", "get interpolant for marked positions in fmla", CPK_EXPR, expr *, get_interpolant(ctx, arg);); + +// UNARY_CMD(get_and_check_interpolant_cmd, "get-and-check-interpolant", "", "get and check interpolant for marked positions in fmla", CPK_EXPR, expr *, get_and_check_interpolant(ctx, arg);); + +class get_interpolant_cmd : public parametric_cmd { +protected: + ptr_vector m_targets; +public: + get_interpolant_cmd(char const * name = "get-interpolant"):parametric_cmd(name) {} + + virtual char const * get_usage() const { return "+"; } + + virtual char const * get_main_descr() const { + return "get interpolant for formulas"; + } + + virtual void init_pdescrs(cmd_context & ctx, param_descrs & p) { + } + + virtual void prepare(cmd_context & ctx) { + parametric_cmd::prepare(ctx); + m_targets.resize(0); + } + + virtual cmd_arg_kind next_arg_kind(cmd_context & ctx) const { + return CPK_EXPR; + } + + virtual void set_next_arg(cmd_context & ctx, expr * arg) { + m_targets.push_back(arg); + } + + virtual void execute(cmd_context & ctx) { + get_interpolant(ctx,m_targets,m_params); + } +}; + +class compute_interpolant_cmd : public get_interpolant_cmd { +public: + compute_interpolant_cmd(char const * name = "compute-interpolant"):get_interpolant_cmd(name) {} + + virtual void execute(cmd_context & ctx) { + compute_interpolant(ctx,m_targets,m_params); + } + +}; + +void install_interpolant_cmds(cmd_context & ctx) { + ctx.insert(alloc(get_interpolant_cmd)); + ctx.insert(alloc(compute_interpolant_cmd)); + // ctx.insert(alloc(get_and_check_interpolant_cmd)); +} diff --git a/src/cmd_context/interpolant_cmds.h b/src/cmd_context/interpolant_cmds.h new file mode 100644 index 000000000..d7ceb5dc2 --- /dev/null +++ b/src/cmd_context/interpolant_cmds.h @@ -0,0 +1,24 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + interpolant_cmds.h + +Abstract: + Commands for interpolation. + +Author: + + Leonardo (leonardo) 2011-12-23 + +Notes: + +--*/ +#ifndef _INTERPOLANT_CMDS_H_ +#define _INTERPOLANT_CMDS_H_ + +class cmd_context; +void install_interpolant_cmds(cmd_context & ctx); + +#endif diff --git a/src/duality/duality.h b/src/duality/duality.h new file mode 100644 index 000000000..979071639 --- /dev/null +++ b/src/duality/duality.h @@ -0,0 +1,1044 @@ +/*++ +Copyright (c) 2012 Microsoft Corporation + +Module Name: + + duality.h + +Abstract: + + main header for duality + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + + +--*/ + +#pragma once + +#include "duality_wrapper.h" +#include +#include + +// make hash_map and hash_set available +#ifndef WIN32 +using namespace stl_ext; +#endif + +namespace Duality { + + /* Generic operations on Z3 formulas */ + + struct Z3User { + + context &ctx; + solver &slvr; + + typedef func_decl FuncDecl; + typedef expr Term; + + Z3User(context &_ctx, solver &_slvr) : ctx(_ctx), slvr(_slvr){} + + const char *string_of_int(int n); + + Term conjoin(const std::vector &args); + + Term sum(const std::vector &args); + + Term CloneQuantifier(const Term &t, const Term &new_body); + + Term SubstRec(hash_map &memo, const Term &t); + + void Strengthen(Term &x, const Term &y); + + // return the func_del of an app if it is uninterpreted + + bool get_relation(const Term &t, func_decl &R); + + // return true if term is an individual variable + // TODO: have to check that it is not a background symbol + + bool is_variable(const Term &t); + + FuncDecl SuffixFuncDecl(Term t, int n); + + + Term SubstRecHide(hash_map &memo, const Term &t, int number); + + void CollectConjuncts(const Term &f, std::vector &lits, bool pos = true); + + void SortTerms(std::vector &terms); + + Term SortSum(const Term &t); + + void Summarize(const Term &t); + + int CumulativeDecisions(); + + int CountOperators(const Term &t); + + Term SubstAtom(hash_map &memo, const expr &t, const expr &atom, const expr &val); + + Term RemoveRedundancy(const Term &t); + + bool IsLiteral(const expr &lit, expr &atom, expr &val); + + expr Negate(const expr &f); + + expr SimplifyAndOr(const std::vector &args, bool is_and); + + expr ReallySimplifyAndOr(const std::vector &args, bool is_and); + + int MaxIndex(hash_map &memo, const Term &t); + + bool IsClosedFormula(const Term &t); + + Term AdjustQuantifiers(const Term &t); +private: + + void SummarizeRec(hash_set &memo, std::vector &lits, int &ops, const Term &t); + int CountOperatorsRec(hash_set &memo, const Term &t); + void RemoveRedundancyOp(bool pol, std::vector &args, hash_map &smemo); + Term RemoveRedundancyRec(hash_map &memo, hash_map &smemo, const Term &t); + Term SubstAtomTriv(const expr &foo, const expr &atom, const expr &val); + expr ReduceAndOr(const std::vector &args, bool is_and, std::vector &res); + expr FinishAndOr(const std::vector &args, bool is_and); + expr PullCommonFactors(std::vector &args, bool is_and); + + +}; + + /** This class represents a relation post-fixed point (RPFP) problem as + * a "problem graph". The graph consists of Nodes and hyper-edges. + * + * A node consists of + * - Annotation, a symbolic relation + * - Bound, a symbolic relation giving an upper bound on Annotation + * + * + * A hyper-edge consists of: + * - Children, a sequence of children Nodes, + * - F, a symbolic relational transformer, + * - Parent, a single parent Node. + * + * The graph is "solved" when: + * - For every Node n, n.Annotation subseteq n.Bound + * - For every hyperedge e, e.F(e.Children.Annotation) subseteq e.Parent.Annotation + * + * where, if x is a sequence of Nodes, x.Annotation is the sequences + * of Annotations of the nodes in the sequence. + * + * A symbolic Transformer consists of + * - RelParams, a sequence of relational symbols + * - IndParams, a sequence of individual symbols + * - Formula, a formula over RelParams and IndParams + * + * A Transformer t represents a function that takes sequence R of relations + * and yields the relation lambda (t.Indparams). Formula(R/RelParams). + * + * As a special case, a nullary Transformer (where RelParams is the empty sequence) + * represents a fixed relation. + * + * An RPFP consists of + * - Nodes, a set of Nodes + * - Edges, a set of hyper-edges + * - Context, a prover context that contains formula AST's + * + * Multiple RPFP's can use the same Context, but you should be careful + * that only one RPFP asserts constraints in the context at any time. + * + * */ + class RPFP : public Z3User + { + public: + + class Edge; + class Node; + bool HornClauses; + + + /** Interface class for interpolating solver. */ + + class LogicSolver { + public: + + context *ctx; /** Z3 context for formulas */ + solver *slvr; /** Z3 solver */ + bool need_goals; /** Can the solver use the goal tree to optimize interpolants? */ + solver aux_solver; /** For temporary use -- don't leave assertions here. */ + + /** Tree interpolation. This method assumes the formulas in TermTree + "assumptions" are currently asserted in the solver. The return + value indicates whether the assertions are satisfiable. In the + UNSAT case, a tree interpolant is returned in "interpolants". + In the SAT case, a model is returned. + */ + + virtual + lbool interpolate_tree(TermTree *assumptions, + TermTree *&interpolants, + model &_model, + TermTree *goals = 0, + bool weak = false + ) = 0; + + /** Declare a constant in the background theory. */ + virtual void declare_constant(const func_decl &f) = 0; + + /** Is this a background constant? */ + virtual bool is_constant(const func_decl &f) = 0; + + /** Assert a background axiom. */ + virtual void assert_axiom(const expr &axiom) = 0; + + /** Get the background axioms. */ + virtual const std::vector &get_axioms() = 0; + + /** Return a string describing performance. */ + virtual std::string profile() = 0; + + virtual void write_interpolation_problem(const std::string &file_name, + const std::vector &assumptions, + const std::vector &theory + ){} + + /** Cancel, throw Canceled object if possible. */ + virtual void cancel(){ } + + /* Note: aux solver uses extensional array theory, since it + needs to be able to produce counter-models for + interpolants the have array equalities in them. + */ + LogicSolver(context &c) : aux_solver(c,true){} + + virtual ~LogicSolver(){} + }; + + /** This solver uses iZ3. */ + class iZ3LogicSolver : public LogicSolver { + public: + interpolating_context *ictx; /** iZ3 context for formulas */ + interpolating_solver *islvr; /** iZ3 solver */ + + lbool interpolate_tree(TermTree *assumptions, + TermTree *&interpolants, + model &_model, + TermTree *goals = 0, + bool weak = false) + { + literals _labels; + islvr->SetWeakInterpolants(weak); + return islvr->interpolate_tree(assumptions,interpolants,_model,_labels,true); + } + + void assert_axiom(const expr &axiom){ + islvr->AssertInterpolationAxiom(axiom); + } + + const std::vector &get_axioms() { + return islvr->GetInterpolationAxioms(); + } + + std::string profile(){ + return islvr->profile(); + } + +#if 0 + iZ3LogicSolver(config &_config){ + ctx = ictx = new interpolating_context(_config); + slvr = islvr = new interpolating_solver(*ictx); + need_goals = false; + islvr->SetWeakInterpolants(true); + } +#endif + + iZ3LogicSolver(context &c) : LogicSolver(c) { + ctx = ictx = &c; + slvr = islvr = new interpolating_solver(*ictx); + need_goals = false; + islvr->SetWeakInterpolants(true); + } + + void write_interpolation_problem(const std::string &file_name, + const std::vector &assumptions, + const std::vector &theory + ){ +#if 0 + islvr->write_interpolation_problem(file_name,assumptions,theory); +#endif + + } + + void cancel(){islvr->cancel();} + + /** Declare a constant in the background theory. */ + virtual void declare_constant(const func_decl &f){ + bckg.insert(f); + } + + /** Is this a background constant? */ + virtual bool is_constant(const func_decl &f){ + return bckg.find(f) != bckg.end(); + } + + ~iZ3LogicSolver(){ + // delete ictx; + delete islvr; + } + private: + hash_set bckg; + + }; + +#if 0 + /** Create a logic solver from a Z3 configuration. */ + static iZ3LogicSolver *CreateLogicSolver(config &_config){ + return new iZ3LogicSolver(_config); + } +#endif + + /** Create a logic solver from a low-level Z3 context. + Only use this if you know what you're doing. */ + static iZ3LogicSolver *CreateLogicSolver(context c){ + return new iZ3LogicSolver(c); + } + + LogicSolver *ls; + + private: + int nodeCount; + int edgeCount; + + class stack_entry + { + public: + std::list edges; + std::list nodes; + std::list > constraints; + }; + + + public: + model dualModel; + private: + literals dualLabels; + std::list stack; + std::vector axioms; // only saved here for printing purposes + solver &aux_solver; + hash_set *proof_core; + + public: + + /** Construct an RPFP graph with a given interpolating prover context. It is allowed to + have multiple RPFP's use the same context, but you should never have teo RPFP's + with the same conext asserting nodes or edges at the same time. Note, if you create + axioms in one RPFP, them create a second RPFP with the same context, the second will + inherit the axioms. + */ + + RPFP(LogicSolver *_ls) : Z3User(*(_ls->ctx), *(_ls->slvr)), dualModel(*(_ls->ctx)), aux_solver(_ls->aux_solver) + { + ls = _ls; + nodeCount = 0; + edgeCount = 0; + stack.push_back(stack_entry()); + HornClauses = false; + proof_core = 0; + } + + ~RPFP(); + + /** Symbolic representation of a relational transformer */ + class Transformer + { + public: + std::vector RelParams; + std::vector IndParams; + Term Formula; + RPFP *owner; + hash_map labels; + + Transformer *Clone() + { + return new Transformer(*this); + } + + void SetEmpty(){ + Formula = owner->ctx.bool_val(false); + } + + void SetFull(){ + Formula = owner->ctx.bool_val(true); + } + + bool IsEmpty(){ + return eq(Formula,owner->ctx.bool_val(false)); + } + + bool IsFull(){ + return eq(Formula,owner->ctx.bool_val(true)); + } + + void UnionWith(const Transformer &other){ + Term t = owner->SubstParams(other.IndParams,IndParams,other.Formula); + Formula = Formula || t; + } + + void IntersectWith(const Transformer &other){ + Term t = owner->SubstParams(other.IndParams,IndParams,other.Formula); + Formula = Formula && t; + } + + bool SubsetEq(const Transformer &other){ + Term t = owner->SubstParams(other.IndParams,IndParams,other.Formula); + expr test = Formula && !t; + owner->aux_solver.push(); + owner->aux_solver.add(test); + check_result res = owner->aux_solver.check(); + owner->aux_solver.pop(1); + return res == unsat; + } + + void Complement(){ + Formula = !Formula; + } + + void Simplify(){ + Formula = Formula.simplify(); + } + + Transformer(const std::vector &_RelParams, const std::vector &_IndParams, const Term &_Formula, RPFP *_owner) + : RelParams(_RelParams), IndParams(_IndParams), Formula(_Formula) {owner = _owner;} + }; + + /** Create a symbolic transformer. */ + Transformer CreateTransformer(const std::vector &_RelParams, const std::vector &_IndParams, const Term &_Formula) + { + // var ops = new Util.ContextOps(ctx); + // var foo = ops.simplify_lhs(_Formula); + // t.Formula = foo.Item1; + // t.labels = foo.Item2; + return Transformer(_RelParams,_IndParams,_Formula,this); + } + + /** Create a relation (nullary relational transformer) */ + Transformer CreateRelation(const std::vector &_IndParams, const Term &_Formula) + { + return CreateTransformer(std::vector(), _IndParams, _Formula); + } + + /** A node in the RPFP graph */ + class Node + { + public: + FuncDecl Name; + Transformer Annotation; + Transformer Bound; + Transformer Underapprox; + RPFP *owner; + int number; + Edge *Outgoing; + std::vector Incoming; + Term dual; + Node *map; + + Node(const FuncDecl &_Name, const Transformer &_Annotation, const Transformer &_Bound, const Transformer &_Underapprox, const Term &_dual, RPFP *_owner, int _number) + : Name(_Name), Annotation(_Annotation), Bound(_Bound), Underapprox(_Underapprox), dual(_dual) {owner = _owner; number = _number; Outgoing = 0;} + }; + + /** Create a node in the graph. The input is a term R(v_1...v_n) + * where R is an arbitrary relational symbol and v_1...v_n are + * arbitary distinct variables. The names are only of mnemonic value, + * however, the number and type of arguments determine the type + * of the relation at this node. */ + + Node *CreateNode(const Term &t) + { + std::vector _IndParams; + int nargs = t.num_args(); + for(int i = 0; i < nargs; i++) + _IndParams.push_back(t.arg(i)); + Node *n = new Node(t.decl(), + CreateRelation(_IndParams,ctx.bool_val(true)), + CreateRelation(_IndParams,ctx.bool_val(true)), + CreateRelation(_IndParams,ctx.bool_val(false)), + expr(ctx), this, ++nodeCount + ); + nodes.push_back(n); + return n; + } + + /** Clone a node (can be from another graph). */ + + Node *CloneNode(Node *old) + { + Node *n = new Node(old->Name, + old->Annotation, + old->Bound, + old->Underapprox, + expr(ctx), + this, + ++nodeCount + ); + nodes.push_back(n); + n->map = old; + return n; + } + + /** Delete a node. You can only do this if not connected to any edges.*/ + void DeleteNode(Node *node){ + if(node->Outgoing || !node->Incoming.empty()) + throw "cannot delete RPFP node"; + for(std::vector::iterator it = nodes.end(), en = nodes.begin(); it != en;){ + if(*(--it) == node){ + nodes.erase(it); + break; + } + } + delete node; + } + + /** This class represents a hyper-edge in the RPFP graph */ + + class Edge + { + public: + Transformer F; + Node *Parent; + std::vector Children; + RPFP *owner; + int number; + // these should be internal... + Term dual; + hash_map relMap; + hash_map varMap; + Edge *map; + Term labeled; + std::vector constraints; + + Edge(Node *_Parent, const Transformer &_F, const std::vector &_Children, RPFP *_owner, int _number) + : F(_F), Parent(_Parent), Children(_Children), dual(expr(_owner->ctx)) { + owner = _owner; + number = _number; + } + }; + + + /** Create a hyper-edge. */ + Edge *CreateEdge(Node *_Parent, const Transformer &_F, const std::vector &_Children) + { + Edge *e = new Edge(_Parent,_F,_Children,this,++edgeCount); + _Parent->Outgoing = e; + for(unsigned i = 0; i < _Children.size(); i++) + _Children[i]->Incoming.push_back(e); + edges.push_back(e); + return e; + } + + + /** Delete a hyper-edge and unlink it from any nodes. */ + void DeleteEdge(Edge *edge){ + if(edge->Parent) + edge->Parent->Outgoing = 0; + for(unsigned int i = 0; i < edge->Children.size(); i++){ + std::vector &ic = edge->Children[i]->Incoming; + for(std::vector::iterator it = ic.begin(), en = ic.end(); it != en; ++it){ + if(*it == edge){ + ic.erase(it); + break; + } + } + } + for(std::vector::iterator it = edges.end(), en = edges.begin(); it != en;){ + if(*(--it) == edge){ + edges.erase(it); + break; + } + } + delete edge; + } + + /** Create an edge that lower-bounds its parent. */ + Edge *CreateLowerBoundEdge(Node *_Parent) + { + return CreateEdge(_Parent, _Parent->Annotation, std::vector()); + } + + + /** For incremental solving, asserts the constraint associated + * with this edge in the SMT context. If this edge is removed, + * you must pop the context accordingly. The second argument is + * the number of pushes we are inside. */ + + void AssertEdge(Edge *e, int persist = 0, bool with_children = false, bool underapprox = false); + + /* Constrain an edge by the annotation of one of its children. */ + + void ConstrainParent(Edge *parent, Node *child); + + /** For incremental solving, asserts the negation of the upper bound associated + * with a node. + * */ + + void AssertNode(Node *n); + + /** Assert a constraint on an edge in the SMT context. + */ + void ConstrainEdge(Edge *e, const Term &t); + + /** Fix the truth values of atomic propositions in the given + edge to their values in the current assignment. */ + void FixCurrentState(Edge *root); + + void FixCurrentStateFull(Edge *edge, const expr &extra); + + /** Declare a constant in the background theory. */ + + void DeclareConstant(const FuncDecl &f); + + /** Assert a background axiom. Background axioms can be used to provide the + * theory of auxilliary functions or relations. All symbols appearing in + * background axioms are considered global, and may appear in both transformer + * and relational solutions. Semantically, a solution to the RPFP gives + * an interpretation of the unknown relations for each interpretation of the + * auxilliary symbols that is consistent with the axioms. Axioms should be + * asserted before any calls to Push. They cannot be de-asserted by Pop. */ + + void AssertAxiom(const Term &t); + +#if 0 + /** Do not call this. */ + + void RemoveAxiom(const Term &t); +#endif + + /** Solve an RPFP graph. This means either strengthen the annotation + * so that the bound at the given root node is satisfied, or + * show that this cannot be done by giving a dual solution + * (i.e., a counterexample). + * + * In the current implementation, this only works for graphs that + * are: + * - tree-like + * + * - closed. + * + * In a tree-like graph, every nod has out most one incoming and one out-going edge, + * and there are no cycles. In a closed graph, every node has exactly one out-going + * edge. This means that the leaves of the tree are all hyper-edges with no + * children. Such an edge represents a relation (nullary transformer) and thus + * a lower bound on its parent. The parameter root must be the root of this tree. + * + * If Solve returns LBool.False, this indicates success. The annotation of the tree + * has been updated to satisfy the upper bound at the root. + * + * If Solve returns LBool.True, this indicates a counterexample. For each edge, + * you can then call Eval to determine the values of symbols in the transformer formula. + * You can also call Empty on a node to determine if its value in the counterexample + * is the empty relation. + * + * \param root The root of the tree + * \param persist Number of context pops through which result should persist + * + * + */ + + lbool Solve(Node *root, int persist); + + /** Same as Solve, but annotates only a single node. */ + + lbool SolveSingleNode(Node *root, Node *node); + + /** Get the constraint tree (but don't solve it) */ + + TermTree *GetConstraintTree(Node *root, Node *skip_descendant = 0); + + /** Dispose of the dual model (counterexample) if there is one. */ + + void DisposeDualModel(); + + /** Check satisfiability of asserted edges and nodes. Same functionality as + * Solve, except no primal solution (interpolant) is generated in the unsat case. */ + + check_result Check(Node *root, std::vector underapproxes = std::vector(), + std::vector *underapprox_core = 0); + + /** Update the model, attempting to make the propositional literals in assumps true. If possible, + return sat, else return unsat and keep the old model. */ + + check_result CheckUpdateModel(Node *root, std::vector assumps); + + /** Determines the value in the counterexample of a symbol occuring in the transformer formula of + * a given edge. */ + + Term Eval(Edge *e, Term t); + + /** Return the fact derived at node p in a counterexample. */ + + Term EvalNode(Node *p); + + /** Returns true if the given node is empty in the primal solution. For proecudure summaries, + this means that the procedure is not called in the current counter-model. */ + + bool Empty(Node *p); + + /** Compute an underapproximation of every node in a tree rooted at "root", + based on a previously computed counterexample. */ + + Term ComputeUnderapprox(Node *root, int persist); + + /** Try to strengthen the annotation of a node by removing disjuncts. */ + void Generalize(Node *root, Node *node); + + + /** Compute disjunctive interpolant for node by case splitting */ + void InterpolateByCases(Node *root, Node *node); + + /** Push a scope. Assertions made after Push can be undone by Pop. */ + + void Push(); + + /** Exception thrown when bad clause is encountered */ + + struct bad_clause { + std::string msg; + int i; + bad_clause(const std::string &_msg, int _i){ + msg = _msg; + i = _i; + } + }; + + struct parse_error { + std::string msg; + parse_error(const std::string &_msg){ + msg = _msg; + } + }; + + struct file_not_found { + }; + + struct bad_format { + }; + + /** Pop a scope (see Push). Note, you cannot pop axioms. */ + + void Pop(int num_scopes); + + /** Erase the proof by performing a Pop, Push and re-assertion of + all the popped constraints */ + void PopPush(); + + /** Return true if the given edge is used in the proof of unsat. + Can be called only after Solve or Check returns an unsat result. */ + + bool EdgeUsedInProof(Edge *edge); + + + /** Convert a collection of clauses to Nodes and Edges in the RPFP. + + Predicate unknowns are uninterpreted predicates not + occurring in the background theory. + + Clauses are of the form + + B => P(t_1,...,t_k) + + where P is a predicate unknown and predicate unknowns + occur only positivey in H and only under existential + quantifiers in prenex form. + + Each predicate unknown maps to a node. Each clause maps to + an edge. Let C be a clause B => P(t_1,...,t_k) where the + sequence of predicate unknowns occurring in B (in order + of occurrence) is P_1..P_n. The clause maps to a transformer + T where: + + T.Relparams = P_1..P_n + T.Indparams = x_1...x+k + T.Formula = B /\ t_1 = x_1 /\ ... /\ t_k = x_k + + Throws exception bad_clause(msg,i) if a clause i is + in the wrong form. + + */ + + struct label_struct { + symbol name; + expr value; + bool pos; + label_struct(const symbol &s, const expr &e, bool b) + : name(s), value(e), pos(b) {} + }; + + +#ifdef WIN32 + __declspec(dllexport) +#endif + void FromClauses(const std::vector &clauses); + + void FromFixpointContext(fixedpoint fp, std::vector &queries); + + void WriteSolution(std::ostream &s); + + void WriteCounterexample(std::ostream &s, Node *node); + + enum FileFormat {DualityFormat, SMT2Format, HornFormat}; + + /** Write the RPFP to a file (currently in SMTLIB 1.2 format) */ + void WriteProblemToFile(std::string filename, FileFormat format = DualityFormat); + + /** Read the RPFP from a file (in specificed format) */ + void ReadProblemFromFile(std::string filename, FileFormat format = DualityFormat); + + /** Translate problem to Horn clause form */ + void ToClauses(std::vector &cnsts, FileFormat format = DualityFormat); + + /** Translate the RPFP to a fixed point context, with queries */ + fixedpoint ToFixedPointProblem(std::vector &queries); + + /** Nodes of the graph. */ + std::vector nodes; + + /** Edges of the graph. */ + std::vector edges; + + Term SubstParams(const std::vector &from, + const std::vector &to, const Term &t); + + Term Localize(Edge *e, const Term &t); + + void EvalNodeAsConstraint(Node *p, Transformer &res); + + TermTree *GetGoalTree(Node *root); + + int EvalTruth(hash_map &memo, Edge *e, const Term &f); + + void GetLabels(Edge *e, std::vector &labels); + + // int GetLabelsRec(hash_map *memo, const Term &f, std::vector &labels, bool labpos); + + /** Compute and save the proof core for future calls to + EdgeUsedInProof. You only need to call this if you will pop + the solver before calling EdgeUsedInProof. + */ + void ComputeProofCore(); + + private: + + void ClearProofCore(){ + if(proof_core) + delete proof_core; + proof_core = 0; + } + + Term SuffixVariable(const Term &t, int n); + + Term HideVariable(const Term &t, int n); + + void RedVars(Node *node, Term &b, std::vector &v); + + Term RedDualRela(Edge *e, std::vector &args, int idx); + + Term LocalizeRec(Edge *e, hash_map &memo, const Term &t); + + void SetEdgeMaps(Edge *e); + + Term ReducedDualEdge(Edge *e); + + TermTree *ToTermTree(Node *root, Node *skip_descendant = 0); + + TermTree *ToGoalTree(Node *root); + + void CollapseTermTreeRec(TermTree *root, TermTree *node); + + TermTree *CollapseTermTree(TermTree *node); + + void DecodeTree(Node *root, TermTree *interp, int persist); + + Term GetUpperBound(Node *n); + + TermTree *AddUpperBound(Node *root, TermTree *t); + +#if 0 + void WriteInterps(System.IO.StreamWriter f, TermTree t); +#endif + + void WriteEdgeVars(Edge *e, hash_map &memo, const Term &t, std::ostream &s); + + void WriteEdgeAssignment(std::ostream &s, Edge *e); + + + // Scan the clause body for occurrences of the predicate unknowns + + Term ScanBody(hash_map &memo, + const Term &t, + hash_map &pmap, + std::vector &res, + std::vector &nodes); + + Term RemoveLabelsRec(hash_map &memo, const Term &t, std::vector &lbls); + + Term RemoveLabels(const Term &t, std::vector &lbls); + + Term GetAnnotation(Node *n); + + + Term GetUnderapprox(Node *n); + + Term UnderapproxFlag(Node *n); + + hash_map underapprox_flag_rev; + + Node *UnderapproxFlagRev(const Term &flag); + + Term ProjectFormula(std::vector &keep_vec, const Term &f); + + int SubtermTruth(hash_map &memo, const Term &); + + void ImplicantRed(hash_map &memo, const Term &f, std::vector &lits, + hash_set *done, bool truth, hash_set &dont_cares); + + void Implicant(hash_map &memo, const Term &f, std::vector &lits, hash_set &dont_cares); + + Term UnderapproxFormula(const Term &f, hash_set &dont_cares); + + void ImplicantFullRed(hash_map &memo, const Term &f, std::vector &lits, + hash_set &done, hash_set &dont_cares); + + Term UnderapproxFullFormula(const Term &f, hash_set &dont_cares); + + Term ToRuleRec(Edge *e, hash_map &memo, const Term &t, std::vector &quants); + + hash_map resolve_ite_memo; + + Term ResolveIte(hash_map &memo, const Term &t, std::vector &lits, + hash_set *done, hash_set &dont_cares); + + struct ArrayValue { + bool defined; + std::map entries; + expr def_val; + }; + + void EvalArrayTerm(const Term &t, ArrayValue &res); + + Term EvalArrayEquality(const Term &f); + + Term ModelValueAsConstraint(const Term &t); + + void GetLabelsRec(hash_map &memo, const Term &f, std::vector &labels, + hash_set *done, bool truth); + + Term SubstBoundRec(hash_map > &memo, hash_map &subst, int level, const Term &t); + + Term SubstBound(hash_map &subst, const Term &t); + + void ConstrainEdgeLocalized(Edge *e, const Term &t); + + void GreedyReduce(solver &s, std::vector &conjuncts); + + void NegateLits(std::vector &lits); + + expr SimplifyOr(std::vector &lits); + + void SetAnnotation(Node *root, const expr &t); + + void AddEdgeToSolver(Edge *edge); + + void AddToProofCore(hash_set &core); + + void GetGroundLitsUnderQuants(hash_set *memo, const Term &f, std::vector &res, int under); + + Term StrengthenFormulaByCaseSplitting(const Term &f, std::vector &case_lits); + + expr NegateLit(const expr &f); + + }; + + /** RPFP solver base class. */ + + class Solver { + + public: + + struct Counterexample { + RPFP *tree; + RPFP::Node *root; + Counterexample(){ + tree = 0; + root = 0; + } + }; + + /** Solve the problem. You can optionally give an old + counterexample to use as a guide. This is chiefly useful for + abstraction refinement metholdologies, and is only used as a + heuristic. */ + + virtual bool Solve() = 0; + + virtual Counterexample GetCounterexample() = 0; + + virtual bool SetOption(const std::string &option, const std::string &value) = 0; + + /** Learn heuristic information from another solver. This + is chiefly useful for abstraction refinement, when we want to + solve a series of similar problems. */ + + virtual void LearnFrom(Counterexample &old_cex) = 0; + + virtual ~Solver(){} + + static Solver *Create(const std::string &solver_class, RPFP *rpfp); + + /** This can be called asynchrnously to cause Solve to throw a + Canceled exception at some time in the future. + */ + virtual void Cancel() = 0; + + /** Object thrown on cancellation */ + struct Canceled {}; + + }; +} + + +// Allow to hash on nodes and edges in deterministic way + +namespace hash_space { + template <> + class hash { + public: + size_t operator()(const Duality::RPFP::Node *p) const { + return p->number; + } + }; +} + +namespace hash_space { + template <> + class hash { + public: + size_t operator()(const Duality::RPFP::Edge *p) const { + return p->number; + } + }; +} + +// allow to walk sets of nodes without address dependency + +namespace std { + template <> + class less { + public: + bool operator()(Duality::RPFP::Node * const &s, Duality::RPFP::Node * const &t) const { + return s->number < t->number; // s.raw()->get_id() < t.raw()->get_id(); + } + }; +} diff --git a/src/duality/duality_hash.h b/src/duality/duality_hash.h new file mode 100755 index 000000000..f6767c037 --- /dev/null +++ b/src/duality/duality_hash.h @@ -0,0 +1,169 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3hash.h + +Abstract: + + Wrapper for stl hash tables + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +// pull in the headers for has_map and hash_set +// these live in non-standard places + +#ifndef IZ3_HASH_H +#define IZ3_HASH_H + +//#define USE_UNORDERED_MAP +#ifdef USE_UNORDERED_MAP + +#define stl_ext std +#define hash_space std +#include +#include +#define hash_map unordered_map +#define hash_set unordered_set + +#else + +#if __GNUC__ >= 3 +#undef __DEPRECATED +#define stl_ext __gnu_cxx +#define hash_space stl_ext +#include +#include +#else +#ifdef WIN32 +#define stl_ext stdext +#define hash_space std +#include +#include +#else +#define stl_ext std +#define hash_space std +#include +#include +#endif +#endif + +#endif + +#include + +// stupid STL doesn't include hash function for class string + +#ifndef WIN32 + +namespace stl_ext { + template <> + class hash { + stl_ext::hash H; + public: + size_t operator()(const std::string &s) const { + return H(s.c_str()); + } + }; +} + +#endif + +namespace hash_space { + template <> + class hash > { + public: + size_t operator()(const std::pair &p) const { + return p.first + p.second; + } + }; +} + +#ifdef WIN32 +template <> inline +size_t stdext::hash_value >(const std::pair& p) +{ // hash _Keyval to size_t value one-to-one + return p.first + p.second; +} +#endif + +namespace hash_space { + template + class hash > { + public: + size_t operator()(const std::pair &p) const { + return (size_t)p.first + (size_t)p.second; + } + }; +} + +#if 0 +template inline +size_t stdext::hash_value >(const std::pair& p) +{ // hash _Keyval to size_t value one-to-one + return (size_t)p.first + (size_t)p.second; +} +#endif + +#ifdef WIN32 + +namespace std { + template <> + class less > { + public: + bool operator()(const pair &x, const pair &y) const { + return x.first < y.first || x.first == y.first && x.second < y.second; + } + }; + +} + +namespace std { + template + class less > { + public: + bool operator()(const pair &x, const pair &y) const { + return (size_t)x.first < (size_t)y.first || (size_t)x.first == (size_t)y.first && (size_t)x.second < (size_t)y.second; + } + }; + +} + +#endif + + +#ifndef WIN32 + +namespace stl_ext { + template + class hash { + public: + size_t operator()(const T *p) const { + return (size_t) p; + } + }; +} + +#endif + +#ifdef WIN32 + + + + +template +class hash_map : public stl_ext::hash_map > > {}; + +template +class hash_set : public stl_ext::hash_set > > {}; + +#endif + +#endif diff --git a/src/duality/duality_profiling.cpp b/src/duality/duality_profiling.cpp new file mode 100755 index 000000000..bec32c51a --- /dev/null +++ b/src/duality/duality_profiling.cpp @@ -0,0 +1,126 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + duality_profiling.cpp + +Abstract: + + collection performance information for duality + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + + +--*/ + + +#include +#include +#include +#include +#include + +#include "duality_wrapper.h" + +namespace Duality { + + void show_time(){ + output_time(std::cout,current_time()); + std::cout << "\n"; + } + + typedef std::map nmap; + + struct node { + std::string name; + clock_t time; + clock_t start_time; + nmap sub; + struct node *parent; + + node(); + } top; + + node::node(){ + time = 0; + parent = 0; + } + + struct node *current; + + struct init { + init(){ + top.name = "TOTAL"; + current = ⊤ + } + } initializer; + + struct time_entry { + clock_t t; + time_entry(){t = 0;}; + void add(clock_t incr){t += incr;} + }; + + struct ltstr + { + bool operator()(const char* s1, const char* s2) const + { + return strcmp(s1, s2) < 0; + } + }; + + typedef std::map tmap; + + static std::ostream *pfs; + + void print_node(node &top, int indent, tmap &totals){ + for(int i = 0; i < indent; i++) (*pfs) << " "; + (*pfs) << top.name; + int dots = 70 - 2 * indent - top.name.size(); + for(int i = 0; i second,indent+1,totals); + } + + void print_profile(std::ostream &os) { + pfs = &os; + top.time = 0; + for(nmap::iterator it = top.sub.begin(); it != top.sub.end(); it++) + top.time += it->second.time; + tmap totals; + print_node(top,0,totals); + (*pfs) << "TOTALS:" << std::endl; + for(tmap::iterator it = totals.begin(); it != totals.end(); it++){ + (*pfs) << (it->first) << " "; + output_time(*pfs, it->second.t); + (*pfs) << std::endl; + } + } + + void timer_start(const char *name){ + node &child = current->sub[name]; + if(child.name.empty()){ // a new node + child.parent = current; + child.name = name; + } + child.start_time = current_time(); + current = &child; + } + + void timer_stop(const char *name){ + if(current->name != name || !current->parent){ + std::cerr << "imbalanced timer_start and timer_stop"; + exit(1); + } + current->time += (current_time() - current->start_time); + current = current->parent; + } +} diff --git a/src/duality/duality_profiling.h b/src/duality/duality_profiling.h new file mode 100755 index 000000000..ff70fae23 --- /dev/null +++ b/src/duality/duality_profiling.h @@ -0,0 +1,38 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + duality_profiling.h + +Abstract: + + collection performance information for duality + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + + +--*/ + +#ifndef DUALITYPROFILING_H +#define DUALITYPROFILING_H + +#include + +namespace Duality { + /** Start a timer with given name */ + void timer_start(const char *); + /** Stop a timer with given name */ + void timer_stop(const char *); + /** Print out timings */ + void print_profile(std::ostream &s); + /** Show the current time. */ + void show_time(); +} + +#endif + diff --git a/src/duality/duality_rpfp.cpp b/src/duality/duality_rpfp.cpp new file mode 100644 index 000000000..a5e2b9167 --- /dev/null +++ b/src/duality/duality_rpfp.cpp @@ -0,0 +1,2894 @@ +/*++ +Copyright (c) 2012 Microsoft Corporation + +Module Name: + + duality_rpfp.h + +Abstract: + + implements relational post-fixedpoint problem + (RPFP) data structure. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + + +--*/ + + + +#include "duality.h" +#include "duality_profiling.h" +#include +#include +#include +#include + +#ifndef WIN32 +// #define Z3OPS +#endif + +// TODO: do we need these? +#ifdef Z3OPS + +class Z3_subterm_truth { + public: + virtual bool eval(Z3_ast f) = 0; + ~Z3_subterm_truth(){} +}; + +Z3_subterm_truth *Z3_mk_subterm_truth(Z3_context ctx, Z3_model model); + +#endif + +#include + +// TODO: use something official for this +int debug_gauss = 0; + +namespace Duality { + + static char string_of_int_buffer[20]; + + const char *Z3User::string_of_int(int n){ + sprintf(string_of_int_buffer,"%d",n); + return string_of_int_buffer; + } + + RPFP::Term RPFP::SuffixVariable(const Term &t, int n) + { + std::string name = t.decl().name().str() + "_" + string_of_int(n); + return ctx.constant(name.c_str(), t.get_sort()); + } + + RPFP::Term RPFP::HideVariable(const Term &t, int n) + { + std::string name = std::string("@p_") + t.decl().name().str() + "_" + string_of_int(n); + return ctx.constant(name.c_str(), t.get_sort()); + } + + void RPFP::RedVars(Node *node, Term &b, std::vector &v) + { + int idx = node->number; + if(HornClauses) + b = ctx.bool_val(true); + else { + std::string name = std::string("@b_") + string_of_int(idx); + symbol sym = ctx.str_symbol(name.c_str()); + b = ctx.constant(sym,ctx.bool_sort()); + } + // ctx.constant(name.c_str(), ctx.bool_sort()); + v = node->Annotation.IndParams; + for(unsigned i = 0; i < v.size(); i++) + v[i] = SuffixVariable(v[i],idx); + } + + void Z3User::SummarizeRec(hash_set &memo, std::vector &lits, int &ops, const Term &t){ + if(memo.find(t) != memo.end()) + return; + memo.insert(t); + if(t.is_app()){ + decl_kind k = t.decl().get_decl_kind(); + if(k == And || k == Or || k == Not || k == Implies || k == Iff){ + ops++; + int nargs = t.num_args(); + for(int i = 0; i < nargs; i++) + SummarizeRec(memo,lits,ops,t.arg(i)); + return; + } + } + lits.push_back(t); + } + + int Z3User::CumulativeDecisions(){ +#if 0 + std::string stats = Z3_statistics_to_string(ctx); + int pos = stats.find("decisions:"); + pos += 10; + int end = stats.find('\n',pos); + std::string val = stats.substr(pos,end-pos); + return atoi(val.c_str()); +#endif + return slvr.get_num_decisions(); + } + + + void Z3User::Summarize(const Term &t){ + hash_set memo; std::vector lits; int ops = 0; + SummarizeRec(memo, lits, ops, t); + std::cout << ops << ": "; + for(unsigned i = 0; i < lits.size(); i++){ + if(i > 0) std::cout << ", "; + std::cout << lits[i]; + } + } + + int Z3User::CountOperatorsRec(hash_set &memo, const Term &t){ + if(memo.find(t) != memo.end()) + return 0; + memo.insert(t); + if(t.is_app()){ + decl_kind k = t.decl().get_decl_kind(); + if(k == And || k == Or){ + int count = 1; + int nargs = t.num_args(); + for(int i = 0; i < nargs; i++) + count += CountOperatorsRec(memo,t.arg(i)); + return count; + } + return 0; + } + if(t.is_quantifier()) + return CountOperatorsRec(memo,t.body())+2; // count 2 for a quantifier + return 0; + } + + int Z3User::CountOperators(const Term &t){ + hash_set memo; + return CountOperatorsRec(memo,t); + } + + + Z3User::Term Z3User::conjoin(const std::vector &args){ + return ctx.make(And,args); + } + + Z3User::Term Z3User::sum(const std::vector &args){ + return ctx.make(Plus,args); + } + + RPFP::Term RPFP::RedDualRela(Edge *e, std::vector &args, int idx){ + Node *child = e->Children[idx]; + Term b(ctx); + std::vector v; + RedVars(child, b, v); + for (unsigned i = 0; i < args.size(); i++) + { + if (eq(args[i].get_sort(),ctx.bool_sort())) + args[i] = ctx.make(Iff,args[i], v[i]); + else + args[i] = args[i] == v[i]; + } + return args.size() > 0 ? (b && conjoin(args)) : b; + } + + Z3User::Term Z3User::CloneQuantifier(const Term &t, const Term &new_body){ +#if 0 + Z3_context c = ctx; + Z3_ast a = t; + std::vector pats; + int num_pats = Z3_get_quantifier_num_patterns(c,a); + for(int i = 0; i < num_pats; i++) + pats.push_back(Z3_get_quantifier_pattern_ast(c,a,i)); + std::vector no_pats; + int num_no_pats = Z3_get_quantifier_num_patterns(c,a); + for(int i = 0; i < num_no_pats; i++) + no_pats.push_back(Z3_get_quantifier_no_pattern_ast(c,a,i)); + int bound = Z3_get_quantifier_num_bound(c,a); + std::vector sorts; + std::vector names; + for(int i = 0; i < bound; i++){ + sorts.push_back(Z3_get_quantifier_bound_sort(c,a,i)); + names.push_back(Z3_get_quantifier_bound_name(c,a,i)); + } + Z3_ast foo = Z3_mk_quantifier_ex(c, + Z3_is_quantifier_forall(c,a), + Z3_get_quantifier_weight(c,a), + 0, + 0, + num_pats, + &pats[0], + num_no_pats, + &no_pats[0], + bound, + &sorts[0], + &names[0], + new_body); + return expr(ctx,foo); +#endif + return clone_quantifier(t,new_body); + } + + + RPFP::Term RPFP::LocalizeRec(Edge *e, hash_map &memo, const Term &t) + { + std::pair foo(t,expr(ctx)); + std::pair::iterator, bool> bar = memo.insert(foo); + Term &res = bar.first->second; + if(!bar.second) return res; + hash_map::iterator it = e->varMap.find(t); + if (it != e->varMap.end()){ + res = it->second; + return res; + } + if (t.is_app()) + { + func_decl f = t.decl(); + std::vector args; + int nargs = t.num_args(); + for(int i = 0; i < nargs; i++) + args.push_back(LocalizeRec(e, memo, t.arg(i))); + hash_map::iterator rit = e->relMap.find(f); + if(rit != e->relMap.end()) + res = RedDualRela(e,args,(rit->second)); + else { + if (args.size() == 0 && f.get_decl_kind() == Uninterpreted && !ls->is_constant(f)) + { + res = HideVariable(t,e->number); + } + else + { + res = f(args.size(),&args[0]); + } + } + } + else if (t.is_quantifier()) + { + std::vector pats; + t.get_patterns(pats); + for(unsigned i = 0; i < pats.size(); i++) + pats[i] = LocalizeRec(e,memo,pats[i]); + Term body = LocalizeRec(e,memo,t.body()); + res = clone_quantifier(t, body, pats); + } + else res = t; + return res; + } + + void RPFP::SetEdgeMaps(Edge *e){ + timer_start("SetEdgeMaps"); + e->relMap.clear(); + e->varMap.clear(); + for(unsigned i = 0; i < e->F.RelParams.size(); i++){ + e->relMap[e->F.RelParams[i]] = i; + } + Term b(ctx); + std::vector v; + RedVars(e->Parent, b, v); + for(unsigned i = 0; i < e->F.IndParams.size(); i++){ + // func_decl parentParam = e->Parent.Annotation.IndParams[i]; + expr oldname = e->F.IndParams[i]; + expr newname = v[i]; + e->varMap[oldname] = newname; + } + timer_stop("SetEdgeMaps"); + + } + + + RPFP::Term RPFP::Localize(Edge *e, const Term &t){ + timer_start("Localize"); + hash_map memo; + if(e->F.IndParams.size() > 0 && e->varMap.empty()) + SetEdgeMaps(e); // TODO: why is this needed? + Term res = LocalizeRec(e,memo,t); + timer_stop("Localize"); + return res; + } + + + RPFP::Term RPFP::ReducedDualEdge(Edge *e) + { + SetEdgeMaps(e); + timer_start("RedVars"); + Term b(ctx); + std::vector v; + RedVars(e->Parent, b, v); + timer_stop("RedVars"); + // ast_to_track = (ast)b; + return implies(b, Localize(e, e->F.Formula)); + } + + TermTree *RPFP::ToTermTree(Node *root, Node *skip_descendant) + { + if(skip_descendant && root == skip_descendant) + return new TermTree(ctx.bool_val(true)); + Edge *e = root->Outgoing; + if(!e) return new TermTree(ctx.bool_val(true), std::vector()); + std::vector children(e->Children.size()); + for(unsigned i = 0; i < children.size(); i++) + children[i] = ToTermTree(e->Children[i],skip_descendant); + // Term top = ReducedDualEdge(e); + Term top = e->dual.null() ? ctx.bool_val(true) : e->dual; + TermTree *res = new TermTree(top, children); + for(unsigned i = 0; i < e->constraints.size(); i++) + res->addTerm(e->constraints[i]); + return res; + } + + TermTree *RPFP::GetGoalTree(Node *root){ + std::vector children(1); + children[0] = ToGoalTree(root); + return new TermTree(ctx.bool_val(false),children); + } + + TermTree *RPFP::ToGoalTree(Node *root) + { + Term b(ctx); + std::vector v; + RedVars(root, b, v); + Term goal = root->Name(v); + Edge *e = root->Outgoing; + if(!e) return new TermTree(goal, std::vector()); + std::vector children(e->Children.size()); + for(unsigned i = 0; i < children.size(); i++) + children[i] = ToGoalTree(e->Children[i]); + // Term top = ReducedDualEdge(e); + return new TermTree(goal, children); + } + + Z3User::Term Z3User::SubstRec(hash_map &memo, const Term &t) + { + std::pair foo(t,expr(ctx)); + std::pair::iterator, bool> bar = memo.insert(foo); + Term &res = bar.first->second; + if(!bar.second) return res; + if (t.is_app()) + { + func_decl f = t.decl(); + std::vector args; + int nargs = t.num_args(); + for(int i = 0; i < nargs; i++) + args.push_back(SubstRec(memo, t.arg(i))); + res = f(args.size(),&args[0]); + } + else if (t.is_quantifier()) + { + std::vector pats; + t.get_patterns(pats); + for(unsigned i = 0; i < pats.size(); i++) + pats[i] = SubstRec(memo,pats[i]); + Term body = SubstRec(memo,t.body()); + res = clone_quantifier(t, body, pats); + } + // res = CloneQuantifier(t,SubstRec(memo, t.body())); + else res = t; + return res; + } + + bool Z3User::IsLiteral(const expr &lit, expr &atom, expr &val){ + if(!(lit.is_quantifier() && IsClosedFormula(lit))){ + if(!lit.is_app()) + return false; + decl_kind k = lit.decl().get_decl_kind(); + if(k == Not){ + if(IsLiteral(lit.arg(0),atom,val)){ + val = eq(val,ctx.bool_val(true)) ? ctx.bool_val(false) : ctx.bool_val(true); + return true; + } + return false; + } + if(k == And || k == Or || k == Iff || k == Implies) + return false; + } + atom = lit; + val = ctx.bool_val(true); + return true; + } + + expr Z3User::Negate(const expr &f){ + if(f.is_app() && f.decl().get_decl_kind() == Not) + return f.arg(0); + else if(eq(f,ctx.bool_val(true))) + return ctx.bool_val(false); + else if(eq(f,ctx.bool_val(false))) + return ctx.bool_val(true); + return !f; + } + + expr Z3User::ReduceAndOr(const std::vector &args, bool is_and, std::vector &res){ + for(unsigned i = 0; i < args.size(); i++) + if(!eq(args[i],ctx.bool_val(is_and))){ + if(eq(args[i],ctx.bool_val(!is_and))) + return ctx.bool_val(!is_and); + res.push_back(args[i]); + } + return expr(); + } + + expr Z3User::FinishAndOr(const std::vector &args, bool is_and){ + if(args.size() == 0) + return ctx.bool_val(is_and); + if(args.size() == 1) + return args[0]; + return ctx.make(is_and ? And : Or,args); + } + + expr Z3User::SimplifyAndOr(const std::vector &args, bool is_and){ + std::vector sargs; + expr res = ReduceAndOr(args,is_and,sargs); + if(!res.null()) return res; + return FinishAndOr(sargs,is_and); + } + + expr Z3User::PullCommonFactors(std::vector &args, bool is_and){ + + // first check if there's anything to do... + if(args.size() < 2) + return FinishAndOr(args,is_and); + for(unsigned i = 0; i < args.size(); i++){ + const expr &a = args[i]; + if(!(a.is_app() && a.decl().get_decl_kind() == (is_and ? Or : And))) + return FinishAndOr(args,is_and); + } + std::vector common; + for(unsigned i = 0; i < args.size(); i++){ + unsigned n = args[i].num_args(); + std::vector v(n),w; + for(unsigned j = 0; j < n; j++) + v[j] = args[i].arg(j); + std::less comp; + std::sort(v.begin(),v.end(),comp); + if(i == 0) + common.swap(v); + else { + std::set_intersection(common.begin(),common.end(),v.begin(),v.end(),std::inserter(w,w.begin()),comp); + common.swap(w); + } + } + if(common.empty()) + return FinishAndOr(args,is_and); + std::set common_set(common.begin(),common.end()); + for(unsigned i = 0; i < args.size(); i++){ + unsigned n = args[i].num_args(); + std::vector lits; + for(unsigned j = 0; j < n; j++){ + const expr b = args[i].arg(j); + if(common_set.find(b) == common_set.end()) + lits.push_back(b); + } + args[i] = SimplifyAndOr(lits,!is_and); + } + common.push_back(SimplifyAndOr(args,is_and)); + return SimplifyAndOr(common,!is_and); + } + + expr Z3User::ReallySimplifyAndOr(const std::vector &args, bool is_and){ + std::vector sargs; + expr res = ReduceAndOr(args,is_and,sargs); + if(!res.null()) return res; + return PullCommonFactors(sargs,is_and); + } + + Z3User::Term Z3User::SubstAtomTriv(const expr &foo, const expr &atom, const expr &val){ + if(eq(foo,atom)) + return val; + else if(foo.is_app() && foo.decl().get_decl_kind() == Not && eq(foo.arg(0),atom)) + return Negate(val); + else + return foo; + } + + Z3User::Term Z3User::SubstAtom(hash_map &memo, const expr &t, const expr &atom, const expr &val){ + std::pair foo(t,expr(ctx)); + std::pair::iterator, bool> bar = memo.insert(foo); + Term &res = bar.first->second; + if(!bar.second) return res; + if (t.is_app()){ + func_decl f = t.decl(); + decl_kind k = f.get_decl_kind(); + + // TODO: recur here, but how much? We don't want to be quadractic in formula size + + if(k == And || k == Or){ + int nargs = t.num_args(); + std::vector args(nargs); + for(int i = 0; i < nargs; i++) + args[i] = SubstAtom(memo,t.arg(i),atom,val); + res = ReallySimplifyAndOr(args, k==And); + return res; + } + } + else if(t.is_quantifier() && atom.is_quantifier()){ + if(eq(t,atom)) + res = val; + else + res = clone_quantifier(t,SubstAtom(memo,t.body(),atom,val)); + return res; + } + res = SubstAtomTriv(t,atom,val); + return res; + } + + void Z3User::RemoveRedundancyOp(bool pol, std::vector &args, hash_map &smemo){ + for(unsigned i = 0; i < args.size(); i++){ + const expr &lit = args[i]; + expr atom, val; + if(IsLiteral(lit,atom,val)){ + for(unsigned j = 0; j < args.size(); j++) + if(j != i){ + smemo.clear(); + args[j] = SubstAtom(smemo,args[j],atom,pol ? val : !val); + } + } + } + } + + + Z3User::Term Z3User::RemoveRedundancyRec(hash_map &memo, hash_map &smemo, const Term &t) + { + std::pair foo(t,expr(ctx)); + std::pair::iterator, bool> bar = memo.insert(foo); + Term &res = bar.first->second; + if(!bar.second) return res; + if (t.is_app()) + { + func_decl f = t.decl(); + std::vector args; + int nargs = t.num_args(); + for(int i = 0; i < nargs; i++) + args.push_back(RemoveRedundancyRec(memo, smemo, t.arg(i))); + + decl_kind k = f.get_decl_kind(); + if(k == And){ + RemoveRedundancyOp(true,args,smemo); + res = ReallySimplifyAndOr(args, true); + } + else if(k == Or){ + RemoveRedundancyOp(false,args,smemo); + res = ReallySimplifyAndOr(args, false); + } + else { + if(k == Equal && args[0].get_id() > args[1].get_id()) + std::swap(args[0],args[1]); + res = f(args.size(),&args[0]); + } + } + else if (t.is_quantifier()) + { + Term body = RemoveRedundancyRec(memo,smemo,t.body()); + res = clone_quantifier(t, body); + } + else res = t; + return res; + } + + Z3User::Term Z3User::RemoveRedundancy(const Term &t){ + hash_map memo; + hash_map smemo; + return RemoveRedundancyRec(memo,smemo,t); + } + + Z3User::Term Z3User::AdjustQuantifiers(const Term &t) + { + if(t.is_quantifier() || (t.is_app() && t.has_quantifiers())) + return t.qe_lite(); + return t; + } + + Z3User::Term Z3User::SubstRecHide(hash_map &memo, const Term &t, int number) + { + std::pair foo(t,expr(ctx)); + std::pair::iterator, bool> bar = memo.insert(foo); + Term &res = bar.first->second; + if(!bar.second) return res; + if (t.is_app()) + { + func_decl f = t.decl(); + std::vector args; + int nargs = t.num_args(); + if (nargs == 0 && f.get_decl_kind() == Uninterpreted){ + std::string name = std::string("@q_") + t.decl().name().str() + "_" + string_of_int(number); + res = ctx.constant(name.c_str(), t.get_sort()); + return res; + } + for(int i = 0; i < nargs; i++) + args.push_back(SubstRec(memo, t.arg(i))); + res = f(args.size(),&args[0]); + } + else if (t.is_quantifier()) + res = CloneQuantifier(t,SubstRec(memo, t.body())); + else res = t; + return res; + } + + RPFP::Term RPFP::SubstParams(const std::vector &from, + const std::vector &to, const Term &t){ + hash_map memo; + bool some_diff = false; + for(unsigned i = 0; i < from.size(); i++) + if(i < to.size() && !eq(from[i],to[i])){ + memo[from[i]] = to[i]; + some_diff = true; + } + return some_diff ? SubstRec(memo,t) : t; + } + + + void Z3User::Strengthen(Term &x, const Term &y) + { + if (eq(x,ctx.bool_val(true))) + x = y; + else + x = x && y; + } + + void RPFP::SetAnnotation(Node *root, const expr &t){ + hash_map memo; + Term b; + std::vector v; + RedVars(root, b, v); + memo[b] = ctx.bool_val(true); + for (unsigned i = 0; i < v.size(); i++) + memo[v[i]] = root->Annotation.IndParams[i]; + Term annot = SubstRec(memo, t); + // Strengthen(ref root.Annotation.Formula, annot); + root->Annotation.Formula = annot; + } + + void RPFP::DecodeTree(Node *root, TermTree *interp, int persist) + { + std::vector &ic = interp->getChildren(); + if (ic.size() > 0) + { + std::vector &nc = root->Outgoing->Children; + for (unsigned i = 0; i < nc.size(); i++) + DecodeTree(nc[i], ic[i], persist); + } + SetAnnotation(root,interp->getTerm()); +#if 0 + if(persist != 0) + Z3_persist_ast(ctx,root->Annotation.Formula,persist); +#endif + } + + RPFP::Term RPFP::GetUpperBound(Node *n) + { + // TODO: cache this + Term b(ctx); std::vector v; + RedVars(n, b, v); + hash_map memo; + for (unsigned int i = 0; i < v.size(); i++) + memo[n->Bound.IndParams[i]] = v[i]; + Term cnst = SubstRec(memo, n->Bound.Formula); + return b && !cnst; + } + + RPFP::Term RPFP::GetAnnotation(Node *n) + { + if(eq(n->Annotation.Formula,ctx.bool_val(true))) + return n->Annotation.Formula; + // TODO: cache this + Term b(ctx); std::vector v; + RedVars(n, b, v); + hash_map memo; + for (unsigned i = 0; i < v.size(); i++) + memo[n->Annotation.IndParams[i]] = v[i]; + Term cnst = SubstRec(memo, n->Annotation.Formula); + return !b || cnst; + } + + RPFP::Term RPFP::GetUnderapprox(Node *n) + { + // TODO: cache this + Term b(ctx); std::vector v; + RedVars(n, b, v); + hash_map memo; + for (unsigned i = 0; i < v.size(); i++) + memo[n->Underapprox.IndParams[i]] = v[i]; + Term cnst = SubstRecHide(memo, n->Underapprox.Formula, n->number); + return !b || cnst; + } + + TermTree *RPFP::AddUpperBound(Node *root, TermTree *t) + { + Term f = !((ast)(root->dual)) ? ctx.bool_val(true) : root->dual; + std::vector c(1); c[0] = t; + return new TermTree(f, c); + } + +#if 0 + void RPFP::WriteInterps(System.IO.StreamWriter f, TermTree t) + { + foreach (var c in t.getChildren()) + WriteInterps(f, c); + f.WriteLine(t.getTerm()); + } +#endif + + + /** For incremental solving, asserts the constraint associated + * with this edge in the SMT context. If this edge is removed, + * you must pop the context accordingly. The second argument is + * the number of pushes we are inside. The constraint formula + * will survive "persist" pops of the context. If you plan + * to reassert the edge after popping the context once, + * you can save time re-constructing the formula by setting + * "persist" to one. If you set "persist" too high, however, + * you could have a memory leak. + * + * The flag "with children" causes the annotations of the children + * to be asserted. The flag underapprox causes the underapproximations + * of the children to be asserted *conditionally*. See Check() on + * how to actually enforce these constraints. + * + */ + + void RPFP::AssertEdge(Edge *e, int persist, bool with_children, bool underapprox) + { + if(eq(e->F.Formula,ctx.bool_val(true)) && (!with_children || e->Children.empty())) + return; + if (e->dual.null()) + { + timer_start("ReducedDualEdge"); + e->dual = ReducedDualEdge(e); + timer_stop("ReducedDualEdge"); + timer_start("getting children"); + if(with_children) + for(unsigned i = 0; i < e->Children.size(); i++) + e->dual = e->dual && GetAnnotation(e->Children[i]); + if(underapprox){ + std::vector cus(e->Children.size()); + for(unsigned i = 0; i < e->Children.size(); i++) + cus[i] = !UnderapproxFlag(e->Children[i]) || GetUnderapprox(e->Children[i]); + expr cnst = conjoin(cus); + e->dual = e->dual && cnst; + } + timer_stop("getting children"); + timer_start("Persisting"); + std::list::reverse_iterator it = stack.rbegin(); + for(int i = 0; i < persist && it != stack.rend(); i++) + it++; + if(it != stack.rend()) + it->edges.push_back(e); +#if 0 + if(persist != 0) + Z3_persist_ast(ctx,e->dual,persist); +#endif + timer_stop("Persisting"); + //Console.WriteLine("{0}", cnst); + } + timer_start("solver add"); + slvr.add(e->dual); + timer_stop("solver add"); + } + + void RPFP::ConstrainParent(Edge *parent, Node *child){ + ConstrainEdgeLocalized(parent,GetAnnotation(child)); + } + + + /** For incremental solving, asserts the negation of the upper bound associated + * with a node. + * */ + + void RPFP::AssertNode(Node *n) + { + if (n->dual.null()) + { + n->dual = GetUpperBound(n); + stack.back().nodes.push_back(n); + slvr.add(n->dual); + } + } + + /** Assert a constraint on an edge in the SMT context. + */ + + void RPFP::ConstrainEdge(Edge *e, const Term &t) + { + Term tl = Localize(e, t); + ConstrainEdgeLocalized(e,tl); + } + + void RPFP::ConstrainEdgeLocalized(Edge *e, const Term &tl) + { + e->constraints.push_back(tl); + stack.back().constraints.push_back(std::pair(e,tl)); + slvr.add(tl); + } + + + + /** Declare a constant in the background theory. */ + + void RPFP::DeclareConstant(const FuncDecl &f){ + ls->declare_constant(f); + } + + /** Assert a background axiom. Background axioms can be used to provide the + * theory of auxilliary functions or relations. All symbols appearing in + * background axioms are considered global, and may appear in both transformer + * and relational solutions. Semantically, a solution to the RPFP gives + * an interpretation of the unknown relations for each interpretation of the + * auxilliary symbols that is consistent with the axioms. Axioms should be + * asserted before any calls to Push. They cannot be de-asserted by Pop. */ + + void RPFP::AssertAxiom(const Term &t) + { + ls->assert_axiom(t); + axioms.push_back(t); // for printing only + } + +#if 0 + /** Do not call this. */ + + void RPFP::RemoveAxiom(const Term &t) + { + slvr.RemoveInterpolationAxiom(t); + } +#endif + + /** Solve an RPFP graph. This means either strengthen the annotation + * so that the bound at the given root node is satisfied, or + * show that this cannot be done by giving a dual solution + * (i.e., a counterexample). + * + * In the current implementation, this only works for graphs that + * are: + * - tree-like + * + * - closed. + * + * In a tree-like graph, every nod has out most one incoming and one out-going edge, + * and there are no cycles. In a closed graph, every node has exactly one out-going + * edge. This means that the leaves of the tree are all hyper-edges with no + * children. Such an edge represents a relation (nullary transformer) and thus + * a lower bound on its parent. The parameter root must be the root of this tree. + * + * If Solve returns LBool.False, this indicates success. The annotation of the tree + * has been updated to satisfy the upper bound at the root. + * + * If Solve returns LBool.True, this indicates a counterexample. For each edge, + * you can then call Eval to determine the values of symbols in the transformer formula. + * You can also call Empty on a node to determine if its value in the counterexample + * is the empty relation. + * + * \param root The root of the tree + * \param persist Number of context pops through which result should persist + * + * + */ + + lbool RPFP::Solve(Node *root, int persist) + { + timer_start("Solve"); + TermTree *tree = GetConstraintTree(root); + TermTree *interpolant = NULL; + TermTree *goals = NULL; + if(ls->need_goals) + goals = GetGoalTree(root); + ClearProofCore(); + + // if (dualModel != null) dualModel.Dispose(); + // if (dualLabels != null) dualLabels.Dispose(); + + timer_start("interpolate_tree"); + lbool res = ls->interpolate_tree(tree, interpolant, dualModel,goals,true); + timer_stop("interpolate_tree"); + if (res == l_false) + { + DecodeTree(root, interpolant->getChildren()[0], persist); + delete interpolant; + } + + delete tree; + if(goals) + delete goals; + + timer_stop("Solve"); + return res; + } + + void RPFP::CollapseTermTreeRec(TermTree *root, TermTree *node){ + root->addTerm(node->getTerm()); + std::vector &cnsts = node->getTerms(); + for(unsigned i = 0; i < cnsts.size(); i++) + root->addTerm(cnsts[i]); + std::vector &chs = node->getChildren(); + for(unsigned i = 0; i < chs.size(); i++){ + CollapseTermTreeRec(root,chs[i]); + } + } + + TermTree *RPFP::CollapseTermTree(TermTree *node){ + std::vector &chs = node->getChildren(); + for(unsigned i = 0; i < chs.size(); i++) + CollapseTermTreeRec(node,chs[i]); + for(unsigned i = 0; i < chs.size(); i++) + delete chs[i]; + chs.clear(); + return node; + } + + lbool RPFP::SolveSingleNode(Node *root, Node *node) + { + timer_start("Solve"); + TermTree *tree = CollapseTermTree(GetConstraintTree(root,node)); + tree->getChildren().push_back(CollapseTermTree(ToTermTree(node))); + TermTree *interpolant = NULL; + ClearProofCore(); + + timer_start("interpolate_tree"); + lbool res = ls->interpolate_tree(tree, interpolant, dualModel,0,true); + timer_stop("interpolate_tree"); + if (res == l_false) + { + DecodeTree(node, interpolant->getChildren()[0], 0); + delete interpolant; + } + + delete tree; + timer_stop("Solve"); + return res; + } + + /** Get the constraint tree (but don't solve it) */ + + TermTree *RPFP::GetConstraintTree(Node *root, Node *skip_descendant) + { + return AddUpperBound(root, ToTermTree(root,skip_descendant)); + } + + /** Dispose of the dual model (counterexample) if there is one. */ + + void RPFP::DisposeDualModel() + { + dualModel = model(ctx,NULL); + } + + RPFP::Term RPFP::UnderapproxFlag(Node *n){ + expr flag = ctx.constant((std::string("@under") + string_of_int(n->number)).c_str(), ctx.bool_sort()); + underapprox_flag_rev[flag] = n; + return flag; + } + + RPFP::Node *RPFP::UnderapproxFlagRev(const Term &flag){ + return underapprox_flag_rev[flag]; + } + + /** Check satisfiability of asserted edges and nodes. Same functionality as + * Solve, except no primal solution (interpolant) is generated in the unsat case. + * The vector underapproxes gives the set of node underapproximations to be enforced + * (assuming these were conditionally asserted by AssertEdge). + * + */ + + check_result RPFP::Check(Node *root, std::vector underapproxes, std::vector *underapprox_core ) + { + ClearProofCore(); + // if (dualModel != null) dualModel.Dispose(); + check_result res; + if(!underapproxes.size()) + res = slvr.check(); + else { + std::vector us(underapproxes.size()); + for(unsigned i = 0; i < underapproxes.size(); i++) + us[i] = UnderapproxFlag(underapproxes[i]); + slvr.check(); // TODO: no idea why I need to do this + if(underapprox_core){ + std::vector unsat_core(us.size()); + unsigned core_size = 0; + res = slvr.check(us.size(),&us[0],&core_size,&unsat_core[0]); + underapprox_core->resize(core_size); + for(unsigned i = 0; i < core_size; i++) + (*underapprox_core)[i] = UnderapproxFlagRev(unsat_core[i]); + } + else { + res = slvr.check(us.size(),&us[0]); + bool dump = false; + if(dump){ + std::vector cnsts; + // cnsts.push_back(axioms[0]); + cnsts.push_back(root->dual); + cnsts.push_back(root->Outgoing->dual); + ls->write_interpolation_problem("temp.smt",cnsts,std::vector()); + } + } + // check_result temp = slvr.check(); + } + dualModel = slvr.get_model(); + return res; + } + + check_result RPFP::CheckUpdateModel(Node *root, std::vector assumps){ + // check_result temp1 = slvr.check(); // no idea why I need to do this + ClearProofCore(); + check_result res = slvr.check_keep_model(assumps.size(),&assumps[0]); + dualModel = slvr.get_model(); + return res; + } + + /** Determines the value in the counterexample of a symbol occuring in the transformer formula of + * a given edge. */ + + RPFP::Term RPFP::Eval(Edge *e, Term t) + { + Term tl = Localize(e, t); + return dualModel.eval(tl); + } + + + + /** Returns true if the given node is empty in the primal solution. For proecudure summaries, + this means that the procedure is not called in the current counter-model. */ + + bool RPFP::Empty(Node *p) + { + Term b; std::vector v; + RedVars(p, b, v); + // dualModel.show_internals(); + // std::cout << "b: " << b << std::endl; + expr bv = dualModel.eval(b); + // std::cout << "bv: " << bv << std::endl; + bool res = !eq(bv,ctx.bool_val(true)); + return res; + } + + RPFP::Term RPFP::EvalNode(Node *p) + { + Term b; std::vector v; + RedVars(p, b, v); + std::vector args; + for(unsigned i = 0; i < v.size(); i++) + args.push_back(dualModel.eval(v[i])); + return (p->Name)(args); + } + + void RPFP::EvalArrayTerm(const RPFP::Term &t, ArrayValue &res){ + if(t.is_app()){ + decl_kind k = t.decl().get_decl_kind(); + if(k == AsArray){ + func_decl fd = t.decl().get_func_decl_parameter(0); + func_interp r = dualModel.get_func_interp(fd); + int num = r.num_entries(); + res.defined = true; + for(int i = 0; i < num; i++){ + expr arg = r.get_arg(i,0); + expr value = r.get_value(i); + res.entries[arg] = value; + } + res.def_val = r.else_value(); + return; + } + else if(k == Store){ + EvalArrayTerm(t.arg(0),res); + if(!res.defined)return; + expr addr = t.arg(1); + expr val = t.arg(2); + if(addr.is_numeral() && val.is_numeral()){ + if(eq(val,res.def_val)) + res.entries.erase(addr); + else + res.entries[addr] = val; + } + else + res.defined = false; + return; + } + } + res.defined = false; + } + + int eae_count = 0; + + RPFP::Term RPFP::EvalArrayEquality(const RPFP::Term &f){ + ArrayValue lhs,rhs; + eae_count++; + EvalArrayTerm(f.arg(0),lhs); + EvalArrayTerm(f.arg(1),rhs); + if(lhs.defined && rhs.defined){ + if(eq(lhs.def_val,rhs.def_val)) + if(lhs.entries == rhs.entries) + return ctx.bool_val(true); + return ctx.bool_val(false); + } + return f; + } + + /** Compute truth values of all boolean subterms in current model. + Assumes formulas has been simplified by Z3, so only boolean ops + ands and, or, not. Returns result in memo. + */ + + int RPFP::SubtermTruth(hash_map &memo, const Term &f){ + if(memo.find(f) != memo.end()){ + return memo[f]; + } + int res; + if(f.is_app()){ + int nargs = f.num_args(); + decl_kind k = f.decl().get_decl_kind(); + if(k == Implies){ + res = SubtermTruth(memo,!f.arg(0) || f.arg(1)); + goto done; + } + if(k == And) { + res = 1; + for(int i = 0; i < nargs; i++){ + int ar = SubtermTruth(memo,f.arg(i)); + if(ar == 0){ + res = 0; + goto done; + } + if(ar == 2)res = 2; + } + goto done; + } + else if(k == Or) { + res = 0; + for(int i = 0; i < nargs; i++){ + int ar = SubtermTruth(memo,f.arg(i)); + if(ar == 1){ + res = 1; + goto done; + } + if(ar == 2)res = 2; + } + goto done; + } + else if(k == Not) { + int ar = SubtermTruth(memo,f.arg(0)); + res = (ar == 0) ? 1 : ((ar == 1) ? 0 : 2); + goto done; + } + } + { + bool pos; std::vector names; + if(f.is_label(pos,names)){ + res = SubtermTruth(memo,f.arg(0)); + goto done; + } + } + { + expr bv = dualModel.eval(f); + if(bv.is_app() && bv.decl().get_decl_kind() == Equal && + bv.arg(0).is_array()){ + bv = EvalArrayEquality(bv); + } + // Hack!!!! array equalities can occur negatively! + if(bv.is_app() && bv.decl().get_decl_kind() == Not && + bv.arg(0).decl().get_decl_kind() == Equal && + bv.arg(0).arg(0).is_array()){ + bv = dualModel.eval(!EvalArrayEquality(bv.arg(0))); + } + if(eq(bv,ctx.bool_val(true))) + res = 1; + else if(eq(bv,ctx.bool_val(false))) + res = 0; + else + res = 2; + } + done: + memo[f] = res; + return res; + } + + int RPFP::EvalTruth(hash_map &memo, Edge *e, const Term &f){ + Term tl = Localize(e, f); + return SubtermTruth(memo,tl); + } + + /** Compute truth values of all boolean subterms in current model. + Assumes formulas has been simplified by Z3, so only boolean ops + ands and, or, not. Returns result in memo. + */ + +#if 0 + int RPFP::GetLabelsRec(hash_map *memo, const Term &f, std::vector &labels, bool labpos){ + if(memo[labpos].find(f) != memo[labpos].end()){ + return memo[labpos][f]; + } + int res; + if(f.is_app()){ + int nargs = f.num_args(); + decl_kind k = f.decl().get_decl_kind(); + if(k == Implies){ + res = GetLabelsRec(memo,f.arg(1) || !f.arg(0), labels, labpos); + goto done; + } + if(k == And) { + res = 1; + for(int i = 0; i < nargs; i++){ + int ar = GetLabelsRec(memo,f.arg(i), labels, labpos); + if(ar == 0){ + res = 0; + goto done; + } + if(ar == 2)res = 2; + } + goto done; + } + else if(k == Or) { + res = 0; + for(int i = 0; i < nargs; i++){ + int ar = GetLabelsRec(memo,f.arg(i), labels, labpos); + if(ar == 1){ + res = 1; + goto done; + } + if(ar == 2)res = 2; + } + goto done; + } + else if(k == Not) { + int ar = GetLabelsRec(memo,f.arg(0), labels, !labpos); + res = (ar == 0) ? 1 : ((ar == 1) ? 0 : 2); + goto done; + } + } + { + bool pos; std::vector names; + if(f.is_label(pos,names)){ + res = GetLabelsRec(memo,f.arg(0), labels, labpos); + if(pos == labpos && res == (pos ? 1 : 0)) + for(unsigned i = 0; i < names.size(); i++) + labels.push_back(names[i]); + goto done; + } + } + { + expr bv = dualModel.eval(f); + if(bv.is_app() && bv.decl().get_decl_kind() == Equal && + bv.arg(0).is_array()){ + bv = EvalArrayEquality(bv); + } + // Hack!!!! array equalities can occur negatively! + if(bv.is_app() && bv.decl().get_decl_kind() == Not && + bv.arg(0).decl().get_decl_kind() == Equal && + bv.arg(0).arg(0).is_array()){ + bv = dualModel.eval(!EvalArrayEquality(bv.arg(0))); + } + if(eq(bv,ctx.bool_val(true))) + res = 1; + else if(eq(bv,ctx.bool_val(false))) + res = 0; + else + res = 2; + } + done: + memo[labpos][f] = res; + return res; + } +#endif + + void RPFP::GetLabelsRec(hash_map &memo, const Term &f, std::vector &labels, + hash_set *done, bool truth){ + if(done[truth].find(f) != done[truth].end()) + return; /* already processed */ + if(f.is_app()){ + int nargs = f.num_args(); + decl_kind k = f.decl().get_decl_kind(); + if(k == Implies){ + GetLabelsRec(memo,f.arg(1) || !f.arg(0) ,labels,done,truth); + goto done; + } + if(k == Iff){ + int b = SubtermTruth(memo,f.arg(0)); + if(b == 2) + throw "disaster in GetLabelsRec"; + GetLabelsRec(memo,f.arg(1),labels,done,truth ? b : !b); + goto done; + } + if(truth ? k == And : k == Or) { + for(int i = 0; i < nargs; i++) + GetLabelsRec(memo,f.arg(i),labels,done,truth); + goto done; + } + if(truth ? k == Or : k == And) { + for(int i = 0; i < nargs; i++){ + Term a = f.arg(i); + timer_start("SubtermTruth"); + int b = SubtermTruth(memo,a); + timer_stop("SubtermTruth"); + if(truth ? (b == 1) : (b == 0)){ + GetLabelsRec(memo,a,labels,done,truth); + goto done; + } + } + /* Unreachable! */ + throw "error in RPFP::GetLabelsRec"; + goto done; + } + else if(k == Not) { + GetLabelsRec(memo,f.arg(0),labels,done,!truth); + goto done; + } + else { + bool pos; std::vector names; + if(f.is_label(pos,names)){ + GetLabelsRec(memo,f.arg(0), labels, done, truth); + if(pos == truth) + for(unsigned i = 0; i < names.size(); i++) + labels.push_back(names[i]); + goto done; + } + } + } + done: + done[truth].insert(f); + } + + void RPFP::GetLabels(Edge *e, std::vector &labels){ + if(!e->map || e->map->labeled.null()) + return; + Term tl = Localize(e, e->map->labeled); + hash_map memo; + hash_set done[2]; + GetLabelsRec(memo,tl,labels,done,true); + } + +#ifdef Z3OPS + static Z3_subterm_truth *stt; +#endif + + int ir_count = 0; + + void RPFP::ImplicantRed(hash_map &memo, const Term &f, std::vector &lits, + hash_set *done, bool truth, hash_set &dont_cares){ + if(done[truth].find(f) != done[truth].end()) + return; /* already processed */ +#if 0 + int this_count = ir_count++; + if(this_count == 50092) + std::cout << "foo!\n"; +#endif + if(f.is_app()){ + int nargs = f.num_args(); + decl_kind k = f.decl().get_decl_kind(); + if(k == Implies){ + ImplicantRed(memo,f.arg(1) || !f.arg(0) ,lits,done,truth,dont_cares); + goto done; + } + if(k == Iff){ + int b = SubtermTruth(memo,f.arg(0)); + if(b == 2) + throw "disaster in ImplicantRed"; + ImplicantRed(memo,f.arg(1),lits,done,truth ? b : !b,dont_cares); + goto done; + } + if(truth ? k == And : k == Or) { + for(int i = 0; i < nargs; i++) + ImplicantRed(memo,f.arg(i),lits,done,truth,dont_cares); + goto done; + } + if(truth ? k == Or : k == And) { + for(int i = 0; i < nargs; i++){ + Term a = f.arg(i); +#if 0 + if(i == nargs - 1){ // last chance! + ImplicantRed(memo,a,lits,done,truth,dont_cares); + goto done; + } +#endif + timer_start("SubtermTruth"); +#ifdef Z3OPS + bool b = stt->eval(a); +#else + int b = SubtermTruth(memo,a); +#endif + timer_stop("SubtermTruth"); + if(truth ? (b == 1) : (b == 0)){ + ImplicantRed(memo,a,lits,done,truth,dont_cares); + goto done; + } + } + /* Unreachable! */ + // TODO: need to indicate this failure to caller + // std::cerr << "error in RPFP::ImplicantRed"; + goto done; + } + else if(k == Not) { + ImplicantRed(memo,f.arg(0),lits,done,!truth,dont_cares); + goto done; + } + } + { + if(dont_cares.find(f) == dont_cares.end()){ + expr rf = ResolveIte(memo,f,lits,done,dont_cares); + expr bv = truth ? rf : !rf; + lits.push_back(bv); + } + } + done: + done[truth].insert(f); + } + + void RPFP::ImplicantFullRed(hash_map &memo, const Term &f, std::vector &lits, + hash_set &done, hash_set &dont_cares){ + if(done.find(f) != done.end()) + return; /* already processed */ + if(f.is_app()){ + int nargs = f.num_args(); + decl_kind k = f.decl().get_decl_kind(); + if(k == Implies || k == Iff || k == And || k == Or || k == Not){ + for(int i = 0; i < nargs; i++) + ImplicantFullRed(memo,f.arg(i),lits,done,dont_cares); + goto done; + } + } + { + if(dont_cares.find(f) == dont_cares.end()){ + int b = SubtermTruth(memo,f); + if(b != 0 && b != 1) goto done; + expr bv = (b==1) ? f : !f; + lits.push_back(bv); + } + } + done: + done.insert(f); + } + + RPFP::Term RPFP::ResolveIte(hash_map &memo, const Term &t, std::vector &lits, + hash_set *done, hash_set &dont_cares){ + if(resolve_ite_memo.find(t) != resolve_ite_memo.end()) + return resolve_ite_memo[t]; + Term res; + if (t.is_app()) + { + func_decl f = t.decl(); + std::vector args; + int nargs = t.num_args(); + if(f.get_decl_kind() == Ite){ + timer_start("SubtermTruth"); +#ifdef Z3OPS + bool sel = stt->eval(t.arg(0)); +#else + int xval = SubtermTruth(memo,t.arg(0)); + bool sel; + if(xval == 0)sel = false; + else if(xval == 1)sel = true; + else + throw "unresolved ite in model"; +#endif + timer_stop("SubtermTruth"); + ImplicantRed(memo,t.arg(0),lits,done,sel,dont_cares); + res = ResolveIte(memo,t.arg(sel?1:2),lits,done,dont_cares); + } + else { + for(int i = 0; i < nargs; i++) + args.push_back(ResolveIte(memo,t.arg(i),lits,done,dont_cares)); + res = f(args.size(),&args[0]); + } + } + else res = t; + resolve_ite_memo[t] = res; + return res; + } + + void RPFP::Implicant(hash_map &memo, const Term &f, std::vector &lits, hash_set &dont_cares){ + hash_set done[2]; + ImplicantRed(memo,f,lits,done,true, dont_cares); + } + + + /** Underapproximate a formula using current counterexample. */ + + RPFP::Term RPFP::UnderapproxFormula(const Term &f, hash_set &dont_cares){ + /* first compute truth values of subterms */ + hash_map memo; + #ifdef Z3OPS + stt = Z3_mk_subterm_truth(ctx,dualModel); + #endif + // SubtermTruth(memo,f); + /* now compute an implicant */ + std::vector lits; + Implicant(memo,f,lits, dont_cares); +#ifdef Z3OPS + delete stt; stt = 0; +#endif + /* return conjunction of literals */ + return conjoin(lits); + } + + RPFP::Term RPFP::UnderapproxFullFormula(const Term &f, hash_set &dont_cares){ + /* first compute truth values of subterms */ + hash_map memo; + hash_set done; + std::vector lits; + ImplicantFullRed(memo,f,lits,done,dont_cares); + /* return conjunction of literals */ + return conjoin(lits); + } + + struct VariableProjector : Z3User { + + struct elim_cand { + Term var; + int sup; + Term val; + }; + + typedef expr Term; + + hash_set keep; + hash_map var_ord; + int num_vars; + std::vector elim_cands; + hash_map > sup_map; + hash_map elim_map; + std::vector ready_cands; + hash_map cand_map; + params simp_params; + + VariableProjector(Z3User &_user, std::vector &keep_vec) : + Z3User(_user), simp_params() + { + num_vars = 0; + for(unsigned i = 0; i < keep_vec.size(); i++){ + keep.insert(keep_vec[i]); + var_ord[keep_vec[i]] = num_vars++; + } + } + + int VarNum(const Term &v){ + if(var_ord.find(v) == var_ord.end()) + var_ord[v] = num_vars++; + return var_ord[v]; + } + + bool IsVar(const Term &t){ + return t.is_app() && t.num_args() == 0 && t.decl().get_decl_kind() == Uninterpreted; + } + + bool IsPropLit(const Term &t, Term &a){ + if(IsVar(t)){ + a = t; + return true; + } + else if(t.is_app() && t.decl().get_decl_kind() == Not) + return IsPropLit(t.arg(0),a); + return false; + } + + void CountOtherVarsRec(hash_map &memo, + const Term &t, + int id, + int &count){ + std::pair foo(t,0); + std::pair::iterator, bool> bar = memo.insert(foo); + // int &res = bar.first->second; + if(!bar.second) return; + if (t.is_app()) + { + func_decl f = t.decl(); + std::vector args; + int nargs = t.num_args(); + if (nargs == 0 && f.get_decl_kind() == Uninterpreted){ + if(cand_map.find(t) != cand_map.end()){ + count++; + sup_map[t].push_back(id); + } + } + for(int i = 0; i < nargs; i++) + CountOtherVarsRec(memo, t.arg(i), id, count); + } + else if (t.is_quantifier()) + CountOtherVarsRec(memo, t.body(), id, count); + } + + void NewElimCand(const Term &lhs, const Term &rhs){ + if(debug_gauss){ + std::cout << "mapping " << lhs << " to " << rhs << std::endl; + } + elim_cand cand; + cand.var = lhs; + cand.sup = 0; + cand.val = rhs; + elim_cands.push_back(cand); + cand_map[lhs] = elim_cands.size()-1; + } + + void MakeElimCand(const Term &lhs, const Term &rhs){ + if(eq(lhs,rhs)) + return; + if(!IsVar(lhs)){ + if(IsVar(rhs)){ + MakeElimCand(rhs,lhs); + return; + } + else{ + std::cout << "would have mapped a non-var\n"; + return; + } + } + if(IsVar(rhs) && VarNum(rhs) > VarNum(lhs)){ + MakeElimCand(rhs,lhs); + return; + } + if(keep.find(lhs) != keep.end()) + return; + if(cand_map.find(lhs) == cand_map.end()) + NewElimCand(lhs,rhs); + else { + int cand_idx = cand_map[lhs]; + if(IsVar(rhs) && cand_map.find(rhs) == cand_map.end() + && keep.find(rhs) == keep.end()) + NewElimCand(rhs,elim_cands[cand_idx].val); + elim_cands[cand_idx].val = rhs; + } + } + + Term FindRep(const Term &t){ + if(cand_map.find(t) == cand_map.end()) + return t; + Term &res = elim_cands[cand_map[t]].val; + if(IsVar(res)){ + assert(VarNum(res) < VarNum(t)); + res = FindRep(res); + return res; + } + return t; + } + + void GaussElimCheap(const std::vector &lits_in, + std::vector &lits_out){ + for(unsigned i = 0; i < lits_in.size(); i++){ + Term lit = lits_in[i]; + if(lit.is_app()){ + decl_kind k = lit.decl().get_decl_kind(); + if(k == Equal || k == Iff) + MakeElimCand(FindRep(lit.arg(0)),FindRep(lit.arg(1))); + } + } + + for(unsigned i = 0; i < elim_cands.size(); i++){ + elim_cand &cand = elim_cands[i]; + hash_map memo; + CountOtherVarsRec(memo,cand.val,i,cand.sup); + if(cand.sup == 0) + ready_cands.push_back(i); + } + + while(!ready_cands.empty()){ + elim_cand &cand = elim_cands[ready_cands.back()]; + ready_cands.pop_back(); + Term rep = FindRep(cand.var); + if(!eq(rep,cand.var)) + if(cand_map.find(rep) != cand_map.end()){ + int rep_pos = cand_map[rep]; + cand.val = elim_cands[rep_pos].val; + } + Term val = SubstRec(elim_map,cand.val); + if(debug_gauss){ + std::cout << "subbing " << cand.var << " --> " << val << std::endl; + } + elim_map[cand.var] = val; + std::vector &sup = sup_map[cand.var]; + for(unsigned i = 0; i < sup.size(); i++){ + int c = sup[i]; + if((--elim_cands[c].sup) == 0) + ready_cands.push_back(c); + } + } + + for(unsigned i = 0; i < lits_in.size(); i++){ + Term lit = lits_in[i]; + lit = SubstRec(elim_map,lit); + lit = lit.simplify(); + if(eq(lit,ctx.bool_val(true))) + continue; + Term a; + if(IsPropLit(lit,a)) + if(keep.find(lit) == keep.end()) + continue; + lits_out.push_back(lit); + } + } + + // maps variables to constrains in which the occur pos, neg + hash_map la_index[2]; + hash_map la_coeffs[2]; + std::vector la_pos_vars; + bool fixing; + + void IndexLAcoeff(const Term &coeff1, const Term &coeff2, Term t, int id){ + Term coeff = coeff1 * coeff2; + coeff = coeff.simplify(); + Term is_pos = (coeff >= ctx.int_val(0)); + is_pos = is_pos.simplify(); + if(eq(is_pos,ctx.bool_val(true))) + IndexLA(true,coeff,t, id); + else + IndexLA(false,coeff,t, id); + } + + void IndexLAremove(const Term &t){ + if(IsVar(t)){ + la_index[0][t] = -1; // means ineligible to be eliminated + la_index[1][t] = -1; // (more that one occurrence, or occurs not in linear comb) + } + else if(t.is_app()){ + int nargs = t.num_args(); + for(int i = 0; i < nargs; i++) + IndexLAremove(t.arg(i)); + } + // TODO: quantifiers? + } + + + void IndexLA(bool pos, const Term &coeff, const Term &t, int id){ + if(t.is_numeral()) + return; + if(t.is_app()){ + int nargs = t.num_args(); + switch(t.decl().get_decl_kind()){ + case Plus: + for(int i = 0; i < nargs; i++) + IndexLA(pos,coeff,t.arg(i), id); + break; + case Sub: + IndexLA(pos,coeff,t.arg(0), id); + IndexLA(!pos,coeff,t.arg(1), id); + break; + case Times: + if(t.arg(0).is_numeral()) + IndexLAcoeff(coeff,t.arg(0),t.arg(1), id); + else if(t.arg(1).is_numeral()) + IndexLAcoeff(coeff,t.arg(1),t.arg(0), id); + break; + default: + if(IsVar(t) && (fixing || la_index[pos].find(t) == la_index[pos].end())){ + la_index[pos][t] = id; + la_coeffs[pos][t] = coeff; + if(pos && !fixing) + la_pos_vars.push_back(t); // this means we only add a var once + } + else + IndexLAremove(t); + } + } + } + + void IndexLAstart(bool pos, const Term &t, int id){ + IndexLA(pos,(pos ? ctx.int_val(1) : ctx.int_val(-1)), t, id); + } + + void IndexLApred(bool pos, const Term &p, int id){ + if(p.is_app()){ + switch(p.decl().get_decl_kind()){ + case Not: + IndexLApred(!pos, p.arg(0),id); + break; + case Leq: + case Lt: + IndexLAstart(!pos, p.arg(0), id); + IndexLAstart(pos, p.arg(1), id); + break; + case Geq: + case Gt: + IndexLAstart(pos,p.arg(0), id); + IndexLAstart(!pos,p.arg(1), id); + break; + default: + IndexLAremove(p); + } + } + } + + void IndexLAfix(const Term &p, int id){ + fixing = true; + IndexLApred(true,p,id); + fixing = false; + } + + bool IsCanonIneq(const Term &lit, Term &term, Term &bound){ + // std::cout << Z3_simplify_get_help(ctx) << std::endl; + bool pos = lit.decl().get_decl_kind() != Not; + Term atom = pos ? lit : lit.arg(0); + if(atom.decl().get_decl_kind() == Leq){ + if(pos){ + bound = atom.arg(0); + term = atom.arg(1).simplify(simp_params); +#if Z3_MAJOR_VERSION < 4 + term = SortSum(term); +#endif + } + else { + bound = (atom.arg(1) + ctx.int_val(1)); + term = atom.arg(0); + // std::cout << "simplifying bound: " << bound << std::endl; + bound = bound.simplify(); + term = term.simplify(simp_params); +#if Z3_MAJOR_VERSION < 4 + term = SortSum(term); +#endif + } + return true; + } + else if(atom.decl().get_decl_kind() == Geq){ + if(pos){ + bound = atom.arg(1); // integer axiom + term = atom.arg(0).simplify(simp_params); +#if Z3_MAJOR_VERSION < 4 + term = SortSum(term); +#endif + return true; + } + else{ + bound = -(atom.arg(1) - ctx.int_val(1)); // integer axiom + term = -atom.arg(0); + bound = bound.simplify(); + term = term.simplify(simp_params); +#if Z3_MAJOR_VERSION < 4 + term = SortSum(term); +#endif + } + return true; + } + return false; + } + + Term CanonIneqTerm(const Term &p){ + Term term,bound; + Term ps = p.simplify(); + bool ok = IsCanonIneq(ps,term,bound); + assert(ok); + return term - bound; + } + + void ElimRedundantBounds(std::vector &lits){ + hash_map best_bound; + for(unsigned i = 0; i < lits.size(); i++){ + lits[i] = lits[i].simplify(simp_params); + Term term,bound; + if(IsCanonIneq(lits[i],term,bound)){ + if(best_bound.find(term) == best_bound.end()) + best_bound[term] = i; + else { + int best = best_bound[term]; + Term bterm,bbound; + IsCanonIneq(lits[best],bterm,bbound); + Term comp = bound > bbound; + comp = comp.simplify(); + if(eq(comp,ctx.bool_val(true))){ + lits[best] = ctx.bool_val(true); + best_bound[term] = i; + } + else { + lits[i] = ctx.bool_val(true); + } + } + } + } + } + + void FourierMotzkinCheap(const std::vector &lits_in, + std::vector &lits_out){ + simp_params.set(":som",true); + simp_params.set(":sort-sums",true); + fixing = false; lits_out = lits_in; + ElimRedundantBounds(lits_out); + for(unsigned i = 0; i < lits_out.size(); i++) + IndexLApred(true,lits_out[i],i); + + for(unsigned i = 0; i < la_pos_vars.size(); i++){ + Term var = la_pos_vars[i]; + if(la_index[false].find(var) != la_index[false].end()){ + int pos_idx = la_index[true][var]; + int neg_idx = la_index[false][var]; + if(pos_idx >= 0 && neg_idx >= 0){ + if(keep.find(var) != keep.end()){ + std::cout << "would have eliminated keep var\n"; + continue; + } + Term tpos = CanonIneqTerm(lits_out[pos_idx]); + Term tneg = CanonIneqTerm(lits_out[neg_idx]); + Term pos_coeff = la_coeffs[true][var]; + Term neg_coeff = -la_coeffs[false][var]; + Term comb = neg_coeff * tpos + pos_coeff * tneg; + Term ineq = ctx.int_val(0) <= comb; + ineq = ineq.simplify(); + lits_out[pos_idx] = ineq; + lits_out[neg_idx] = ctx.bool_val(true); + IndexLAfix(ineq,pos_idx); + } + } + } + } + + Term ProjectFormula(const Term &f){ + std::vector lits, new_lits1, new_lits2; + CollectConjuncts(f,lits); + timer_start("GaussElimCheap"); + GaussElimCheap(lits,new_lits1); + timer_stop("GaussElimCheap"); + timer_start("FourierMotzkinCheap"); + FourierMotzkinCheap(new_lits1,new_lits2); + timer_stop("FourierMotzkinCheap"); + return conjoin(new_lits2); + } + }; + + void Z3User::CollectConjuncts(const Term &f, std::vector &lits, bool pos){ + if(f.is_app() && f.decl().get_decl_kind() == Not) + CollectConjuncts(f.arg(0), lits, !pos); + else if(pos && f.is_app() && f.decl().get_decl_kind() == And){ + int num_args = f.num_args(); + for(int i = 0; i < num_args; i++) + CollectConjuncts(f.arg(i),lits,true); + } + else if(!pos && f.is_app() && f.decl().get_decl_kind() == Or){ + int num_args = f.num_args(); + for(int i = 0; i < num_args; i++) + CollectConjuncts(f.arg(i),lits,false); + } + else if(pos) + lits.push_back(f); + else + lits.push_back(!f); + } + + struct TermLt { + bool operator()(const expr &x, const expr &y){ + unsigned xid = x.get_id(); + unsigned yid = y.get_id(); + return xid < yid; + } + }; + + void Z3User::SortTerms(std::vector &terms){ + TermLt foo; + std::sort(terms.begin(),terms.end(),foo); + } + + Z3User::Term Z3User::SortSum(const Term &t){ + if(!(t.is_app() && t.decl().get_decl_kind() == Plus)) + return t; + int nargs = t.num_args(); + if(nargs < 2) return t; + std::vector args(nargs); + for(int i = 0; i < nargs; i++) + args[i] = t.arg(i); + SortTerms(args); + if(nargs == 2) + return args[0] + args[1]; + return sum(args); + } + + + RPFP::Term RPFP::ProjectFormula(std::vector &keep_vec, const Term &f){ + VariableProjector vp(*this,keep_vec); + return vp.ProjectFormula(f); + } + + /** Compute an underapproximation of every node in a tree rooted at "root", + based on a previously computed counterexample. The underapproximation + may contain free variables that are implicitly existentially quantified. + */ + + RPFP::Term RPFP::ComputeUnderapprox(Node *root, int persist){ + /* if terminated underapprox is empty set (false) */ + bool show_model = false; + if(show_model) + std::cout << dualModel << std::endl; + if(!root->Outgoing){ + root->Underapprox.SetEmpty(); + return ctx.bool_val(true); + } + /* if not used in cex, underapprox is empty set (false) */ + if(Empty(root)){ + root->Underapprox.SetEmpty(); + return ctx.bool_val(true); + } + /* compute underapprox of children first */ + std::vector &chs = root->Outgoing->Children; + std::vector chu(chs.size()); + for(unsigned i = 0; i < chs.size(); i++) + chu[i] = ComputeUnderapprox(chs[i],persist); + + Term b; std::vector v; + RedVars(root, b, v); + /* underapproximate the edge formula */ + hash_set dont_cares; + dont_cares.insert(b); + resolve_ite_memo.clear(); + timer_start("UnderapproxFormula"); + Term dual = root->Outgoing->dual.null() ? ctx.bool_val(true) : root->Outgoing->dual; + Term eu = UnderapproxFormula(dual,dont_cares); + timer_stop("UnderapproxFormula"); + /* combine with children */ + chu.push_back(eu); + eu = conjoin(chu); + /* project onto appropriate variables */ + eu = ProjectFormula(v,eu); + eu = eu.simplify(); + +#if 0 + /* check the result is consistent */ + { + hash_map memo; + int res = SubtermTruth(memo, eu); + if(res != 1) + throw "inconsistent projection"; + } +#endif + + /* rename to appropriate variable names */ + hash_map memo; + for (unsigned i = 0; i < v.size(); i++) + memo[v[i]] = root->Annotation.IndParams[i]; /* copy names from annotation */ + Term funder = SubstRec(memo, eu); + root->Underapprox = CreateRelation(root->Annotation.IndParams,funder); +#if 0 + if(persist) + Z3_persist_ast(ctx,root->Underapprox.Formula,persist); +#endif + return eu; + } + + void RPFP::FixCurrentState(Edge *edge){ + hash_set dont_cares; + resolve_ite_memo.clear(); + timer_start("UnderapproxFormula"); + Term dual = edge->dual.null() ? ctx.bool_val(true) : edge->dual; + Term eu = UnderapproxFormula(dual,dont_cares); + timer_stop("UnderapproxFormula"); + ConstrainEdgeLocalized(edge,eu); + } + + void RPFP::FixCurrentStateFull(Edge *edge, const expr &extra){ + hash_set dont_cares; + resolve_ite_memo.clear(); + timer_start("UnderapproxFormula"); + Term dual = edge->dual.null() ? ctx.bool_val(true) : edge->dual; + for(unsigned i = 0; i < edge->constraints.size(); i++) + dual = dual && edge->constraints[i]; + // dual = dual && extra; + Term eu = UnderapproxFullFormula(dual,dont_cares); + timer_stop("UnderapproxFormula"); + ConstrainEdgeLocalized(edge,eu); + } + + + + + void RPFP::GetGroundLitsUnderQuants(hash_set *memo, const Term &f, std::vector &res, int under){ + if(memo[under].find(f) != memo[under].end()) + return; + memo[under].insert(f); + if(f.is_app()){ + if(!under && !f.has_quantifiers()) + return; + decl_kind k = f.decl().get_decl_kind(); + if(k == And || k == Or || k == Implies || k == Iff){ + int num_args = f.num_args(); + for(int i = 0; i < num_args; i++) + GetGroundLitsUnderQuants(memo,f.arg(i),res,under); + return; + } + } + else if (f.is_quantifier()){ +#if 0 + // treat closed quantified formula as a literal 'cause we hate nested quantifiers + if(under && IsClosedFormula(f)) + res.push_back(f); + else +#endif + GetGroundLitsUnderQuants(memo,f.body(),res,1); + return; + } + if(under && f.is_ground()) + res.push_back(f); + } + + RPFP::Term RPFP::StrengthenFormulaByCaseSplitting(const Term &f, std::vector &case_lits){ + hash_set memo[2]; + std::vector lits; + GetGroundLitsUnderQuants(memo, f, lits, 0); + hash_set lits_hash; + for(unsigned i = 0; i < lits.size(); i++) + lits_hash.insert(lits[i]); + hash_map subst; + hash_map stt_memo; + std::vector conjuncts; + for(unsigned i = 0; i < lits.size(); i++){ + const expr &lit = lits[i]; + if(lits_hash.find(NegateLit(lit)) == lits_hash.end()){ + case_lits.push_back(lit); + bool tval = false; + expr atom = lit; + if(lit.is_app() && lit.decl().get_decl_kind() == Not){ + tval = true; + atom = lit.arg(0); + } + expr etval = ctx.bool_val(tval); + if(atom.is_quantifier()) + subst[atom] = etval; // this is a bit desperate, since we can't eval quants + else { + int b = SubtermTruth(stt_memo,atom); + if(b == (tval ? 1 : 0)) + subst[atom] = etval; + else { + if(b == 0 || b == 1){ + etval = ctx.bool_val(b ? true : false); + subst[atom] = etval; + conjuncts.push_back(b ? atom : !atom); + } + } + } + } + } + expr g = f; + if(!subst.empty()){ + g = SubstRec(subst,f); + if(conjuncts.size()) + g = g && ctx.make(And,conjuncts); + g = g.simplify(); + } +#if 1 + expr g_old = g; + g = RemoveRedundancy(g); + bool changed = !eq(g,g_old); + g = g.simplify(); + if(changed) { // a second pass can get some more simplification + g = RemoveRedundancy(g); + g = g.simplify(); + } +#else + g = RemoveRedundancy(g); + g = g.simplify(); +#endif + g = AdjustQuantifiers(g); + return g; + } + + RPFP::Term RPFP::ModelValueAsConstraint(const Term &t){ + if(t.is_array()){ + ArrayValue arr; + Term e = dualModel.eval(t); + EvalArrayTerm(e, arr); + if(arr.defined){ + std::vector cs; + for(std::map::iterator it = arr.entries.begin(), en = arr.entries.end(); it != en; ++it){ + expr foo = select(t,expr(ctx,it->first)) == expr(ctx,it->second); + cs.push_back(foo); + } + return conjoin(cs); + } + } + else { + expr r = dualModel.get_const_interp(t.decl()); + if(!r.null()){ + expr res = t == expr(ctx,r); + return res; + } + } + return ctx.bool_val(true); + } + + void RPFP::EvalNodeAsConstraint(Node *p, Transformer &res) + { + Term b; std::vector v; + RedVars(p, b, v); + std::vector args; + for(unsigned i = 0; i < v.size(); i++){ + expr val = ModelValueAsConstraint(v[i]); + if(!eq(val,ctx.bool_val(true))) + args.push_back(val); + } + expr cnst = conjoin(args); + hash_map memo; + for (unsigned i = 0; i < v.size(); i++) + memo[v[i]] = p->Annotation.IndParams[i]; /* copy names from annotation */ + Term funder = SubstRec(memo, cnst); + res = CreateRelation(p->Annotation.IndParams,funder); + } + +#if 0 + void RPFP::GreedyReduce(solver &s, std::vector &conjuncts){ + // verify + s.push(); + expr conj = ctx.make(And,conjuncts); + s.add(conj); + check_result res = s.check(); + if(res != unsat) + throw "should be unsat"; + s.pop(1); + + for(unsigned i = 0; i < conjuncts.size(); ){ + std::swap(conjuncts[i],conjuncts.back()); + expr save = conjuncts.back(); + conjuncts.pop_back(); + s.push(); + expr conj = ctx.make(And,conjuncts); + s.add(conj); + check_result res = s.check(); + s.pop(1); + if(res != unsat){ + conjuncts.push_back(save); + std::swap(conjuncts[i],conjuncts.back()); + i++; + } + } + } +#endif + + void RPFP::GreedyReduce(solver &s, std::vector &conjuncts){ + std::vector lits(conjuncts.size()); + for(unsigned i = 0; i < lits.size(); i++){ + func_decl pred = ctx.fresh_func_decl("@alit", ctx.bool_sort()); + lits[i] = pred(); + s.add(ctx.make(Implies,lits[i],conjuncts[i])); + } + + // verify + check_result res = s.check(lits.size(),&lits[0]); + if(res != unsat){ + // add the axioms in the off chance they are useful + const std::vector &theory = ls->get_axioms(); + for(unsigned i = 0; i < theory.size(); i++) + s.add(theory[i]); + if(s.check(lits.size(),&lits[0]) != unsat) + throw "should be unsat"; + } + + for(unsigned i = 0; i < conjuncts.size(); ){ + std::swap(conjuncts[i],conjuncts.back()); + std::swap(lits[i],lits.back()); + check_result res = s.check(lits.size()-1,&lits[0]); + if(res != unsat){ + std::swap(conjuncts[i],conjuncts.back()); + std::swap(lits[i],lits.back()); + i++; + } + else { + conjuncts.pop_back(); + lits.pop_back(); + } + } + } + + expr RPFP::NegateLit(const expr &f){ + if(f.is_app() && f.decl().get_decl_kind() == Not) + return f.arg(0); + else + return !f; + } + + void RPFP::NegateLits(std::vector &lits){ + for(unsigned i = 0; i < lits.size(); i++){ + expr &f = lits[i]; + if(f.is_app() && f.decl().get_decl_kind() == Not) + f = f.arg(0); + else + f = !f; + } + } + + expr RPFP::SimplifyOr(std::vector &lits){ + if(lits.size() == 0) + return ctx.bool_val(false); + if(lits.size() == 1) + return lits[0]; + return ctx.make(Or,lits); + } + + // set up edge constraint in aux solver + void RPFP::AddEdgeToSolver(Edge *edge){ + if(!edge->dual.null()) + aux_solver.add(edge->dual); + for(unsigned i = 0; i < edge->constraints.size(); i++){ + expr tl = edge->constraints[i]; + aux_solver.add(tl); + } + } + + void RPFP::InterpolateByCases(Node *root, Node *node){ + aux_solver.push(); + AddEdgeToSolver(node->Outgoing); + node->Annotation.SetEmpty(); + hash_set *core = new hash_set; + core->insert(node->Outgoing->dual); + while(1){ + aux_solver.push(); + expr annot = !GetAnnotation(node); + aux_solver.add(annot); + if(aux_solver.check() == unsat){ + aux_solver.pop(1); + break; + } + dualModel = aux_solver.get_model(); + aux_solver.pop(1); + Push(); + FixCurrentStateFull(node->Outgoing,annot); + ConstrainEdgeLocalized(node->Outgoing,!GetAnnotation(node)); + check_result foo = Check(root); + if(foo != unsat) + throw "should be unsat"; + AddToProofCore(*core); + Transformer old_annot = node->Annotation; + SolveSingleNode(root,node); + + { + expr itp = GetAnnotation(node); + dualModel = aux_solver.get_model(); + std::vector case_lits; + itp = StrengthenFormulaByCaseSplitting(itp, case_lits); + SetAnnotation(node,itp); + } + + if(node->Annotation.IsEmpty()){ + std::cout << "bad in InterpolateByCase -- core:\n"; + std::vector assumps; + slvr.get_proof().get_assumptions(assumps); + for(unsigned i = 0; i < assumps.size(); i++) + assumps[i].show(); + throw "ack!"; + } + Pop(1); + node->Annotation.UnionWith(old_annot); + } + if(proof_core) + delete proof_core; // shouldn't happen + proof_core = core; + aux_solver.pop(1); + } + + void RPFP::Generalize(Node *root, Node *node){ + aux_solver.push(); + AddEdgeToSolver(node->Outgoing); + expr fmla = GetAnnotation(node); + std::vector conjuncts; + CollectConjuncts(fmla,conjuncts,false); + GreedyReduce(aux_solver,conjuncts); // try to remove conjuncts one at a tme + aux_solver.pop(1); + NegateLits(conjuncts); + SetAnnotation(node,SimplifyOr(conjuncts)); + } + + /** Push a scope. Assertions made after Push can be undone by Pop. */ + + void RPFP::Push() + { + stack.push_back(stack_entry()); + slvr.push(); + } + + /** Pop a scope (see Push). Note, you cannot pop axioms. */ + + void RPFP::Pop(int num_scopes) + { + slvr.pop(num_scopes); + for (int i = 0; i < num_scopes; i++) + { + stack_entry &back = stack.back(); + for(std::list::iterator it = back.edges.begin(), en = back.edges.end(); it != en; ++it) + (*it)->dual = expr(ctx,NULL); + for(std::list::iterator it = back.nodes.begin(), en = back.nodes.end(); it != en; ++it) + (*it)->dual = expr(ctx,NULL); + for(std::list >::iterator it = back.constraints.begin(), en = back.constraints.end(); it != en; ++it) + (*it).first->constraints.pop_back(); + stack.pop_back(); + } + } + + /** Erase the proof by performing a Pop, Push and re-assertion of + all the popped constraints */ + + void RPFP::PopPush(){ + slvr.pop(1); + slvr.push(); + stack_entry &back = stack.back(); + for(std::list::iterator it = back.edges.begin(), en = back.edges.end(); it != en; ++it) + slvr.add((*it)->dual); + for(std::list::iterator it = back.nodes.begin(), en = back.nodes.end(); it != en; ++it) + slvr.add((*it)->dual); + for(std::list >::iterator it = back.constraints.begin(), en = back.constraints.end(); it != en; ++it) + slvr.add((*it).second); + } + + + + // This returns a new FuncDel with same sort as top-level function + // of term t, but with numeric suffix appended to name. + + Z3User::FuncDecl Z3User::SuffixFuncDecl(Term t, int n) + { + std::string name = t.decl().name().str() + "_" + string_of_int(n); + std::vector sorts; + int nargs = t.num_args(); + for(int i = 0; i < nargs; i++) + sorts.push_back(t.arg(i).get_sort()); + return ctx.function(name.c_str(), nargs, &sorts[0], t.get_sort()); + } + + // Scan the clause body for occurrences of the predicate unknowns + + RPFP::Term RPFP::ScanBody(hash_map &memo, + const Term &t, + hash_map &pmap, + std::vector &parms, + std::vector &nodes) + { + if(memo.find(t) != memo.end()) + return memo[t]; + Term res(ctx); + if (t.is_app()) { + func_decl f = t.decl(); + if(pmap.find(f) != pmap.end()){ + nodes.push_back(pmap[f]); + f = SuffixFuncDecl(t,parms.size()); + parms.push_back(f); + } + int nargs = t.num_args(); + std::vector args; + for(int i = 0; i < nargs; i++) + args.push_back(ScanBody(memo,t.arg(i),pmap,parms,nodes)); + res = f(nargs,&args[0]); + } + else if (t.is_quantifier()) + res = CloneQuantifier(t,ScanBody(memo,t.body(),pmap,parms,nodes)); + else + res = t; + memo[t] = res; + return res; + } + + // return the func_del of an app if it is uninterpreted + + bool Z3User::get_relation(const Term &t, func_decl &R){ + if(!t.is_app()) + return false; + R = t.decl(); + return R.get_decl_kind() == Uninterpreted; + } + + // return true if term is an individual variable + // TODO: have to check that it is not a background symbol + + bool Z3User::is_variable(const Term &t){ + if(!t.is_app()) + return false; + return t.decl().get_decl_kind() == Uninterpreted + && t.num_args() == 0; + } + + RPFP::Term RPFP::RemoveLabelsRec(hash_map &memo, const Term &t, + std::vector &lbls){ + if(memo.find(t) != memo.end()) + return memo[t]; + Term res(ctx); + if (t.is_app()){ + func_decl f = t.decl(); + std::vector names; + bool pos; + if(t.is_label(pos,names)){ + res = RemoveLabelsRec(memo,t.arg(0),lbls); + for(unsigned i = 0; i < names.size(); i++) + lbls.push_back(label_struct(names[i],res,pos)); + } + else { + int nargs = t.num_args(); + std::vector args; + for(int i = 0; i < nargs; i++) + args.push_back(RemoveLabelsRec(memo,t.arg(i),lbls)); + res = f(nargs,&args[0]); + } + } + else if (t.is_quantifier()) + res = CloneQuantifier(t,RemoveLabelsRec(memo,t.body(),lbls)); + else + res = t; + memo[t] = res; + return res; + } + + RPFP::Term RPFP::RemoveLabels(const Term &t, std::vector &lbls){ + hash_map memo ; + return RemoveLabelsRec(memo,t,lbls); + } + + RPFP::Term RPFP::SubstBoundRec(hash_map > &memo, hash_map &subst, int level, const Term &t) + { + std::pair foo(t,expr(ctx)); + std::pair::iterator, bool> bar = memo[level].insert(foo); + Term &res = bar.first->second; + if(!bar.second) return res; + if (t.is_app()) + { + func_decl f = t.decl(); + std::vector args; + int nargs = t.num_args(); + if(nargs == 0) + ls->declare_constant(f); // keep track of background constants + for(int i = 0; i < nargs; i++) + args.push_back(SubstBoundRec(memo, subst, level, t.arg(i))); + res = f(args.size(),&args[0]); + } + else if (t.is_quantifier()){ + int bound = t.get_quantifier_num_bound(); + std::vector pats; + t.get_patterns(pats); + for(unsigned i = 0; i < pats.size(); i++) + pats[i] = SubstBoundRec(memo, subst, level + bound, pats[i]); + res = clone_quantifier(t, SubstBoundRec(memo, subst, level + bound, t.body()), pats); + } + else if (t.is_var()) { + int idx = t.get_index_value(); + if(idx >= level && subst.find(idx-level) != subst.end()){ + res = subst[idx-level]; + } + else res = t; + } + else res = t; + return res; + } + + RPFP::Term RPFP::SubstBound(hash_map &subst, const Term &t){ + hash_map > memo; + return SubstBoundRec(memo, subst, 0, t); + } + + int Z3User::MaxIndex(hash_map &memo, const Term &t) + { + std::pair foo(t,-1); + std::pair::iterator, bool> bar = memo.insert(foo); + int &res = bar.first->second; + if(!bar.second) return res; + if (t.is_app()){ + func_decl f = t.decl(); + int nargs = t.num_args(); + for(int i = 0; i < nargs; i++){ + int m = MaxIndex(memo, t.arg(i)); + if(m > res) + res = m; + } + } + else if (t.is_quantifier()){ + int bound = t.get_quantifier_num_bound(); + res = MaxIndex(memo,t.body()) - bound; + } + else if (t.is_var()) { + res = t.get_index_value(); + } + return res; + } + + bool Z3User::IsClosedFormula(const Term &t){ + hash_map memo; + return MaxIndex(memo,t) < 0; + } + + + /** Convert a collection of clauses to Nodes and Edges in the RPFP. + + Predicate unknowns are uninterpreted predicates not + occurring in the background theory. + + Clauses are of the form + + B => P(t_1,...,t_k) + + where P is a predicate unknown and predicate unknowns + occur only positivey in H and only under existential + quantifiers in prenex form. + + Each predicate unknown maps to a node. Each clause maps to + an edge. Let C be a clause B => P(t_1,...,t_k) where the + sequence of predicate unknowns occurring in B (in order + of occurrence) is P_1..P_n. The clause maps to a transformer + T where: + + T.Relparams = P_1..P_n + T.Indparams = x_1...x+k + T.Formula = B /\ t_1 = x_1 /\ ... /\ t_k = x_k + + Throws exception bad_clause(msg,i) if a clause i is + in the wrong form. + + */ + + + static bool canonical_clause(const expr &clause){ + if(clause.decl().get_decl_kind() != Implies) + return false; + expr arg = clause.arg(1); + return arg.is_app() && (arg.decl().get_decl_kind() == False || + arg.decl().get_decl_kind() == Uninterpreted); + } + + void RPFP::FromClauses(const std::vector &unskolemized_clauses){ + hash_map pmap; + func_decl fail_pred = ctx.fresh_func_decl("@Fail", ctx.bool_sort()); + + std::vector clauses(unskolemized_clauses); + // first, skolemize the clauses + + for(unsigned i = 0; i < clauses.size(); i++){ + expr &t = clauses[i]; + if (t.is_quantifier() && t.is_quantifier_forall()) { + int bound = t.get_quantifier_num_bound(); + std::vector sorts; + std::vector names; + hash_map subst; + for(int j = 0; j < bound; j++){ + sort the_sort = t.get_quantifier_bound_sort(j); + symbol name = t.get_quantifier_bound_name(j); + expr skolem = ctx.constant(symbol(ctx,name),sort(ctx,the_sort)); + subst[bound-1-j] = skolem; + } + t = SubstBound(subst,t.body()); + } + } + + // create the nodes from the heads of the clauses + + for(unsigned i = 0; i < clauses.size(); i++){ + Term &clause = clauses[i]; + if(clause.is_app() && clause.decl().get_decl_kind() == Uninterpreted) + clause = implies(ctx.bool_val(true),clause); + if(!canonical_clause(clause)) + clause = implies((!clause).simplify(),ctx.bool_val(false)); + Term head = clause.arg(1); + func_decl R(ctx); + bool is_query = false; + if (eq(head,ctx.bool_val(false))){ + R = fail_pred; + // R = ctx.constant("@Fail", ctx.bool_sort()).decl(); + is_query = true; + } + else if(!get_relation(head,R)) + throw bad_clause("rhs must be a predicate application",i); + if(pmap.find(R) == pmap.end()){ + + // If the node doesn't exitst, create it. The Indparams + // are arbitrary, but we use the rhs arguments if they + // are variables for mnomonic value. + + hash_set seen; + std::vector Indparams; + for(unsigned j = 0; j < head.num_args(); j++){ + Term arg = head.arg(j); + if(!is_variable(arg) || seen.find(arg) != seen.end()){ + std::string name = std::string("@a_") + string_of_int(j); + arg = ctx.constant(name.c_str(),arg.get_sort()); + } + seen.insert(arg); + Indparams.push_back(arg); + } + Node *node = CreateNode(R(Indparams.size(),&Indparams[0])); + //nodes.push_back(node); + pmap[R] = node; + if (is_query) + node->Bound = CreateRelation(std::vector(), ctx.bool_val(false)); + } + } + + // create the edges + + for(unsigned i = 0; i < clauses.size(); i++){ + Term clause = clauses[i]; + Term body = clause.arg(0); + Term head = clause.arg(1); + func_decl R(ctx); + if (eq(head,ctx.bool_val(false))) + R = fail_pred; + //R = ctx.constant("@Fail", ctx.bool_sort()).decl(); + else get_relation(head,R); + Node *Parent = pmap[R]; + std::vector Indparams; + hash_set seen; + for(unsigned j = 0; j < head.num_args(); j++){ + Term arg = head.arg(j); + if(!is_variable(arg) || seen.find(arg) != seen.end()){ + std::string name = std::string("@a_") + string_of_int(j); + Term var = ctx.constant(name.c_str(),arg.get_sort()); + body = body && (arg == var); + arg = var; + } + seen.insert(arg); + Indparams.push_back(arg); + } + + // We extract the relparams positionally + + std::vector Relparams; + hash_map scan_memo; + std::vector Children; + body = ScanBody(scan_memo,body,pmap,Relparams,Children); + Term labeled = body; + std::vector lbls; // TODO: throw this away for now + body = RemoveLabels(body,lbls); + body = body.simplify(); + + // Create the edge + Transformer T = CreateTransformer(Relparams,Indparams,body); + Edge *edge = CreateEdge(Parent,T,Children); + edge->labeled = labeled;; // remember for label extraction + // edges.push_back(edge); + } + } + + + + void RPFP::WriteSolution(std::ostream &s){ + for(unsigned i = 0; i < nodes.size(); i++){ + Node *node = nodes[i]; + Term asgn = (node->Name)(node->Annotation.IndParams) == node->Annotation.Formula; + s << asgn << std::endl; + } + } + + void RPFP::WriteEdgeVars(Edge *e, hash_map &memo, const Term &t, std::ostream &s) + { + std::pair foo(t,0); + std::pair::iterator, bool> bar = memo.insert(foo); + // int &res = bar.first->second; + if(!bar.second) return; + hash_map::iterator it = e->varMap.find(t); + if (it != e->varMap.end()){ + return; + } + if (t.is_app()) + { + func_decl f = t.decl(); + // int idx; + int nargs = t.num_args(); + for(int i = 0; i < nargs; i++) + WriteEdgeVars(e, memo, t.arg(i),s); + if (nargs == 0 && f.get_decl_kind() == Uninterpreted && !ls->is_constant(f)){ + Term rename = HideVariable(t,e->number); + Term value = dualModel.eval(rename); + s << " (= " << t << " " << value << ")\n"; + } + } + else if (t.is_quantifier()) + WriteEdgeVars(e,memo,t.body(),s); + return; + } + + void RPFP::WriteEdgeAssignment(std::ostream &s, Edge *e){ + s << "(\n"; + hash_map memo; + WriteEdgeVars(e, memo, e->F.Formula ,s); + s << ")\n"; + } + + void RPFP::WriteCounterexample(std::ostream &s, Node *node){ + for(unsigned i = 0; i < node->Outgoing->Children.size(); i++){ + Node *child = node->Outgoing->Children[i]; + if(!Empty(child)) + WriteCounterexample(s,child); + } + s << "(" << node->number << " : " << EvalNode(node) << " <- "; + for(unsigned i = 0; i < node->Outgoing->Children.size(); i++){ + Node *child = node->Outgoing->Children[i]; + if(!Empty(child)) + s << " " << node->Outgoing->Children[i]->number; + } + s << ")" << std::endl; + WriteEdgeAssignment(s,node->Outgoing); + } + + RPFP::Term RPFP::ToRuleRec(Edge *e, hash_map &memo, const Term &t, std::vector &quants) + { + std::pair foo(t,expr(ctx)); + std::pair::iterator, bool> bar = memo.insert(foo); + Term &res = bar.first->second; + if(!bar.second) return res; + if (t.is_app()) + { + func_decl f = t.decl(); + // int idx; + std::vector args; + int nargs = t.num_args(); + for(int i = 0; i < nargs; i++) + args.push_back(ToRuleRec(e, memo, t.arg(i),quants)); + hash_map::iterator rit = e->relMap.find(f); + if(rit != e->relMap.end()){ + Node* child = e->Children[rit->second]; + FuncDecl op = child->Name; + res = op(args.size(),&args[0]); + } + else { + res = f(args.size(),&args[0]); + if(nargs == 0 && t.decl().get_decl_kind() == Uninterpreted) + quants.push_back(t); + } + } + else if (t.is_quantifier()) + { + Term body = ToRuleRec(e,memo,t.body(),quants); + res = CloneQuantifier(t,body); + } + else res = t; + return res; + } + + + void RPFP::ToClauses(std::vector &cnsts, FileFormat format){ + cnsts.resize(edges.size()); + for(unsigned i = 0; i < edges.size(); i++){ + Edge *edge = edges[i]; + SetEdgeMaps(edge); + std::vector quants; + hash_map memo; + Term lhs = ToRuleRec(edge, memo, edge->F.Formula,quants); + Term rhs = (edge->Parent->Name)(edge->F.IndParams.size(),&edge->F.IndParams[0]); + for(unsigned j = 0; j < edge->F.IndParams.size(); j++) + ToRuleRec(edge,memo,edge->F.IndParams[j],quants); // just to get quants + Term cnst = implies(lhs,rhs); +#if 0 + for(unsigned i = 0; i < quants.size(); i++){ + std::cout << expr(ctx,(Z3_ast)quants[i]) << " : " << sort(ctx,Z3_get_sort(ctx,(Z3_ast)quants[i])) << std::endl; + } +#endif + if(format != DualityFormat) + cnst= forall(quants,cnst); + cnsts[i] = cnst; + } + // int num_rules = cnsts.size(); + + for(unsigned i = 0; i < nodes.size(); i++){ + Node *node = nodes[i]; + if(!node->Bound.IsFull()){ + Term lhs = (node->Name)(node->Bound.IndParams) && !node->Bound.Formula; + Term cnst = implies(lhs,ctx.bool_val(false)); + if(format != DualityFormat){ + std::vector quants; + for(unsigned j = 0; j < node->Bound.IndParams.size(); j++) + quants.push_back(node->Bound.IndParams[j]); + if(format == HornFormat) + cnst= exists(quants,!cnst); + else + cnst= forall(quants, cnst); + } + cnsts.push_back(cnst); + } + } + + } + + void RPFP::AddToProofCore(hash_set &core){ + std::vector assumps; + slvr.get_proof().get_assumptions(assumps); + for(unsigned i = 0; i < assumps.size(); i++) + core.insert(assumps[i]); + } + + void RPFP::ComputeProofCore(){ + if(!proof_core){ + proof_core = new hash_set; + AddToProofCore(*proof_core); + } + } + + bool RPFP::EdgeUsedInProof(Edge *edge){ + ComputeProofCore(); + if(!edge->dual.null() && proof_core->find(edge->dual) != proof_core->end()) + return true; + for(unsigned i = 0; i < edge->constraints.size(); i++) + if(proof_core->find(edge->constraints[i]) != proof_core->end()) + return true; + return false; + } + +RPFP::~RPFP(){ + ClearProofCore(); + for(unsigned i = 0; i < nodes.size(); i++) + delete nodes[i]; + for(unsigned i = 0; i < edges.size(); i++) + delete edges[i]; + } +} + +#if 0 +void show_ast(expr *a){ + std::cout << *a; +} +#endif diff --git a/src/duality/duality_solver.cpp b/src/duality/duality_solver.cpp new file mode 100644 index 000000000..13c839186 --- /dev/null +++ b/src/duality/duality_solver.cpp @@ -0,0 +1,2518 @@ +/*++ +Copyright (c) 2012 Microsoft Corporation + +Module Name: + + duality_solver.h + +Abstract: + + implements relational post-fixedpoint problem + (RPFP) solver + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + + +--*/ + +#include "duality.h" +#include "duality_profiling.h" + +#include +#include +#include +#include + +// TODO: make these official options or get rid of them + +#define NEW_CAND_SEL +// #define LOCALIZE_CONJECTURES +// #define CANDS_FROM_UPDATES +#define CANDS_FROM_COVER_FAIL +#define DEPTH_FIRST_EXPAND +#define MINIMIZE_CANDIDATES +// #define MINIMIZE_CANDIDATES_HARDER +#define BOUNDED +#define CHECK_CANDS_FROM_IND_SET +#define UNDERAPPROX_NODES +#define NEW_EXPAND +#define EARLY_EXPAND +// #define TOP_DOWN +// #define EFFORT_BOUNDED_STRAT +#define SKIP_UNDERAPPROX_NODES + + +namespace Duality { + + // TODO: must be a better place for this... + static char string_of_int_buffer[20]; + + static const char *string_of_int(int n){ + sprintf(string_of_int_buffer,"%d",n); + return string_of_int_buffer; + } + + /** Generic object for producing diagnostic output. */ + + class Reporter { + protected: + RPFP *rpfp; + public: + Reporter(RPFP *_rpfp){ + rpfp = _rpfp; + } + virtual void Extend(RPFP::Node *node){} + virtual void Update(RPFP::Node *node, const RPFP::Transformer &update){} + virtual void Bound(RPFP::Node *node){} + virtual void Expand(RPFP::Edge *edge){} + virtual void AddCover(RPFP::Node *covered, std::vector &covering){} + virtual void RemoveCover(RPFP::Node *covered, RPFP::Node *covering){} + virtual void Conjecture(RPFP::Node *node, const RPFP::Transformer &t){} + virtual void Forcing(RPFP::Node *covered, RPFP::Node *covering){} + virtual void Dominates(RPFP::Node *node, RPFP::Node *other){} + virtual void InductionFailure(RPFP::Edge *edge, const std::vector &children){} + virtual void UpdateUnderapprox(RPFP::Node *node, const RPFP::Transformer &update){} + virtual void Reject(RPFP::Edge *edge, const std::vector &Children){} + virtual void Message(const std::string &msg){} + virtual ~Reporter(){} + }; + + Reporter *CreateStdoutReporter(RPFP *rpfp); + + /** Object we throw in case of catastrophe. */ + + struct InternalError { + std::string msg; + InternalError(const std::string _msg) + : msg(_msg) {} + }; + + + /** This is the main solver. It takes anarbitrary (possibly cyclic) + RPFP and either annotates it with a solution, or returns a + counterexample derivation in the form of an embedd RPFP tree. */ + + class Duality : public Solver { + + public: + Duality(RPFP *_rpfp) + : ctx(_rpfp->ctx), + slvr(_rpfp->slvr), + nodes(_rpfp->nodes), + edges(_rpfp->edges) + { + rpfp = _rpfp; + reporter = 0; + heuristic = 0; + FullExpand = false; + NoConj = false; + FeasibleEdges = true; + UseUnderapprox = true; + Report = false; + StratifiedInlining = false; + RecursionBound = -1; + } + + typedef RPFP::Node Node; + typedef RPFP::Edge Edge; + + /** This struct represents a candidate for extending the + unwinding. It consists of an edge to instantiate + and a vector of children for the new instance. */ + + struct Candidate { + Edge *edge; std::vector + Children; + }; + + /** Comparison operator, allowing us to sort Nodes + by their number field. */ + + struct lnode + { + bool operator()(const Node* s1, const Node* s2) const + { + return s1->number < s2->number; + } + }; + + typedef std::set Unexpanded; // sorted set of Nodes + + /** This class provides a heuristic for expanding a derivation + tree. */ + + class Heuristic { + RPFP *rpfp; + + /** Heuristic score for unwinding nodes. Currently this + counts the number of updates. */ + struct score { + int updates; + score() : updates(0) {} + }; + hash_map scores; + + public: + Heuristic(RPFP *_rpfp){ + rpfp = _rpfp; + } + + virtual ~Heuristic(){} + + virtual void Update(RPFP::Node *node){ + scores[node].updates++; + } + + /** Heuristic choice of nodes to expand. Takes a set "choices" + and returns a subset "best". We currently choose the + nodes with the fewest updates. + */ +#if 0 + virtual void ChooseExpand(const std::set &choices, std::set &best){ + int best_score = INT_MAX; + for(std::set::iterator it = choices.begin(), en = choices.end(); it != en; ++it){ + Node *node = (*it)->map; + int score = scores[node].updates; + best_score = std::min(best_score,score); + } + for(std::set::iterator it = choices.begin(), en = choices.end(); it != en; ++it) + if(scores[(*it)->map].updates == best_score) + best.insert(*it); + } +#else + virtual void ChooseExpand(const std::set &choices, std::set &best, bool high_priority=false, bool best_only=false){ + if(high_priority) return; + int best_score = INT_MAX; + int worst_score = 0; + for(std::set::iterator it = choices.begin(), en = choices.end(); it != en; ++it){ + Node *node = (*it)->map; + int score = scores[node].updates; + best_score = std::min(best_score,score); + worst_score = std::max(worst_score,score); + } + int cutoff = best_only ? best_score : (best_score + (worst_score-best_score)/2); + for(std::set::iterator it = choices.begin(), en = choices.end(); it != en; ++it) + if(scores[(*it)->map].updates <= cutoff) + best.insert(*it); + } +#endif + + /** Called when done expanding a tree */ + virtual void Done() {} + }; + + + class Covering; // see below + + // These members represent the state of the algorithm. + + RPFP *rpfp; // the input RPFP + Reporter *reporter; // object for logging + Heuristic *heuristic; // expansion heuristic + context &ctx; // Z3 context + solver &slvr; // Z3 solver + std::vector &nodes; // Nodes of input RPFP + std::vector &edges; // Edges of input RPFP + std::vector leaves; // leaf nodes of unwinding (unused) + Unexpanded unexpanded; // unexpanded nodes + std::list candidates; // candidates for expansion + // maps children to edges in input RPFP + hash_map > edges_by_child; + // maps each node in input RPFP to its expanded instances + hash_map > insts_of_node; + // maps each node in input RPFP to all its instances + hash_map > all_of_node; + RPFP *unwinding; // the unwinding + Covering *indset; // proposed inductive subset + Counterexample cex; // counterexample + std::list to_expand; + hash_set updated_nodes; + hash_map underapprox_map; // maps underapprox nodes to the nodes they approximate + int last_decisions; + hash_set overapproxes; + +#ifdef BOUNDED + struct Counter { + int val; + Counter(){val = 0;} + }; + typedef std::map NodeToCounter; + hash_map back_edges; // counts of back edges +#endif + + /** Solve the problem. */ + virtual bool Solve(){ + reporter = Report ? CreateStdoutReporter(rpfp) : new Reporter(rpfp); +#ifndef LOCALIZE_CONJECTURES + heuristic = !cex.tree ? new Heuristic(rpfp) : new ReplayHeuristic(rpfp,cex); +#else + heuristic = !cex.tree ? (Heuristic *)(new LocalHeuristic(rpfp)) + : (Heuristic *)(new ReplayHeuristic(rpfp,cex)); +#endif + cex.tree = 0; // heuristic now owns it + unwinding = new RPFP(rpfp->ls); + unwinding->HornClauses = rpfp->HornClauses; + indset = new Covering(this); + last_decisions = 0; + CreateEdgesByChildMap(); + CreateLeaves(); +#ifndef TOP_DOWN + if(!StratifiedInlining){ + if(FeasibleEdges)NullaryCandidates(); + else InstantiateAllEdges(); + } +#else + for(unsigned i = 0; i < leaves.size(); i++) + if(!SatisfyUpperBound(leaves[i])) + return false; +#endif + StratifiedLeafCount = -1; + timer_start("SolveMain"); + bool res = SolveMain(); // does the actual work + timer_stop("SolveMain"); + // print_profile(std::cout); + delete indset; + delete heuristic; + delete unwinding; + delete reporter; + return res; + } + + void Cancel(){ + // TODO + } + +#if 0 + virtual void Restart(RPFP *_rpfp){ + rpfp = _rpfp; + delete unwinding; + nodes = _rpfp->nodes; + edges = _rpfp->edges; + leaves.clear(); + unexpanded.clear(); // unexpanded nodes + candidates.clear(); // candidates for expansion + edges_by_child.clear(); + insts_of_node.clear(); + all_of_node.clear(); + to_expand.clear(); + } +#endif + + virtual void LearnFrom(Counterexample &old_cex){ + cex = old_cex; + } + + /** Return the counterexample */ + virtual Counterexample GetCounterexample(){ + Counterexample res = cex; + cex.tree = 0; // Cex now belongs to caller + return res; + } + + // options + bool FullExpand; // do not use partial expansion of derivation tree + bool NoConj; // do not use conjectures (no forced covering) + bool FeasibleEdges; // use only feasible edges in unwinding + bool UseUnderapprox; // use underapproximations + bool Report; // spew on stdout + bool StratifiedInlining; // Do stratified inlining as preprocessing step + int RecursionBound; // Recursion bound for bounded verification + + bool SetBoolOption(bool &opt, const std::string &value){ + if(value == "0") { + opt = false; + return true; + } + if(value == "1") { + opt = true; + return true; + } + return false; + } + + bool SetIntOption(int &opt, const std::string &value){ + opt = atoi(value.c_str()); + return true; + } + + /** Set options (not currently used) */ + virtual bool SetOption(const std::string &option, const std::string &value){ + if(option == "full_expand"){ + return SetBoolOption(FullExpand,value); + } + if(option == "no_conj"){ + return SetBoolOption(NoConj,value); + } + if(option == "feasible_edges"){ + return SetBoolOption(FeasibleEdges,value); + } + if(option == "use_underapprox"){ + return SetBoolOption(UseUnderapprox,value); + } + if(option == "report"){ + return SetBoolOption(Report,value); + } + if(option == "stratified_inlining"){ + return SetBoolOption(StratifiedInlining,value); + } + if(option == "recursion_bound"){ + return SetIntOption(RecursionBound,value); + } + return false; + } + + /** Create an instance of a node in the unwinding. Set its + annotation to true, and mark it unexpanded. */ + Node* CreateNodeInstance(Node *node, int number = 0){ + RPFP::Node *inst = unwinding->CloneNode(node); + inst->Annotation.SetFull(); + if(number < 0) inst->number = number; + unexpanded.insert(inst); + all_of_node[node].push_back(inst); + return inst; + } + + /** Create an instance of an edge in the unwinding, with given + parent and children. */ + void CreateEdgeInstance(Edge *edge, Node *parent, const std::vector &children){ + RPFP::Edge *inst = unwinding->CreateEdge(parent,edge->F,children); + inst->map = edge; + } + + void MakeLeaf(Node *node, bool do_not_expand = false){ + node->Annotation.SetEmpty(); + Edge *e = unwinding->CreateLowerBoundEdge(node); +#ifdef TOP_DOWN + node->Annotation.SetFull(); // allow this node to cover others +#endif + if(StratifiedInlining) + node->Annotation.SetFull(); // allow this node to cover others + else + updated_nodes.insert(node); + e->map = 0; + reporter->Extend(node); +#ifdef EARLY_EXPAND + if(!do_not_expand) + TryExpandNode(node); +#endif + // e->F.SetEmpty(); + } + + void MakeOverapprox(Node *node){ + node->Annotation.SetFull(); + Edge *e = unwinding->CreateLowerBoundEdge(node); + overapproxes.insert(node); + e->map = 0; + } + + /** We start the unwinding with leaves that under-approximate + each relation with false. */ + void CreateLeaves(){ + unexpanded.clear(); + leaves.clear(); + for(unsigned i = 0; i < nodes.size(); i++){ + RPFP::Node *node = CreateNodeInstance(nodes[i]); + if(0 && nodes[i]->Outgoing->Children.size() == 0) + CreateEdgeInstance(nodes[i]->Outgoing,node,std::vector()); + else { + if(!StratifiedInlining) + MakeLeaf(node); + else { + MakeOverapprox(node); + LeafMap[nodes[i]] = node; + } + } + leaves.push_back(node); + } + } + + /** Create the map from children to edges in the input RPFP. This + is used to generate candidates for expansion. */ + void CreateEdgesByChildMap(){ + edges_by_child.clear(); + for(unsigned i = 0; i < edges.size(); i++){ + Edge *e = edges[i]; + std::set done; + for(unsigned j = 0; j < e->Children.size(); j++){ + Node *c = e->Children[j]; + if(done.find(c) == done.end()) // avoid duplicates + edges_by_child[c].push_back(e); + done.insert(c); + } + } + } + + void NullaryCandidates(){ + for(unsigned i = 0; i < edges.size(); i++){ + RPFP::Edge *edge = edges[i]; + if(edge->Children.size() == 0){ + Candidate cand; + cand.edge = edge; + candidates.push_back(cand); + } + } + } + + void InstantiateAllEdges(){ + hash_map leaf_map; + for(unsigned i = 0; i < leaves.size(); i++){ + leaf_map[leaves[i]->map] = leaves[i]; + insts_of_node[leaves[i]->map].push_back(leaves[i]); + } + unexpanded.clear(); + for(unsigned i = 0; i < edges.size(); i++){ + Edge *edge = edges[i]; + Candidate c; c.edge = edge; + c.Children.resize(edge->Children.size()); + for(unsigned j = 0; j < c.Children.size(); j++) + c.Children[j] = leaf_map[edge->Children[j]]; + Extend(c); + } + for(Unexpanded::iterator it = unexpanded.begin(), en = unexpanded.end(); it != en; ++it) + indset->Add(*it); + for(unsigned i = 0; i < leaves.size(); i++){ + std::vector &foo = insts_of_node[leaves[i]->map]; + foo.erase(foo.begin()); + } + } + + bool ProducedBySI(Edge *edge, std::vector &children){ + if(LeafMap.find(edge->Parent) == LeafMap.end()) return false; + Node *other = LeafMap[edge->Parent]; + if(other->Outgoing->map != edge) return false; + std::vector &ochs = other->Outgoing->Children; + for(unsigned i = 0; i < children.size(); i++) + if(ochs[i] != children[i]) return false; + return true; + } + + /** Add a candidate for expansion, but not if Stratified inlining has already + produced it */ + + void AddCandidate(Edge *edge, std::vector &children){ + if(StratifiedInlining && ProducedBySI(edge,children)) + return; + candidates.push_back(Candidate()); + candidates.back().edge = edge; + candidates.back().Children = children; + } + + /** Generate candidates for expansion, given a vector of candidate + sets for each argument position. This recursively produces + the cross product. + */ + void GenCandidatesRec(int pos, Edge *edge, + const std::vector > &vec, + std::vector &children){ + if(pos == (int)vec.size()){ + AddCandidate(edge,children); + } + else { + for(unsigned i = 0; i < vec[pos].size(); i++){ + children[pos] = vec[pos][i]; + GenCandidatesRec(pos+1,edge,vec,children); + } + } + } + + /** Setup for above recursion. */ + void GenCandidates(int pos, Edge *edge, + const std::vector > &vec){ + std::vector children(vec.size()); + GenCandidatesRec(0,edge,vec,children); + } + + /** Expand a node. We find all the candidates for expansion using + this node and other already expanded nodes. This is a little + tricky, since a node may be used for multiple argument + positions of an edge, and we don't want to produce duplicates. + */ + +#ifndef NEW_EXPAND + void ExpandNode(Node *node){ + std::vector &nedges = edges_by_child[node->map]; + for(unsigned i = 0; i < nedges.size(); i++){ + Edge *edge = nedges[i]; + for(unsigned npos = 0; npos < edge->Children.size(); ++npos){ + if(edge->Children[npos] == node->map){ + std::vector > vec(edge->Children.size()); + vec[npos].push_back(node); + for(unsigned j = 0; j < edge->Children.size(); j++){ + if(j != npos){ + std::vector &insts = insts_of_node[edge->Children[j]]; + for(unsigned k = 0; k < insts.size(); k++) + if(indset->Candidate(insts[k])) + vec[j].push_back(insts[k]); + } + if(j < npos && edge->Children[j] == node->map) + vec[j].push_back(node); + } + GenCandidates(0,edge,vec); + } + } + } + unexpanded.erase(node); + insts_of_node[node->map].push_back(node); + } +#else + /** If the current proposed solution is not inductive, + use the induction failure to generate candidates for extension. */ + void ExpandNode(Node *node){ + unexpanded.erase(node); + insts_of_node[node->map].push_back(node); + timer_start("GenCandIndFailUsing"); + std::vector &nedges = edges_by_child[node->map]; + for(unsigned i = 0; i < nedges.size(); i++){ + Edge *edge = nedges[i]; + slvr.push(); + RPFP *checker = new RPFP(rpfp->ls); + Node *root = CheckerJustForEdge(edge,checker,true); + if(root){ + expr using_cond = ctx.bool_val(false); + for(unsigned npos = 0; npos < edge->Children.size(); ++npos) + if(edge->Children[npos] == node->map) + using_cond = using_cond || checker->Localize(root->Outgoing->Children[npos]->Outgoing,NodeMarker(node)); + slvr.add(using_cond); + if(checker->Check(root) != unsat){ + Candidate candidate; + ExtractCandidateFromCex(edge,checker,root,candidate); + reporter->InductionFailure(edge,candidate.Children); + candidates.push_back(candidate); + } + } + slvr.pop(1); + delete checker; + } + timer_stop("GenCandIndFailUsing"); + } +#endif + + void ExpandNodeFromOther(Node *node, Node *other){ + std::vector &in = other->Incoming; + for(unsigned i = 0; i < in.size(); i++){ + Edge *edge = in[i]; + Candidate cand; + cand.edge = edge->map; + cand.Children = edge->Children; + for(unsigned j = 0; j < cand.Children.size(); j++) + if(cand.Children[j] == other) + cand.Children[j] = node; + candidates.push_front(cand); + } + // unexpanded.erase(node); + // insts_of_node[node->map].push_back(node); + } + + /** Expand a node based on some uncovered node it dominates. + This pushes cahdidates onto the *front* of the candidate + queue, so these expansions are done depth-first. */ + bool ExpandNodeFromCoverFail(Node *node){ + if(!node->Outgoing || node->Outgoing->Children.size() == 0) + return false; + Node *other = indset->GetSimilarNode(node); + if(!other) + return false; +#ifdef UNDERAPPROX_NODES + Node *under_node = CreateUnderapproxNode(node); + underapprox_map[under_node] = node; + indset->CoverByNode(node,under_node); + ExpandNodeFromOther(under_node,other); + ExpandNode(under_node); +#else + ExpandNodeFromOther(node,other); + unexpanded.erase(node); + insts_of_node[node->map].push_back(node); +#endif + return true; + } + + + /** Make a boolean variable to act as a "marker" for a node. */ + expr NodeMarker(Node *node){ + std::string name = std::string("@m_") + string_of_int(node->number); + return ctx.constant(name.c_str(),ctx.bool_sort()); + } + + /** Union the annotation of dst into src. If with_markers is + true, we conjoin the annotation formula of dst with its + marker. This allows us to discover which disjunct is + true in a satisfying assignment. */ + void UnionAnnotations(RPFP::Transformer &dst, Node *src, bool with_markers = false){ + if(!with_markers) + dst.UnionWith(src->Annotation); + else { + RPFP::Transformer t = src->Annotation; + t.Formula = t.Formula && NodeMarker(src); + dst.UnionWith(t); + } + } + + void GenNodeSolutionFromIndSet(Node *node, RPFP::Transformer &annot, bool with_markers = false){ + annot.SetEmpty(); + std::vector &insts = insts_of_node[node]; + for(unsigned j = 0; j < insts.size(); j++) + if(indset->Contains(insts[j])) + UnionAnnotations(annot,insts[j],with_markers); + annot.Simplify(); + } + + /** Generate a proposed solution of the input RPFP from + the unwinding, by unioning the instances of each node. */ + void GenSolutionFromIndSet(bool with_markers = false){ + for(unsigned i = 0; i < nodes.size(); i++){ + Node *node = nodes[i]; + GenNodeSolutionFromIndSet(node,node->Annotation,with_markers); + } + } + +#ifdef BOUNDED + bool NodePastRecursionBound(Node *node){ + if(RecursionBound < 0) return false; + NodeToCounter &backs = back_edges[node]; + for(NodeToCounter::iterator it = backs.begin(), en = backs.end(); it != en; ++it){ + if(it->second.val > RecursionBound) + return true; + } + return false; + } +#endif + + /** Test whether a given extension candidate actually represents + an induction failure. Right now we approximate this: if + the resulting node in the unwinding could be labeled false, + it clearly is not an induction failure. */ + + bool CandidateFeasible(const Candidate &cand){ + if(!FeasibleEdges) return true; + timer_start("CandidateFeasible"); + RPFP *checker = new RPFP(rpfp->ls); + // std::cout << "Checking feasibility of extension " << cand.edge->Parent->number << std::endl; + checker->Push(); + std::vector chs(cand.Children.size()); + Node *root = checker->CloneNode(cand.edge->Parent); +#ifdef BOUNDED + for(unsigned i = 0; i < cand.Children.size(); i++) + if(NodePastRecursionBound(cand.Children[i])){ + timer_stop("CandidateFeasible"); + return false; + } +#endif +#ifdef NEW_CAND_SEL + GenNodeSolutionFromIndSet(cand.edge->Parent,root->Bound); +#else + root->Bound.SetEmpty(); +#endif + checker->AssertNode(root); + for(unsigned i = 0; i < cand.Children.size(); i++) + chs[i] = checker->CloneNode(cand.Children[i]); + Edge *e = checker->CreateEdge(root,cand.edge->F,chs); + checker->AssertEdge(e,0,true); + // std::cout << "Checking SAT: " << e->dual << std::endl; + bool res = checker->Check(root) != unsat; + // std::cout << "Result: " << res << std::endl; + if(!res)reporter->Reject(cand.edge,cand.Children); + checker->Pop(1); + delete checker; + timer_stop("CandidateFeasible"); + return res; + } + + + /* For stratified inlining, we need a topological sort of the + nodes. */ + + hash_map TopoSort; + int TopoSortCounter; + + void DoTopoSortRec(Node *node){ + if(TopoSort.find(node) != TopoSort.end()) + return; + TopoSort[node] = TopoSortCounter++; // just to break cycles + Edge *edge = node->Outgoing; // note, this is just *one* outgoing edge + if(edge){ + std::vector &chs = edge->Children; + for(unsigned i = 0; i < chs.size(); i++) + DoTopoSortRec(chs[i]); + } + TopoSort[node] = TopoSortCounter++; + } + + void DoTopoSort(){ + TopoSort.clear(); + TopoSortCounter = 0; + for(unsigned i = 0; i < nodes.size(); i++) + DoTopoSortRec(nodes[i]); + } + + /** In stratified inlining, we build the unwinding from the bottom + down, trying to satisfy the node bounds. We do this as a pre-pass, + limiting the expansion. If we get a counterexample, we are done, + else we continue as usual expanding the unwinding upward. + */ + + int StratifiedLeafCount; + + bool DoStratifiedInlining(){ + timer_start("StratifiedInlining"); + DoTopoSort(); + for(unsigned i = 0; i < leaves.size(); i++){ + Node *node = leaves[i]; + bool res = SatisfyUpperBound(node); + if(!res){ + timer_stop("StratifiedInlining"); + return false; + } + } + // don't leave any dangling nodes! +#ifndef EFFORT_BOUNDED_STRAT + for(unsigned i = 0; i < leaves.size(); i++) + if(!leaves[i]->Outgoing) + MakeLeaf(leaves[i],true); +#endif + timer_stop("StratifiedInlining"); + return true; + } + + /** Here, we do the downward expansion for stratified inlining */ + + hash_map LeafMap, StratifiedLeafMap; + + Edge *GetNodeOutgoing(Node *node, int last_decs = 0){ + if(overapproxes.find(node) == overapproxes.end()) return node->Outgoing; /* already expanded */ + overapproxes.erase(node); +#ifdef EFFORT_BOUNDED_STRAT + if(last_decs > 5000){ + // RPFP::Transformer save = node->Annotation; + node->Annotation.SetEmpty(); + Edge *e = unwinding->CreateLowerBoundEdge(node); + // node->Annotation = save; + insts_of_node[node->map].push_back(node); + // std::cout << "made leaf: " << node->number << std::endl; + return e; + } +#endif + Edge *edge = node->map->Outgoing; + std::vector &chs = edge->Children; + + // make sure we don't create a covered node in this process! + + for(unsigned i = 0; i < chs.size(); i++){ + Node *child = chs[i]; + if(TopoSort[child] < TopoSort[node->map]){ + Node *leaf = LeafMap[child]; + if(!indset->Contains(leaf)) + return node->Outgoing; + } + } + + std::vector nchs(chs.size()); + for(unsigned i = 0; i < chs.size(); i++){ + Node *child = chs[i]; + if(TopoSort[child] < TopoSort[node->map]){ + Node *leaf = LeafMap[child]; + nchs[i] = leaf; + if(unexpanded.find(leaf) != unexpanded.end()){ + unexpanded.erase(leaf); + insts_of_node[child].push_back(leaf); + } + } + else { + if(StratifiedLeafMap.find(child) == StratifiedLeafMap.end()){ + RPFP::Node *nchild = CreateNodeInstance(child,StratifiedLeafCount--); + MakeLeaf(nchild); + nchild->Annotation.SetEmpty(); + StratifiedLeafMap[child] = nchild; + indset->SetDominated(nchild); + } + nchs[i] = StratifiedLeafMap[child]; + } + } + CreateEdgeInstance(edge,node,nchs); + reporter->Extend(node); + return node->Outgoing; + } + + void SetHeuristicOldNode(Node *node){ + LocalHeuristic *h = dynamic_cast(heuristic); + if(h) + h->SetOldNode(node); + } + + /** This does the actual solving work. We try to generate + candidates for extension. If we succed, we extend the + unwinding. If we fail, we have a solution. */ + bool SolveMain(){ + if(StratifiedInlining && !DoStratifiedInlining()) + return false; +#ifdef BOUNDED + DoTopoSort(); +#endif + while(true){ + timer_start("ProduceCandidatesForExtension"); + ProduceCandidatesForExtension(); + timer_stop("ProduceCandidatesForExtension"); + if(candidates.empty()){ + GenSolutionFromIndSet(); + return true; + } + Candidate cand = candidates.front(); + candidates.pop_front(); + if(CandidateFeasible(cand)) + if(!Extend(cand)) + return false; + } + } + + // hack: put something local into the underapproximation formula + // without this, interpolants can be pretty bad + void AddThing(expr &conj){ + std::string name = "@thing"; + expr thing = ctx.constant(name.c_str(),ctx.bool_sort()); + if(conj.is_app() && conj.decl().get_decl_kind() == And){ + std::vector conjs(conj.num_args()+1); + for(unsigned i = 0; i+1 < conjs.size(); i++) + conjs[i] = conj.arg(i); + conjs[conjs.size()-1] = thing; + conj = rpfp->conjoin(conjs); + } + } + + + Node *CreateUnderapproxNode(Node *node){ + // cex.tree->ComputeUnderapprox(cex.root,0); + RPFP::Node *under_node = CreateNodeInstance(node->map /* ,StratifiedLeafCount-- */); + under_node->Annotation.IntersectWith(cex.root->Underapprox); + AddThing(under_node->Annotation.Formula); + Edge *e = unwinding->CreateLowerBoundEdge(under_node); + under_node->Annotation.SetFull(); // allow this node to cover others + back_edges[under_node] = back_edges[node]; + e->map = 0; + reporter->Extend(under_node); + return under_node; + } + + /** Try to prove a conjecture about a node. If successful + update the unwinding annotation appropriately. */ + bool ProveConjecture(Node *node, const RPFP::Transformer &t,Node *other = 0, Counterexample *_cex = 0){ + reporter->Conjecture(node,t); + timer_start("ProveConjecture"); + RPFP::Transformer save = node->Bound; + node->Bound.IntersectWith(t); + +#ifndef LOCALIZE_CONJECTURES + bool ok = SatisfyUpperBound(node); +#else + SetHeuristicOldNode(other); + bool ok = SatisfyUpperBound(node); + SetHeuristicOldNode(0); +#endif + + if(ok){ + timer_stop("ProveConjecture"); + return true; + } +#ifdef UNDERAPPROX_NODES + if(UseUnderapprox && last_decisions > 500){ + std::cout << "making an underapprox\n"; + ExpandNodeFromCoverFail(node); + } +#endif + if(_cex) *_cex = cex; + else delete cex.tree; // delete the cex if not required + cex.tree = 0; + node->Bound = save; // put back original bound + timer_stop("ProveConjecture"); + return false; + } + + /** If a node is part of the inductive subset, expand it. + We ask the inductive subset to exclude the node if possible. + */ + void TryExpandNode(RPFP::Node *node){ + if(indset->Close(node)) return; + if(!NoConj && indset->Conjecture(node)){ +#ifdef UNDERAPPROX_NODES + /* TODO: temporary fix. this prevents an infinite loop in case + the node is covered by multiple others. This should be + removed when covering by a set is implemented. + */ + if(indset->Contains(node)){ + unexpanded.erase(node); + insts_of_node[node->map].push_back(node); + } +#endif + return; + } +#ifdef UNDERAPPROX_NODES + if(!indset->Contains(node)) + return; // could be covered by an underapprox node +#endif + indset->Add(node); +#if defined(CANDS_FROM_COVER_FAIL) && !defined(UNDERAPPROX_NODES) + if(ExpandNodeFromCoverFail(node)) + return; +#endif + ExpandNode(node); + } + + /** Make the conjunction of markers for all (expanded) instances of + a node in the input RPFP. */ + expr AllNodeMarkers(Node *node){ + expr res = ctx.bool_val(true); + std::vector &insts = insts_of_node[node]; + for(int k = insts.size()-1; k >= 0; k--) + res = res && NodeMarker(insts[k]); + return res; + } + + void RuleOutNodesPastBound(Node *node, RPFP::Transformer &t){ +#ifdef BOUNDED + if(RecursionBound < 0)return; + std::vector &insts = insts_of_node[node]; + for(unsigned i = 0; i < insts.size(); i++) + if(NodePastRecursionBound(insts[i])) + t.Formula = t.Formula && !NodeMarker(insts[i]); +#endif + } + + + void GenNodeSolutionWithMarkersAux(Node *node, RPFP::Transformer &annot, expr &marker_disjunction){ +#ifdef BOUNDED + if(RecursionBound >= 0 && NodePastRecursionBound(node)) + return; +#endif + RPFP::Transformer temp = node->Annotation; + expr marker = NodeMarker(node); + temp.Formula = (!marker || temp.Formula); + annot.IntersectWith(temp); + marker_disjunction = marker_disjunction || marker; + } + + bool GenNodeSolutionWithMarkers(Node *node, RPFP::Transformer &annot, bool expanded_only = false){ + bool res = false; + annot.SetFull(); + expr marker_disjunction = ctx.bool_val(false); + std::vector &insts = expanded_only ? insts_of_node[node] : all_of_node[node]; + for(unsigned j = 0; j < insts.size(); j++){ + Node *node = insts[j]; + if(indset->Contains(insts[j])){ + GenNodeSolutionWithMarkersAux(node, annot, marker_disjunction); res = true; + } + } + annot.Formula = annot.Formula && marker_disjunction; + annot.Simplify(); + return res; + } + + /** Make a checker to determine if an edge in the input RPFP + is satisfied. */ + Node *CheckerJustForEdge(Edge *edge, RPFP *checker, bool expanded_only = false){ + Node *root = checker->CloneNode(edge->Parent); + GenNodeSolutionFromIndSet(edge->Parent, root->Bound); + if(root->Bound.IsFull()) + return 0; + checker->AssertNode(root); + std::vector cs; + for(unsigned j = 0; j < edge->Children.size(); j++){ + Node *oc = edge->Children[j]; + Node *nc = checker->CloneNode(oc); + if(!GenNodeSolutionWithMarkers(oc,nc->Annotation,expanded_only)) + return 0; + Edge *e = checker->CreateLowerBoundEdge(nc); + checker->AssertEdge(e); + cs.push_back(nc); + } + checker->AssertEdge(checker->CreateEdge(root,edge->F,cs)); + return root; + } + +#ifndef MINIMIZE_CANDIDATES_HARDER + +#if 0 + /** Make a checker to detheermine if an edge in the input RPFP + is satisfied. */ + Node *CheckerForEdge(Edge *edge, RPFP *checker){ + Node *root = checker->CloneNode(edge->Parent); + root->Bound = edge->Parent->Annotation; + root->Bound.Formula = (!AllNodeMarkers(edge->Parent)) || root->Bound.Formula; + checker->AssertNode(root); + std::vector cs; + for(unsigned j = 0; j < edge->Children.size(); j++){ + Node *oc = edge->Children[j]; + Node *nc = checker->CloneNode(oc); + nc->Annotation = oc->Annotation; + RuleOutNodesPastBound(oc,nc->Annotation); + Edge *e = checker->CreateLowerBoundEdge(nc); + checker->AssertEdge(e); + cs.push_back(nc); + } + checker->AssertEdge(checker->CreateEdge(root,edge->F,cs)); + return root; + } + +#else + /** Make a checker to determine if an edge in the input RPFP + is satisfied. */ + Node *CheckerForEdge(Edge *edge, RPFP *checker){ + Node *root = checker->CloneNode(edge->Parent); + GenNodeSolutionFromIndSet(edge->Parent, root->Bound); +#if 0 + if(root->Bound.IsFull()) + return = 0; +#endif + checker->AssertNode(root); + std::vector cs; + for(unsigned j = 0; j < edge->Children.size(); j++){ + Node *oc = edge->Children[j]; + Node *nc = checker->CloneNode(oc); + GenNodeSolutionWithMarkers(oc,nc->Annotation,true); + Edge *e = checker->CreateLowerBoundEdge(nc); + checker->AssertEdge(e); + cs.push_back(nc); + } + checker->AssertEdge(checker->CreateEdge(root,edge->F,cs)); + return root; + } +#endif + + /** If an edge is not satisfied, produce an extension candidate + using instances of its children that violate the parent annotation. + We find these using the marker predicates. */ + void ExtractCandidateFromCex(Edge *edge, RPFP *checker, Node *root, Candidate &candidate){ + candidate.edge = edge; + for(unsigned j = 0; j < edge->Children.size(); j++){ + Edge *lb = root->Outgoing->Children[j]->Outgoing; + std::vector &insts = insts_of_node[edge->Children[j]]; +#ifndef MINIMIZE_CANDIDATES + for(int k = insts.size()-1; k >= 0; k--) +#else + for(unsigned k = 0; k < insts.size(); k++) +#endif + { + Node *inst = insts[k]; + if(indset->Contains(inst)){ + if(checker->Empty(lb->Parent) || + eq(checker->Eval(lb,NodeMarker(inst)),ctx.bool_val(true))){ + candidate.Children.push_back(inst); + goto next_child; + } + } + } + throw InternalError("No candidate from induction failure"); + next_child:; + } + } +#else + + + /** Make a checker to determine if an edge in the input RPFP + is satisfied. */ + Node *CheckerForEdge(Edge *edge, RPFP *checker){ + Node *root = checker->CloneNode(edge->Parent); + GenNodeSolutionFromIndSet(edge->Parent, root->Bound); + if(root->Bound.IsFull()) + return = 0; + checker->AssertNode(root); + std::vector cs; + for(unsigned j = 0; j < edge->Children.size(); j++){ + Node *oc = edge->Children[j]; + Node *nc = checker->CloneNode(oc); + GenNodeSolutionWithMarkers(oc,nc->Annotation,true); + Edge *e = checker->CreateLowerBoundEdge(nc); + checker->AssertEdge(e); + cs.push_back(nc); + } + checker->AssertEdge(checker->CreateEdge(root,edge->F,cs)); + return root; + } + + /** If an edge is not satisfied, produce an extension candidate + using instances of its children that violate the parent annotation. + We find these using the marker predicates. */ + void ExtractCandidateFromCex(Edge *edge, RPFP *checker, Node *root, Candidate &candidate){ + candidate.edge = edge; + std::vector assumps; + for(unsigned j = 0; j < edge->Children.size(); j++){ + Edge *lb = root->Outgoing->Children[j]->Outgoing; + std::vector &insts = insts_of_node[edge->Children[j]]; + for(unsigned k = 0; k < insts.size(); k++) + { + Node *inst = insts[k]; + expr marker = NodeMarker(inst); + if(indset->Contains(inst)){ + if(checker->Empty(lb->Parent) || + eq(checker->Eval(lb,marker),ctx.bool_val(true))){ + candidate.Children.push_back(inst); + assumps.push_back(checker->Localize(lb,marker)); + goto next_child; + } + assumps.push_back(checker->Localize(lb,marker)); + if(checker->CheckUpdateModel(root,assumps) != unsat){ + candidate.Children.push_back(inst); + goto next_child; + } + assumps.pop_back(); + } + } + throw InternalError("No candidate from induction failure"); + next_child:; + } + } + +#endif + + + /** If the current proposed solution is not inductive, + use the induction failure to generate candidates for extension. */ + void GenCandidatesFromInductionFailure(bool full_scan = false){ + timer_start("GenCandIndFail"); + GenSolutionFromIndSet(true /* add markers */); + for(unsigned i = 0; i < edges.size(); i++){ + Edge *edge = edges[i]; + if(!full_scan && updated_nodes.find(edge->Parent) == updated_nodes.end()) + continue; + slvr.push(); + RPFP *checker = new RPFP(rpfp->ls); + Node *root = CheckerForEdge(edge,checker); + if(checker->Check(root) != unsat){ + Candidate candidate; + ExtractCandidateFromCex(edge,checker,root,candidate); + reporter->InductionFailure(edge,candidate.Children); + candidates.push_back(candidate); + } + slvr.pop(1); + delete checker; + } + updated_nodes.clear(); + timer_stop("GenCandIndFail"); +#ifdef CHECK_CANDS_FROM_IND_SET + for(std::list::iterator it = candidates.begin(), en = candidates.end(); it != en; ++it){ + if(!CandidateFeasible(*it)) + throw "produced infeasible candidate"; + } +#endif + if(!full_scan && candidates.empty()){ + reporter->Message("No candidates from updates. Trying full scan."); + GenCandidatesFromInductionFailure(true); + } + } + +#ifdef CANDS_FROM_UPDATES + /** If the given edge is not inductive in the current proposed solution, + use the induction failure to generate candidates for extension. */ + void GenCandidatesFromEdgeInductionFailure(RPFP::Edge *edge){ + GenSolutionFromIndSet(true /* add markers */); + for(unsigned i = 0; i < edges.size(); i++){ + slvr.push(); + Edge *edge = edges[i]; + RPFP *checker = new RPFP(rpfp->ls); + Node *root = CheckerForEdge(edge,checker); + if(checker->Check(root) != unsat){ + Candidate candidate; + ExtractCandidateFromCex(edge,checker,root,candidate); + reporter->InductionFailure(edge,candidate.Children); + candidates.push_back(candidate); + } + slvr.pop(1); + delete checker; + } + } +#endif + + /** Find the unexpanded nodes in the inductive subset. */ + void FindNodesToExpand(){ + for(Unexpanded::iterator it = unexpanded.begin(), en = unexpanded.end(); it != en; ++it){ + Node *node = *it; + if(indset->Candidate(node)) + to_expand.push_back(node); + } + } + + /** Try to create some extension candidates from the unexpanded + nodes. */ + void ProduceSomeCandidates(){ + while(candidates.empty() && !to_expand.empty()){ + Node *node = to_expand.front(); + to_expand.pop_front(); + TryExpandNode(node); + } + } + + std::list postponed_candidates; + + /** Try to produce some extension candidates, first from unexpanded + nides, and if this fails, from induction failure. */ + void ProduceCandidatesForExtension(){ + if(candidates.empty()) + ProduceSomeCandidates(); + while(candidates.empty()){ + FindNodesToExpand(); + if(to_expand.empty()) break; + ProduceSomeCandidates(); + } + if(candidates.empty()){ +#ifdef DEPTH_FIRST_EXPAND + if(postponed_candidates.empty()){ + GenCandidatesFromInductionFailure(); + postponed_candidates.swap(candidates); + } + if(!postponed_candidates.empty()){ + candidates.push_back(postponed_candidates.front()); + postponed_candidates.pop_front(); + } +#else + GenCandidatesFromInductionFailure(); +#endif + } + } + + bool UpdateNodeToNode(Node *node, Node *top){ + if(!node->Annotation.SubsetEq(top->Annotation)){ + reporter->Update(node,top->Annotation); + indset->Update(node,top->Annotation); + updated_nodes.insert(node->map); + node->Annotation.IntersectWith(top->Annotation); + return true; + } + return false; + } + + /** Update the unwinding solution, using an interpolant for the + derivation tree. */ + void UpdateWithInterpolant(Node *node, RPFP *tree, Node *top){ + if(top->Outgoing) + for(unsigned i = 0; i < top->Outgoing->Children.size(); i++) + UpdateWithInterpolant(node->Outgoing->Children[i],tree,top->Outgoing->Children[i]); + UpdateNodeToNode(node, top); + heuristic->Update(node); + } + + /** Update unwinding lower bounds, using a counterexample. */ + + void UpdateWithCounterexample(Node *node, RPFP *tree, Node *top){ + if(top->Outgoing) + for(unsigned i = 0; i < top->Outgoing->Children.size(); i++) + UpdateWithCounterexample(node->Outgoing->Children[i],tree,top->Outgoing->Children[i]); + if(!top->Underapprox.SubsetEq(node->Underapprox)){ + reporter->UpdateUnderapprox(node,top->Underapprox); + // indset->Update(node,top->Annotation); + node->Underapprox.UnionWith(top->Underapprox); + heuristic->Update(node); + } + } + + /** Try to update the unwinding to satisfy the upper bound of a + node. */ + bool SatisfyUpperBound(Node *node){ + if(node->Bound.IsFull()) return true; + reporter->Bound(node); + int start_decs = rpfp->CumulativeDecisions(); + DerivationTree *dtp = new DerivationTreeSlow(this,unwinding,reporter,heuristic,FullExpand); + DerivationTree &dt = *dtp; + bool res = dt.Derive(unwinding,node,UseUnderapprox); + int end_decs = rpfp->CumulativeDecisions(); + // std::cout << "decisions: " << (end_decs - start_decs) << std::endl; + last_decisions = end_decs - start_decs; + if(res){ + cex.tree = dt.tree; + cex.root = dt.top; + if(UseUnderapprox){ + UpdateWithCounterexample(node,dt.tree,dt.top); + } + } + else { + UpdateWithInterpolant(node,dt.tree,dt.top); + delete dt.tree; + } + delete dtp; + return !res; + } + + /* If the counterexample derivation is partial due to + use of underapproximations, complete it. */ + + void BuildFullCex(Node *node){ + DerivationTree dt(this,unwinding,reporter,heuristic,FullExpand); + bool res = dt.Derive(unwinding,node,UseUnderapprox,true); // build full tree + if(!res) throw "Duality internal error in BuildFullCex"; + if(cex.tree) + delete cex.tree; + cex.tree = dt.tree; + cex.root = dt.top; + } + + void UpdateBackEdges(Node *node){ +#ifdef BOUNDED + std::vector &chs = node->Outgoing->Children; + for(unsigned i = 0; i < chs.size(); i++){ + Node *child = chs[i]; + bool is_back = TopoSort[child->map] >= TopoSort[node->map]; + NodeToCounter &nov = back_edges[node]; + NodeToCounter chv = back_edges[child]; + if(is_back) + chv[child->map].val++; + for(NodeToCounter::iterator it = chv.begin(), en = chv.end(); it != en; ++it){ + Node *back = it->first; + Counter &c = nov[back]; + c.val = std::max(c.val,it->second.val); + } + } +#endif + } + + /** Extend the unwinding, keeping it solved. */ + bool Extend(Candidate &cand){ + timer_start("Extend"); + Node *node = CreateNodeInstance(cand.edge->Parent); + CreateEdgeInstance(cand.edge,node,cand.Children); + UpdateBackEdges(node); + reporter->Extend(node); + bool res = SatisfyUpperBound(node); + if(res) indset->CloseDescendants(node); + else { +#ifdef UNDERAPPROX_NODES + ExpandUnderapproxNodes(cex.tree, cex.root); +#endif + if(UseUnderapprox) BuildFullCex(node); + timer_stop("Extend"); + return res; + } +#ifdef EARLY_EXPAND + TryExpandNode(node); +#endif + timer_stop("Extend"); + return res; + } + + void ExpandUnderapproxNodes(RPFP *tree, Node *root){ + Node *node = root->map; + if(underapprox_map.find(node) != underapprox_map.end()){ + RPFP::Transformer cnst = root->Annotation; + tree->EvalNodeAsConstraint(root, cnst); + cnst.Complement(); + Node *orig = underapprox_map[node]; + RPFP::Transformer save = orig->Bound; + orig->Bound = cnst; + DerivationTree dt(this,unwinding,reporter,heuristic,FullExpand); + bool res = dt.Derive(unwinding,orig,UseUnderapprox,true,tree); + if(!res){ + UpdateWithInterpolant(orig,dt.tree,dt.top); + throw "bogus underapprox!"; + } + ExpandUnderapproxNodes(tree,dt.top); + } + else if(root->Outgoing){ + std::vector &chs = root->Outgoing->Children; + for(unsigned i = 0; i < chs.size(); i++) + ExpandUnderapproxNodes(tree,chs[i]); + } + } + + + /** This class represents a derivation tree. */ + class DerivationTree { + public: + + DerivationTree(Duality *_duality, RPFP *rpfp, Reporter *_reporter, Heuristic *_heuristic, bool _full_expand) + : slvr(rpfp->slvr), + ctx(rpfp->ctx) + { + duality = _duality; + reporter = _reporter; + heuristic = _heuristic; + full_expand = _full_expand; + } + + Duality *duality; + Reporter *reporter; + Heuristic *heuristic; + solver &slvr; + context &ctx; + RPFP *tree; + RPFP::Node *top; + std::list leaves; + bool full_expand; + bool underapprox; + bool constrained; + bool false_approx; + std::vector underapprox_core; + int start_decs, last_decs; + + /* We build derivation trees in one of three modes: + + 1) In normal mode, we build the full tree without considering + underapproximations. + + 2) In underapprox mode, we use underapproximations to cut off + the tree construction. THis means the resulting tree may not + be complete. + + 3) In constrained mode, we build the full tree but use + underapproximations as upper bounds. This mode is used to + complete the partial derivation constructed in underapprox + mode. + */ + + bool Derive(RPFP *rpfp, RPFP::Node *root, bool _underapprox, bool _constrained = false, RPFP *_tree = 0){ + underapprox = _underapprox; + constrained = _constrained; + false_approx = true; + timer_start("Derive"); + tree = _tree ? _tree : new RPFP(rpfp->ls); + tree->HornClauses = rpfp->HornClauses; + tree->Push(); // so we can clear out the solver later when finished + top = CreateApproximatedInstance(root); + tree->AssertNode(top); // assert the negation of the top-level spec + timer_start("Build"); + bool res = Build(); + heuristic->Done(); + timer_stop("Build"); + timer_start("Pop"); + tree->Pop(1); + timer_stop("Pop"); + timer_stop("Derive"); + return res; + } + +#define WITH_CHILDREN + + Node *CreateApproximatedInstance(RPFP::Node *from){ + Node *to = tree->CloneNode(from); + to->Annotation = from->Annotation; +#ifndef WITH_CHILDREN + tree->CreateLowerBoundEdge(to); +#endif + leaves.push_back(to); + return to; + } + + bool CheckWithUnderapprox(){ + timer_start("CheckWithUnderapprox"); + std::vector leaves_vector(leaves.size()); + std::copy(leaves.begin(),leaves.end(),leaves_vector.begin()); + check_result res = tree->Check(top,leaves_vector); + timer_stop("CheckWithUnderapprox"); + return res != unsat; + } + + virtual bool Build(){ +#ifdef EFFORT_BOUNDED_STRAT + start_decs = tree->CumulativeDecisions(); +#endif + while(ExpandSomeNodes(true)); // do high-priority expansions + while (true) + { +#ifndef WITH_CHILDREN + timer_start("asserting leaves"); + timer_start("pushing"); + tree->Push(); + timer_stop("pushing"); + for(std::list::iterator it = leaves.begin(), en = leaves.end(); it != en; ++it) + tree->AssertEdge((*it)->Outgoing,1); // assert the overapproximation, and keep it past pop + timer_stop("asserting leaves"); + lbool res = tree->Solve(top, 2); // incremental solve, keep interpolants for two pops + timer_start("popping leaves"); + tree->Pop(1); + timer_stop("popping leaves"); +#else + lbool res; + if((underapprox || false_approx) && top->Outgoing && CheckWithUnderapprox()){ + if(constrained) goto expand_some_nodes; // in constrained mode, keep expanding + goto we_are_sat; // else if underapprox is sat, we stop + } + // tree->Check(top); + res = tree->Solve(top, 1); // incremental solve, keep interpolants for one pop +#endif + if (res == l_false) + return false; + + expand_some_nodes: + if(ExpandSomeNodes()) + continue; + + we_are_sat: + if(underapprox && !constrained){ + timer_start("ComputeUnderapprox"); + tree->ComputeUnderapprox(top,1); + timer_stop("ComputeUnderapprox"); + } + else { +#ifdef UNDERAPPROX_NODES +#ifndef SKIP_UNDERAPPROX_NODES + timer_start("ComputeUnderapprox"); + tree->ComputeUnderapprox(top,1); + timer_stop("ComputeUnderapprox"); +#endif +#endif + } + return true; + } + } + + virtual void ExpandNode(RPFP::Node *p){ + // tree->RemoveEdge(p->Outgoing); + Edge *edge = duality->GetNodeOutgoing(p->map,last_decs); + std::vector &cs = edge->Children; + std::vector children(cs.size()); + for(unsigned i = 0; i < cs.size(); i++) + children[i] = CreateApproximatedInstance(cs[i]); + Edge *ne = tree->CreateEdge(p, p->map->Outgoing->F, children); + ne->map = p->map->Outgoing->map; +#ifndef WITH_CHILDREN + tree->AssertEdge(ne); // assert the edge in the solver +#else + tree->AssertEdge(ne,0,!full_expand,(underapprox || false_approx)); // assert the edge in the solver +#endif + reporter->Expand(ne); + } + +#define UNDERAPPROXCORE +#ifndef UNDERAPPROXCORE + void ExpansionChoices(std::set &best){ + std::set choices; + for(std::list::iterator it = leaves.begin(), en = leaves.end(); it != en; ++it) + if (!tree->Empty(*it)) // if used in the counter-model + choices.insert(*it); + heuristic->ChooseExpand(choices, best); + } +#else +#if 0 + + void ExpansionChoices(std::set &best){ + std::vector unused_set, used_set; + std::set choices; + for(std::list::iterator it = leaves.begin(), en = leaves.end(); it != en; ++it){ + Node *n = *it; + if (!tree->Empty(n)) + used_set.push_back(n); + else + unused_set.push_back(n); + } + if(tree->Check(top,unused_set) == unsat) + throw "error in ExpansionChoices"; + for(unsigned i = 0; i < used_set.size(); i++){ + Node *n = used_set[i]; + unused_set.push_back(n); + if(!top->Outgoing || tree->Check(top,unused_set) == unsat){ + unused_set.pop_back(); + choices.insert(n); + } + else + std::cout << "Using underapprox of " << n->number << std::endl; + } + heuristic->ChooseExpand(choices, best); + } +#else + void ExpansionChoicesFull(std::set &best, bool high_priority, bool best_only = false){ + std::set choices; + for(std::list::iterator it = leaves.begin(), en = leaves.end(); it != en; ++it) + if (high_priority || !tree->Empty(*it)) // if used in the counter-model + choices.insert(*it); + heuristic->ChooseExpand(choices, best, high_priority, best_only); + } + + void ExpansionChoicesRec(std::vector &unused_set, std::vector &used_set, + std::set &choices, int from, int to){ + if(from == to) return; + int orig_unused = unused_set.size(); + unused_set.resize(orig_unused + (to - from)); + std::copy(used_set.begin()+from,used_set.begin()+to,unused_set.begin()+orig_unused); + if(!top->Outgoing || tree->Check(top,unused_set) == unsat){ + unused_set.resize(orig_unused); + if(to - from == 1){ +#if 1 + std::cout << "Not using underapprox of " << used_set[from] ->number << std::endl; +#endif + choices.insert(used_set[from]); + } + else { + int mid = from + (to - from)/2; + ExpansionChoicesRec(unused_set, used_set, choices, from, mid); + ExpansionChoicesRec(unused_set, used_set, choices, mid, to); + } + } + else { +#if 1 + std::cout << "Using underapprox of "; + for(int i = from; i < to; i++){ + std::cout << used_set[i]->number << " "; + if(used_set[i]->map->Underapprox.IsEmpty()) + std::cout << "(false!) "; + } + std::cout << std::endl; +#endif + } + } + + std::set old_choices; + + void ExpansionChoices(std::set &best, bool high_priority, bool best_only = false){ + if(!underapprox || constrained || high_priority){ + ExpansionChoicesFull(best, high_priority,best_only); + return; + } + std::vector unused_set, used_set; + std::set choices; + for(std::list::iterator it = leaves.begin(), en = leaves.end(); it != en; ++it){ + Node *n = *it; + if (!tree->Empty(n)){ + if(old_choices.find(n) != old_choices.end() || n->map->Underapprox.IsEmpty()) + choices.insert(n); + else + used_set.push_back(n); + } + else + unused_set.push_back(n); + } + if(tree->Check(top,unused_set) == unsat) + throw "error in ExpansionChoices"; + ExpansionChoicesRec(unused_set, used_set, choices, 0, used_set.size()); + old_choices = choices; + heuristic->ChooseExpand(choices, best, high_priority); + } +#endif +#endif + + bool ExpandSomeNodes(bool high_priority = false, int max = INT_MAX){ +#ifdef EFFORT_BOUNDED_STRAT + last_decs = tree->CumulativeDecisions() - start_decs; +#endif + timer_start("ExpandSomeNodes"); + timer_start("ExpansionChoices"); + std::set choices; + ExpansionChoices(choices,high_priority,max != INT_MAX); + timer_stop("ExpansionChoices"); + std::list leaves_copy = leaves; // copy so can modify orig + leaves.clear(); + int count = 0; + for(std::list::iterator it = leaves_copy.begin(), en = leaves_copy.end(); it != en; ++it){ + if(choices.find(*it) != choices.end() && count < max){ + count++; + ExpandNode(*it); + } + else leaves.push_back(*it); + } + timer_stop("ExpandSomeNodes"); + return !choices.empty(); + } + + void RemoveExpansion(RPFP::Node *p){ + Edge *edge = p->Outgoing; + Node *parent = edge->Parent; + std::vector cs = edge->Children; + tree->DeleteEdge(edge); + for(unsigned i = 0; i < cs.size(); i++) + tree->DeleteNode(cs[i]); + leaves.push_back(parent); + } + }; + + class DerivationTreeSlow : public DerivationTree { + public: + + struct stack_entry { + unsigned level; // SMT solver stack level + std::vector expansions; + }; + + std::vector stack; + + hash_map updates; + + DerivationTreeSlow(Duality *_duality, RPFP *rpfp, Reporter *_reporter, Heuristic *_heuristic, bool _full_expand) + : DerivationTree(_duality, rpfp, _reporter, _heuristic, _full_expand) { + stack.push_back(stack_entry()); + } + + virtual bool Build(){ + + stack.back().level = tree->slvr.get_scope_level(); + + while (true) + { + lbool res; + + unsigned slvr_level = tree->slvr.get_scope_level(); + if(slvr_level != stack.back().level) + throw "stacks out of sync!"; + + // res = tree->Solve(top, 1); // incremental solve, keep interpolants for one pop + check_result foo = tree->Check(top); + res = foo == unsat ? l_false : l_true; + + if (res == l_false) { + if (stack.empty()) // should never happen + return false; + + { + std::vector &expansions = stack.back().expansions; + int update_count = 0; + for(unsigned i = 0; i < expansions.size(); i++){ + Node *node = expansions[i]; + tree->SolveSingleNode(top,node); + if(expansions.size() == 1 && NodeTooComplicated(node)) + SimplifyNode(node); + tree->Generalize(top,node); + if(RecordUpdate(node)) + update_count++; + } + if(update_count == 0) + reporter->Message("backtracked without learning"); + } + tree->ComputeProofCore(); // need to compute the proof core before popping solver + while(1) { + std::vector &expansions = stack.back().expansions; + bool prev_level_used = LevelUsedInProof(stack.size()-2); // need to compute this before pop + tree->Pop(1); + hash_set leaves_to_remove; + for(unsigned i = 0; i < expansions.size(); i++){ + Node *node = expansions[i]; + // if(node != top) + // tree->ConstrainParent(node->Incoming[0],node); + std::vector &cs = node->Outgoing->Children; + for(unsigned i = 0; i < cs.size(); i++){ + leaves_to_remove.insert(cs[i]); + UnmapNode(cs[i]); + if(std::find(updated_nodes.begin(),updated_nodes.end(),cs[i]) != updated_nodes.end()) + throw "help!"; + } + } + RemoveLeaves(leaves_to_remove); // have to do this before actually deleting the children + for(unsigned i = 0; i < expansions.size(); i++){ + Node *node = expansions[i]; + RemoveExpansion(node); + } + stack.pop_back(); + if(prev_level_used || stack.size() == 1) break; + RemoveUpdateNodesAtCurrentLevel(); // this level is about to be deleted -- remove its children from update list + std::vector &unused_ex = stack.back().expansions; + for(unsigned i = 0; i < unused_ex.size(); i++) + heuristic->Update(unused_ex[i]->map); // make it less likely to expand this node in future + } + HandleUpdatedNodes(); + if(stack.size() == 1) + return false; + } + else { + tree->Push(); + std::vector &expansions = stack.back().expansions; + for(unsigned i = 0; i < expansions.size(); i++){ + tree->FixCurrentState(expansions[i]->Outgoing); + } +#if 0 + if(tree->slvr.check() == unsat) + throw "help!"; +#endif + stack.push_back(stack_entry()); + stack.back().level = tree->slvr.get_scope_level(); + if(ExpandSomeNodes(false,1)){ + continue; + } + while(stack.size() > 1){ + tree->Pop(1); + stack.pop_back(); + } + return true; + } + } + } + + bool NodeTooComplicated(Node *node){ + return tree->CountOperators(node->Annotation.Formula) > 3; + } + + void SimplifyNode(Node *node){ + // have to destroy the old proof to get a new interpolant + tree->PopPush(); + tree->InterpolateByCases(top,node); + } + + bool LevelUsedInProof(unsigned level){ + std::vector &expansions = stack[level].expansions; + for(unsigned i = 0; i < expansions.size(); i++) + if(tree->EdgeUsedInProof(expansions[i]->Outgoing)) + return true; + return false; + } + + void RemoveUpdateNodesAtCurrentLevel() { + for(std::list::iterator it = updated_nodes.begin(), en = updated_nodes.end(); it != en;){ + Node *node = *it; + if(AtCurrentStackLevel(node->Incoming[0]->Parent)){ + std::list::iterator victim = it; + ++it; + updated_nodes.erase(victim); + } + else + ++it; + } + } + + void RemoveLeaves(hash_set &leaves_to_remove){ + std::list leaves_copy; + leaves_copy.swap(leaves); + for(std::list::iterator it = leaves_copy.begin(), en = leaves_copy.end(); it != en; ++it){ + if(leaves_to_remove.find(*it) == leaves_to_remove.end()) + leaves.push_back(*it); + } + } + + hash_map > node_map; + std::list updated_nodes; + + virtual void ExpandNode(RPFP::Node *p){ + stack.back().expansions.push_back(p); + DerivationTree::ExpandNode(p); + std::vector &new_nodes = p->Outgoing->Children; + for(unsigned i = 0; i < new_nodes.size(); i++){ + Node *n = new_nodes[i]; + node_map[n->map].push_back(n); + } + } + + bool RecordUpdate(Node *node){ + bool res = duality->UpdateNodeToNode(node->map,node); + if(res){ + std::vector to_update = node_map[node->map]; + for(unsigned i = 0; i < to_update.size(); i++){ + Node *node2 = to_update[i]; + // maintain invariant that no nodes on updated list are created at current stack level + if(node2 == node || !(node->Incoming.size() > 0 && AtCurrentStackLevel(node2->Incoming[0]->Parent))){ + updated_nodes.push_back(node2); + if(node2 != node) + node2->Annotation = node->Annotation; + } + } + } + return res; + } + + void HandleUpdatedNodes(){ + for(std::list::iterator it = updated_nodes.begin(), en = updated_nodes.end(); it != en;){ + Node *node = *it; + node->Annotation = node->map->Annotation; + if(node->Incoming.size() > 0) + tree->ConstrainParent(node->Incoming[0],node); + if(AtCurrentStackLevel(node->Incoming[0]->Parent)){ + std::list::iterator victim = it; + ++it; + updated_nodes.erase(victim); + } + else + ++it; + } + } + + bool AtCurrentStackLevel(Node *node){ + std::vector vec = stack.back().expansions; + for(unsigned i = 0; i < vec.size(); i++) + if(vec[i] == node) + return true; + return false; + } + + void UnmapNode(Node *node){ + std::vector &vec = node_map[node->map]; + for(unsigned i = 0; i < vec.size(); i++){ + if(vec[i] == node){ + std::swap(vec[i],vec.back()); + vec.pop_back(); + return; + } + } + throw "can't unmap node"; + } + + }; + + + class Covering { + + struct cover_info { + Node *covered_by; + std::list covers; + bool dominated; + std::set dominates; + cover_info(){ + covered_by = 0; + dominated = false; + } + }; + + typedef hash_map cover_map; + cover_map cm; + Duality *parent; + bool some_updates; + + Node *&covered_by(Node *node){ + return cm[node].covered_by; + } + + std::list &covers(Node *node){ + return cm[node].covers; + } + + std::vector &insts_of_node(Node *node){ + return parent->insts_of_node[node]; + } + + Reporter *reporter(){ + return parent->reporter; + } + + std::set &dominates(Node *x){ + return cm[x].dominates; + } + + bool dominates(Node *x, Node *y){ + std::set &d = cm[x].dominates; + return d.find(y) != d.end(); + } + + bool &dominated(Node *x){ + return cm[x].dominated; + } + + public: + + Covering(Duality *_parent){ + parent = _parent; + some_updates = false; + } + + bool IsCoveredRec(hash_set &memo, Node *node){ + if(memo.find(node) != memo.end()) + return false; + memo.insert(node); + if(covered_by(node)) return true; + for(unsigned i = 0; i < node->Outgoing->Children.size(); i++) + if(IsCoveredRec(memo,node->Outgoing->Children[i])) + return true; + return false; + } + + bool IsCovered(Node *node){ + hash_set memo; + return IsCoveredRec(memo,node); + } + +#ifndef UNDERAPPROX_NODES + void RemoveCoveringsBy(Node *node){ + std::list &cs = covers(node); + for(std::list::iterator it = cs.begin(), en = cs.end(); it != en; it++){ + covered_by(*it) = 0; + reporter()->RemoveCover(*it,node); + } + cs.clear(); + } +#else + void RemoveCoveringsBy(Node *node){ + std::vector &cs = parent->all_of_node[node->map]; + for(std::vector::iterator it = cs.begin(), en = cs.end(); it != en; it++){ + Node *other = *it; + if(covered_by(other) && CoverOrder(node,other)){ + covered_by(other) = 0; + reporter()->RemoveCover(*it,node); + } + } + } +#endif + + void RemoveAscendantCoveringsRec(hash_set &memo, Node *node){ + if(memo.find(node) != memo.end()) + return; + memo.insert(node); + RemoveCoveringsBy(node); + for(std::vector::iterator it = node->Incoming.begin(), en = node->Incoming.end(); it != en; ++it) + RemoveAscendantCoveringsRec(memo,(*it)->Parent); + } + + void RemoveAscendantCoverings(Node *node){ + hash_set memo; + RemoveAscendantCoveringsRec(memo,node); + } + + bool CoverOrder(Node *covering, Node *covered){ +#ifdef UNDERAPPROX_NODES + if(parent->underapprox_map.find(covered) != parent->underapprox_map.end()) + return false; + if(parent->underapprox_map.find(covering) != parent->underapprox_map.end()) + return covering->number < covered->number || parent->underapprox_map[covering] == covered; +#endif + return covering->number < covered->number; + } + + bool CheckCover(Node *covered, Node *covering){ + return + CoverOrder(covering,covered) + && covered->Annotation.SubsetEq(covering->Annotation) + && !IsCovered(covering); + } + + bool CoverByNode(Node *covered, Node *covering){ + if(CheckCover(covered,covering)){ + covered_by(covered) = covering; + covers(covering).push_back(covered); + std::vector others; others.push_back(covering); + reporter()->AddCover(covered,others); + RemoveAscendantCoverings(covered); + return true; + } + else + return false; + } + +#ifdef UNDERAPPROX_NODES + bool CoverByAll(Node *covered){ + RPFP::Transformer all = covered->Annotation; + all.SetEmpty(); + std::vector &insts = parent->insts_of_node[covered->map]; + std::vector others; + for(unsigned i = 0; i < insts.size(); i++){ + Node *covering = insts[i]; + if(CoverOrder(covering,covered) && !IsCovered(covering)){ + others.push_back(covering); + all.UnionWith(covering->Annotation); + } + } + if(others.size() && covered->Annotation.SubsetEq(all)){ + covered_by(covered) = covered; // anything non-null will do + reporter()->AddCover(covered,others); + RemoveAscendantCoverings(covered); + return true; + } + else + return false; + } +#endif + + bool Close(Node *node){ + if(covered_by(node)) + return true; +#ifndef UNDERAPPROX_NODES + std::vector &insts = insts_of_node(node->map); + for(unsigned i = 0; i < insts.size(); i++) + if(CoverByNode(node,insts[i])) + return true; +#else + if(CoverByAll(node)) + return true; +#endif + return false; + } + + bool CloseDescendantsRec(hash_set &memo, Node *node){ + if(memo.find(node) != memo.end()) + return false; + for(unsigned i = 0; i < node->Outgoing->Children.size(); i++) + if(CloseDescendantsRec(memo,node->Outgoing->Children[i])) + return true; + if(Close(node)) + return true; + memo.insert(node); + return false; + } + + bool CloseDescendants(Node *node){ + timer_start("CloseDescendants"); + hash_set memo; + bool res = CloseDescendantsRec(memo,node); + timer_stop("CloseDescendants"); + return res; + } + + bool Contains(Node *node){ + timer_start("Contains"); + bool res = !IsCovered(node); + timer_stop("Contains"); + return res; + } + + bool Candidate(Node *node){ + timer_start("Candidate"); + bool res = !IsCovered(node) && !dominated(node); + timer_stop("Candidate"); + return res; + } + + void SetDominated(Node *node){ + dominated(node) = true; + } + + bool CouldCover(Node *covered, Node *covering){ +#ifdef UNDERAPPROX_NODES + // if(parent->underapprox_map.find(covering) != parent->underapprox_map.end()) + // return parent->underapprox_map[covering] == covered; +#endif + if(CoverOrder(covering,covered) + && !IsCovered(covering)){ + RPFP::Transformer f(covering->Annotation); f.SetEmpty(); +#if defined(TOP_DOWN) || defined(EFFORT_BOUNDED_STRAT) + if(parent->StratifiedInlining) + return true; +#endif + return !covering->Annotation.SubsetEq(f); + } + return false; + } + + bool ContainsCex(Node *node, Counterexample &cex){ + expr val = cex.tree->Eval(cex.root->Outgoing,node->Annotation.Formula); + return eq(val,parent->ctx.bool_val(true)); + } + + /** We conjecture that the annotations of similar nodes may be + true of this one. We start with later nodes, on the + principle that their annotations are likely weaker. We save + a counterexample -- if annotations of other nodes are true + in this counterexample, we don't need to check them. + */ + +#ifndef UNDERAPPROX_NODES + bool Conjecture(Node *node){ + std::vector &insts = insts_of_node(node->map); + Counterexample cex; + for(int i = insts.size() - 1; i >= 0; i--){ + Node *other = insts[i]; + if(CouldCover(node,other)){ + reporter()->Forcing(node,other); + if(cex.tree && !ContainsCex(other,cex)) + continue; + if(cex.tree) {delete cex.tree; cex.tree = 0;} + if(parent->ProveConjecture(node,other->Annotation,other,&cex)) + if(CloseDescendants(node)) + return true; + } + } + if(cex.tree) {delete cex.tree; cex.tree = 0;} + return false; + } +#else + bool Conjecture(Node *node){ + std::vector &insts = insts_of_node(node->map); + Counterexample cex; + RPFP::Transformer Bound = node->Annotation; + Bound.SetEmpty(); + bool some_other = false; + for(int i = insts.size() - 1; i >= 0; i--){ + Node *other = insts[i]; + if(CouldCover(node,other)){ + reporter()->Forcing(node,other); + Bound.UnionWith(other->Annotation); + some_other = true; + } + } + if(some_other && parent->ProveConjecture(node,Bound)){ + CloseDescendants(node); + return true; + } + return false; + } +#endif + + void Update(Node *node, const RPFP::Transformer &update){ + RemoveCoveringsBy(node); + some_updates = true; + } + +#ifndef UNDERAPPROX_NODES + Node *GetSimilarNode(Node *node){ + if(!some_updates) + return 0; + std::vector &insts = insts_of_node(node->map); + for(int i = insts.size()-1; i >= 0; i--){ + Node *other = insts[i]; + if(dominates(node,other)) + if(CoverOrder(other,node) + && !IsCovered(other)) + return other; + } + return 0; + } +#else + Node *GetSimilarNode(Node *node){ + if(!some_updates) + return 0; + std::vector &insts = insts_of_node(node->map); + for(int i = insts.size() - 1; i >= 0; i--){ + Node *other = insts[i]; + if(CoverOrder(other,node) + && !IsCovered(other)) + return other; + } + return 0; + } +#endif + + bool Dominates(Node * node, Node *other){ + if(node == other) return false; + if(other->Outgoing->map == 0) return true; + if(node->Outgoing->map == other->Outgoing->map){ + assert(node->Outgoing->Children.size() == other->Outgoing->Children.size()); + for(unsigned i = 0; i < node->Outgoing->Children.size(); i++){ + Node *nc = node->Outgoing->Children[i]; + Node *oc = other->Outgoing->Children[i]; + if(!(nc == oc || oc->Outgoing->map ==0 || dominates(nc,oc))) + return false; + } + return true; + } + return false; + } + + void Add(Node *node){ + std::vector &insts = insts_of_node(node->map); + for(unsigned i = 0; i < insts.size(); i++){ + Node *other = insts[i]; + if(Dominates(node,other)){ + cm[node].dominates.insert(other); + cm[other].dominated = true; + reporter()->Dominates(node, other); + } + } + } + + }; + + /* This expansion heuristic makes use of a previuosly obtained + counterexample as a guide. This is for use in abstraction + refinement schemes.*/ + + class ReplayHeuristic : public Heuristic { + + Counterexample old_cex; + public: + ReplayHeuristic(RPFP *_rpfp, Counterexample &_old_cex) + : Heuristic(_rpfp), old_cex(_old_cex) + { + } + + ~ReplayHeuristic(){ + if(old_cex.tree) + delete old_cex.tree; + } + + // Maps nodes of derivation tree into old cex + hash_map cex_map; + + void Done() { + cex_map.clear(); + if(old_cex.tree) + delete old_cex.tree; + old_cex.tree = 0; // only replay once! + } + + void ShowNodeAndChildren(Node *n){ + std::cout << n->Name.name() << ": "; + std::vector &chs = n->Outgoing->Children; + for(unsigned i = 0; i < chs.size(); i++) + std::cout << chs[i]->Name.name() << " " ; + std::cout << std::endl; + } + + // HACK: When matching relation names, we drop suffixes used to + // make the names unique between runs. For compatibility + // with boggie, we drop suffixes beginning with @@ + std::string BaseName(const std::string &name){ + int pos = name.find("@@"); + if(pos >= 1) + return name.substr(0,pos); + return name; + } + + virtual void ChooseExpand(const std::set &choices, std::set &best, bool high_priority, bool best_only){ + if(!high_priority || !old_cex.tree){ + Heuristic::ChooseExpand(choices,best,false); + return; + } + // first, try to match the derivatino tree nodes to the old cex + std::set matched, unmatched; + for(std::set::iterator it = choices.begin(), en = choices.end(); it != en; ++it){ + Node *node = (*it); + if(cex_map.empty()) + cex_map[node] = old_cex.root; // match the root nodes + if(cex_map.find(node) == cex_map.end()){ // try to match an unmatched node + Node *parent = node->Incoming[0]->Parent; // assumes we are a tree! + if(cex_map.find(parent) == cex_map.end()) + throw "catastrophe in ReplayHeuristic::ChooseExpand"; + Node *old_parent = cex_map[parent]; + std::vector &chs = parent->Outgoing->Children; + if(old_parent && old_parent->Outgoing){ + std::vector &old_chs = old_parent->Outgoing->Children; + for(unsigned i = 0, j=0; i < chs.size(); i++){ + if(j < old_chs.size() && BaseName(chs[i]->Name.name().str()) == BaseName(old_chs[j]->Name.name().str())) + cex_map[chs[i]] = old_chs[j++]; + else { + std::cerr << "WARNING: duality: unmatched child: " << chs[i]->Name.name() << std::endl; + cex_map[chs[i]] = 0; + } + } + goto matching_done; + } + for(unsigned i = 0; i < chs.size(); i++) + cex_map[chs[i]] = 0; + } + matching_done: + Node *old_node = cex_map[node]; + if(!old_node) + unmatched.insert(node); + else if(old_cex.tree->Empty(old_node)) + unmatched.insert(node); + else + matched.insert(node); + } + if (matched.empty() && !high_priority) + Heuristic::ChooseExpand(unmatched,best,false); + else + Heuristic::ChooseExpand(matched,best,false); + } + }; + + + class LocalHeuristic : public Heuristic { + + RPFP::Node *old_node; + public: + LocalHeuristic(RPFP *_rpfp) + : Heuristic(_rpfp) + { + old_node = 0; + } + + void SetOldNode(RPFP::Node *_old_node){ + old_node = _old_node; + cex_map.clear(); + } + + // Maps nodes of derivation tree into old subtree + hash_map cex_map; + + virtual void ChooseExpand(const std::set &choices, std::set &best){ + if(old_node == 0){ + Heuristic::ChooseExpand(choices,best); + return; + } + // first, try to match the derivatino tree nodes to the old cex + std::set matched, unmatched; + for(std::set::iterator it = choices.begin(), en = choices.end(); it != en; ++it){ + Node *node = (*it); + if(cex_map.empty()) + cex_map[node] = old_node; // match the root nodes + if(cex_map.find(node) == cex_map.end()){ // try to match an unmatched node + Node *parent = node->Incoming[0]->Parent; // assumes we are a tree! + if(cex_map.find(parent) == cex_map.end()) + throw "catastrophe in ReplayHeuristic::ChooseExpand"; + Node *old_parent = cex_map[parent]; + std::vector &chs = parent->Outgoing->Children; + if(old_parent && old_parent->Outgoing){ + std::vector &old_chs = old_parent->Outgoing->Children; + if(chs.size() == old_chs.size()){ + for(unsigned i = 0; i < chs.size(); i++) + cex_map[chs[i]] = old_chs[i]; + goto matching_done; + } + else + std::cout << "derivation tree does not match old cex" << std::endl; + } + for(unsigned i = 0; i < chs.size(); i++) + cex_map[chs[i]] = 0; + } + matching_done: + Node *old_node = cex_map[node]; + if(!old_node) + unmatched.insert(node); + else if(old_node != node->map) + unmatched.insert(node); + else + matched.insert(node); + } + Heuristic::ChooseExpand(unmatched,best); + } + }; + + + }; + + + class StreamReporter : public Reporter { + std::ostream &s; + public: + StreamReporter(RPFP *_rpfp, std::ostream &_s) + : Reporter(_rpfp), s(_s) {event = 0;} + int event; + void ev(){ + s << "[" << event++ << "]" ; + } + virtual void Extend(RPFP::Node *node){ + ev(); s << "node " << node->number << ": " << node->Name.name(); + std::vector &rps = node->Outgoing->Children; + for(unsigned i = 0; i < rps.size(); i++) + s << " " << rps[i]->number; + s << std::endl; + } + virtual void Update(RPFP::Node *node, const RPFP::Transformer &update){ + ev(); s << "update " << node->number << " " << node->Name.name() << ": "; + rpfp->Summarize(update.Formula); + std::cout << std::endl; + } + virtual void Bound(RPFP::Node *node){ + ev(); s << "check " << node->number << std::endl; + } + virtual void Expand(RPFP::Edge *edge){ + RPFP::Node *node = edge->Parent; + ev(); s << "expand " << node->map->number << " " << node->Name.name() << std::endl; + } + virtual void AddCover(RPFP::Node *covered, std::vector &covering){ + ev(); s << "cover " << covered->Name.name() << ": " << covered->number << " by "; + for(unsigned i = 0; i < covering.size(); i++) + std::cout << covering[i]->number << " "; + std::cout << std::endl; + } + virtual void RemoveCover(RPFP::Node *covered, RPFP::Node *covering){ + ev(); s << "uncover " << covered->Name.name() << ": " << covered->number << " by " << covering->number << std::endl; + } + virtual void Forcing(RPFP::Node *covered, RPFP::Node *covering){ + ev(); s << "forcing " << covered->Name.name() << ": " << covered->number << " by " << covering->number << std::endl; + } + virtual void Conjecture(RPFP::Node *node, const RPFP::Transformer &t){ + ev(); s << "conjecture " << node->number << " " << node->Name.name() << ": "; + rpfp->Summarize(t.Formula); + std::cout << std::endl; + } + virtual void Dominates(RPFP::Node *node, RPFP::Node *other){ + ev(); s << "dominates " << node->Name.name() << ": " << node->number << " > " << other->number << std::endl; + } + virtual void InductionFailure(RPFP::Edge *edge, const std::vector &children){ + ev(); s << "induction failure: " << edge->Parent->Name.name() << ", children ="; + for(unsigned i = 0; i < children.size(); i++) + s << " " << children[i]->number; + s << std::endl; + } + virtual void UpdateUnderapprox(RPFP::Node *node, const RPFP::Transformer &update){ + ev(); s << "underapprox " << node->number << " " << node->Name.name() << ": " << update.Formula << std::endl; + } + virtual void Reject(RPFP::Edge *edge, const std::vector &children){ + ev(); s << "reject " << edge->Parent->number << " " << edge->Parent->Name.name() << ": "; + for(unsigned i = 0; i < children.size(); i++) + s << " " << children[i]->number; + s << std::endl; + } + virtual void Message(const std::string &msg){ + ev(); s << "msg " << msg << std::endl; + } + + }; + + Solver *Solver::Create(const std::string &solver_class, RPFP *rpfp){ + Duality *s = alloc(Duality,rpfp); + return s; + } + + Reporter *CreateStdoutReporter(RPFP *rpfp){ + return new StreamReporter(rpfp, std::cout); + } +} diff --git a/src/duality/duality_wrapper.cpp b/src/duality/duality_wrapper.cpp new file mode 100644 index 000000000..55883202f --- /dev/null +++ b/src/duality/duality_wrapper.cpp @@ -0,0 +1,677 @@ +/*++ +Copyright (c) 2012 Microsoft Corporation + +Module Name: + + wrapper.cpp + +Abstract: + + wrap various objects in the style expected by duality + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + + +--*/ + +#include "duality_wrapper.h" +#include +#include "smt_solver.h" +#include "iz3interp.h" +#include "statistics.h" +#include "expr_abstract.h" +#include "stopwatch.h" +#include "model_smt2_pp.h" +#include "qe_lite.h" + +namespace Duality { + + solver::solver(Duality::context& c, bool extensional) : object(c), the_model(c) { + params_ref p; + p.set_bool("proof", true); // this is currently useless + p.set_bool("model", true); + p.set_bool("unsat_core", true); + p.set_bool("mbqi",true); + p.set_str("mbqi.id","itp"); // use mbqi for quantifiers in interpolants + p.set_uint("mbqi.max_iterations",1); // use mbqi for quantifiers in interpolants + if(true || extensional) + p.set_bool("array.extensional",true); + scoped_ptr sf = mk_smt_solver_factory(); + m_solver = (*sf)(m(), p, true, true, true, ::symbol::null); + m_solver->updt_params(p); // why do we have to do this? + canceled = false; + } + +expr context::constant(const std::string &name, const sort &ty){ + symbol s = str_symbol(name.c_str()); + return cook(m().mk_const(m().mk_const_decl(s, ty))); +} + +expr context::make(decl_kind op, int n, ::expr **args){ + switch(op) { + case True: return mki(m_basic_fid,OP_TRUE,n,args); + case False: return mki(m_basic_fid,OP_FALSE,n,args); + case Equal: return mki(m_basic_fid,OP_EQ,n,args); + case Distinct: return mki(m_basic_fid,OP_DISTINCT,n,args); + case Ite: return mki(m_basic_fid,OP_ITE,n,args); + case And: return mki(m_basic_fid,OP_AND,n,args); + case Or: return mki(m_basic_fid,OP_OR,n,args); + case Iff: return mki(m_basic_fid,OP_IFF,n,args); + case Xor: return mki(m_basic_fid,OP_XOR,n,args); + case Not: return mki(m_basic_fid,OP_NOT,n,args); + case Implies: return mki(m_basic_fid,OP_IMPLIES,n,args); + case Oeq: return mki(m_basic_fid,OP_OEQ,n,args); + case Interp: return mki(m_basic_fid,OP_INTERP,n,args); + case Leq: return mki(m_arith_fid,OP_LE,n,args); + case Geq: return mki(m_arith_fid,OP_GE,n,args); + case Lt: return mki(m_arith_fid,OP_LT,n,args); + case Gt: return mki(m_arith_fid,OP_GT,n,args); + case Plus: return mki(m_arith_fid,OP_ADD,n,args); + case Sub: return mki(m_arith_fid,OP_SUB,n,args); + case Uminus: return mki(m_arith_fid,OP_UMINUS,n,args); + case Times: return mki(m_arith_fid,OP_MUL,n,args); + case Div: return mki(m_arith_fid,OP_DIV,n,args); + case Idiv: return mki(m_arith_fid,OP_IDIV,n,args); + case Rem: return mki(m_arith_fid,OP_REM,n,args); + case Mod: return mki(m_arith_fid,OP_MOD,n,args); + case Power: return mki(m_arith_fid,OP_POWER,n,args); + case ToReal: return mki(m_arith_fid,OP_TO_REAL,n,args); + case ToInt: return mki(m_arith_fid,OP_TO_INT,n,args); + case IsInt: return mki(m_arith_fid,OP_IS_INT,n,args); + case Store: return mki(m_array_fid,OP_STORE,n,args); + case Select: return mki(m_array_fid,OP_SELECT,n,args); + case ConstArray: return mki(m_array_fid,OP_CONST_ARRAY,n,args); + case ArrayDefault: return mki(m_array_fid,OP_ARRAY_DEFAULT,n,args); + case ArrayMap: return mki(m_array_fid,OP_ARRAY_MAP,n,args); + case SetUnion: return mki(m_array_fid,OP_SET_UNION,n,args); + case SetIntersect: return mki(m_array_fid,OP_SET_INTERSECT,n,args); + case SetDifference: return mki(m_array_fid,OP_SET_DIFFERENCE,n,args); + case SetComplement: return mki(m_array_fid,OP_SET_COMPLEMENT,n,args); + case SetSubSet: return mki(m_array_fid,OP_SET_SUBSET,n,args); + case AsArray: return mki(m_array_fid,OP_AS_ARRAY,n,args); + default: + assert(0); + return expr(*this); + } +} + + expr context::mki(family_id fid, ::decl_kind dk, int n, ::expr **args){ + return cook(m().mk_app(fid, dk, 0, 0, n, (::expr **)args)); +} + +expr context::make(decl_kind op, const std::vector &args){ + static std::vector< ::expr*> a(10); + if(a.size() < args.size()) + a.resize(args.size()); + for(unsigned i = 0; i < args.size(); i++) + a[i] = to_expr(args[i].raw()); + return make(op,args.size(), args.size() ? &a[0] : 0); +} + +expr context::make(decl_kind op){ + return make(op,0,0); +} + +expr context::make(decl_kind op, const expr &arg0){ + ::expr *a = to_expr(arg0.raw()); + return make(op,1,&a); +} + +expr context::make(decl_kind op, const expr &arg0, const expr &arg1){ + ::expr *args[2]; + args[0] = to_expr(arg0.raw()); + args[1] = to_expr(arg1.raw()); + return make(op,2,args); +} + +expr context::make(decl_kind op, const expr &arg0, const expr &arg1, const expr &arg2){ + ::expr *args[3]; + args[0] = to_expr(arg0.raw()); + args[1] = to_expr(arg1.raw()); + args[2] = to_expr(arg2.raw()); + return make(op,3,args); +} + +expr context::make_quant(decl_kind op, const std::vector &bvs, const expr &body){ + if(bvs.size() == 0) return body; + std::vector< ::expr *> foo(bvs.size()); + + + std::vector< ::symbol> names; + std::vector< ::sort *> types; + std::vector< ::expr *> bound_asts; + unsigned num_bound = bvs.size(); + + for (unsigned i = 0; i < num_bound; ++i) { + app* a = to_app(bvs[i].raw()); + ::symbol s(to_app(a)->get_decl()->get_name()); + names.push_back(s); + types.push_back(m().get_sort(a)); + bound_asts.push_back(a); + } + expr_ref abs_body(m()); + expr_abstract(m(), 0, num_bound, &bound_asts[0], to_expr(body.raw()), abs_body); + expr_ref result(m()); + result = m().mk_quantifier( + op == Forall, + names.size(), &types[0], &names[0], abs_body.get(), + 0, + ::symbol(), + ::symbol(), + 0, 0, + 0, 0 + ); + return cook(result.get()); +} + +expr context::make_quant(decl_kind op, const std::vector &_sorts, const std::vector &_names, const expr &body){ + if(_sorts.size() == 0) return body; + + + std::vector< ::symbol> names; + std::vector< ::sort *> types; + std::vector< ::expr *> bound_asts; + unsigned num_bound = _sorts.size(); + + for (unsigned i = 0; i < num_bound; ++i) { + names.push_back(_names[i]); + types.push_back(to_sort(_sorts[i].raw())); + } + expr_ref result(m()); + result = m().mk_quantifier( + op == Forall, + names.size(), &types[0], &names[0], to_expr(body.raw()), + 0, + ::symbol(), + ::symbol(), + 0, 0, + 0, 0 + ); + return cook(result.get()); +} + + + decl_kind func_decl::get_decl_kind() const { + return ctx().get_decl_kind(*this); + } + + decl_kind context::get_decl_kind(const func_decl &t){ + ::func_decl *d = to_func_decl(t.raw()); + if (null_family_id == d->get_family_id()) + return Uninterpreted; + // return (opr)d->get_decl_kind(); + if (m_basic_fid == d->get_family_id()) { + switch(d->get_decl_kind()) { + case OP_TRUE: return True; + case OP_FALSE: return False; + case OP_EQ: return Equal; + case OP_DISTINCT: return Distinct; + case OP_ITE: return Ite; + case OP_AND: return And; + case OP_OR: return Or; + case OP_IFF: return Iff; + case OP_XOR: return Xor; + case OP_NOT: return Not; + case OP_IMPLIES: return Implies; + case OP_OEQ: return Oeq; + case OP_INTERP: return Interp; + default: + return OtherBasic; + } + } + if (m_arith_fid == d->get_family_id()) { + switch(d->get_decl_kind()) { + case OP_LE: return Leq; + case OP_GE: return Geq; + case OP_LT: return Lt; + case OP_GT: return Gt; + case OP_ADD: return Plus; + case OP_SUB: return Sub; + case OP_UMINUS: return Uminus; + case OP_MUL: return Times; + case OP_DIV: return Div; + case OP_IDIV: return Idiv; + case OP_REM: return Rem; + case OP_MOD: return Mod; + case OP_POWER: return Power; + case OP_TO_REAL: return ToReal; + case OP_TO_INT: return ToInt; + case OP_IS_INT: return IsInt; + default: + return OtherArith; + } + } + if (m_array_fid == d->get_family_id()) { + switch(d->get_decl_kind()) { + case OP_STORE: return Store; + case OP_SELECT: return Select; + case OP_CONST_ARRAY: return ConstArray; + case OP_ARRAY_DEFAULT: return ArrayDefault; + case OP_ARRAY_MAP: return ArrayMap; + case OP_SET_UNION: return SetUnion; + case OP_SET_INTERSECT: return SetIntersect; + case OP_SET_DIFFERENCE: return SetDifference; + case OP_SET_COMPLEMENT: return SetComplement; + case OP_SET_SUBSET: return SetSubSet; + case OP_AS_ARRAY: return AsArray; + default: + return OtherArray; + } + } + + return Other; + } + + + sort_kind context::get_sort_kind(const sort &s){ + family_id fid = to_sort(s.raw())->get_family_id(); + ::decl_kind k = to_sort(s.raw())->get_decl_kind(); + if (m().is_uninterp(to_sort(s.raw()))) { + return UninterpretedSort; + } + else if (fid == m_basic_fid && k == BOOL_SORT) { + return BoolSort; + } + else if (fid == m_arith_fid && k == INT_SORT) { + return IntSort; + } + else if (fid == m_arith_fid && k == REAL_SORT) { + return RealSort; + } + else if (fid == m_array_fid && k == ARRAY_SORT) { + return ArraySort; + } + else { + return UnknownSort; + } + } + + expr func_decl::operator()(unsigned n, expr const * args) const { + std::vector< ::expr *> _args(n); + for(unsigned i = 0; i < n; i++) + _args[i] = to_expr(args[i].raw()); + return ctx().cook(m().mk_app(to_func_decl(raw()),n,&_args[0])); + } + + int solver::get_num_decisions(){ + ::statistics st; + m_solver->collect_statistics(st); + std::ostringstream ss; + st.display(ss); + std::string stats = ss.str(); + int pos = stats.find("decisions:"); + if(pos < 0) return 0; // for some reason, decisions are not reported if there are none + pos += 10; + int end = stats.find('\n',pos); + std::string val = stats.substr(pos,end-pos); + return atoi(val.c_str()); + } + + void context::print_expr(std::ostream &s, const ast &e){ + s << mk_pp(e.raw(), m()); + } + + + expr expr::simplify(const params &_p) const { + ::expr * a = to_expr(raw()); + params_ref p = _p.get(); + th_rewriter m_rw(m(), p); + expr_ref result(m()); + m_rw(a, result); + return ctx().cook(result); + } + + expr expr::simplify() const { + params p; + return simplify(p); + } + + expr expr::qe_lite() const { + ::qe_lite qe(m()); + expr_ref result(to_expr(raw()),m()); + proof_ref pf(m()); + qe(result,pf); + return ctx().cook(result); + } + + expr clone_quantifier(const expr &q, const expr &b){ + return q.ctx().cook(q.m().update_quantifier(to_quantifier(q.raw()), to_expr(b.raw()))); + } + + expr clone_quantifier(const expr &q, const expr &b, const std::vector &patterns){ + quantifier *thing = to_quantifier(q.raw()); + bool is_forall = thing->is_forall(); + unsigned num_patterns = patterns.size(); + std::vector< ::expr *> _patterns(num_patterns); + for(unsigned i = 0; i < num_patterns; i++) + _patterns[i] = to_expr(patterns[i].raw()); + return q.ctx().cook(q.m().update_quantifier(thing, is_forall, num_patterns, &_patterns[0], to_expr(b.raw()))); + } + + void expr::get_patterns(std::vector &pats) const { + quantifier *thing = to_quantifier(raw()); + unsigned num_patterns = thing->get_num_patterns(); + :: expr * const *it = thing->get_patterns(); + pats.resize(num_patterns); + for(unsigned i = 0; i < num_patterns; i++) + pats[i] = expr(ctx(),it[i]); + } + + + func_decl context::fresh_func_decl(char const * prefix, const std::vector &domain, sort const & range){ + std::vector < ::sort * > _domain(domain.size()); + for(unsigned i = 0; i < domain.size(); i++) + _domain[i] = to_sort(domain[i].raw()); + ::func_decl* d = m().mk_fresh_func_decl(prefix, + _domain.size(), + &_domain[0], + to_sort(range.raw())); + return func_decl(*this,d); + } + + func_decl context::fresh_func_decl(char const * prefix, sort const & range){ + ::func_decl* d = m().mk_fresh_func_decl(prefix, + 0, + 0, + to_sort(range.raw())); + return func_decl(*this,d); + } + + + +#if 0 + + + lbool interpolating_solver::interpolate( + const std::vector &assumptions, + std::vector &interpolants, + model &model, + Z3_literals &labels, + bool incremental) + { + Z3_model _model = 0; + Z3_literals _labels = 0; + Z3_lbool lb; + std::vector _assumptions(assumptions.size()); + std::vector _interpolants(assumptions.size()-1); + for(unsigned i = 0; i < assumptions.size(); i++) + _assumptions[i] = assumptions[i]; + std::vector _theory(theory.size()); + for(unsigned i = 0; i < theory.size(); i++) + _theory[i] = theory[i]; + + lb = Z3_interpolate( + ctx(), + _assumptions.size(), + &_assumptions[0], + 0, + 0, + &_interpolants[0], + &_model, + &_labels, + incremental, + _theory.size(), + &_theory[0] + ); + + if(lb == Z3_L_FALSE){ + interpolants.resize(_interpolants.size()); + for (unsigned i = 0; i < _interpolants.size(); ++i) { + interpolants[i] = expr(ctx(),_interpolants[i]); + } + } + + if (_model) { + model = iz3wrapper::model(ctx(), _model); + } + + if(_labels){ + labels = _labels; + } + + return lb; + } + +#endif + + static int linearize_assumptions(int num, + TermTree *assumptions, + std::vector > &linear_assumptions, + std::vector &parents){ + for(unsigned i = 0; i < assumptions->getChildren().size(); i++) + num = linearize_assumptions(num, assumptions->getChildren()[i], linear_assumptions, parents); + // linear_assumptions[num].push_back(assumptions->getTerm()); + for(unsigned i = 0; i < assumptions->getChildren().size(); i++) + parents[assumptions->getChildren()[i]->getNumber()] = num; + parents[num] = SHRT_MAX; // in case we have no parent + linear_assumptions[num].push_back(assumptions->getTerm()); + std::vector &ts = assumptions->getTerms(); + for(unsigned i = 0; i < ts.size(); i++) + linear_assumptions[num].push_back(ts[i]); + return num + 1; + } + + static int unlinearize_interpolants(int num, + TermTree* assumptions, + const std::vector &interpolant, + TermTree * &tree_interpolant) + { + std::vector chs(assumptions->getChildren().size()); + for(unsigned i = 0; i < assumptions->getChildren().size(); i++) + num = unlinearize_interpolants(num, assumptions->getChildren()[i], interpolant,chs[i]); + expr f; + if(num < (int)interpolant.size()) // last interpolant is missing, presumed false + f = interpolant[num]; + tree_interpolant = new TermTree(f,chs); + return num + 1; + } + + + lbool interpolating_solver::interpolate_tree(TermTree *assumptions, + TermTree *&interpolant, + model &model, + literals &labels, + bool incremental + ) + + { + int size = assumptions->number(0); + std::vector > linear_assumptions(size); + std::vector parents(size); + linearize_assumptions(0,assumptions,linear_assumptions,parents); + + ptr_vector< ::ast> _interpolants(size-1); + vector >_assumptions(size); + for(int i = 0; i < size; i++) + for(unsigned j = 0; j < linear_assumptions[i].size(); j++) + _assumptions[i].push_back(linear_assumptions[i][j]); + ::vector _parents; _parents.resize(parents.size()); + for(unsigned i = 0; i < parents.size(); i++) + _parents[i] = parents[i]; + ptr_vector< ::ast> _theory(theory.size()); + for(unsigned i = 0; i < theory.size(); i++) + _theory[i] = theory[i]; + + + if(!incremental){ + push(); + for(unsigned i = 0; i < linear_assumptions.size(); i++) + for(unsigned j = 0; j < linear_assumptions[i].size(); j++) + add(linear_assumptions[i][j]); + } + + check_result res = check(); + + if(res == unsat){ + + interpolation_options_struct opts; + if(weak_mode) + opts.set("weak","1"); + + ::ast *proof = m_solver->get_proof(); + iz3interpolate(m(),proof,_assumptions,_parents,_interpolants,_theory,&opts); + + std::vector linearized_interpolants(_interpolants.size()); + for(unsigned i = 0; i < _interpolants.size(); i++) + linearized_interpolants[i] = expr(ctx(),_interpolants[i]); + + // since iz3interpolant returns interpolants with one ref count, we decrement here + for(unsigned i = 0; i < _interpolants.size(); i++) + m().dec_ref(_interpolants[i]); + + unlinearize_interpolants(0,assumptions,linearized_interpolants,interpolant); + interpolant->setTerm(ctx().bool_val(false)); + } + + model_ref _m; + m_solver->get_model(_m); + model = Duality::model(ctx(),_m.get()); + +#if 0 + if(_labels){ + labels = _labels; + } +#endif + + if(!incremental) + pop(); + + return (res == unsat) ? l_false : ((res == sat) ? l_true : l_undef); + + } + + + void interpolating_solver::SetWeakInterpolants(bool weak){ + weak_mode = weak; + } + + + void interpolating_solver::SetPrintToFile(const std::string &filename){ + print_filename = filename; + } + + void interpolating_solver::AssertInterpolationAxiom(const expr & t){ + add(t); + theory.push_back(t); + } + + + void interpolating_solver::RemoveInterpolationAxiom(const expr & t){ + // theory.remove(t); + } + + + const char *interpolating_solver::profile(){ + // return Z3_interpolation_profile(ctx()); + return ""; + } + + + static void get_assumptions_rec(stl_ext::hash_set &memo, const proof &pf, std::vector &assumps){ + if(memo.find(pf) != memo.end())return; + memo.insert(pf); + pfrule dk = pf.rule(); + if(dk == PR_ASSERTED){ + expr con = pf.conc(); + assumps.push_back(con); + } + else { + unsigned nprems = pf.num_prems(); + for(unsigned i = 0; i < nprems; i++){ + proof arg = pf.prem(i); + get_assumptions_rec(memo,arg,assumps); + } + } + } + + void proof::get_assumptions(std::vector &assumps){ + stl_ext::hash_set memo; + get_assumptions_rec(memo,*this,assumps); + } + + + void ast::show() const{ + std::cout << mk_pp(raw(), m()) << std::endl; + } + + void model::show() const { + model_smt2_pp(std::cout, m(), *m_model, 0); + std::cout << std::endl; + } + + void model::show_hash() const { + std::ostringstream ss; + model_smt2_pp(ss, m(), *m_model, 0); + hash_space::hash hasher; + unsigned h = hasher(ss.str()); + std::cout << "model hash: " << h << "\n"; + } + + void solver::show() { + unsigned n = m_solver->get_num_assertions(); + if(!n) + return; + ast_smt_pp pp(m()); + for (unsigned i = 0; i < n-1; ++i) + pp.add_assumption(m_solver->get_assertion(i)); + pp.display_smt2(std::cout, m_solver->get_assertion(n-1)); + } + + void solver::show_assertion_ids() { +#if 0 + unsigned n = m_solver->get_num_assertions(); + std::cerr << "assertion ids: "; + for (unsigned i = 0; i < n-1; ++i) + std::cerr << " " << m_solver->get_assertion(i)->get_id(); + std::cerr << "\n"; +#else + unsigned n = m_solver->get_num_assertions(); + std::cerr << "assertion ids hash: "; + unsigned h = 0; + for (unsigned i = 0; i < n-1; ++i) + h += m_solver->get_assertion(i)->get_id(); + std::cerr << h << "\n"; +#endif + } + + void include_ast_show(ast &a){ + a.show(); + } + + void include_model_show(model &a){ + a.show(); + } + + void show_ast(::ast *a, ast_manager &m) { + std::cout << mk_pp(a, m) << std::endl; + } + + bool expr::is_label (bool &pos,std::vector &names) const { + buffer< ::symbol> _names; + bool res = m().is_label(to_expr(raw()),pos,_names); + if(res) + for(unsigned i = 0; i < _names.size(); i++) + names.push_back(symbol(ctx(),_names[i])); + return res; + } + + double current_time() + { + static stopwatch sw; + static bool started = false; + if(!started){ + sw.start(); + started = true; + } + return sw.get_current_seconds(); + } + +} + + + + diff --git a/src/duality/duality_wrapper.h b/src/duality/duality_wrapper.h new file mode 100755 index 000000000..3515194c0 --- /dev/null +++ b/src/duality/duality_wrapper.h @@ -0,0 +1,1430 @@ +/*++ +Copyright (c) 2012 Microsoft Corporation + +Module Name: + + duality_wrapper.h + +Abstract: + + wrap various Z3 classes in the style expected by duality + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + + +--*/ +#ifndef __DUALITY_WRAPPER_H_ +#define __DUALITY_WRAPPER_H_ + +#include +#include +#include +#include +#include +#include +#include"version.h" +#include + +#include "iz3hash.h" +#include "model.h" +#include "solver.h" + +#include"well_sorted.h" +#include"arith_decl_plugin.h" +#include"bv_decl_plugin.h" +#include"datatype_decl_plugin.h" +#include"array_decl_plugin.h" +#include"ast_translation.h" +#include"ast_pp.h" +#include"ast_ll_pp.h" +#include"ast_smt_pp.h" +#include"ast_smt2_pp.h" +#include"th_rewriter.h" +#include"var_subst.h" +#include"expr_substitution.h" +#include"pp.h" +#include"scoped_ctrl_c.h" +#include"cancel_eh.h" +#include"scoped_timer.h" + +namespace Duality { + + class exception; + class config; + class context; + class symbol; + class params; + class ast; + class sort; + class func_decl; + class expr; + class solver; + class goal; + class tactic; + class probe; + class model; + class func_interp; + class func_entry; + class statistics; + class apply_result; + class fixedpoint; + class literals; + + /** + Duality global configuration object. + */ + + class config { + params_ref m_params; + config & operator=(config const & s); + public: + config(config const & s) : m_params(s.m_params) {} + config(const params_ref &_params) : m_params(_params) {} + config() { } // TODO: create a new params object here + ~config() { } + void set(char const * param, char const * value) { m_params.set_str(param,value); } + void set(char const * param, bool value) { m_params.set_bool(param,value); } + void set(char const * param, int value) { m_params.set_uint(param,value); } + params_ref &get() {return m_params;} + const params_ref &get() const {return m_params;} + }; + + enum decl_kind { + True, + False, + And, + Or, + Not, + Iff, + Ite, + Equal, + Implies, + Distinct, + Xor, + Oeq, + Interp, + Leq, + Geq, + Lt, + Gt, + Plus, + Sub, + Uminus, + Times, + Div, + Idiv, + Rem, + Mod, + Power, + ToReal, + ToInt, + IsInt, + Select, + Store, + ConstArray, + ArrayDefault, + ArrayMap, + SetUnion, + SetIntersect, + SetDifference, + SetComplement, + SetSubSet, + AsArray, + Numeral, + Forall, + Exists, + Variable, + Uninterpreted, + OtherBasic, + OtherArith, + OtherArray, + Other + }; + + enum sort_kind {BoolSort,IntSort,RealSort,ArraySort,UninterpretedSort,UnknownSort}; + + /** + A context has an ast manager global configuration options, etc. + */ + + class context { + protected: + ast_manager &mgr; + config m_config; + + family_id m_basic_fid; + family_id m_array_fid; + family_id m_arith_fid; + family_id m_bv_fid; + family_id m_dt_fid; + family_id m_datalog_fid; + arith_util m_arith_util; + + public: + context(ast_manager &_manager, const config &_config) : mgr(_manager), m_config(_config), m_arith_util(_manager) { + m_basic_fid = m().get_basic_family_id(); + m_arith_fid = m().mk_family_id("arith"); + m_bv_fid = m().mk_family_id("bv"); + m_array_fid = m().mk_family_id("array"); + m_dt_fid = m().mk_family_id("datatype"); + m_datalog_fid = m().mk_family_id("datalog_relation"); + } + ~context() { } + + ast_manager &m() const {return *(ast_manager *)&mgr;} + + void set(char const * param, char const * value) { m_config.set(param,value); } + void set(char const * param, bool value) { m_config.set(param,value); } + void set(char const * param, int value) { m_config.set(param,value); } + + symbol str_symbol(char const * s); + symbol int_symbol(int n); + + sort bool_sort(); + sort int_sort(); + sort real_sort(); + sort bv_sort(unsigned sz); + sort array_sort(sort d, sort r); + + func_decl function(symbol const & name, unsigned arity, sort const * domain, sort const & range); + func_decl function(char const * name, unsigned arity, sort const * domain, sort const & range); + func_decl function(char const * name, sort const & domain, sort const & range); + func_decl function(char const * name, sort const & d1, sort const & d2, sort const & range); + func_decl function(char const * name, sort const & d1, sort const & d2, sort const & d3, sort const & range); + func_decl function(char const * name, sort const & d1, sort const & d2, sort const & d3, sort const & d4, sort const & range); + func_decl function(char const * name, sort const & d1, sort const & d2, sort const & d3, sort const & d4, sort const & d5, sort const & range); + func_decl fresh_func_decl(char const * name, const std::vector &domain, sort const & range); + func_decl fresh_func_decl(char const * name, sort const & range); + + expr constant(symbol const & name, sort const & s); + expr constant(char const * name, sort const & s); + expr constant(const std::string &name, sort const & s); + expr bool_const(char const * name); + expr int_const(char const * name); + expr real_const(char const * name); + expr bv_const(char const * name, unsigned sz); + + expr bool_val(bool b); + + expr int_val(int n); + expr int_val(unsigned n); + expr int_val(char const * n); + + expr real_val(int n, int d); + expr real_val(int n); + expr real_val(unsigned n); + expr real_val(char const * n); + + expr bv_val(int n, unsigned sz); + expr bv_val(unsigned n, unsigned sz); + expr bv_val(char const * n, unsigned sz); + + expr num_val(int n, sort const & s); + + expr mki(family_id fid, ::decl_kind dk, int n, ::expr **args); + expr make(decl_kind op, int n, ::expr **args); + expr make(decl_kind op, const std::vector &args); + expr make(decl_kind op); + expr make(decl_kind op, const expr &arg0); + expr make(decl_kind op, const expr &arg0, const expr &arg1); + expr make(decl_kind op, const expr &arg0, const expr &arg1, const expr &arg2); + + expr make_quant(decl_kind op, const std::vector &bvs, const expr &body); + expr make_quant(decl_kind op, const std::vector &_sorts, const std::vector &_names, const expr &body); + + + decl_kind get_decl_kind(const func_decl &t); + + sort_kind get_sort_kind(const sort &s); + + void print_expr(std::ostream &s, const ast &e); + + fixedpoint mk_fixedpoint(); + + expr cook(::expr *a); + std::vector cook(ptr_vector< ::expr> v); + ::expr *uncook(const expr &a); + }; + + template + class z3array { + T * m_array; + unsigned m_size; + z3array(z3array const & s); + z3array & operator=(z3array const & s); + public: + z3array(unsigned sz):m_size(sz) { m_array = new T[sz]; } + ~z3array() { delete[] m_array; } + unsigned size() const { return m_size; } + T & operator[](unsigned i) { assert(i < m_size); return m_array[i]; } + T const & operator[](unsigned i) const { assert(i < m_size); return m_array[i]; } + T const * ptr() const { return m_array; } + T * ptr() { return m_array; } + }; + + class object { + protected: + context * m_ctx; + public: + object(): m_ctx((context *)0) {} + object(context & c):m_ctx(&c) {} + object(object const & s):m_ctx(s.m_ctx) {} + context & ctx() const { return *m_ctx; } + friend void check_context(object const & a, object const & b) { assert(a.m_ctx == b.m_ctx); } + ast_manager &m() const {return m_ctx->m();} + }; + + class symbol : public object { + ::symbol m_sym; + public: + symbol(context & c, ::symbol s):object(c), m_sym(s) {} + symbol(symbol const & s):object(s), m_sym(s.m_sym) {} + symbol & operator=(symbol const & s) { m_ctx = s.m_ctx; m_sym = s.m_sym; return *this; } + operator ::symbol() const {return m_sym;} + std::string str() const { + if (m_sym.is_numerical()) { + std::ostringstream buffer; + buffer << m_sym.get_num(); + return buffer.str(); + } + else { + return m_sym.bare_str(); + } + } + friend std::ostream & operator<<(std::ostream & out, symbol const & s){ + return out << s.str(); + } + friend bool operator==(const symbol &x, const symbol &y){ + return x.m_sym == y.m_sym; + } + }; + + class params : public config {}; + + /** Wrapper around an ast pointer */ + class ast_i : public object { + protected: + ::ast *_ast; + public: + ::ast * const &raw() const {return _ast;} + ast_i(context & c, ::ast *a = 0) : object(c) {_ast = a;} + + ast_i(){_ast = 0;} + bool eq(const ast_i &other) const { + return _ast == other._ast; + } + bool lt(const ast_i &other) const { + return _ast < other._ast; + } + friend bool operator==(const ast_i &x, const ast_i&y){ + return x.eq(y); + } + friend bool operator!=(const ast_i &x, const ast_i&y){ + return !x.eq(y); + } + friend bool operator<(const ast_i &x, const ast_i&y){ + return x.lt(y); + } + size_t hash() const {return (size_t)_ast;} + bool null() const {return !_ast;} + }; + + /** Reference counting verison of above */ + class ast : public ast_i { + public: + operator ::ast*() const { return raw(); } + friend bool eq(ast const & a, ast const & b) { return a.raw() == b.raw(); } + + + ast(context &c, ::ast *a = 0) : ast_i(c,a) { + if(_ast) + m().inc_ref(a); + } + + ast() {} + + ast(const ast &other) : ast_i(other) { + if(_ast) + m().inc_ref(_ast); + } + + ast &operator=(const ast &other) { + if(_ast) + m().dec_ref(_ast); + _ast = other._ast; + m_ctx = other.m_ctx; + if(_ast) + m().inc_ref(_ast); + return *this; + } + + ~ast(){ + if(_ast) + m().dec_ref(_ast); + } + + void show() const; + + }; + + class sort : public ast { + public: + sort(context & c):ast(c) {} + sort(context & c, ::sort *s):ast(c, s) {} + sort(sort const & s):ast(s) {} + operator ::sort*() const { return to_sort(raw()); } + sort & operator=(sort const & s) { return static_cast(ast::operator=(s)); } + + bool is_bool() const { return m().is_bool(*this); } + bool is_int() const { return ctx().get_sort_kind(*this) == IntSort; } + bool is_real() const { return ctx().get_sort_kind(*this) == RealSort; } + bool is_arith() const; + bool is_array() const { return ctx().get_sort_kind(*this) == ArraySort; } + bool is_datatype() const; + bool is_relation() const; + bool is_finite_domain() const; + + + sort array_domain() const; + sort array_range() const; + }; + + + class func_decl : public ast { + public: + func_decl() : ast() {} + func_decl(context & c):ast(c) {} + func_decl(context & c, ::func_decl *n):ast(c, n) {} + func_decl(func_decl const & s):ast(s) {} + operator ::func_decl*() const { return to_func_decl(*this); } + func_decl & operator=(func_decl const & s) { return static_cast(ast::operator=(s)); } + + unsigned arity() const; + sort domain(unsigned i) const; + sort range() const; + symbol name() const {return symbol(ctx(),to_func_decl(raw())->get_name());} + decl_kind get_decl_kind() const; + + bool is_const() const { return arity() == 0; } + + expr operator()(unsigned n, expr const * args) const; + expr operator()(const std::vector &args) const; + expr operator()() const; + expr operator()(expr const & a) const; + expr operator()(int a) const; + expr operator()(expr const & a1, expr const & a2) const; + expr operator()(expr const & a1, int a2) const; + expr operator()(int a1, expr const & a2) const; + expr operator()(expr const & a1, expr const & a2, expr const & a3) const; + expr operator()(expr const & a1, expr const & a2, expr const & a3, expr const & a4) const; + expr operator()(expr const & a1, expr const & a2, expr const & a3, expr const & a4, expr const & a5) const; + + func_decl get_func_decl_parameter(unsigned idx){ + return func_decl(ctx(),to_func_decl(to_func_decl(raw())->get_parameters()[idx].get_ast())); + } + + }; + + class expr : public ast { + public: + expr() : ast() {} + expr(context & c):ast(c) {} + expr(context & c, ::ast *n):ast(c, n) {} + expr(expr const & n):ast(n) {} + expr & operator=(expr const & n) { return static_cast(ast::operator=(n)); } + operator ::expr*() const { return to_expr(raw()); } + unsigned get_id() const {return to_expr(raw())->get_id();} + + sort get_sort() const { return sort(ctx(),m().get_sort(to_expr(raw()))); } + + bool is_bool() const { return get_sort().is_bool(); } + bool is_int() const { return get_sort().is_int(); } + bool is_real() const { return get_sort().is_real(); } + bool is_arith() const { return get_sort().is_arith(); } + bool is_array() const { return get_sort().is_array(); } + bool is_datatype() const { return get_sort().is_datatype(); } + bool is_relation() const { return get_sort().is_relation(); } + bool is_finite_domain() const { return get_sort().is_finite_domain(); } + + bool is_numeral() const { + return is_app() && decl().get_decl_kind() == OtherArith && m().is_unique_value(to_expr(raw())); + } + bool is_app() const {return raw()->get_kind() == AST_APP;} + bool is_quantifier() const {return raw()->get_kind() == AST_QUANTIFIER;} + bool is_var() const {return raw()->get_kind() == AST_VAR;} + bool is_label (bool &pos,std::vector &names) const ; + bool is_ground() const {return to_app(raw())->is_ground();} + bool has_quantifiers() const {return to_app(raw())->has_quantifiers();} + + // operator Z3_app() const { assert(is_app()); return reinterpret_cast(m_ast); } + func_decl decl() const {return func_decl(ctx(),to_app(raw())->get_decl());} + unsigned num_args() const { + ast_kind dk = raw()->get_kind(); + switch(dk){ + case AST_APP: + return to_app(raw())->get_num_args(); + case AST_QUANTIFIER: + return 1; + case AST_VAR: + return 0; + default:; + } + SASSERT(0); + return 0; + } + expr arg(unsigned i) const { + ast_kind dk = raw()->get_kind(); + switch(dk){ + case AST_APP: + return ctx().cook(to_app(raw())->get_arg(i)); + case AST_QUANTIFIER: + return ctx().cook(to_quantifier(raw())->get_expr()); + default:; + } + assert(0); + return expr(); + } + + expr body() const { + return ctx().cook(to_quantifier(raw())->get_expr()); + } + + friend expr operator!(expr const & a) { + // ::expr *e = a; + return expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_NOT,a)); + } + + friend expr operator&&(expr const & a, expr const & b) { + return expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_AND,a,b)); + } + + friend expr operator||(expr const & a, expr const & b) { + return expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_OR,a,b)); + } + + friend expr implies(expr const & a, expr const & b) { + return expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_IMPLIES,a,b)); + } + + friend expr operator==(expr const & a, expr const & b) { + return expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_EQ,a,b)); + } + + friend expr operator!=(expr const & a, expr const & b) { + return expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_DISTINCT,a,b)); + } + + friend expr operator+(expr const & a, expr const & b) { + return a.ctx().make(Plus,a,b); // expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_ADD,a,b)); + } + + friend expr operator*(expr const & a, expr const & b) { + return a.ctx().make(Times,a,b); // expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_MUL,a,b)); + } + + friend expr operator/(expr const & a, expr const & b) { + return a.ctx().make(Div,a,b); // expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_DIV,a,b)); + } + + friend expr operator-(expr const & a) { + return a.ctx().make(Uminus,a); // expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_UMINUS,a)); + } + + friend expr operator-(expr const & a, expr const & b) { + return a.ctx().make(Sub,a,b); // expr(a.ctx(),a.m().mk_app(a.ctx().m_arith_fid,OP_SUB,a,b)); + } + + friend expr operator<=(expr const & a, expr const & b) { + return a.ctx().make(Leq,a,b); // expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_LE,a,b)); + } + + friend expr operator>=(expr const & a, expr const & b) { + return a.ctx().make(Geq,a,b); //expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_GE,a,b)); + } + + friend expr operator<(expr const & a, expr const & b) { + return a.ctx().make(Lt,a,b); expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_LT,a,b)); + } + + friend expr operator>(expr const & a, expr const & b) { + return a.ctx().make(Gt,a,b); expr(a.ctx(),a.m().mk_app(a.m().get_basic_family_id(),OP_GT,a,b)); + } + + expr simplify() const; + + expr simplify(params const & p) const; + + expr qe_lite() const; + + friend expr clone_quantifier(const expr &, const expr &); + + friend expr clone_quantifier(const expr &q, const expr &b, const std::vector &patterns); + + friend std::ostream & operator<<(std::ostream & out, expr const & m){ + m.ctx().print_expr(out,m); + return out; + } + + void get_patterns(std::vector &pats) const ; + + unsigned get_quantifier_num_bound() const { + return to_quantifier(raw())->get_num_decls(); + } + + unsigned get_index_value() const { + var* va = to_var(raw()); + return va->get_idx(); + } + + bool is_quantifier_forall() const { + return to_quantifier(raw())->is_forall(); + } + + sort get_quantifier_bound_sort(unsigned n) const { + return sort(ctx(),to_quantifier(raw())->get_decl_sort(n)); + } + + symbol get_quantifier_bound_name(unsigned n) const { + return symbol(ctx(),to_quantifier(raw())->get_decl_names()[n]); + } + + friend expr forall(const std::vector &quants, const expr &body); + + friend expr exists(const std::vector &quants, const expr &body); + + }; + + + typedef ::decl_kind pfrule; + + class proof : public ast { + public: + proof(context & c):ast(c) {} + proof(context & c, ::proof *s):ast(c, s) {} + proof(proof const & s):ast(s) {} + operator ::proof*() const { return to_app(raw()); } + proof & operator=(proof const & s) { return static_cast(ast::operator=(s)); } + + pfrule rule() const { + ::func_decl *d = to_app(raw())->get_decl(); + return d->get_decl_kind(); + } + + unsigned num_prems() const { + return to_app(raw())->get_num_args() - 1; + } + + expr conc() const { + return ctx().cook(to_app(raw())->get_arg(num_prems())); + } + + proof prem(unsigned i) const { + return proof(ctx(),to_app(to_app(raw())->get_arg(i))); + } + + void get_assumptions(std::vector &assumps); + }; + +#if 0 + +#if Z3_MAJOR_VERSION > 4 || Z3_MAJOR_VERSION == 4 && Z3_MINOR_VERSION >= 3 + template + class ast_vector_tpl : public object { + Z3_ast_vector m_vector; + void init(Z3_ast_vector v) { Z3_ast_vector_inc_ref(ctx(), v); m_vector = v; } + public: + ast_vector_tpl(context & c):object(c) { init(Z3_mk_ast_vector(c)); } + ast_vector_tpl(context & c, Z3_ast_vector v):object(c) { init(v); } + ast_vector_tpl(ast_vector_tpl const & s):object(s), m_vector(s.m_vector) { Z3_ast_vector_inc_ref(ctx(), m_vector); } + ~ast_vector_tpl() { /* Z3_ast_vector_dec_ref(ctx(), m_vector); */ } + operator Z3_ast_vector() const { return m_vector; } + unsigned size() const { return Z3_ast_vector_size(ctx(), m_vector); } + T operator[](unsigned i) const { Z3_ast r = Z3_ast_vector_get(ctx(), m_vector, i); check_error(); return cast_ast()(ctx(), r); } + void push_back(T const & e) { Z3_ast_vector_push(ctx(), m_vector, e); check_error(); } + void resize(unsigned sz) { Z3_ast_vector_resize(ctx(), m_vector, sz); check_error(); } + T back() const { return operator[](size() - 1); } + void pop_back() { assert(size() > 0); resize(size() - 1); } + bool empty() const { return size() == 0; } + ast_vector_tpl & operator=(ast_vector_tpl const & s) { + Z3_ast_vector_inc_ref(s.ctx(), s.m_vector); + // Z3_ast_vector_dec_ref(ctx(), m_vector); + m_ctx = s.m_ctx; + m_vector = s.m_vector; + return *this; + } + friend std::ostream & operator<<(std::ostream & out, ast_vector_tpl const & v) { out << Z3_ast_vector_to_string(v.ctx(), v); return out; } + }; + + typedef ast_vector_tpl ast_vector; + typedef ast_vector_tpl expr_vector; + typedef ast_vector_tpl sort_vector; + typedef ast_vector_tpl func_decl_vector; + +#endif + +#endif + + class func_interp : public object { + ::func_interp * m_interp; + void init(::func_interp * e) { + m_interp = e; + } + public: + func_interp(context & c, ::func_interp * e):object(c) { init(e); } + func_interp(func_interp const & s):object(s) { init(s.m_interp); } + ~func_interp() { } + operator ::func_interp *() const { return m_interp; } + func_interp & operator=(func_interp const & s) { + m_ctx = s.m_ctx; + m_interp = s.m_interp; + return *this; + } + unsigned num_entries() const { return m_interp->num_entries(); } + expr get_arg(unsigned ent, unsigned arg) const { + return expr(ctx(),m_interp->get_entry(ent)->get_arg(arg)); + } + expr get_value(unsigned ent) const { + return expr(ctx(),m_interp->get_entry(ent)->get_result()); + } + expr else_value() const { + return expr(ctx(),m_interp->get_else()); + } + }; + + + + class model : public object { + model_ref m_model; + void init(::model *m) { + m_model = m; + } + public: + model(context & c, ::model * m = 0):object(c), m_model(m) { } + model(model const & s):object(s), m_model(s.m_model) { } + ~model() { } + operator ::model *() const { return m_model.get(); } + model & operator=(model const & s) { + // ::model *_inc_ref(s.ctx(), s.m_model); + // ::model *_dec_ref(ctx(), m_model); + m_ctx = s.m_ctx; + m_model = s.m_model.get(); + return *this; + } + model & operator=(::model *s) { + m_model = s; + return *this; + } + + expr eval(expr const & n, bool model_completion=true) const { + ::model * _m = m_model.get(); + expr_ref result(ctx().m()); + _m->eval(n, result, model_completion); + return expr(ctx(), result); + } + + void show() const; + void show_hash() const; + + unsigned num_consts() const {return m_model.get()->get_num_constants();} + unsigned num_funcs() const {return m_model.get()->get_num_functions();} + func_decl get_const_decl(unsigned i) const {return func_decl(ctx(),m_model.get()->get_constant(i));} + func_decl get_func_decl(unsigned i) const {return func_decl(ctx(),m_model.get()->get_function(i));} + unsigned size() const; + func_decl operator[](unsigned i) const; + + expr get_const_interp(func_decl f) const { + return ctx().cook(m_model->get_const_interp(to_func_decl(f.raw()))); + } + + func_interp get_func_interp(func_decl f) const { + return func_interp(ctx(),m_model->get_func_interp(to_func_decl(f.raw()))); + } + +#if 0 + friend std::ostream & operator<<(std::ostream & out, model const & m) { out << Z3_model_to_string(m.ctx(), m); return out; } +#endif + }; + +#if 0 + class stats : public object { + Z3_stats m_stats; + void init(Z3_stats e) { + m_stats = e; + Z3_stats_inc_ref(ctx(), m_stats); + } + public: + stats(context & c):object(c), m_stats(0) {} + stats(context & c, Z3_stats e):object(c) { init(e); } + stats(stats const & s):object(s) { init(s.m_stats); } + ~stats() { if (m_stats) Z3_stats_dec_ref(ctx(), m_stats); } + operator Z3_stats() const { return m_stats; } + stats & operator=(stats const & s) { + Z3_stats_inc_ref(s.ctx(), s.m_stats); + if (m_stats) Z3_stats_dec_ref(ctx(), m_stats); + m_ctx = s.m_ctx; + m_stats = s.m_stats; + return *this; + } + unsigned size() const { return Z3_stats_size(ctx(), m_stats); } + std::string key(unsigned i) const { Z3_string s = Z3_stats_get_key(ctx(), m_stats, i); check_error(); return s; } + bool is_uint(unsigned i) const { Z3_bool r = Z3_stats_is_uint(ctx(), m_stats, i); check_error(); return r != 0; } + bool is_double(unsigned i) const { Z3_bool r = Z3_stats_is_double(ctx(), m_stats, i); check_error(); return r != 0; } + unsigned uint_value(unsigned i) const { unsigned r = Z3_stats_get_uint_value(ctx(), m_stats, i); check_error(); return r; } + double double_value(unsigned i) const { double r = Z3_stats_get_double_value(ctx(), m_stats, i); check_error(); return r; } + friend std::ostream & operator<<(std::ostream & out, stats const & s) { out << Z3_stats_to_string(s.ctx(), s); return out; } + }; +#endif + + enum check_result { + unsat, sat, unknown + }; + + class fixedpoint : public object { + + public: + void get_rules(std::vector &rules); + void get_assertions(std::vector &rules); + void register_relation(const func_decl &rela); + void add_rule(const expr &clause, const symbol &name); + void assert_cnst(const expr &cnst); + }; + + inline std::ostream & operator<<(std::ostream & out, check_result r) { + if (r == unsat) out << "unsat"; + else if (r == sat) out << "sat"; + else out << "unknown"; + return out; + } + + inline check_result to_check_result(lbool l) { + if (l == l_true) return sat; + else if (l == l_false) return unsat; + return unknown; + } + + class solver : public object { + protected: + ::solver *m_solver; + model the_model; + bool canceled; + public: + solver(context & c, bool extensional = false); + solver(context & c, ::solver *s):object(c),the_model(c) { m_solver = s; canceled = false;} + solver(solver const & s):object(s), the_model(s.the_model) { m_solver = s.m_solver; canceled = false;} + ~solver() { + if(m_solver) + dealloc(m_solver); + } + operator ::solver*() const { return m_solver; } + solver & operator=(solver const & s) { + m_ctx = s.m_ctx; + m_solver = s.m_solver; + the_model = s.the_model; + return *this; + } + struct cancel_exception {}; + void checkpoint(){ + if(canceled) + throw(cancel_exception()); + } + // void set(params const & p) { Z3_solver_set_params(ctx(), m_solver, p); check_error(); } + void push() { m_solver->push(); } + void pop(unsigned n = 1) { m_solver->pop(n); } + // void reset() { Z3_solver_reset(ctx(), m_solver); check_error(); } + void add(expr const & e) { m_solver->assert_expr(e); } + check_result check() { + checkpoint(); + lbool r = m_solver->check_sat(0,0); + model_ref m; + m_solver->get_model(m); + the_model = m.get(); + return to_check_result(r); + } + check_result check_keep_model(unsigned n, expr * const assumptions, unsigned *core_size = 0, expr *core = 0) { + model old_model(the_model); + check_result res = check(n,assumptions,core_size,core); + if(the_model == 0) + the_model = old_model; + return res; + } + check_result check(unsigned n, expr * const assumptions, unsigned *core_size = 0, expr *core = 0) { + checkpoint(); + std::vector< ::expr *> _assumptions(n); + for (unsigned i = 0; i < n; i++) { + _assumptions[i] = to_expr(assumptions[i]); + } + the_model = 0; + lbool r = m_solver->check_sat(n,&_assumptions[0]); + + if(core_size && core){ + ptr_vector< ::expr> _core; + m_solver->get_unsat_core(_core); + *core_size = _core.size(); + for(unsigned i = 0; i < *core_size; i++) + core[i] = expr(ctx(),_core[i]); + } + + model_ref m; + m_solver->get_model(m); + the_model = m.get(); + + return to_check_result(r); + } +#if 0 + check_result check(expr_vector assumptions) { + unsigned n = assumptions.size(); + z3array _assumptions(n); + for (unsigned i = 0; i < n; i++) { + check_context(*this, assumptions[i]); + _assumptions[i] = assumptions[i]; + } + Z3_lbool r = Z3_check_assumptions(ctx(), m_solver, n, _assumptions.ptr()); + check_error(); + return to_check_result(r); + } +#endif + model get_model() const { return model(ctx(), the_model); } + // std::string reason_unknown() const { Z3_string r = Z3_solver_get_reason_unknown(ctx(), m_solver); check_error(); return r; } + // stats statistics() const { Z3_stats r = Z3_solver_get_statistics(ctx(), m_solver); check_error(); return stats(ctx(), r); } +#if 0 + expr_vector unsat_core() const { Z3_ast_vector r = Z3_solver_get_unsat_core(ctx(), m_solver); check_error(); return expr_vector(ctx(), r); } + expr_vector assertions() const { Z3_ast_vector r = Z3_solver_get_assertions(ctx(), m_solver); check_error(); return expr_vector(ctx(), r); } +#endif + // expr proof() const { Z3_ast r = Z3_solver_proof(ctx(), m_solver); check_error(); return expr(ctx(), r); } + // friend std::ostream & operator<<(std::ostream & out, solver const & s) { out << Z3_solver_to_string(s.ctx(), s); return out; } + + int get_num_decisions(); + + void cancel(){ + canceled = true; + if(m_solver) + m_solver->cancel(); + } + + unsigned get_scope_level(){return m_solver->get_scope_level();} + + void show(); + void show_assertion_ids(); + + proof get_proof(){ + return proof(ctx(),m_solver->get_proof()); + } + + }; + +#if 0 + class goal : public object { + Z3_goal m_goal; + void init(Z3_goal s) { + m_goal = s; + Z3_goal_inc_ref(ctx(), s); + } + public: + goal(context & c, bool models=true, bool unsat_cores=false, bool proofs=false):object(c) { init(Z3_mk_goal(c, models, unsat_cores, proofs)); } + goal(context & c, Z3_goal s):object(c) { init(s); } + goal(goal const & s):object(s) { init(s.m_goal); } + ~goal() { Z3_goal_dec_ref(ctx(), m_goal); } + operator Z3_goal() const { return m_goal; } + goal & operator=(goal const & s) { + Z3_goal_inc_ref(s.ctx(), s.m_goal); + Z3_goal_dec_ref(ctx(), m_goal); + m_ctx = s.m_ctx; + m_goal = s.m_goal; + return *this; + } + void add(expr const & f) { check_context(*this, f); Z3_goal_assert(ctx(), m_goal, f); check_error(); } + unsigned size() const { return Z3_goal_size(ctx(), m_goal); } + expr operator[](unsigned i) const { Z3_ast r = Z3_goal_formula(ctx(), m_goal, i); check_error(); return expr(ctx(), r); } + Z3_goal_prec precision() const { return Z3_goal_precision(ctx(), m_goal); } + bool inconsistent() const { return Z3_goal_inconsistent(ctx(), m_goal) != 0; } + unsigned depth() const { return Z3_goal_depth(ctx(), m_goal); } + void reset() { Z3_goal_reset(ctx(), m_goal); } + unsigned num_exprs() const { Z3_goal_num_exprs(ctx(), m_goal); } + bool is_decided_sat() const { return Z3_goal_is_decided_sat(ctx(), m_goal) != 0; } + bool is_decided_unsat() const { return Z3_goal_is_decided_unsat(ctx(), m_goal) != 0; } + friend std::ostream & operator<<(std::ostream & out, goal const & g) { out << Z3_goal_to_string(g.ctx(), g); return out; } + }; + + class apply_result : public object { + Z3_apply_result m_apply_result; + void init(Z3_apply_result s) { + m_apply_result = s; + Z3_apply_result_inc_ref(ctx(), s); + } + public: + apply_result(context & c, Z3_apply_result s):object(c) { init(s); } + apply_result(apply_result const & s):object(s) { init(s.m_apply_result); } + ~apply_result() { Z3_apply_result_dec_ref(ctx(), m_apply_result); } + operator Z3_apply_result() const { return m_apply_result; } + apply_result & operator=(apply_result const & s) { + Z3_apply_result_inc_ref(s.ctx(), s.m_apply_result); + Z3_apply_result_dec_ref(ctx(), m_apply_result); + m_ctx = s.m_ctx; + m_apply_result = s.m_apply_result; + return *this; + } + unsigned size() const { return Z3_apply_result_get_num_subgoals(ctx(), m_apply_result); } + goal operator[](unsigned i) const { Z3_goal r = Z3_apply_result_get_subgoal(ctx(), m_apply_result, i); check_error(); return goal(ctx(), r); } + goal operator[](int i) const { assert(i >= 0); return this->operator[](static_cast(i)); } + model convert_model(model const & m, unsigned i = 0) const { + check_context(*this, m); + Z3_model new_m = Z3_apply_result_convert_model(ctx(), m_apply_result, i, m); + check_error(); + return model(ctx(), new_m); + } + friend std::ostream & operator<<(std::ostream & out, apply_result const & r) { out << Z3_apply_result_to_string(r.ctx(), r); return out; } + }; + + class tactic : public object { + Z3_tactic m_tactic; + void init(Z3_tactic s) { + m_tactic = s; + Z3_tactic_inc_ref(ctx(), s); + } + public: + tactic(context & c, char const * name):object(c) { Z3_tactic r = Z3_mk_tactic(c, name); check_error(); init(r); } + tactic(context & c, Z3_tactic s):object(c) { init(s); } + tactic(tactic const & s):object(s) { init(s.m_tactic); } + ~tactic() { Z3_tactic_dec_ref(ctx(), m_tactic); } + operator Z3_tactic() const { return m_tactic; } + tactic & operator=(tactic const & s) { + Z3_tactic_inc_ref(s.ctx(), s.m_tactic); + Z3_tactic_dec_ref(ctx(), m_tactic); + m_ctx = s.m_ctx; + m_tactic = s.m_tactic; + return *this; + } + solver mk_solver() const { Z3_solver r = Z3_mk_solver_from_tactic(ctx(), m_tactic); check_error(); return solver(ctx(), r); } + apply_result apply(goal const & g) const { + check_context(*this, g); + Z3_apply_result r = Z3_tactic_apply(ctx(), m_tactic, g); + check_error(); + return apply_result(ctx(), r); + } + apply_result operator()(goal const & g) const { + return apply(g); + } + std::string help() const { char const * r = Z3_tactic_get_help(ctx(), m_tactic); check_error(); return r; } + friend tactic operator&(tactic const & t1, tactic const & t2) { + check_context(t1, t2); + Z3_tactic r = Z3_tactic_and_then(t1.ctx(), t1, t2); + t1.check_error(); + return tactic(t1.ctx(), r); + } + friend tactic operator|(tactic const & t1, tactic const & t2) { + check_context(t1, t2); + Z3_tactic r = Z3_tactic_or_else(t1.ctx(), t1, t2); + t1.check_error(); + return tactic(t1.ctx(), r); + } + friend tactic repeat(tactic const & t, unsigned max=UINT_MAX) { + Z3_tactic r = Z3_tactic_repeat(t.ctx(), t, max); + t.check_error(); + return tactic(t.ctx(), r); + } + friend tactic with(tactic const & t, params const & p) { + Z3_tactic r = Z3_tactic_using_params(t.ctx(), t, p); + t.check_error(); + return tactic(t.ctx(), r); + } + friend tactic try_for(tactic const & t, unsigned ms) { + Z3_tactic r = Z3_tactic_try_for(t.ctx(), t, ms); + t.check_error(); + return tactic(t.ctx(), r); + } + }; + + class probe : public object { + Z3_probe m_probe; + void init(Z3_probe s) { + m_probe = s; + Z3_probe_inc_ref(ctx(), s); + } + public: + probe(context & c, char const * name):object(c) { Z3_probe r = Z3_mk_probe(c, name); check_error(); init(r); } + probe(context & c, double val):object(c) { Z3_probe r = Z3_probe_const(c, val); check_error(); init(r); } + probe(context & c, Z3_probe s):object(c) { init(s); } + probe(probe const & s):object(s) { init(s.m_probe); } + ~probe() { Z3_probe_dec_ref(ctx(), m_probe); } + operator Z3_probe() const { return m_probe; } + probe & operator=(probe const & s) { + Z3_probe_inc_ref(s.ctx(), s.m_probe); + Z3_probe_dec_ref(ctx(), m_probe); + m_ctx = s.m_ctx; + m_probe = s.m_probe; + return *this; + } + double apply(goal const & g) const { double r = Z3_probe_apply(ctx(), m_probe, g); check_error(); return r; } + double operator()(goal const & g) const { return apply(g); } + friend probe operator<=(probe const & p1, probe const & p2) { + check_context(p1, p2); Z3_probe r = Z3_probe_le(p1.ctx(), p1, p2); p1.check_error(); return probe(p1.ctx(), r); + } + friend probe operator<=(probe const & p1, double p2) { return p1 <= probe(p1.ctx(), p2); } + friend probe operator<=(double p1, probe const & p2) { return probe(p2.ctx(), p1) <= p2; } + friend probe operator>=(probe const & p1, probe const & p2) { + check_context(p1, p2); Z3_probe r = Z3_probe_ge(p1.ctx(), p1, p2); p1.check_error(); return probe(p1.ctx(), r); + } + friend probe operator>=(probe const & p1, double p2) { return p1 >= probe(p1.ctx(), p2); } + friend probe operator>=(double p1, probe const & p2) { return probe(p2.ctx(), p1) >= p2; } + friend probe operator<(probe const & p1, probe const & p2) { + check_context(p1, p2); Z3_probe r = Z3_probe_lt(p1.ctx(), p1, p2); p1.check_error(); return probe(p1.ctx(), r); + } + friend probe operator<(probe const & p1, double p2) { return p1 < probe(p1.ctx(), p2); } + friend probe operator<(double p1, probe const & p2) { return probe(p2.ctx(), p1) < p2; } + friend probe operator>(probe const & p1, probe const & p2) { + check_context(p1, p2); Z3_probe r = Z3_probe_gt(p1.ctx(), p1, p2); p1.check_error(); return probe(p1.ctx(), r); + } + friend probe operator>(probe const & p1, double p2) { return p1 > probe(p1.ctx(), p2); } + friend probe operator>(double p1, probe const & p2) { return probe(p2.ctx(), p1) > p2; } + friend probe operator==(probe const & p1, probe const & p2) { + check_context(p1, p2); Z3_probe r = Z3_probe_eq(p1.ctx(), p1, p2); p1.check_error(); return probe(p1.ctx(), r); + } + friend probe operator==(probe const & p1, double p2) { return p1 == probe(p1.ctx(), p2); } + friend probe operator==(double p1, probe const & p2) { return probe(p2.ctx(), p1) == p2; } + friend probe operator&&(probe const & p1, probe const & p2) { + check_context(p1, p2); Z3_probe r = Z3_probe_and(p1.ctx(), p1, p2); p1.check_error(); return probe(p1.ctx(), r); + } + friend probe operator||(probe const & p1, probe const & p2) { + check_context(p1, p2); Z3_probe r = Z3_probe_or(p1.ctx(), p1, p2); p1.check_error(); return probe(p1.ctx(), r); + } + friend probe operator!(probe const & p) { + Z3_probe r = Z3_probe_not(p.ctx(), p); p.check_error(); return probe(p.ctx(), r); + } + }; + + inline tactic fail_if(probe const & p) { + Z3_tactic r = Z3_tactic_fail_if(p.ctx(), p); + p.check_error(); + return tactic(p.ctx(), r); + } + inline tactic when(probe const & p, tactic const & t) { + check_context(p, t); + Z3_tactic r = Z3_tactic_when(t.ctx(), p, t); + t.check_error(); + return tactic(t.ctx(), r); + } + inline tactic cond(probe const & p, tactic const & t1, tactic const & t2) { + check_context(p, t1); check_context(p, t2); + Z3_tactic r = Z3_tactic_cond(t1.ctx(), p, t1, t2); + t1.check_error(); + return tactic(t1.ctx(), r); + } + +#endif + + inline expr context::bool_val(bool b){return b ? make(True) : make(False);} + + inline symbol context::str_symbol(char const * s) { ::symbol r = ::symbol(s); return symbol(*this, r); } + inline symbol context::int_symbol(int n) { ::symbol r = ::symbol(n); return symbol(*this, r); } + + inline sort context::bool_sort() { + ::sort *s = m().mk_sort(m_basic_fid, BOOL_SORT); + return sort(*this, s); + } + inline sort context::int_sort() { + ::sort *s = m().mk_sort(m_arith_fid, INT_SORT); + return sort(*this, s); + } + inline sort context::real_sort() { + ::sort *s = m().mk_sort(m_arith_fid, REAL_SORT); + return sort(*this, s); + } + inline sort context::array_sort(sort d, sort r) { + parameter params[2] = { parameter(d), parameter(to_sort(r)) }; + ::sort * s = m().mk_sort(m_array_fid, ARRAY_SORT, 2, params); + return sort(*this, s); + } + + + inline func_decl context::function(symbol const & name, unsigned arity, sort const * domain, sort const & range) { + std::vector< ::sort *> sv(arity); + for(unsigned i = 0; i < arity; i++) + sv[i] = domain[i]; + ::func_decl* d = m().mk_func_decl(name,arity,&sv[0],range); + return func_decl(*this,d); + } + + inline func_decl context::function(char const * name, unsigned arity, sort const * domain, sort const & range) { + return function(str_symbol(name), arity, domain, range); + } + + inline func_decl context::function(char const * name, sort const & domain, sort const & range) { + sort args[1] = { domain }; + return function(name, 1, args, range); + } + + inline func_decl context::function(char const * name, sort const & d1, sort const & d2, sort const & range) { + sort args[2] = { d1, d2 }; + return function(name, 2, args, range); + } + + inline func_decl context::function(char const * name, sort const & d1, sort const & d2, sort const & d3, sort const & range) { + sort args[3] = { d1, d2, d3 }; + return function(name, 3, args, range); + } + + inline func_decl context::function(char const * name, sort const & d1, sort const & d2, sort const & d3, sort const & d4, sort const & range) { + sort args[4] = { d1, d2, d3, d4 }; + return function(name, 4, args, range); + } + + inline func_decl context::function(char const * name, sort const & d1, sort const & d2, sort const & d3, sort const & d4, sort const & d5, sort const & range) { + sort args[5] = { d1, d2, d3, d4, d5 }; + return function(name, 5, args, range); + } + + + inline expr context::constant(symbol const & name, sort const & s) { + ::expr *r = m().mk_const(m().mk_const_decl(name, s)); + return expr(*this, r); + } + inline expr context::constant(char const * name, sort const & s) { return constant(str_symbol(name), s); } + inline expr context::bool_const(char const * name) { return constant(name, bool_sort()); } + inline expr context::int_const(char const * name) { return constant(name, int_sort()); } + inline expr context::real_const(char const * name) { return constant(name, real_sort()); } + inline expr context::bv_const(char const * name, unsigned sz) { return constant(name, bv_sort(sz)); } + + inline expr func_decl::operator()(const std::vector &args) const { + return operator()(static_cast(args.size()),&args[0]); + } + inline expr func_decl::operator()() const { + return operator()(0,0); + } + inline expr func_decl::operator()(expr const & a) const { + return operator()(1,&a); + } + inline expr func_decl::operator()(expr const & a1, expr const & a2) const { + expr args[2] = {a1,a2}; + return operator()(2,args); + } + inline expr func_decl::operator()(expr const & a1, expr const & a2, expr const & a3) const { + expr args[3] = {a1,a2,a3}; + return operator()(3,args); + } + inline expr func_decl::operator()(expr const & a1, expr const & a2, expr const & a3, expr const & a4) const { + expr args[4] = {a1,a2,a3,a4}; + return operator()(4,args); + } + inline expr func_decl::operator()(expr const & a1, expr const & a2, expr const & a3, expr const & a4, expr const & a5) const { + expr args[5] = {a1,a2,a3,a4,a5}; + return operator()(5,args); + } + + + inline expr select(expr const & a, expr const & i) { return a.ctx().make(Select,a,i); } + inline expr store(expr const & a, expr const & i, expr const & v) { return a.ctx().make(Store,a,i,v); } + + inline expr forall(const std::vector &quants, const expr &body){ + return body.ctx().make_quant(Forall,quants,body); + } + + inline expr exists(const std::vector &quants, const expr &body){ + return body.ctx().make_quant(Exists,quants,body); + } + + inline expr context::int_val(int n){ + :: sort *r = m().mk_sort(m_arith_fid, INT_SORT); + return cook(m_arith_util.mk_numeral(rational(n),r)); + } + + + class literals : public object { + }; + + class TermTree { + public: + + TermTree(expr _term){ + term = _term; + } + + TermTree(expr _term, const std::vector &_children){ + term = _term; + children = _children; + } + + inline expr getTerm(){return term;} + + inline std::vector &getTerms(){return terms;} + + inline std::vector &getChildren(){ + return children; + } + + inline int number(int from){ + for(unsigned i = 0; i < children.size(); i++) + from = children[i]->number(from); + num = from; + return from + 1; + } + + inline int getNumber(){ + return num; + } + + inline void setTerm(expr t){term = t;} + + inline void addTerm(expr t){terms.push_back(t);} + + inline void setChildren(const std::vector & _children){ + children = _children; + } + + inline void setNumber(int _num){ + num = _num; + } + + ~TermTree(){ + for(unsigned i = 0; i < children.size(); i++) + delete children[i]; + } + + private: + expr term; + std::vector terms; + std::vector children; + int num; + }; + + typedef context interpolating_context; + + class interpolating_solver : public solver { + public: + interpolating_solver(context &ctx) + : solver(ctx) + { + weak_mode = false; + } + + public: + lbool interpolate(const std::vector &assumptions, + std::vector &interpolants, + model &_model, + literals &lits, + bool incremental + ); + + lbool interpolate_tree(TermTree *assumptions, + TermTree *&interpolants, + model &_model, + literals &lits, + bool incremental + ); + + bool read_interpolation_problem(const std::string &file_name, + std::vector &assumptions, + std::vector &theory, + std::string &error_message + ); + + void write_interpolation_problem(const std::string &file_name, + const std::vector &assumptions, + const std::vector &theory + ); + + void AssertInterpolationAxiom(const expr &expr); + void RemoveInterpolationAxiom(const expr &expr); + + void SetWeakInterpolants(bool weak); + void SetPrintToFile(const std::string &file_name); + + const std::vector &GetInterpolationAxioms() {return theory;} + const char *profile(); + + private: + bool weak_mode; + std::string print_filename; + std::vector theory; + }; + + + inline expr context::cook(::expr *a) {return expr(*this,a);} + + inline std::vector context::cook(ptr_vector< ::expr> v) { + std::vector _v(v.size()); + for(unsigned i = 0; i < v.size(); i++) + _v[i] = cook(v[i]); + return _v; + } + + inline ::expr *context::uncook(const expr &a) { + m().inc_ref(a.raw()); + return to_expr(a.raw()); + } + + typedef double clock_t; + clock_t current_time(); + inline void output_time(std::ostream &os, clock_t time){os << time;} +}; + +// to make Duality::ast hashable +namespace hash_space { + template <> + class hash { + public: + size_t operator()(const Duality::ast &s) const { + return s.raw()->get_id(); + } + }; +} + +// to make Duality::ast hashable in windows +#ifdef WIN32 +template <> inline +size_t stdext::hash_value(const Duality::ast& s) +{ + return s.raw()->get_id(); +} +#endif + +// to make Duality::ast usable in ordered collections +namespace std { + template <> + class less { + public: + bool operator()(const Duality::ast &s, const Duality::ast &t) const { + // return s.raw() < t.raw(); + return s.raw()->get_id() < t.raw()->get_id(); + } + }; +} + +// to make Duality::func_decl hashable +namespace hash_space { + template <> + class hash { + public: + size_t operator()(const Duality::func_decl &s) const { + return s.raw()->get_id(); + } + }; +} + +// to make Duality::func_decl hashable in windows +#ifdef WIN32 +template <> inline +size_t stdext::hash_value(const Duality::func_decl& s) +{ + return s.raw()->get_id(); +} +#endif + +// to make Duality::func_decl usable in ordered collections +namespace std { + template <> + class less { + public: + bool operator()(const Duality::func_decl &s, const Duality::func_decl &t) const { + // return s.raw() < t.raw(); + return s.raw()->get_id() < t.raw()->get_id(); + } + }; +} + + +#endif + diff --git a/src/interp/foci2.h b/src/interp/foci2.h new file mode 100755 index 000000000..261dd05dc --- /dev/null +++ b/src/interp/foci2.h @@ -0,0 +1,75 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + foci2.h + +Abstract: + + An interface class for foci2. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#ifndef FOCI2_H +#define FOCI2_H + +#include +#include + +#ifdef WIN32 +#define FOCI2_EXPORT __declspec(dllexport) +#else +#define FOCI2_EXPORT __attribute__ ((visibility ("default"))) +#endif + +class foci2 { + public: + virtual ~foci2(){} + + typedef int ast; + typedef int symb; + + /** Built-in operators */ + enum ops { + And = 0, Or, Not, Iff, Ite, Equal, Plus, Times, Floor, Leq, Div, Bool, Int, Array, Tsym, Fsym, Forall, Exists, Distinct, LastOp + }; + + virtual symb mk_func(const std::string &s) = 0; + virtual symb mk_pred(const std::string &s) = 0; + virtual ast mk_op(ops op, const std::vector args) = 0; + virtual ast mk_op(ops op, ast) = 0; + virtual ast mk_op(ops op, ast, ast) = 0; + virtual ast mk_op(ops op, ast, ast, ast) = 0; + virtual ast mk_int(const std::string &) = 0; + virtual ast mk_rat(const std::string &) = 0; + virtual ast mk_true() = 0; + virtual ast mk_false() = 0; + virtual ast mk_app(symb,const std::vector args) = 0; + + virtual bool get_func(ast, symb &) = 0; + virtual bool get_pred(ast, symb &) = 0; + virtual bool get_op(ast, ops &) = 0; + virtual bool get_true(ast id) = 0; + virtual bool get_false(ast id) = 0; + virtual bool get_int(ast id, std::string &res) = 0; + virtual bool get_rat(ast id, std::string &res) = 0; + virtual const std::string &get_symb(symb) = 0; + + virtual int get_num_args(ast) = 0; + virtual ast get_arg(ast, int) = 0; + + virtual void show_ast(ast) = 0; + + virtual bool interpolate(const std::vector &frames, std::vector &itps, std::vector parents) = 0; + + FOCI2_EXPORT static foci2 *create(const std::string &); +}; + +#endif diff --git a/src/interp/foci2stub/foci2.cpp b/src/interp/foci2stub/foci2.cpp new file mode 100644 index 000000000..31908855b --- /dev/null +++ b/src/interp/foci2stub/foci2.cpp @@ -0,0 +1,25 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + foci2.cpp + +Abstract: + + Fake foci2, to be replaced + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + + +#include "foci2.h" + +FOCI2_EXPORT foci2 *foci2::create(const std::string &){ + return 0; +} diff --git a/src/interp/foci2stub/foci2.h b/src/interp/foci2stub/foci2.h new file mode 100755 index 000000000..261dd05dc --- /dev/null +++ b/src/interp/foci2stub/foci2.h @@ -0,0 +1,75 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + foci2.h + +Abstract: + + An interface class for foci2. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#ifndef FOCI2_H +#define FOCI2_H + +#include +#include + +#ifdef WIN32 +#define FOCI2_EXPORT __declspec(dllexport) +#else +#define FOCI2_EXPORT __attribute__ ((visibility ("default"))) +#endif + +class foci2 { + public: + virtual ~foci2(){} + + typedef int ast; + typedef int symb; + + /** Built-in operators */ + enum ops { + And = 0, Or, Not, Iff, Ite, Equal, Plus, Times, Floor, Leq, Div, Bool, Int, Array, Tsym, Fsym, Forall, Exists, Distinct, LastOp + }; + + virtual symb mk_func(const std::string &s) = 0; + virtual symb mk_pred(const std::string &s) = 0; + virtual ast mk_op(ops op, const std::vector args) = 0; + virtual ast mk_op(ops op, ast) = 0; + virtual ast mk_op(ops op, ast, ast) = 0; + virtual ast mk_op(ops op, ast, ast, ast) = 0; + virtual ast mk_int(const std::string &) = 0; + virtual ast mk_rat(const std::string &) = 0; + virtual ast mk_true() = 0; + virtual ast mk_false() = 0; + virtual ast mk_app(symb,const std::vector args) = 0; + + virtual bool get_func(ast, symb &) = 0; + virtual bool get_pred(ast, symb &) = 0; + virtual bool get_op(ast, ops &) = 0; + virtual bool get_true(ast id) = 0; + virtual bool get_false(ast id) = 0; + virtual bool get_int(ast id, std::string &res) = 0; + virtual bool get_rat(ast id, std::string &res) = 0; + virtual const std::string &get_symb(symb) = 0; + + virtual int get_num_args(ast) = 0; + virtual ast get_arg(ast, int) = 0; + + virtual void show_ast(ast) = 0; + + virtual bool interpolate(const std::vector &frames, std::vector &itps, std::vector parents) = 0; + + FOCI2_EXPORT static foci2 *create(const std::string &); +}; + +#endif diff --git a/src/interp/iz3base.cpp b/src/interp/iz3base.cpp new file mode 100755 index 000000000..f316c22cf --- /dev/null +++ b/src/interp/iz3base.cpp @@ -0,0 +1,352 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3base.cpp + +Abstract: + + Base class for interpolators. Includes an AST manager and a scoping + object as bases. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + + +#include "iz3base.h" +#include +#include +#include +#include +#include "solver.h" +#include "../smt/smt_solver.h" + + +#ifndef WIN32 +using namespace stl_ext; +#endif + + +iz3base::range &iz3base::ast_range(ast t){ + return ast_ranges_hash[t].rng; +} + +iz3base::range &iz3base::sym_range(symb d){ + return sym_range_hash[d]; +} + +void iz3base::add_frame_range(int frame, ast t){ + range &rng = ast_range(t); + if(!in_range(frame,rng)){ + range_add(frame,rng); + for(int i = 0, n = num_args(t); i < n; ++i) + add_frame_range(frame,arg(t,i)); + if(op(t) == Uninterpreted) + range_add(frame,sym_range(sym(t))); + } +} + +#if 1 +iz3base::range &iz3base::ast_scope(ast t){ + ranges &rngs = ast_ranges_hash[t]; + range &rng = rngs.scp; + if(!rngs.scope_computed){ // not computed yet + rng = range_full(); + for(int i = 0, n = num_args(t); i < n; ++i) + rng = range_glb(rng,ast_scope(arg(t,i))); + if(op(t) == Uninterpreted) + if(parents.empty() || num_args(t) == 0) // in tree mode, all function syms are global + rng = range_glb(rng,sym_range(sym(t))); + rngs.scope_computed = true; + } + return rng; +} +#else +iz3base::range &iz3base::ast_scope(ast t){ + ranges &rngs = ast_ranges_hash[t]; + if(rngs.scope_computed) return rngs.scp; + range rng = range_full(); + for(int i = 0, n = num_args(t); i < n; ++i) + rng = range_glb(rng,ast_scope(arg(t,i))); + if(op(t) == Uninterpreted) + if(parents.empty() || num_args(t) == 0) // in tree mode, all function syms are global + rng = range_glb(rng,sym_range(sym(t))); + rngs = ast_ranges_hash[t]; + rngs.scope_computed = true; + rngs.scp = rng; + return rngs.scp; +} +#endif + +void iz3base::print(const std::string &filename){ + ast t = make(And,cnsts); + std::ofstream f(filename.c_str()); + print_sat_problem(f,t); + f.close(); +} + + +void iz3base::gather_conjuncts_rec(ast n, std::vector &conjuncts, stl_ext::hash_set &memo){ + if(memo.find(n) == memo.end()){ + memo.insert(n); + if(op(n) == And){ + int nargs = num_args(n); + for(int i = 0; i < nargs; i++) + gather_conjuncts_rec(arg(n,i),conjuncts,memo); + } + else + conjuncts.push_back(n); + } +} + +void iz3base::gather_conjuncts(ast n, std::vector &conjuncts){ + hash_set memo; + gather_conjuncts_rec(n,conjuncts,memo); +} + +bool iz3base::is_literal(ast n){ + if(is_not(n))n = arg(n,0); + if(is_true(n) || is_false(n)) return false; + if(op(n) == And) return false; + return true; +} + +iz3base::ast iz3base::simplify_and(std::vector &conjuncts){ + hash_set memo; + for(unsigned i = 0; i < conjuncts.size(); i++){ + if(is_false(conjuncts[i])) + return conjuncts[i]; + if(is_true(conjuncts[i]) || memo.find(conjuncts[i]) != memo.end()){ + std::swap(conjuncts[i],conjuncts.back()); + conjuncts.pop_back(); + } + else if(memo.find(mk_not(conjuncts[i])) != memo.end()) + return mk_false(); // contradiction! + else + memo.insert(conjuncts[i]); + } + if(conjuncts.empty())return mk_true(); + return make(And,conjuncts); +} + +iz3base::ast iz3base::simplify_with_lit_rec(ast n, ast lit, stl_ext::hash_map &memo, int depth){ + if(is_not(n))return mk_not(simplify_with_lit_rec(mk_not(n),lit,memo,depth)); + if(n == lit) return mk_true(); + ast not_lit = mk_not(lit); + if(n == not_lit) return mk_false(); + if(op(n) != And || depth <= 0) return n; + std::pair foo(n,ast()); + std::pair::iterator,bool> bar = memo.insert(foo); + ast &res = bar.first->second; + if(!bar.second) return res; + int nargs = num_args(n); + std::vector args(nargs); + for(int i = 0; i < nargs; i++) + args[i] = simplify_with_lit_rec(arg(n,i),lit,memo,depth-1); + res = simplify_and(args); + return res; +} + +iz3base::ast iz3base::simplify_with_lit(ast n, ast lit){ + hash_map memo; + return simplify_with_lit_rec(n,lit,memo,1); +} + +iz3base::ast iz3base::simplify(ast n){ + if(is_not(n)) return mk_not(simplify(mk_not(n))); + std::pair memo_obj(n,ast()); + std::pair::iterator,bool> memo = simplify_memo.insert(memo_obj); + ast &res = memo.first->second; + if(!memo.second) return res; + switch(op(n)){ + case And: { + std::vector conjuncts; + gather_conjuncts(n,conjuncts); + for(unsigned i = 0; i < conjuncts.size(); i++) + conjuncts[i] = simplify(conjuncts[i]); +#if 0 + for(unsigned i = 0; i < conjuncts.size(); i++) + if(is_literal(conjuncts[i])) + for(unsigned j = 0; j < conjuncts.size(); j++) + if(j != i) + conjuncts[j] = simplify_with_lit(conjuncts[j],conjuncts[i]); +#endif + res = simplify_and(conjuncts); + } + break; + case Equal: { + ast x = arg(n,0); + ast y = arg(n,1); + if(ast_id(x) > ast_id(y)) + std::swap(x,y); + res = make(Equal,x,y); + break; + } + default: + res = n; + } + return res; +} + +void iz3base::initialize(const std::vector &_parts, const std::vector &_parents, const std::vector &_theory){ + cnsts = _parts; + theory = _theory; + for(unsigned i = 0; i < cnsts.size(); i++) + add_frame_range(i, cnsts[i]); + for(unsigned i = 0; i < _theory.size(); i++){ + add_frame_range(SHRT_MIN, _theory[i]); + add_frame_range(SHRT_MAX, _theory[i]); + } + for(unsigned i = 0; i < cnsts.size(); i++) + frame_map[cnsts[i]] = i; + for(unsigned i = 0; i < theory.size(); i++) + frame_map[theory[i]] = INT_MAX; +} + +void iz3base::initialize(const std::vector > &_parts, const std::vector &_parents, const std::vector &_theory){ + cnsts.resize(_parts.size()); + theory = _theory; + for(unsigned i = 0; i < _parts.size(); i++) + for(unsigned j = 0; j < _parts[i].size(); j++){ + cnsts[i] = make(And,_parts[i]); + add_frame_range(i, _parts[i][j]); + frame_map[_parts[i][j]] = i; + } + for(unsigned i = 0; i < _theory.size(); i++){ + add_frame_range(SHRT_MIN, _theory[i]); + add_frame_range(SHRT_MAX, _theory[i]); + frame_map[theory[i]] = INT_MAX; + } +} + +void iz3base::check_interp(const std::vector &itps, std::vector &theory){ +#if 0 + Z3_config config = Z3_mk_config(); + Z3_context vctx = Z3_mk_context(config); + int frames = cnsts.size(); + std::vector foocnsts(cnsts); + for(unsigned i = 0; i < frames; i++) + foocnsts[i] = Z3_mk_implies(ctx,Z3_mk_true(ctx),cnsts[i]); + Z3_write_interpolation_problem(ctx,frames,&foocnsts[0],0, "temp_lemma.smt", theory.size(), &theory[0]); + int vframes,*vparents; + Z3_ast *vcnsts; + const char *verror; + bool ok = Z3_read_interpolation_problem(vctx,&vframes,&vcnsts,0,"temp_lemma.smt",&verror); + assert(ok); + std::vector vvcnsts(vframes); + std::copy(vcnsts,vcnsts+vframes,vvcnsts.begin()); + std::vector vitps(itps.size()); + for(unsigned i = 0; i < itps.size(); i++) + vitps[i] = Z3_mk_implies(ctx,Z3_mk_true(ctx),itps[i]); + Z3_write_interpolation_problem(ctx,itps.size(),&vitps[0],0,"temp_interp.smt"); + int iframes,*iparents; + Z3_ast *icnsts; + const char *ierror; + ok = Z3_read_interpolation_problem(vctx,&iframes,&icnsts,0,"temp_interp.smt",&ierror); + assert(ok); + const char *error = 0; + bool iok = Z3_check_interpolant(vctx, frames, &vvcnsts[0], parents.size() ? &parents[0] : 0, icnsts, &error); + assert(iok); +#endif +} + +bool iz3base::is_sat(const std::vector &q, ast &_proof){ + + params_ref p; + p.set_bool("proof", true); // this is currently useless + p.set_bool("model", true); + p.set_bool("unsat_core", true); + scoped_ptr sf = mk_smt_solver_factory(); + ::solver *m_solver = (*sf)(m(), p, true, true, true, ::symbol::null); + ::solver &s = *m_solver; + + for(unsigned i = 0; i < q.size(); i++) + s.assert_expr(to_expr(q[i].raw())); + lbool res = s.check_sat(0,0); + if(res == l_false){ + ::ast *proof = s.get_proof(); + _proof = cook(proof); + } + dealloc(m_solver); + return res != l_false; +} + + +void iz3base::find_children(const stl_ext::hash_set &cnsts_set, + const ast &tree, + std::vector &cnsts, + std::vector &parents, + std::vector &conjuncts, + std::vector &children, + std::vector &pos_map, + bool merge + ){ + std::vector my_children; + std::vector my_conjuncts; + if(op(tree) == Interp){ // if we've hit an interpolation position... + find_children(cnsts_set,arg(tree,0),cnsts,parents,my_conjuncts,my_children,pos_map,merge); + if(my_conjuncts.empty()) + my_conjuncts.push_back(mk_true()); // need at least one conjunct + int root = cnsts.size() + my_conjuncts.size() - 1; + for(unsigned i = 0; i < my_conjuncts.size(); i++){ + parents.push_back(root); + cnsts.push_back(my_conjuncts[i]); + } + for(unsigned i = 0; i < my_children.size(); i++) + parents[my_children[i]] = root; + children.push_back(root); + pos_map.push_back(root); + } + else { + if(op(tree) == And){ + int nargs = num_args(tree); + for(int i = 0; i < nargs; i++) + find_children(cnsts_set,arg(tree,i),cnsts,parents,my_conjuncts,my_children,pos_map,merge); + } + if(cnsts_set.find(tree) != cnsts_set.end()){ + if(merge && !my_conjuncts.empty()) + my_conjuncts.back() = mk_and(my_conjuncts.back(),tree); + else + my_conjuncts.push_back(tree); + } + for(unsigned i = 0; i < my_children.size(); i++) + children.push_back(my_children[i]); + for(unsigned i = 0; i < my_conjuncts.size(); i++) + conjuncts.push_back(my_conjuncts[i]); + } +} + +void iz3base::to_parents_vec_representation(const std::vector &_cnsts, + const ast &tree, + std::vector &cnsts, + std::vector &parents, + std::vector &theory, + std::vector &pos_map, + bool merge + ){ + std::vector my_children; + std::vector my_conjuncts; + hash_set cnsts_set; + for(unsigned i = 0; i < _cnsts.size(); i++) + cnsts_set.insert(_cnsts[i]); + ast _tree = (op(tree) != Interp) ? make(Interp,tree) : tree; + find_children(cnsts_set,_tree,cnsts,parents,my_conjuncts,my_children,pos_map,merge); + if(op(tree) != Interp) pos_map.pop_back(); + parents[parents.size()-1] = SHRT_MAX; + + // rest of the constraints are the background theory + + hash_set used_set; + for(unsigned i = 0; i < cnsts.size(); i++) + used_set.insert(cnsts[i]); + for(unsigned i = 0; i < _cnsts.size(); i++) + if(used_set.find(_cnsts[i]) == used_set.end()) + theory.push_back(_cnsts[i]); +} + diff --git a/src/interp/iz3base.h b/src/interp/iz3base.h new file mode 100755 index 000000000..6bf09bb85 --- /dev/null +++ b/src/interp/iz3base.h @@ -0,0 +1,195 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3base.h + +Abstract: + + Base class for interpolators. Includes an AST manager and a scoping + object as bases. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#ifndef IZ3BASE_H +#define IZ3BASE_H + +#include "iz3mgr.h" +#include "iz3scopes.h" + +namespace hash_space { + template <> + class hash { + public: + size_t operator()(func_decl * const &s) const { + return (size_t) s; + } + }; +} + +/* Base class for interpolators. Includes an AST manager and a scoping + object as bases. */ + +class iz3base : public iz3mgr, public scopes { + + public: + + /** Get the range in which an expression occurs. This is the + smallest subtree containing all occurrences of the + expression. */ + range &ast_range(ast); + + /** Get the scope of an expression. This is the set of tree nodes in + which all of the expression's symbols are in scope. */ + range &ast_scope(ast); + + /** Get the range of a symbol. This is the smallest subtree containing + all occurrences of the symbol. */ + range &sym_range(symb); + + /** Is an expression local (in scope in some frame)? */ + + bool is_local(ast node){ + return !range_is_empty(ast_scope(node)); + } + + /** Simplify an expression */ + + ast simplify(ast); + + /** Constructor */ + + iz3base(ast_manager &_m_manager, + const std::vector &_cnsts, + const std::vector &_parents, + const std::vector &_theory) + : iz3mgr(_m_manager), scopes(_parents) { + initialize(_cnsts,_parents,_theory); + weak = false; + } + + iz3base(const iz3mgr& other, + const std::vector &_cnsts, + const std::vector &_parents, + const std::vector &_theory) + : iz3mgr(other), scopes(_parents) { + initialize(_cnsts,_parents,_theory); + weak = false; + } + + iz3base(const iz3mgr& other, + const std::vector > &_cnsts, + const std::vector &_parents, + const std::vector &_theory) + : iz3mgr(other), scopes(_parents) { + initialize(_cnsts,_parents,_theory); + weak = false; + } + + iz3base(const iz3mgr& other) + : iz3mgr(other), scopes() { + weak = false; + } + + /* Set our options */ + void set_option(const std::string &name, const std::string &value){ + if(name == "weak" && value == "1") weak = true; + } + + /* Are we doing weak interpolants? */ + bool weak_mode(){return weak;} + + /** Print interpolation problem to an SMTLIB format file */ + void print(const std::string &filename); + + /** Check correctness of a solutino to this problem. */ + void check_interp(const std::vector &itps, std::vector &theory); + + /** For convenience -- is this formula SAT? */ + bool is_sat(const std::vector &consts, ast &_proof); + + /** Interpolator for clauses, to be implemented */ + virtual void interpolate_clause(std::vector &lits, std::vector &itps){ + throw "no interpolator"; + } + + ast get_proof_check_assump(range &rng){ + std::vector cs(theory); + cs.push_back(cnsts[rng.hi]); + return make(And,cs); + } + + int frame_of_assertion(const ast &ass){ + stl_ext::hash_map::iterator it = frame_map.find(ass); + if(it == frame_map.end()) + throw "unknown assertion"; + return it->second; + } + + + void to_parents_vec_representation(const std::vector &_cnsts, + const ast &tree, + std::vector &cnsts, + std::vector &parents, + std::vector &theory, + std::vector &pos_map, + bool merge = false + ); + + protected: + std::vector cnsts; + std::vector theory; + + private: + + struct ranges { + range rng; + range scp; + bool scope_computed; + ranges(){scope_computed = false;} + }; + + stl_ext::hash_map sym_range_hash; + stl_ext::hash_map ast_ranges_hash; + stl_ext::hash_map simplify_memo; + stl_ext::hash_map frame_map; // map assertions to frames + + int frames; // number of frames + + protected: + void add_frame_range(int frame, ast t); + + private: + void initialize(const std::vector &_parts, const std::vector &_parents, const std::vector &_theory); + + void initialize(const std::vector > &_parts, const std::vector &_parents, const std::vector &_theory); + + bool is_literal(ast n); + void gather_conjuncts_rec(ast n, std::vector &conjuncts, stl_ext::hash_set &memo); + void gather_conjuncts(ast n, std::vector &conjuncts); + ast simplify_and(std::vector &conjuncts); + ast simplify_with_lit_rec(ast n, ast lit, stl_ext::hash_map &memo, int depth); + ast simplify_with_lit(ast n, ast lit); + void find_children(const stl_ext::hash_set &cnsts_set, + const ast &tree, + std::vector &cnsts, + std::vector &parents, + std::vector &conjuncts, + std::vector &children, + std::vector &pos_map, + bool merge + ); + bool weak; + +}; + + + +#endif diff --git a/src/interp/iz3checker.cpp b/src/interp/iz3checker.cpp new file mode 100755 index 000000000..b4e55af20 --- /dev/null +++ b/src/interp/iz3checker.cpp @@ -0,0 +1,222 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3checker.cpp + +Abstract: + + check correctness of interpolant + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#include "iz3base.h" +#include "iz3checker.h" + +#include +#include +#include +#include +#include +#include +#include + + +#ifndef WIN32 +using namespace stl_ext; +#endif + +struct iz3checker : iz3base { + + /* HACK: for tree interpolants, we assume that uninterpreted functions + are global. This is because in the current state of the tree interpolation + code, symbols that appear in sibling sub-trees have to be global, and + we have no way to eliminate such function symbols. When tree interpoaltion is + fixed, we can tree function symbols the same as constant symbols. */ + + bool is_tree; + + void support(const ast &t, std::set &res, hash_set &memo){ + if(memo.find(t) != memo.end()) return; + memo.insert(t); + + int nargs = num_args(t); + for(int i = 0; i < nargs; i++) + support(arg(t,i),res,memo); + + switch(op(t)){ + case Uninterpreted: + if(nargs == 0 || !is_tree) { + std::string name = string_of_symbol(sym(t)); + res.insert(name); + } + break; + case Forall: + case Exists: + support(get_quantifier_body(t),res,memo); + break; + default:; + } + } + + bool check(solver *s, std::ostream &err, + const std::vector &cnsts, + const std::vector &parents, + const std::vector &itp, + const std::vector &theory){ + + is_tree = !parents.empty(); + int num = cnsts.size(); + std::vector > children(num); + + for(int i = 0; i < num-1; i++){ + if(parents.size()) + children[parents[i]].push_back(i); + else + children[i+1].push_back(i); + } + + for(int i = 0; i < num; i++){ + s->push(); + for(unsigned j = 0; j < theory.size(); j++) + s->assert_expr(to_expr(theory[j].raw())); + s->assert_expr(to_expr(cnsts[i].raw())); + std::vector &cs = children[i]; + for(unsigned j = 0; j < cs.size(); j++) + s->assert_expr(to_expr(itp[cs[j]].raw())); + if(i != num-1) + s->assert_expr(to_expr(mk_not(itp[i]).raw())); + lbool result = s->check_sat(0,0); + if(result != l_false){ + err << "interpolant " << i << " is incorrect"; + + s->pop(1); + for(unsigned j = 0; j < theory.size(); j++) + s->assert_expr(to_expr(theory[j].raw())); + for(unsigned j = 0; j < cnsts.size(); j++) + if(in_range(j,range_downward(i))) + s->assert_expr(to_expr(cnsts[j].raw())); + if(i != num-1) + s->assert_expr(to_expr(mk_not(itp[i]).raw())); + lbool result = s->check_sat(0,0); + if(result != l_false) + err << "interpolant " << i << " is not implied by its downeard closurn"; + + return false; + } + s->pop(1); + } + + std::vector > supports(num); + for(int i = 0; i < num; i++){ + hash_set memo; + support(cnsts[i],supports[i],memo); + } + for(int i = 0; i < num-1; i++){ + std::vector Bside(num); + for(int j = num-1; j >= 0; j--) + Bside[j] = j != i; + for(int j = num-1; j >= 0; j--) + if(!Bside[j]){ + std::vector &cs = children[i]; + for(unsigned k = 0; k < cs.size(); k++) + Bside[cs[k]] = false; + } + std::set Asup, Bsup,common,Isup,bad; + for(int j = num-1; j >= 0; j--){ + std::set &side = Bside[j] ? Bsup : Asup; + side.insert(supports[j].begin(),supports[j].end()); + } + std::set_intersection(Asup.begin(),Asup.end(),Bsup.begin(),Bsup.end(),std::inserter(common,common.begin())); + { + hash_set tmemo; + for(unsigned j = 0; j < theory.size(); j++) + support(theory[j],common,tmemo); // all theory symbols allowed in interps + } + hash_set memo; + support(itp[i],Isup,memo); + std::set_difference(Isup.begin(),Isup.end(),common.begin(),common.end(),std::inserter(bad,bad.begin())); + if(!bad.empty()){ + err << "bad symbols in interpolant " << i << ":"; + std::copy(bad.begin(),bad.end(),std::ostream_iterator(err,",")); + return false; + } + } + return true; + } + + bool check(solver *s, std::ostream &err, + const std::vector &_cnsts, + const ast &tree, + const std::vector &itp){ + + std::vector pos_map; + + // convert to the parents vector representation + + to_parents_vec_representation(_cnsts, tree, cnsts, parents, theory, pos_map); + + //use the parents vector representation to compute interpolant + return check(s,err,cnsts,parents,itp,theory); + } + + iz3checker(ast_manager &_m) + : iz3base(_m) { + } + + iz3checker(iz3mgr &_m) + : iz3base(_m) { + } + +}; + +template +std::vector to_std_vector(const ::vector &v){ + std::vector _v(v.size()); + for(unsigned i = 0; i < v.size(); i++) + _v[i] = v[i]; + return _v; +} + + +bool iz3check(ast_manager &_m_manager, + solver *s, + std::ostream &err, + const ptr_vector &cnsts, + const ::vector &parents, + const ptr_vector &interps, + const ptr_vector &theory) +{ + iz3checker chk(_m_manager); + return chk.check(s,err,chk.cook(cnsts),to_std_vector(parents),chk.cook(interps),chk.cook(theory)); +} + +bool iz3check(iz3mgr &mgr, + solver *s, + std::ostream &err, + const std::vector &cnsts, + const std::vector &parents, + const std::vector &interps, + const std::vector &theory) +{ + iz3checker chk(mgr); + return chk.check(s,err,cnsts,parents,interps,theory); +} + +bool iz3check(ast_manager &_m_manager, + solver *s, + std::ostream &err, + const ptr_vector &_cnsts, + ast *tree, + const ptr_vector &interps) +{ + iz3checker chk(_m_manager); + return chk.check(s,err,chk.cook(_cnsts),chk.cook(tree),chk.cook(interps)); +} diff --git a/src/interp/iz3checker.h b/src/interp/iz3checker.h new file mode 100644 index 000000000..d46ea0654 --- /dev/null +++ b/src/interp/iz3checker.h @@ -0,0 +1,49 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3checker.h + +Abstract: + + check correctness of an interpolant + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#ifndef IZ3_CHECKER_H +#define IZ3_CHECKER_H + +#include "iz3mgr.h" +#include "solver.h" + +bool iz3check(ast_manager &_m_manager, + solver *s, + std::ostream &err, + const ptr_vector &cnsts, + const ::vector &parents, + const ptr_vector &interps, + const ptr_vector &theory); + +bool iz3check(ast_manager &_m_manager, + solver *s, + std::ostream &err, + const ptr_vector &cnsts, + ast *tree, + const ptr_vector &interps); + +bool iz3check(iz3mgr &mgr, + solver *s, + std::ostream &err, + const std::vector &cnsts, + const std::vector &parents, + const std::vector &interps, + const ptr_vector &theory); + +#endif diff --git a/src/interp/iz3foci.cpp b/src/interp/iz3foci.cpp new file mode 100755 index 000000000..1d81a1b15 --- /dev/null +++ b/src/interp/iz3foci.cpp @@ -0,0 +1,359 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3foci.cpp + +Abstract: + + Implements a secondary solver using foci2. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#include +#include +#include + +#include "iz3hash.h" +#include "foci2.h" +#include "iz3foci.h" + +#ifndef WIN32 +using namespace stl_ext; +#endif + +class iz3foci_impl : public iz3secondary { + + int frames; + int *parents; + foci2 *foci; + foci2::symb select_op; + foci2::symb store_op; + foci2::symb mod_op; + +public: + iz3foci_impl(iz3mgr *mgr, int _frames, int *_parents) : iz3secondary(*mgr) { + frames = _frames; + parents = _parents; + foci = 0; + } + + typedef hash_map AstToNode; + AstToNode ast_to_node; // maps Z3 ast's to foci expressions + + typedef hash_map NodeToAst; + NodeToAst node_to_ast; // maps Z3 ast's to foci expressions + + // We only use this for FuncDeclToSymbol, which has no range destructor + struct symb_hash { + size_t operator()(const symb &s) const { + return (size_t) s; + } + }; + + typedef hash_map FuncDeclToSymbol; + FuncDeclToSymbol func_decl_to_symbol; // maps Z3 func decls to symbols + + typedef hash_map SymbolToFuncDecl; + SymbolToFuncDecl symbol_to_func_decl; // maps symbols to Z3 func decls + + int from_symb(symb func){ + std::string name = string_of_symbol(func); + bool is_bool = is_bool_type(get_range_type(func)); + foci2::symb f; + if(is_bool) + f = foci->mk_pred(name); + else + f = foci->mk_func(name); + symbol_to_func_decl[f] = func; + func_decl_to_symbol[func] = f; + return f; + } + + // create a symbol corresponding to a DeBruijn index of a particular type + // the type has to be encoded into the name because the same index can + // occur with different types + foci2::symb make_deBruijn_symbol(int index, type ty){ + std::ostringstream s; + // s << "#" << index << "#" << type; + return foci->mk_func(s.str()); + } + + int from_Z3_ast(ast t){ + std::pair foo(t,0); + std::pair bar = ast_to_node.insert(foo); + int &res = bar.first->second; + if(!bar.second) return res; + int nargs = num_args(t); + std::vector args(nargs); + for(int i = 0; i < nargs; i++) + args[i] = from_Z3_ast(arg(t,i)); + + switch(op(t)){ + case True: + res = foci->mk_true(); break; + case False: + res = foci->mk_false(); break; + case And: + res = foci->mk_op(foci2::And,args); break; + case Or: + res = foci->mk_op(foci2::Or,args); break; + case Not: + res = foci->mk_op(foci2::Not,args[0]); break; + case Iff: + res = foci->mk_op(foci2::Iff,args); break; + case OP_OEQ: // bit of a mystery, this one... + if(args[0] == args[1]) res = foci->mk_true(); + else res = foci->mk_op(foci2::Iff,args); + break; + case Ite: + if(is_bool_type(get_type(t))) + res = foci->mk_op(foci2::And,foci->mk_op(foci2::Or,foci->mk_op(foci2::Not,args[0]),args[1]),foci->mk_op(foci2::Or,args[0],args[2])); + else + res = foci->mk_op(foci2::Ite,args); + break; + case Equal: + res = foci->mk_op(foci2::Equal,args); break; + case Implies: + args[0] = foci->mk_op(foci2::Not,args[0]); res = foci->mk_op(foci2::Or,args); break; + case Xor: + res = foci->mk_op(foci2::Not,foci->mk_op(foci2::Iff,args)); break; + case Leq: + res = foci->mk_op(foci2::Leq,args); break; + case Geq: + std::swap(args[0],args[1]); res = foci->mk_op(foci2::Leq,args); break; + case Gt: + res = foci->mk_op(foci2::Not,foci->mk_op(foci2::Leq,args)); break; + case Lt: + std::swap(args[0],args[1]); res = foci->mk_op(foci2::Not,foci->mk_op(foci2::Leq,args)); break; + case Plus: + res = foci->mk_op(foci2::Plus,args); break; + case Sub: + args[1] = foci->mk_op(foci2::Times,foci->mk_int("-1"),args[1]); res = foci->mk_op(foci2::Plus,args); break; + case Uminus: + res = foci->mk_op(foci2::Times,foci->mk_int("-1"),args[0]); break; + case Times: + res = foci->mk_op(foci2::Times,args); break; + case Idiv: + res = foci->mk_op(foci2::Div,args); break; + case Mod: + res = foci->mk_app(mod_op,args); break; + case Select: + res = foci->mk_app(select_op,args); break; + case Store: + res = foci->mk_app(store_op,args); break; + case Distinct: + res = foci->mk_op(foci2::Distinct,args); break; + case Uninterpreted: { + symb func = sym(t); + FuncDeclToSymbol::iterator it = func_decl_to_symbol.find(func); + foci2::symb f = (it == func_decl_to_symbol.end()) ? from_symb(func) : it->second; + if(foci->get_symb(f).substr(0,3) == "lbl" && args.size()==1) // HACK to handle Z3 labels + res = args[0]; + else if(foci->get_symb(f).substr(0,3) == "lbl" && args.size()==0) // HACK to handle Z3 labels + res = foci->mk_true(); + else res = foci->mk_app(f,args); + break; + } + case Numeral: { + std::string s = string_of_numeral(t); + res = foci->mk_int(s); + break; + } + case Forall: + case Exists: { + bool is_forall = op(t) == Forall; + foci2::ops qop = is_forall ? foci2::Forall : foci2::Exists; + int bvs = get_quantifier_num_bound(t); + std::vector foci_bvs(bvs); + for(int i = 0; i < bvs; i++){ + std::string name = get_quantifier_bound_name(t,i); + //Z3_string name = Z3_get_symbol_string(ctx,sym); + // type ty = get_quantifier_bound_type(t,i); + foci2::symb f = foci->mk_func(name); + foci2::ast v = foci->mk_app(f,std::vector()); + foci_bvs[i] = v; + } + foci2::ast body = from_Z3_ast(get_quantifier_body(t)); + foci_bvs.push_back(body); + res = foci->mk_op(qop,foci_bvs); + node_to_ast[res] = t; // desperate + break; + } + case Variable: { // a deBruijn index + int index = get_variable_index_value(t); + type ty = get_type(t); + foci2::symb symbol = make_deBruijn_symbol(index,ty); + res = foci->mk_app(symbol,std::vector()); + break; + } + default: + { + std::cerr << "iZ3: unsupported Z3 operator in expression\n "; + print_expr(std::cerr,t); + std::cerr << "\n"; + assert(0 && "iZ3: unsupported Z3 operator"); + } + } + return res; + } + + // convert an expr to Z3 ast + ast to_Z3_ast(foci2::ast i){ + std::pair foo(i,ast()); + std::pair bar = node_to_ast.insert(foo); + if(!bar.second) return bar.first->second; + ast &res = bar.first->second; + + if(i < 0){ + res = mk_not(to_Z3_ast(-i)); + return res; + } + + // get the arguments + unsigned n = foci->get_num_args(i); + std::vector args(n); + for(unsigned j = 0; j < n; j++) + args[j] = to_Z3_ast(foci->get_arg(i,j)); + + // handle operators + foci2::ops o; + foci2::symb f; + std::string nval; + if(foci->get_true(i)) + res = mk_true(); + else if(foci->get_false(i)) + res = mk_false(); + else if(foci->get_op(i,o)){ + switch(o){ + case foci2::And: + res = make(And,args); break; + case foci2::Or: + res = make(Or,args); break; + case foci2::Not: + res = mk_not(args[0]); break; + case foci2::Iff: + res = make(Iff,args[0],args[1]); break; + case foci2::Ite: + res = make(Ite,args[0],args[1],args[2]); break; + case foci2::Equal: + res = make(Equal,args[0],args[1]); break; + case foci2::Plus: + res = make(Plus,args); break; + case foci2::Times: + res = make(Times,args); break; + case foci2::Div: + res = make(Idiv,args[0],args[1]); break; + case foci2::Leq: + res = make(Leq,args[0],args[1]); break; + case foci2::Distinct: + res = make(Distinct,args); + break; + case foci2::Tsym: + res = mk_true(); + break; + case foci2::Fsym: + res = mk_false(); + break; + case foci2::Forall: + case foci2::Exists: + { + int nargs = n; + std::vector bounds(nargs-1); + for(int i = 0; i < nargs-1; i++) + bounds[i] = args[i]; + opr oz = o == foci2::Forall ? Forall : Exists; + res = make_quant(oz,bounds,args[nargs-1]); + } + break; + default: + assert("unknown built-in op"); + } + } + else if(foci->get_int(i,nval)){ + res = make_int(nval); + } + else if(foci->get_func(i,f)){ + if(f == select_op){ + assert(n == 2); + res = make(Select,args[0],args[1]); + } + else if(f == store_op){ + assert(n == 3); + res = make(Store,args[0],args[1],args[2]); + } + else if(f == mod_op){ + assert(n == 2); + res = make(Mod,args[0],args[1]); + } + else { + std::pair foo(f,(symb)0); + std::pair bar = symbol_to_func_decl.insert(foo); + symb &func_decl = bar.first->second; + if(bar.second){ + std::cout << "unknown function symbol:\n"; + foci->show_ast(i); + assert(0); + } + res = make(func_decl,args); + } + } + else { + std::cerr << "iZ3: unknown FOCI expression kind\n"; + assert(0 && "iZ3: unknown FOCI expression kind"); + } + return res; + } + + int interpolate(const std::vector &cnsts, std::vector &itps){ + assert((int)cnsts.size() == frames); + std::string lia("lia"); +#ifdef _FOCI2 + foci = foci2::create(lia); +#else + foci = 0; +#endif + if(!foci){ + std::cerr << "iZ3: cannot find foci lia solver.\n"; + assert(0); + } + select_op = foci->mk_func("select"); + store_op = foci->mk_func("store"); + mod_op = foci->mk_func("mod"); + std::vector foci_cnsts(frames), foci_itps(frames-1), foci_parents; + if(parents) + foci_parents.resize(frames); + for(int i = 0; i < frames; i++){ + foci_cnsts[i] = from_Z3_ast(cnsts[i]); + if(parents) + foci_parents[i] = parents[i]; + } + int res = foci->interpolate(foci_cnsts, foci_itps, foci_parents); + if(res == 0){ + assert((int)foci_itps.size() == frames-1); + itps.resize(frames-1); + for(int i = 0; i < frames-1; i++){ + // foci->show_ast(foci_itps[i]); + itps[i] = to_Z3_ast(foci_itps[i]); + } + } + ast_to_node.clear(); + node_to_ast.clear(); + func_decl_to_symbol.clear(); + symbol_to_func_decl.clear(); + delete foci; + return res; + } + +}; + +iz3secondary *iz3foci::create(iz3mgr *mgr, int num, int *parents){ + return new iz3foci_impl(mgr,num,parents); +} diff --git a/src/interp/iz3foci.h b/src/interp/iz3foci.h new file mode 100755 index 000000000..86b049f24 --- /dev/null +++ b/src/interp/iz3foci.h @@ -0,0 +1,32 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3foci.h + +Abstract: + + Implements a secondary solver using foci2. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#ifndef IZ3FOCI_H +#define IZ3FOCI_H + +#include "iz3secondary.h" + +/** Secondary prover based on Cadence FOCI. */ + +class iz3foci { + public: + static iz3secondary *create(iz3mgr *mgr, int num, int *parents); +}; + +#endif diff --git a/src/interp/iz3hash.h b/src/interp/iz3hash.h new file mode 100755 index 000000000..75d9aa604 --- /dev/null +++ b/src/interp/iz3hash.h @@ -0,0 +1,171 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3hash.h + +Abstract: + + Wrapper for stl hash tables + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +// pull in the headers for has_map and hash_set +// these live in non-standard places + +#ifndef IZ3_HASH_H +#define IZ3_HASH_H + +//#define USE_UNORDERED_MAP +#ifdef USE_UNORDERED_MAP + +#define stl_ext std +#define hash_space std +#include +#include +#define hash_map unordered_map +#define hash_set unordered_set + +#else + +#if __GNUC__ >= 3 +#undef __DEPRECATED +#define stl_ext __gnu_cxx +#define hash_space stl_ext +#include +#include +#else +#ifdef WIN32 +#define stl_ext stdext +#define hash_space std +#include +#include +#else +#define stl_ext std +#define hash_space std +#include +#include +#endif +#endif + +#endif + +#include + +// stupid STL doesn't include hash function for class string + +#ifndef WIN32 + +namespace stl_ext { + template <> + class hash { + stl_ext::hash H; + public: + size_t operator()(const std::string &s) const { + return H(s.c_str()); + } + }; +} + +#endif + +namespace hash_space { + template <> + class hash > { + public: + size_t operator()(const std::pair &p) const { + return p.first + p.second; + } + }; +} + +#ifdef WIN32 +template <> inline +size_t stdext::hash_value >(const std::pair& p) +{ // hash _Keyval to size_t value one-to-one + return p.first + p.second; +} +#endif + +namespace hash_space { + template + class hash > { + public: + size_t operator()(const std::pair &p) const { + return (size_t)p.first + (size_t)p.second; + } + }; +} + +#if 0 +template inline +size_t stdext::hash_value >(const std::pair& p) +{ // hash _Keyval to size_t value one-to-one + return (size_t)p.first + (size_t)p.second; +} +#endif + +#ifdef WIN32 + +namespace std { + template <> + class less > { + public: + bool operator()(const pair &x, const pair &y) const { + return x.first < y.first || x.first == y.first && x.second < y.second; + } + }; + +} + +namespace std { + template + class less > { + public: + bool operator()(const pair &x, const pair &y) const { + return (size_t)x.first < (size_t)y.first || (size_t)x.first == (size_t)y.first && (size_t)x.second < (size_t)y.second; + } + }; + +} + +#endif + + +#ifndef WIN32 + +#if 0 +namespace stl_ext { + template + class hash { + public: + size_t operator()(const T *p) const { + return (size_t) p; + } + }; +} +#endif + +#endif + +#ifdef WIN32 + + + + +template +class hash_map : public stl_ext::hash_map > > {}; + +template +class hash_set : public stl_ext::hash_set > > {}; + +#endif + +#endif diff --git a/src/interp/iz3interp.cpp b/src/interp/iz3interp.cpp new file mode 100755 index 000000000..56dc1ccec --- /dev/null +++ b/src/interp/iz3interp.cpp @@ -0,0 +1,504 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3interp.cpp + +Abstract: + + Interpolation based on proof translation. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +/* Copyright 2011 Microsoft Research. */ +#include +#include +#include +#include +#include +#include +#include + +#include "iz3profiling.h" +#include "iz3translate.h" +#include "iz3foci.h" +#include "iz3proof.h" +#include "iz3hash.h" +#include "iz3interp.h" + + + +#ifndef WIN32 +using namespace stl_ext; +#endif + + + +#if 1 + +struct frame_reducer : public iz3mgr { + + int frames; + hash_map frame_map; + std::vector assertions_map; + std::vector orig_parents_copy; + std::vector used_frames; + + + frame_reducer(const iz3mgr &other) + : iz3mgr(other) {} + + void get_proof_assumptions_rec(z3pf proof, hash_set &memo, std::vector &used_frames){ + if(memo.count(proof))return; + memo.insert(proof); + pfrule dk = pr(proof); + if(dk == PR_ASSERTED){ + ast con = conc(proof); + if(frame_map.find(con) != frame_map.end()){ // false for theory facts + int frame = frame_map[con]; + used_frames[frame] = true; + } + } + else { + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + z3pf arg = prem(proof,i); + get_proof_assumptions_rec(arg,memo,used_frames); + } + } + } + + void get_frames(const std::vector >&z3_preds, + const std::vector &orig_parents, + std::vector >&assertions, + std::vector &parents, + z3pf proof){ + frames = z3_preds.size(); + orig_parents_copy = orig_parents; + for(unsigned i = 0; i < z3_preds.size(); i++) + for(unsigned j = 0; j < z3_preds[i].size(); j++) + frame_map[z3_preds[i][j]] = i; + used_frames.resize(frames); + hash_set memo; + get_proof_assumptions_rec(proof,memo,used_frames); + std::vector assertions_back_map(frames); + + for(unsigned i = 0; i < z3_preds.size(); i++) + if(used_frames[i] || i == z3_preds.size() - 1){ + assertions.push_back(z3_preds[i]); + assertions_map.push_back(i); + assertions_back_map[i] = assertions.size() - 1; + } + + if(orig_parents.size()){ + parents.resize(assertions.size()); + for(unsigned i = 0; i < assertions.size(); i++){ + int p = orig_parents[assertions_map[i]]; + while(p != SHRT_MAX && !used_frames[p]) + p = orig_parents[p]; + parents[i] = p == SHRT_MAX ? p : assertions_back_map[p]; + } + } + + // std::cout << "used frames = " << frames << "\n"; + } + + void fix_interpolants(std::vector &interpolants){ + std::vector unfixed = interpolants; + interpolants.resize(frames - 1); + for(int i = 0; i < frames - 1; i++) + interpolants[i] = mk_true(); + for(unsigned i = 0; i < unfixed.size(); i++) + interpolants[assertions_map[i]] = unfixed[i]; + for(int i = 0; i < frames-2; i++){ + int p = orig_parents_copy.size() == 0 ? i+1 : orig_parents_copy[i]; + if(p < frames - 1 && !used_frames[p]) + interpolants[p] = interpolants[i]; + } + } +}; + +#else +struct frame_reducer { + + + + frame_reducer(context _ctx){ + } + + void get_frames(const std::vector &z3_preds, + const std::vector &orig_parents, + std::vector &assertions, + std::vector &parents, + ast proof){ + assertions = z3_preds; + parents = orig_parents; + } + + void fix_interpolants(std::vector &interpolants){ + } +}; + +#endif + + +#if 0 +static lbool test_secondary(context ctx, + int num, + ast *cnsts, + ast *interps, + int *parents = 0 + ){ + iz3secondary *sp = iz3foci::create(ctx,num,parents); + std::vector frames(num), interpolants(num-1); + std::copy(cnsts,cnsts+num,frames.begin()); + int res = sp->interpolate(frames,interpolants); + if(res == 0) + std::copy(interpolants.begin(),interpolants.end(),interps); + return res ? L_TRUE : L_FALSE; +} +#endif + +template +struct killme { + T *p; + killme(){p = 0;} + void set(T *_p) {p = _p;} + ~killme(){ + if(p) + delete p; + } +}; + + +class iz3interp : public iz3base { +public: + + killme sp_killer; + killme tr_killer; + + bool is_linear(std::vector &parents){ + for(int i = 0; i < ((int)parents.size())-1; i++) + if(parents[i] != i+1) + return false; + return true; + } + + void test_secondary(const std::vector &cnsts, + const std::vector &parents, + std::vector &interps + ){ + int num = cnsts.size(); + iz3secondary *sp = iz3foci::create(this,num,(int *)(parents.empty()?0:&parents[0])); + int res = sp->interpolate(cnsts, interps); + if(res != 0) + throw "secondary failed"; + } + + void proof_to_interpolant(z3pf proof, + const std::vector > &cnsts, + const std::vector &parents, + std::vector &interps, + const std::vector &theory, + interpolation_options_struct *options = 0 + ){ +#if 0 + test_secondary(cnsts,parents,interps); + return; +#endif + + profiling::timer_start("Interpolation prep"); + + // get rid of frames not used in proof + + std::vector > cnsts_vec; + std::vector parents_vec; + frame_reducer fr(*this); + fr.get_frames(cnsts,parents,cnsts_vec,parents_vec,proof); + + int num = cnsts_vec.size(); + std::vector interps_vec(num-1); + + // if this is really a sequence problem, we can make it easier + if(is_linear(parents_vec)) + parents_vec.clear(); + + // create a secondary prover + iz3secondary *sp = iz3foci::create(this,num,parents_vec.empty()?0:&parents_vec[0]); + sp_killer.set(sp); // kill this on exit + +#define BINARY_INTERPOLATION +#ifndef BINARY_INTERPOLATION + // create a translator + iz3translation *tr = iz3translation::create(*this,sp,cnsts_vec,parents_vec,theory); + tr_killer.set(tr); + + // set the translation options, if needed + if(options) + for(hash_map::iterator it = options->map.begin(), en = options->map.end(); it != en; ++it) + tr->set_option(it->first, it->second); + + // create a proof object to hold the translation + iz3proof pf(tr); + + profiling::timer_stop("Interpolation prep"); + + // translate into an interpolatable proof + profiling::timer_start("Proof translation"); + tr->translate(proof,pf); + profiling::timer_stop("Proof translation"); + + // translate the proof into interpolants + profiling::timer_start("Proof interpolation"); + for(int i = 0; i < num-1; i++){ + interps_vec[i] = pf.interpolate(tr->range_downward(i),tr->weak_mode()); + interps_vec[i] = tr->quantify(interps_vec[i],tr->range_downward(i)); + } + profiling::timer_stop("Proof interpolation"); +#else + iz3base the_base(*this,cnsts_vec,parents_vec,theory); + + profiling::timer_stop("Interpolation prep"); + + for(int i = 0; i < num-1; i++){ + range rng = the_base.range_downward(i); + std::vector > cnsts_vec_vec(2); + for(unsigned j = 0; j < cnsts_vec.size(); j++){ + bool is_A = the_base.in_range(j,rng); + for(unsigned k = 0; k < cnsts_vec[j].size(); k++) + cnsts_vec_vec[is_A ? 0 : 1].push_back(cnsts_vec[j][k]); + } + + killme tr_killer_i; + iz3translation *tr = iz3translation::create(*this,sp,cnsts_vec_vec,std::vector(),theory); + tr_killer_i.set(tr); + + // set the translation options, if needed + if(options) + for(hash_map::iterator it = options->map.begin(), en = options->map.end(); it != en; ++it) + tr->set_option(it->first, it->second); + + // create a proof object to hold the translation + iz3proof pf(tr); + + // translate into an interpolatable proof + profiling::timer_start("Proof translation"); + tr->translate(proof,pf); + profiling::timer_stop("Proof translation"); + + // translate the proof into interpolants + profiling::timer_start("Proof interpolation"); + interps_vec[i] = pf.interpolate(tr->range_downward(0),tr->weak_mode()); + interps_vec[i] = tr->quantify(interps_vec[i],tr->range_downward(0)); + profiling::timer_stop("Proof interpolation"); + } +#endif + // put back in the removed frames + fr.fix_interpolants(interps_vec); + + interps = interps_vec; + + } + + + void proof_to_interpolant(z3pf proof, + std::vector &cnsts, + const std::vector &parents, + std::vector &interps, + const std::vector &theory, + interpolation_options_struct *options = 0 + ){ + std::vector > cnsts_vec(cnsts.size()); + for(unsigned i = 0; i < cnsts.size(); i++) + cnsts_vec[i].push_back(cnsts[i]); + proof_to_interpolant(proof,cnsts_vec,parents,interps,theory,options); + } + + // same as above, but represents the tree using an ast + + void proof_to_interpolant(const z3pf &proof, + const std::vector &_cnsts, + const ast &tree, + std::vector &interps, + interpolation_options_struct *options = 0 + ){ + std::vector pos_map; + + // convert to the parents vector representation + + to_parents_vec_representation(_cnsts, tree, cnsts, parents, theory, pos_map); + + //use the parents vector representation to compute interpolant + proof_to_interpolant(proof,cnsts,parents,interps,theory,options); + + // get the interps for the tree positions + std::vector _interps = interps; + interps.resize(pos_map.size()); + for(unsigned i = 0; i < pos_map.size(); i++) + interps[i] = i < _interps.size() ? _interps[i] : mk_false(); + } + + bool has_interp(hash_map &memo, const ast &t){ + if(memo.find(t) != memo.end()) + return memo[t]; + bool res = false; + if(op(t) == Interp) + res = true; + else if(op(t) == And){ + int nargs = num_args(t); + for(int i = 0; i < nargs; i++) + res |= has_interp(memo, arg(t,i)); + } + memo[t] = res; + return res; + } + + void collect_conjuncts(std::vector &cnsts, hash_map &memo, const ast &t){ + if(!has_interp(memo,t)) + cnsts.push_back(t); + else { + int nargs = num_args(t); + for(int i = 0; i < nargs; i++) + collect_conjuncts(cnsts, memo, arg(t,i)); + } + } + + void assert_conjuncts(solver &s, std::vector &cnsts, const ast &t){ + hash_map memo; + collect_conjuncts(cnsts,memo,t); + for(unsigned i = 0; i < cnsts.size(); i++) + s.assert_expr(to_expr(cnsts[i].raw())); + } + + iz3interp(ast_manager &_m_manager) + : iz3base(_m_manager) {} +}; + +void iz3interpolate(ast_manager &_m_manager, + ast *proof, + const ptr_vector &cnsts, + const ::vector &parents, + ptr_vector &interps, + const ptr_vector &theory, + interpolation_options_struct * options) +{ + iz3interp itp(_m_manager); + if(options) + options->apply(itp); + std::vector _cnsts(cnsts.size()); + std::vector _parents(parents.size()); + std::vector _interps; + std::vector _theory(theory.size()); + for(unsigned i = 0; i < cnsts.size(); i++) + _cnsts[i] = itp.cook(cnsts[i]); + for(unsigned i = 0; i < parents.size(); i++) + _parents[i] = parents[i]; + for(unsigned i = 0; i < theory.size(); i++) + _theory[i] = itp.cook(theory[i]); + iz3mgr::ast _proof = itp.cook(proof); + itp.proof_to_interpolant(_proof,_cnsts,_parents,_interps,_theory,options); + interps.resize(_interps.size()); + for(unsigned i = 0; i < interps.size(); i++) + interps[i] = itp.uncook(_interps[i]); +} + +void iz3interpolate(ast_manager &_m_manager, + ast *proof, + const ::vector > &cnsts, + const ::vector &parents, + ptr_vector &interps, + const ptr_vector &theory, + interpolation_options_struct * options) +{ + iz3interp itp(_m_manager); + if(options) + options->apply(itp); + std::vector > _cnsts(cnsts.size()); + std::vector _parents(parents.size()); + std::vector _interps; + std::vector _theory(theory.size()); + for(unsigned i = 0; i < cnsts.size(); i++) + for(unsigned j = 0; j < cnsts[i].size(); j++) + _cnsts[i].push_back(itp.cook(cnsts[i][j])); + for(unsigned i = 0; i < parents.size(); i++) + _parents[i] = parents[i]; + for(unsigned i = 0; i < theory.size(); i++) + _theory[i] = itp.cook(theory[i]); + iz3mgr::ast _proof = itp.cook(proof); + itp.proof_to_interpolant(_proof,_cnsts,_parents,_interps,_theory,options); + interps.resize(_interps.size()); + for(unsigned i = 0; i < interps.size(); i++) + interps[i] = itp.uncook(_interps[i]); +} + +void iz3interpolate(ast_manager &_m_manager, + ast *proof, + const ptr_vector &cnsts, + ast *tree, + ptr_vector &interps, + interpolation_options_struct * options) +{ + iz3interp itp(_m_manager); + if(options) + options->apply(itp); + std::vector _cnsts(cnsts.size()); + std::vector _interps; + for(unsigned i = 0; i < cnsts.size(); i++) + _cnsts[i] = itp.cook(cnsts[i]); + iz3mgr::ast _proof = itp.cook(proof); + iz3mgr::ast _tree = itp.cook(tree); + itp.proof_to_interpolant(_proof,_cnsts,_tree,_interps,options); + interps.resize(_interps.size()); + for(unsigned i = 0; i < interps.size(); i++) + interps[i] = itp.uncook(_interps[i]); +} + +lbool iz3interpolate(ast_manager &_m_manager, + solver &s, + ast *tree, + ptr_vector &cnsts, + ptr_vector &interps, + model_ref &m, + interpolation_options_struct * options) +{ + iz3interp itp(_m_manager); + if(options) + options->apply(itp); + iz3mgr::ast _tree = itp.cook(tree); + std::vector _cnsts; + itp.assert_conjuncts(s,_cnsts,_tree); + profiling::timer_start("solving"); + lbool res = s.check_sat(0,0); + profiling::timer_stop("solving"); + if(res == l_false){ + ast *proof = s.get_proof(); + iz3mgr::ast _proof = itp.cook(proof); + std::vector _interps; + itp.proof_to_interpolant(_proof,_cnsts,_tree,_interps,options); + interps.resize(_interps.size()); + for(unsigned i = 0; i < interps.size(); i++) + interps[i] = itp.uncook(_interps[i]); + } + else if(m){ + s.get_model(m); + } + cnsts.resize(_cnsts.size()); + for(unsigned i = 0; i < cnsts.size(); i++) + cnsts[i] = itp.uncook(_cnsts[i]); + return res; +} + +void interpolation_options_struct::apply(iz3base &b){ + for(stl_ext::hash_map::iterator it = map.begin(), en = map.end(); + it != en; + ++it) + b.set_option((*it).first,(*it).second); +} + diff --git a/src/interp/iz3interp.h b/src/interp/iz3interp.h new file mode 100644 index 000000000..52aa716c3 --- /dev/null +++ b/src/interp/iz3interp.h @@ -0,0 +1,93 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3interp.h + +Abstract: + + Interpolation based on proof translation. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#ifndef IZ3_INTERP_H +#define IZ3_INTERP_H + +#include "iz3hash.h" +#include "solver.h" + +class iz3base; + +struct interpolation_options_struct { + stl_ext::hash_map map; +public: + void set(const std::string &name, const std::string &value){ + map[name] = value; + } + void apply(iz3base &b); +}; + +/** This object is thrown if a tree interpolation problem is mal-formed */ +struct iz3_bad_tree { +}; + +/** This object is thrown when iz3 fails due to an incompleteness in + the secondary solver. */ +struct iz3_incompleteness { +}; + +typedef interpolation_options_struct *interpolation_options; + +/* Compute an interpolant from a proof. This version uses the parents vector + representation, for compatibility with the old API. */ + +void iz3interpolate(ast_manager &_m_manager, + ast *proof, + const ptr_vector &cnsts, + const ::vector &parents, + ptr_vector &interps, + const ptr_vector &theory, + interpolation_options_struct * options = 0); + +/* Same as above, but each constraint is a vector of formulas. */ + +void iz3interpolate(ast_manager &_m_manager, + ast *proof, + const vector > &cnsts, + const ::vector &parents, + ptr_vector &interps, + const ptr_vector &theory, + interpolation_options_struct * options = 0); + +/* Compute an interpolant from a proof. This version uses the ast + representation, for compatibility with the new API. */ + +void iz3interpolate(ast_manager &_m_manager, + ast *proof, + const ptr_vector &cnsts, + ast *tree, + ptr_vector &interps, + interpolation_options_struct * options); + +/* Compute an interpolant from an ast representing an interpolation + problem, if unsat, else return a model (if enabled). Uses the + given solver to produce the proof/model. Also returns a vector + of the constraints in the problem, helpful for checking correctness. +*/ + +lbool iz3interpolate(ast_manager &_m_manager, + solver &s, + ast *tree, + ptr_vector &cnsts, + ptr_vector &interps, + model_ref &m, + interpolation_options_struct * options); + +#endif diff --git a/src/interp/iz3mgr.cpp b/src/interp/iz3mgr.cpp new file mode 100644 index 000000000..f35dae93f --- /dev/null +++ b/src/interp/iz3mgr.cpp @@ -0,0 +1,871 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3mgr.cpp + +Abstract: + + A wrapper around an ast manager, providing convenience methods. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + + +#include "iz3mgr.h" + +#include +#include +#include +#include + +#include "expr_abstract.h" +#include "params.h" + + +#ifndef WIN32 +using namespace stl_ext; +#endif + + +std::ostream &operator <<(std::ostream &s, const iz3mgr::ast &a){ + return s; +} + + +iz3mgr::ast iz3mgr::make_var(const std::string &name, type ty){ + symbol s = symbol(name.c_str()); + return cook(m().mk_const(m().mk_const_decl(s, ty))); +} + +iz3mgr::ast iz3mgr::make(opr op, int n, raw_ast **args){ + switch(op) { + case True: return mki(m_basic_fid,OP_TRUE,n,args); + case False: return mki(m_basic_fid,OP_FALSE,n,args); + case Equal: return mki(m_basic_fid,OP_EQ,n,args); + case Distinct: return mki(m_basic_fid,OP_DISTINCT,n,args); + case Ite: return mki(m_basic_fid,OP_ITE,n,args); + case And: return mki(m_basic_fid,OP_AND,n,args); + case Or: return mki(m_basic_fid,OP_OR,n,args); + case Iff: return mki(m_basic_fid,OP_IFF,n,args); + case Xor: return mki(m_basic_fid,OP_XOR,n,args); + case Not: return mki(m_basic_fid,OP_NOT,n,args); + case Implies: return mki(m_basic_fid,OP_IMPLIES,n,args); + case Oeq: return mki(m_basic_fid,OP_OEQ,n,args); + case Interp: return mki(m_basic_fid,OP_INTERP,n,args); + case Leq: return mki(m_arith_fid,OP_LE,n,args); + case Geq: return mki(m_arith_fid,OP_GE,n,args); + case Lt: return mki(m_arith_fid,OP_LT,n,args); + case Gt: return mki(m_arith_fid,OP_GT,n,args); + case Plus: return mki(m_arith_fid,OP_ADD,n,args); + case Sub: return mki(m_arith_fid,OP_SUB,n,args); + case Uminus: return mki(m_arith_fid,OP_UMINUS,n,args); + case Times: return mki(m_arith_fid,OP_MUL,n,args); + case Div: return mki(m_arith_fid,OP_DIV,n,args); + case Idiv: return mki(m_arith_fid,OP_IDIV,n,args); + case Rem: return mki(m_arith_fid,OP_REM,n,args); + case Mod: return mki(m_arith_fid,OP_MOD,n,args); + case Power: return mki(m_arith_fid,OP_POWER,n,args); + case ToReal: return mki(m_arith_fid,OP_TO_REAL,n,args); + case ToInt: return mki(m_arith_fid,OP_TO_INT,n,args); + case IsInt: return mki(m_arith_fid,OP_IS_INT,n,args); + case Store: return mki(m_array_fid,OP_STORE,n,args); + case Select: return mki(m_array_fid,OP_SELECT,n,args); + case ConstArray: return mki(m_array_fid,OP_CONST_ARRAY,n,args); + case ArrayDefault: return mki(m_array_fid,OP_ARRAY_DEFAULT,n,args); + case ArrayMap: return mki(m_array_fid,OP_ARRAY_MAP,n,args); + case SetUnion: return mki(m_array_fid,OP_SET_UNION,n,args); + case SetIntersect: return mki(m_array_fid,OP_SET_INTERSECT,n,args); + case SetDifference: return mki(m_array_fid,OP_SET_DIFFERENCE,n,args); + case SetComplement: return mki(m_array_fid,OP_SET_COMPLEMENT,n,args); + case SetSubSet: return mki(m_array_fid,OP_SET_SUBSET,n,args); + case AsArray: return mki(m_array_fid,OP_AS_ARRAY,n,args); + default: + assert(0); + return ast(); + } +} + +iz3mgr::ast iz3mgr::mki(family_id fid, decl_kind dk, int n, raw_ast **args){ + return cook(m().mk_app(fid, dk, 0, 0, n, (expr **)args)); +} + +iz3mgr::ast iz3mgr::make(opr op, const std::vector &args){ + static std::vector a(10); + if(a.size() < args.size()) + a.resize(args.size()); + for(unsigned i = 0; i < args.size(); i++) + a[i] = args[i].raw(); + return make(op,args.size(), args.size() ? &a[0] : 0); +} + +iz3mgr::ast iz3mgr::make(opr op){ + return make(op,0,0); +} + +iz3mgr::ast iz3mgr::make(opr op, const ast &arg0){ + raw_ast *a = arg0.raw(); + return make(op,1,&a); +} + +iz3mgr::ast iz3mgr::make(opr op, const ast &arg0, const ast &arg1){ + raw_ast *args[2]; + args[0] = arg0.raw(); + args[1] = arg1.raw(); + return make(op,2,args); +} + +iz3mgr::ast iz3mgr::make(opr op, const ast &arg0, const ast &arg1, const ast &arg2){ + raw_ast *args[3]; + args[0] = arg0.raw(); + args[1] = arg1.raw(); + args[2] = arg2.raw(); + return make(op,3,args); +} + +iz3mgr::ast iz3mgr::make(symb sym, int n, raw_ast **args){ + return cook(m().mk_app(sym, n, (expr **) args)); +} + +iz3mgr::ast iz3mgr::make(symb sym, const std::vector &args){ + static std::vector a(10); + if(a.size() < args.size()) + a.resize(args.size()); + for(unsigned i = 0; i < args.size(); i++) + a[i] = args[i].raw(); + return make(sym,args.size(), args.size() ? &a[0] : 0); +} + +iz3mgr::ast iz3mgr::make(symb sym){ + return make(sym,0,0); +} + +iz3mgr::ast iz3mgr::make(symb sym, const ast &arg0){ + raw_ast *a = arg0.raw(); + return make(sym,1,&a); +} + +iz3mgr::ast iz3mgr::make(symb sym, const ast &arg0, const ast &arg1){ + raw_ast *args[2]; + args[0] = arg0.raw(); + args[1] = arg1.raw(); + return make(sym,2,args); +} + +iz3mgr::ast iz3mgr::make(symb sym, const ast &arg0, const ast &arg1, const ast &arg2){ + raw_ast *args[3]; + args[0] = arg0.raw(); + args[1] = arg1.raw(); + args[2] = arg2.raw(); + return make(sym,3,args); +} + +iz3mgr::ast iz3mgr::make_quant(opr op, const std::vector &bvs, ast &body){ + if(bvs.size() == 0) return body; + std::vector foo(bvs.size()); + + + std::vector names; + std::vector types; + std::vector bound_asts; + unsigned num_bound = bvs.size(); + + for (unsigned i = 0; i < num_bound; ++i) { + app* a = to_app(bvs[i].raw()); + symbol s(to_app(a)->get_decl()->get_name()); + names.push_back(s); + types.push_back(m().get_sort(a)); + bound_asts.push_back(a); + } + expr_ref abs_body(m()); + expr_abstract(m(), 0, num_bound, &bound_asts[0], to_expr(body.raw()), abs_body); + expr_ref result(m()); + result = m().mk_quantifier( + op == Forall, + names.size(), &types[0], &names[0], abs_body.get(), + 0, + symbol("itp"), + symbol(), + 0, 0, + 0, 0 + ); + return cook(result.get()); +} + +// FIXME replace this with existing Z3 functionality + +iz3mgr::ast iz3mgr::clone(const ast &t, const std::vector &_args){ + if(_args.size() == 0) + return t; + + ast_manager& m = m_manager; + expr* a = to_expr(t.raw()); + static std::vector rargs(10); + if(rargs.size() < _args.size()) + rargs.resize(_args.size()); + for(unsigned i = 0; i < _args.size(); i++) + rargs[i] = _args[i].raw(); + expr* const* args = (expr **)&rargs[0]; + switch(a->get_kind()) { + case AST_APP: { + app* e = to_app(a); + if (e->get_num_args() != _args.size()) { + assert(0); + } + else { + a = m.mk_app(e->get_decl(), _args.size(), args); + } + break; + } + case AST_QUANTIFIER: { + if (_args.size() != 1) { + assert(0); + } + else { + a = m.update_quantifier(to_quantifier(a), args[0]); + } + break; + } + default: + break; + } + return cook(a); +} + + +void iz3mgr::show(ast t){ + params_ref p; + p.set_bool("flat_assoc",false); + std::cout << mk_pp(t.raw(), m(), p) << std::endl; +} + +void iz3mgr::show_symb(symb s){ + std::cout << mk_pp(s, m()) << std::endl; +} + +void iz3mgr::print_expr(std::ostream &s, const ast &e){ + params_ref p; + p.set_bool("flat_assoc",false); + s << mk_pp(e.raw(), m(), p); +} + + +void iz3mgr::print_clause(std::ostream &s, std::vector &cls){ + s << "("; + for(unsigned i = 0; i < cls.size(); i++){ + if(i > 0) s << ","; + print_expr(s,cls[i]); + } + s << ")"; +} + +void iz3mgr::show_clause(std::vector &cls){ + print_clause(std::cout,cls); + std::cout << std::endl; +} + +void iz3mgr::print_lit(ast lit){ + ast abslit = is_not(lit) ? arg(lit,0) : lit; + int f = op(abslit); + if(f == And || f == Or || f == Iff){ + if(is_not(lit)) std::cout << "~"; + std::cout << "[" << abslit << "]"; + } + else + std::cout << lit; +} + + +static int pretty_cols = 79; +static int pretty_indent_chars = 2; + +static int pretty_find_delim(const std::string &s, int pos){ + int level = 0; + int end = s.size(); + for(; pos < end; pos++){ + int ch = s[pos]; + if(ch == '(')level++; + if(ch == ')')level--; + if(level < 0 || (level == 0 && ch == ','))break; + } + return pos; +} + +static void pretty_newline(std::ostream &f, int indent){ + f << std::endl; + for(int i = 0; i < indent; i++) + f << " "; +} + +void iz3mgr::pretty_print(std::ostream &f, const std::string &s){ + int cur_indent = 0; + int indent = 0; + int col = 0; + int pos = 0; + while(pos < (int)s.size()){ + int delim = pretty_find_delim(s,pos); + if(s[pos] != ')' && s[pos] != ',' && cur_indent > indent){ + pretty_newline(f,indent); + cur_indent = indent; + col = indent; + continue; + } + if (col + delim - pos > pretty_cols) { + if (col > indent) { + pretty_newline(f,indent); + cur_indent = indent; + col = indent; + continue; + } + int paren = s.find('(',pos); + if(paren != (int)std::string::npos){ + int chars = paren - pos + 1; + f << s.substr(pos,chars); + indent += pretty_indent_chars; + if(col) pretty_newline(f,indent); + cur_indent = indent; + pos += chars; + col = indent; + continue; + } + } + int chars = delim - pos + 1; + f << s.substr(pos,chars); + pos += chars; + col += chars; + if(s[delim] == ')') + indent -= pretty_indent_chars; + } +} + + +iz3mgr::opr iz3mgr::op(const ast &t){ + ast_kind dk = t.raw()->get_kind(); + switch(dk){ + case AST_APP: { + expr * e = to_expr(t.raw()); + func_decl *d = to_app(t.raw())->get_decl(); + if (null_family_id == d->get_family_id()) + return Uninterpreted; + // return (opr)d->get_decl_kind(); + if (m_basic_fid == d->get_family_id()) { + switch(d->get_decl_kind()) { + case OP_TRUE: return True; + case OP_FALSE: return False; + case OP_EQ: return Equal; + case OP_DISTINCT: return Distinct; + case OP_ITE: return Ite; + case OP_AND: return And; + case OP_OR: return Or; + case OP_IFF: return Iff; + case OP_XOR: return Xor; + case OP_NOT: return Not; + case OP_IMPLIES: return Implies; + case OP_OEQ: return Oeq; + case OP_INTERP: return Interp; + default: + return Other; + } + } + if (m_arith_fid == d->get_family_id()) { + switch(d->get_decl_kind()) { + case OP_LE: return Leq; + case OP_GE: return Geq; + case OP_LT: return Lt; + case OP_GT: return Gt; + case OP_ADD: return Plus; + case OP_SUB: return Sub; + case OP_UMINUS: return Uminus; + case OP_MUL: return Times; + case OP_DIV: return Div; + case OP_IDIV: return Idiv; + case OP_REM: return Rem; + case OP_MOD: return Mod; + case OP_POWER: return Power; + case OP_TO_REAL: return ToReal; + case OP_TO_INT: return ToInt; + case OP_IS_INT: return IsInt; + default: + if (m().is_unique_value(e)) + return Numeral; + return Other; + } + } + if (m_array_fid == d->get_family_id()) { + switch(d->get_decl_kind()) { + case OP_STORE: return Store; + case OP_SELECT: return Select; + case OP_CONST_ARRAY: return ConstArray; + case OP_ARRAY_DEFAULT: return ArrayDefault; + case OP_ARRAY_MAP: return ArrayMap; + case OP_SET_UNION: return SetUnion; + case OP_SET_INTERSECT: return SetIntersect; + case OP_SET_DIFFERENCE: return SetDifference; + case OP_SET_COMPLEMENT: return SetComplement; + case OP_SET_SUBSET: return SetSubSet; + case OP_AS_ARRAY: return AsArray; + default: + return Other; + } + } + + return Other; + } + + + case AST_QUANTIFIER: + return to_quantifier(t.raw())->is_forall() ? Forall : Exists; + case AST_VAR: + return Variable; + default:; + } + return Other; +} + + +iz3mgr::pfrule iz3mgr::pr(const ast &t){ + func_decl *d = to_app(t.raw())->get_decl(); + assert(m_basic_fid == d->get_family_id()); + return d->get_decl_kind(); +} + +void iz3mgr::print_sat_problem(std::ostream &out, const ast &t){ + ast_smt_pp pp(m()); + pp.set_simplify_implies(false); + pp.display_smt2(out, to_expr(t.raw())); +} + +iz3mgr::ast iz3mgr::z3_simplify(const ast &e){ + ::expr * a = to_expr(e.raw()); + params_ref p; + th_rewriter m_rw(m(), p); + expr_ref result(m()); + m_rw(a, result); + return cook(result); +} + +iz3mgr::ast iz3mgr::z3_really_simplify(const ast &e){ + ::expr * a = to_expr(e.raw()); + params_ref simp_params; + simp_params.set_bool(":som",true); + simp_params.set_bool(":sort-sums",true); + th_rewriter m_rw(m(), simp_params); + expr_ref result(m()); + m_rw(a, result); + return cook(result); +} + + +#if 0 +static rational lcm(const rational &x, const rational &y){ + int a = x.numerator(); + int b = y.numerator(); + return rational(a * b / gcd(a, b)); +} +#endif + +static rational extract_lcd(std::vector &res){ + if(res.size() == 0) return rational(1); // shouldn't happen + rational lcd = denominator(res[0]); + for(unsigned i = 1; i < res.size(); i++) + lcd = lcm(lcd,denominator(res[i])); + for(unsigned i = 0; i < res.size(); i++) + res[i] *= lcd; + return lcd; +} + +void iz3mgr::get_farkas_coeffs(const ast &proof, std::vector& coeffs){ + std::vector rats; + get_farkas_coeffs(proof,rats); + coeffs.resize(rats.size()); + for(unsigned i = 0; i < rats.size(); i++){ + sort *is = m().mk_sort(m_arith_fid, INT_SORT); + ast coeff = cook(m_arith_util.mk_numeral(rats[i],is)); + coeffs[i] = coeff; + } +} + +static void abs_rat(std::vector &rats){ + // check that they are all non-neg -- if neg, take abs val and warn! + for(unsigned i = 0; i < rats.size(); i++) + if(rats[i].is_neg()){ + // std::cout << "negative Farkas coeff!\n"; + rats[i] = -rats[i]; + } +} + +bool iz3mgr::is_farkas_coefficient_negative(const ast &proof, int n){ + rational r; + symb s = sym(proof); + bool ok = s->get_parameter(n+2).is_rational(r); + if(!ok) + throw "Bad Farkas coefficient"; + return r.is_neg(); +} + +void iz3mgr::get_farkas_coeffs(const ast &proof, std::vector& rats){ + symb s = sym(proof); + int numps = s->get_num_parameters(); + rats.resize(numps-2); +#if 0 + if(num_prems(proof) < numps-2){ + std::cout << "bad farkas rule: " << num_prems(proof) << " premises should be " << numps-2 << "\n"; + } +#endif + for(int i = 2; i < numps; i++){ + rational r; + bool ok = s->get_parameter(i).is_rational(r); + if(!ok) + throw "Bad Farkas coefficient"; +#if 0 + { + ast con = conc(prem(proof,i-2)); + ast temp = make_real(r); // for debugging + opr o = is_not(con) ? op(arg(con,0)) : op(con); + if(is_not(con) ? (o == Leq || o == Lt) : (o == Geq || o == Gt)) + r = -r; + } +#endif + rats[i-2] = r; + } +#if 0 + if(rats.size() != 0 && rats[0].is_neg()){ + for(unsigned i = 0; i < rats.size(); i++){ + assert(rats[i].is_neg()); + rats[i] = -rats[i]; + } + } +#endif + abs_rat(rats); + extract_lcd(rats); +} + +void iz3mgr::get_assign_bounds_coeffs(const ast &proof, std::vector& coeffs){ + std::vector rats; + get_assign_bounds_coeffs(proof,rats); + coeffs.resize(rats.size()); + for(unsigned i = 0; i < rats.size(); i++){ + coeffs[i] = make_int(rats[i]); + } +} + +void iz3mgr::get_assign_bounds_coeffs(const ast &proof, std::vector& rats){ + symb s = sym(proof); + int numps = s->get_num_parameters(); + rats.resize(numps-1); + rats[0] = rational(1); + ast conseq = arg(conc(proof),0); + opr conseq_o = is_not(conseq) ? op(arg(conseq,0)) : op(conseq); + bool conseq_neg = is_not(conseq) ? (conseq_o == Leq || conseq_o == Lt) : (conseq_o == Geq || conseq_o == Gt); + for(int i = 2; i < numps; i++){ + rational r; + bool ok = s->get_parameter(i).is_rational(r); + if(!ok) + throw "Bad Farkas coefficient"; + { + ast con = arg(conc(proof),i-1); + ast temp = make_real(r); // for debugging + opr o = is_not(con) ? op(arg(con,0)) : op(con); + if(is_not(con) ? (o == Leq || o == Lt) : (o == Geq || o == Gt)) + r = -r; + if(conseq_neg) + r = -r; + } + rats[i-1] = r; + } +#if 0 + if(rats[1].is_neg()){ // work around bug -- if all coeffs negative, negate them + for(unsigned i = 1; i < rats.size(); i++){ + if(!rats[i].is_neg()) + throw "Bad Farkas coefficients"; + rats[i] = -rats[i]; + } + } +#endif + abs_rat(rats); + extract_lcd(rats); +} + +void iz3mgr::get_assign_bounds_rule_coeffs(const ast &proof, std::vector& coeffs){ + std::vector rats; + get_assign_bounds_rule_coeffs(proof,rats); + coeffs.resize(rats.size()); + for(unsigned i = 0; i < rats.size(); i++){ + coeffs[i] = make_int(rats[i]); + } +} + +void iz3mgr::get_assign_bounds_rule_coeffs(const ast &proof, std::vector& rats){ + symb s = sym(proof); + int numps = s->get_num_parameters(); + rats.resize(numps-1); + rats[0] = rational(1); + ast conseq = arg(conc(proof),0); + opr conseq_o = is_not(conseq) ? op(arg(conseq,0)) : op(conseq); + bool conseq_neg = is_not(conseq) ? (conseq_o == Leq || conseq_o == Lt) : (conseq_o == Geq || conseq_o == Gt); + for(int i = 2; i < numps; i++){ + rational r; + bool ok = s->get_parameter(i).is_rational(r); + if(!ok) + throw "Bad Farkas coefficient"; + { + ast con = conc(prem(proof,i-2)); + ast temp = make_real(r); // for debugging + opr o = is_not(con) ? op(arg(con,0)) : op(con); + if(is_not(con) ? (o == Leq || o == Lt) : (o == Geq || o == Gt)) + r = -r; + if(conseq_neg) + r = -r; + } + rats[i-1] = r; + } +#if 0 + if(rats[1].is_neg()){ // work around bug -- if all coeffs negative, negate them + for(unsigned i = 1; i < rats.size(); i++){ + if(!rats[i].is_neg()) + throw "Bad Farkas coefficients"; + rats[i] = -rats[i]; + } + } +#endif + abs_rat(rats); + extract_lcd(rats); +} + + /** Set P to P + cQ, where P and Q are linear inequalities. Assumes P is 0 <= y or 0 < y. */ + +void iz3mgr::linear_comb(ast &P, const ast &c, const ast &Q){ + ast Qrhs; + bool strict = op(P) == Lt; + if(is_not(Q)){ + ast nQ = arg(Q,0); + switch(op(nQ)){ + case Gt: + Qrhs = make(Sub,arg(nQ,1),arg(nQ,0)); + break; + case Lt: + Qrhs = make(Sub,arg(nQ,0),arg(nQ,1)); + break; + case Geq: + Qrhs = make(Sub,arg(nQ,1),arg(nQ,0)); + strict = true; + break; + case Leq: + Qrhs = make(Sub,arg(nQ,0),arg(nQ,1)); + strict = true; + break; + default: + throw "not an inequality"; + } + } + else { + switch(op(Q)){ + case Leq: + Qrhs = make(Sub,arg(Q,1),arg(Q,0)); + break; + case Geq: + Qrhs = make(Sub,arg(Q,0),arg(Q,1)); + break; + case Lt: + Qrhs = make(Sub,arg(Q,1),arg(Q,0)); + strict = true; + break; + case Gt: + Qrhs = make(Sub,arg(Q,0),arg(Q,1)); + strict = true; + break; + default: + throw "not an inequality"; + } + } + Qrhs = make(Times,c,Qrhs); + if(strict) + P = make(Lt,arg(P,0),make(Plus,arg(P,1),Qrhs)); + else + P = make(Leq,arg(P,0),make(Plus,arg(P,1),Qrhs)); +} + +iz3mgr::ast iz3mgr::sum_inequalities(const std::vector &coeffs, const std::vector &ineqs){ + ast zero = make_int("0"); + ast thing = make(Leq,zero,zero); + for(unsigned i = 0; i < ineqs.size(); i++){ + linear_comb(thing,coeffs[i],ineqs[i]); + } + thing = simplify_ineq(thing); + return thing; +} + +void iz3mgr::mk_idiv(const ast& t, const rational &d, ast &whole, ast &frac){ + opr o = op(t); + if(o == Plus){ + int nargs = num_args(t); + for(int i = 0; i < nargs; i++) + mk_idiv(arg(t,i),d,whole,frac); + return; + } + else if(o == Times){ + rational coeff; + if(is_numeral(arg(t,0),coeff)){ + if(gcd(coeff,d) == d){ + whole = make(Plus,whole,make(Times,make_int(coeff/d),arg(t,1))); + return; + } + } + } + frac = make(Plus,frac,t); +} + +iz3mgr::ast iz3mgr::mk_idiv(const ast& q, const rational &d){ + ast t = z3_simplify(q); + if(d == rational(1)) + return t; + else { + ast whole = make_int("0"); + ast frac = whole; + mk_idiv(t,d,whole,frac); + return z3_simplify(make(Plus,whole,make(Idiv,z3_simplify(frac),make_int(d)))); + } +} + +iz3mgr::ast iz3mgr::mk_idiv(const ast& t, const ast &d){ + rational r; + if(is_numeral(d,r)) + return mk_idiv(t,r); + return make(Idiv,t,d); +} + + +// does variable occur in expression? +int iz3mgr::occurs_in1(stl_ext::hash_map &occurs_in_memo,ast var, ast e){ + std::pair foo(e,false); + std::pair::iterator,bool> bar = occurs_in_memo.insert(foo); + bool &res = bar.first->second; + if(bar.second){ + if(e == var) res = true; + int nargs = num_args(e); + for(int i = 0; i < nargs; i++) + res |= occurs_in1(occurs_in_memo,var,arg(e,i)); + } + return res; +} + +int iz3mgr::occurs_in(ast var, ast e){ + hash_map memo; + return occurs_in1(memo,var,e); +} + + +bool iz3mgr::solve_arith(const ast &v, const ast &x, const ast &y, ast &res){ + if(op(x) == Plus){ + int n = num_args(x); + for(int i = 0; i < n; i++){ + if(arg(x,i) == v){ + res = z3_simplify(make(Sub, y, make(Sub, x, v))); + return true; + } + } + } + return false; +} + +// find a controlling equality for a given variable v in a term +// a controlling equality is of the form v = t, which, being +// false would force the formula to have the specifid truth value +// returns t, or null if no such + +iz3mgr::ast iz3mgr::cont_eq(stl_ext::hash_set &cont_eq_memo, bool truth, ast v, ast e){ + if(is_not(e)) return cont_eq(cont_eq_memo, !truth,v,arg(e,0)); + if(cont_eq_memo.find(e) != cont_eq_memo.end()) + return ast(); + cont_eq_memo.insert(e); + if(!truth && op(e) == Equal){ + if(arg(e,0) == v) return(arg(e,1)); + if(arg(e,1) == v) return(arg(e,0)); + ast res; + if(solve_arith(v,arg(e,0),arg(e,1),res)) return res; + if(solve_arith(v,arg(e,1),arg(e,0),res)) return res; + } + if((!truth && op(e) == And) || (truth && op(e) == Or)){ + int nargs = num_args(e); + for(int i = 0; i < nargs; i++){ + ast res = cont_eq(cont_eq_memo, truth, v, arg(e,i)); + if(!res.null()) return res; + } + } + if(truth && op(e) == Implies){ + ast res = cont_eq(cont_eq_memo, !truth, v, arg(e,0)); + if(!res.null()) return res; + res = cont_eq(cont_eq_memo, truth, v, arg(e,1)); + if(!res.null()) return res; + } + return ast(); +} + + // substitute a term t for unbound occurrences of variable v in e + +iz3mgr::ast iz3mgr::subst(stl_ext::hash_map &subst_memo, ast var, ast t, ast e){ + if(e == var) return t; + std::pair foo(e,ast()); + std::pair::iterator,bool> bar = subst_memo.insert(foo); + ast &res = bar.first->second; + if(bar.second){ + int nargs = num_args(e); + std::vector args(nargs); + for(int i = 0; i < nargs; i++) + args[i] = subst(subst_memo,var,t,arg(e,i)); + opr f = op(e); + if(f == Equal && args[0] == args[1]) res = mk_true(); + else res = clone(e,args); + } + return res; +} + +iz3mgr::ast iz3mgr::subst(ast var, ast t, ast e){ + hash_map memo; + return subst(memo,var,t,e); +} + +iz3mgr::ast iz3mgr::subst(stl_ext::hash_map &subst_memo,ast e){ + std::pair foo(e,ast()); + std::pair::iterator,bool> bar = subst_memo.insert(foo); + ast &res = bar.first->second; + if(bar.second){ + int nargs = num_args(e); + std::vector args(nargs); + for(int i = 0; i < nargs; i++) + args[i] = subst(subst_memo,arg(e,i)); + opr f = op(e); + if(f == Equal && args[0] == args[1]) res = mk_true(); + else res = clone(e,args); + } + return res; +} + + // apply a quantifier to a formula, with some optimizations + // 1) bound variable does not occur -> no quantifier + // 2) bound variable must be equal to some term -> substitute + +iz3mgr::ast iz3mgr::apply_quant(opr quantifier, ast var, ast e){ + if((quantifier == Forall && op(e) == And) + || (quantifier == Exists && op(e) == Or)){ + int n = num_args(e); + std::vector args(n); + for(int i = 0; i < n; i++) + args[i] = apply_quant(quantifier,var,arg(e,i)); + return make(op(e),args); + } + if(!occurs_in(var,e))return e; + hash_set cont_eq_memo; + ast cterm = cont_eq(cont_eq_memo, quantifier == Forall, var, e); + if(!cterm.null()){ + return subst(var,cterm,e); + } + std::vector bvs; bvs.push_back(var); + return make_quant(quantifier,bvs,e); +} diff --git a/src/interp/iz3mgr.h b/src/interp/iz3mgr.h new file mode 100644 index 000000000..645c72ccb --- /dev/null +++ b/src/interp/iz3mgr.h @@ -0,0 +1,718 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3mgr.h + +Abstract: + + A wrapper around an ast manager, providing convenience methods. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#ifndef IZ3MGR_H +#define IZ3MGR_H + + +#include +#include "iz3hash.h" + +#include"well_sorted.h" +#include"arith_decl_plugin.h" +#include"bv_decl_plugin.h" +#include"datatype_decl_plugin.h" +#include"array_decl_plugin.h" +#include"ast_translation.h" +#include"ast_pp.h" +#include"ast_ll_pp.h" +#include"ast_smt_pp.h" +#include"ast_smt2_pp.h" +#include"th_rewriter.h" +#include"var_subst.h" +#include"expr_substitution.h" +#include"pp.h" +#include"scoped_ctrl_c.h" +#include"cancel_eh.h" +#include"scoped_timer.h" +//# include"pp_params.hpp" + +/* A wrapper around an ast manager, providing convenience methods. */ + +/** Shorthands for some built-in operators. */ + + + +// rename this to keep it accessible, as we use ast for something else +typedef ast raw_ast; + +/** Wrapper around an ast pointer */ +class ast_i { + protected: + raw_ast *_ast; + public: + raw_ast * const &raw() const {return _ast;} + ast_i(raw_ast *a){_ast = a;} + + ast_i(){_ast = 0;} + bool eq(const ast_i &other) const { + return _ast == other._ast; + } + bool lt(const ast_i &other) const { + return _ast->get_id() < other._ast->get_id(); + } + friend bool operator==(const ast_i &x, const ast_i&y){ + return x.eq(y); + } + friend bool operator!=(const ast_i &x, const ast_i&y){ + return !x.eq(y); + } + friend bool operator<(const ast_i &x, const ast_i&y){ + return x.lt(y); + } + size_t hash() const {return _ast->get_id();} + bool null() const {return !_ast;} +}; + +/** Reference counting verison of above */ +class ast_r : public ast_i { + ast_manager *_m; + public: + ast_r(ast_manager *m, raw_ast *a) : ast_i(a) { + _m = m; + m->inc_ref(a); + } + + ast_r() {_m = 0;} + + ast_r(const ast_r &other) : ast_i(other) { + _m = other._m; + _m->inc_ref(_ast); + } + + ast_r &operator=(const ast_r &other) { + if(_ast) + _m->dec_ref(_ast); + _ast = other._ast; + _m = other._m; + _m->inc_ref(_ast); + return *this; + } + + ~ast_r(){ + if(_ast) + _m->dec_ref(_ast); + } + + ast_manager *mgr() const {return _m;} + +}; + +// to make ast_r hashable +namespace hash_space { + template <> + class hash { + public: + size_t operator()(const ast_r &s) const { + return s.raw()->get_id(); + } + }; +} + +// to make ast_r hashable in windows +#ifdef WIN32 +template <> inline +size_t stdext::hash_value(const ast_r& s) +{ + return s.raw()->get_id(); +} +#endif + +// to make ast_r usable in ordered collections +namespace std { + template <> + class less { + public: + bool operator()(const ast_r &s, const ast_r &t) const { + // return s.raw() < t.raw(); + return s.raw()->get_id() < t.raw()->get_id(); + } + }; +} + + +/** Wrapper around an AST manager, providing convenience methods. */ + +class iz3mgr { + + public: + typedef ast_r ast; + // typedef decl_kind opr; + typedef func_decl *symb; + typedef sort *type; + typedef ast_r z3pf; + typedef decl_kind pfrule; + + enum opr { + True, + False, + And, + Or, + Not, + Iff, + Ite, + Equal, + Implies, + Distinct, + Xor, + Oeq, + Interp, + Leq, + Geq, + Lt, + Gt, + Plus, + Sub, + Uminus, + Times, + Div, + Idiv, + Rem, + Mod, + Power, + ToReal, + ToInt, + IsInt, + Select, + Store, + ConstArray, + ArrayDefault, + ArrayMap, + SetUnion, + SetIntersect, + SetDifference, + SetComplement, + SetSubSet, + AsArray, + Numeral, + Forall, + Exists, + Variable, + Uninterpreted, + Other + }; + + opr op(const ast &t); + + unsigned ast_id(const ast &x) + { + return to_expr(x.raw())->get_id(); + } + + /** Overloads for constructing ast. */ + + ast make_var(const std::string &name, type ty); + ast make(opr op, const std::vector &args); + ast make(opr op); + ast make(opr op, const ast &arg0); + ast make(opr op, const ast &arg0, const ast &arg1); + ast make(opr op, const ast &arg0, const ast &arg1, const ast &arg2); + ast make(symb sym, const std::vector &args); + ast make(symb sym); + ast make(symb sym, const ast &arg0); + ast make(symb sym, const ast &arg0, const ast &arg1); + ast make(symb sym, const ast &arg0, const ast &arg1, const ast &arg2); + ast make_quant(opr op, const std::vector &bvs, ast &body); + ast clone(const ast &t, const std::vector &args); + + ast_manager &m() {return m_manager;} + + ast cook(raw_ast *a) {return ast(&m_manager,a);} + + std::vector cook(ptr_vector v) { + std::vector _v(v.size()); + for(unsigned i = 0; i < v.size(); i++) + _v[i] = cook(v[i]); + return _v; + } + + raw_ast *uncook(const ast &a) { + m_manager.inc_ref(a.raw()); + return a.raw(); + } + + /** Methods for destructing ast. */ + + + int num_args(ast t){ + ast_kind dk = t.raw()->get_kind(); + switch(dk){ + case AST_APP: + return to_app(t.raw())->get_num_args(); + case AST_QUANTIFIER: + return 1; + case AST_VAR: + return 0; + default:; + } + assert(0); + } + + ast arg(const ast &t, int i){ + ast_kind dk = t.raw()->get_kind(); + switch(dk){ + case AST_APP: + return cook(to_app(t.raw())->get_arg(i)); + case AST_QUANTIFIER: + return cook(to_quantifier(t.raw())->get_expr()); + default:; + } + assert(0); + return ast(); + } + + void get_args(const ast &t, std::vector &res){ + res.resize(num_args(t)); + for(unsigned i = 0; i < res.size(); i++) + res[i] = arg(t,i); + } + + symb sym(ast t){ + return to_app(t.raw())->get_decl(); + } + + std::string string_of_symbol(symb s){ + symbol _s = s->get_name(); + if (_s.is_numerical()) { + std::ostringstream buffer; + buffer << _s.get_num(); + return buffer.str(); + } + else { + return _s.bare_str(); + } + } + + type get_type(ast t){ + return m().get_sort(to_expr(t.raw())); + } + + std::string string_of_numeral(const ast& t){ + rational r; + expr* e = to_expr(t.raw()); + assert(e); + if (m_arith_util.is_numeral(e, r)) + return r.to_string(); + assert(0); + return "NaN"; + } + + bool is_numeral(const ast& t, rational &r){ + expr* e = to_expr(t.raw()); + assert(e); + return m_arith_util.is_numeral(e, r); + } + + rational get_coeff(const ast& t){ + rational res; + if(op(t) == Times && is_numeral(arg(t,0),res)) + return res; + return rational(1); + } + + ast get_linear_var(const ast& t){ + rational res; + if(op(t) == Times && is_numeral(arg(t,0),res)) + return arg(t,1); + return t; + } + + int get_quantifier_num_bound(const ast &t) { + return to_quantifier(t.raw())->get_num_decls(); + } + + std::string get_quantifier_bound_name(const ast &t, unsigned i) { + return to_quantifier(t.raw())->get_decl_names()[i].bare_str(); + } + + type get_quantifier_bound_type(const ast &t, unsigned i) { + return to_quantifier(t.raw())->get_decl_sort(i); + } + + ast get_quantifier_body(const ast &t) { + return cook(to_quantifier(t.raw())->get_expr()); + } + + unsigned get_variable_index_value(const ast &t) { + var* va = to_var(t.raw()); + return va->get_idx(); + } + + bool is_bool_type(type t){ + family_id fid = to_sort(t)->get_family_id(); + decl_kind k = to_sort(t)->get_decl_kind(); + return fid == m().get_basic_family_id() && k == BOOL_SORT; + } + + bool is_array_type(type t){ + family_id fid = to_sort(t)->get_family_id(); + decl_kind k = to_sort(t)->get_decl_kind(); + return fid == m_array_fid && k == ARRAY_SORT; + } + + type get_range_type(symb s){ + return to_func_decl(s)->get_range(); + } + + int get_num_parameters(const symb &s){ + return to_func_decl(s)->get_num_parameters(); + } + + ast get_ast_parameter(const symb &s, int idx){ + return cook(to_func_decl(s)->get_parameters()[idx].get_ast()); + } + + enum lemma_theory {ArithTheory,ArrayTheory,UnknownTheory}; + + lemma_theory get_theory_lemma_theory(const ast &proof){ + symb s = sym(proof); + ::symbol p0; + bool ok = s->get_parameter(0).is_symbol(p0); + if(!ok) return UnknownTheory; + std::string foo(p0.bare_str()); + if(foo == "arith") + return ArithTheory; + if(foo == "array") + return ArrayTheory; + return UnknownTheory; + } + + enum lemma_kind {FarkasKind,Leq2EqKind,Eq2LeqKind,GCDTestKind,AssignBoundsKind,EqPropagateKind,UnknownKind}; + + lemma_kind get_theory_lemma_kind(const ast &proof){ + symb s = sym(proof); + ::symbol p0; + bool ok = s->get_parameter(1).is_symbol(p0); + if(!ok) return UnknownKind; + std::string foo(p0.bare_str()); + if(foo == "farkas") + return FarkasKind; + if(foo == "triangle-eq") + return is_not(arg(conc(proof),0)) ? Eq2LeqKind : Leq2EqKind; + if(foo == "gcd-test") + return GCDTestKind; + if(foo == "assign-bounds") + return AssignBoundsKind; + if(foo == "eq-propagate") + return EqPropagateKind; + return UnknownKind; + } + + void get_farkas_coeffs(const ast &proof, std::vector& coeffs); + + void get_farkas_coeffs(const ast &proof, std::vector& rats); + + void get_assign_bounds_coeffs(const ast &proof, std::vector& rats); + + void get_assign_bounds_coeffs(const ast &proof, std::vector& rats); + + void get_assign_bounds_rule_coeffs(const ast &proof, std::vector& rats); + + void get_assign_bounds_rule_coeffs(const ast &proof, std::vector& rats); + + bool is_farkas_coefficient_negative(const ast &proof, int n); + + bool is_true(ast t){ + return op(t) == True; + } + + bool is_false(ast t){ + return op(t) == False; + } + + bool is_iff(ast t){ + return op(t) == Iff; + } + + bool is_or(ast t){ + return op(t) == Or; + } + + bool is_not(ast t){ + return op(t) == Not; + } + + /** Simplify an expression using z3 simplifier */ + + ast z3_simplify(const ast& e); + + /** Simplify, sorting sums */ + ast z3_really_simplify(const ast &e); + + + // Some constructors that simplify things + + ast mk_not(ast x){ + opr o = op(x); + if(o == True) return make(False); + if(o == False) return make(True); + if(o == Not) return arg(x,0); + return make(Not,x); + } + + ast mk_and(ast x, ast y){ + opr ox = op(x); + opr oy = op(y); + if(ox == True) return y; + if(oy == True) return x; + if(ox == False) return x; + if(oy == False) return y; + if(x == y) return x; + return make(And,x,y); + } + + ast mk_or(ast x, ast y){ + opr ox = op(x); + opr oy = op(y); + if(ox == False) return y; + if(oy == False) return x; + if(ox == True) return x; + if(oy == True) return y; + if(x == y) return x; + return make(Or,x,y); + } + + ast mk_implies(ast x, ast y){ + opr ox = op(x); + opr oy = op(y); + if(ox == True) return y; + if(oy == False) return mk_not(x); + if(ox == False) return mk_true(); + if(oy == True) return y; + if(x == y) return mk_true(); + return make(Implies,x,y); + } + + ast mk_or(const std::vector &x){ + ast res = mk_false(); + for(unsigned i = 0; i < x.size(); i++) + res = mk_or(res,x[i]); + return res; + } + + ast mk_and(const std::vector &x){ + std::vector conjs; + for(unsigned i = 0; i < x.size(); i++){ + const ast &e = x[i]; + opr o = op(e); + if(o == False) + return mk_false(); + if(o != True) + conjs.push_back(e); + } + if(conjs.size() == 0) + return mk_true(); + if(conjs.size() == 1) + return conjs[0]; + return make(And,conjs); + } + + ast mk_equal(ast x, ast y){ + if(x == y) return make(True); + opr ox = op(x); + opr oy = op(y); + if(ox == True) return y; + if(oy == True) return x; + if(ox == False) return mk_not(y); + if(oy == False) return mk_not(x); + if(ox == False && oy == True) return make(False); + if(oy == False && ox == True) return make(False); + return make(Equal,x,y); + } + + ast z3_ite(ast x, ast y, ast z){ + opr ox = op(x); + opr oy = op(y); + opr oz = op(z); + if(ox == True) return y; + if(ox == False) return z; + if(y == z) return y; + if(oy == True && oz == False) return x; + if(oz == True && oy == False) return mk_not(x); + return make(Ite,x,y,z); + } + + ast make_int(const std::string &s) { + sort *r = m().mk_sort(m_arith_fid, INT_SORT); + return cook(m_arith_util.mk_numeral(rational(s.c_str()),r)); + } + + ast make_int(const rational &s) { + sort *r = m().mk_sort(m_arith_fid, INT_SORT); + return cook(m_arith_util.mk_numeral(s,r)); + } + + ast make_real(const std::string &s) { + sort *r = m().mk_sort(m_arith_fid, REAL_SORT); + return cook(m_arith_util.mk_numeral(rational(s.c_str()),r)); + } + + ast make_real(const rational &s) { + sort *r = m().mk_sort(m_arith_fid, REAL_SORT); + return cook(m_arith_util.mk_numeral(s,r)); + } + + ast mk_false() { return make(False); } + + ast mk_true() { return make(True); } + + ast mk_fresh_constant(char const * prefix, type s){ + return cook(m().mk_fresh_const(prefix, s)); + } + + type bool_type() { + ::sort *s = m().mk_sort(m_basic_fid, BOOL_SORT); + return s; + } + + type int_type() { + ::sort *s = m().mk_sort(m_arith_fid, INT_SORT); + return s; + } + + type real_type() { + ::sort *s = m().mk_sort(m_arith_fid, REAL_SORT); + return s; + } + + type array_type(type d, type r) { + parameter params[2] = { parameter(d), parameter(to_sort(r)) }; + ::sort * s = m().mk_sort(m_array_fid, ARRAY_SORT, 2, params); + return s; + } + + symb function(const std::string &str_name, unsigned arity, type *domain, type range) { + ::symbol name = ::symbol(str_name.c_str()); + std::vector< ::sort *> sv(arity); + for(unsigned i = 0; i < arity; i++) + sv[i] = domain[i]; + ::func_decl* d = m().mk_func_decl(name,arity,&sv[0],range); + return d; + } + + void linear_comb(ast &P, const ast &c, const ast &Q); + + ast sum_inequalities(const std::vector &coeffs, const std::vector &ineqs); + + ast simplify_ineq(const ast &ineq){ + ast res = make(op(ineq),arg(ineq,0),z3_simplify(arg(ineq,1))); + return res; + } + + void mk_idiv(const ast& t, const rational &d, ast &whole, ast &frac); + + ast mk_idiv(const ast& t, const rational &d); + + ast mk_idiv(const ast& t, const ast &d); + + /** methods for destructing proof terms */ + + pfrule pr(const z3pf &t); + + int num_prems(const z3pf &t){return to_app(t.raw())->get_num_args()-1;} + + z3pf prem(const z3pf &t, int n){return arg(t,n);} + + z3pf conc(const z3pf &t){return arg(t,num_prems(t));} + + + /* quantifier handling */ + + // substitute a term t for unbound occurrences of variable v in e + + ast subst(ast var, ast t, ast e); + + // apply a substitution defined by a map + ast subst(stl_ext::hash_map &map, ast e); + + // apply a quantifier to a formula, with some optimizations + // 1) bound variable does not occur -> no quantifier + // 2) bound variable must be equal to some term -> substitute + + ast apply_quant(opr quantifier, ast var, ast e); + + + /** For debugging */ + void show(ast); + + void show_symb(symb s); + + /** Constructor */ + + void print_lit(ast lit); + + void print_expr(std::ostream &s, const ast &e); + + void print_clause(std::ostream &s, std::vector &cls); + + void print_sat_problem(std::ostream &out, const ast &t); + + void show_clause(std::vector &cls); + + static void pretty_print(std::ostream &f, const std::string &s); + + iz3mgr(ast_manager &_m_manager) + : m_manager(_m_manager), + m_arith_util(_m_manager) + { + m_basic_fid = m().get_basic_family_id(); + m_arith_fid = m().mk_family_id("arith"); + m_bv_fid = m().mk_family_id("bv"); + m_array_fid = m().mk_family_id("array"); + m_dt_fid = m().mk_family_id("datatype"); + m_datalog_fid = m().mk_family_id("datalog_relation"); + } + + iz3mgr(const iz3mgr& other) + : m_manager(other.m_manager), + m_arith_util(other.m_manager) + { + m_basic_fid = m().get_basic_family_id(); + m_arith_fid = m().mk_family_id("arith"); + m_bv_fid = m().mk_family_id("bv"); + m_array_fid = m().mk_family_id("array"); + m_dt_fid = m().mk_family_id("datatype"); + m_datalog_fid = m().mk_family_id("datalog_relation"); + } + + protected: + ast_manager &m_manager; + int occurs_in(ast var, ast e); + + private: + ast mki(family_id fid, decl_kind sk, int n, raw_ast **args); + ast make(opr op, int n, raw_ast **args); + ast make(symb sym, int n, raw_ast **args); + int occurs_in1(stl_ext::hash_map &occurs_in_memo, ast var, ast e); + bool solve_arith(const ast &v, const ast &x, const ast &y, ast &res); + ast cont_eq(stl_ext::hash_set &cont_eq_memo, bool truth, ast v, ast e); + ast subst(stl_ext::hash_map &subst_memo, ast var, ast t, ast e); + + + family_id m_basic_fid; + family_id m_array_fid; + family_id m_arith_fid; + family_id m_bv_fid; + family_id m_dt_fid; + family_id m_datalog_fid; + arith_util m_arith_util; +}; + +#endif + diff --git a/src/interp/iz3params.pyg b/src/interp/iz3params.pyg new file mode 100644 index 000000000..5b574c859 --- /dev/null +++ b/src/interp/iz3params.pyg @@ -0,0 +1,5 @@ +def_module_params('interp', + description='interpolation parameters', + export=True, + params=(('profile', BOOL, False, '(INTERP) profile interpolation'), + )) diff --git a/src/interp/iz3pp.cpp b/src/interp/iz3pp.cpp new file mode 100644 index 000000000..df6fcaf53 --- /dev/null +++ b/src/interp/iz3pp.cpp @@ -0,0 +1,189 @@ +/*++ +Copyright (c) 2013 Microsoft Corporation + +Module Name: + + iz3pp.cpp + +Abstract: + + Pretty-print interpolation problems + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +/* Copyright 2011 Microsoft Research. */ +#include +#include +#include +#include +#include +#include +#include + +#include "iz3mgr.h" +#include "iz3pp.h" +#include "func_decl_dependencies.h" +#include"for_each_expr.h" +#include"ast_smt_pp.h" +#include"ast_smt2_pp.h" +#include"expr_functors.h" +#include"expr_abstract.h" + + +#ifndef WIN32 +using namespace stl_ext; +#endif + +#ifndef WIN32 +// We promise not to use this for hash_map with range destructor +namespace stl_ext { + template <> + class hash { + public: + size_t operator()(const expr *p) const { + return (size_t) p; + } + }; +} +#endif + + +// TBD: algebraic data-types declarations will not be printed. +class free_func_visitor { + ast_manager& m; + func_decl_set m_funcs; + obj_hashtable m_sorts; +public: + free_func_visitor(ast_manager& m): m(m) {} + void operator()(var * n) { } + void operator()(app * n) { + m_funcs.insert(n->get_decl()); + sort* s = m.get_sort(n); + if (s->get_family_id() == null_family_id) { + m_sorts.insert(s); + } + } + void operator()(quantifier * n) { } + func_decl_set& funcs() { return m_funcs; } + obj_hashtable& sorts() { return m_sorts; } +}; + +class iz3pp_helper : public iz3mgr { +public: + + void print_tree(const ast &tree, hash_map &cnames, std::ostream &out){ + hash_map::iterator foo = cnames.find(to_expr(tree.raw())); + if(foo != cnames.end()){ + symbol nm = foo->second; + if (is_smt2_quoted_symbol(nm)) { + out << mk_smt2_quoted_symbol(nm); + } + else { + out << nm; + } + } + else if(op(tree) == And){ + out << "(and"; + int nargs = num_args(tree); + for(int i = 0; i < nargs; i++){ + out << " "; + print_tree(arg(tree,i), cnames, out); + } + out << ")"; + } + else if(op(tree) == Interp){ + out << "(interp "; + print_tree(arg(tree,0), cnames, out); + out << ")"; + } + else throw iz3pp_bad_tree(); + } + + + iz3pp_helper(ast_manager &_m_manager) + : iz3mgr(_m_manager) {} +}; + +void iz3pp(ast_manager &m, + const ptr_vector &cnsts_vec, + expr *tree, + std::ostream& out) { + + unsigned sz = cnsts_vec.size(); + expr* const* cnsts = &cnsts_vec[0]; + + out << "(set-option :produce-interpolants true)\n"; + + free_func_visitor visitor(m); + expr_mark visited; + bool print_low_level = true; // m_params.print_low_level_smt2(); + +#define PP(_e_) if (print_low_level) out << mk_smt_pp(_e_, m); else ast_smt2_pp(out, _e_, env); + + smt2_pp_environment_dbg env(m); + + for (unsigned i = 0; i < sz; ++i) { + expr* e = cnsts[i]; + for_each_expr(visitor, visited, e); + } + + // name all the constraints + hash_map cnames; + int ctr = 1; + for(unsigned i = 0; i < sz; i++){ + symbol nm; + std::ostringstream s; + s << "f!" << (ctr++); + cnames[cnsts[i]] = symbol(s.str().c_str()); + } + + func_decl_set &funcs = visitor.funcs(); + func_decl_set::iterator it = funcs.begin(), end = funcs.end(); + + obj_hashtable& sorts = visitor.sorts(); + obj_hashtable::iterator sit = sorts.begin(), send = sorts.end(); + + + + for (; sit != send; ++sit) { + PP(*sit); + } + + for (; it != end; ++it) { + func_decl* f = *it; + if(f->get_family_id() == null_family_id){ + PP(f); + out << "\n"; + } + } + + for (unsigned i = 0; i < sz; ++i) { + out << "(assert "; + expr* r = cnsts[i]; + symbol nm = cnames[r]; + out << "(! "; + PP(r); + out << " :named "; + if (is_smt2_quoted_symbol(nm)) { + out << mk_smt2_quoted_symbol(nm); + } + else { + out << nm; + } + out << ")"; + out << ")\n"; + } + out << "(check-sat)\n"; + out << "(get-interpolant "; + iz3pp_helper pp(m); + pp.print_tree(pp.cook(tree),cnames,out); + out << ")\n"; +} + + diff --git a/src/interp/iz3pp.h b/src/interp/iz3pp.h new file mode 100644 index 000000000..3ea5925e8 --- /dev/null +++ b/src/interp/iz3pp.h @@ -0,0 +1,35 @@ +/*++ +Copyright (c) 2013 Microsoft Corporation + +Module Name: + + iz3pp.cpp + +Abstract: + + Pretty-print interpolation problems + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#ifndef IZ3_PP_H +#define IZ3_PP_H + +#include "iz3mgr.h" + +/** Exception thrown in case of mal-formed tree interpoloation + specification */ + +struct iz3pp_bad_tree { +}; + +void iz3pp(ast_manager &m, + const ptr_vector &cnsts_vec, + expr *tree, + std::ostream& out); +#endif diff --git a/src/interp/iz3profiling.cpp b/src/interp/iz3profiling.cpp new file mode 100755 index 000000000..262254271 --- /dev/null +++ b/src/interp/iz3profiling.cpp @@ -0,0 +1,146 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3profiling.h + +Abstract: + + Some routines for measuring performance. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#include "iz3profiling.h" + +#include +#include +#include +#include +#include +#include "stopwatch.h" + + +// FIXME fill in these stubs + +#define clock_t double + +static double current_time() +{ + static stopwatch sw; + static bool started = false; + if(!started){ + sw.start(); + started = true; + } + return sw.get_current_seconds(); +} + +static void output_time(std::ostream &os, clock_t time){ + os << time; +} + + +namespace profiling { + + void show_time(){ + output_time(std::cout,current_time()); + std::cout << "\n"; + } + + typedef std::map nmap; + + struct node { + std::string name; + clock_t time; + clock_t start_time; + nmap sub; + struct node *parent; + + node(); + } top; + + node::node(){ + time = 0; + parent = 0; + } + + struct node *current; + + struct init { + init(){ + top.name = "TOTAL"; + current = ⊤ + } + } initializer; + + struct time_entry { + clock_t t; + time_entry(){t = 0;}; + void add(clock_t incr){t += incr;} + }; + + struct ltstr + { + bool operator()(const char* s1, const char* s2) const + { + return strcmp(s1, s2) < 0; + } + }; + + typedef std::map tmap; + + static std::ostream *pfs; + + void print_node(node &top, int indent, tmap &totals){ + for(int i = 0; i < indent; i++) (*pfs) << " "; + (*pfs) << top.name; + int dots = 70 - 2 * indent - top.name.size(); + for(int i = 0; i second,indent+1,totals); + } + + void print(std::ostream &os) { + pfs = &os; + top.time = 0; + for(nmap::iterator it = top.sub.begin(); it != top.sub.end(); it++) + top.time += it->second.time; + tmap totals; + print_node(top,0,totals); + (*pfs) << "TOTALS:" << std::endl; + for(tmap::iterator it = totals.begin(); it != totals.end(); it++){ + (*pfs) << (it->first) << " "; + output_time(*pfs, it->second.t); + (*pfs) << std::endl; + } + } + + void timer_start(const char *name){ + node &child = current->sub[name]; + if(child.name.empty()){ // a new node + child.parent = current; + child.name = name; + } + child.start_time = current_time(); + current = &child; + } + + void timer_stop(const char *name){ + if(current->name != name || !current->parent){ + std::cerr << "imbalanced timer_start and timer_stop"; + exit(1); + } + current->time += (current_time() - current->start_time); + current = current->parent; + } +} diff --git a/src/interp/iz3profiling.h b/src/interp/iz3profiling.h new file mode 100755 index 000000000..cdb7066e3 --- /dev/null +++ b/src/interp/iz3profiling.h @@ -0,0 +1,37 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3profiling.h + +Abstract: + + Some routines for measuring performance. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#ifndef IZ3PROFILING_H +#define IZ3PROFILING_H + +#include + +namespace profiling { + /** Start a timer with given name */ + void timer_start(const char *); + /** Stop a timer with given name */ + void timer_stop(const char *); + /** Print out timings */ + void print(std::ostream &s); + /** Show the current time. */ + void show_time(); +} + +#endif + diff --git a/src/interp/iz3proof.cpp b/src/interp/iz3proof.cpp new file mode 100755 index 000000000..968a4b25a --- /dev/null +++ b/src/interp/iz3proof.cpp @@ -0,0 +1,622 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3proof.cpp + +Abstract: + + This class defines a simple interpolating proof system. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + + + +#include "iz3proof.h" +#include "iz3profiling.h" + +#include +#include +#include +#include + +// #define FACTOR_INTERPS +// #define CHECK_PROOFS + + +void iz3proof::resolve(ast pivot, std::vector &cls1, const std::vector &cls2){ +#ifdef CHECK_PROOFS + std::vector orig_cls1 = cls1; +#endif + ast neg_pivot = pv->mk_not(pivot); + bool found_pivot1 = false, found_pivot2 = false; + for(unsigned i = 0; i < cls1.size(); i++){ + if(cls1[i] == neg_pivot){ + cls1[i] = cls1.back(); + cls1.pop_back(); + found_pivot1 = true; + break; + } + } + { + std::set memo; + memo.insert(cls1.begin(),cls1.end()); + for(unsigned j = 0; j < cls2.size(); j++){ + if(cls2[j] == pivot) + found_pivot2 = true; + else + if(memo.find(cls2[j]) == memo.end()) + cls1.push_back(cls2[j]); + } + } + if(found_pivot1 && found_pivot2) + return; + +#ifdef CHECK_PROOFS + std::cerr << "resolution anomaly: " << nodes.size()-1 << "\n"; +#if 0 + std::cerr << "pivot: "; {pv->print_lit(pivot); std::cout << "\n";} + std::cerr << "left clause:\n"; + for(unsigned i = 0; i < orig_cls1.size(); i++) + {pv->print_lit(orig_cls1[i]); std::cout << "\n";} + std::cerr << "right clause:\n"; + for(unsigned i = 0; i < cls2.size(); i++) + {pv->print_lit(cls2[i]); std::cout << "\n";} + throw proof_error(); +#endif +#endif +} + +iz3proof::node iz3proof::make_resolution(ast pivot, node premise1, node premise2) +{ + if(nodes[premise1].rl == Hypothesis) return premise2; // resolve with hyp is noop + if(nodes[premise2].rl == Hypothesis) return premise1; + node res = make_node(); + node_struct &n = nodes[res]; + n.rl = Resolution; + n.aux = pivot; + n.premises.resize(2); + n.premises[0] = (premise1); + n.premises[1] = (premise2); +#ifdef CHECK_PROOFS + n.conclusion = nodes[premise1].conclusion; + resolve(pivot,n.conclusion,nodes[premise2].conclusion); + n.frame = 1; +#else + n.frame = 0; // compute conclusion lazily +#endif + return res; +} + +iz3proof::node iz3proof::resolve_lemmas(ast pivot, node premise1, node premise2) +{ + std::vector lits(nodes[premise1].conclusion), itp; // no interpolant + resolve(pivot,lits,nodes[premise2].conclusion); + return make_lemma(lits,itp); +} + + +iz3proof::node iz3proof::make_assumption(int frame, const std::vector &assumption){ +#if 0 + std::cout << "assumption: \n"; + for(unsigned i = 0; i < assumption.size(); i++) + pv->show(assumption[i]); + std::cout << "\n"; +#endif + node res = make_node(); + node_struct &n = nodes[res]; + n.rl = Assumption; + n.conclusion.resize(1); + n.conclusion = assumption; + n.frame = frame; + return res; +} + +iz3proof::node iz3proof::make_hypothesis(ast hypothesis){ + node res = make_node(); + node_struct &n = nodes[res]; + n.rl = Hypothesis; + n.conclusion.resize(2); + n.conclusion[0] = hypothesis; + n.conclusion[1] = pv->mk_not(hypothesis); + return res; +} + +iz3proof::node iz3proof::make_theory(const std::vector &conclusion, std::vector premises){ + node res = make_node(); + node_struct &n = nodes[res]; + n.rl = Theory; + n.conclusion = conclusion; + n.premises = premises; + return res; +} + +iz3proof::node iz3proof::make_axiom(const std::vector &conclusion){ + node res = make_node(); + node_struct &n = nodes[res]; + n.rl = Axiom; + n.conclusion = conclusion; + return res; +} + +iz3proof::node iz3proof::make_contra(node prem, const std::vector &conclusion){ + node res = make_node(); + node_struct &n = nodes[res]; + n.rl = Contra; + n.conclusion = conclusion; +#ifdef CHECK_PROOFS + //if(!(conclusion == nodes[prem].conclusion)){ + //std::cerr << "internal error: proof error\n"; + //assert(0 && "proof error"); + //} +#endif + n.premises.push_back(prem); + return res; +} + + +iz3proof::node iz3proof::make_lemma(const std::vector &conclusion, const std::vector &interpolation){ + node res = make_node(); + node_struct &n = nodes[res]; + n.rl = Lemma; + n.conclusion = conclusion; + n.frame = interps.size(); + interps.push_back(interpolation); + return res; +} + +/** Make a Reflexivity node. This rule produces |- x = x */ + +iz3proof::node iz3proof::make_reflexivity(ast con){ + node res = make_node(); + node_struct &n = nodes[res]; + n.rl = Reflexivity; + n.conclusion.push_back(con); + return res; +} + +/** Make a Symmetry node. This takes a derivation of |- x = y and + produces | y = x */ + +iz3proof::node iz3proof::make_symmetry(ast con, node prem){ + node res = make_node(); + node_struct &n = nodes[res]; + n.rl = Reflexivity; + n.conclusion.push_back(con); + n.premises.push_back(prem); + return res; +} + +/** Make a transitivity node. This takes derivations of |- x = y + and |- y = z produces | x = z */ + +iz3proof::node iz3proof::make_transitivity(ast con, node prem1, node prem2){ + node res = make_node(); + node_struct &n = nodes[res]; + n.rl = Transitivity; + n.conclusion.push_back(con); + n.premises.push_back(prem1); + n.premises.push_back(prem2); + return res; +} + + +/** Make a congruence node. This takes derivations of |- x_i = y_i + and produces |- f(x_1,...,x_n) = f(y_1,...,y_n) */ + +iz3proof::node iz3proof::make_congruence(ast con, const std::vector &prems){ + node res = make_node(); + node_struct &n = nodes[res]; + n.rl = Congruence; + n.conclusion.push_back(con); + n.premises = prems; + return res; +} + + +/** Make an equality contradicition node. This takes |- x = y + and |- !(x = y) and produces false. */ + +iz3proof::node iz3proof::make_eqcontra(node prem1, node prem2){ + node res = make_node(); + node_struct &n = nodes[res]; + n.rl = EqContra; + n.premises.push_back(prem1); + n.premises.push_back(prem2); + return res; +} + +iz3proof::node iz3proof::copy_rec(stl_ext::hash_map &memo, iz3proof &src, node n){ + stl_ext::hash_map::iterator it = memo.find(n); + if(it != memo.end()) return (*it).second; + node_struct &ns = src.nodes[n]; + std::vector prems(ns.premises.size()); + for(unsigned i = 0; i < prems.size(); i++) + prems[i] = copy_rec(memo,src,ns.premises[i]); + nodes.push_back(ns); + nodes.back().premises.swap(prems); + if(ns.rl == Lemma){ + nodes.back().frame = interps.size(); + interps.push_back(src.interps[ns.frame]); + } + int res = nodes.size()-1; + memo[n] = res; + return res; +} + +iz3proof::node iz3proof::copy(iz3proof &src, node n){ + stl_ext::hash_map memo; + return copy_rec(memo, src, n); +} + +bool iz3proof::pred_in_A(ast id){ + return weak + ? pv->ranges_intersect(pv->ast_range(id),rng) : + pv->range_contained(pv->ast_range(id),rng); +} + +bool iz3proof::term_in_B(ast id){ + prover::range r = pv->ast_scope(id); + if(weak) { + if(pv->range_min(r) == SHRT_MIN) + return !pv->range_contained(r,rng); + else + return !pv->ranges_intersect(r,rng); + } + else + return !pv->range_contained(r,rng); +} + +bool iz3proof::frame_in_A(int frame){ + return pv->in_range(frame,rng); +} + +bool iz3proof::lit_in_B(ast lit){ + return + b_lits.find(lit) != b_lits.end() + || b_lits.find(pv->mk_not(lit)) != b_lits.end(); +} + +iz3proof::ast iz3proof::my_or(ast x, ast y){ + return pv->mk_not(pv->mk_and(pv->mk_not(x),pv->mk_not(y))); +} + +iz3proof::ast iz3proof::get_A_lits(std::vector &cls){ + ast foo = pv->mk_false(); + for(unsigned i = 0; i < cls.size(); i++){ + ast lit = cls[i]; + if(b_lits.find(pv->mk_not(lit)) == b_lits.end()){ + if(pv->range_max(pv->ast_scope(lit)) == pv->range_min(pv->ast_scope(lit))){ + std::cout << "bad lit: " << pv->range_max(rng) << " : " << pv->range_max(pv->ast_scope(lit)) << " : " << (pv->ast_id(lit)) << " : "; + pv->show(lit); + } + foo = my_or(foo,lit); + } + } + return foo; +} + +iz3proof::ast iz3proof::get_B_lits(std::vector &cls){ + ast foo = pv->mk_false(); + for(unsigned i = 0; i < cls.size(); i++){ + ast lit = cls[i]; + if(b_lits.find(pv->mk_not(lit)) != b_lits.end()) + foo = my_or(foo,lit); + } + return foo; +} + +void iz3proof::set_of_B_lits(std::vector &cls, std::set &res){ + for(unsigned i = 0; i < cls.size(); i++){ + ast lit = cls[i]; + if(b_lits.find(pv->mk_not(lit)) != b_lits.end()) + res.insert(lit); + } +} + +void iz3proof::set_of_A_lits(std::vector &cls, std::set &res){ + for(unsigned i = 0; i < cls.size(); i++){ + ast lit = cls[i]; + if(b_lits.find(pv->mk_not(lit)) == b_lits.end()) + res.insert(lit); + } +} + +void iz3proof::find_B_lits(){ + b_lits.clear(); + for(unsigned i = 0; i < nodes.size(); i++){ + node_struct &n = nodes[i]; + std::vector &cls = n.conclusion; + if(n.rl == Assumption){ + if(weak) goto lemma; + if(!frame_in_A(n.frame)) + for(unsigned j = 0; j < cls.size(); j++) + b_lits.insert(cls[j]); + } + else if(n.rl == Lemma) { + lemma: + for(unsigned j = 0; j < cls.size(); j++) + if(term_in_B(cls[j])) + b_lits.insert(cls[j]); + } + } +} + +iz3proof::ast iz3proof::disj_of_set(std::set &s){ + ast res = pv->mk_false(); + for(std::set::iterator it = s.begin(), en = s.end(); it != en; ++it) + res = my_or(*it,res); + return res; +} + +void iz3proof::mk_and_factor(int p1, int p2, int i, std::vector &itps, std::vector > &disjs){ +#ifdef FACTOR_INTERPS + std::set &d1 = disjs[p1]; + std::set &d2 = disjs[p2]; + if(!weak){ + if(pv->is_true(itps[p1])){ + itps[i] = itps[p2]; + disjs[i] = disjs[p2]; + } + else if(pv->is_true(itps[p2])){ + itps[i] = itps[p1]; + disjs[i] = disjs[p1]; + } + else { + std::set p1only,p2only; + std::insert_iterator > p1o(p1only,p1only.begin()); + std::insert_iterator > p2o(p2only,p2only.begin()); + std::insert_iterator > dio(disjs[i],disjs[i].begin()); + std::set_difference(d1.begin(),d1.end(),d2.begin(),d2.end(),p1o); + std::set_difference(d2.begin(),d2.end(),d1.begin(),d1.end(),p2o); + std::set_intersection(d1.begin(),d1.end(),d2.begin(),d2.end(),dio); + ast p1i = my_or(itps[p1],disj_of_set(p1only)); + ast p2i = my_or(itps[p2],disj_of_set(p2only)); + itps[i] = pv->mk_and(p1i,p2i); + } + } + else { + itps[i] = pv->mk_and(itps[p1],itps[p2]); + std::insert_iterator > dio(disjs[i],disjs[i].begin()); + std::set_union(d1.begin(),d1.end(),d2.begin(),d2.end(),dio); + } +#endif +} + +void iz3proof::mk_or_factor(int p1, int p2, int i, std::vector &itps, std::vector > &disjs){ +#ifdef FACTOR_INTERPS + std::set &d1 = disjs[p1]; + std::set &d2 = disjs[p2]; + if(weak){ + if(pv->is_false(itps[p1])){ + itps[i] = itps[p2]; + disjs[i] = disjs[p2]; + } + else if(pv->is_false(itps[p2])){ + itps[i] = itps[p1]; + disjs[i] = disjs[p1]; + } + else { + std::set p1only,p2only; + std::insert_iterator > p1o(p1only,p1only.begin()); + std::insert_iterator > p2o(p2only,p2only.begin()); + std::insert_iterator > dio(disjs[i],disjs[i].begin()); + std::set_difference(d1.begin(),d1.end(),d2.begin(),d2.end(),p1o); + std::set_difference(d2.begin(),d2.end(),d1.begin(),d1.end(),p2o); + std::set_intersection(d1.begin(),d1.end(),d2.begin(),d2.end(),dio); + ast p1i = pv->mk_and(itps[p1],pv->mk_not(disj_of_set(p1only))); + ast p2i = pv->mk_and(itps[p2],pv->mk_not(disj_of_set(p2only))); + itps[i] = my_or(p1i,p2i); + } + } + else { + itps[i] = my_or(itps[p1],itps[p2]); + std::insert_iterator > dio(disjs[i],disjs[i].begin()); + std::set_union(d1.begin(),d1.end(),d2.begin(),d2.end(),dio); + } +#endif +} + +void iz3proof::interpolate_lemma(node_struct &n){ + if(interps[n.frame].size()) + return; // already computed + pv->interpolate_clause(n.conclusion,interps[n.frame]); +} + + iz3proof::ast iz3proof::interpolate(const prover::range &_rng, bool _weak +#ifdef CHECK_PROOFS + , ast assump + , std::vector *parents +#endif +){ + // std::cout << "proof size: " << nodes.size() << "\n"; + rng = _rng; + weak = _weak; +#ifdef CHECK_PROOFS + if(nodes[nodes.size()-1].conclusion.size() != 0) + std::cerr << "internal error: proof conclusion is not empty clause\n"; + if(!child_interps.size()){ + child_interps.resize(nodes.size()); + for(unsigned j = 0; j < nodes.size(); j++) + child_interps[j] = pv->mk_true(); + } +#endif + std::vector itps(nodes.size()); +#ifdef FACTOR_INTERPS + std::vector > disjs(nodes.size()); +#endif + profiling::timer_start("Blits"); + find_B_lits(); + profiling::timer_stop("Blits"); + profiling::timer_start("interp_proof"); + // strengthen(); + for(unsigned i = 0; i < nodes.size(); i++){ + node_struct &n = nodes[i]; + ast &q = itps[i]; + switch(n.rl){ + case Assumption: { + + if(frame_in_A(n.frame)){ + /* HypC-A */ + if(!weak) +#ifdef FACTOR_INTERPS + { + q = pv->mk_false(); + set_of_B_lits(n.conclusion,disjs[i]); + } +#else + q = get_B_lits(n.conclusion); +#endif + else + q = pv->mk_false(); + } + else { + /* HypEq-B */ + if(!weak) + q = pv->mk_true(); + else +#ifdef FACTOR_INTERPS + { + q = pv->mk_true(); + set_of_A_lits(n.conclusion,disjs[i]); + } +#else + q = pv->mk_not(get_A_lits(n.conclusion)); +#endif + } + break; + } + case Resolution: { + ast p = n.aux; + p = pv->is_not(p) ? pv->mk_not(p) : p; // should be positive, but just in case + if(lit_in_B(p)) +#ifdef FACTOR_INTERPS + mk_and_factor(n.premises[0],n.premises[1],i,itps,disjs); +#else + q = pv->mk_and(itps[n.premises[0]],itps[n.premises[1]]); +#endif + else +#ifdef FACTOR_INTERPS + mk_or_factor(n.premises[0],n.premises[1],i,itps,disjs); +#else + q = my_or(itps[n.premises[0]],itps[n.premises[1]]); +#endif + break; + } + case Lemma: { + interpolate_lemma(n); // make sure lemma interpolants have been computed + q = interps[n.frame][pv->range_max(rng)]; + break; + } + case Contra: { + q = itps[n.premises[0]]; +#ifdef FACTOR_INTERPS + disjs[i] = disjs[n.premises[0]]; +#endif + break; + } + default: + assert(0 && "rule not allowed in interpolated proof"); + } +#ifdef CHECK_PROOFS + int this_frame = pv->range_max(rng); + if(0 && this_frame == 39) { + std::vector alits; + ast s = pv->mk_true(); + for(unsigned j = 0; j < n.conclusion.size(); j++) + if(pred_in_A(n.conclusion[j])){ + int scpmax = pv->range_max(pv->ast_scope(n.conclusion[j])); + if(scpmax == this_frame) + s = pv->mk_and(s,pv->mk_not(n.conclusion[j])); + } + ast ci = child_interps[i]; + s = pv->mk_and(pv->mk_and(s,pv->mk_and(assump,pv->mk_not(q))),ci); + if(pv->is_sat(s)){ + std::cout << "interpolation invariant violated at step " << i << "\n"; + assert(0 && "interpolation invariant violated"); + } + } + if((*parents)[this_frame] == 39) + child_interps[i] = pv->mk_and(child_interps[i],q); +#endif + } + ast &bar = itps[nodes.size()-1]; +#ifdef FACTOR_INTERPS + if(!weak) + bar = my_or(bar,disj_of_set(disjs[nodes.size()-1])); + else + bar = pv->mk_and(bar,pv->mk_not(disj_of_set(disjs[nodes.size()-1]))); +#endif + profiling::timer_stop("interp_proof"); + profiling::timer_start("simplifying"); + bar = pv->simplify(bar); + profiling::timer_stop("simplifying"); + return bar; +} + + +void iz3proof::print(std::ostream &s, int id){ + node_struct &n = nodes[id]; + switch(n.rl){ + case Assumption: + s << "Assumption("; + pv->print_clause(s,n.conclusion); + s << ")"; + break; + case Hypothesis: + s << "Hyp("; pv->print_expr(s,n.conclusion[0]); s << ")"; break; + case Reflexivity: + s << "Refl("; pv->print_expr(s,n.conclusion[0]); s << ")"; break; + case Symmetry: + s << "Symm("; print(s,n.premises[0]); s << ")"; break; + case Transitivity: + s << "Trans("; print(s,n.premises[0]); s << ","; print(s,n.premises[1]); s << ")"; break; + case Congruence: + s << "Cong("; pv->print_expr(s,n.conclusion[0]); + for(unsigned i = 0; i < n.premises.size(); i++){ + s << ","; + print(s,n.premises[i]); + } + s << ")"; break; + case EqContra: + s << "EqContra("; print(s,n.premises[0]); s << ","; print(s,n.premises[1]); s << ")"; break; + case Resolution: + s << "Res("; + pv->print_expr(s,n.aux); s << ","; + print(s,n.premises[0]); s << ","; print(s,n.premises[1]); s << ")"; + break; + case Lemma: + s << "Lemma("; + pv->print_clause(s,n.conclusion); + for(unsigned i = 0; i < n.premises.size(); i++){ + s << ","; + print(s,n.premises[i]); + } + s << ")"; + break; + case Contra: + s << "Contra("; + print(s,n.premises[0]); + s << ")"; + break; + default:; + } +} + + +void iz3proof::show(int id){ + std::ostringstream ss; + print(ss,id); + iz3base::pretty_print(std::cout,ss.str()); + // std::cout << ss.str(); + std::cout << "\n"; +} + + diff --git a/src/interp/iz3proof.h b/src/interp/iz3proof.h new file mode 100755 index 000000000..ae08e9d36 --- /dev/null +++ b/src/interp/iz3proof.h @@ -0,0 +1,272 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3proof.h + +Abstract: + + This class defines a simple interpolating proof system. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#ifndef IZ3PROOF_H +#define IZ3PROOF_H + +#include + +#include "iz3base.h" +#include "iz3secondary.h" + +// #define CHECK_PROOFS + +/** This class defines a simple proof system. + + A proof is a dag consisting of "nodes". The children of each node + are its "premises". Each node has a "conclusion" that is a clause, + represented as a vector of literals. + + The literals are represented by abstract syntax trees. Operations + on these, including computation of scopes are provided by iz3base. + + A proof can be interpolated, provided it is restricted to the + rules Resolution, Assumption, Contra and Lemma, and that all + clauses are strict (i.e., each literal in each clause is local). + + */ + +class iz3proof { + public: + /** The type of proof nodes (nodes in the derivation tree). */ + typedef int node; + + /** Enumeration of proof rules. */ + enum rule {Resolution,Assumption,Hypothesis,Theory,Axiom,Contra,Lemma,Reflexivity,Symmetry,Transitivity,Congruence,EqContra}; + + /** Interface to prover. */ + typedef iz3base prover; + + /** Ast type. */ + typedef prover::ast ast; + + /** Object thrown in case of a proof error. */ + struct proof_error {}; + + /* Null proof node */ + static const node null = -1; + + /** Make a resolution node with given pivot liter and premises. + The conclusion of premise1 should contain the negation of the + pivot literal, while the conclusion of premise2 should containe the + pivot literal. + */ + node make_resolution(ast pivot, node premise1, node premise2); + + /** Make an assumption node. The given clause is assumed in the given frame. */ + node make_assumption(int frame, const std::vector &assumption); + + /** Make a hypothesis node. If phi is the hypothesis, this is + effectively phi |- phi. */ + node make_hypothesis(ast hypothesis); + + /** Make a theory node. This can be any inference valid in the theory. */ + node make_theory(const std::vector &conclusion, std::vector premises); + + /** Make an axiom node. The conclusion must be an instance of an axiom. */ + node make_axiom(const std::vector &conclusion); + + /** Make a Contra node. This rule takes a derivation of the form + Gamma |- False and produces |- \/~Gamma. */ + + node make_contra(node prem, const std::vector &conclusion); + + /** Make a lemma node. A lemma node must have an interpolation. */ + node make_lemma(const std::vector &conclusion, const std::vector &interpolation); + + /** Make a Reflexivity node. This rule produces |- x = x */ + + node make_reflexivity(ast con); + + /** Make a Symmetry node. This takes a derivation of |- x = y and + produces | y = x */ + + node make_symmetry(ast con, node prem); + + /** Make a transitivity node. This takes derivations of |- x = y + and |- y = z produces | x = z */ + + node make_transitivity(ast con, node prem1, node prem2); + + /** Make a congruence node. This takes derivations of |- x_i = y_i + and produces |- f(x_1,...,x_n) = f(y_1,...,y_n) */ + + node make_congruence(ast con, const std::vector &prems); + + /** Make an equality contradicition node. This takes |- x = y + and |- !(x = y) and produces false. */ + + node make_eqcontra(node prem1, node prem2); + + /** Get the rule of a node in a proof. */ + rule get_rule(node n){ + return nodes[n].rl; + } + + /** Get the pivot of a resolution node. */ + ast get_pivot(node n){ + return nodes[n].aux; + } + + /** Get the frame of an assumption node. */ + int get_frame(node n){ + return nodes[n].frame; + } + + /** Get the number of literals of the conclusion of a node. */ + int get_num_conclusion_lits(node n){ + return get_conclusion(n).size(); + } + + /** Get the nth literal of the conclusion of a node. */ + ast get_nth_conclusion_lit(node n, int i){ + return get_conclusion(n)[i]; + } + + /** Get the conclusion of a node. */ + void get_conclusion(node n, std::vector &result){ + result = get_conclusion(n); + } + + /** Get the number of premises of a node. */ + int get_num_premises(node n){ + return nodes[n].premises.size(); + } + + /** Get the nth premise of a node. */ + int get_nth_premise(node n, int i){ + return nodes[n].premises[i]; + } + + /** Get all the premises of a node. */ + void get_premises(node n, std::vector &result){ + result = nodes[n].premises; + } + + /** Create a new proof node, replacing the premises of an old + one. */ + + node clone(node n, std::vector &premises){ + if(premises == nodes[n].premises) + return n; + nodes.push_back(nodes[n]); + nodes.back().premises = premises; + return nodes.size()-1; + } + + /** Copy a proof node from src */ + node copy(iz3proof &src, node n); + + /** Resolve two lemmas on a given literal. */ + + node resolve_lemmas(ast pivot, node left, node right); + + /** Swap two proofs. */ + void swap(iz3proof &other){ + std::swap(pv,other.pv); + nodes.swap(other.nodes); + interps.swap(other.interps); + } + + /** Compute an interpolant for a proof, where the "A" side is defined by + the given range of frames. Parameter "weak", when true, uses different + interpolation system that resutls in generally weaker interpolants. + */ + ast interpolate(const prover::range &_rng, bool weak = false +#ifdef CHECK_PROOFS + , Z3_ast assump = (Z3_ast)0, std::vector *parents = 0 + +#endif + ); + + /** print proof node to a stream */ + + void print(std::ostream &s, node n); + + /** show proof node on stdout */ + void show(node n); + + /** Construct a proof, with a given prover. */ + iz3proof(prover *p){ + pv = p; + } + + /** Default constructor */ + iz3proof(){pv = 0;} + + + protected: + + struct node_struct { + rule rl; + ast aux; + int frame; + std::vector conclusion; + std::vector premises; + }; + + std::vector nodes; + std::vector > interps; // interpolations of lemmas + prover *pv; + + node make_node(){ + nodes.push_back(node_struct()); + return nodes.size()-1; + } + + void resolve(ast pivot, std::vector &cls1, const std::vector &cls2); + + node copy_rec(stl_ext::hash_map &memo, iz3proof &src, node n); + + void interpolate_lemma(node_struct &n); + + // lazily compute the result of resolution + // the node member "frame" indicates result is computed + const std::vector &get_conclusion(node x){ + node_struct &n = nodes[x]; + if(n.rl == Resolution && !n.frame){ + n.conclusion = get_conclusion(n.premises[0]); + resolve(n.aux,n.conclusion,get_conclusion(n.premises[1])); + n.frame = 1; + } + return n.conclusion; + } + + prover::range rng; + bool weak; + stl_ext::hash_set b_lits; + ast my_or(ast x, ast y); +#ifdef CHECK_PROOFS + std::vector child_interps; +#endif + bool pred_in_A(ast id); + bool term_in_B(ast id); + bool frame_in_A(int frame); + bool lit_in_B(ast lit); + ast get_A_lits(std::vector &cls); + ast get_B_lits(std::vector &cls); + void find_B_lits(); + ast disj_of_set(std::set &s); + void mk_or_factor(int p1, int p2, int i, std::vector &itps, std::vector > &disjs); + void mk_and_factor(int p1, int p2, int i, std::vector &itps, std::vector > &disjs); + void set_of_B_lits(std::vector &cls, std::set &res); + void set_of_A_lits(std::vector &cls, std::set &res); +}; + +#endif diff --git a/src/interp/iz3proof_itp.cpp b/src/interp/iz3proof_itp.cpp new file mode 100644 index 000000000..75d14bca1 --- /dev/null +++ b/src/interp/iz3proof_itp.cpp @@ -0,0 +1,2676 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3proof.cpp + +Abstract: + + This class defines a simple interpolating proof system. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#include "iz3proof_itp.h" + +#ifndef WIN32 +using namespace stl_ext; +#endif + +// #define INVARIANT_CHECKING + +class iz3proof_itp_impl : public iz3proof_itp { + + prover *pv; + prover::range rng; + bool weak; + + enum LitType {LitA,LitB,LitMixed}; + + hash_map placeholders; + + // These symbols represent deduction rules + + /* This symbol represents a proof by contradiction. That is, + contra(p,l1 /\ ... /\ lk) takes a proof p of + + l1,...,lk |- false + + and returns a proof of + + |- ~l1,...,~l2 + */ + symb contra; + + /* The summation rule. The term sum(p,c,i) takes a proof p of an + inequality i', an integer coefficient c and an inequality i, and + yieds a proof of i' + ci. */ + symb sum; + + /* Proof rotation. The proof term rotate(q,p) takes a + proof p of: + + Gamma, q |- false + + and yields a proof of: + + Gamma |- ~q + */ + symb rotate_sum; + + /* Inequalities to equality. leq2eq(p, q, r) takes a proof + p of ~x=y, a proof q of x <= y and a proof r of y <= x + and yields a proof of false. */ + symb leq2eq; + + /* Equality to inequality. eq2leq(p, q) takes a proof p of x=y, and + a proof q ~(x <= y) and and yields a proof of false. */ + symb eq2leq; + + /* Proof term cong(p,q) takes a proof p of x=y and a proof + q of t != t and returns a proof of false. */ + symb cong; + + + /* Excluded middle. exmid(phi,p,q) takes a proof p of phi and a + proof q of ~\phi and returns a proof of false. */ + symb exmid; + + /* Symmetry. symm(p) takes a proof p of x=y and produces + a proof of y=x. */ + symb symm; + + /* Modus ponens. modpon(p,e,q) takes proofs p of P, e of P=Q + and q of ~Q and returns a proof of false. */ + symb modpon; + + /* This oprerator represents a concatenation of rewrites. The term + a=b;c=d represents an A rewrite from a to b, followed by a B + rewrite fron b to c, followed by an A rewrite from c to d. + */ + symb concat; + + /* This represents a lack of a proof */ + ast no_proof; + + // This is used to represent an infinitessimal value + ast epsilon; + + // Represents the top position of a term + ast top_pos; + + // add_pos(i,pos) represents position pos if the ith argument + symb add_pos; + + // rewrite proof rules + + /* rewrite_A(pos,cond,x=y) derives A |- cond => t[x]_p = t[y]_p + where t is an arbitrary term */ + symb rewrite_A; + + /* rewrite_B(pos,cond,x=y) derives B |- cond => t[x]_p = t[y]_p, + where t is an arbitrary term */ + symb rewrite_B; + + /* a normalization step is of the form (lhs=rhs) : proof, where "proof" + is a proof of lhs=rhs and lhs is a mixed term. If rhs is a mixed term + then it must have a greater index than lhs. */ + symb normal_step; + + /* A chain of normalization steps is either "true" (the null chain) + or normal_chain( ), where step is a normalization step + and tail is a normalization chain. The lhs of must have + a less term index than any lhs in the chain. Moreover, the rhs of + may not occur as the lhs of step in . If we wish to + add lhs=rhs to the beginning of and rhs=rhs' occurs in + we must apply transitivity, transforming to lhs=rhs'. */ + + symb normal_chain; + + /* If p is a proof of Q and c is a normalization chain, then normal(p,c) + is a proof of Q(c) (that is, Q with all substitutions in c performed). */ + + symb normal; + + /** Stand-ins for quantifiers */ + + symb sforall, sexists; + + + ast get_placeholder(ast t){ + hash_map::iterator it = placeholders.find(t); + if(it != placeholders.end()) + return it->second; + ast &res = placeholders[t]; + res = mk_fresh_constant("@p",get_type(t)); +#if 0 + std::cout << "placeholder "; + print_expr(std::cout,res); + std::cout << " = "; + print_expr(std::cout,t); + std::cout << std::endl; +#endif + return res; + } + + ast make_contra_node(const ast &pf, const std::vector &lits, int pfok = -1){ + if(lits.size() == 0) + return pf; + std::vector reslits; + reslits.push_back(make(contra,pf,mk_false())); + for(unsigned i = 0; i < lits.size(); i++){ + ast bar; + if(pfok & (1 << i)) bar = make(rotate_sum,lits[i],pf); + else bar = no_proof; + ast foo = make(contra,bar,lits[i]); + reslits.push_back(foo); + } + return make(And,reslits); + } + + LitType get_term_type(const ast &lit){ + prover::range r = pv->ast_scope(lit); + if(pv->range_is_empty(r)) + return LitMixed; + if(weak) { + if(pv->range_min(r) == SHRT_MIN) + return pv->range_contained(r,rng) ? LitA : LitB; + else + return pv->ranges_intersect(r,rng) ? LitA : LitB; + } + else + return pv->range_contained(r,rng) ? LitA : LitB; + } + + bool term_common(const ast &t){ + prover::range r = pv->ast_scope(t); + return pv->ranges_intersect(r,rng) && !pv->range_contained(r,rng); + } + + bool term_in_vocab(LitType ty, const ast &lit){ + prover::range r = pv->ast_scope(lit); + if(ty == LitA){ + return pv->ranges_intersect(r,rng); + } + return !pv->range_contained(r,rng); + } + + /** Make a resolution node with given pivot literal and premises. + The conclusion of premise1 should contain the negation of the + pivot literal, while the conclusion of premise2 should contain the + pivot literal. + */ + node make_resolution(ast pivot, const std::vector &conc, node premise1, node premise2) { + LitType lt = get_term_type(pivot); + if(lt == LitA) + return my_or(premise1,premise2); + if(lt == LitB) + return my_and(premise1,premise2); + + /* the mixed case is a bit complicated */ + + static int non_local_count = 0; + ast res = resolve_arith(pivot,conc,premise1,premise2); +#ifdef INVARIANT_CHECKING + check_contra(conc,res); +#endif + non_local_count++; + return res; + } + + + /* Handles the case of resolution on a mixed arith atom. */ + + ast resolve_arith(const ast &pivot, const std::vector &conc, node premise1, node premise2){ + ast atom = get_lit_atom(pivot); + hash_map memo; + ast neg_pivot_lit = mk_not(atom); + if(op(pivot) != Not) + std::swap(premise1,premise2); + if(op(pivot) == Equal && op(arg(pivot,0)) == Select && op(arg(pivot,1)) == Select){ + neg_pivot_lit = mk_not(neg_pivot_lit); + std::swap(premise1,premise2); + } + return resolve_arith_rec1(memo, neg_pivot_lit, premise1, premise2); + } + + + ast apply_coeff(const ast &coeff, const ast &t){ +#if 0 + rational r; + if(!is_integer(coeff,r)) + throw "ack!"; + ast n = make_int(r.numerator()); + ast res = make(Times,n,t); + if(!r.is_int()) { + ast d = make_int(r.numerator()); + res = mk_idiv(res,d); + } + return res; +#endif + return make(Times,coeff,t); + } + + ast sum_ineq(const ast &coeff1, const ast &ineq1, const ast &coeff2, const ast &ineq2){ + opr sum_op = Leq; + if(op(ineq1) == Lt || op(ineq2) == Lt) + sum_op = Lt; + ast sum_sides[2]; + for(int i = 0; i < 2; i++){ + sum_sides[i] = make(Plus,apply_coeff(coeff1,arg(ineq1,i)),apply_coeff(coeff2,arg(ineq2,i))); + sum_sides[i] = z3_simplify(sum_sides[i]); + } + return make(sum_op,sum_sides[0],sum_sides[1]); + } + + + void collect_contra_resolvents(int from, const ast &pivot1, const ast &pivot, const ast &conj, std::vector &res){ + int nargs = num_args(conj); + for(int i = from; i < nargs; i++){ + ast f = arg(conj,i); + if(!(f == pivot)){ + ast ph = get_placeholder(mk_not(arg(pivot1,1))); + ast pf = arg(pivot1,0); + ast thing = pf == no_proof ? no_proof : subst_term_and_simp(ph,pf,arg(f,0)); + ast newf = make(contra,thing,arg(f,1)); + res.push_back(newf); + } + } + } + + bool is_negative_equality(const ast &e){ + if(op(e) == Not){ + opr o = op(arg(e,0)); + return o == Equal || o == Iff; + } + return false; + } + + int count_negative_equalities(const std::vector &resolvent){ + int res = 0; + for(unsigned i = 0; i < resolvent.size(); i++) + if(is_negative_equality(arg(resolvent[i],1))) + res++; + return res; + } + + ast resolve_contra_nf(const ast &pivot1, const ast &conj1, + const ast &pivot2, const ast &conj2){ + std::vector resolvent; + collect_contra_resolvents(0,pivot1,pivot2,conj2,resolvent); + collect_contra_resolvents(1,pivot2,pivot1,conj1,resolvent); + if(count_negative_equalities(resolvent) > 1) + throw proof_error(); + if(resolvent.size() == 1) // we have proved a contradiction + return simplify(arg(resolvent[0],0)); // this is the proof -- get interpolant + return make(And,resolvent); + } + + ast resolve_contra(const ast &pivot1, const ast &conj1, + const ast &pivot2, const ast &conj2){ + if(arg(pivot1,0) != no_proof) + return resolve_contra_nf(pivot1, conj1, pivot2, conj2); + if(arg(pivot2,0) != no_proof) + return resolve_contra_nf(pivot2, conj2, pivot1, conj1); + return resolve_with_quantifier(pivot1, conj1, pivot2, conj2); + } + + + bool is_contra_itp(const ast &pivot1, ast itp2, ast &pivot2){ + if(op(itp2) == And){ + int nargs = num_args(itp2); + for(int i = 1; i < nargs; i++){ + ast foo = arg(itp2,i); + if(op(foo) == Uninterpreted && sym(foo) == contra){ + if(arg(foo,1) == pivot1){ + pivot2 = foo; + return true; + } + } + else break; + } + } + return false; + } + + ast resolve_arith_rec2(hash_map &memo, const ast &pivot1, const ast &conj1, const ast &itp2){ + ast &res = memo[itp2]; + if(!res.null()) + return res; + + ast pivot2; + if(is_contra_itp(mk_not(arg(pivot1,1)),itp2,pivot2)) + res = resolve_contra(pivot1,conj1,pivot2,itp2); + else { + switch(op(itp2)){ + case Or: + case And: + case Implies: { + unsigned nargs = num_args(itp2); + std::vector args; args.resize(nargs); + for(unsigned i = 0; i < nargs; i++) + args[i] = resolve_arith_rec2(memo, pivot1, conj1, arg(itp2,i)); + ast foo = itp2; // get rid of const + res = clone(foo,args); + break; + } + default: + { + symb s = sym(itp2); + if(s == sforall || s == sexists) + res = make(s,arg(itp2,0),resolve_arith_rec2(memo, pivot1, conj1, arg(itp2,1))); + else + res = itp2; + } + } + } + return res; + } + + + ast resolve_arith_rec1(hash_map &memo, const ast &neg_pivot_lit, const ast &itp1, const ast &itp2){ + ast &res = memo[itp1]; + if(!res.null()) + return res; + ast pivot1; + if(is_contra_itp(neg_pivot_lit,itp1,pivot1)){ + hash_map memo2; + res = resolve_arith_rec2(memo2,pivot1,itp1,itp2); + } + else { + switch(op(itp1)){ + case Or: + case And: + case Implies: { + unsigned nargs = num_args(itp1); + std::vector args; args.resize(nargs); + for(unsigned i = 0; i < nargs; i++) + args[i] = resolve_arith_rec1(memo, neg_pivot_lit, arg(itp1,i), itp2); + ast foo = itp1; // get rid of const + res = clone(foo,args); + break; + } + default: + { + symb s = sym(itp1); + if(s == sforall || s == sexists) + res = make(s,arg(itp1,0),resolve_arith_rec1(memo, neg_pivot_lit, arg(itp1,1), itp2)); + else + res = itp1; + } + } + } + return res; + } + + void check_contra(hash_set &memo, hash_set &neg_lits, const ast &foo){ + if(memo.find(foo) != memo.end()) + return; + memo.insert(foo); + if(op(foo) == Uninterpreted && sym(foo) == contra){ + ast neg_lit = arg(foo,1); + if(!is_false(neg_lit) && neg_lits.find(neg_lit) == neg_lits.end()) + throw "lost a literal"; + return; + } + else { + switch(op(foo)){ + case Or: + case And: + case Implies: { + unsigned nargs = num_args(foo); + std::vector args; args.resize(nargs); + for(unsigned i = 0; i < nargs; i++) + check_contra(memo, neg_lits, arg(foo,i)); + break; + } + default: break; + } + } + } + + void check_contra(const std::vector &neg_lits, const ast &foo){ + hash_set memo; + hash_set neg_lits_set; + for(unsigned i = 0; i < neg_lits.size(); i++) + if(get_term_type(neg_lits[i]) == LitMixed) + neg_lits_set.insert(mk_not(neg_lits[i])); + check_contra(memo,neg_lits_set,foo); + } + + hash_map subst_memo; // memo of subst function + + ast subst_term_and_simp(const ast &var, const ast &t, const ast &e){ + subst_memo.clear(); + return subst_term_and_simp_rec(var,t,e); + } + + ast subst_term_and_simp_rec(const ast &var, const ast &t, const ast &e){ + if(e == var) return t; + std::pair foo(e,ast()); + std::pair::iterator,bool> bar = subst_memo.insert(foo); + ast &res = bar.first->second; + if(bar.second){ + symb g = sym(e); + if(g == rotate_sum){ + if(var == get_placeholder(arg(e,0))){ + res = e; + } + else + res = make(rotate_sum,arg(e,0),subst_term_and_simp_rec(var,t,arg(e,1))); + return res; + } + if(g == concat){ + res = e; + return res; + } + int nargs = num_args(e); + std::vector args(nargs); + for(int i = 0; i < nargs; i++) + args[i] = subst_term_and_simp_rec(var,t,arg(e,i)); + opr f = op(e); + if(f == Equal && args[0] == args[1]) res = mk_true(); + else if(f == And) res = my_and(args); + else if(f == Or) res = my_or(args); + else if(f == Idiv) res = mk_idiv(args[0],args[1]); + else res = clone(e,args); + } + return res; + } + + /* This is where the real work happens. Here, we simplify the + proof obtained by cut elimination, obtaining an interpolant. */ + + struct cannot_simplify {}; + hash_map simplify_memo; + + ast simplify(const ast &t){ + ast res = normalize(simplify_rec(t)); +#ifdef BOGUS_QUANTS + if(localization_vars.size()) + res = add_quants(z3_simplify(res)); +#endif + return res; + } + + ast simplify_rec(const ast &e){ + std::pair foo(e,ast()); + std::pair::iterator,bool> bar = simplify_memo.insert(foo); + ast &res = bar.first->second; + if(bar.second){ + int nargs = num_args(e); + std::vector args(nargs); + bool placeholder_arg = false; + symb g = sym(e); + if(g == concat){ + res = e; + return res; + } + for(int i = 0; i < nargs; i++){ + if(i == 0 && g == rotate_sum) + args[i] = arg(e,i); + else + args[i] = simplify_rec(arg(e,i)); + placeholder_arg |= is_placeholder(args[i]); + } + try { + opr f = op(e); + if(f == Equal && args[0] == args[1]) res = mk_true(); + else if(f == And) res = my_and(args); + else if(f == Or) + res = my_or(args); + else if(f == Idiv) res = mk_idiv(args[0],args[1]); + else if(f == Uninterpreted && !placeholder_arg){ + if(g == rotate_sum) res = simplify_rotate(args); + else if(g == symm) res = simplify_symm(args); + else if(g == modpon) res = simplify_modpon(args); + else if(g == sum) res = simplify_sum(args); +#if 0 + else if(g == cong) res = simplify_cong(args); + else if(g == modpon) res = simplify_modpon(args); + else if(g == leq2eq) res = simplify_leq2eq(args); + else if(g == eq2leq) res = simplify_eq2leq(args); +#endif + else res = clone(e,args); + } + else res = clone(e,args); + } + catch (const cannot_simplify &){ + res = clone(e,args); + } + } + return res; + } + + + ast simplify_rotate(const std::vector &args){ + const ast &pf = args[1]; + ast pl = get_placeholder(args[0]); + if(op(pf) == Uninterpreted){ + symb g = sym(pf); + if(g == sum) return simplify_rotate_sum(pl,pf); + if(g == leq2eq) return simplify_rotate_leq2eq(pl,args[0],pf); + if(g == eq2leq) return simplify_rotate_eq2leq(pl,args[0],pf); + if(g == cong) return simplify_rotate_cong(pl,args[0],pf); + if(g == modpon) return simplify_rotate_modpon(pl,args[0],pf); + // if(g == symm) return simplify_rotate_symm(pl,args[0],pf); + } + if(op(pf) == Leq) + throw "foo!"; + throw cannot_simplify(); + } + + bool is_normal_ineq(const ast &ineq){ + if(sym(ineq) == normal) + return is_ineq(arg(ineq,0)); + return is_ineq(ineq); + } + + ast simplify_sum(std::vector &args){ + ast Aproves = mk_true(), Bproves = mk_true(); + ast ineq = args[0]; + if(!is_normal_ineq(ineq)) throw cannot_simplify(); + sum_cond_ineq(ineq,args[1],args[2],Aproves,Bproves); + return my_and(Aproves,my_implies(Bproves,ineq)); + } + + ast simplify_rotate_sum(const ast &pl, const ast &pf){ + ast cond = mk_true(); + ast ineq = make(Leq,make_int("0"),make_int("0")); + return rotate_sum_rec(pl,pf,cond,ineq); + } + + bool is_rewrite_chain(const ast &chain){ + return sym(chain) == concat; + } + +#if 0 + ast ineq_from_chain_simple(const ast &chain, ast &cond){ + if(is_true(chain)) + return chain; + ast last = chain_last(chain); + ast rest = chain_rest(chain); + if(is_true(rest) && is_rewrite_side(LitA,last) + && is_true(rewrite_lhs(last))){ + cond = my_and(cond,rewrite_cond(last)); + return rewrite_rhs(last); + } + if(is_rewrite_side(LitB,last) && is_true(rewrite_cond(last))) + return ineq_from_chain_simple(rest,cond); + return chain; + } +#endif + + ast ineq_from_chain(const ast &chain, ast &Aproves, ast &Bproves){ + if(is_rewrite_chain(chain)) + return rewrite_chain_to_normal_ineq(chain,Aproves,Bproves); + return chain; + } + + + void sum_cond_ineq(ast &ineq, const ast &coeff2, const ast &ineq2, ast &Aproves, ast &Bproves){ + opr o = op(ineq2); + if(o == Implies){ + sum_cond_ineq(ineq,coeff2,arg(ineq2,1),Aproves,Bproves); + Bproves = my_and(Bproves,arg(ineq2,0)); + } + else { + ast the_ineq = ineq_from_chain(ineq2,Aproves,Bproves); + if(sym(ineq) == normal || sym(the_ineq) == normal){ + sum_normal_ineq(ineq,coeff2,the_ineq,Aproves,Bproves); + return; + } + if(is_ineq(the_ineq)) + linear_comb(ineq,coeff2,the_ineq); + else + throw cannot_simplify(); + } + } + + void destruct_normal(const ast &pf, ast &p, ast &n){ + if(sym(pf) == normal){ + p = arg(pf,0); + n = arg(pf,1); + } + else { + p = pf; + n = mk_true(); + } + } + + void sum_normal_ineq(ast &ineq, const ast &coeff2, const ast &ineq2, ast &Aproves, ast &Bproves){ + ast in1,in2,n1,n2; + destruct_normal(ineq,in1,n1); + destruct_normal(ineq2,in2,n2); + ast dummy1, dummy2; + sum_cond_ineq(in1,coeff2,in2,dummy1,dummy2); + n1 = merge_normal_chains(n1,n2, Aproves, Bproves); + ineq = make_normal(in1,n1); + } + + bool is_ineq(const ast &ineq){ + opr o = op(ineq); + if(o == Not) o = op(arg(ineq,0)); + return o == Leq || o == Lt || o == Geq || o == Gt; + } + + // divide both sides of inequality by a non-negative integer divisor + ast idiv_ineq(const ast &ineq1, const ast &divisor){ + if(sym(ineq1) == normal){ + ast in1,n1; + destruct_normal(ineq1,in1,n1); + in1 = idiv_ineq(in1,divisor); + return make_normal(in1,n1); + } + if(divisor == make_int(rational(1))) + return ineq1; + ast ineq = ineq1; + if(op(ineq) == Lt) + ineq = simplify_ineq(make(Leq,arg(ineq,0),make(Sub,arg(ineq,1),make_int("1")))); + return make(op(ineq),mk_idiv(arg(ineq,0),divisor),mk_idiv(arg(ineq,1),divisor)); + } + + ast rotate_sum_rec(const ast &pl, const ast &pf, ast &Bproves, ast &ineq){ + if(pf == pl) + return my_implies(Bproves,simplify_ineq(ineq)); + if(op(pf) == Uninterpreted && sym(pf) == sum){ + if(arg(pf,2) == pl){ + ast Aproves = mk_true(); + sum_cond_ineq(ineq,make_int("1"),arg(pf,0),Aproves,Bproves); + if(!is_true(Aproves)) + throw "help!"; + ineq = idiv_ineq(ineq,arg(pf,1)); + return my_implies(Bproves,ineq); + } + ast Aproves = mk_true(); + sum_cond_ineq(ineq,arg(pf,1),arg(pf,2),Aproves,Bproves); + if(!is_true(Aproves)) + throw "help!"; + return rotate_sum_rec(pl,arg(pf,0),Bproves,ineq); + } + throw cannot_simplify(); + } + + ast simplify_rotate_leq2eq(const ast &pl, const ast &neg_equality, const ast &pf){ + if(pl == arg(pf,0)){ + ast equality = arg(neg_equality,0); + ast x = arg(equality,0); + ast y = arg(equality,1); + ast Aproves1 = mk_true(), Bproves1 = mk_true(); + ast xleqy = round_ineq(ineq_from_chain(arg(pf,1),Aproves1,Bproves1)); + ast yleqx = round_ineq(ineq_from_chain(arg(pf,2),Aproves1,Bproves1)); + ast ineq1 = make(Leq,make_int("0"),make_int("0")); + sum_cond_ineq(ineq1,make_int("-1"),xleqy,Aproves1,Bproves1); + sum_cond_ineq(ineq1,make_int("-1"),yleqx,Aproves1,Bproves1); + Bproves1 = my_and(Bproves1,z3_simplify(ineq1)); + ast Aproves2 = mk_true(), Bproves2 = mk_true(); + ast ineq2 = make(Leq,make_int("0"),make_int("0")); + sum_cond_ineq(ineq2,make_int("1"),xleqy,Aproves2,Bproves2); + sum_cond_ineq(ineq2,make_int("1"),yleqx,Aproves2,Bproves2); + Bproves2 = z3_simplify(ineq2); + if(!is_true(Aproves1) || !is_true(Aproves2)) + throw "help!"; + if(get_term_type(x) == LitA){ + ast iter = z3_simplify(make(Plus,x,get_ineq_rhs(xleqy))); + ast rewrite1 = make_rewrite(LitA,top_pos,Bproves1,make(Equal,x,iter)); + ast rewrite2 = make_rewrite(LitB,top_pos,Bproves2,make(Equal,iter,y)); + return chain_cons(chain_cons(mk_true(),rewrite1),rewrite2); + } + if(get_term_type(y) == LitA){ + ast iter = z3_simplify(make(Plus,y,get_ineq_rhs(yleqx))); + ast rewrite2 = make_rewrite(LitA,top_pos,Bproves1,make(Equal,iter,y)); + ast rewrite1 = make_rewrite(LitB,top_pos,Bproves2,make(Equal,x,iter)); + return chain_cons(chain_cons(mk_true(),rewrite1),rewrite2); + } + throw cannot_simplify(); + } + throw cannot_simplify(); + } + + ast round_ineq(const ast &ineq){ + if(!is_ineq(ineq)) + throw cannot_simplify(); + ast res = simplify_ineq(ineq); + if(op(res) == Lt) + res = make(Leq,arg(res,0),make(Sub,arg(res,1),make_int("1"))); + return res; + } + + ast simplify_rotate_eq2leq(const ast &pl, const ast &neg_equality, const ast &pf){ + if(pl == arg(pf,1)){ + ast cond = mk_true(); + ast equa = sep_cond(arg(pf,0),cond); + if(is_equivrel_chain(equa)){ + ast lhs,rhs; eq_from_ineq(arg(neg_equality,0),lhs,rhs); // get inequality we need to prove + LitType lhst = get_term_type(lhs), rhst = get_term_type(rhs); + if(lhst != LitMixed && rhst != LitMixed){ + ast ineqs= chain_ineqs(op(arg(neg_equality,0)),LitA,equa,lhs,rhs); // chain must be from lhs to rhs + cond = my_and(cond,chain_conditions(LitA,equa)); + ast Bconds = z3_simplify(chain_conditions(LitB,equa)); + if(is_true(Bconds) && op(ineqs) != And) + return my_implies(cond,ineqs); + } + else { + ast itp = make(Leq,make_int(rational(0)),make_int(rational(0))); + return make_normal(itp,cons_normal(fix_normal(lhs,rhs,equa),mk_true())); + } + } + } + throw cannot_simplify(); + } + + void reverse_modpon(std::vector &args){ + std::vector sargs(1); sargs[0] = args[1]; + args[1] = simplify_symm(sargs); + if(is_equivrel_chain(args[2])) + args[1] = down_chain(args[1]); + std::swap(args[0],args[2]); + } + + ast simplify_rotate_modpon(const ast &pl, const ast &neg_equality, const ast &pf){ + std::vector args; args.resize(3); + args[0] = arg(pf,0); + args[1] = arg(pf,1); + args[2] = arg(pf,2); + if(pl == args[0]) + reverse_modpon(args); + if(pl == args[2]){ + ast cond = mk_true(); + ast chain = simplify_modpon_fwd(args, cond); + return my_implies(cond,chain); + } + throw cannot_simplify(); + } + + ast get_ineq_rhs(const ast &ineq2){ + opr o = op(ineq2); + if(o == Implies) + return get_ineq_rhs(arg(ineq2,1)); + else if(o == Leq || o == Lt) + return arg(ineq2,1); + throw cannot_simplify(); + } + + ast simplify_rotate_cong(const ast &pl, const ast &neg_equality, const ast &pf){ + if(pl == arg(pf,2)){ + if(op(arg(pf,0)) == True) + return mk_true(); + rational pos; + if(is_numeral(arg(pf,1),pos)){ + int ipos = pos.get_unsigned(); + ast cond = mk_true(); + ast equa = sep_cond(arg(pf,0),cond); +#if 0 + if(op(equa) == Equal){ + ast pe = mk_not(neg_equality); + ast lhs = subst_in_arg_pos(ipos,arg(equa,0),arg(pe,0)); + ast rhs = subst_in_arg_pos(ipos,arg(equa,1),arg(pe,1)); + ast res = make(Equal,lhs,rhs); + return my_implies(cond,res); + } +#endif + ast res = chain_pos_add(ipos,equa); + return my_implies(cond,res); + } + } + throw cannot_simplify(); + } + + ast simplify_symm(const std::vector &args){ + if(op(args[0]) == True) + return mk_true(); + ast cond = mk_true(); + ast equa = sep_cond(args[0],cond); + if(is_equivrel_chain(equa)) + return my_implies(cond,reverse_chain(equa)); + throw cannot_simplify(); + } + + ast simplify_modpon_fwd(const std::vector &args, ast &cond){ + ast P = sep_cond(args[0],cond); + ast PeqQ = sep_cond(args[1],cond); + ast chain; + if(is_equivrel_chain(P)){ + try { + ast split[2]; + split_chain(PeqQ,split); + chain = reverse_chain(split[0]); + chain = concat_rewrite_chain(chain,P); + chain = concat_rewrite_chain(chain,split[1]); + } + catch(const cannot_split &){ + static int this_count = 0; + this_count++; + ast tail, pref = get_head_chain(PeqQ,tail,false); // pref is x=y, tail is x=y -> x'=y' + ast split[2]; split_chain(tail,split); // rewrites from x to x' and y to y' + ast head = chain_last(pref); + ast prem = make_rewrite(rewrite_side(head),top_pos,rewrite_cond(head),make(Iff,mk_true(),mk_not(rewrite_lhs(head)))); + ast back_chain = chain_cons(mk_true(),prem); + back_chain = concat_rewrite_chain(back_chain,chain_pos_add(0,reverse_chain(chain_rest(pref)))); + ast cond = contra_chain(back_chain,P); + if(is_rewrite_side(LitA,head)) + cond = mk_not(cond); + ast fwd_rewrite = make_rewrite(rewrite_side(head),top_pos,cond,rewrite_rhs(head)); + P = chain_cons(mk_true(),fwd_rewrite); + chain = reverse_chain(split[0]); + chain = concat_rewrite_chain(chain,P); + chain = concat_rewrite_chain(chain,split[1]); + } + } + else { // if not an equivalence, must be of form T <-> pred + chain = concat_rewrite_chain(P,PeqQ); + } + return chain; + } + + void get_subterm_normals(const ast &ineq1, const ast &ineq2, const ast &chain, ast &normals, + const ast &pos, hash_set &memo, ast &Aproves, ast &Bproves){ + opr o1 = op(ineq1); + opr o2 = op(ineq2); + if(o1 == Not || o1 == Leq || o1 == Lt || o1 == Geq || o1 == Gt || o1 == Plus || o1 == Times){ + int n = num_args(ineq1); + if(o2 != o1 || num_args(ineq2) != n) + throw "bad inequality rewriting"; + for(int i = 0; i < n; i++){ + ast new_pos = add_pos_to_end(pos,i); + get_subterm_normals(arg(ineq1,i), arg(ineq2,i), chain, normals, new_pos, memo, Aproves, Bproves); + } + } + else if(get_term_type(ineq2) == LitMixed && memo.find(ineq2) == memo.end()){ + memo.insert(ineq2); + ast sub_chain = extract_rewrites(chain,pos); + if(is_true(sub_chain)) + throw "bad inequality rewriting"; + ast new_normal = make_normal_step(ineq2,ineq1,reverse_chain(sub_chain)); + normals = merge_normal_chains(normals,cons_normal(new_normal,mk_true()), Aproves, Bproves); + } + } + + ast rewrite_chain_to_normal_ineq(const ast &chain, ast &Aproves, ast &Bproves){ + ast tail, pref = get_head_chain(chain,tail,false); // pref is x=y, tail is x=y -> x'=y' + ast head = chain_last(pref); + ast ineq1 = rewrite_rhs(head); + ast ineq2 = apply_rewrite_chain(ineq1,tail); + ast nc = mk_true(); + hash_set memo; + get_subterm_normals(ineq1,ineq2,tail,nc,top_pos,memo, Aproves, Bproves); + ast itp; + if(is_rewrite_side(LitA,head)){ + itp = ineq1; + ast mc = z3_simplify(chain_side_proves(LitB,pref)); + Bproves = my_and(Bproves,mc); + } + else { + itp = make(Leq,make_int(rational(0)),make_int(rational(0))); + ast mc = z3_simplify(chain_side_proves(LitA,pref)); + Aproves = my_and(Aproves,mc); + } + return make_normal(itp,nc); + } + + /* Given a chain rewrite chain deriving not P and a rewrite chain deriving P, return an interpolant. */ + ast contra_chain(const ast &neg_chain, const ast &pos_chain){ + // equality is a special case. we use the derivation of x=y to rewrite not(x=y) to not(y=y) + if(is_equivrel_chain(pos_chain)){ + ast tail, pref = get_head_chain(neg_chain,tail); // pref is not(x=y), tail is not(x,y) -> not(x',y') + ast split[2]; split_chain(down_chain(tail),split); // rewrites from x to x' and y to y' + ast chain = split[0]; + chain = concat_rewrite_chain(chain,pos_chain); // rewrites from x to y' + chain = concat_rewrite_chain(chain,reverse_chain(split[1])); // rewrites from x to y + chain = concat_rewrite_chain(pref,chain_pos_add(0,chain_pos_add(0,chain))); // rewrites t -> not(y=y) + ast head = chain_last(pref); + if(is_rewrite_side(LitB,head)){ + ast condition = chain_conditions(LitB,chain); + return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),condition); + } + else { + ast condition = chain_conditions(LitA,chain); + return my_and(chain_conditions(LitB,chain),my_implies(condition,mk_not(chain_formulas(LitB,chain)))); + } + // ast chain = concat_rewrite_chain(neg_chain,chain_pos_add(0,chain_pos_add(0,pos_chain))); + // return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),chain_conditions(LitB,chain)); + } + // otherwise, we reverse the derivation of t = P and use it to rewrite not(P) to not(t) + ast chain = concat_rewrite_chain(neg_chain,chain_pos_add(0,reverse_chain(pos_chain))); + return my_and(my_implies(chain_conditions(LitA,chain),chain_formulas(LitA,chain)),chain_conditions(LitB,chain)); + } + + ast simplify_modpon(const std::vector &args){ + ast Aproves = mk_true(), Bproves = mk_true(); + ast chain = simplify_modpon_fwd(args,Bproves); + ast Q2 = sep_cond(args[2],Bproves); + ast interp; + if(is_normal_ineq(Q2)){ // inequalities are special + ast nQ2 = rewrite_chain_to_normal_ineq(chain,Aproves,Bproves); + sum_cond_ineq(nQ2,make_int(rational(1)),Q2,Aproves,Bproves); + interp = normalize(nQ2); + } + else + interp = is_negation_chain(chain) ? contra_chain(chain,Q2) : contra_chain(Q2,chain); + return my_and(Aproves,my_implies(Bproves,interp)); + } + + + bool is_equivrel(const ast &p){ + opr o = op(p); + return o == Equal || o == Iff; + } + + struct rewrites_failed{}; + + /* Suppose p in Lang(B) and A |- p -> q and B |- q -> r. Return a z in Lang(B) such that + B |- p -> z and A |- z -> q. Collect any side conditions in "rules". */ + + ast commute_rewrites(const ast &p, const ast &q, const ast &r, ast &rules){ + if(q == r) + return p; + if(p == q) + return r; + else { + ast rew = make(Equal,q,r); + if(get_term_type(rew) == LitB){ + apply_common_rewrites(p,p,q,rules); // A rewrites must be over comon vocab + return r; + } + } + if(sym(p) != sym(q) || sym(q) != sym(r)) + throw rewrites_failed(); + int nargs = num_args(p); + if(nargs != num_args(q) || nargs != num_args(r)) + throw rewrites_failed(); + std::vector args; args.resize(nargs); + for(int i = 0; i < nargs; i++) + args[i] = commute_rewrites(arg(p,i),arg(q,i),arg(r,i),rules); + return clone(p,args); + } + + ast apply_common_rewrites(const ast &p, const ast &q, const ast &r, ast &rules){ + if(q == r) + return p; + ast rew = make(Equal,q,r); + if(term_common(rew)){ + if(p != q) + throw rewrites_failed(); + rules = my_and(rules,rew); + return r; + } + if(sym(p) != sym(q) || sym(q) != sym(r)) + return p; + int nargs = num_args(p); + if(nargs != num_args(q) || nargs != num_args(r)) + return p; + std::vector args; args.resize(nargs); + for(int i = 0; i < nargs; i++) + args[i] = apply_common_rewrites(arg(p,i),arg(q,i),arg(r,i),rules); + return clone(p,args); + } + + ast apply_all_rewrites(const ast &p, const ast &q, const ast &r){ + if(q == r) + return p; + if(p == q) + return r; + if(sym(p) != sym(q) || sym(q) != sym(r)) + throw rewrites_failed(); + int nargs = num_args(p); + if(nargs != num_args(q) || nargs != num_args(r)) + throw rewrites_failed(); + std::vector args; args.resize(nargs); + for(int i = 0; i < nargs; i++) + args[i] = apply_all_rewrites(arg(p,i),arg(q,i),arg(r,i)); + return clone(p,args); + } + + ast delta(const ast &x, const ast &y){ + if(op(x) != op(y) || (op(x) == Uninterpreted && sym(x) != sym(y)) || num_args(x) != num_args(y)) + return make(Equal,x,y); + ast res = mk_true(); + int nargs = num_args(x); + for(int i = 0; i < nargs; i++) + res = my_and(res,delta(arg(x,i),arg(y,i))); + return res; + } + + bool diff_rec(LitType t, const ast &p, const ast &q, ast &pd, ast &qd){ + if(p == q) + return false; + if(term_in_vocab(t,p) && term_in_vocab(t,q)){ + pd = p; + qd = q; + return true; + } + else { + if(sym(p) != sym(q)) return false; + int nargs = num_args(p); + if(num_args(q) != nargs) return false; + for(int i = 0; i < nargs; i++) + if(diff_rec(t,arg(p,i),arg(q,i),pd,qd)) + return true; + return false; + } + } + + void diff(LitType t, const ast &p, const ast &q, ast &pd, ast &qd){ + if(!diff_rec(t,p,q,pd,qd)) + throw cannot_simplify(); + } + + bool apply_diff_rec(LitType t, const ast &inp, const ast &p, const ast &q, ast &out){ + if(p == q) + return false; + if(term_in_vocab(t,p) && term_in_vocab(t,q)){ + if(inp != p) + throw cannot_simplify(); + out = q; + return true; + } + else { + int nargs = num_args(p); + if(sym(p) != sym(q)) throw cannot_simplify(); + if(num_args(q) != nargs) throw cannot_simplify(); + if(sym(p) != sym(inp)) throw cannot_simplify(); + if(num_args(inp) != nargs) throw cannot_simplify(); + for(int i = 0; i < nargs; i++) + if(apply_diff_rec(t,arg(inp,i),arg(p,i),arg(q,i),out)) + return true; + return false; + } + } + + ast apply_diff(LitType t, const ast &inp, const ast &p, const ast &q){ + ast out; + if(!apply_diff_rec(t,inp,p,q,out)) + throw cannot_simplify(); + return out; + } + + bool merge_A_rewrites(const ast &A1, const ast &A2, ast &merged) { + if(arg(A1,1) == arg(A2,0)){ + merged = make(op(A1),arg(A1,0),arg(A2,1)); + return true; + } + ast diff1l, diff1r, diff2l, diff2r,diffBl,diffBr; + diff(LitA,arg(A1,0),arg(A1,1),diff1l,diff1r); + diff(LitA,arg(A2,0),arg(A2,1),diff2l,diff2r); + diff(LitB,arg(A1,1),arg(A2,0),diffBl,diffBr); + if(!term_common(diff2l) && !term_common(diffBr)){ + ast A1r = apply_diff(LitB,arg(A2,1),arg(A2,0),arg(A1,1)); + merged = make(op(A1),arg(A1,0),A1r); + return true; + } + if(!term_common(diff1r) && !term_common(diffBl)){ + ast A2l = apply_diff(LitB,arg(A1,0),arg(A1,1),arg(A2,0)); + merged = make(op(A1),A2l,arg(A2,1)); + return true; + } + return false; + } + + void collect_A_rewrites(const ast &t, std::vector &res){ + if(is_true(t)) + return; + if(sym(t) == concat){ + res.push_back(arg(t,0)); + collect_A_rewrites(arg(t,0),res); + return; + } + res.push_back(t); + } + + ast concat_A_rewrites(const std::vector &rew){ + if(rew.size() == 0) + return mk_true(); + ast res = rew[0]; + for(unsigned i = 1; i < rew.size(); i++) + res = make(concat,res,rew[i]); + return res; + } + + ast merge_concat_rewrites(const ast &A1, const ast &A2){ + std::vector rew; + collect_A_rewrites(A1,rew); + int first = rew.size(), last = first; // range that might need merging + collect_A_rewrites(A2,rew); + while(first > 0 && first < (int)rew.size() && first <= last){ + ast merged; + if(merge_A_rewrites(rew[first-1],rew[first],merged)){ + rew[first] = merged; + first--; + rew.erase(rew.begin()+first); + last--; + if(first >= last) last = first+1; + } + else + first++; + } + return concat_A_rewrites(rew); + } + + ast sep_cond(const ast &t, ast &cond){ + if(op(t) == Implies){ + cond = my_and(cond,arg(t,0)); + return arg(t,1); + } + return t; + } + + + /* operations on term positions */ + + /** Finds the difference between two positions. If p1 < p2 (p1 is a + position below p2), returns -1 and sets diff to p2-p1 (the psath + from position p2 to position p1). If p2 < p1 (p2 is a position + below p1), returns 1 and sets diff to p1-p2 (the psath from + position p1 to position p2). If equal, return 0 and set diff to + top_pos. Else (if p1 and p2 are independent) returns 2 and + leaves diff unchanged. */ + + int pos_diff(const ast &p1, const ast &p2, ast &diff){ + if(p1 == top_pos && p2 != top_pos){ + diff = p2; + return 1; + } + if(p2 == top_pos && p1 != top_pos){ + diff = p1; + return -1; + } + if(p1 == top_pos && p2 == top_pos){ + diff = p1; + return 0; + } + if(arg(p1,0) == arg(p2,0)) // same argument position, recur + return pos_diff(arg(p1,1),arg(p2,1),diff); + return 2; + } + + /* return the position of pos in the argth argument */ + ast pos_add(int arg, const ast &pos){ + return make(add_pos,make_int(rational(arg)),pos); + } + + ast add_pos_to_end(const ast &pos, int i){ + if(pos == top_pos) + return pos_add(i,pos); + return make(add_pos,arg(pos,0),add_pos_to_end(arg(pos,1),i)); + } + + /* return the argument number of position, if not top */ + int pos_arg(const ast &pos){ + rational r; + if(is_numeral(arg(pos,0),r)) + return r.get_unsigned(); + throw "bad position!"; + } + + /* substitute y into position pos in x */ + ast subst_in_pos(const ast &x, const ast &pos, const ast &y){ + if(pos == top_pos) + return y; + int p = pos_arg(pos); + int nargs = num_args(x); + if(p >= 0 && p < nargs){ + std::vector args(nargs); + for(int i = 0; i < nargs; i++) + args[i] = i == p ? subst_in_pos(arg(x,i),arg(pos,1),y) : arg(x,i); + return clone(x,args); + } + throw "bad term position!"; + } + + ast diff_chain(LitType t, const ast &pos, const ast &x, const ast &y, const ast &prefix){ + int nargs = num_args(x); + if(x == y) return prefix; + if(sym(x) == sym(y) && nargs == num_args(y)){ + ast res = prefix; + for(int i = 0; i < nargs; i++) + res = diff_chain(t,pos_add(i,pos),arg(x,i),arg(y,i),res); + return res; + } + return chain_cons(prefix,make_rewrite(t,pos,mk_true(),make_equiv_rel(x,y))); + } + + /* operations on rewrites */ + ast make_rewrite(LitType t, const ast &pos, const ast &cond, const ast &equality){ +#if 0 + if(pos == top_pos && op(equality) == Iff && !is_true(arg(equality,0))) + throw "bad rewrite"; +#endif + return make(t == LitA ? rewrite_A : rewrite_B, pos, cond, equality); + } + + ast rewrite_pos(const ast &rew){ + return arg(rew,0); + } + + ast rewrite_cond(const ast &rew){ + return arg(rew,1); + } + + ast rewrite_equ(const ast &rew){ + return arg(rew,2); + } + + ast rewrite_lhs(const ast &rew){ + return arg(arg(rew,2),0); + } + + ast rewrite_rhs(const ast &rew){ + return arg(arg(rew,2),1); + } + + /* operations on rewrite chains */ + + ast chain_cons(const ast &chain, const ast &elem){ + return make(concat,chain,elem); + } + + ast chain_rest(const ast &chain){ + return arg(chain,0); + } + + ast chain_last(const ast &chain){ + return arg(chain,1); + } + + ast rewrite_update_rhs(const ast &rew, const ast &pos, const ast &new_rhs, const ast &new_cond){ + ast foo = subst_in_pos(rewrite_rhs(rew),pos,new_rhs); + ast equality = arg(rew,2); + return make(sym(rew),rewrite_pos(rew),my_and(rewrite_cond(rew),new_cond),make(op(equality),arg(equality,0),foo)); + } + + ast rewrite_update_lhs(const ast &rew, const ast &pos, const ast &new_lhs, const ast &new_cond){ + ast foo = subst_in_pos(rewrite_lhs(rew),pos,new_lhs); + ast equality = arg(rew,2); + return make(sym(rew),rewrite_pos(rew),my_and(rewrite_cond(rew),new_cond),make(op(equality),foo,arg(equality,1))); + } + + bool is_common_rewrite(const ast &rew){ + return term_common(arg(rew,2)); + } + + bool is_right_mover(const ast &rew){ + return term_common(rewrite_lhs(rew)) && !term_common(rewrite_rhs(rew)); + } + + bool is_left_mover(const ast &rew){ + return term_common(rewrite_rhs(rew)) && !term_common(rewrite_lhs(rew)); + } + + bool same_side(const ast &rew1, const ast &rew2){ + return sym(rew1) == sym(rew2); + } + + bool is_rewrite_side(LitType t, const ast &rew){ + if(t == LitA) + return sym(rew) == rewrite_A; + return sym(rew) == rewrite_B; + } + + LitType rewrite_side(const ast &rew){ + return (sym(rew) == rewrite_A) ? LitA : LitB; + } + + ast rewrite_to_formula(const ast &rew){ + return my_implies(arg(rew,1),arg(rew,2)); + } + + // make rewrite rew conditon on rewrite cond + ast rewrite_conditional(const ast &cond, const ast &rew){ + ast cf = rewrite_to_formula(cond); + return make(sym(rew),arg(rew,0),my_and(arg(rew,1),cf),arg(rew,2)); + } + + ast reverse_rewrite(const ast &rew){ + ast equ = arg(rew,2); + return make(sym(rew),arg(rew,0),arg(rew,1),make(op(equ),arg(equ,1),arg(equ,0))); + } + + ast rewrite_pos_add(int apos, const ast &rew){ + return make(sym(rew),pos_add(apos,arg(rew,0)),arg(rew,1),arg(rew,2)); + } + + ast rewrite_pos_set(const ast &pos, const ast &rew){ + return make(sym(rew),pos,arg(rew,1),arg(rew,2)); + } + + ast rewrite_up(const ast &rew){ + return make(sym(rew),arg(arg(rew,0),1),arg(rew,1),arg(rew,2)); + } + + /** Adds a rewrite to a chain of rewrites, keeping the chain in + normal form. An empty chain is represented by true.*/ + + ast add_rewrite_to_chain(const ast &chain, const ast &rewrite){ + if(is_true(chain)) + return chain_cons(chain,rewrite); + ast last = chain_last(chain); + ast rest = chain_rest(chain); + if(same_side(last,rewrite)){ + ast p1 = rewrite_pos(last); + ast p2 = rewrite_pos(rewrite); + ast diff; + switch(pos_diff(p1,p2,diff)){ + case 1: { + ast absorb = rewrite_update_rhs(last,diff,rewrite_rhs(rewrite),rewrite_cond(rewrite)); + return add_rewrite_to_chain(rest,absorb); + } + case 0: + case -1: { + ast absorb = rewrite_update_lhs(rewrite,diff,rewrite_lhs(last),rewrite_cond(last)); + return add_rewrite_to_chain(rest,absorb); + } + default: {// independent case + bool rm = is_right_mover(last); + bool lm = is_left_mover(rewrite); + if((lm && !rm) || (rm && !lm)) + return chain_swap(rest,last,rewrite); + } + } + } + else { + if(is_left_mover(rewrite)){ + if(is_common_rewrite(last)) + return add_rewrite_to_chain(chain_cons(rest,flip_rewrite(last)),rewrite); + if(!is_left_mover(last)) + return chain_swap(rest,last,rewrite); + } + if(is_right_mover(last)){ + if(is_common_rewrite(rewrite)) + return add_rewrite_to_chain(chain,flip_rewrite(rewrite)); + if(!is_right_mover(rewrite)) + return chain_swap(rest,last,rewrite); + } + } + return chain_cons(chain,rewrite); + } + + ast chain_swap(const ast &rest, const ast &last, const ast &rewrite){ + return chain_cons(add_rewrite_to_chain(rest,rewrite),last); + } + + ast flip_rewrite(const ast &rew){ + symb flip_sym = (sym(rew) == rewrite_A) ? rewrite_B : rewrite_A; + ast cf = rewrite_to_formula(rew); + return make(flip_sym,arg(rew,0),my_implies(arg(rew,1),cf),arg(rew,2)); + } + + /** concatenates two rewrite chains, keeping result in normal form. */ + + ast concat_rewrite_chain(const ast &chain1, const ast &chain2){ + if(is_true(chain2)) return chain1; + if(is_true(chain1)) return chain2; + ast foo = concat_rewrite_chain(chain1,chain_rest(chain2)); + return add_rewrite_to_chain(foo,chain_last(chain2)); + } + + /** reverse a chain of rewrites */ + + ast reverse_chain_rec(const ast &chain, const ast &prefix){ + if(is_true(chain)) + return prefix; + ast last = reverse_rewrite(chain_last(chain)); + ast rest = chain_rest(chain); + return reverse_chain_rec(rest,chain_cons(prefix,last)); + } + + ast reverse_chain(const ast &chain){ + return reverse_chain_rec(chain,mk_true()); + } + + bool is_equivrel_chain(const ast &chain){ + if(is_true(chain)) + return true; + ast last = chain_last(chain); + ast rest = chain_rest(chain); + if(is_true(rest)) + return !is_true(rewrite_lhs(last)); + return is_equivrel_chain(rest); + } + + bool is_negation_chain(const ast &chain){ + if(is_true(chain)) + return false; + ast last = chain_last(chain); + ast rest = chain_rest(chain); + if(is_true(rest)) + return op(rewrite_rhs(last)) == Not; + return is_negation_chain(rest); + } + + // split a rewrite chain into head and tail at last top-level rewrite + ast get_head_chain(const ast &chain, ast &tail, bool is_not = true){ + ast last = chain_last(chain); + ast rest = chain_rest(chain); + ast pos = rewrite_pos(last); + if(pos == top_pos || (is_not && arg(pos,1) == top_pos)){ + tail = mk_true(); + return chain; + } + if(is_true(rest)) + throw "bad rewrite chain"; + ast head = get_head_chain(rest,tail,is_not); + tail = chain_cons(tail,last); + return head; + } + + struct cannot_split {}; + + /** Split a chain of rewrites two chains, operating on positions 0 and 1. + Fail if any rewrite in the chain operates on top position. */ + void split_chain_rec(const ast &chain, ast *res){ + if(is_true(chain)) + return; + ast last = chain_last(chain); + ast rest = chain_rest(chain); + split_chain_rec(rest,res); + ast pos = rewrite_pos(last); + if(pos == top_pos){ + if(rewrite_lhs(last) == rewrite_rhs(last)) + return; // skip if it's a noop + throw cannot_split(); + } + int arg = pos_arg(pos); + if(arg<0 || arg > 1) + throw cannot_split(); + res[arg] = chain_cons(res[arg],rewrite_up(last)); + } + + void split_chain(const ast &chain, ast *res){ + res[0] = res[1] = mk_true(); + split_chain_rec(chain,res); + } + + ast extract_rewrites(const ast &chain, const ast &pos){ + if(is_true(chain)) + return chain; + ast last = chain_last(chain); + ast rest = chain_rest(chain); + ast new_rest = extract_rewrites(rest,pos); + ast p1 = rewrite_pos(last); + ast diff; + switch(pos_diff(p1,pos,diff)){ + case -1: { + ast new_last = rewrite_pos_set(diff, last); + return chain_cons(new_rest,new_last); + } + case 1: + if(rewrite_lhs(last) != rewrite_rhs(last)) + throw "bad rewrite chain"; + break; + default:; + } + return new_rest; + } + + ast down_chain(const ast &chain){ + ast split[2]; + split_chain(chain,split); + return split[0]; + } + + ast chain_conditions(LitType t, const ast &chain){ + if(is_true(chain)) + return mk_true(); + ast last = chain_last(chain); + ast rest = chain_rest(chain); + ast cond = chain_conditions(t,rest); + if(is_rewrite_side(t,last)) + cond = my_and(cond,rewrite_cond(last)); + return cond; + } + + ast chain_formulas(LitType t, const ast &chain){ + if(is_true(chain)) + return mk_true(); + ast last = chain_last(chain); + ast rest = chain_rest(chain); + ast cond = chain_formulas(t,rest); + if(is_rewrite_side(t,last)) + cond = my_and(cond,rewrite_equ(last)); + return cond; + } + + + ast chain_ineqs(opr comp_op, LitType t, const ast &chain, const ast &lhs, const ast &rhs){ + if(is_true(chain)){ + if(lhs != rhs) + throw "bad ineq inference"; + return make(Leq,make_int(rational(0)),make_int(rational(0))); + } + ast last = chain_last(chain); + ast rest = chain_rest(chain); + ast mid = subst_in_pos(rhs,rewrite_pos(last),rewrite_lhs(last)); + ast cond = chain_ineqs(comp_op,t,rest,lhs,mid); + if(is_rewrite_side(t,last)){ + ast diff; + if(comp_op == Leq) diff = make(Sub,rhs,mid); + else diff = make(Sub,mid,rhs); + ast foo = z3_simplify(make(Leq,make_int("0"),diff)); + if(is_true(cond)) + cond = foo; + else { + linear_comb(cond,make_int(rational(1)),foo); + cond = simplify_ineq(cond); + } + } + return cond; + } + + ast ineq_to_lhs(const ast &ineq){ + ast s = make(Leq,make_int(rational(0)),make_int(rational(0))); + linear_comb(s,make_int(rational(1)),ineq); + return simplify_ineq(s); + } + + void eq_from_ineq(const ast &ineq, ast &lhs, ast &rhs){ + // ast s = ineq_to_lhs(ineq); + // ast srhs = arg(s,1); + ast srhs = arg(ineq,0); + if(op(srhs) == Plus && num_args(srhs) == 2 && arg(ineq,1) == make_int(rational(0))){ + lhs = arg(srhs,0); + rhs = arg(srhs,1); + // if(op(lhs) == Times) + // std::swap(lhs,rhs); + if(op(rhs) == Times){ + rhs = arg(rhs,1); + // if(op(ineq) == Leq) + // std::swap(lhs,rhs); + return; + } + } + if(op(ineq) == Leq || op(ineq) == Geq){ + lhs = srhs; + rhs = arg(ineq,1); + return; + } + throw "bad ineq"; + } + + ast chain_pos_add(int arg, const ast &chain){ + if(is_true(chain)) + return mk_true(); + ast last = rewrite_pos_add(arg,chain_last(chain)); + ast rest = chain_pos_add(arg,chain_rest(chain)); + return chain_cons(rest,last); + } + + ast apply_rewrite_chain(const ast &t, const ast &chain){ + if(is_true(chain)) + return t; + ast last = chain_last(chain); + ast rest = chain_rest(chain); + ast mid = apply_rewrite_chain(t,rest); + ast res = subst_in_pos(mid,rewrite_pos(last),rewrite_rhs(last)); + return res; + } + + ast drop_rewrites(LitType t, const ast &chain, ast &remainder){ + if(!is_true(chain)){ + ast last = chain_last(chain); + ast rest = chain_rest(chain); + if(is_rewrite_side(t,last)){ + ast res = drop_rewrites(t,rest,remainder); + remainder = chain_cons(remainder,last); + return res; + } + } + remainder = mk_true(); + return chain; + } + + // Normalization chains + + ast cons_normal(const ast &first, const ast &rest){ + return make(normal_chain,first,rest); + } + + ast normal_first(const ast &t){ + return arg(t,0); + } + + ast normal_rest(const ast &t){ + return arg(t,1); + } + + ast normal_lhs(const ast &t){ + return arg(arg(t,0),0); + } + + ast normal_rhs(const ast &t){ + return arg(arg(t,0),1); + } + + ast normal_proof(const ast &t){ + return arg(t,1); + } + + ast make_normal_step(const ast &lhs, const ast &rhs, const ast &proof){ + return make(normal_step,make_equiv(lhs,rhs),proof); + } + + ast make_normal(const ast &ineq, const ast &nrml){ + if(!is_ineq(ineq)) + throw "what?"; + return make(normal,ineq,nrml); + } + + ast fix_normal(const ast &lhs, const ast &rhs, const ast &proof){ + LitType rhst = get_term_type(rhs); + if(rhst != LitMixed || ast_id(lhs) < ast_id(rhs)) + return make_normal_step(lhs,rhs,proof); + else + return make_normal_step(rhs,lhs,reverse_chain(proof)); + } + + ast chain_side_proves(LitType side, const ast &chain){ + LitType other_side = side == LitA ? LitB : LitA; + return my_and(chain_conditions(other_side,chain),my_implies(chain_conditions(side,chain),chain_formulas(side,chain))); + } + + // Merge two normalization chains + ast merge_normal_chains_rec(const ast &chain1, const ast &chain2, hash_map &trans, ast &Aproves, ast &Bproves){ + if(is_true(chain1)) + return chain2; + if(is_true(chain2)) + return chain1; + ast f1 = normal_first(chain1); + ast f2 = normal_first(chain2); + ast lhs1 = normal_lhs(f1); + ast lhs2 = normal_lhs(f2); + int id1 = ast_id(lhs1); + int id2 = ast_id(lhs2); + if(id1 < id2) return cons_normal(f1,merge_normal_chains_rec(normal_rest(chain1),chain2,trans,Aproves,Bproves)); + if(id2 < id1) return cons_normal(f2,merge_normal_chains_rec(chain1,normal_rest(chain2),trans,Aproves,Bproves)); + ast rhs1 = normal_rhs(f1); + ast rhs2 = normal_rhs(f2); + LitType t1 = get_term_type(rhs1); + LitType t2 = get_term_type(rhs2); + int tid1 = ast_id(rhs1); + int tid2 = ast_id(rhs2); + ast pf1 = normal_proof(f1); + ast pf2 = normal_proof(f2); + ast new_normal; + if(t1 == LitMixed && (t2 != LitMixed || tid2 > tid1)){ + ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2); + new_normal = f2; + trans[rhs1] = make_normal_step(rhs1,rhs2,new_proof); + } + else if(t2 == LitMixed && (t1 != LitMixed || tid1 > tid2)) + return merge_normal_chains_rec(chain2,chain1,trans,Aproves,Bproves); + else if(t1 == LitA && t2 == LitB){ + ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2); + ast Bproof, Aproof = drop_rewrites(LitB,new_proof,Bproof); + ast mcA = chain_side_proves(LitB,Aproof); + Bproves = my_and(Bproves,mcA); + ast mcB = chain_side_proves(LitA,Bproof); + Aproves = my_and(Aproves,mcB); + ast rep = apply_rewrite_chain(rhs1,Aproof); + new_proof = concat_rewrite_chain(pf1,Aproof); + new_normal = make_normal_step(rhs1,rep,new_proof); + } + else if(t1 == LitA && t2 == LitB) + return merge_normal_chains_rec(chain2,chain1,trans,Aproves,Bproves); + else if(t1 == LitA) { + ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2); + ast mc = chain_side_proves(LitB,new_proof); + Bproves = my_and(Bproves,mc); + new_normal = f1; // choice is arbitrary + } + else { /* t1 = t2 = LitB */ + ast new_proof = concat_rewrite_chain(reverse_chain(pf1),pf2); + ast mc = chain_side_proves(LitA,new_proof); + Aproves = my_and(Aproves,mc); + new_normal = f1; // choice is arbitrary + } + return cons_normal(new_normal,merge_normal_chains_rec(normal_rest(chain1),normal_rest(chain2),trans,Aproves,Bproves)); + } + + ast trans_normal_chain(const ast &chain, hash_map &trans){ + if(is_true(chain)) + return chain; + ast f = normal_first(chain); + ast r = normal_rest(chain); + ast rhs = normal_rhs(f); + hash_map::iterator it = trans.find(rhs); + ast new_normal; + if(it != trans.end()){ + const ast &f2 = it->second; + ast pf = concat_rewrite_chain(normal_proof(f),normal_proof(f2)); + new_normal = make_normal_step(normal_lhs(f),normal_rhs(f2),pf); + } + else + new_normal = f; + return cons_normal(new_normal,trans_normal_chain(r,trans)); + } + + ast merge_normal_chains(const ast &chain1, const ast &chain2, ast &Aproves, ast &Bproves){ + hash_map trans; + ast res = merge_normal_chains_rec(chain1,chain2,trans,Aproves,Bproves); + res = trans_normal_chain(res,trans); + return res; + } + + bool destruct_cond_ineq(ast t, ast &Aproves, ast &Bproves, ast&ineq){ + if(op(t) == And){ + Aproves = arg(t,0); + t = arg(t,1); + } + else + Aproves = mk_true(); + if(op(t) == Implies){ + Bproves = arg(t,0); + t = arg(t,1); + } + else + Bproves = mk_true(); + if(is_normal_ineq(t)){ + ineq = t; + return true; + } + return false; + } + + ast cons_cond_ineq(const ast &Aproves, const ast &Bproves, const ast &ineq){ + return my_and(Aproves,my_implies(Bproves,ineq)); + } + + ast normalize(const ast &ct){ + ast Aproves,Bproves,t; + if(!destruct_cond_ineq(ct,Aproves,Bproves,t)) + return ct; + if(sym(t) != normal) + return ct; + ast chain = arg(t,1); + hash_map map; + for(ast c = chain; !is_true(c); c = normal_rest(c)){ + ast first = normal_first(c); + ast lhs = normal_lhs(first); + ast rhs = normal_rhs(first); + map[lhs] = rhs; + } + ast res = subst(map,arg(t,0)); + return cons_cond_ineq(Aproves,Bproves,res); + } + + /** Make an assumption node. The given clause is assumed in the given frame. */ + virtual node make_assumption(int frame, const std::vector &assumption){ + if(!weak){ + if(pv->in_range(frame,rng)){ + std::vector itp_clause; + for(unsigned i = 0; i < assumption.size(); i++) + if(get_term_type(assumption[i]) != LitA) + itp_clause.push_back(assumption[i]); + ast res = my_or(itp_clause); + return res; + } + else { + return mk_true(); + } + } + else { + if(pv->in_range(frame,rng)){ + return mk_false(); + } + else { + std::vector itp_clause; + for(unsigned i = 0; i < assumption.size(); i++) + if(get_term_type(assumption[i]) != LitB) + itp_clause.push_back(assumption[i]); + ast res = my_or(itp_clause); + return mk_not(res); + } + } + } + + ast make_local_rewrite(LitType t, const ast &p){ + ast rew = is_equivrel(p) ? p : make(Iff,mk_true(),p); +#if 0 + if(op(rew) == Iff && !is_true(arg(rew,0))) + return diff_chain(t,top_pos,arg(rew,0),arg(rew,1), mk_true()); +#endif + return chain_cons(mk_true(),make_rewrite(t, top_pos, mk_true(), rew)); + } + + ast triv_interp(const symb &rule, const std::vector &premises, int mask_in){ + std::vector ps; ps.resize(premises.size()); + std::vector conjs; + int mask = 0; + for(unsigned i = 0; i < ps.size(); i++){ + ast p = premises[i]; + LitType t = get_term_type(p); + switch(t){ + case LitA: + case LitB: + ps[i] = make_local_rewrite(t,p); + break; + default: + ps[i] = get_placeholder(p); // can only prove consequent! + if(mask_in & (1 << i)) + mask |= (1 << conjs.size()); + conjs.push_back(p); + } + } + ast ref = make(rule,ps); + ast res = make_contra_node(ref,conjs,mask); + return res; + } + + ast triv_interp(const symb &rule, const ast &p0, const ast &p1, const ast &p2, int mask){ + std::vector ps; ps.resize(3); + ps[0] = p0; + ps[1] = p1; + ps[2] = p2; + return triv_interp(rule,ps,mask); + } + + /** Make a modus-ponens node. This takes derivations of |- x + and |- x = y and produces |- y */ + + virtual node make_mp(const ast &p_eq_q, const ast &prem1, const ast &prem2){ + + /* Interpolate the axiom p, p=q -> q */ + ast p = arg(p_eq_q,0); + ast q = arg(p_eq_q,1); + ast itp; + if(get_term_type(p_eq_q) == LitMixed){ + int mask = 1 << 2; + if(op(p) == Not && is_equivrel(arg(p,0))) + mask |= 1; // we may need to run this rule backward if first premise is negative equality + itp = triv_interp(modpon,p,p_eq_q,mk_not(q),mask); + } + else { + if(get_term_type(p) == LitA){ + if(get_term_type(q) == LitA) + itp = mk_false(); + else { + if(get_term_type(p_eq_q) == LitA) + itp = q; + else + throw proof_error(); + } + } + else { + if(get_term_type(q) == LitA){ + if(get_term_type(make(Equal,p,q)) == LitA) + itp = mk_not(p); + else + throw proof_error(); + } + else + itp = mk_true(); + } + } + + /* Resolve it with the premises */ + std::vector conc; conc.push_back(q); conc.push_back(mk_not(p_eq_q)); + itp = make_resolution(p,conc,itp,prem1); + conc.pop_back(); + itp = make_resolution(p_eq_q,conc,itp,prem2); + return itp; + } + + ast capture_localization(ast e){ + // #define CAPTURE_LOCALIZATION +#ifdef CAPTURE_LOCALIZATION + for(int i = localization_vars.size() - 1; i >= 0; i--){ + LocVar &lv = localization_vars[i]; + if(occurs_in(lv.var,e)){ + symb q = (pv->in_range(lv.frame,rng)) ? sexists : sforall; + e = make(q,make(Equal,lv.var,lv.term),e); // use Equal because it is polymorphic + } + } +#endif + return e; + } + + /** Make an axiom node. The conclusion must be an instance of an axiom. */ + virtual node make_axiom(const std::vector &conclusion, prover::range frng){ + int nargs = conclusion.size(); + std::vector largs(nargs); + std::vector eqs; + std::vector pfs; + + for(int i = 0; i < nargs; i++){ + ast argpf; + ast lit = conclusion[i]; + largs[i] = localize_term(lit,frng,argpf); + frng = pv->range_glb(frng,pv->ast_scope(largs[i])); + if(largs[i] != lit){ + eqs.push_back(make_equiv(largs[i],lit)); + pfs.push_back(argpf); + } + } + + int frame = pv->range_max(frng); + ast itp = make_assumption(frame,largs); + + for(unsigned i = 0; i < eqs.size(); i++) + itp = make_mp(eqs[i],itp,pfs[i]); + return capture_localization(itp); + } + + virtual node make_axiom(const std::vector &conclusion){ + return make_axiom(conclusion,pv->range_full()); + } + + /** Make a Contra node. This rule takes a derivation of the form + Gamma |- False and produces |- \/~Gamma. */ + + virtual node make_contra(node prem, const std::vector &conclusion){ + return prem; + } + + /** Make hypothesis. Creates a node of the form P |- P. */ + + virtual node make_hypothesis(const ast &P){ + if(is_not(P)) + return make_hypothesis(arg(P,0)); + switch(get_term_type(P)){ + case LitA: + return mk_false(); + case LitB: + return mk_true(); + default: // mixed hypothesis + switch(op(P)){ + case Geq: + case Leq: + case Gt: + case Lt: { + ast zleqz = make(Leq,make_int("0"),make_int("0")); + ast fark1 = make(sum,zleqz,make_int("1"),get_placeholder(P)); + ast fark2 = make(sum,fark1,make_int("1"),get_placeholder(mk_not(P))); + ast res = make(And,make(contra,fark2,mk_false()), + make(contra,get_placeholder(mk_not(P)),P), + make(contra,get_placeholder(P),mk_not(P))); + return res; + } + default: { + ast em = make(exmid,P,get_placeholder(P),get_placeholder(mk_not(P))); + ast res = make(And,make(contra,em,mk_false()), + make(contra,get_placeholder(mk_not(P)),P), + make(contra,get_placeholder(P),mk_not(P))); + return res; + } + } + } + } + + /** Make a Reflexivity node. This rule produces |- x = x */ + + virtual node make_reflexivity(ast con){ + throw proof_error(); +} + + /** Make a Symmetry node. This takes a derivation of |- x = y and + produces | y = x. Ditto for ~(x=y) */ + + virtual node make_symmetry(ast con, const ast &premcon, node prem){ +#if 0 + ast x = arg(con,0); + ast y = arg(con,1); + ast p = make(op(con),y,x); +#endif + if(get_term_type(con) != LitMixed) + return prem; // symmetry shmymmetry... + ast em = make(exmid,con,make(symm,get_placeholder(premcon)),get_placeholder(mk_not(con))); + ast itp = make(And,make(contra,em,mk_false()), + make(contra,make(symm,get_placeholder(mk_not(con))),premcon), + make(contra,make(symm,get_placeholder(premcon)),mk_not(con))); + + std::vector conc; conc.push_back(con); + itp = make_resolution(premcon,conc,itp,prem); + return itp; + } + + ast make_equiv_rel(const ast &x, const ast &y){ + if(is_bool_type(get_type(x))) + return make(Iff,x,y); + return make(Equal,x,y); + } + + /** Make a transitivity node. This takes derivations of |- x = y + and |- y = z produces | x = z */ + + virtual node make_transitivity(const ast &x, const ast &y, const ast &z, node prem1, node prem2){ + + /* Interpolate the axiom x=y,y=z,-> x=z */ + ast p = make_equiv_rel(x,y); + ast q = make_equiv_rel(y,z); + ast r = make_equiv_rel(x,z); + ast equiv = make(Iff,p,r); + ast itp; + + itp = make_congruence(q,equiv,prem2); + itp = make_mp(equiv,prem1,itp); + + return itp; + + } + + /** Make a congruence node. This takes derivations of |- x_i = y_i + and produces |- f(x_1,...,x_n) = f(y_1,...,y_n) */ + + virtual node make_congruence(const ast &p, const ast &con, const ast &prem1){ + ast x = arg(p,0), y = arg(p,1); + ast itp; + LitType con_t = get_term_type(con); + if(get_term_type(p) == LitA){ + if(con_t == LitA) + itp = mk_false(); + else if(con_t == LitB) + itp = p; + else + itp = make_mixed_congruence(x, y, p, con, prem1); + } + else { + if(con_t == LitA) + itp = mk_not(p); + else{ + if(con_t == LitB) + itp = mk_true(); + else + itp = make_mixed_congruence(x, y, p, con, prem1); + } + } + std::vector conc; conc.push_back(con); + itp = make_resolution(p,conc,itp,prem1); + return itp; + } + + int find_congruence_position(const ast &p, const ast &con){ + // find the argument position of x and y + const ast &x = arg(p,0); + const ast &y = arg(p,1); + int nargs = num_args(arg(con,0)); + for(int i = 0; i < nargs; i++) + if(x == arg(arg(con,0),i) && y == arg(arg(con,1),i)) + return i; + throw proof_error(); + } + + /** Make a congruence node. This takes derivations of |- x_i1 = y_i1, |- x_i2 = y_i2,... + and produces |- f(...x_i1...x_i2...) = f(...y_i1...y_i2...) */ + + node make_congruence(const std::vector &p, const ast &con, const std::vector &prems){ + if(p.size() == 0) + throw proof_error(); + if(p.size() == 1) + return make_congruence(p[0],con,prems[0]); + ast thing = con; + ast res = mk_true(); + for(unsigned i = 0; i < p.size(); i++){ + int pos = find_congruence_position(p[i],thing); + ast next = subst_in_arg_pos(pos,arg(p[i],1),arg(thing,0)); + ast goal = make(op(thing),arg(thing,0),next); + ast equa = make_congruence(p[i],goal,prems[i]); + if(i == 0) + res = equa; + else { + ast trace = make(op(con),arg(con,0),arg(thing,0)); + ast equiv = make(Iff,trace,make(op(trace),arg(trace,0),next)); + ast foo = make_congruence(goal,equiv,equa); + res = make_mp(equiv,res,foo); + } + thing = make(op(thing),next,arg(thing,1)); + } + return res; + } + + /* Interpolate a mixed congruence axiom. */ + + virtual ast make_mixed_congruence(const ast &x, const ast &y, const ast &p, const ast &con, const ast &prem1){ + ast foo = p; + std::vector conjs; + LitType t = get_term_type(foo); + switch(t){ + case LitA: + case LitB: + foo = make_local_rewrite(t,foo); + break; + case LitMixed: + conjs.push_back(foo); + foo = get_placeholder(foo); + } + // find the argument position of x and y + int pos = -1; + int nargs = num_args(arg(con,0)); + for(int i = 0; i < nargs; i++) + if(x == arg(arg(con,0),i) && y == arg(arg(con,1),i)) + pos = i; + if(pos == -1) + throw proof_error(); + ast bar = make(cong,foo,make_int(rational(pos)),get_placeholder(mk_not(con))); + conjs.push_back(mk_not(con)); + return make_contra_node(bar,conjs); + } + + ast subst_in_arg_pos(int pos, ast term, ast app){ + std::vector args; + get_args(app,args); + args[pos] = term; + return clone(app,args); + } + + /** Make a farkas proof node. */ + + virtual node make_farkas(ast con, const std::vector &prems, const std::vector &prem_cons, + const std::vector &coeffs){ + + /* Compute the interpolant for the clause */ + + ast zero = make_int("0"); + std::vector conjs; + ast thing = make(Leq,zero,zero); + for(unsigned i = 0; i < prem_cons.size(); i++){ + const ast &lit = prem_cons[i]; + if(get_term_type(lit) == LitA) + linear_comb(thing,coeffs[i],lit); + } + thing = simplify_ineq(thing); + for(unsigned i = 0; i < prem_cons.size(); i++){ + const ast &lit = prem_cons[i]; + if(get_term_type(lit) == LitMixed){ + thing = make(sum,thing,coeffs[i],get_placeholder(lit)); + conjs.push_back(lit); + } + } + thing = make_contra_node(thing,conjs); + + /* Resolve it with the premises */ + std::vector conc; conc.resize(prem_cons.size()); + for(unsigned i = 0; i < prem_cons.size(); i++) + conc[prem_cons.size()-i-1] = prem_cons[i]; + for(unsigned i = 0; i < prem_cons.size(); i++){ + thing = make_resolution(prem_cons[i],conc,thing,prems[i]); + conc.pop_back(); + } + return thing; + } + + /** Set P to P + cQ, where P and Q are linear inequalities. Assumes P is 0 <= y or 0 < y. */ + + void linear_comb(ast &P, const ast &c, const ast &Q){ + ast Qrhs; + bool strict = op(P) == Lt; + if(is_not(Q)){ + ast nQ = arg(Q,0); + switch(op(nQ)){ + case Gt: + Qrhs = make(Sub,arg(nQ,1),arg(nQ,0)); + break; + case Lt: + Qrhs = make(Sub,arg(nQ,0),arg(nQ,1)); + break; + case Geq: + Qrhs = make(Sub,arg(nQ,1),arg(nQ,0)); + strict = true; + break; + case Leq: + Qrhs = make(Sub,arg(nQ,0),arg(nQ,1)); + strict = true; + break; + default: + throw proof_error(); + } + } + else { + switch(op(Q)){ + case Leq: + Qrhs = make(Sub,arg(Q,1),arg(Q,0)); + break; + case Geq: + Qrhs = make(Sub,arg(Q,0),arg(Q,1)); + break; + case Lt: + Qrhs = make(Sub,arg(Q,1),arg(Q,0)); + strict = true; + break; + case Gt: + Qrhs = make(Sub,arg(Q,0),arg(Q,1)); + strict = true; + break; + default: + throw proof_error(); + } + } + Qrhs = make(Times,c,Qrhs); + if(strict) + P = make(Lt,arg(P,0),make(Plus,arg(P,1),Qrhs)); + else + P = make(Leq,arg(P,0),make(Plus,arg(P,1),Qrhs)); + } + + /* Make an axiom instance of the form |- x<=y, y<= x -> x =y */ + virtual node make_leq2eq(ast x, ast y, const ast &xleqy, const ast &yleqx){ + ast con = make(Equal,x,y); + ast itp; + switch(get_term_type(con)){ + case LitA: + itp = mk_false(); + break; + case LitB: + itp = mk_true(); + break; + default: { // mixed equality + if(get_term_type(x) == LitMixed || get_term_type(y) == LitMixed) + std::cerr << "WARNING: mixed term in leq2eq\n"; + std::vector conjs; conjs.resize(3); + conjs[0] = mk_not(con); + conjs[1] = xleqy; + conjs[2] = yleqx; + itp = make_contra_node(make(leq2eq, + get_placeholder(mk_not(con)), + get_placeholder(xleqy), + get_placeholder(yleqx)), + conjs,1); + } + } + return itp; + } + + /* Make an axiom instance of the form |- x = y -> x <= y */ + virtual node make_eq2leq(ast x, ast y, const ast &xleqy){ + ast itp; + switch(get_term_type(xleqy)){ + case LitA: + itp = mk_false(); + break; + case LitB: + itp = mk_true(); + break; + default: { // mixed equality + std::vector conjs; conjs.resize(2); + conjs[0] = make(Equal,x,y); + conjs[1] = mk_not(xleqy); + itp = make(eq2leq,get_placeholder(conjs[0]),get_placeholder(conjs[1])); + itp = make_contra_node(itp,conjs,2); + } + } + return itp; + } + + + /* Make an inference of the form t <= c |- t/d <= floor(c/d) where t + is an affine term divisble by d and c is an integer constant */ + virtual node make_cut_rule(const ast &tleqc, const ast &d, const ast &con, const ast &prem){ + ast itp = mk_false(); + switch(get_term_type(con)){ + case LitA: + itp = mk_false(); + break; + case LitB: + itp = mk_true(); + break; + default: { + std::vector conjs; conjs.resize(2); + conjs[0] = tleqc; + conjs[1] = mk_not(con); + itp = make(sum,get_placeholder(conjs[0]),d,get_placeholder(conjs[1])); + itp = make_contra_node(itp,conjs); + } + } + std::vector conc; conc.push_back(con); + itp = make_resolution(tleqc,conc,itp,prem); + return itp; + } + + + + // create a fresh variable for localization + ast fresh_localization_var(const ast &term, int frame){ + std::ostringstream s; + s << "%" << (localization_vars.size()); + ast var = make_var(s.str().c_str(),get_type(term)); + pv->sym_range(sym(var)) = pv->range_full(); // make this variable global + localization_vars.push_back(LocVar(var,term,frame)); + return var; + } + + struct LocVar { // localization vars + ast var; // a fresh variable + ast term; // term it represents + int frame; // frame in which it's defined + LocVar(ast v, ast t, int f){var=v;term=t;frame=f;} + }; + + std::vector localization_vars; // localization vars in order of creation + hash_map localization_map; // maps terms to their localization vars + hash_map localization_pf_map; // maps terms to proofs of their localizations + + /* "localize" a term e to a given frame range, creating new symbols to + represent non-local subterms. This returns the localized version e_l, + as well as a proof thet e_l = l. + */ + + ast make_refl(const ast &e){ + if(get_term_type(e) == LitA) + return mk_false(); + return mk_true(); // TODO: is this right? + } + + + ast make_equiv(const ast &x, const ast &y){ + if(get_type(x) == bool_type()) + return make(Iff,x,y); + else + return make(Equal,x,y); + } + + ast localize_term(ast e, const prover::range &rng, ast &pf){ + ast orig_e = e; + pf = make_refl(e); // proof that e = e + + prover::range erng = pv->ast_scope(e); + if(!(erng.lo > erng.hi) && pv->ranges_intersect(pv->ast_scope(e),rng)){ + return e; // this term occurs in range, so it's O.K. + } + + hash_map::iterator it = localization_map.find(e); + if(it != localization_map.end()){ + pf = localization_pf_map[e]; + e = it->second; + } + + else { + // if it is non-local, we must first localize the arguments to + // the range of its function symbol + + int nargs = num_args(e); + if(nargs > 0 /* && (!is_local(e) || flo <= hi || fhi >= lo) */){ + prover::range frng = rng; + opr o = op(e); + if(o == Uninterpreted){ + symb f = sym(e); + prover::range srng = pv->sym_range(f); + if(pv->ranges_intersect(srng,rng)) // localize to desired range if possible + frng = pv->range_glb(srng,rng); + else + frng = srng; // this term will be localized + } + else if(o == Plus || o == Times){ // don't want bound variables inside arith ops + frng = erng; // this term will be localized + } + std::vector largs(nargs); + std::vector eqs; + std::vector pfs; + for(int i = 0; i < nargs; i++){ + ast argpf; + largs[i] = localize_term(arg(e,i),frng,argpf); + frng = pv->range_glb(frng,pv->ast_scope(largs[i])); + if(largs[i] != arg(e,i)){ + eqs.push_back(make_equiv(largs[i],arg(e,i))); + pfs.push_back(argpf); + } + } + + e = clone(e,largs); + if(pfs.size()) + pf = make_congruence(eqs,make_equiv(e,orig_e),pfs); + // assert(is_local(e)); + } + + localization_pf_map[orig_e] = pf; + localization_map[orig_e] = e; + } + + if(pv->ranges_intersect(pv->ast_scope(e),rng)) + return e; // this term occurs in range, so it's O.K. + + if(is_array_type(get_type(e))) + throw "help!"; + + // choose a frame for the constraint that is close to range + int frame = pv->range_near(pv->ast_scope(e),rng); + + ast new_var = fresh_localization_var(e,frame); + localization_map[orig_e] = new_var; + std::vector foo; foo.push_back(make_equiv(new_var,e)); + ast bar = make_assumption(frame,foo); + pf = make_transitivity(new_var,e,orig_e,bar,pf); + localization_pf_map[orig_e] = pf; + return new_var; + } + + ast delete_quant(hash_map &memo, const ast &v, const ast &e){ + std::pair foo(e,ast()); + std::pair::iterator,bool> bar = memo.insert(foo); + ast &res = bar.first->second; + if(bar.second){ + opr o = op(e); + switch(o){ + case Or: + case And: + case Implies: { + unsigned nargs = num_args(e); + std::vector args; args.resize(nargs); + for(unsigned i = 0; i < nargs; i++) + args[i] = delete_quant(memo, v, arg(e,i)); + res = make(o,args); + break; + } + case Uninterpreted: { + symb s = sym(e); + ast w = arg(arg(e,0),0); + if(s == sforall || s == sexists){ + res = delete_quant(memo,v,arg(e,1)); + if(w != v) + res = make(s,w,res); + break; + } + } + default: + res = e; + } + } + return res; + } + + ast insert_quants(hash_map &memo, const ast &e){ + std::pair foo(e,ast()); + std::pair::iterator,bool> bar = memo.insert(foo); + ast &res = bar.first->second; + if(bar.second){ + opr o = op(e); + switch(o){ + case Or: + case And: + case Implies: { + unsigned nargs = num_args(e); + std::vector args; args.resize(nargs); + for(unsigned i = 0; i < nargs; i++) + args[i] = insert_quants(memo, arg(e,i)); + res = make(o,args); + break; + } + case Uninterpreted: { + symb s = sym(e); + if(s == sforall || s == sexists){ + opr q = (s == sforall) ? Forall : Exists; + ast v = arg(arg(e,0),0); + hash_map dmemo; + ast body = delete_quant(dmemo,v,arg(e,1)); + body = insert_quants(memo,body); + res = apply_quant(q,v,body); + break; + } + } + default: + res = e; + } + } + return res; + } + + ast add_quants(ast e){ +#ifdef CAPTURE_LOCALIZATION + if(!localization_vars.empty()){ + hash_map memo; + e = insert_quants(memo,e); + } +#else + for(int i = localization_vars.size() - 1; i >= 0; i--){ + LocVar &lv = localization_vars[i]; + opr quantifier = (pv->in_range(lv.frame,rng)) ? Exists : Forall; + e = apply_quant(quantifier,lv.var,e); + } +#endif + return e; + } + + node make_resolution(ast pivot, node premise1, node premise2) { + std::vector lits; + return make_resolution(pivot,lits,premise1,premise2); + } + + /* Return an interpolant from a proof of false */ + ast interpolate(const node &pf){ + // proof of false must be a formula, with quantified symbols +#ifndef BOGUS_QUANTS + return add_quants(z3_simplify(pf)); +#else + return z3_simplify(pf); +#endif + } + + ast resolve_with_quantifier(const ast &pivot1, const ast &conj1, + const ast &pivot2, const ast &conj2){ + if(is_not(arg(pivot1,1))) + return resolve_with_quantifier(pivot2,conj2,pivot1,conj1); + ast eqpf; + ast P = arg(pivot1,1); + ast Ploc = localize_term(P, rng, eqpf); + ast pPloc = make_hypothesis(Ploc); + ast pP = make_mp(make(Iff,Ploc,P),pPloc,eqpf); + ast rP = make_resolution(P,conj1,pP); + ast nP = mk_not(P); + ast nPloc = mk_not(Ploc); + ast neqpf = make_congruence(make(Iff,Ploc,P),make(Iff,nPloc,nP),eqpf); + ast npPloc = make_hypothesis(nPloc); + ast npP = make_mp(make(Iff,nPloc,nP),npPloc,neqpf); + ast nrP = make_resolution(nP,conj2,npP); + ast res = make_resolution(Ploc,rP,nrP); + return capture_localization(res); + } + + ast get_contra_coeff(const ast &f){ + ast c = arg(f,0); + // if(!is_not(arg(f,1))) + // c = make(Uminus,c); + return c; + } + + ast my_or(const ast &a, const ast &b){ + return mk_or(a,b); + } + + ast my_and(const ast &a, const ast &b){ + return mk_and(a,b); + } + + ast my_implies(const ast &a, const ast &b){ + return mk_implies(a,b); + } + + ast my_or(const std::vector &a){ + return mk_or(a); + } + + ast my_and(const std::vector &a){ + return mk_and(a); + } + + ast get_lit_atom(const ast &l){ + if(op(l) == Not) + return arg(l,0); + return l; + } + + bool is_placeholder(const ast &e){ + if(op(e) == Uninterpreted){ + std::string name = string_of_symbol(sym(e)); + if(name.size() > 2 && name[0] == '@' && name[1] == 'p') + return true; + } + return false; + } + +public: + iz3proof_itp_impl(prover *p, const prover::range &r, bool w) + : iz3proof_itp(*p) + { + pv = p; + rng = r; + weak = false ; //w; + type boolintbooldom[3] = {bool_type(),int_type(),bool_type()}; + type booldom[1] = {bool_type()}; + type boolbooldom[2] = {bool_type(),bool_type()}; + type boolboolbooldom[3] = {bool_type(),bool_type(),bool_type()}; + type intbooldom[2] = {int_type(),bool_type()}; + contra = function("@contra",2,boolbooldom,bool_type()); + m().inc_ref(contra); + sum = function("@sum",3,boolintbooldom,bool_type()); + m().inc_ref(sum); + rotate_sum = function("@rotsum",2,boolbooldom,bool_type()); + m().inc_ref(rotate_sum); + leq2eq = function("@leq2eq",3,boolboolbooldom,bool_type()); + m().inc_ref(leq2eq); + eq2leq = function("@eq2leq",2,boolbooldom,bool_type()); + m().inc_ref(eq2leq); + cong = function("@cong",3,boolintbooldom,bool_type()); + m().inc_ref(cong); + exmid = function("@exmid",3,boolboolbooldom,bool_type()); + m().inc_ref(exmid); + symm = function("@symm",1,booldom,bool_type()); + m().inc_ref(symm); + epsilon = make_var("@eps",int_type()); + modpon = function("@mp",3,boolboolbooldom,bool_type()); + m().inc_ref(modpon); + no_proof = make_var("@nop",bool_type()); + concat = function("@concat",2,boolbooldom,bool_type()); + m().inc_ref(concat); + top_pos = make_var("@top_pos",bool_type()); + add_pos = function("@add_pos",2,intbooldom,bool_type()); + m().inc_ref(add_pos); + rewrite_A = function("@rewrite_A",3,boolboolbooldom,bool_type()); + m().inc_ref(rewrite_A); + rewrite_B = function("@rewrite_B",3,boolboolbooldom,bool_type()); + m().inc_ref(rewrite_B); + normal_step = function("@normal_step",2,boolbooldom,bool_type()); + m().inc_ref(normal_step); + normal_chain = function("@normal_chain",2,boolbooldom,bool_type()); + m().inc_ref(normal_chain); + normal = function("@normal",2,boolbooldom,bool_type()); + m().inc_ref(normal); + sforall = function("@sforall",2,boolbooldom,bool_type()); + m().inc_ref(sforall); + sexists = function("@sexists",2,boolbooldom,bool_type()); + m().inc_ref(sexists); + } + + ~iz3proof_itp_impl(){ + m().dec_ref(contra); + m().dec_ref(sum); + m().dec_ref(rotate_sum); + m().dec_ref(leq2eq); + m().dec_ref(eq2leq); + m().dec_ref(cong); + m().dec_ref(exmid); + m().dec_ref(symm); + m().dec_ref(modpon); + m().dec_ref(concat); + m().dec_ref(add_pos); + m().dec_ref(rewrite_A); + m().dec_ref(rewrite_B); + } +}; + +iz3proof_itp *iz3proof_itp::create(prover *p, const prover::range &r, bool w){ + return new iz3proof_itp_impl(p,r,w); +} + diff --git a/src/interp/iz3proof_itp.h b/src/interp/iz3proof_itp.h new file mode 100644 index 000000000..4d76e3a92 --- /dev/null +++ b/src/interp/iz3proof_itp.h @@ -0,0 +1,141 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3proof.h + +Abstract: + + This class defines a simple interpolating proof system. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#ifndef IZ3PROOF_ITP_H +#define IZ3PROOF_ITP_H + +#include + +#include "iz3base.h" +#include "iz3secondary.h" + +// #define CHECK_PROOFS + +/** This class defines a simple proof system. + + As opposed to iz3proof, this class directly computes interpolants, + so the proof representation is just the interpolant itself. + + */ + +class iz3proof_itp : public iz3mgr { + public: + + /** Enumeration of proof rules. */ + enum rule {Resolution,Assumption,Hypothesis,Theory,Axiom,Contra,Lemma,Reflexivity,Symmetry,Transitivity,Congruence,EqContra}; + + /** Interface to prover. */ + typedef iz3base prover; + + /** Ast type. */ + typedef prover::ast ast; + + /** The type of proof nodes (just interpolants). */ + typedef ast node; + + /** Object thrown in case of a proof error. */ + struct proof_error {}; + + + /** Make a resolution node with given pivot literal and premises. + The conclusion of premise1 should contain the negation of the + pivot literal, while the conclusion of premise2 should containe the + pivot literal. + */ + virtual node make_resolution(ast pivot, const std::vector &conc, node premise1, node premise2) = 0; + + /** Make an assumption node. The given clause is assumed in the given frame. */ + virtual node make_assumption(int frame, const std::vector &assumption) = 0; + + /** Make a hypothesis node. If phi is the hypothesis, this is + effectively phi |- phi. */ + virtual node make_hypothesis(const ast &hypothesis) = 0; + + /** Make an axiom node. The conclusion must be an instance of an axiom. */ + virtual node make_axiom(const std::vector &conclusion) = 0; + + /** Make an axiom node. The conclusion must be an instance of an axiom. Localize axiom instance to range*/ + virtual node make_axiom(const std::vector &conclusion, prover::range) = 0; + + /** Make a Contra node. This rule takes a derivation of the form + Gamma |- False and produces |- \/~Gamma. */ + + virtual node make_contra(node prem, const std::vector &conclusion) = 0; + + /** Make a Reflexivity node. This rule produces |- x = x */ + + virtual node make_reflexivity(ast con) = 0; + + /** Make a Symmetry node. This takes a derivation of |- x = y and + produces | y = x */ + + virtual node make_symmetry(ast con, const ast &premcon, node prem) = 0; + + /** Make a transitivity node. This takes derivations of |- x = y + and |- y = z produces | x = z */ + + virtual node make_transitivity(const ast &x, const ast &y, const ast &z, node prem1, node prem2) = 0; + + /** Make a congruence node. This takes a derivation of |- x_i = y_i + and produces |- f(...x_i,...) = f(...,y_i,...) */ + + virtual node make_congruence(const ast &xi_eq_yi, const ast &con, const ast &prem1) = 0; + + /** Make a congruence node. This takes derivations of |- x_i1 = y_i1, |- x_i2 = y_i2,... + and produces |- f(...x_i1...x_i2...) = f(...y_i1...y_i2...) */ + + virtual node make_congruence(const std::vector &xi_eq_yi, const ast &con, const std::vector &prems) = 0; + + /** Make a modus-ponens node. This takes derivations of |- x + and |- x = y and produces |- y */ + + virtual node make_mp(const ast &x_eq_y, const ast &prem1, const ast &prem2) = 0; + + /** Make a farkas proof node. */ + + virtual node make_farkas(ast con, const std::vector &prems, const std::vector &prem_cons, const std::vector &coeffs) = 0; + + /* Make an axiom instance of the form |- x<=y, y<= x -> x =y */ + virtual node make_leq2eq(ast x, ast y, const ast &xleqy, const ast &yleqx) = 0; + + /* Make an axiom instance of the form |- x = y -> x <= y */ + virtual node make_eq2leq(ast x, ast y, const ast &xeqy) = 0; + + /* Make an inference of the form t <= c |- t/d <= floor(c/d) where t + is an affine term divisble by d and c is an integer constant */ + virtual node make_cut_rule(const ast &tleqc, const ast &d, const ast &con, const ast &prem) = 0; + + /* Return an interpolant from a proof of false */ + virtual ast interpolate(const node &pf) = 0; + + /** Create proof object to construct an interpolant. */ + static iz3proof_itp *create(prover *p, const prover::range &r, bool _weak); + + protected: + iz3proof_itp(iz3mgr &m) + : iz3mgr(m) + { + } + + public: + virtual ~iz3proof_itp(){ + } +}; + +#endif diff --git a/src/interp/iz3scopes.cpp b/src/interp/iz3scopes.cpp new file mode 100755 index 000000000..198ecffe3 --- /dev/null +++ b/src/interp/iz3scopes.cpp @@ -0,0 +1,321 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3scopes.cpp + +Abstract: + + Calculations with scopes, for both sequence and tree interpolation. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#include + +#include + +#include "iz3scopes.h" + + +/** computes the least common ancestor of two nodes in the tree, or SHRT_MAX if none */ +int scopes::tree_lca(int n1, int n2){ + if(!tree_mode()) + return std::max(n1,n2); + if(n1 == SHRT_MIN) return n2; + if(n2 == SHRT_MIN) return n1; + if(n1 == SHRT_MAX || n2 == SHRT_MAX) return SHRT_MAX; + while(n1 != n2){ + if(n1 == SHRT_MAX || n2 == SHRT_MAX) return SHRT_MAX; + assert(n1 >= 0 && n2 >= 0 && n1 < (int)parents.size() && n2 < (int)parents.size()); + if(n1 < n2) n1 = parents[n1]; + else n2 = parents[n2]; + } + return n1; +} + +/** computes the greatest common descendant two nodes in the tree, or SHRT_MIN if none */ +int scopes::tree_gcd(int n1, int n2){ + if(!tree_mode()) + return std::min(n1,n2); + int foo = tree_lca(n1,n2); + if(foo == n1) return n2; + if(foo == n2) return n1; + return SHRT_MIN; +} + +#ifndef FULL_TREE + +/** test whether a tree node is contained in a range */ +bool scopes::in_range(int n, const range &rng){ + return tree_lca(rng.lo,n) == n && tree_gcd(rng.hi,n) == n; +} + +/** test whether two ranges of tree nodes intersect */ +bool scopes::ranges_intersect(const range &rng1, const range &rng2){ + return tree_lca(rng1.lo,rng2.hi) == rng2.hi && tree_lca(rng1.hi,rng2.lo) == rng1.hi; +} + + +bool scopes::range_contained(const range &rng1, const range &rng2){ + return tree_lca(rng2.lo,rng1.lo) == rng1.lo + && tree_lca(rng1.hi,rng2.hi) == rng2.hi; +} + +scopes::range scopes::range_lub(const range &rng1, const range &rng2){ + range res; + res.lo = tree_gcd(rng1.lo,rng2.lo); + res.hi = tree_lca(rng1.hi,rng2.hi); + return res; +} + +scopes::range scopes::range_glb(const range &rng1, const range &rng2){ + range res; + res.lo = tree_lca(rng1.lo,rng2.lo); + res.hi = tree_gcd(rng1.hi,rng2.hi); + return res; +} + +#else + + + namespace std { + template <> + class hash { + public: + size_t operator()(const scopes::range_lo &p) const { + return p.lo + (size_t)p.next; + } + }; + } + + template <> inline + size_t stdext::hash_value(const scopes::range_lo& p) + { + std::hash h; + return h(p); + } + + namespace std { + template <> + class less { + public: + bool operator()(const scopes::range_lo &x, const scopes::range_lo &y) const { + return x.lo < y.lo || x.lo == y.lo && (size_t)x.next < (size_t)y.next; + } + }; + } + + + struct range_op { + scopes::range_lo *x, *y; + int hi; + range_op(scopes::range_lo *_x, scopes::range_lo *_y, int _hi){ + x = _x; y = _y; hi = _hi; + } + }; + + namespace std { + template <> + class hash { + public: + size_t operator()(const range_op &p) const { + return (size_t) p.x + (size_t)p.y + p.hi; + } + }; + } + + template <> inline + size_t stdext::hash_value(const range_op& p) + { + std::hash h; + return h(p); + } + + namespace std { + template <> + class less { + public: + bool operator()(const range_op &x, const range_op &y) const { + return (size_t)x.x < (size_t)y.x || x.x == y.x && + ((size_t)x.y < (size_t)y.y || x.y == y.y && x.hi < y.hi); + } + }; + } + + struct range_tables { + hash_map unique; + hash_map lub; + hash_map glb; + }; + + + scopes::range_lo *scopes::find_range_lo(int lo, range_lo *next){ + range_lo foo(lo,next); + std::pair baz(foo,(range_lo *)0); + std::pair::iterator,bool> bar = rt->unique.insert(baz); + if(bar.second) + bar.first->second = new range_lo(lo,next); + return bar.first->second; + //std::pair::iterator,bool> bar = rt->unique.insert(foo); + // const range_lo *baz = &*(bar.first); + // return (range_lo *)baz; // exit const hell + } + + scopes::range_lo *scopes::range_lub_lo(range_lo *rng1, range_lo *rng2){ + if(!rng1) return rng2; + if(!rng2) return rng1; + if(rng1->lo > rng2->lo) + std::swap(rng1,rng2); + std::pair foo(range_op(rng1,rng2,0),(range_lo *)0); + std::pair::iterator,bool> bar = rt->lub.insert(foo); + range_lo *&res = bar.first->second; + if(!bar.second) return res; + if(!(rng1->next && rng1->next->lo <= rng2->lo)){ + for(int lo = rng1->lo; lo <= rng2->lo; lo = parents[lo]) + if(lo == rng2->lo) + {rng2 = rng2->next; break;} + } + range_lo *baz = range_lub_lo(rng1->next,rng2); + res = find_range_lo(rng1->lo,baz); + return res; + } + + + scopes::range_lo *scopes::range_glb_lo(range_lo *rng1, range_lo *rng2, int hi){ + if(!rng1) return rng1; + if(!rng2) return rng2; + if(rng1->lo > rng2->lo) + std::swap(rng1,rng2); + std::pair cand(range_op(rng1,rng2,hi),(range_lo *)0); + std::pair::iterator,bool> bar = rt->glb.insert(cand); + range_lo *&res = bar.first->second; + if(!bar.second) return res; + range_lo *foo; + if(!(rng1->next && rng1->next->lo <= rng2->lo)){ + int lim = hi; + if(rng1->next) lim = std::min(lim,rng1->next->lo); + int a = rng1->lo, b = rng2->lo; + while(a != b && b <= lim){ + a = parents[a]; + if(a > b)std::swap(a,b); + } + if(a == b && b <= lim){ + foo = range_glb_lo(rng1->next,rng2->next,hi); + foo = find_range_lo(b,foo); + } + else + foo = range_glb_lo(rng2,rng1->next,hi); + } + else foo = range_glb_lo(rng1->next,rng2,hi); + res = foo; + return res; + } + + /** computes the lub (smallest containing subtree) of two ranges */ + scopes::range scopes::range_lub(const range &rng1, const range &rng2){ + int hi = tree_lca(rng1.hi,rng2.hi); + if(hi == SHRT_MAX) return range_full(); + range_lo *lo = range_lub_lo(rng1.lo,rng2.lo); + return range(hi,lo); + } + + /** computes the glb (intersection) of two ranges */ + scopes::range scopes::range_glb(const range &rng1, const range &rng2){ + if(rng1.hi == SHRT_MAX) return rng2; + if(rng2.hi == SHRT_MAX) return rng1; + int hi = tree_gcd(rng1.hi,rng2.hi); + range_lo *lo = hi == SHRT_MIN ? 0 : range_glb_lo(rng1.lo,rng2.lo,hi); + if(!lo) hi = SHRT_MIN; + return range(hi,lo); + } + + /** is this range empty? */ + bool scopes::range_is_empty(const range &rng){ + return rng.hi == SHRT_MIN; + } + + /** return an empty range */ + scopes::range scopes::range_empty(){ + return range(SHRT_MIN,0); + } + + /** return a full range */ + scopes::range scopes::range_full(){ + return range(SHRT_MAX,0); + } + + /** return the maximal element of a range */ + int scopes::range_max(const range &rng){ + return rng.hi; + } + + /** return a minimal (not necessarily unique) element of a range */ + int scopes::range_min(const range &rng){ + if(rng.hi == SHRT_MAX) return SHRT_MIN; + return rng.lo ? rng.lo->lo : SHRT_MAX; + } + + + /** return range consisting of downward closure of a point */ + scopes::range scopes::range_downward(int _hi){ + std::vector descendants(parents.size()); + for(int i = descendants.size() - 1; i >= 0 ; i--) + descendants[i] = i == _hi || parents[i] < parents.size() && descendants[parents[i]]; + for(unsigned i = 0; i < descendants.size() - 1; i++) + if(parents[i] < parents.size()) + descendants[parents[i]] = false; + range_lo *foo = 0; + for(int i = descendants.size() - 1; i >= 0; --i) + if(descendants[i]) foo = find_range_lo(i,foo); + return range(_hi,foo); + } + + /** add an element to a range */ + void scopes::range_add(int i, range &n){ + range foo = range(i, find_range_lo(i,0)); + n = range_lub(foo,n); + } + + /** Choose an element of rng1 that is near to rng2 */ + int scopes::range_near(const range &rng1, const range &rng2){ + + int frame; + int thing = tree_lca(rng1.hi,rng2.hi); + if(thing != rng1.hi) return rng1.hi; + range line = range(rng1.hi,find_range_lo(rng2.hi,(range_lo *)0)); + line = range_glb(line,rng1); + return range_min(line); + } + + + /** test whether a tree node is contained in a range */ + bool scopes::in_range(int n, const range &rng){ + range r = range_empty(); + range_add(n,r); + r = range_glb(rng,r); + return !range_is_empty(r); + } + + /** test whether two ranges of tree nodes intersect */ + bool scopes::ranges_intersect(const range &rng1, const range &rng2){ + range r = range_glb(rng1,rng2); + return !range_is_empty(r); + } + + + bool scopes::range_contained(const range &rng1, const range &rng2){ + range r = range_glb(rng1,rng2); + return r.hi == rng1.hi && r.lo == rng1.lo; + } + + +#endif + + diff --git a/src/interp/iz3scopes.h b/src/interp/iz3scopes.h new file mode 100755 index 000000000..54cf89aba --- /dev/null +++ b/src/interp/iz3scopes.h @@ -0,0 +1,197 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3scopes.h + +Abstract: + + Calculations with scopes, for both sequence and tree interpolation. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + + +#ifndef IZ3SOPES_H +#define IZ3SOPES_H + +#include +#include + +class scopes { + + public: + /** Construct from parents vector. */ + scopes(const std::vector &_parents){ + parents = _parents; + } + + scopes(){ + } + + void initialize(const std::vector &_parents){ + parents = _parents; + } + + /** The parents vector defining the tree structure */ + std::vector parents; + + // #define FULL_TREE +#ifndef FULL_TREE + struct range { + range(){ + lo = SHRT_MAX; + hi = SHRT_MIN; + } + short lo, hi; + }; + + /** computes the lub (smallest containing subtree) of two ranges */ + range range_lub(const range &rng1, const range &rng2); + + /** computes the glb (intersection) of two ranges */ + range range_glb(const range &rng1, const range &rng2); + + /** is this range empty? */ + bool range_is_empty(const range &rng){ + return rng.hi < rng.lo; + } + + /** return an empty range */ + range range_empty(){ + range res; + res.lo = SHRT_MAX; + res.hi = SHRT_MIN; + return res; + } + + /** return an empty range */ + range range_full(){ + range res; + res.lo = SHRT_MIN; + res.hi = SHRT_MAX; + return res; + } + + /** return the maximal element of a range */ + int range_max(const range &rng){ + return rng.hi; + } + + /** return a minimal (not necessarily unique) element of a range */ + int range_min(const range &rng){ + return rng.lo; + } + + /** return range consisting of downward closure of a point */ + range range_downward(int _hi){ + range foo; + foo.lo = SHRT_MIN; + foo.hi = _hi; + return foo; + } + + void range_add(int i, range &n){ +#if 0 + if(i < n.lo) n.lo = i; + if(i > n.hi) n.hi = i; +#else + range rng; rng.lo = i; rng.hi = i; + n = range_lub(rng,n); +#endif + } + + /** Choose an element of rng1 that is near to rng2 */ + int range_near(const range &rng1, const range &rng2){ + int frame; + int thing = tree_lca(rng1.lo,rng2.hi); + if(thing == rng1.lo) frame = rng1.lo; + else frame = tree_gcd(thing,rng1.hi); + return frame; + } +#else + + struct range_lo { + int lo; + range_lo *next; + range_lo(int _lo, range_lo *_next){ + lo = _lo; + next = _next; + } + }; + + struct range { + int hi; + range_lo *lo; + range(int _hi, range_lo *_lo){ + hi = _hi; + lo = _lo; + } + range(){ + hi = SHRT_MIN; + lo = 0; + } + }; + + range_tables *rt; + + /** computes the lub (smallest containing subtree) of two ranges */ + range range_lub(const range &rng1, const range &rng2); + + /** computes the glb (intersection) of two ranges */ + range range_glb(const range &rng1, const range &rng2); + + /** is this range empty? */ + bool range_is_empty(const range &rng); + + /** return an empty range */ + range range_empty(); + + /** return a full range */ + range range_full(); + + /** return the maximal element of a range */ + int range_max(const range &rng); + + /** return a minimal (not necessarily unique) element of a range */ + int range_min(const range &rng); + + /** return range consisting of downward closure of a point */ + range range_downward(int _hi); + + /** add an element to a range */ + void range_add(int i, range &n); + + /** Choose an element of rng1 that is near to rng2 */ + int range_near(const range &rng1, const range &rng2); + + range_lo *find_range_lo(int lo, range_lo *next); + range_lo *range_lub_lo(range_lo *rng1, range_lo *rng2); + range_lo *range_glb_lo(range_lo *rng1, range_lo *rng2, int lim); + +#endif + + /** test whether a tree node is contained in a range */ + bool in_range(int n, const range &rng); + + /** test whether two ranges of tree nodes intersect */ + bool ranges_intersect(const range &rng1, const range &rng2); + + /** test whether range rng1 contained in range rng2 */ + bool range_contained(const range &rng1, const range &rng2); + + private: + int tree_lca(int n1, int n2); + int tree_gcd(int n1, int n2); + bool tree_mode(){return parents.size() != 0;} + + + +}; +#endif diff --git a/src/interp/iz3secondary.h b/src/interp/iz3secondary.h new file mode 100755 index 000000000..0eb1b9a93 --- /dev/null +++ b/src/interp/iz3secondary.h @@ -0,0 +1,40 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3secondary + +Abstract: + + Interface for secondary provers. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + + +#ifndef IZ3SECONDARY_H +#define IZ3SECONDARY_H + +/** Interface class for secondary provers. */ + +#include "iz3base.h" +#include + +class iz3secondary : public iz3mgr { + public: + virtual int interpolate(const std::vector &frames, std::vector &interpolants) = 0; + virtual ~iz3secondary(){} + + protected: + iz3secondary(const iz3mgr &mgr) : iz3mgr(mgr) {} +}; + + + +#endif diff --git a/src/interp/iz3translate.cpp b/src/interp/iz3translate.cpp new file mode 100755 index 000000000..fae914292 --- /dev/null +++ b/src/interp/iz3translate.cpp @@ -0,0 +1,1778 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3translate.cpp + +Abstract: + + Translate a Z3 proof to in interpolated proof. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + +#include "iz3translate.h" +#include "iz3proof.h" +#include "iz3profiling.h" +#include "iz3interp.h" +#include "iz3proof_itp.h" + +#include +#include +#include +#include +#include +#include +#include + +//using std::vector; +#ifndef WIN32 +using namespace stl_ext; +#endif + + + +/* This translator goes directly from Z3 proofs to interpolated + proofs without an intermediate representation. No secondary + prover is used. +*/ + +class iz3translation_full : public iz3translation { +public: + + + typedef iz3proof_itp Iproof; + + Iproof *iproof; + + /* Here we have lots of hash tables for memoizing various methods and + other such global data structures. + */ + + typedef hash_map AstToInt; + AstToInt locality; // memoizes locality of Z3 proof terms + + typedef std::pair EquivEntry; + typedef hash_map EquivTab; + EquivTab equivs; // maps non-local terms to equivalent local terms, with proof + + typedef hash_set AstHashSet; + AstHashSet equivs_visited; // proofs already checked for equivalences + + typedef std::pair, hash_map > AstToIpf; + AstToIpf translation; // Z3 proof nodes to Iproof nodes + + int frames; // number of frames + + typedef std::set AstSet; + typedef hash_map AstToAstSet; + AstToAstSet hyp_map; // map proof terms to hypothesis set + + struct LocVar { // localization vars + ast var; // a fresh variable + ast term; // term it represents + int frame; // frame in which it's defined + LocVar(ast v, ast t, int f){var=v;term=t;frame=f;} + }; + + std::vector localization_vars; // localization vars in order of creation + typedef hash_map AstToAst; + AstToAst localization_map; // maps terms to their localization vars + + typedef hash_map AstToBool; + AstToBool occurs_in_memo; // memo of occurs_in function + + AstHashSet cont_eq_memo; // memo of cont_eq function + + AstToAst subst_memo; // memo of subst function + + symb commute; + + public: + + +#define from_ast(x) (x) + + // #define NEW_LOCALITY + +#ifdef NEW_LOCALITY + range rng; // the range of frames in the "A" part of the interpolant +#endif + + /* To handle skolemization, we have to scan the proof for skolem + symbols and assign each to a frame. THe assignment is heuristic. + */ + + int scan_skolems_rec(hash_map &memo, const ast &proof, int frame){ + std::pair foo(proof,INT_MAX); + std::pair bar = memo.insert(foo); + int &res = bar.first->second; + if(!bar.second) return res; + pfrule dk = pr(proof); + if(dk == PR_ASSERTED){ + ast ass = conc(proof); + res = frame_of_assertion(ass); + } + else if(dk == PR_SKOLEMIZE){ + ast quanted = arg(conc(proof),0); + if(op(quanted) == Not) + quanted = arg(quanted,0); + // range r = ast_range(quanted); + // if(range_is_empty(r)) + range r = ast_scope(quanted); + if(range_is_empty(r)) + throw "can't skolemize"; + if(frame == INT_MAX || !in_range(frame,r)) + frame = range_max(r); // this is desperation -- may fail + if(frame >= frames) frame = frames - 1; + add_frame_range(frame,arg(conc(proof),1)); + r = ast_scope(arg(conc(proof),1)); + } + else if(dk==PR_MODUS_PONENS_OEQ){ + frame = scan_skolems_rec(memo,prem(proof,0),frame); + scan_skolems_rec(memo,prem(proof,1),frame); + } + else { + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + int bar = scan_skolems_rec(memo,prem(proof,i),frame); + if(res == INT_MAX || res == bar) res = bar; + else if(bar != INT_MAX) res = -1; + } + } + return res; + } + + void scan_skolems(const ast &proof){ + hash_map memo; + scan_skolems_rec(memo,proof, INT_MAX); + } + + // determine locality of a proof term + // return frame of derivation if local, or -1 if not + // result INT_MAX means the proof term is a tautology + // memoized in hash_map "locality" + + int get_locality_rec(ast proof){ + std::pair foo(proof,INT_MAX); + std::pair bar = locality.insert(foo); + int &res = bar.first->second; + if(!bar.second) return res; + pfrule dk = pr(proof); + if(dk == PR_ASSERTED){ + ast ass = conc(proof); + res = frame_of_assertion(ass); +#ifdef NEW_LOCALITY + if(in_range(res,rng)) + res = range_max(rng); + else + res = frames-1; +#endif + } + else if(dk == PR_QUANT_INST){ + std::vector lits; + ast con = conc(proof); + get_Z3_lits(con, lits); + iproof->make_axiom(lits); + } + else { + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + ast arg = prem(proof,i); + int bar = get_locality_rec(arg); + if(res == INT_MAX || res == bar) res = bar; + else if(bar != INT_MAX) res = -1; + } + } + return res; + } + + + int get_locality(ast proof){ + // if(lia_z3_axioms_only) return -1; + int res = get_locality_rec(proof); + if(res != -1){ + ast con = conc(proof); + range rng = ast_scope(con); + + // hack: if a clause contains "true", it reduces to "true", + // which means we won't compute the range correctly. we handle + // this case by computing the ranges of the literals separately + + if(is_true(con)){ + std::vector lits; + get_Z3_lits(conc(proof),lits); + for(unsigned i = 0; i < lits.size(); i++) + rng = range_glb(rng,ast_scope(lits[i])); + } + + if(!range_is_empty(rng)){ + AstSet &hyps = get_hyps(proof); + for(AstSet::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it){ + ast hyp = *it; + rng = range_glb(rng,ast_scope(hyp)); + } + } + + if(res == INT_MAX){ + if(range_is_empty(rng)) + res = -1; + else res = range_max(rng); + } + else { + if(!in_range(res,rng)) + res = -1; + } + } + return res; + } + + + AstSet &get_hyps(ast proof){ + std::pair foo(proof,AstSet()); + std::pair bar = hyp_map.insert(foo); + AstSet &res = bar.first->second; + if(!bar.second) return res; + pfrule dk = pr(proof); + if(dk == PR_HYPOTHESIS){ + ast con = conc(proof); + res.insert(con); + } + else { + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + ast arg = prem(proof,i); + AstSet &arg_hyps = get_hyps(arg); + res.insert(arg_hyps.begin(),arg_hyps.end()); + } + if(dk == PR_LEMMA){ + ast con = conc(proof); + res.erase(mk_not(con)); + if(is_or(con)){ + int clause_size = num_args(con); + for(int i = 0; i < clause_size; i++){ + ast neglit = mk_not(arg(con,i)); + res.erase(neglit); + } + } + } + } +#if 0 + AstSet::iterator it = res.begin(), en = res.end(); + if(it != en){ + AstSet::iterator old = it; + ++it; + for(; it != en; ++it, ++old) + if(!(*old < *it)) + std::cout << "foo!"; + } +#endif + return res; + } + + // Find all the judgements of the form p <-> q, where + // p is local and q is non-local, recording them in "equivs" + // the map equivs_visited is used to record the already visited proof terms + + void find_equivs(ast proof){ + if(equivs_visited.find(proof) != equivs_visited.end()) + return; + equivs_visited.insert(proof); + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++) // do all the sub_terms + find_equivs(prem(proof,i)); + ast con = conc(proof); // get the conclusion + if(is_iff(con)){ + ast iff = con; + for(int i = 0; i < 2; i++) + if(!is_local(arg(iff,i)) && is_local(arg(iff,1-i))){ + std::pair > foo(arg(iff,i),std::pair(arg(iff,1-i),proof)); + equivs.insert(foo); + } + } + } + + // get the lits of a Z3 clause + void get_Z3_lits(ast t, std::vector &lits){ + opr dk = op(t); + if(dk == False) + return; // false = empty clause + if(dk == Or){ + unsigned nargs = num_args(t); + lits.resize(nargs); + for(unsigned i = 0; i < nargs; i++) // do all the sub_terms + lits[i] = arg(t,i); + } + else { + lits.push_back(t); + } + } + + // resolve two clauses represented as vectors of lits. replace first clause + void resolve(ast pivot, std::vector &cls1, std::vector &cls2){ + ast neg_pivot = mk_not(pivot); + for(unsigned i = 0; i < cls1.size(); i++){ + if(cls1[i] == pivot){ + cls1[i] = cls1.back(); + cls1.pop_back(); + bool found_pivot2 = false; + for(unsigned j = 0; j < cls2.size(); j++){ + if(cls2[j] == neg_pivot) + found_pivot2 = true; + else + cls1.push_back(cls2[j]); + } + assert(found_pivot2); + return; + } + } + assert(0 && "resolve failed"); + } + + // get lits resulting from unit resolution up to and including "position" + // TODO: this is quadratic -- fix it + void do_unit_resolution(ast proof, int position, std::vector &lits){ + ast orig_clause = conc(prem(proof,0)); + get_Z3_lits(orig_clause,lits); + for(int i = 1; i <= position; i++){ + std::vector unit(1); + unit[0] = conc(prem(proof,i)); + resolve(mk_not(unit[0]),lits,unit); + } + } + +#if 0 + // clear the localization variables + void clear_localization(){ + localization_vars.clear(); + localization_map.clear(); + } + + // create a fresh variable for localization + ast fresh_localization_var(ast term, int frame){ + std::ostringstream s; + s << "%" << (localization_vars.size()); + ast var = make_var(s.str().c_str(),get_type(term)); + sym_range(sym(var)) = range_full(); // make this variable global + localization_vars.push_back(LocVar(var,term,frame)); + return var; + } + + + // "localize" a term to a given frame range by + // creating new symbols to represent non-local subterms + + ast localize_term(ast e, const range &rng){ + if(ranges_intersect(ast_scope(e),rng)) + return e; // this term occurs in range, so it's O.K. + AstToAst::iterator it = localization_map.find(e); + if(it != localization_map.end()) + return it->second; + + // if is is non-local, we must first localize the arguments to + // the range of its function symbol + + int nargs = num_args(e); + if(nargs > 0 /* && (!is_local(e) || flo <= hi || fhi >= lo) */){ + range frng = rng; + if(op(e) == Uninterpreted){ + symb f = sym(e); + range srng = sym_range(f); + if(ranges_intersect(srng,rng)) // localize to desired range if possible + frng = range_glb(srng,rng); + } + std::vector largs(nargs); + for(int i = 0; i < nargs; i++){ + largs[i] = localize_term(arg(e,i),frng); + frng = range_glb(frng,ast_scope(largs[i])); + } + e = clone(e,largs); + assert(is_local(e)); + } + + + if(ranges_intersect(ast_scope(e),rng)) + return e; // this term occurs in range, so it's O.K. + + // choose a frame for the constraint that is close to range + int frame = range_near(ast_scope(e),rng); + + ast new_var = fresh_localization_var(e,frame); + localization_map[e] = new_var; + ast cnst = make(Equal,new_var,e); + // antes.push_back(std::pair(cnst,frame)); + return new_var; + } + + // some patterm matching functions + + // match logical or with nargs arguments + // assumes AIG form + bool match_or(ast e, ast *args, int nargs){ + if(op(e) != Or) return false; + int n = num_args(e); + if(n != nargs) return false; + for(int i = 0; i < nargs; i++) + args[i] = arg(e,i); + return true; + } + + // match operator f with exactly nargs arguments + bool match_op(ast e, opr f, ast *args, int nargs){ + if(op(e) != f) return false; + int n = num_args(e); + if(n != nargs) return false; + for(int i = 0; i < nargs; i++) + args[i] = arg(e,i); + return true; + } + + // see if the given formula can be interpreted as + // an axiom instance (e.g., an array axiom instance). + // if so, add it to "antes" in an appropriate frame. + // this may require "localization" + + void get_axiom_instance(ast e){ + + // "store" axiom + // (or (= w q) (= (select (store a1 w y) q) (select a1 q))) + // std::cout << "ax: "; show(e); + ast lits[2],eq_ops_l[2],eq_ops_r[2],sel_ops[2], sto_ops[3], sel_ops2[2] ; + if(match_or(e,lits,2)) + if(match_op(lits[0],Equal,eq_ops_l,2)) + if(match_op(lits[1],Equal,eq_ops_r,2)) + for(int i = 0; i < 2; i++){ // try the second equality both ways + if(match_op(eq_ops_r[0],Select,sel_ops,2)) + if(match_op(sel_ops[0],Store,sto_ops,3)) + if(match_op(eq_ops_r[1],Select,sel_ops2,2)) + for(int j = 0; j < 2; j++){ // try the first equality both ways + if(eq_ops_l[0] == sto_ops[1] + && eq_ops_l[1] == sel_ops[1] + && eq_ops_l[1] == sel_ops2[1] + && sto_ops[0] == sel_ops2[0]) + if(is_local(sel_ops[0])) // store term must be local + { + ast sto = sel_ops[0]; + ast addr = localize_term(eq_ops_l[1],ast_scope(sto)); + ast res = make(Or, + make(Equal,eq_ops_l[0],addr), + make(Equal, + make(Select,sto,addr), + make(Select,sel_ops2[0],addr))); + // int frame = range_min(ast_scope(res)); TODO + // antes.push_back(std::pair(res,frame)); + return; + } + std::swap(eq_ops_l[0],eq_ops_l[1]); + } + std::swap(eq_ops_r[0],eq_ops_r[1]); + } + } + + // a quantifier instantation looks like (~ forall x. P) \/ P[z/x] + // we need to find a time frame for P, then localize P[z/x] in this frame + + void get_quantifier_instance(ast e){ + ast disjs[2]; + if(match_or(e,disjs,2)){ + if(is_local(disjs[0])){ + ast res = localize_term(disjs[1], ast_scope(disjs[0])); + // int frame = range_min(ast_scope(res)); TODO + // antes.push_back(std::pair(res,frame)); + return; + } + } + } + + ast get_judgement(ast proof){ + ast con = from_ast(conc(proof)); + AstSet &hyps = get_hyps(proof); + std::vector hyps_vec; + for(AstSet::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it) + hyps_vec.push_back(*it); + if(hyps_vec.size() == 0) return con; + con = make(Or,mk_not(make(And,hyps_vec)),con); + return con; + } + + // does variable occur in expression? + int occurs_in1(ast var, ast e){ + std::pair foo(e,false); + std::pair bar = occurs_in_memo.insert(foo); + bool &res = bar.first->second; + if(bar.second){ + if(e == var) res = true; + int nargs = num_args(e); + for(int i = 0; i < nargs; i++) + res |= occurs_in1(var,arg(e,i)); + } + return res; + } + + int occurs_in(ast var, ast e){ + occurs_in_memo.clear(); + return occurs_in1(var,e); + } + + // find a controlling equality for a given variable v in a term + // a controlling equality is of the form v = t, which, being + // false would force the formula to have the specifid truth value + // returns t, or null if no such + + ast cont_eq(bool truth, ast v, ast e){ + if(is_not(e)) return cont_eq(!truth,v,arg(e,0)); + if(cont_eq_memo.find(e) != cont_eq_memo.end()) + return ast(); + cont_eq_memo.insert(e); + if(!truth && op(e) == Equal){ + if(arg(e,0) == v) return(arg(e,1)); + if(arg(e,1) == v) return(arg(e,0)); + } + if((!truth && op(e) == And) || (truth && op(e) == Or)){ + int nargs = num_args(e); + for(int i = 0; i < nargs; i++){ + ast res = cont_eq(truth, v, arg(e,i)); + if(!res.null()) return res; + } + } + return ast(); + } + + // substitute a term t for unbound occurrences of variable v in e + + ast subst(ast var, ast t, ast e){ + if(e == var) return t; + std::pair foo(e,ast()); + std::pair bar = subst_memo.insert(foo); + ast &res = bar.first->second; + if(bar.second){ + int nargs = num_args(e); + std::vector args(nargs); + for(int i = 0; i < nargs; i++) + args[i] = subst(var,t,arg(e,i)); + opr f = op(e); + if(f == Equal && args[0] == args[1]) res = mk_true(); + else res = clone(e,args); + } + return res; + } + + // apply a quantifier to a formula, with some optimizations + // 1) bound variable does not occur -> no quantifier + // 2) bound variable must be equal to some term -> substitute + + ast apply_quant(opr quantifier, ast var, ast e){ + if(!occurs_in(var,e))return e; + cont_eq_memo.clear(); + ast cterm = cont_eq(quantifier == Forall, var, e); + if(!cterm.null()){ + subst_memo.clear(); + return subst(var,cterm,e); + } + std::vector bvs; bvs.push_back(var); + return make_quant(quantifier,bvs,e); + } + + // add quantifiers over the localization vars + // to an interpolant for frames lo-hi + + ast add_quants(ast e, int lo, int hi){ + for(int i = localization_vars.size() - 1; i >= 0; i--){ + LocVar &lv = localization_vars[i]; + opr quantifier = (lv.frame >= lo && lv.frame <= hi) ? Exists : Forall; + e = apply_quant(quantifier,lv.var,e); + } + return e; + } + + int get_lits_locality(std::vector &lits){ + range rng = range_full(); + for(std::vector::iterator it = lits.begin(), en = lits.end(); it != en; ++it){ + ast lit = *it; + rng = range_glb(rng,ast_scope(lit)); + } + if(range_is_empty(rng)) return -1; + int hi = range_max(rng); + if(hi >= frames) return frames - 1; + return hi; + } +#endif + + int num_lits(ast ast){ + opr dk = op(ast); + if(dk == False) + return 0; + if(dk == Or){ + unsigned nargs = num_args(ast); + int n = 0; + for(unsigned i = 0; i < nargs; i++) // do all the sub_terms + n += num_lits(arg(ast,i)); + return n; + } + else + return 1; + } + + void symbols_out_of_scope_rec(hash_set &memo, hash_set &symb_memo, int frame, const ast &t){ + if(memo.find(t) != memo.end()) + return; + memo.insert(t); + if(op(t) == Uninterpreted){ + symb s = sym(t); + range r = sym_range(s); + if(!in_range(frame,r) && symb_memo.find(s) == symb_memo.end()){ + std::cout << string_of_symbol(s) << "\n"; + symb_memo.insert(s); + } + } + int nargs = num_args(t); + for(int i = 0; i < nargs; i++) + symbols_out_of_scope_rec(memo,symb_memo,frame,arg(t,i)); + } + + void symbols_out_of_scope(int frame, const ast &t){ + hash_set memo; + hash_set symb_memo; + symbols_out_of_scope_rec(memo,symb_memo,frame,t); + } + + void conc_symbols_out_of_scope(int frame, const ast &t){ + symbols_out_of_scope(frame,conc(t)); + } + + std::vector lit_trace; + hash_set marked_proofs; + + bool proof_has_lit(const ast &proof, const ast &lit){ + AstSet &hyps = get_hyps(proof); + if(hyps.find(mk_not(lit)) != hyps.end()) + return true; + std::vector lits; + ast con = conc(proof); + get_Z3_lits(con, lits); + for(unsigned i = 0; i < lits.size(); i++) + if(lits[i] == lit) + return true; + return false; + } + + + void trace_lit_rec(const ast &lit, const ast &proof, AstHashSet &memo){ + if(memo.find(proof) == memo.end()){ + memo.insert(proof); + AstSet &hyps = get_hyps(proof); + std::vector lits; + for(AstSet::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it) + lits.push_back(mk_not(*it)); + ast con = conc(proof); + get_Z3_lits(con, lits); + for(unsigned i = 0; i < lits.size(); i++){ + if(lits[i] == lit){ + print_expr(std::cout,proof); + std::cout << "\n"; + marked_proofs.insert(proof); + pfrule dk = pr(proof); + if(dk == PR_UNIT_RESOLUTION || dk == PR_LEMMA){ + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + ast arg = prem(proof,i); + trace_lit_rec(lit,arg,memo); + } + } + else + lit_trace.push_back(proof); + } + } + } + } + + ast traced_lit; + + int trace_lit(const ast &lit, const ast &proof){ + marked_proofs.clear(); + lit_trace.clear(); + traced_lit = lit; + AstHashSet memo; + trace_lit_rec(lit,proof,memo); + return lit_trace.size(); + } + + bool is_literal_or_lit_iff(const ast &lit){ + if(my_is_literal(lit)) return true; + if(op(lit) == Iff){ + return my_is_literal(arg(lit,0)) && my_is_literal(arg(lit,1)); + } + return false; + } + + bool my_is_literal(const ast &lit){ + ast abslit = is_not(lit) ? arg(lit,0) : lit; + int f = op(abslit); + return !(f == And || f == Or || f == Iff); + } + + hash_map asts_by_id; + + void print_lit(const ast &lit){ + ast abslit = is_not(lit) ? arg(lit,0) : lit; + if(!is_literal_or_lit_iff(lit)){ + if(is_not(lit)) std::cout << "~"; + int id = ast_id(abslit); + asts_by_id[id] = abslit; + std::cout << "[" << id << "]"; + } + else + print_expr(std::cout,lit); + } + + void expand(int id){ + if(asts_by_id.find(id) == asts_by_id.end()) + std::cout << "undefined\n"; + else { + ast lit = asts_by_id[id]; + std::string s = string_of_symbol(sym(lit)); + std::cout << "(" << s; + unsigned nargs = num_args(lit); + for(unsigned i = 0; i < nargs; i++){ + std::cout << " "; + print_lit(arg(lit,i)); + } + std::cout << ")\n";; + } + } + + void show_lit(const ast &lit){ + print_lit(lit); + std::cout << "\n"; + } + + void print_z3_lit(const ast &a){ + print_lit(from_ast(a)); + } + + void show_z3_lit(const ast &a){ + print_z3_lit(a); + std::cout << "\n"; + } + + + void show_con(const ast &proof, bool brief){ + if(!traced_lit.null() && proof_has_lit(proof,traced_lit)) + std::cout << "(*) "; + ast con = conc(proof); + AstSet &hyps = get_hyps(proof); + int count = 0; + for(AstSet::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it){ + if(brief && ++count > 5){ + std::cout << "... "; + break; + } + print_lit(*it); + std::cout << " "; + } + std::cout << "|- "; + std::vector lits; + get_Z3_lits(con,lits); + for(unsigned i = 0; i < lits.size(); i++){ + print_lit(lits[i]); + std::cout << " "; + } + range r = ast_scope(con); + std::cout << " {" << r.lo << "," << r.hi << "}"; + std::cout << "\n"; + } + + void show_step(const ast &proof){ + std::cout << "\n"; + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + std::cout << "(" << i << ") "; + ast arg = prem(proof,i); + show_con(arg,true); + } + std::cout << "|------ "; + std::cout << string_of_symbol(sym(proof)) << "\n"; + show_con(proof,false); + } + + void show_marked( const ast &proof){ + std::cout << "\n"; + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + ast arg = prem(proof,i); + if(!traced_lit.null() && proof_has_lit(arg,traced_lit)){ + std::cout << "(" << i << ") "; + show_con(arg,true); + } + } + } + + std::vector pfhist; + int pfhist_pos; + + void pfgoto(const ast &proof){ + if(pfhist.size() == 0) + pfhist_pos = 0; + else pfhist_pos++; + pfhist.resize(pfhist_pos); + pfhist.push_back(proof); + show_step(proof); + } + + + void pfback(){ + if(pfhist_pos > 0){ + pfhist_pos--; + show_step(pfhist[pfhist_pos]); + } + } + + void pffwd(){ + if(pfhist_pos < ((int)pfhist.size()) - 1){ + pfhist_pos++; + show_step(pfhist[pfhist_pos]); + } + } + + void pfprem(int i){ + if(pfhist.size() > 0){ + ast proof = pfhist[pfhist_pos]; + unsigned nprems = num_prems(proof); + if(i >= 0 && i < (int)nprems) + pfgoto(prem(proof,i)); + } + } + + + + // translate a unit resolution sequence + Iproof::node translate_ur(ast proof){ + ast prem0 = prem(proof,0); + Iproof::node itp = translate_main(prem0,true); + std::vector clause; + ast conc0 = conc(prem0); + int nprems = num_prems(proof); + if(nprems == 2 && conc0 == mk_not(conc(prem(proof,1)))) + clause.push_back(conc0); + else + get_Z3_lits(conc0,clause); + for(int position = 1; position < nprems; position++){ + ast ante = prem(proof,position); + ast pnode = conc(ante); + ast pnode_abs = !is_not(pnode) ? pnode : mk_not(pnode); + Iproof::node neg = itp; + Iproof::node pos = translate_main(ante, false); + if(is_not(pnode)){ + pnode = mk_not(pnode); + std::swap(neg,pos); + } + std::vector unit(1); + unit[0] = conc(ante); + resolve(mk_not(conc(ante)),clause,unit); + itp = iproof->make_resolution(pnode,clause,neg,pos); + } + return itp; + } + + // get an inequality in the form 0 <= t where t is a linear term + ast rhs_normalize_inequality(const ast &ineq){ + ast zero = make_int("0"); + ast thing = make(Leq,zero,zero); + linear_comb(thing,make_int("1"),ineq); + thing = simplify_ineq(thing); + return thing; + } + + bool check_farkas(const std::vector &prems, const ast &con){ + ast zero = make_int("0"); + ast thing = make(Leq,zero,zero); + for(unsigned i = 0; i < prems.size(); i++) + linear_comb(thing,make_int(rational(1)),prems[i]); + linear_comb(thing,make_int(rational(-1)),con); + thing = simplify_ineq(thing); + return arg(thing,1) == make_int(rational(0)); + } + + // get an inequality in the form t <= c or t < c, there t is affine and c constant + ast normalize_inequality(const ast &ineq){ + ast zero = make_int("0"); + ast thing = make(Leq,zero,zero); + linear_comb(thing,make_int("1"),ineq); + thing = simplify_ineq(thing); + ast lhs = arg(thing,0); + ast rhs = arg(thing,1); + opr o = op(rhs); + if(o != Numeral){ + if(op(rhs) == Plus){ + int nargs = num_args(rhs); + ast const_term = zero; + int i = 0; + if(nargs > 0 && op(arg(rhs,0)) == Numeral){ + const_term = arg(rhs,0); + i++; + } + if(i < nargs){ + std::vector non_const; + for(; i < nargs; i++) + non_const.push_back(arg(rhs,i)); + lhs = make(Sub,lhs,make(Plus,non_const)); + } + rhs = const_term; + } + else { + lhs = make(Sub,lhs,make(Plus,rhs)); + rhs = zero; + } + lhs = z3_simplify(lhs); + rhs = z3_simplify(rhs); + thing = make(op(thing),lhs,rhs); + } + return thing; + } + + void get_linear_coefficients(const ast &t, std::vector &coeffs){ + if(op(t) == Plus){ + int nargs = num_args(t); + for(int i = 0; i < nargs; i++) + coeffs.push_back(get_coeff(arg(t,i))); + } + else + coeffs.push_back(get_coeff(t)); + } + + /* given an affine term t, get the GCD of the coefficients in t. */ + ast gcd_of_coefficients(const ast &t){ + std::vector coeffs; + get_linear_coefficients(t,coeffs); + if(coeffs.size() == 0) + return make_int("1"); // arbitrary + rational d = coeffs[0]; + for(unsigned i = 1; i < coeffs.size(); i++){ + d = gcd(d,coeffs[i]); + } + return make_int(d); + } + + ast get_bounded_variable(const ast &ineq, bool &lb){ + ast nineq = normalize_inequality(ineq); + ast lhs = arg(nineq,0); + switch(op(lhs)){ + case Uninterpreted: + lb = false; + return lhs; + case Times: + if(arg(lhs,0) == make_int(rational(1))) + lb = false; + else if(arg(lhs,0) == make_int(rational(-1))) + lb = true; + else + throw unsupported(); + return arg(lhs,1); + default: + throw unsupported(); + } + } + + rational get_term_coefficient(const ast &t1, const ast &v){ + ast t = arg(normalize_inequality(t1),0); + if(op(t) == Plus){ + int nargs = num_args(t); + for(int i = 0; i < nargs; i++){ + if(get_linear_var(arg(t,i)) == v) + return get_coeff(arg(t,i)); + } + } + else + if(get_linear_var(t) == v) + return get_coeff(t); + return rational(0); + } + + + Iproof::node GCDtoDivRule(const ast &proof, bool pol, std::vector &coeffs, std::vector &prems, ast &cut_con){ + // gather the summands of the desired polarity + std::vector my_prems; + std::vector my_coeffs; + std::vector my_prem_cons; + for(unsigned i = pol ? 0 : 1; i < coeffs.size(); i+= 2){ + rational &c = coeffs[i]; + if(c.is_pos()){ + my_prems.push_back(prems[i]); + my_coeffs.push_back(make_int(c)); + my_prem_cons.push_back(conc(prem(proof,i))); + } + } + ast my_con = sum_inequalities(my_coeffs,my_prem_cons); + + // handle generalized GCD test. sadly, we dont' get the coefficients... + if(coeffs[0].is_zero()){ + bool lb; + int xtra_prem = 0; + ast bv = get_bounded_variable(conc(prem(proof,0)),lb); + rational bv_coeff = get_term_coefficient(my_con,bv); + if(bv_coeff.is_pos() != lb) + xtra_prem = 1; + if(bv_coeff.is_neg()) + bv_coeff = -bv_coeff; + + my_prems.push_back(prems[xtra_prem]); + my_coeffs.push_back(make_int(bv_coeff)); + my_prem_cons.push_back(conc(prem(proof,xtra_prem))); + my_con = sum_inequalities(my_coeffs,my_prem_cons); + } + + my_con = normalize_inequality(my_con); + Iproof::node hyp = iproof->make_hypothesis(mk_not(my_con)); + my_prems.push_back(hyp); + my_coeffs.push_back(make_int("1")); + my_prem_cons.push_back(mk_not(my_con)); + Iproof::node res = iproof->make_farkas(mk_false(),my_prems,my_prem_cons,my_coeffs); + + ast t = arg(my_con,0); + ast c = arg(my_con,1); + ast d = gcd_of_coefficients(t); + t = z3_simplify(mk_idiv(t,d)); + c = z3_simplify(mk_idiv(c,d)); + cut_con = make(op(my_con),t,c); + return iproof->make_cut_rule(my_con,d,cut_con,res); + } + + + ast divide_inequalities(const ast &x, const ast&y){ + std::vector xcoeffs,ycoeffs; + get_linear_coefficients(arg(x,1),xcoeffs); + get_linear_coefficients(arg(y,1),ycoeffs); + if(xcoeffs.size() != ycoeffs.size() || xcoeffs.size() == 0) + throw "bad assign-bounds lemma"; + rational ratio = xcoeffs[0]/ycoeffs[0]; + return make_int(ratio); // better be integer! + } + + ast AssignBounds2Farkas(const ast &proof, const ast &con){ + std::vector farkas_coeffs; + get_assign_bounds_coeffs(proof,farkas_coeffs); + std::vector lits; + int nargs = num_args(con); + if(nargs != (int)(farkas_coeffs.size())) + throw "bad assign-bounds theory lemma"; +#if 0 + for(int i = 1; i < nargs; i++) + lits.push_back(mk_not(arg(con,i))); + ast sum = sum_inequalities(farkas_coeffs,lits); + ast conseq = rhs_normalize_inequality(arg(con,0)); + ast d = divide_inequalities(sum,conseq); + std::vector my_coeffs; + my_coeffs.push_back(d); + for(unsigned i = 0; i < farkas_coeffs.size(); i++) + my_coeffs.push_back(farkas_coeffs[i]); +#else + std::vector my_coeffs; +#endif + std::vector my_cons; + for(int i = 1; i < nargs; i++){ + my_cons.push_back(mk_not(arg(con,i))); + my_coeffs.push_back(farkas_coeffs[i]); + } + ast farkas_con = normalize_inequality(sum_inequalities(my_coeffs,my_cons)); + my_cons.push_back(mk_not(farkas_con)); + my_coeffs.push_back(make_int("1")); + std::vector my_hyps; + for(int i = 0; i < nargs; i++) + my_hyps.push_back(iproof->make_hypothesis(my_cons[i])); + ast res = iproof->make_farkas(mk_false(),my_hyps,my_cons,my_coeffs); + res = iproof->make_cut_rule(farkas_con,farkas_coeffs[0],arg(con,0),res); + return res; + } + + ast AssignBoundsRule2Farkas(const ast &proof, const ast &con, std::vector prems){ + std::vector farkas_coeffs; + get_assign_bounds_rule_coeffs(proof,farkas_coeffs); + std::vector lits; + int nargs = num_prems(proof)+1; + if(nargs != (int)(farkas_coeffs.size())) + throw "bad assign-bounds theory lemma"; + std::vector my_coeffs; + std::vector my_cons; + for(int i = 1; i < nargs; i++){ + my_cons.push_back(conc(prem(proof,i-1))); + my_coeffs.push_back(farkas_coeffs[i]); + } + ast farkas_con = normalize_inequality(sum_inequalities(my_coeffs,my_cons)); + std::vector my_hyps; + for(int i = 1; i < nargs; i++) + my_hyps.push_back(prems[i-1]); + my_cons.push_back(mk_not(farkas_con)); + my_coeffs.push_back(make_int("1")); + my_hyps.push_back(iproof->make_hypothesis(mk_not(farkas_con))); + ast res = iproof->make_farkas(mk_false(),my_hyps,my_cons,my_coeffs); + res = iproof->make_cut_rule(farkas_con,farkas_coeffs[0],conc(proof),res); + return res; + } + + Iproof::node RewriteClause(Iproof::node clause, const ast &rew){ + if(pr(rew) == PR_MONOTONICITY){ + int nequivs = num_prems(rew); + for(int i = 0; i < nequivs; i++){ + Iproof::node equiv_pf = translate_main(prem(rew,i),false); + ast equiv = conc(prem(rew,i)); + clause = iproof->make_mp(equiv,clause,equiv_pf); + } + return clause; + } + if(pr(rew) == PR_TRANSITIVITY){ + clause = RewriteClause(clause,prem(rew,0)); + clause = RewriteClause(clause,prem(rew,1)); + return clause; + } + if(pr(rew) == PR_REWRITE){ + return clause; // just hope the rewrite does nothing! + } + throw unsupported(); + } + + + // Following code is for elimination of "commutativity" axiom + + Iproof::node make_commuted_modus_ponens(const ast &proof, const std::vector &args){ + ast pf = arg(args[1],0); + ast comm_equiv = arg(args[1],1); // equivalence relation with possible commutations + ast P = conc(prem(proof,0)); + ast Q = conc(proof); + Iproof::node P_pf = args[0]; + ast P_comm = arg(comm_equiv,0); + ast Q_comm = arg(comm_equiv,1); + if(P != P_comm) + P_pf = iproof->make_symmetry(P_comm,P,P_pf); + Iproof::node res = iproof->make_mp(comm_equiv,P_pf,pf); + if(Q != Q_comm) + res = iproof->make_symmetry(Q,Q_comm,res); + return res; + } + + Iproof::node make_commuted_monotonicity(const ast &proof, const std::vector &args){ + ast pf = arg(args[0],0); + ast comm_equiv = arg(args[0],1); // equivalence relation with possible commutations + ast con = make(Iff,make(Not,arg(comm_equiv,0)),make(Not,arg(comm_equiv,1))); + std::vector eqs; eqs.push_back(comm_equiv); + std::vector pfs; pfs.push_back(pf); + ast res = iproof->make_congruence(eqs,con,pfs); + res = make(commute,res,con); + return res; + } + + Iproof::node make_commuted_symmetry(const ast &proof, const std::vector &args){ + ast pf = arg(args[0],0); + ast comm_equiv = arg(args[0],1); // equivalence relation with possible commutations + ast con = make(Iff,arg(comm_equiv,1),arg(comm_equiv,0)); + ast res = iproof->make_symmetry(con,comm_equiv,pf); + res = make(commute,res,con); + return res; + } + + void unpack_commuted(const ast &proof, const ast &cm, ast &pf, ast &comm_equiv){ + if(sym(cm) == commute){ + pf = arg(cm,0); + comm_equiv = arg(cm,1); + } + else { + pf = cm; + comm_equiv = conc(proof); + } + } + + Iproof::node make_commuted_transitivity(const ast &proof, const std::vector &args){ + ast pf[2], comm_equiv[2]; + for(int i = 0; i < 2; i++) + unpack_commuted(prem(proof,i),args[i],pf[i],comm_equiv[i]); + if(!(arg(comm_equiv[0],1) == arg(comm_equiv[1],0))){ + ast tw = twist(prem(proof,1)); + ast np = translate_main(tw,false); + unpack_commuted(tw,np,pf[1],comm_equiv[1]); + } + ast con = make(Iff,arg(comm_equiv[0],0),arg(comm_equiv[1],1)); + ast res = iproof->make_transitivity(arg(comm_equiv[0],0),arg(comm_equiv[0],1),arg(comm_equiv[1],1),pf[0],pf[1]); + res = make(commute,res,con); + return res; + } + + ast commute_equality(const ast &eq){ + return make(Equal,arg(eq,1),arg(eq,0)); + } + + ast commute_equality_iff(const ast &con){ + if(op(con) != Iff || op(arg(con,0)) != Equal) + throw unsupported(); + return make(Iff,commute_equality(arg(con,0)),commute_equality(arg(con,1))); + } + + // convert a proof of a=b <-> c=d into a proof of b=a <-> d=c + // TODO: memoize this? + ast twist(const ast &proof){ + pfrule dk = pr(proof); + ast con = commute_equality_iff(conc(proof)); + int n = num_prems(proof); + std::vector prs(n); + if(dk == PR_MONOTONICITY){ + for(int i = 0; i < n; i++) + prs[i] = prem(proof,i); + } + else + for(int i = 0; i < n; i++) + prs[i] = twist(prem(proof,i)); + switch(dk){ + case PR_MONOTONICITY: + case PR_SYMMETRY: + case PR_TRANSITIVITY: + case PR_COMMUTATIVITY: + prs.push_back(con); + return clone(proof,prs); + default: + throw unsupported(); + } + } + + struct TermLt { + iz3mgr &m; + bool operator()(const ast &x, const ast &y){ + unsigned xid = m.ast_id(x); + unsigned yid = m.ast_id(y); + return xid < yid; + } + TermLt(iz3mgr &_m) : m(_m) {} + }; + + void SortTerms(std::vector &terms){ + TermLt foo(*this); + std::sort(terms.begin(),terms.end(),foo); + } + + ast SortSum(const ast &t){ + if(!(op(t) == Plus)) + return t; + int nargs = num_args(t); + if(nargs < 2) return t; + std::vector args(nargs); + for(int i = 0; i < nargs; i++) + args[i] = arg(t,i); + SortTerms(args); + return make(Plus,args); + } + + ast really_normalize_ineq(const ast &ineq){ + ast res = normalize_inequality(ineq); + res = make(op(res),SortSum(arg(res,0)),arg(res,1)); + return res; + } + + Iproof::node reconstruct_farkas(const std::vector &prems, const std::vector &pfs, const ast &con){ + int nprems = prems.size(); + std::vector pcons(nprems),npcons(nprems); + hash_map pcon_to_pf, npcon_to_pcon; + for(int i = 0; i < nprems; i++){ + pcons[i] = conc(prems[i]); + npcons[i] = really_normalize_ineq(pcons[i]); + pcon_to_pf[npcons[i]] = pfs[i]; + npcon_to_pcon[npcons[i]] = pcons[i]; + } + // ast leq = make(Leq,arg(con,0),arg(con,1)); + ast ncon = really_normalize_ineq(mk_not(con)); + pcons.push_back(mk_not(con)); + npcons.push_back(ncon); + // ast assumps = make(And,pcons); + ast new_proof; + if(is_sat(npcons,new_proof)) + throw "Proof error!"; + pfrule dk = pr(new_proof); + int nnp = num_prems(new_proof); + std::vector my_prems; + std::vector farkas_coeffs, my_pcons; + + if(dk == PR_TH_LEMMA + && get_theory_lemma_theory(new_proof) == ArithTheory + && get_theory_lemma_kind(new_proof) == FarkasKind) + get_farkas_coeffs(new_proof,farkas_coeffs); + else if(dk == PR_UNIT_RESOLUTION && nnp == 2){ + for(int i = 0; i < nprems; i++) + farkas_coeffs.push_back(make_int(rational(1))); + } + else + throw "cannot reconstruct farkas proof"; + + for(int i = 0; i < nnp; i++){ + ast p = conc(prem(new_proof,i)); + p = really_normalize_ineq(p); + if(pcon_to_pf.find(p) != pcon_to_pf.end()){ + my_prems.push_back(pcon_to_pf[p]); + my_pcons.push_back(npcon_to_pcon[p]); + } + else if(p == ncon){ + my_prems.push_back(iproof->make_hypothesis(mk_not(con))); + my_pcons.push_back(mk_not(con)); + } + else + throw "cannot reconstruct farkas proof"; + } + Iproof::node res = iproof->make_farkas(mk_false(),my_prems,my_pcons,farkas_coeffs); + return res; + } + + bool is_eq_propagate(const ast &proof){ + return pr(proof) == PR_TH_LEMMA && get_theory_lemma_theory(proof) == ArithTheory && get_theory_lemma_kind(proof) == EqPropagateKind; + } + + ast EqPropagate(const ast &con, const std::vector &prems, const std::vector &args){ + Iproof::node fps[2]; + ast ineq_con[2]; + for(int i = 0; i < 2; i++){ + opr o = i == 0 ? Leq : Geq; + ineq_con[i] = make(o, arg(con,0), arg(con,1)); + fps[i] = reconstruct_farkas(prems,args,ineq_con[i]); + } + ast res = iproof->make_leq2eq(arg(con,0), arg(con,1), ineq_con[0], ineq_con[1]); + std::vector dummy_clause; + for(int i = 0; i < 2; i++) + res = iproof->make_resolution(ineq_con[i],dummy_clause,res,fps[i]); + return res; + } + + struct CannotCombineEqPropagate {}; + + void CombineEqPropagateRec(const ast &proof, std::vector &prems, std::vector &args, ast &eqprem){ + if(pr(proof) == PR_TRANSITIVITY && is_eq_propagate(prem(proof,1))){ + CombineEqPropagateRec(prem(proof,0), prems, args, eqprem); + ast dummy; + CombineEqPropagateRec(prem(proof,1), prems, args, dummy); + return; + } + if(is_eq_propagate(proof)){ + int nprems = num_prems(proof); + for(int i = 0; i < nprems; i++){ + prems.push_back(prem(proof,i)); + ast ppf = translate_main(prem(proof,i),false); + args.push_back(ppf); + } + return; + } + eqprem = proof; + } + + ast CombineEqPropagate(const ast &proof){ + std::vector prems, args; + ast eq1; + CombineEqPropagateRec(proof, prems, args, eq1); + ast eq2con = conc(proof); + if(!eq1.null()) + eq2con = make(Equal,arg(conc(eq1),1),arg(conc(proof),1)); + ast eq2 = EqPropagate(eq2con,prems,args); + if(!eq1.null()){ + Iproof::node foo = translate_main(eq1,false); + eq2 = iproof->make_transitivity(arg(conc(eq1),0), arg(conc(eq1),1), arg(conc(proof),1), foo, eq2); + } + return eq2; + } + + bool get_store_array(const ast &t, ast &res){ + if(op(t) == Store){ + res = t; + return true; + } + int nargs = num_args(t); + for(int i = 0; i < nargs; i++) + if(get_store_array(arg(t,i),res)) + return true; + return false; + } + + // translate a Z3 proof term into interpolating proof system + + Iproof::node translate_main(ast proof, bool expect_clause = true){ + AstToIpf &tr = translation; + hash_map &trc = expect_clause ? tr.first : tr.second; + std::pair foo(proof,Iproof::node()); + std::pair::iterator, bool> bar = trc.insert(foo); + Iproof::node &res = bar.first->second; + if(!bar.second) return res; + + // Try the locality rule first + + int frame = get_locality(proof); + if(frame != -1){ + ast e = from_ast(conc(proof)); + if(frame >= frames) frame = frames - 1; + std::vector foo; + if(expect_clause) + get_Z3_lits(conc(proof),foo); + else + foo.push_back(e); + AstSet &hyps = get_hyps(proof); + for(AstSet::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it) + foo.push_back(mk_not(*it)); + res = iproof->make_assumption(frame,foo); + return res; + } + + // If the proof is not local, break it down by proof rule + + pfrule dk = pr(proof); + unsigned nprems = num_prems(proof); + if(dk == PR_UNIT_RESOLUTION){ + res = translate_ur(proof); + } + else if(dk == PR_LEMMA){ + ast contra = prem(proof,0); // this is a proof of false from some hyps + res = translate_main(contra); + if(!expect_clause){ + std::vector foo; // the negations of the hyps form a clause + foo.push_back(from_ast(conc(proof))); + AstSet &hyps = get_hyps(proof); + for(AstSet::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it) + foo.push_back(mk_not(*it)); + res = iproof->make_contra(res,foo); + } + } + else { + std::vector lits; + ast con = conc(proof); + if(expect_clause) + get_Z3_lits(con, lits); + else + lits.push_back(from_ast(con)); + + // pattern match some idioms + if(dk == PR_MODUS_PONENS && pr(prem(proof,0)) == PR_QUANT_INST && pr(prem(proof,1)) == PR_REWRITE ) { + res = iproof->make_axiom(lits); + return res; + } + if(dk == PR_MODUS_PONENS && expect_clause && op(con) == Or){ + Iproof::node clause = translate_main(prem(proof,0),true); + res = RewriteClause(clause,prem(proof,1)); + return res; + } + + if(dk == PR_MODUS_PONENS && expect_clause && op(con) == Or) + std::cout << "foo!\n"; + + if(1 && dk == PR_TRANSITIVITY && pr(prem(proof,1)) == PR_COMMUTATIVITY){ + Iproof::node clause = translate_main(prem(proof,0),true); + res = make(commute,clause,conc(prem(proof,0))); // HACK -- we depend on Iproof::node being same as ast. + return res; + } + + if(dk == PR_TRANSITIVITY && is_eq_propagate(prem(proof,1))){ + try { + res = CombineEqPropagate(proof); + return res; + } + catch(const CannotCombineEqPropagate &){ + } + } + + /* this is the symmetry rule for ~=, that is, takes x ~= y and yields y ~= x. + the proof idiom uses commutativity, monotonicity and mp, but we replace it here + with symmtrey and resolution, that is, we prove y = x |- x = y, then resolve + with the proof of ~(x=y) to get ~y=x. */ + if(dk == PR_MODUS_PONENS && pr(prem(proof,1)) == PR_MONOTONICITY && pr(prem(prem(proof,1),0)) == PR_COMMUTATIVITY && num_prems(prem(proof,1)) == 1){ + Iproof::node ante = translate_main(prem(proof,0),false); + ast eq0 = arg(conc(prem(prem(proof,1),0)),0); + ast eq1 = arg(conc(prem(prem(proof,1),0)),1); + Iproof::node eq1hy = iproof->make_hypothesis(eq1); + Iproof::node eq0pf = iproof->make_symmetry(eq0,eq1,eq1hy); + std::vector clause; // just a dummy + res = iproof->make_resolution(eq0,clause,ante,eq0pf); + return res; + } + + // translate all the premises + std::vector args(nprems); + for(unsigned i = 0; i < nprems; i++) + args[i] = translate_main(prem(proof,i),false); + + for(unsigned i = 0; i < nprems; i++) + if(sym(args[i]) == commute + && !(dk == PR_TRANSITIVITY || dk == PR_MODUS_PONENS || dk == PR_SYMMETRY || (dk == PR_MONOTONICITY && op(arg(con,0)) == Not))) + throw unsupported(); + + switch(dk){ + case PR_TRANSITIVITY: { + if(sym(args[0]) == commute || sym(args[1]) == commute) + res = make_commuted_transitivity(proof,args); + else { + // assume the premises are x = y, y = z + ast x = arg(conc(prem(proof,0)),0); + ast y = arg(conc(prem(proof,0)),1); + ast z = arg(conc(prem(proof,1)),1); + res = iproof->make_transitivity(x,y,z,args[0],args[1]); + } + break; + } + case PR_MONOTONICITY: { + std::vector eqs; eqs.resize(args.size()); + for(unsigned i = 0; i < args.size(); i++) + eqs[i] = conc(prem(proof,i)); + if(op(arg(con,0)) == Not && sym(args[0]) == commute) + res = make_commuted_monotonicity(proof,args); + else + res = iproof->make_congruence(eqs,con,args); + break; + } + case PR_REFLEXIVITY: { + res = iproof->make_reflexivity(con); + break; + } + case PR_SYMMETRY: { + if(sym(args[0]) == commute) + res = make_commuted_symmetry(proof,args); + else + res = iproof->make_symmetry(con,conc(prem(proof,0)),args[0]); + break; + } + case PR_MODUS_PONENS: { + if(sym(args[1]) == commute) + res = make_commuted_modus_ponens(proof,args); + else + res = iproof->make_mp(conc(prem(proof,1)),args[0],args[1]); + break; + } + case PR_TH_LEMMA: { + switch(get_theory_lemma_theory(proof)){ + case ArithTheory: + switch(get_theory_lemma_kind(proof)){ + case FarkasKind: { + std::vector farkas_coeffs, prem_cons; + get_farkas_coeffs(proof,farkas_coeffs); + if(nprems == 0) {// axiom, not rule + int nargs = num_args(con); + if(farkas_coeffs.size() != (unsigned)nargs){ + pfgoto(proof); + throw unsupported(); + } + for(int i = 0; i < nargs; i++){ + ast lit = mk_not(arg(con,i)); + prem_cons.push_back(lit); + args.push_back(iproof->make_hypothesis(lit)); + } + } + else { // rule version (proves false) + prem_cons.resize(nprems); + for(unsigned i = 0; i < nprems; i++) + prem_cons[i] = conc(prem(proof,i)); + } + res = iproof->make_farkas(con,args,prem_cons,farkas_coeffs); + break; + } + case Leq2EqKind: { + // conc should be (or x = y (not (leq x y)) (not(leq y z)) ) + ast xeqy = arg(conc(proof),0); + ast x = arg(xeqy,0); + ast y = arg(xeqy,1); + res = iproof->make_leq2eq(x,y,arg(arg(conc(proof),1),0),arg(arg(conc(proof),2),0)); + break; + } + case Eq2LeqKind: { + // conc should be (or (not (= x y)) (leq x y)) + ast xeqy = arg(arg(conc(proof),0),0); + ast xleqy = arg(conc(proof),1); + ast x = arg(xeqy,0); + ast y = arg(xeqy,1); + res = iproof->make_eq2leq(x,y,xleqy); + break; + } + case GCDTestKind: { + std::vector farkas_coeffs; + get_farkas_coeffs(proof,farkas_coeffs); + if(farkas_coeffs.size() != nprems){ + pfgoto(proof); + throw unsupported(); + } + std::vector my_prems; my_prems.resize(2); + std::vector my_prem_cons; my_prem_cons.resize(2); + std::vector my_farkas_coeffs; my_farkas_coeffs.resize(2); + my_prems[0] = GCDtoDivRule(proof, true, farkas_coeffs, args, my_prem_cons[0]); + my_prems[1] = GCDtoDivRule(proof, false, farkas_coeffs, args, my_prem_cons[1]); + ast con = mk_false(); + my_farkas_coeffs[0] = my_farkas_coeffs[1] = make_int("1"); + res = iproof->make_farkas(con,my_prems,my_prem_cons,my_farkas_coeffs); + break; + } + case AssignBoundsKind: { + if(args.size() > 0) + res = AssignBoundsRule2Farkas(proof, conc(proof), args); + else + res = AssignBounds2Farkas(proof,conc(proof)); + break; + } + case EqPropagateKind: { + std::vector prems(nprems); + for(unsigned i = 0; i < nprems; i++) + prems[i] = prem(proof,i); + res = EqPropagate(con,prems,args); + break; + } + default: + throw unsupported(); + } + break; + case ArrayTheory: {// nothing fancy for this + ast store_array; + if(!get_store_array(con,store_array)) + throw unsupported(); + res = iproof->make_axiom(lits,ast_scope(store_array)); + break; + } + default: + throw unsupported(); + } + break; + } + case PR_HYPOTHESIS: { + res = iproof->make_hypothesis(conc(proof)); + break; + } + case PR_QUANT_INST: { + res = iproof->make_axiom(lits); + break; + } + case PR_DEF_AXIOM: { // this should only happen for formulas resulting from quantifier instantiation + res = iproof->make_axiom(lits); + break; + } + default: + assert(0 && "translate_main: unsupported proof rule"); + throw unsupported(); + } + } + + return res; + } + + void clear_translation(){ + translation.first.clear(); + translation.second.clear(); + } + + // We actually compute the interpolant here and then produce a proof consisting of just a lemma + + iz3proof::node translate(ast proof, iz3proof &dst){ + std::vector itps; + scan_skolems(proof); + for(int i = 0; i < frames -1; i++){ +#ifdef NEW_LOCALITY + rng = range_downward(i); + locality.clear(); +#endif + iproof = iz3proof_itp::create(this,range_downward(i),weak_mode()); + Iproof::node ipf = translate_main(proof); + ast itp = iproof->interpolate(ipf); + itps.push_back(itp); + delete iproof; + clear_translation(); + } + // Very simple proof -- lemma of the empty clause with computed interpolation + iz3proof::node Ipf = dst.make_lemma(std::vector(),itps); // builds result in dst + return Ipf; + } + + iz3translation_full(iz3mgr &mgr, + iz3secondary *_secondary, + const std::vector > &cnsts, + const std::vector &parents, + const std::vector &theory) + : iz3translation(mgr, cnsts, parents, theory) + { + frames = cnsts.size(); + traced_lit = ast(); + type boolbooldom[2] = {bool_type(),bool_type()}; + commute = function("@commute",2,boolbooldom,bool_type()); + m().inc_ref(commute); + } + + ~iz3translation_full(){ + m().dec_ref(commute); + } +}; + + + + +#ifdef IZ3_TRANSLATE_FULL + +iz3translation *iz3translation::create(iz3mgr &mgr, + iz3secondary *secondary, + const std::vector > &cnsts, + const std::vector &parents, + const std::vector &theory){ + return new iz3translation_full(mgr,secondary,cnsts,parents,theory); +} + + +#if 1 + +// This is just to make sure certain methods are compiled, so we can call then from the debugger. + +void iz3translation_full_trace_lit(iz3translation_full *p, iz3mgr::ast lit, iz3mgr::ast proof){ + p->trace_lit(lit, proof); +} + +void iz3translation_full_show_step(iz3translation_full *p, iz3mgr::ast proof){ + p->show_step(proof); +} + +void iz3translation_full_show_marked(iz3translation_full *p, iz3mgr::ast proof){ + p->show_marked(proof); +} + +void iz3translation_full_show_lit(iz3translation_full *p, iz3mgr::ast lit){ + p->show_lit(lit); +} + +void iz3translation_full_show_z3_lit(iz3translation_full *p, iz3mgr::ast a){ + p->show_z3_lit(a); +} + +void iz3translation_full_pfgoto(iz3translation_full *p, iz3mgr::ast proof){ + p->pfgoto(proof); +} + + +void iz3translation_full_pfback(iz3translation_full *p ){ + p->pfback(); +} + +void iz3translation_full_pffwd(iz3translation_full *p ){ + p->pffwd(); +} + +void iz3translation_full_pfprem(iz3translation_full *p, int i){ + p->pfprem(i); +} + +void iz3translation_full_expand(iz3translation_full *p, int i){ + p->expand(i); +} + +void iz3translation_full_symbols_out_of_scope(iz3translation_full *p, int i, const iz3mgr::ast &t){ + p->symbols_out_of_scope(i,t); +} + +void iz3translation_full_conc_symbols_out_of_scope(iz3translation_full *p, int i, const iz3mgr::ast &t){ + p->conc_symbols_out_of_scope(i,t); +} + +struct stdio_fixer { + stdio_fixer(){ + std::cout.rdbuf()->pubsetbuf(0,0); + } + +} my_stdio_fixer; + +#endif + +#endif + + diff --git a/src/interp/iz3translate.h b/src/interp/iz3translate.h new file mode 100755 index 000000000..6dc01dfa7 --- /dev/null +++ b/src/interp/iz3translate.h @@ -0,0 +1,65 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3translate.h + +Abstract: + + Interface for proof translations from Z3 proofs to interpolatable + proofs. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + + +#ifndef IZ3TRANSLATION_H +#define IZ3TRANSLATION_H + +#include "iz3proof.h" +#include "iz3secondary.h" + +// This is a interface class for translation from Z3 proof terms to +// an interpolatable proof + +class iz3translation : public iz3base { +public: + virtual iz3proof::node translate(ast, iz3proof &) = 0; + virtual ast quantify(ast e, const range &rng){return e;} + virtual ~iz3translation(){} + + /** This is thrown when the proof cannot be translated. */ + struct unsupported { + }; + + static iz3translation *create(iz3mgr &mgr, + iz3secondary *secondary, + const std::vector > &frames, + const std::vector &parents, + const std::vector &theory); + + protected: + iz3translation(iz3mgr &mgr, + const std::vector > &_cnsts, + const std::vector &_parents, + const std::vector &_theory) + : iz3base(mgr,_cnsts,_parents,_theory) {} +}; + +//#define IZ3_TRANSLATE_DIRECT2 +#ifdef _FOCI2 +#define IZ3_TRANSLATE_DIRECT +#else +#define IZ3_TRANSLATE_FULL +#endif + +#endif + + + diff --git a/src/interp/iz3translate_direct.cpp b/src/interp/iz3translate_direct.cpp new file mode 100755 index 000000000..4b7d47e0f --- /dev/null +++ b/src/interp/iz3translate_direct.cpp @@ -0,0 +1,1767 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + iz3translate_direct.cpp + +Abstract: + + Translate a Z3 proof into the interpolating proof calculus. + Translation is direct, without transformations on the target proof + representaiton. + +Author: + + Ken McMillan (kenmcmil) + +Revision History: + +--*/ + + +#include "iz3translate.h" +#include "iz3proof.h" +#include "iz3profiling.h" +#include "iz3interp.h" + +#include +#include +#include +#include +#include +#include +#include + +//using std::vector; +#ifndef WIN32 +using namespace stl_ext; +#endif + +#ifndef WIN32 + +/* This can introduce an address dependency if the range type of hash_map has + a destructor. Since the code in this file is not used and only here for + historical comparisons, we allow this non-determinism. + */ +namespace stl_ext { + template + class hash { + public: + size_t operator()(const T *p) const { + return (size_t) p; + } + }; +} + +#endif + + +static int lemma_count = 0; +static int nll_lemma_count = 0; +#define SHOW_LEMMA_COUNT -1 + +// One half of a resolution. We need this to distinguish +// between resolving as a clause and as a unit clause. +// if pivot == conclusion(proof) it is unit. + +struct Z3_resolvent { + iz3base::ast proof; + bool is_unit; + iz3base::ast pivot; + Z3_resolvent(const iz3base::ast &_proof, bool _is_unit, const iz3base::ast &_pivot){ + proof = _proof; + is_unit = _is_unit; + pivot = _pivot; + } +}; + +namespace hash_space { + template <> + class hash { + public: + size_t operator()(const Z3_resolvent &p) const { + return (p.proof.hash() + p.pivot.hash()); + } + }; +} + +#ifdef WIN32 + +template <> inline +size_t stdext::hash_value(const Z3_resolvent& p) +{ + std::hash h; + return h(p); +} + + +namespace std { + template <> + class less { + public: + bool operator()(const Z3_resolvent &x, const Z3_resolvent &y) const { + size_t ixproof = (size_t) x.proof.raw(); + size_t iyproof = (size_t) y.proof.raw(); + if(ixproof < iyproof) return true; + if(ixproof > iyproof) return false; + return x.pivot < y.pivot; + } + }; +} + +#else + +bool operator==(const Z3_resolvent &x, const Z3_resolvent &y) { + return x.proof == y.proof && x.pivot == y.pivot; +} + + +#endif + +typedef std::vector ResolventAppSet; + +struct non_local_lits { + ResolventAppSet proofs; // the proof nodes being raised + non_local_lits(ResolventAppSet &_proofs){ + proofs.swap(_proofs); + } +}; + +namespace hash_space { + template <> + class hash { + public: + size_t operator()(const non_local_lits &p) const { + size_t h = 0; + for(ResolventAppSet::const_iterator it = p.proofs.begin(), en = p.proofs.end(); it != en; ++it) + h += (size_t)*it; + return h; + } + }; +} + +#ifdef WIN32 + +template <> inline +size_t stdext::hash_value(const non_local_lits& p) +{ + std::hash h; + return h(p); +} + +namespace std { + template <> + class less { + public: + bool operator()(const non_local_lits &x, const non_local_lits &y) const { + ResolventAppSet::const_iterator itx = x.proofs.begin(); + ResolventAppSet::const_iterator ity = y.proofs.begin(); + while(true){ + if(ity == y.proofs.end()) return false; + if(itx == x.proofs.end()) return true; + size_t xi = (size_t) *itx; + size_t yi = (size_t) *ity; + if(xi < yi) return true; + if(xi > yi) return false; + ++itx; ++ity; + } + } + }; +} + +#else + +bool operator==(const non_local_lits &x, const non_local_lits &y) { + ResolventAppSet::const_iterator itx = x.proofs.begin(); + ResolventAppSet::const_iterator ity = y.proofs.begin(); + while(true){ + if(ity == y.proofs.end()) return itx == x.proofs.end(); + if(itx == x.proofs.end()) return ity == y.proofs.end(); + if(*itx != *ity) return false; + ++itx; ++ity; + } +} + + +#endif + +/* This translator goes directly from Z3 proofs to interpolatable + proofs without an intermediate representation as an iz3proof. */ + +class iz3translation_direct : public iz3translation { +public: + + typedef ast Zproof; // type of non-interpolating proofs + typedef iz3proof Iproof; // type of interpolating proofs + + /* Here we have lots of hash tables for memoizing various methods and + other such global data structures. + */ + + typedef hash_map AstToInt; + AstToInt locality; // memoizes locality of Z3 proof terms + + typedef std::pair EquivEntry; + typedef hash_map EquivTab; + EquivTab equivs; // maps non-local terms to equivalent local terms, with proof + + typedef hash_set AstHashSet; + AstHashSet equivs_visited; // proofs already checked for equivalences + + + typedef std::pair, hash_map > AstToIpf; + AstToIpf translation; // Zproof nodes to Iproof nodes + + AstHashSet antes_added; // Z3 proof terms whose antecedents have been added to the list + std::vector > antes; // list of antecedent/frame pairs + std::vector local_antes; // list of local antecedents + + Iproof *iproof; // the interpolating proof we are constructing + + int frames; // number of frames + + typedef std::set AstSet; + typedef hash_map AstToAstSet; + AstToAstSet hyp_map; // map proof terms to hypothesis set + + struct LocVar { // localization vars + ast var; // a fresh variable + ast term; // term it represents + int frame; // frame in which it's defined + LocVar(ast v, ast t, int f){var=v;term=t;frame=f;} + }; + + std::vector localization_vars; // localization vars in order of creation + typedef hash_map AstToAst; + AstToAst localization_map; // maps terms to their localization vars + + typedef hash_map AstToBool; + + + + iz3secondary *secondary; // the secondary prover + + // Unique table for sets of non-local resolutions + hash_map non_local_lits_unique; + + // Unique table for resolvents + hash_map Z3_resolvent_unique; + + // Translation memo for case of non-local resolutions + hash_map non_local_translation; + + public: + + +#define from_ast(x) (x) + + // determine locality of a proof term + // return frame of derivation if local, or -1 if not + // result INT_MAX means the proof term is a tautology + // memoized in hash_map "locality" + + int get_locality_rec(ast proof){ + std::pair foo(proof,INT_MAX); + std::pair bar = locality.insert(foo); + int &res = bar.first->second; + if(!bar.second) return res; + if(pr(proof) == PR_ASSERTED){ + ast ass = conc(proof); + res = frame_of_assertion(ass); + } + else { + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + ast arg = prem(proof,i); + int bar = get_locality_rec(arg); + if(res == INT_MAX || res == bar) res = bar; + else if(bar != INT_MAX) res = -1; + } + } + return res; + } + + + int get_locality(ast proof){ + // if(lia_z3_axioms_only) return -1; + int res = get_locality_rec(proof); + if(res != -1){ + ast con = conc(proof); + range rng = ast_scope(con); + + // hack: if a clause contains "true", it reduces to "true", + // which means we won't compute the range correctly. we handle + // this case by computing the ranges of the literals separately + + if(is_true(con)){ + std::vector lits; + get_Z3_lits(conc(proof),lits); + for(unsigned i = 0; i < lits.size(); i++) + rng = range_glb(rng,ast_scope(lits[i])); + } + + if(!range_is_empty(rng)){ + AstSet &hyps = get_hyps(proof); + for(AstSet::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it){ + ast hyp = *it; + rng = range_glb(rng,ast_scope(hyp)); + } + } + + if(res == INT_MAX){ + if(range_is_empty(rng)) + res = -1; + else res = range_max(rng); + } + else { + if(!in_range(res,rng)) + res = -1; + } + } + return res; + } + + AstSet &get_hyps(ast proof){ + std::pair foo(proof,AstSet()); + std::pair bar = hyp_map.insert(foo); + AstSet &res = bar.first->second; + if(!bar.second) return res; + pfrule dk = pr(proof); + if(dk == PR_HYPOTHESIS){ + ast con = conc(proof); + res.insert(con); + } + else { + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + ast arg = prem(proof,i); + AstSet &arg_hyps = get_hyps(arg); + res.insert(arg_hyps.begin(),arg_hyps.end()); + } + if(dk == PR_LEMMA){ + ast con = conc(proof); + res.erase(mk_not(con)); + if(is_or(con)){ + int clause_size = num_args(con); + for(int i = 0; i < clause_size; i++){ + ast neglit = mk_not(arg(con,i)); + res.erase(neglit); + } + } + } + } +#if 0 + AstSet::iterator it = res.begin(), en = res.end(); + if(it != en){ + AstSet::iterator old = it; + ++it; + for(; it != en; ++it, ++old) + if(!(*old < *it)) + std::cout << "foo!"; + } +#endif + return res; + } + + + // Find all the judgements of the form p <-> q, where + // p is local and q is non-local, recording them in "equivs" + // the map equivs_visited is used to record the already visited proof terms + + void find_equivs(ast proof){ + if(equivs_visited.find(proof) != equivs_visited.end()) + return; + equivs_visited.insert(proof); + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++) // do all the sub_terms + find_equivs(prem(proof,i)); + ast con = conc(proof); // get the conclusion + if(is_iff(con)){ + ast iff = con; + for(int i = 0; i < 2; i++) + if(!is_local(arg(iff,i)) && is_local(arg(iff,1-i))){ + std::pair > foo(arg(iff,i),std::pair(arg(iff,1-i),proof)); + equivs.insert(foo); + } + } + } + + // get the lits of a Z3 clause as foci terms + void get_Z3_lits(ast t, std::vector &lits){ + opr dk = op(t); + if(dk == False) + return; // false = empty clause + if(dk == Or){ + unsigned nargs = num_args(t); + lits.resize(nargs); + for(unsigned i = 0; i < nargs; i++) // do all the sub_terms + lits[i] = arg(t,i); + } + else { + lits.push_back(t); + } + } + + // resolve two clauses represented as vectors of lits. replace first clause + void resolve(ast pivot, std::vector &cls1, std::vector &cls2){ + ast neg_pivot = mk_not(pivot); + for(unsigned i = 0; i < cls1.size(); i++){ + if(cls1[i] == pivot){ + cls1[i] = cls1.back(); + cls1.pop_back(); + bool found_pivot2 = false; + for(unsigned j = 0; j < cls2.size(); j++){ + if(cls2[j] == neg_pivot) + found_pivot2 = true; + else + cls1.push_back(cls2[j]); + } + assert(found_pivot2); + return; + } + } + assert(0 && "resolve failed"); + } + + // get lits resulting from unit resolution up to and including "position" + // TODO: this is quadratic -- fix it + void do_unit_resolution(ast proof, int position, std::vector &lits){ + ast orig_clause = conc(prem(proof,0)); + get_Z3_lits(orig_clause,lits); + for(int i = 1; i <= position; i++){ + std::vector unit(1); + unit[0] = conc(prem(proof,i)); + resolve(mk_not(unit[0]),lits,unit); + } + } + + + // clear the localization variables + void clear_localization(){ + localization_vars.clear(); + localization_map.clear(); + } + + // create a fresh variable for localization + ast fresh_localization_var(ast term, int frame){ + std::ostringstream s; + s << "%" << (localization_vars.size()); + ast var = make_var(s.str().c_str(),get_type(term)); + sym_range(sym(var)) = range_full(); // make this variable global + localization_vars.push_back(LocVar(var,term,frame)); + return var; + } + + + // "localize" a term to a given frame range by + // creating new symbols to represent non-local subterms + + ast localize_term(ast e, const range &rng){ + if(ranges_intersect(ast_scope(e),rng)) + return e; // this term occurs in range, so it's O.K. + AstToAst::iterator it = localization_map.find(e); + if(it != localization_map.end()) + return it->second; + + // if is is non-local, we must first localize the arguments to + // the range of its function symbol + + int nargs = num_args(e); + if(nargs > 0 /* && (!is_local(e) || flo <= hi || fhi >= lo) */){ + range frng = rng; + if(op(e) == Uninterpreted){ + symb f = sym(e); + range srng = sym_range(f); + if(ranges_intersect(srng,rng)) // localize to desired range if possible + frng = range_glb(srng,rng); + } + std::vector largs(nargs); + for(int i = 0; i < nargs; i++){ + largs[i] = localize_term(arg(e,i),frng); + frng = range_glb(frng,ast_scope(largs[i])); + } + e = clone(e,largs); + assert(is_local(e)); + } + + + if(ranges_intersect(ast_scope(e),rng)) + return e; // this term occurs in range, so it's O.K. + + // choose a frame for the constraint that is close to range + int frame = range_near(ast_scope(e),rng); + + ast new_var = fresh_localization_var(e,frame); + localization_map[e] = new_var; + ast cnst = make(Equal,new_var,e); + antes.push_back(std::pair(cnst,frame)); + return new_var; + } + + // some patterm matching functions + + // match logical or with nargs arguments + // assumes AIG form + bool match_or(ast e, ast *args, int nargs){ + if(op(e) != Or) return false; + int n = num_args(e); + if(n != nargs) return false; + for(int i = 0; i < nargs; i++) + args[i] = arg(e,i); + return true; + } + + // match operator f with exactly nargs arguments + bool match_op(ast e, opr f, ast *args, int nargs){ + if(op(e) != f) return false; + int n = num_args(e); + if(n != nargs) return false; + for(int i = 0; i < nargs; i++) + args[i] = arg(e,i); + return true; + } + + // see if the given formula can be interpreted as + // an axiom instance (e.g., an array axiom instance). + // if so, add it to "antes" in an appropriate frame. + // this may require "localization" + + void get_axiom_instance(ast e){ + + // "store" axiom + // (or (= w q) (= (select (store a1 w y) q) (select a1 q))) + // std::cout << "ax: "; show(e); + ast lits[2],eq_ops_l[2],eq_ops_r[2],sel_ops[2], sto_ops[3], sel_ops2[2] ; + if(match_or(e,lits,2)) + if(match_op(lits[0],Equal,eq_ops_l,2)) + if(match_op(lits[1],Equal,eq_ops_r,2)) + for(int i = 0; i < 2; i++){ // try the second equality both ways + if(match_op(eq_ops_r[0],Select,sel_ops,2)) + if(match_op(sel_ops[0],Store,sto_ops,3)) + if(match_op(eq_ops_r[1],Select,sel_ops2,2)) + for(int j = 0; j < 2; j++){ // try the first equality both ways + if(eq_ops_l[0] == sto_ops[1] + && eq_ops_l[1] == sel_ops[1] + && eq_ops_l[1] == sel_ops2[1] + && sto_ops[0] == sel_ops2[0]) + if(is_local(sel_ops[0])) // store term must be local + { + ast sto = sel_ops[0]; + ast addr = localize_term(eq_ops_l[1],ast_scope(sto)); + ast res = make(Or, + make(Equal,eq_ops_l[0],addr), + make(Equal, + make(Select,sto,addr), + make(Select,sel_ops2[0],addr))); + int frame = range_min(ast_scope(res)); + antes.push_back(std::pair(res,frame)); + return; + } + std::swap(eq_ops_l[0],eq_ops_l[1]); + } + std::swap(eq_ops_r[0],eq_ops_r[1]); + } + } + + // a quantifier instantation looks like (~ forall x. P) \/ P[z/x] + // we need to find a time frame for P, then localize P[z/x] in this frame + + void get_quantifier_instance(ast e){ + ast disjs[2]; + if(match_or(e,disjs,2)){ + if(is_local(disjs[0])){ + ast res = localize_term(disjs[1], ast_scope(disjs[0])); + int frame = range_min(ast_scope(res)); + antes.push_back(std::pair(res,frame)); + return; + } + } + } + + ast get_judgement(ast proof){ + ast con = from_ast(conc(proof)); + AstSet &hyps = get_hyps(proof); + std::vector hyps_vec; + for(AstSet::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it) + hyps_vec.push_back(*it); + if(hyps_vec.size() == 0) return con; + con = make(Or,mk_not(make(And,hyps_vec)),con); + return con; + } + + // add the premises of a proof term to the "antes" list + + void add_antes(ast proof){ + if(antes_added.find(proof) != antes_added.end()) return; + antes_added.insert(proof); + int frame = get_locality(proof); + if(frame != -1) + if(1){ + ast e = get_judgement(proof); + if(frame >= frames) frame = frames-1; // can happen if there are no symbols + antes.push_back(std::pair(e,frame)); + return; + } + pfrule dk = pr(proof); + if(dk == PR_ASSERTED){ + ast ass = conc(proof); + frame = frame_of_assertion(ass); + if(frame >= frames) frame = frames-1; // can happen if a theory fact + antes.push_back(std::pair(ass,frame)); + return; + } + if(dk == PR_TH_LEMMA && num_prems(proof) == 0){ + get_axiom_instance(conc(proof)); + } + if(dk == PR_QUANT_INST && num_prems(proof) == 0){ + get_quantifier_instance(conc(proof)); + } + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + ast arg = prem(proof,i); + add_antes(arg); + } + } + + + // add quantifiers over the localization vars + // to an interpolant for frames lo-hi + + ast add_quants(ast e, int lo, int hi){ + for(int i = localization_vars.size() - 1; i >= 0; i--){ + LocVar &lv = localization_vars[i]; + opr quantifier = (lv.frame >= lo && lv.frame <= hi) ? Exists : Forall; + e = apply_quant(quantifier,lv.var,e); + } + return e; + } + + int get_lits_locality(std::vector &lits){ + range rng = range_full(); + for(std::vector::iterator it = lits.begin(), en = lits.end(); it != en; ++it){ + ast lit = *it; + rng = range_glb(rng,ast_scope(lit)); + } + if(range_is_empty(rng)) return -1; + int hi = range_max(rng); + if(hi >= frames) return frames - 1; + return hi; + } + + + struct invalid_lemma {}; + + + + + // prove a lemma (clause) using current antes list + // return proof of the lemma + // use the secondary prover + + int prove_lemma(std::vector &lits){ + + + // first try localization + if(antes.size() == 0){ + int local_frame = get_lits_locality(lits); + if(local_frame != -1) + return iproof->make_assumption(local_frame,lits); // no proof needed for purely local fact + } + + // group the assumptions by frame + std::vector preds(frames); + for(unsigned i = 0; i < preds.size(); i++) + preds[i] = mk_true(); + for(unsigned i = 0; i < antes.size(); i++){ + int frame = antes[i].second; + preds[frame] = mk_and(preds[frame],antes[i].first); // conjoin it to frame + } + + for(unsigned i = 0; i < lits.size(); i++){ + int frame; + if(!weak_mode()){ + frame = range_max(ast_scope(lits[i])); + if(frame >= frames) frame = frames-1; // could happen if contains no symbols + } + else { + frame = range_min(ast_scope(lits[i])); + if(frame < 0){ + frame = range_max(ast_scope(lits[i])); // could happen if contains no symbols + if(frame >= frames) frame = frames-1; + } + } + preds[frame] = mk_and(preds[frame],mk_not(lits[i])); + } + + + std::vector itps; // holds interpolants + + +#if 1 + ++lemma_count; + // std::cout << "lemma: " << lemma_count << std::endl; + if(lemma_count == SHOW_LEMMA_COUNT){ + for(unsigned i = 0; i < lits.size(); i++) + show_lit(lits[i]); + std::cerr << "lemma written to file lemma.smt:\n"; + iz3base foo(*this,preds,std::vector(),std::vector()); + foo.print("lemma.smt"); + throw invalid_lemma(); + } +#endif + +#if 0 + std::cout << "\nLemma:\n"; + for(unsigned i = 0; i < lits.size(); i++) + show_lit(lits[i]); +#endif + + // interpolate using secondary prover + profiling::timer_start("foci"); + int sat = secondary->interpolate(preds,itps); + profiling::timer_stop("foci"); + + std::cout << "lemma done" << std::endl; + + // if sat, lemma isn't valid, something is wrong + if(sat){ +#if 1 + std::cerr << "invalid lemma written to file invalid_lemma.smt:\n"; + iz3base foo(*this,preds,std::vector(),std::vector()); + foo.print("invalid_lemma.smt"); +#endif + throw iz3_incompleteness(); + } + assert(sat == 0); // if sat, lemma doesn't hold! + + // quantifiy the localization vars + for(unsigned i = 0; i < itps.size(); i++) + itps[i] = add_quants(itps[i],0,i); + + // Make a lemma, storing interpolants + Iproof::node res = iproof->make_lemma(lits,itps); + +#if 0 + std::cout << "Lemma interps\n"; + for(unsigned i = 0; i < itps.size(); i++) + show(itps[i]); +#endif + + // Reset state for the next lemma + antes.clear(); + antes_added.clear(); + clear_localization(); // use a fresh localization for each lemma + + return res; + } + + // sanity check: make sure that any non-local lit is really resolved + // with something in the non_local_lits set + + void check_non_local(ast lit, non_local_lits *nll){ + if(nll) + for(ResolventAppSet::iterator it = nll->proofs.begin(), en = nll->proofs.end(); it != en; ++it){ + ast con = (*it)->pivot; + if(con == mk_not(lit)) return; + } + assert(0 && "bug in non-local resolution handling"); + } + + + void get_local_conclusion_lits(ast proof, bool expect_clause, AstSet &lits){ + std::vector reslits; + if(expect_clause) + get_Z3_lits(conc(proof),reslits); + else reslits.push_back(conc(proof)); + for(unsigned i = 0; i < reslits.size(); i++) + if(is_local(reslits[i])) + lits.insert(reslits[i]); + AstSet &pfhyps = get_hyps(proof); + for(AstSet::iterator hit = pfhyps.begin(), hen = pfhyps.end(); hit != hen; ++hit) + if(is_local(*hit)) + lits.insert(mk_not(*hit)); + } + + + void collect_resolvent_lits(Z3_resolvent *res, const AstSet &hyps, std::vector &lits){ + if(!res->is_unit){ + std::vector reslits; + get_Z3_lits(conc(res->proof),reslits); + for(unsigned i = 0; i < reslits.size(); i++) + if(reslits[i] != res->pivot) + lits.push_back(reslits[i]); + } + AstSet &pfhyps = get_hyps(res->proof); + for(AstSet::iterator hit = pfhyps.begin(), hen = pfhyps.end(); hit != hen; ++hit) + if(hyps.find(*hit) == hyps.end()) + lits.push_back(mk_not(*hit)); + } + + void filter_resolvent_lits(non_local_lits *nll, std::vector &lits){ + std::vector orig_lits; orig_lits.swap(lits); + std::set pivs; + for(ResolventAppSet::iterator it = nll->proofs.begin(), en = nll->proofs.end(); it != en; ++it){ + Z3_resolvent *res = *it; + pivs.insert(res->pivot); + pivs.insert(mk_not(res->pivot)); + } + for(unsigned i = 0; i < orig_lits.size(); i++) + if(pivs.find(orig_lits[i]) == pivs.end()) + lits.push_back(orig_lits[i]); + } + + void collect_all_resolvent_lits(non_local_lits *nll, std::vector &lits){ + if(nll){ + std::vector orig_lits; orig_lits.swap(lits); + std::set pivs; + for(ResolventAppSet::iterator it = nll->proofs.begin(), en = nll->proofs.end(); it != en; ++it){ + Z3_resolvent *res = *it; + pivs.insert(res->pivot); + pivs.insert(mk_not(res->pivot)); + } + for(ResolventAppSet::iterator it = nll->proofs.begin(), en = nll->proofs.end(); it != en; ++it){ + Z3_resolvent *res = *it; + { + std::vector reslits; + if(!res->is_unit) get_Z3_lits(conc(res->proof),reslits); + else reslits.push_back(conc(res->proof)); + for(unsigned i = 0; i < reslits.size(); i++) +#if 0 + if(reslits[i] != res->pivot && pivs.find(reslits[i]) == pivs.end()) +#endif + if(is_local(reslits[i])) + lits.push_back(reslits[i]); + } + } + for(unsigned i = 0; i < orig_lits.size(); i++) + if(pivs.find(orig_lits[i]) == pivs.end()) + lits.push_back(orig_lits[i]); + } + } + + void collect_proof_clause(ast proof, bool expect_clause, std::vector &lits){ + if(expect_clause) + get_Z3_lits(conc(proof),lits); + else + lits.push_back(from_ast(conc(proof))); + AstSet &hyps = get_hyps(proof); + for(AstSet::iterator hit = hyps.begin(), hen = hyps.end(); hit != hen; ++hit) + lits.push_back(mk_not(*hit)); + } + + + // turn a bunch of literals into a lemma, replacing + // non-local lits with their local equivalents + // adds the accumulated antecedents (antes) as + // proof obligations of the lemma + + Iproof::node fix_lemma(std::vector &con_lits, AstSet &hyps, non_local_lits *nll){ + std::vector lits(con_lits); + for(AstSet::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it) + lits.push_back(mk_not(*it)); + if(nll){ + for(ResolventAppSet::iterator it = nll->proofs.begin(), en = nll->proofs.end(); it != en; ++it){ + Z3_resolvent *res = *it; + collect_resolvent_lits(res,hyps,lits); + add_antes(res->proof); + } + filter_resolvent_lits(nll,lits); + } + for(unsigned int i = 0; i < lits.size(); i++){ + EquivTab::iterator it = equivs.find(lits[i]); + if(it != equivs.end()){ + lits[i] = it->second.first; // replace with local equivalent + add_antes(it->second.second); // collect the premises that prove this + } + else { + if(!is_local(lits[i])){ + check_non_local(lits[i],nll); + lits[i] = mk_false(); + } + } + } + // TODO: should check here that derivation is local? + Iproof::node res = prove_lemma(lits); + return res; + } + + int num_lits(ast ast){ + opr dk = op(ast); + if(dk == False) + return 0; + if(dk == Or){ + unsigned nargs = num_args(ast); + int n = 0; + for(unsigned i = 0; i < nargs; i++) // do all the sub_terms + n += num_lits(arg(ast,i)); + return n; + } + else + return 1; + } + + struct non_lit_local_ante {}; + + bool local_antes_simple; + + bool add_local_antes(ast proof, AstSet &hyps, bool expect_clause = false){ + if(antes_added.find(proof) != antes_added.end()) return true; + antes_added.insert(proof); + ast con = from_ast(conc(proof)); + pfrule dk = pr(proof); + if(is_local(con) || equivs.find(con) != equivs.end()){ + if(!expect_clause || num_lits(conc(proof)) == 1){ + AstSet &this_hyps = get_hyps(proof); + if(std::includes(hyps.begin(),hyps.end(),this_hyps.begin(),this_hyps.end())){ + // if(hyps.find(con) == hyps.end()) +#if 0 + if(/* lemma_count == SHOW_LEMMA_COUNT - 1 && */ !is_literal_or_lit_iff(conc(proof))){ + std::cout << "\nnon-lit local ante\n"; + show_step(proof); + show(conc(proof)); + throw non_lit_local_ante(); + } +#endif + local_antes.push_back(proof); + return true; + } + else + ; //std::cout << "bar!\n"; + } + } + if(dk == PR_ASSERTED + //|| dk == PR_HYPOTHESIS + //|| dk == PR_TH_LEMMA + || dk == PR_QUANT_INST + //|| dk == PR_UNIT_RESOLUTION + //|| dk == PR_LEMMA + ) + return false; + if(dk == PR_HYPOTHESIS && hyps.find(con) != hyps.end()) + ; //std::cout << "blif!\n"; + if(dk == PR_HYPOTHESIS + || dk == PR_LEMMA) + ; //std::cout << "foo!\n"; + if(dk == PR_TH_LEMMA && num_prems(proof) == 0){ + // Check if this is an axiom instance + get_axiom_instance(conc(proof)); + } + + // #define SIMPLE_PROOFS +#ifdef SIMPLE_PROOFS + if(!(dk == PR_TRANSITIVITY + || dk == PR_MONOTONICITY)) + local_antes_simple = false; +#endif + + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + ast arg = prem(proof,i); + try { + if(!add_local_antes(arg, hyps, dk == PR_UNIT_RESOLUTION && i == 0)) + return false; + } + catch (non_lit_local_ante) { + std::cout << "\n"; + show_step(proof); + show(conc(proof)); + throw non_lit_local_ante(); + } + } + return true; + } + + std::vector lit_trace; + hash_set marked_proofs; + + bool proof_has_lit(const ast &proof, const ast &lit){ + AstSet &hyps = get_hyps(proof); + if(hyps.find(mk_not(lit)) != hyps.end()) + return true; + std::vector lits; + ast con = conc(proof); + get_Z3_lits(con, lits); + for(unsigned i = 0; i < lits.size(); i++) + if(lits[i] == lit) + return true; + return false; + } + + + void trace_lit_rec(const ast &lit, const ast &proof, AstHashSet &memo){ + if(memo.find(proof) == memo.end()){ + memo.insert(proof); + AstSet &hyps = get_hyps(proof); + std::vector lits; + for(AstSet::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it) + lits.push_back(mk_not(*it)); + ast con = conc(proof); + get_Z3_lits(con, lits); + for(unsigned i = 0; i < lits.size(); i++){ + if(lits[i] == lit){ + print_expr(std::cout,proof); + std::cout << "\n"; + marked_proofs.insert(proof); + pfrule dk = pr(proof); + if(dk == PR_UNIT_RESOLUTION || dk == PR_LEMMA){ + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + ast arg = prem(proof,i); + trace_lit_rec(lit,arg,memo); + } + } + else + lit_trace.push_back(proof); + } + } + } + } + + ast traced_lit; + + int trace_lit(const ast &lit, const ast &proof){ + marked_proofs.clear(); + lit_trace.clear(); + traced_lit = lit; + AstHashSet memo; + trace_lit_rec(lit,proof,memo); + return lit_trace.size(); + } + + bool is_literal_or_lit_iff(const ast &lit){ + if(my_is_literal(lit)) return true; + if(op(lit) == Iff){ + return my_is_literal(arg(lit,0)) && my_is_literal(arg(lit,1)); + } + return false; + } + + bool my_is_literal(const ast &lit){ + ast abslit = is_not(lit) ? arg(lit,0) : lit; + int f = op(abslit); + return !(f == And || f == Or || f == Iff); + } + + void print_lit(const ast &lit){ + ast abslit = is_not(lit) ? arg(lit,0) : lit; + if(!is_literal_or_lit_iff(lit)){ + if(is_not(lit)) std::cout << "~"; + std::cout << "["; + print_expr(std::cout,abslit); + std::cout << "]"; + } + else + print_expr(std::cout,lit); + } + + void show_lit(const ast &lit){ + print_lit(lit); + std::cout << "\n"; + } + + void print_z3_lit(const ast &a){ + print_lit(from_ast(a)); + } + + void show_z3_lit(const ast &a){ + print_z3_lit(a); + std::cout << "\n"; + } + + + void show_con(const ast &proof, bool brief){ + if(!traced_lit.null() && proof_has_lit(proof,traced_lit)) + std::cout << "(*) "; + ast con = conc(proof); + AstSet &hyps = get_hyps(proof); + int count = 0; + for(AstSet::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it){ + if(brief && ++count > 5){ + std::cout << "... "; + break; + } + print_lit(*it); + std::cout << " "; + } + std::cout << "|- "; + std::vector lits; + get_Z3_lits(con,lits); + for(unsigned i = 0; i < lits.size(); i++){ + print_lit(lits[i]); + std::cout << " "; + } + std::cout << "\n"; + } + + void show_step(const ast &proof){ + std::cout << "\n"; + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + std::cout << "(" << i << ") "; + ast arg = prem(proof,i); + show_con(arg,true); + } + std::cout << "|------ "; + std::cout << string_of_symbol(sym(proof)) << "\n"; + show_con(proof,false); + } + + void show_marked( const ast &proof){ + std::cout << "\n"; + unsigned nprems = num_prems(proof); + for(unsigned i = 0; i < nprems; i++){ + ast arg = prem(proof,i); + if(!traced_lit.null() && proof_has_lit(arg,traced_lit)){ + std::cout << "(" << i << ") "; + show_con(arg,true); + } + } + } + + std::vector pfhist; + int pfhist_pos; + + void pfgoto(const ast &proof){ + if(pfhist.size() == 0) + pfhist_pos = 0; + else pfhist_pos++; + pfhist.resize(pfhist_pos); + pfhist.push_back(proof); + show_step(proof); + } + + void show_nll(non_local_lits *nll){ + if(!nll)return; + for(ResolventAppSet::iterator it = nll->proofs.begin(), en = nll->proofs.end(); it != en; ++it){ + Z3_resolvent *res = *it; + show_step(res->proof); + std::cout << "Pivot: "; + show(res->pivot); + std::cout << std::endl; + } + } + + void pfback(){ + if(pfhist_pos > 0){ + pfhist_pos--; + show_step(pfhist[pfhist_pos]); + } + } + + void pffwd(){ + if(pfhist_pos < ((int)pfhist.size()) - 1){ + pfhist_pos++; + show_step(pfhist[pfhist_pos]); + } + } + + void pfprem(int i){ + if(pfhist.size() > 0){ + ast proof = pfhist[pfhist_pos]; + unsigned nprems = num_prems(proof); + if(i >= 0 && i < (int)nprems) + pfgoto(prem(proof,i)); + } + } + + int extract_th_lemma_common(std::vector &lits, non_local_lits *nll, bool lemma_nll = true){ + std::vector la = local_antes; + local_antes.clear(); // clear antecedents for next lemma + antes_added.clear(); + // std::vector lits; + AstSet hyps; // no hyps + for(unsigned i = 0; i < la.size(); i++) + lits.push_back(mk_not(from_ast(conc(la[i])))); + // lits.push_back(from_ast(conc(proof))); + Iproof::node res =fix_lemma(lits,hyps, lemma_nll ? nll : 0); + for(unsigned i = 0; i < la.size(); i++){ + Iproof::node q = translate_main(la[i],nll,false); + ast pnode = from_ast(conc(la[i])); + assert(is_local(pnode) || equivs.find(pnode) != equivs.end()); + Iproof::node neg = res; + Iproof::node pos = q; + if(is_not(pnode)){ + pnode = mk_not(pnode); + std::swap(neg,pos); + } + try { + res = iproof->make_resolution(pnode,neg,pos); + } + catch (const iz3proof::proof_error){ + std::cout << "\nresolution error in theory lemma\n"; + std::cout << "lits:\n"; + for(unsigned j = 0; j < lits.size(); j++) + show_lit(lits[j]); + std::cout << "\nstep:\n"; + show_step(la[i]); + throw invalid_lemma(); + } + } + return res; + } + + Iproof::node extract_simple_proof(const ast &proof, hash_set &leaves){ + if(leaves.find(proof) != leaves.end()) + return iproof->make_hypothesis(conc(proof)); + ast con = from_ast(conc(proof)); + pfrule dk = pr(proof); + unsigned nprems = num_prems(proof); + std::vector args(nprems); + for(unsigned i = 0; i < nprems; i++){ + ast arg = prem(proof,i); + args[i] = extract_simple_proof(arg,leaves); + } + + switch(dk){ + case PR_TRANSITIVITY: + return iproof->make_transitivity(con,args[0],args[1]); + case PR_MONOTONICITY: + return iproof->make_congruence(con,args); + case PR_REFLEXIVITY: + return iproof->make_reflexivity(con); + case PR_SYMMETRY: + return iproof->make_symmetry(con,args[0]); + } + assert(0 && "extract_simple_proof: unknown op"); + return 0; + } + + int extract_th_lemma_simple(const ast &proof, std::vector &lits){ + std::vector la = local_antes; + local_antes.clear(); // clear antecedents for next lemma + antes_added.clear(); + + hash_set leaves; + for(unsigned i = 0; i < la.size(); i++) + leaves.insert(la[i]); + + Iproof::node ipf = extract_simple_proof(proof,leaves); + ast con = from_ast(conc(proof)); + Iproof::node hyp = iproof->make_hypothesis(mk_not(con)); + ipf = iproof->make_eqcontra(ipf,hyp); + + // std::vector lits; + AstSet hyps; // no hyps + for(unsigned i = 0; i < la.size(); i++) + lits.push_back(mk_not(from_ast(conc(la[i])))); + // lits.push_back(from_ast(conc(proof))); + + Iproof::node res = iproof->make_contra(ipf,lits); + + for(unsigned i = 0; i < la.size(); i++){ + Iproof::node q = translate_main(la[i],0,false); + ast pnode = from_ast(conc(la[i])); + assert(is_local(pnode) || equivs.find(pnode) != equivs.end()); + Iproof::node neg = res; + Iproof::node pos = q; + if(is_not(pnode)){ + pnode = mk_not(pnode); + std::swap(neg,pos); + } + try { + res = iproof->make_resolution(pnode,neg,pos); + } + catch (const iz3proof::proof_error){ + std::cout << "\nresolution error in theory lemma\n"; + std::cout << "lits:\n"; + for(unsigned j = 0; j < lits.size(); j++) + show_lit(lits[j]); + std::cout << "\nstep:\n"; + show_step(la[i]); + throw invalid_lemma(); + } + } + return res; + } + + // #define NEW_EXTRACT_TH_LEMMA + + void get_local_hyps(const ast &proof, std::set &res){ + std::set hyps = get_hyps(proof); + for(std::set::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it){ + ast hyp = *it; + if(is_local(hyp)) + res.insert(hyp); + } + } + + int extract_th_lemma(ast proof, std::vector &lits, non_local_lits *nll){ + pfrule dk = pr(proof); + unsigned nprems = num_prems(proof); +#ifdef NEW_EXTRACT_TH_LEMMA + if(nprems == 0 && !nll) +#else + if(nprems == 0) +#endif + return 0; + if(nprems == 0 && dk == PR_TH_LEMMA) + // Check if this is an axiom instance + get_axiom_instance(conc(proof)); + + local_antes_simple = true; + for(unsigned i = 0; i < nprems; i++){ + ast arg = prem(proof,i); + if(!add_local_antes(arg,get_hyps(proof))){ + local_antes.clear(); // clear antecedents for next lemma + antes_added.clear(); + antes.clear(); + return 0; + } + } +#ifdef NEW_EXTRACT_TH_LEMMA + bool lemma_nll = nprems > 1; + if(nll && !lemma_nll){ + lemma_nll = false; + // std::cout << "lemma count = " << nll_lemma_count << "\n"; + for(ResolventAppSet::iterator it = nll->proofs.begin(), en = nll->proofs.end(); it != en; ++it){ + Z3_resolvent *res = *it; + ast arg = res->proof; + std::set loc_hyps; get_local_hyps(arg,loc_hyps); + if(!add_local_antes(arg,loc_hyps)){ + local_antes.clear(); // clear antecedents for next lemma + antes_added.clear(); + antes.clear(); + return 0; + } + } + collect_all_resolvent_lits(nll,lits); + } + int my_count = nll_lemma_count++; + int res; + try { + res = extract_th_lemma_common(lits,nll,lemma_nll); + } +#if 1 + catch (const invalid_lemma &) { + std::cout << "\n\nlemma: " << my_count; + std::cout << "\n\nproof node: \n"; + show_step(proof); + std::cout << "\n\nnon-local: \n"; + show_nll(nll); + pfgoto(nll->proofs[0]->proof); + show(conc(pfhist.back())); + pfprem(1); + show(conc(pfhist.back())); + pfprem(0); + show(conc(pfhist.back())); + pfprem(0); + show(conc(pfhist.back())); + pfprem(0); + show(conc(pfhist.back())); + std::cout << "\n\nliterals: \n"; + for(int i = 0; i < lits.size(); i++) + show_lit(lits[i]); + throw invalid_lemma(); + } +#endif + + return res; +#else +#ifdef SIMPLE_PROOFS + if(local_antes_simple && !nll) + return extract_th_lemma_simple(proof, lits); +#endif + return extract_th_lemma_common(lits,nll); +#endif + } + + int extract_th_lemma_ur(ast proof, int position, std::vector &lits, non_local_lits *nll){ + for(int i = 0; i <= position; i++){ + ast arg = prem(proof,i); + if(!add_local_antes(arg,get_hyps(proof),i==0)){ + local_antes.clear(); // clear antecedents for next lemma + antes_added.clear(); + antes.clear(); + return 0; + } + } + return extract_th_lemma_common(lits,nll); + } + + // see if any of the pushed resolvents are resolutions + // push the current proof into the latest such + int push_into_resolvent(ast proof, std::vector &lits, non_local_lits *nll, bool expect_clause){ + if(!nll) return 0; + if(num_args(proof) > 1) return 0; + ResolventAppSet resos = nll->proofs; + int pos = resos.size()-1; + for( ResolventAppSet::reverse_iterator it = resos.rbegin(), en = resos.rend(); it != en; ++it, --pos){ + Z3_resolvent *reso = *it; + ast ante = reso->proof; + ast pivot = reso->pivot; + bool is_unit = reso->is_unit; + pfrule dk = pr(ante); + bool pushable = dk == PR_UNIT_RESOLUTION || dk == PR_LEMMA; + if(!pushable && num_args(ante) > 1){ +#if 0 + if (!is_local(conc(ante))) + std::cout << "non-local "; + std::cout << "pushable!\n"; +#endif + pushable = true; + } + if(pushable){ + // remove the resolvent from list and add new clause as resolvent + resos.erase((++it).base()); + for(; pos < (int)resos.size(); pos++){ + Z3_resolvent *r = resos[pos]; + resos[pos] = find_resolvent(r->proof,r->is_unit,mk_not(pivot)); + } + resos.push_back(find_resolvent(proof,!expect_clause,mk_not(pivot))); + non_local_lits *new_nll = find_nll(resos); + try { + int res = translate_main(ante,new_nll,!is_unit); + return res; + } + catch (const invalid_lemma &) { + std::cout << "\n\npushing: \n"; + std::cout << "nproof node: \n"; + show_step(proof); + std::cout << "\n\nold non-local: \n"; + show_nll(nll); + std::cout << "\n\nnew non-local: \n"; + show_nll(new_nll); + throw invalid_lemma(); + } + } + } + return 0; // no pushed resolvents are resolution steps + } + + non_local_lits *find_nll(ResolventAppSet &proofs){ + if(proofs.empty()) + return (non_local_lits *)0; + std::pair foo(non_local_lits(proofs),(non_local_lits *)0); + std::pair::iterator,bool> bar = + non_local_lits_unique.insert(foo); + non_local_lits *&res = bar.first->second; + if(bar.second) + res = new non_local_lits(bar.first->first); + return res; + } + + Z3_resolvent *find_resolvent(ast proof, bool unit, ast pivot){ + std::pair foo(Z3_resolvent(proof,unit,pivot),(Z3_resolvent *)0); + std::pair::iterator,bool> bar = + Z3_resolvent_unique.insert(foo); + Z3_resolvent *&res = bar.first->second; + if(bar.second) + res = new Z3_resolvent(bar.first->first); + return res; + } + + // translate a unit resolution at position pos of given app + int translate_ur(ast proof, int position, non_local_lits *nll){ + ast ante = prem(proof,position); + if(position <= 0) + return translate_main(ante, nll); + ast pnode = conc(ante); + ast pnode_abs = !is_not(pnode) ? pnode : mk_not(pnode); + if(is_local(pnode) || equivs.find(pnode) != equivs.end()){ + Iproof::node neg = translate_ur(proof,position-1,nll); + Iproof::node pos = translate_main(ante, nll, false); + if(is_not(pnode)){ + pnode = mk_not(pnode); + std::swap(neg,pos); + } + try { + return iproof->make_resolution(pnode,neg,pos); + } + catch (const iz3proof::proof_error){ + std::cout << "resolution error in unit_resolution, position" << position << "\n"; + show_step(proof); + throw invalid_lemma(); + } + } + else { + // non-local pivot we have no local equivalent for + if(true){ + // try pushing the non-local resolution up + pfrule dk = pr(ante); + non_local_lits *old_nll = nll; + if(dk == PR_HYPOTHESIS) + ; //std::cout << "non-local hyp!\n"; // resolving with a hyp is a no-op + else { + ResolventAppSet new_proofs; + if(nll) new_proofs = nll->proofs; + Z3_resolvent *reso = find_resolvent(ante,true,pnode); + new_proofs.push_back(reso); + nll = find_nll(new_proofs); + } + try { + return translate_ur(proof,position-1,nll); + } + catch (const invalid_lemma &) { + if(old_nll != nll){ + std::cout << "\n\nadded_nll: \n"; + std::cout << "nproof node: \n"; + show_step(proof); + std::cout << "\n\new non-local step: \n"; + show_step(nll->proofs.back()->proof); + } + throw invalid_lemma(); + } + + } + else { + // just make a lemma + std::vector lits; + do_unit_resolution(proof,position,lits); + int res; + if(!(res = extract_th_lemma_ur(proof,position,lits,nll))){ + for(int i = 0; i <= position; i++){ + z3pf p = prem(proof,i); + add_antes(p); + } + res = fix_lemma(lits,get_hyps(proof),nll); + } + return res; + } + } + } + + non_local_lits *update_nll(ast proof, bool expect_clause, non_local_lits *nll){ + std::vector lits; + collect_proof_clause(proof,expect_clause,lits); + AstSet litset; + litset.insert(lits.begin(),lits.end()); + ResolventAppSet to_keep; + for(int i = nll->proofs.size()-1; i >= 0; --i){ + ast traced_lit = (nll->proofs[i])->pivot; + ast traced_lit_neg = mk_not(traced_lit); + if(litset.find(traced_lit) != litset.end() || litset.find(traced_lit_neg) != litset.end()){ + to_keep.push_back(nll->proofs[i]); + std::vector reslits; + AstSet dummy; + collect_resolvent_lits(nll->proofs[i],dummy,reslits); + litset.insert(reslits.begin(),reslits.end()); + } + } + if(to_keep.size() == nll->proofs.size()) return nll; + ResolventAppSet new_proofs; + for(int i = to_keep.size() - 1; i >= 0; --i) + new_proofs.push_back(to_keep[i]); + return find_nll(new_proofs); + } + + // translate a Z3 proof term into a foci proof term + + Iproof::node translate_main(ast proof, non_local_lits *nll, bool expect_clause = true){ + non_local_lits *old_nll = nll; + if(nll) nll = update_nll(proof,expect_clause,nll); + AstToIpf &tr = nll ? non_local_translation[nll] : translation; + hash_map &trc = expect_clause ? tr.first : tr.second; + std::pair foo(proof,INT_MAX); + std::pair bar = trc.insert(foo); + int &res = bar.first->second; + if(!bar.second) return res; + + + try { + int frame = get_locality(proof); + if(frame != -1){ + ast e = from_ast(conc(proof)); + if(frame >= frames) frame = frames - 1; + std::vector foo; + if(expect_clause) + get_Z3_lits(conc(proof),foo); + else + foo.push_back(e); + AstSet &hyps = get_hyps(proof); + for(AstSet::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it) + foo.push_back(mk_not(*it)); + res = iproof->make_assumption(frame,foo); + return res; + } + + pfrule dk = pr(proof); + unsigned nprems = num_prems(proof); + if(dk == PR_UNIT_RESOLUTION){ + res = translate_ur(proof, nprems - 1, nll); + } + else if(dk == PR_LEMMA){ + ast contra = prem(proof,0); // this is a proof of false from some hyps + res = translate_main(contra, nll); + if(!expect_clause){ + std::vector foo; // the negations of the hyps form a clause + foo.push_back(from_ast(conc(proof))); + AstSet &hyps = get_hyps(proof); + for(AstSet::iterator it = hyps.begin(), en = hyps.end(); it != en; ++it) + foo.push_back(mk_not(*it)); + res = iproof->make_contra(res,foo); + } + } + else { + std::vector lits; + ast con = conc(proof); + if(expect_clause) + get_Z3_lits(con, lits); + else + lits.push_back(from_ast(con)); +#ifdef NEW_EXTRACT_TH_LEMMA + if(!(res = push_into_resolvent(proof,lits,nll,expect_clause))){ + if(!(res = extract_th_lemma(proof,lits,nll))){ +#else + if(!(res = extract_th_lemma(proof,lits,nll))){ + if(!(res = push_into_resolvent(proof,lits,nll,expect_clause))){ +#endif + // std::cout << "extract theory lemma failed\n"; + add_antes(proof); + res = fix_lemma(lits,get_hyps(proof),nll); + } + } + } +#ifdef CHECK_PROOFS + + if(0){ + AstSet zpf_con_lits, ipf_con_lits; + get_local_conclusion_lits(proof, expect_clause, zpf_con_lits); + if(nll){ + for(unsigned i = 0; i < nll->proofs.size(); i++) + get_local_conclusion_lits(nll->proofs[i]->proof,!nll->proofs[i]->is_unit,zpf_con_lits); + } + std::vector ipf_con; + iproof->get_conclusion(res,ipf_con); + for(unsigned i = 0; i < ipf_con.size(); i++) + ipf_con_lits.insert(ipf_con[i]); + if(!(ipf_con_lits == zpf_con_lits)){ + std::cout << "proof error:\n"; + std::cout << "expected lits:\n"; + for(AstSet::iterator hit = zpf_con_lits.begin(), hen = zpf_con_lits.end(); hit != hen; ++hit) + show_lit(*hit); + std::cout << "got lits:\n"; + for(AstSet::iterator hit = ipf_con_lits.begin(), hen = ipf_con_lits.end(); hit != hen; ++hit) + show_lit(*hit); + std::cout << "\nproof step:"; + show_step(proof); + std::cout << "\n"; + throw invalid_lemma(); + } + } +#endif + + return res; + } + + catch (const invalid_lemma &) { + if(old_nll != nll){ + std::cout << "\n\nupdated nll: \n"; + std::cout << "nproof node: \n"; + show_step(proof); + std::cout << "\n\new non-local: \n"; + show_nll(nll); + } + throw invalid_lemma(); + } + + } + + // Proof translation is in two stages: + // 1) Translate ast proof term to Zproof + // 2) Translate Zproof to Iproof + + Iproof::node translate(ast proof, Iproof &dst){ + iproof = &dst; + Iproof::node Ipf = translate_main(proof,0); // builds result in dst + return Ipf; + } + + iz3translation_direct(iz3mgr &mgr, + iz3secondary *_secondary, + const std::vector > &cnsts, + const std::vector &parents, + const std::vector &theory) + : iz3translation(mgr, cnsts, parents, theory) + { + secondary = _secondary; + frames = cnsts.size(); + traced_lit = ast(); + } + + ~iz3translation_direct(){ + for(hash_map::iterator + it = non_local_lits_unique.begin(), + en = non_local_lits_unique.end(); + it != en; + ++it) + delete it->second; + + for(hash_map::iterator + it = Z3_resolvent_unique.begin(), + en = Z3_resolvent_unique.end(); + it != en; + ++it) + delete it->second; + } +}; + + + + +#ifdef IZ3_TRANSLATE_DIRECT + +iz3translation *iz3translation::create(iz3mgr &mgr, + iz3secondary *secondary, + const std::vector > &cnsts, + const std::vector &parents, + const std::vector &theory){ + return new iz3translation_direct(mgr,secondary,cnsts,parents,theory); +} + + +#if 1 + +void iz3translation_direct_trace_lit(iz3translation_direct *p, iz3mgr::ast lit, iz3mgr::ast proof){ + p->trace_lit(lit, proof); +} + +void iz3translation_direct_show_step(iz3translation_direct *p, iz3mgr::ast proof){ + p->show_step(proof); +} + +void iz3translation_direct_show_marked(iz3translation_direct *p, iz3mgr::ast proof){ + p->show_marked(proof); +} + +void iz3translation_direct_show_lit(iz3translation_direct *p, iz3mgr::ast lit){ + p->show_lit(lit); +} + +void iz3translation_direct_show_z3_lit(iz3translation_direct *p, iz3mgr::ast a){ + p->show_z3_lit(a); +} + +void iz3translation_direct_pfgoto(iz3translation_direct *p, iz3mgr::ast proof){ + p->pfgoto(proof); +} + +void iz3translation_direct_show_nll(iz3translation_direct *p, non_local_lits *nll){ + p->show_nll(nll); +} + +void iz3translation_direct_pfback(iz3translation_direct *p ){ + p->pfback(); +} + +void iz3translation_direct_pffwd(iz3translation_direct *p ){ + p->pffwd(); +} + +void iz3translation_direct_pfprem(iz3translation_direct *p, int i){ + p->pfprem(i); +} + + +struct stdio_fixer { + stdio_fixer(){ + std::cout.rdbuf()->pubsetbuf(0,0); + } + +} my_stdio_fixer; + +#endif + +#endif + + diff --git a/src/muz/base/dl_context.cpp b/src/muz/base/dl_context.cpp index 8183744b5..fbef46ae3 100644 --- a/src/muz/base/dl_context.cpp +++ b/src/muz/base/dl_context.cpp @@ -179,6 +179,7 @@ namespace datalog { void context::push() { m_trail.push_scope(); m_trail.push(restore_rules(m_rule_set)); + m_trail.push(restore_vec_size_trail(m_rule_fmls)); m_trail.push(restore_vec_size_trail(m_background)); } @@ -186,6 +187,10 @@ namespace datalog { if (m_trail.get_num_scopes() == 0) { throw default_exception("there are no backtracking points to pop to"); } + if(m_engine.get()){ + if(get_engine() != DUALITY_ENGINE) + throw default_exception("operation is not supported by engine"); + } m_trail.pop_scope(1); } @@ -699,6 +704,10 @@ namespace datalog { check_existential_tail(r); check_positive_predicates(r); break; + case DUALITY_ENGINE: + check_existential_tail(r); + check_positive_predicates(r); + break; case CLP_ENGINE: check_existential_tail(r); check_positive_predicates(r); @@ -920,6 +929,9 @@ namespace datalog { else if (e == symbol("clp")) { m_engine_type = CLP_ENGINE; } + else if (e == symbol("duality")) { + m_engine_type = DUALITY_ENGINE; + } if (m_engine_type == LAST_ENGINE) { expr_fast_mark1 mark; @@ -945,6 +957,18 @@ namespace datalog { } lbool context::query(expr* query) { +#if 0 + // TODO: what? + if(get_engine() != DUALITY_ENGINE) { + new_query(); + rule_set::iterator it = m_rule_set.begin(), end = m_rule_set.end(); + rule_ref r(m_rule_manager); + for (; it != end; ++it) { + r = *it; + check_rule(r); + } + } +#endif m_mc = mk_skip_model_converter(); m_last_status = OK; m_last_answer = 0; @@ -958,6 +982,8 @@ namespace datalog { case CLP_ENGINE: flush_add_rules(); break; + case DUALITY_ENGINE: + break; default: UNREACHABLE(); } @@ -983,6 +1009,7 @@ namespace datalog { if (get_engine() == DATALOG_ENGINE) { m_rel = dynamic_cast(m_engine.get()); } + } } @@ -1014,7 +1041,6 @@ namespace datalog { out << "\n---------------\n"; out << "Original rules\n"; display_rules(out); - out << "\n---------------\n"; out << "Transformed rules\n"; m_transformed_rule_set.display(out); @@ -1076,6 +1102,15 @@ namespace datalog { } } + void context::get_raw_rule_formulas(expr_ref_vector& rules, svector& names){ + for (unsigned i = 0; i < m_rule_fmls.size(); ++i) { + expr_ref r = bind_variables(m_rule_fmls[i].get(), true); + rules.push_back(r.get()); + // rules.push_back(m_rule_fmls[i].get()); + names.push_back(m_rule_names[i]); + } + } + void context::get_rules_as_formulas(expr_ref_vector& rules, svector& names) { expr_ref fml(m); datalog::rule_manager& rm = get_rule_manager(); diff --git a/src/muz/base/dl_context.h b/src/muz/base/dl_context.h index 5eef04202..51b54bc01 100644 --- a/src/muz/base/dl_context.h +++ b/src/muz/base/dl_context.h @@ -357,6 +357,7 @@ namespace datalog { rule_set & get_rules() { flush_add_rules(); return m_rule_set; } void get_rules_as_formulas(expr_ref_vector& fmls, svector& names); + void get_raw_rule_formulas(expr_ref_vector& fmls, svector& names); void add_fact(app * head); void add_fact(func_decl * pred, const relation_fact & fact); @@ -492,14 +493,16 @@ namespace datalog { /** \brief retrieve model from inductive invariant that shows query is unsat. - \pre engine == 'pdr' - this option is only supported for PDR mode. + \pre engine == 'pdr' || engine == 'duality' - this option is only supported + for PDR mode and Duality mode. */ model_ref get_model(); /** \brief retrieve proof from derivation of the query. - \pre engine == 'pdr' - this option is only supported for PDR mode. + \pre engine == 'pdr' || engine == 'duality'- this option is only supported + for PDR mode and Duality mode. */ proof_ref get_proof(); diff --git a/src/muz/base/dl_engine_base.h b/src/muz/base/dl_engine_base.h index 52dc9acd4..eaeebe979 100644 --- a/src/muz/base/dl_engine_base.h +++ b/src/muz/base/dl_engine_base.h @@ -30,7 +30,8 @@ namespace datalog { QBMC_ENGINE, TAB_ENGINE, CLP_ENGINE, - LAST_ENGINE + LAST_ENGINE, + DUALITY_ENGINE }; class engine_base { diff --git a/src/muz/base/dl_util.h b/src/muz/base/dl_util.h index e13e7b53a..57f7575ca 100644 --- a/src/muz/base/dl_util.h +++ b/src/muz/base/dl_util.h @@ -57,7 +57,6 @@ namespace datalog { LAST_CACHE_MODE }; - struct std_string_hash_proc { unsigned operator()(const std::string & s) const { return string_hash(s.c_str(), static_cast(s.length()), 17); } diff --git a/src/muz/base/fixedpoint_params.pyg b/src/muz/base/fixedpoint_params.pyg index 46f4fce99..e09e1307b 100644 --- a/src/muz/base/fixedpoint_params.pyg +++ b/src/muz/base/fixedpoint_params.pyg @@ -67,6 +67,13 @@ def_module_params('fixedpoint', ('print_statistics', BOOL, False, 'print statistics'), ('use_utvpi', BOOL, True, 'PDR: Enable UTVPI strategy'), ('tab_selection', SYMBOL, 'weight', 'selection method for tabular strategy: weight (default), first, var-use'), + ('full_expand', BOOL, False, 'DUALITY: Fully expand derivation trees'), + ('no_conj', BOOL, False, 'DUALITY: No forced covering (conjectures)'), + ('feasible_edges', BOOL, True, 'DUALITY: Don\'t expand definitley infeasible edges'), + ('use_underapprox', BOOL, False, 'DUALITY: Use underapproximations'), + ('stratified_inlining', BOOL, False, 'DUALITY: Use stratified inlining'), + ('recursion_bound', UINT, UINT_MAX, 'DUALITY: Recursion bound for stratified inlining'), + ('profile', BOOL, False, 'DUALITY: profile run time'), ('dump_aig', SYMBOL, '', 'Dump clauses in AIG text format (AAG) to the given file name'), )) diff --git a/src/muz/duality/duality_dl_interface.cpp b/src/muz/duality/duality_dl_interface.cpp new file mode 100644 index 000000000..12dd4ff3e --- /dev/null +++ b/src/muz/duality/duality_dl_interface.cpp @@ -0,0 +1,493 @@ +/*++ +Copyright (c) 2013 Microsoft Corporation + +Module Name: + + duality_dl_interface.cpp + +Abstract: + + SMT2 interface for Duality + +Author: + + Krystof Hoder (t-khoder) 2011-9-22. + Modified by Ken McMIllan (kenmcmil) 2013-4-18. + +Revision History: + +--*/ + +#include "dl_context.h" +#include "dl_mk_coi_filter.h" +#include "dl_mk_interp_tail_simplifier.h" +#include "dl_mk_subsumption_checker.h" +#include "dl_mk_rule_inliner.h" +#include "dl_rule.h" +#include "dl_rule_transformer.h" +#include "smt2parser.h" +#include "duality_dl_interface.h" +#include "dl_rule_set.h" +#include "dl_mk_slice.h" +#include "dl_mk_unfold.h" +#include "dl_mk_coalesce.h" +#include "expr_abstract.h" +#include "model_smt2_pp.h" +#include "model_v2_pp.h" +#include "fixedpoint_params.hpp" +#include "scoped_proof.h" + +// template class symbol_table; + + +#include "duality.h" +#include "duality_profiling.h" + +// using namespace Duality; + +namespace Duality { + + enum DualityStatus {StatusModel, StatusRefutation, StatusUnknown, StatusNull}; + + class duality_data { + public: + context ctx; + RPFP::LogicSolver *ls; + RPFP *rpfp; + + DualityStatus status; + std::vector clauses; + std::vector > clause_labels; + hash_map map; // edges to clauses + Solver::Counterexample cex; + + duality_data(ast_manager &_m) : ctx(_m,config(params_ref())) { + ls = 0; + rpfp = 0; + status = StatusNull; + } + ~duality_data(){ + if(rpfp) + dealloc(rpfp); + if(ls) + dealloc(ls); + if(cex.tree) + delete cex.tree; + } + }; + + +dl_interface::dl_interface(datalog::context& dl_ctx) : + engine_base(dl_ctx.get_manager(), "duality"), + m_ctx(dl_ctx) + +{ + _d = 0; + // dl_ctx.get_manager().toggle_proof_mode(PGM_FINE); +} + + +dl_interface::~dl_interface() { + if(_d) + dealloc(_d); +} + + +// +// Check if the new rules are weaker so that we can +// re-use existing context. +// +#if 0 +void dl_interface::check_reset() { + // TODO + datalog::rule_ref_vector const& new_rules = m_ctx.get_rules().get_rules(); + datalog::rule_ref_vector const& old_rules = m_old_rules.get_rules(); + bool is_subsumed = !old_rules.empty(); + for (unsigned i = 0; is_subsumed && i < new_rules.size(); ++i) { + is_subsumed = false; + for (unsigned j = 0; !is_subsumed && j < old_rules.size(); ++j) { + if (m_ctx.check_subsumes(*old_rules[j], *new_rules[i])) { + is_subsumed = true; + } + } + if (!is_subsumed) { + TRACE("pdr", new_rules[i]->display(m_ctx, tout << "Fresh rule ");); + m_context->reset(); + } + } + m_old_rules.reset(); + m_old_rules.add_rules(new_rules.size(), new_rules.c_ptr()); +} +#endif + + +lbool dl_interface::query(::expr * query) { + + // we restore the initial state in the datalog context + m_ctx.ensure_opened(); + + // if there is old data, get the cex and dispose (later) + Solver::Counterexample old_cex; + duality_data *old_data = _d; + if(old_data) + old_cex = old_data->cex; + + scoped_proof generate_proofs_please(m_ctx.get_manager()); + + // make a new problem and solver + _d = alloc(duality_data,m_ctx.get_manager()); + _d->ls = alloc(RPFP::iZ3LogicSolver,_d->ctx); + _d->rpfp = alloc(RPFP,_d->ls); + + + + expr_ref_vector rules(m_ctx.get_manager()); + svector< ::symbol> names; + // m_ctx.get_rules_as_formulas(rules, names); + m_ctx.get_raw_rule_formulas(rules, names); + + // get all the rules as clauses + std::vector &clauses = _d->clauses; + clauses.clear(); + for (unsigned i = 0; i < rules.size(); ++i) { + expr e(_d->ctx,rules[i].get()); + clauses.push_back(e); + } + + // turn the query into a clause + expr q(_d->ctx,m_ctx.bind_variables(query,false)); + + std::vector b_sorts; + std::vector b_names; + if (q.is_quantifier() && !q.is_quantifier_forall()) { + int bound = q.get_quantifier_num_bound(); + for(int j = 0; j < bound; j++){ + b_sorts.push_back(q.get_quantifier_bound_sort(j)); + b_names.push_back(q.get_quantifier_bound_name(j)); + } + q = q.arg(0); + } + + expr qc = implies(q,_d->ctx.bool_val(false)); + qc = _d->ctx.make_quant(Forall,b_sorts,b_names,qc); + clauses.push_back(qc); + + // get the background axioms + unsigned num_asserts = m_ctx.get_num_assertions(); + for (unsigned i = 0; i < num_asserts; ++i) { + expr e(_d->ctx,m_ctx.get_assertion(i)); + _d->rpfp->AssertAxiom(e); + } + + // creates 1-1 map between clauses and rpfp edges + _d->rpfp->FromClauses(clauses); + + // populate the edge-to-clause map + for(unsigned i = 0; i < _d->rpfp->edges.size(); ++i) + _d->map[_d->rpfp->edges[i]] = i; + + // create a solver object + + Solver *rs = Solver::Create("duality", _d->rpfp); + + rs->LearnFrom(old_cex); // new solver gets hints from old cex + + // set its options + IF_VERBOSE(1, rs->SetOption("report","1");); + rs->SetOption("full_expand",m_ctx.get_params().full_expand() ? "1" : "0"); + rs->SetOption("no_conj",m_ctx.get_params().no_conj() ? "1" : "0"); + rs->SetOption("feasible_edges",m_ctx.get_params().feasible_edges() ? "1" : "0"); + rs->SetOption("use_underapprox",m_ctx.get_params().use_underapprox() ? "1" : "0"); + rs->SetOption("stratified_inlining",m_ctx.get_params().stratified_inlining() ? "1" : "0"); + unsigned rb = m_ctx.get_params().recursion_bound(); + if(rb != UINT_MAX){ + std::ostringstream os; os << rb; + rs->SetOption("recursion_bound", os.str()); + } + + // Solve! + bool ans; + try { + ans = rs->Solve(); + } + catch (Duality::solver::cancel_exception &exn){ + throw default_exception("duality canceled"); + } + + // profile! + + if(m_ctx.get_params().profile()) + print_profile(std::cout); + + // save the result and counterexample if there is one + _d->status = ans ? StatusModel : StatusRefutation; + _d->cex = rs->GetCounterexample(); + + if(old_data){ + old_data->cex.tree = 0; // we own it now + dealloc(old_data); + } + + + dealloc(rs); + + // true means the RPFP problem is SAT, so the query is UNSAT + return ans ? l_false : l_true; +} + +expr_ref dl_interface::get_cover_delta(int level, ::func_decl* pred_orig) { + SASSERT(false); + return expr_ref(m_ctx.get_manager()); +} + + void dl_interface::add_cover(int level, ::func_decl* pred, ::expr* property) { + SASSERT(false); +} + + unsigned dl_interface::get_num_levels(::func_decl* pred) { + SASSERT(false); + return 0; +} + + void dl_interface::collect_statistics(::statistics& st) const { +} + +void dl_interface::reset_statistics() { +} + +static hash_set *local_func_decls; + +static void print_proof(dl_interface *d, std::ostream& out, Solver::Counterexample &cex) { + context &ctx = d->dd()->ctx; + RPFP::Node &node = *cex.root; + RPFP::Edge &edge = *node.Outgoing; + + // first, prove the children (that are actually used) + + for(unsigned i = 0; i < edge.Children.size(); i++){ + if(!cex.tree->Empty(edge.Children[i])){ + Solver::Counterexample foo = cex; + foo.root = edge.Children[i]; + print_proof(d,out,foo); + } + } + + // print the label and the proved fact + + out << "(step s!" << node.number; + out << " (" << node.Name.name(); + for(unsigned i = 0; i < edge.F.IndParams.size(); i++) + out << " " << cex.tree->Eval(&edge,edge.F.IndParams[i]); + out << ")\n"; + + // print the rule number + + out << " rule!" << node.Outgoing->map->number; + + // print the substitution + + out << " (subst\n"; + RPFP::Edge *orig_edge = edge.map; + int orig_clause = d->dd()->map[orig_edge]; + expr &t = d->dd()->clauses[orig_clause]; + if (t.is_quantifier() && t.is_quantifier_forall()) { + int bound = t.get_quantifier_num_bound(); + std::vector sorts; + std::vector names; + hash_map subst; + for(int j = 0; j < bound; j++){ + sort the_sort = t.get_quantifier_bound_sort(j); + symbol name = t.get_quantifier_bound_name(j); + expr skolem = ctx.constant(symbol(ctx,name),sort(ctx,the_sort)); + out << " (= " << skolem << " " << cex.tree->Eval(&edge,skolem) << ")\n"; + expr local_skolem = cex.tree->Localize(&edge,skolem); + (*local_func_decls).insert(local_skolem.decl()); + } + } + out << " )\n"; + + out << " (labels"; + std::vector labels; + cex.tree->GetLabels(&edge,labels); + for(unsigned j = 0; j < labels.size(); j++){ + out << " " << labels[j]; + } + + out << " )\n"; + + // reference the proofs of all the children, in syntactic order + // "true" means the child is not needed + + out << " (ref "; + for(unsigned i = 0; i < edge.Children.size(); i++){ + if(!cex.tree->Empty(edge.Children[i])) + out << " s!" << edge.Children[i]->number; + else + out << " true"; + } + out << " )"; + out << ")\n"; +} + + + void dl_interface::display_certificate(std::ostream& out) const { + ((dl_interface *)this)->display_certificate_non_const(out); + } + +void dl_interface::display_certificate_non_const(std::ostream& out) { + if(_d->status == StatusModel){ + ast_manager &m = m_ctx.get_manager(); + model_ref md = get_model(); + model_smt2_pp(out, m, *md.get(), 0); + } + else if(_d->status == StatusRefutation){ + out << "(derivation\n"; + // negation of the query is the last clause -- prove it + hash_set locals; + local_func_decls = &locals; + print_proof(this,out,_d->cex); + out << ")\n"; + out << "(model \n\""; + ::model mod(m_ctx.get_manager()); + model orig_model = _d->cex.tree->dualModel; + for(unsigned i = 0; i < orig_model.num_consts(); i++){ + func_decl cnst = orig_model.get_const_decl(i); + if(locals.find(cnst) == locals.end()){ + expr thing = orig_model.get_const_interp(cnst); + mod.register_decl(to_func_decl(cnst.raw()),to_expr(thing.raw())); + } + } + for(unsigned i = 0; i < orig_model.num_funcs(); i++){ + func_decl cnst = orig_model.get_func_decl(i); + if(locals.find(cnst) == locals.end()){ + func_interp thing = orig_model.get_func_interp(cnst); + ::func_interp *thing_raw = thing; + mod.register_decl(to_func_decl(cnst.raw()),thing_raw->copy()); + } + } + model_v2_pp(out,mod); + out << "\")\n"; + } +} + +expr_ref dl_interface::get_answer() { + SASSERT(false); + return expr_ref(m_ctx.get_manager()); +} + +void dl_interface::cancel() { +#if 0 + if(_d && _d->ls) + _d->ls->cancel(); +#else + // HACK: duality can't cancel at all times, we just exit here + std::cout << "(error \"duality canceled\")\nunknown\n"; + abort(); +#endif +} + +void dl_interface::cleanup() { +} + +void dl_interface::updt_params() { +} + +model_ref dl_interface::get_model() { + ast_manager &m = m_ctx.get_manager(); + model_ref md(alloc(::model, m)); + std::vector &nodes = _d->rpfp->nodes; + expr_ref_vector conjs(m); + for (unsigned i = 0; i < nodes.size(); ++i) { + RPFP::Node *node = nodes[i]; + func_decl &pred = node->Name; + expr_ref prop(m); + prop = to_expr(node->Annotation.Formula); + std::vector ¶ms = node->Annotation.IndParams; + expr_ref q(m); + expr_ref_vector sig_vars(m); + for (unsigned j = 0; j < params.size(); ++j) + sig_vars.push_back(params[params.size()-j-1]); // TODO: why backwards? + expr_abstract(m, 0, sig_vars.size(), sig_vars.c_ptr(), prop, q); + if (params.empty()) { + md->register_decl(pred, q); + } + else { + ::func_interp* fi = alloc(::func_interp, m, params.size()); + fi->set_else(q); + md->register_decl(pred, fi); + } + } + return md; +} + +static proof_ref extract_proof(dl_interface *d, Solver::Counterexample &cex) { + context &ctx = d->dd()->ctx; + ast_manager &mgr = ctx.m(); + RPFP::Node &node = *cex.root; + RPFP::Edge &edge = *node.Outgoing; + RPFP::Edge *orig_edge = edge.map; + + // first, prove the children (that are actually used) + + proof_ref_vector prems(mgr); + ::vector substs; + int orig_clause = d->dd()->map[orig_edge]; + expr &t = d->dd()->clauses[orig_clause]; + prems.push_back(mgr.mk_asserted(ctx.uncook(t))); + + substs.push_back(expr_ref_vector(mgr)); + if (t.is_quantifier() && t.is_quantifier_forall()) { + int bound = t.get_quantifier_num_bound(); + std::vector sorts; + std::vector names; + hash_map subst; + for(int j = 0; j < bound; j++){ + sort the_sort = t.get_quantifier_bound_sort(j); + symbol name = t.get_quantifier_bound_name(j); + expr skolem = ctx.constant(symbol(ctx,name),sort(ctx,the_sort)); + expr val = cex.tree->Eval(&edge,skolem); + expr_ref thing(ctx.uncook(val),mgr); + substs[0].push_back(thing); + expr local_skolem = cex.tree->Localize(&edge,skolem); + (*local_func_decls).insert(local_skolem.decl()); + } + } + + svector > pos; + for(unsigned i = 0; i < edge.Children.size(); i++){ + if(!cex.tree->Empty(edge.Children[i])){ + pos.push_back(std::pair(i+1,0)); + Solver::Counterexample foo = cex; + foo.root = edge.Children[i]; + proof_ref prem = extract_proof(d,foo); + prems.push_back(prem); + substs.push_back(expr_ref_vector(mgr)); + } + } + + func_decl f = node.Name; + std::vector args; + for(unsigned i = 0; i < edge.F.IndParams.size(); i++) + args.push_back(cex.tree->Eval(&edge,edge.F.IndParams[i])); + expr conc = f(args); + + + ::vector< ::proof *> pprems; + for(unsigned i = 0; i < prems.size(); i++) + pprems.push_back(prems[i].get()); + + proof_ref res(mgr.mk_hyper_resolve(pprems.size(),&pprems[0], ctx.uncook(conc), pos, substs),mgr); + return res; + +} + +proof_ref dl_interface::get_proof() { + if(_d->status == StatusRefutation){ + hash_set locals; + local_func_decls = &locals; + return extract_proof(this,_d->cex); + } + else + return proof_ref(m_ctx.get_manager()); +} +} diff --git a/src/muz/duality/duality_dl_interface.h b/src/muz/duality/duality_dl_interface.h new file mode 100644 index 000000000..19444ca50 --- /dev/null +++ b/src/muz/duality/duality_dl_interface.h @@ -0,0 +1,80 @@ +/*++ +Copyright (c) 2013 Microsoft Corporation + +Module Name: + + duality_dl_interface.h + +Abstract: + + SMT2 interface for Duality + +Author: + + Krystof Hoder (t-khoder) 2011-9-22. + Modified by Ken McMIllan (kenmcmil) 2013-4-18. + +Revision History: + +--*/ + +#ifndef _DUALITY_DL_INTERFACE_H_ +#define _DUALITY_DL_INTERFACE_H_ + +#include "lbool.h" +#include "dl_rule.h" +#include "dl_rule_set.h" +#include "dl_engine_base.h" +#include "statistics.h" + +namespace datalog { + class context; +} + +namespace Duality { + + class duality_data; + + class dl_interface : public datalog::engine_base { + duality_data *_d; + datalog::context &m_ctx; + + public: + dl_interface(datalog::context& ctx); + ~dl_interface(); + + lbool query(expr* query); + + void cancel(); + + void cleanup(); + + void display_certificate(std::ostream& out) const; + + void collect_statistics(statistics& st) const; + + void reset_statistics(); + + expr_ref get_answer(); + + unsigned get_num_levels(func_decl* pred); + + expr_ref get_cover_delta(int level, func_decl* pred); + + void add_cover(int level, func_decl* pred, expr* property); + + void updt_params(); + + model_ref get_model(); + + proof_ref get_proof(); + + duality_data *dd(){return _d;} + + private: + void display_certificate_non_const(std::ostream& out); + }; +} + + +#endif diff --git a/src/muz/fp/dl_cmds.cpp b/src/muz/fp/dl_cmds.cpp index f6965b084..827b90e60 100644 --- a/src/muz/fp/dl_cmds.cpp +++ b/src/muz/fp/dl_cmds.cpp @@ -454,6 +454,44 @@ public: } }; +/** + \brief fixedpoint-push command. +*/ +class dl_push_cmd : public cmd { + ref m_dl_ctx; +public: + dl_push_cmd(dl_context * dl_ctx): + cmd("fixedpoint-push"), + m_dl_ctx(dl_ctx) + {} + + virtual char const * get_usage() const { return ""; } + virtual char const * get_descr(cmd_context & ctx) const { return "push the fixedpoint context"; } + virtual unsigned get_arity() const { return 0; } + virtual void execute(cmd_context & ctx) { + m_dl_ctx->push(); + } +}; + +/** + \brief fixedpoint-pop command. +*/ +class dl_pop_cmd : public cmd { + ref m_dl_ctx; +public: + dl_pop_cmd(dl_context * dl_ctx): + cmd("fixedpoint-pop"), + m_dl_ctx(dl_ctx) + {} + + virtual char const * get_usage() const { return ""; } + virtual char const * get_descr(cmd_context & ctx) const { return "pop the fixedpoint context"; } + virtual unsigned get_arity() const { return 0; } + virtual void execute(cmd_context & ctx) { + m_dl_ctx->pop(); + } +}; + static void install_dl_cmds_aux(cmd_context& ctx, dl_collected_cmds* collected_cmds) { dl_context * dl_ctx = alloc(dl_context, ctx, collected_cmds); @@ -461,6 +499,13 @@ static void install_dl_cmds_aux(cmd_context& ctx, dl_collected_cmds* collected_c ctx.insert(alloc(dl_query_cmd, dl_ctx)); ctx.insert(alloc(dl_declare_rel_cmd, dl_ctx)); ctx.insert(alloc(dl_declare_var_cmd, dl_ctx)); + // #ifndef _EXTERNAL_RELEASE + // TODO: we need these! +#if 1 + ctx.insert(alloc(dl_push_cmd, dl_ctx)); // not exposed to keep command-extensions simple. + ctx.insert(alloc(dl_pop_cmd, dl_ctx)); +#endif + // #endif } void install_dl_cmds(cmd_context & ctx) { diff --git a/src/muz/fp/dl_register_engine.cpp b/src/muz/fp/dl_register_engine.cpp index 84b64fafa..57413b7cb 100644 --- a/src/muz/fp/dl_register_engine.cpp +++ b/src/muz/fp/dl_register_engine.cpp @@ -22,6 +22,7 @@ Revision History: #include "tab_context.h" #include "rel_context.h" #include "pdr_dl_interface.h" +#include "duality_dl_interface.h" namespace datalog { register_engine::register_engine(): m_ctx(0) {} @@ -40,6 +41,8 @@ namespace datalog { return alloc(tab, *m_ctx); case CLP_ENGINE: return alloc(clp, *m_ctx); + case DUALITY_ENGINE: + return alloc(Duality::dl_interface, *m_ctx); case LAST_ENGINE: UNREACHABLE(); return 0; diff --git a/src/muz/pdr/pdr_context.cpp b/src/muz/pdr/pdr_context.cpp index f530dd7f9..e55806243 100644 --- a/src/muz/pdr/pdr_context.cpp +++ b/src/muz/pdr/pdr_context.cpp @@ -1125,9 +1125,7 @@ namespace pdr { n->mk_instantiate(r0, r1, binding); proof_ref p1(m), p2(m); p1 = r0->get_proof(); - if (!p1) { - r0->display(dctx, std::cout); - } + IF_VERBOSE(0, if (!p1) r0->display(dctx, verbose_stream());); SASSERT(p1); pfs[0] = p1; rls[0] = r1; diff --git a/src/muz/transforms/dl_mk_array_blast.cpp b/src/muz/transforms/dl_mk_array_blast.cpp index 93e31ff23..9303181b6 100644 --- a/src/muz/transforms/dl_mk_array_blast.cpp +++ b/src/muz/transforms/dl_mk_array_blast.cpp @@ -19,6 +19,7 @@ Revision History: #include "dl_mk_array_blast.h" #include "qe_util.h" +#include "scoped_proof.h" namespace datalog { @@ -270,7 +271,7 @@ namespace datalog { } } - expr_ref fml2(m), body(m), head(m); + expr_ref fml1(m), fml2(m), body(m), head(m); body = m.mk_and(new_conjs.size(), new_conjs.c_ptr()); head = r.get_head(); sub(body); @@ -287,9 +288,17 @@ namespace datalog { proof_ref p(m); rule_set new_rules(m_ctx); rm.mk_rule(fml2, p, new_rules, r.name()); + rule_ref new_rule(rm); if (m_simplifier.transform_rule(new_rules.last(), new_rule)) { + if (r.get_proof()) { + scoped_proof _sc(m); + r.to_formula(fml1); + p = m.mk_rewrite(fml1, fml2); + p = m.mk_modus_ponens(r.get_proof(), p); + new_rule->set_proof(m, p); + } rules.add_rule(new_rule.get()); rm.mk_rule_rewrite_proof(r, *new_rule.get()); TRACE("dl", new_rule->display(m_ctx, tout << "new rule\n");); diff --git a/src/muz/transforms/dl_mk_bit_blast.cpp b/src/muz/transforms/dl_mk_bit_blast.cpp index 91481ed5a..112541541 100644 --- a/src/muz/transforms/dl_mk_bit_blast.cpp +++ b/src/muz/transforms/dl_mk_bit_blast.cpp @@ -25,6 +25,7 @@ Revision History: #include "filter_model_converter.h" #include "dl_mk_interp_tail_simplifier.h" #include "fixedpoint_params.hpp" +#include "scoped_proof.h" namespace datalog { @@ -268,7 +269,8 @@ namespace datalog { r->to_formula(fml); if (blast(r, fml)) { proof_ref pr(m); - if (m_context.generate_proof_trace()) { + if (r->get_proof()) { + scoped_proof _sc(m); pr = m.mk_asserted(fml); // loses original proof of r. } // TODO add logic for pc: diff --git a/src/muz/transforms/dl_mk_quantifier_abstraction.cpp b/src/muz/transforms/dl_mk_quantifier_abstraction.cpp index 0ad4d61ab..57948a2ff 100644 --- a/src/muz/transforms/dl_mk_quantifier_abstraction.cpp +++ b/src/muz/transforms/dl_mk_quantifier_abstraction.cpp @@ -341,7 +341,6 @@ namespace datalog { } head = mk_head(source, *result, r.get_head(), cnt); fml = m.mk_implies(m.mk_and(tail.size(), tail.c_ptr()), head); - rule_ref_vector added_rules(rm); proof_ref pr(m); rm.mk_rule(fml, pr, *result); TRACE("dl", result->last()->display(m_ctx, tout);); diff --git a/src/muz/transforms/dl_mk_rule_inliner.cpp b/src/muz/transforms/dl_mk_rule_inliner.cpp index 3e933b099..d9dad5c56 100644 --- a/src/muz/transforms/dl_mk_rule_inliner.cpp +++ b/src/muz/transforms/dl_mk_rule_inliner.cpp @@ -527,6 +527,9 @@ namespace datalog { bool mk_rule_inliner::do_eager_inlining(rule * r, rule_set const& rules, rule_ref& res) { + if (r->has_negation()) { + return false; + } SASSERT(rules.is_closed()); const rule_stratifier& strat = rules.get_stratifier(); diff --git a/src/opt/weighted_maxsat.cpp b/src/opt/weighted_maxsat.cpp index c5579bdb6..8bf5e8430 100644 --- a/src/opt/weighted_maxsat.cpp +++ b/src/opt/weighted_maxsat.cpp @@ -416,6 +416,9 @@ namespace opt { if (m_engine == symbol("wpm2")) { return wpm2_solve(); } + if (m_engine == symbol("wpm2b")) { + return wpm2b_solve(); + } return incremental_solve(); } @@ -664,12 +667,6 @@ namespace opt { m_lower = m_upper = rational::zero(); obj_map ans_index; -#ifdef WPM2b - // change from CP'13 - for (unsigned i = 0; i < s.get_num_assertions(); ++i) { - al.push_back(s.get_assertion(i)); - } -#endif vector amk; vector sc; for (unsigned i = 0; i < m_soft.size(); ++i) { @@ -758,10 +755,7 @@ namespace opt { } } rational k; - std::cout << "new bound"; is_sat = new_bound(al, ws, bs, k); - std::cout << " " << k << "\n"; - std::cout.flush(); if (is_sat != l_true) { return is_sat; } @@ -783,6 +777,177 @@ namespace opt { amk.push_back(k); } } + + // Version from CP'13 + lbool wpm2b_solve() { + solver::scoped_push _s(s); + pb_util u(m); + app_ref fml(m), a(m), b(m), c(m); + expr_ref val(m); + expr_ref_vector block(m), ans(m), am(m), soft(m); + m_lower = m_upper = rational::zero(); + obj_map ans_index; + + vector amk; + vector sc; // vector of indices used in at last constraints + expr_ref_vector al(m); // vector of at least constraints. + rational wmax; + for (unsigned i = 0; i < m_soft.size(); ++i) { + rational w = m_weights[i]; + m_upper += w; + if (wmax < w) wmax = w; + b = m.mk_fresh_const("b", m.mk_bool_sort()); + expr* bb = b; + s.mc().insert(b->get_decl()); + a = m.mk_fresh_const("a", m.mk_bool_sort()); + s.mc().insert(a->get_decl()); + ans.push_back(a); + ans_index.insert(a, i); + soft.push_back(0); // assert soft constraints lazily. + + c = m.mk_fresh_const("c", m.mk_bool_sort()); + s.mc().insert(c->get_decl()); + fml = m.mk_implies(c, u.mk_le(1,&w,&bb,rational(0))); + s.assert_expr(fml); + + sc.push_back(uint_set()); + sc.back().insert(i); + am.push_back(c); + + al.push_back(u.mk_ge(1,&w,&bb,rational(0))); + s.assert_expr(al.back()); + + amk.push_back(rational(0)); + } + ++wmax; + + while (true) { + enable_soft(soft, block, ans, wmax); + expr_ref_vector asms(m); + asms.append(ans); + asms.append(am); + lbool is_sat = s.check_sat(asms.size(), asms.c_ptr()); + if (m_cancel && is_sat != l_false) { + is_sat = l_undef; + } + if (is_sat == l_undef) { + return l_undef; + } + if (is_sat == l_true && wmax.is_zero()) { + s.get_model(m_model); + m_upper = m_lower; + for (unsigned i = 0; i < block.size(); ++i) { + VERIFY(m_model->eval(block[i].get(), val)); + m_assignment[i] = m.is_false(val); + } + return l_true; + } + if (is_sat == l_true) { + rational W(0); + for (unsigned i = 0; i < m_weights.size(); ++i) { + if (m_weights[i] < wmax) W += m_weights[i]; + } + harden(am, W); + wmax = decrease(wmax); + continue; + } + SASSERT(is_sat == l_false); + ptr_vector core; + s.get_unsat_core(core); + if (core.empty()) { + return l_false; + } + uint_set A; + for (unsigned i = 0; i < core.size(); ++i) { + unsigned j; + if (ans_index.find(core[i], j) && soft[j].get()) { + A.insert(j); + } + } + if (A.empty()) { + return l_false; + } + uint_set B; + for (unsigned i = 0; i < sc.size(); ++i) { + uint_set t(sc[i]); + t &= A; + if (!t.empty()) { + B |= sc[i]; + m_lower -= amk[i]; + sc[i] = sc.back(); + sc.pop_back(); + am[i] = am.back(); + am.pop_back(); + amk[i] = amk.back(); + amk.pop_back(); + --i; + } + } + vector ws; + expr_ref_vector bs(m); + for (unsigned i = 0; i < m_soft.size(); ++i) { + if (B.contains(i)) { + ws.push_back(m_weights[i]); + bs.push_back(block[i].get()); + } + } + rational k; + + expr_ref_vector al2(al); + for (unsigned i = 0; i < s.get_num_assertions(); ++i) { + al2.push_back(s.get_assertion(i)); + } + is_sat = new_bound(al2, ws, bs, k); + if (is_sat != l_true) { + return is_sat; + } + m_lower += k; + expr_ref B_le_k(m), B_ge_k(m); + B_le_k = u.mk_le(ws.size(), ws.c_ptr(), bs.c_ptr(), k); + B_ge_k = u.mk_ge(ws.size(), ws.c_ptr(), bs.c_ptr(), k); + s.assert_expr(B_ge_k); + al.push_back(B_ge_k); + IF_VERBOSE(1, verbose_stream() << "(wmaxsat.wpm2 lower bound: " << m_lower << ")\n";); + IF_VERBOSE(2, verbose_stream() << "New lower bound: " << B_ge_k << "\n";); + + c = m.mk_fresh_const("c", m.mk_bool_sort()); + s.mc().insert(c->get_decl()); + fml = m.mk_implies(c, B_le_k); + s.assert_expr(fml); + sc.push_back(B); + am.push_back(c); + amk.push_back(k); + } + } + + void harden(expr_ref_vector& am, rational const& W) { + // TBD + } + + rational decrease(rational const& wmax) { + rational wmin(0); + for (unsigned i = 0; i < m_weights.size(); ++i) { + rational w = m_weights[i]; + if (w < wmax && wmin < w) wmin = w; + } + return wmin; + } + + + // enable soft constraints that have reached wmax. + void enable_soft(expr_ref_vector& soft, + expr_ref_vector const& block, + expr_ref_vector const& ans, + rational wmax) { + for (unsigned i = 0; i < m_soft.size(); ++i) { + rational w = m_weights[i]; + if (w >= wmax && !soft[i].get()) { + soft[i] = m.mk_or(m_soft[i].get(), block[i], m.mk_not(ans[i])); + s.assert_expr(soft[i].get()); + } + } + } + lbool new_bound(expr_ref_vector const& al, vector const& ws, diff --git a/src/sat/sat_simplifier.cpp b/src/sat/sat_simplifier.cpp index ba0c67235..f20ffa7c7 100644 --- a/src/sat/sat_simplifier.cpp +++ b/src/sat/sat_simplifier.cpp @@ -9,7 +9,7 @@ Abstract: SAT simplification procedures that use a "full" occurrence list: Subsumption, Blocked Clause Removal, Variable Elimination, ... - + Author: @@ -54,21 +54,21 @@ namespace sat { m_use_list[l2.index()].erase(c); } } - + simplifier::simplifier(solver & _s, params_ref const & p): s(_s), m_num_calls(0) { updt_params(p); reset_statistics(); } - + simplifier::~simplifier() { } inline watch_list & simplifier::get_wlist(literal l) { return s.get_wlist(l); } inline watch_list const & simplifier::get_wlist(literal l) const { return s.get_wlist(l); } - + inline bool simplifier::is_external(bool_var v) const { return s.is_external(v); } inline bool simplifier::was_eliminated(bool_var v) const { return s.was_eliminated(v); } @@ -78,7 +78,7 @@ namespace sat { lbool simplifier::value(literal l) const { return s.value(l); } inline void simplifier::checkpoint() { s.checkpoint(); } - + void simplifier::register_clauses(clause_vector & cs) { std::stable_sort(cs.begin(), cs.end(), size_lt()); clause_vector::iterator it = cs.begin(); @@ -117,7 +117,7 @@ namespace sat { SASSERT(s.get_wlist(~l1).contains(watched(l2, learned))); s.get_wlist(~l1).erase(watched(l2, learned)); } - + void simplifier::init_visited() { m_visited.reset(); m_visited.resize(2*s.num_vars(), false); @@ -155,7 +155,7 @@ namespace sat { if (!learned && (m_elim_blocked_clauses || m_elim_blocked_clauses_at == m_num_calls)) elim_blocked_clauses(); - + if (!learned) m_num_calls++; @@ -180,6 +180,7 @@ namespace sat { bool vars_eliminated = m_num_elim_vars > old_num_elim_vars; if (!m_need_cleanup) { + TRACE("after_simplifier", tout << "skipping cleanup...\n";); if (vars_eliminated) { // must remove learned clauses with eliminated variables cleanup_clauses(s.m_learned, true, true, learned_in_use_lists); @@ -189,6 +190,7 @@ namespace sat { free_memory(); return; } + TRACE("after_simplifier", tout << "cleanning watches...\n";); cleanup_watches(); cleanup_clauses(s.m_learned, true, vars_eliminated, learned_in_use_lists); cleanup_clauses(s.m_clauses, false, vars_eliminated, true); @@ -234,7 +236,7 @@ namespace sat { s.del_clause(c); continue; } - + if (learned && vars_eliminated) { unsigned sz = c.size(); unsigned i; @@ -293,7 +295,7 @@ namespace sat { mark_visited(c[i]); } } - + void simplifier::unmark_all(clause const & c) { unsigned sz = c.size(); for (unsigned i = 0; i < sz; i++) @@ -325,7 +327,7 @@ namespace sat { */ bool simplifier::subsumes1(clause const & c1, clause const & c2, literal & l) { unsigned sz2 = c2.size(); - for (unsigned i = 0; i < sz2; i++) + for (unsigned i = 0; i < sz2; i++) mark_visited(c2[i]); bool r = true; @@ -344,7 +346,7 @@ namespace sat { } } - for (unsigned i = 0; i < sz2; i++) + for (unsigned i = 0; i < sz2; i++) unmark_visited(c2[i]); return r; } @@ -353,7 +355,7 @@ namespace sat { \brief Return the clauses subsumed by c1 and the clauses that can be subsumed resolved using c1. The collections is populated using the use list of target. */ - void simplifier::collect_subsumed1_core(clause const & c1, clause_vector & out, literal_vector & out_lits, + void simplifier::collect_subsumed1_core(clause const & c1, clause_vector & out, literal_vector & out_lits, literal target) { clause_use_list const & cs = m_use_list.get(target); clause_use_list::iterator it = cs.mk_iterator(); @@ -362,7 +364,7 @@ namespace sat { CTRACE("resolution_bug", c2.was_removed(), tout << "clause has been removed:\n" << c2 << "\n";); SASSERT(!c2.was_removed()); if (&c2 != &c1 && - c1.size() <= c2.size() && + c1.size() <= c2.size() && approx_subset(c1.approx(), c2.approx())) { m_sub_counter -= c1.size() + c2.size(); literal l; @@ -373,7 +375,7 @@ namespace sat { } it.next(); } - } + } /** \brief Return the clauses subsumed by c1 and the clauses that can be subsumed resolved using c1. @@ -400,12 +402,12 @@ namespace sat { if (*l_it == null_literal) { // c2 was subsumed if (c1.is_learned() && !c2.is_learned()) - c1.unset_learned(); + c1.unset_learned(); TRACE("subsumption", tout << c1 << " subsumed " << c2 << "\n";); remove_clause(c2); m_num_subsumed++; } - else { + else if (!c2.was_removed()) { // subsumption resolution TRACE("subsumption_resolution", tout << c1 << " sub-ref(" << *l_it << ") " << c2 << "\n";); elim_lit(c2, *l_it); @@ -447,9 +449,9 @@ namespace sat { */ bool simplifier::subsumes0(clause const & c1, clause const & c2) { unsigned sz2 = c2.size(); - for (unsigned i = 0; i < sz2; i++) + for (unsigned i = 0; i < sz2; i++) mark_visited(c2[i]); - + bool r = true; unsigned sz1 = c1.size(); for (unsigned i = 0; i < sz1; i++) { @@ -459,12 +461,12 @@ namespace sat { } } - for (unsigned i = 0; i < sz2; i++) + for (unsigned i = 0; i < sz2; i++) unmark_visited(c2[i]); - + return r; } - + /** \brief Collect the clauses subsumed by c1 (using the occurrence list of target). */ @@ -475,7 +477,7 @@ namespace sat { clause & c2 = it.curr(); SASSERT(!c2.was_removed()); if (&c2 != &c1 && - c1.size() <= c2.size() && + c1.size() <= c2.size() && approx_subset(c1.approx(), c2.approx())) { m_sub_counter -= c1.size() + c2.size(); if (subsumes0(c1, c2)) { @@ -485,7 +487,7 @@ namespace sat { it.next(); } } - + /** \brief Collect the clauses subsumed by c1 */ @@ -493,8 +495,8 @@ namespace sat { literal l = get_min_occ_var0(c1); collect_subsumed0_core(c1, out, l); } - - + + /** \brief Perform backward subsumption using c1. */ @@ -507,16 +509,16 @@ namespace sat { clause & c2 = *(*it); // c2 was subsumed if (c1.is_learned() && !c2.is_learned()) - c1.unset_learned(); + c1.unset_learned(); TRACE("subsumption", tout << c1 << " subsumed " << c2 << "\n";); remove_clause(c2); m_num_subsumed++; } } - + /** \brief Eliminate false literals from c, and update occurrence lists - + Return true if the clause is satisfied */ bool simplifier::cleanup_clause(clause & c, bool in_use_list) { @@ -666,7 +668,7 @@ namespace sat { back_subsumption1(c); if (w.is_learned() && !c.is_learned()) { SASSERT(wlist[j] == w); - TRACE("mark_not_learned_bug", + TRACE("mark_not_learned_bug", tout << "marking as not learned: " << l2 << " " << wlist[j].is_learned() << "\n";); wlist[j].mark_not_learned(); mark_as_not_learned_core(get_wlist(~l2), l); @@ -735,7 +737,7 @@ namespace sat { continue; } if (it2->get_literal() == last_lit) { - TRACE("subsumption", tout << "eliminating: " << ~to_literal(l_idx) + TRACE("subsumption", tout << "eliminating: " << ~to_literal(l_idx) << " " << it2->get_literal() << "\n";); elim++; } @@ -762,12 +764,12 @@ namespace sat { m_num_sub_res(s.m_num_sub_res) { m_watch.start(); } - + ~subsumption_report() { m_watch.stop(); - IF_VERBOSE(SAT_VB_LVL, + IF_VERBOSE(SAT_VB_LVL, verbose_stream() << " (sat-subsumer :subsumed " - << (m_simplifier.m_num_subsumed - m_num_subsumed) + << (m_simplifier.m_num_subsumed - m_num_subsumed) << " :subsumption-resolution " << (m_simplifier.m_num_sub_res - m_num_sub_res) << " :threshold " << m_simplifier.m_sub_counter << mem_stat() @@ -847,12 +849,12 @@ namespace sat { vector const & m_watches; public: literal_lt(use_list const & l, vector const & ws):m_use_list(l), m_watches(ws) {} - + unsigned weight(unsigned l_idx) const { return 2*m_use_list.get(~to_literal(l_idx)).size() + m_watches[l_idx].size(); } - - bool operator()(unsigned l_idx1, unsigned l_idx2) const { + + bool operator()(unsigned l_idx1, unsigned l_idx2) const { return weight(l_idx1) < weight(l_idx2); } }; @@ -861,9 +863,9 @@ namespace sat { heap m_queue; public: queue(use_list const & l, vector const & ws):m_queue(128, literal_lt(l, ws)) {} - void insert(literal l) { + void insert(literal l) { unsigned idx = l.index(); - m_queue.reserve(idx + 1); + m_queue.reserve(idx + 1); SASSERT(!m_queue.contains(idx)); m_queue.insert(idx); } @@ -877,14 +879,14 @@ namespace sat { literal next() { SASSERT(!empty()); return to_literal(m_queue.erase_min()); } bool empty() const { return m_queue.empty(); } }; - + simplifier & s; int m_counter; model_converter & mc; queue m_queue; clause_vector m_to_remove; - blocked_clause_elim(simplifier & _s, unsigned limit, model_converter & _mc, use_list & l, + blocked_clause_elim(simplifier & _s, unsigned limit, model_converter & _mc, use_list & l, vector & wlist): s(_s), m_counter(limit), @@ -946,7 +948,7 @@ namespace sat { clause_vector::iterator it = m_to_remove.begin(); clause_vector::iterator end = m_to_remove.end(); for (; it != end; ++it) { - s.remove_clause(*(*it)); + s.remove_clause(*(*it)); } } { @@ -1025,12 +1027,12 @@ namespace sat { m_num_blocked_clauses(s.m_num_blocked_clauses) { m_watch.start(); } - + ~blocked_cls_report() { m_watch.stop(); - IF_VERBOSE(SAT_VB_LVL, + IF_VERBOSE(SAT_VB_LVL, verbose_stream() << " (sat-blocked-clauses :elim-blocked-clauses " - << (m_simplifier.m_num_blocked_clauses - m_num_blocked_clauses) + << (m_simplifier.m_num_blocked_clauses - m_num_blocked_clauses) << mem_stat() << " :time " << std::fixed << std::setprecision(2) << m_watch.get_seconds() << ")\n";); } @@ -1062,8 +1064,8 @@ namespace sat { unsigned num_neg = m_use_list.get(neg_l).size(); unsigned num_bin_pos = get_num_non_learned_bin(pos_l); unsigned num_bin_neg = get_num_non_learned_bin(neg_l); - unsigned cost = 2 * num_pos * num_neg + num_pos * num_bin_neg + num_neg * num_bin_pos; - CTRACE("elim_vars_detail", cost == 0, tout << v << " num_pos: " << num_pos << " num_neg: " << num_neg << " num_bin_pos: " << num_bin_pos + unsigned cost = 2 * num_pos * num_neg + num_pos * num_bin_neg + num_neg * num_bin_pos; + CTRACE("elim_vars_detail", cost == 0, tout << v << " num_pos: " << num_pos << " num_neg: " << num_neg << " num_bin_pos: " << num_bin_pos << " num_bin_neg: " << num_bin_neg << " cost: " << cost << "\n";); return cost; } @@ -1071,7 +1073,7 @@ namespace sat { typedef std::pair bool_var_and_cost; struct bool_var_and_cost_lt { - bool operator()(bool_var_and_cost const & p1, bool_var_and_cost const & p2) const { return p1.second < p2.second; } + bool operator()(bool_var_and_cost const & p1, bool_var_and_cost const & p2) const { return p1.second < p2.second; } }; void simplifier::order_vars_for_elim(bool_var_vector & r) { @@ -1104,7 +1106,7 @@ namespace sat { r.push_back(it2->first); } } - + /** \brief Collect clauses and binary clauses containing l. */ @@ -1116,7 +1118,7 @@ namespace sat { SASSERT(r.back().size() == it.curr().size()); it.next(); } - + watch_list & wlist = get_wlist(~l); watch_list::iterator it2 = wlist.begin(); watch_list::iterator end2 = wlist.end(); @@ -1129,7 +1131,7 @@ namespace sat { } /** - \brief Resolve clauses c1 and c2. + \brief Resolve clauses c1 and c2. c1 must contain l. c2 must contain ~l. Store result in r. @@ -1149,7 +1151,7 @@ namespace sat { m_visited[l2.index()] = true; r.push_back(l2); } - + literal not_l = ~l; sz = c2.size(); m_elim_counter -= sz; @@ -1164,7 +1166,7 @@ namespace sat { if (!m_visited[l2.index()]) r.push_back(l2); } - + sz = c1.size(); for (unsigned i = 0; i < sz; i++) { literal l2 = c1[i]; @@ -1200,7 +1202,7 @@ namespace sat { break; } } - CTRACE("resolve_bug", it2 == end2, + CTRACE("resolve_bug", it2 == end2, tout << ~l1 << " -> "; display(tout, s.m_cls_allocator, wlist1); tout << "\n" << ~l2 << " -> "; display(tout, s.m_cls_allocator, wlist2); tout << "\n";); @@ -1262,7 +1264,7 @@ namespace sat { TRACE("resolution_bug", tout << "processing: " << v << "\n";); if (value(v) != l_undef) return false; - + literal pos_l(v, false); literal neg_l(v, true); unsigned num_bin_pos = get_num_non_learned_bin(pos_l); @@ -1274,12 +1276,12 @@ namespace sat { m_elim_counter -= num_pos + num_neg; TRACE("resolution", tout << v << " num_pos: " << num_pos << " neg_pos: " << num_neg << "\n";); - + if (num_pos >= m_res_occ_cutoff && num_neg >= m_res_occ_cutoff) return false; unsigned before_lits = num_bin_pos*2 + num_bin_neg*2; - + { clause_use_list::iterator it = pos_occs.mk_iterator(); while (!it.at_end()) { @@ -1287,7 +1289,7 @@ namespace sat { it.next(); } } - + { clause_use_list::iterator it2 = neg_occs.mk_iterator(); while (!it2.at_end()) { @@ -1297,23 +1299,23 @@ namespace sat { } TRACE("resolution", tout << v << " num_pos: " << num_pos << " neg_pos: " << num_neg << " before_lits: " << before_lits << "\n";); - + if (num_pos >= m_res_occ_cutoff3 && num_neg >= m_res_occ_cutoff3 && before_lits > m_res_lit_cutoff3 && s.m_clauses.size() > m_res_cls_cutoff2) return false; - if (num_pos >= m_res_occ_cutoff2 && num_neg >= m_res_occ_cutoff2 && before_lits > m_res_lit_cutoff2 && + if (num_pos >= m_res_occ_cutoff2 && num_neg >= m_res_occ_cutoff2 && before_lits > m_res_lit_cutoff2 && s.m_clauses.size() > m_res_cls_cutoff1 && s.m_clauses.size() <= m_res_cls_cutoff2) return false; - if (num_pos >= m_res_occ_cutoff1 && num_neg >= m_res_occ_cutoff1 && before_lits > m_res_lit_cutoff1 && + if (num_pos >= m_res_occ_cutoff1 && num_neg >= m_res_occ_cutoff1 && before_lits > m_res_lit_cutoff1 && s.m_clauses.size() <= m_res_cls_cutoff1) return false; - + m_pos_cls.reset(); m_neg_cls.reset(); collect_clauses(pos_l, m_pos_cls); collect_clauses(neg_l, m_neg_cls); - + m_elim_counter -= num_pos * num_neg + before_lits; - + TRACE("resolution_detail", tout << "collecting number of after_clauses\n";); unsigned before_clauses = num_pos + num_neg; unsigned after_clauses = 0; @@ -1350,7 +1352,7 @@ namespace sat { neg_occs.reset(); m_elim_counter -= num_pos * num_neg + before_lits; - + it1 = m_pos_cls.begin(); end1 = m_pos_cls.end(); for (; it1 != end1; ++it1) { @@ -1393,7 +1395,7 @@ namespace sat { return true; } } - + return true; } @@ -1406,10 +1408,10 @@ namespace sat { m_num_elim_vars(s.m_num_elim_vars) { m_watch.start(); } - + ~elim_var_report() { m_watch.stop(); - IF_VERBOSE(SAT_VB_LVL, + IF_VERBOSE(SAT_VB_LVL, verbose_stream() << " (sat-resolution :elim-bool-vars " << (m_simplifier.m_num_elim_vars - m_num_elim_vars) << " :threshold " << m_simplifier.m_elim_counter @@ -1417,12 +1419,12 @@ namespace sat { << " :time " << std::fixed << std::setprecision(2) << m_watch.get_seconds() << ")\n";); } }; - + void simplifier::elim_vars() { elim_var_report rpt(*this); bool_var_vector vars; order_vars_for_elim(vars); - + bool_var_vector::iterator it = vars.begin(); bool_var_vector::iterator end = vars.end(); for (; it != end; ++it) { @@ -1463,7 +1465,7 @@ namespace sat { void simplifier::collect_param_descrs(param_descrs & r) { sat_simplifier_params::collect_param_descrs(r); } - + void simplifier::collect_statistics(statistics & st) { st.update("subsumed", m_num_subsumed); st.update("subsumption resolution", m_num_sub_res); @@ -1471,7 +1473,7 @@ namespace sat { st.update("elim bool vars", m_num_elim_vars); st.update("elim blocked clauses", m_num_blocked_clauses); } - + void simplifier::reset_statistics() { m_num_blocked_clauses = 0; m_num_subsumed = 0; diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index ec00be897..f57623774 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -27,7 +27,7 @@ Revision History: // define to create a copy of the solver before starting the search // useful for checking models // #define CLONE_BEFORE_SOLVING - + namespace sat { solver::solver(params_ref const & p, extension * ext): @@ -103,7 +103,7 @@ namespace sat { } } } - + // ----------------------- // // Variable & Clause creation @@ -312,7 +312,7 @@ namespace sat { /** \brief Select a watch literal starting the search at the given position. This method is only used for clauses created during the search. - + I use the following rules to select a watch literal. 1- select a literal l in idx >= starting_at such that value(l) = l_true, @@ -329,7 +329,7 @@ namespace sat { lvl(l) >= lvl(l') Without rule 3, boolean propagation is incomplete, that is, it may miss possible propagations. - + \remark The method select_lemma_watch_lit is used to select the watch literal for regular learned clauses. */ unsigned solver::select_watch_lit(clause const & cls, unsigned starting_at) const { @@ -443,7 +443,7 @@ namespace sat { erase_clause_watch(get_wlist(~c[0]), cls_off); erase_clause_watch(get_wlist(~c[1]), cls_off); } - + void solver::dettach_ter_clause(clause & c) { erase_ternary_watch(get_wlist(~c[0]), c[1], c[2]); erase_ternary_watch(get_wlist(~c[1]), c[0], c[2]); @@ -498,10 +498,10 @@ namespace sat { unsigned sz = c.size(); for (unsigned i = 0; i < sz; i++) { switch (value(c[i])) { - case l_true: + case l_true: return l_true; - case l_undef: - found_undef = true; + case l_undef: + found_undef = true; break; default: break; @@ -515,7 +515,7 @@ namespace sat { // Propagation // // ----------------------- - + bool solver::propagate_core(bool update) { if (m_inconsistent) return false; @@ -545,7 +545,7 @@ namespace sat { } for (; it != end; ++it) { switch (it->get_kind()) { - case watched::BINARY: + case watched::BINARY: l1 = it->get_literal(); switch (value(l1)) { case l_false: @@ -585,15 +585,30 @@ namespace sat { break; case watched::CLAUSE: { if (value(it->get_blocked_literal()) == l_true) { + TRACE("propagate_clause_bug", tout << "blocked literal " << it->get_blocked_literal() << "\n"; + clause_offset cls_off = it->get_clause_offset(); + clause & c = *(m_cls_allocator.get_clause(cls_off)); + tout << c << "\n";); *it2 = *it; it2++; break; } clause_offset cls_off = it->get_clause_offset(); clause & c = *(m_cls_allocator.get_clause(cls_off)); + TRACE("propagate_clause_bug", tout << "processing... " << c << "\nwas_removed: " << c.was_removed() << "\n";); if (c[0] == not_l) std::swap(c[0], c[1]); CTRACE("propagate_bug", c[1] != not_l, tout << "l: " << l << " " << c << "\n";); + if (c.was_removed() || c[1] != not_l) { + // Remark: this method may be invoked when the watch lists are not in a consistent state, + // and may contain dead/removed clauses, or clauses with removed literals. + // See: method propagate_unit at sat_simplifier.cpp + // So, we must check whether the clause was marked for deletion, or + // c[1] != not_l + *it2 = *it; + it2++; + break; + } SASSERT(c[1] == not_l); if (value(c[0]) == l_true) { it2->set_clause(c[0], cls_off); @@ -694,7 +709,7 @@ namespace sat { m_conflicts_since_restart = 0; m_restart_threshold = m_config.m_restart_initial; } - + // iff3_finder(*this)(); simplify_problem(); @@ -705,10 +720,10 @@ namespace sat { IF_VERBOSE(SAT_VB_LVL, verbose_stream() << "\"abort: max-conflicts = 0\"\n";); return l_undef; } - + while (true) { SASSERT(!inconsistent()); - + lbool r = bounded_search(); if (r != l_undef) return r; @@ -717,7 +732,7 @@ namespace sat { IF_VERBOSE(SAT_VB_LVL, verbose_stream() << "\"abort: max-conflicts = " << m_conflicts << "\"\n";); return l_undef; } - + restart(); if (m_conflicts >= m_next_simplify) { simplify_problem(); @@ -735,7 +750,7 @@ namespace sat { bool_var solver::next_var() { bool_var next; - + if (m_rand() < static_cast(m_config.m_random_freq * random_gen::max_value())) { if (num_vars() == 0) return null_bool_var; @@ -744,16 +759,16 @@ namespace sat { if (value(next) == l_undef && !was_eliminated(next)) return next; } - + while (!m_case_split_queue.empty()) { next = m_case_split_queue.next_var(); if (value(next) == l_undef && !was_eliminated(next)) return next; } - + return null_bool_var; } - + bool solver::decide() { bool_var next = next_var(); if (next == null_bool_var) @@ -761,7 +776,7 @@ namespace sat { push(); m_stats.m_decision++; lbool phase = m_ext ? m_ext->get_phase(next) : l_undef; - + if (phase == l_undef) { switch (m_config.m_phase) { case PS_ALWAYS_TRUE: @@ -785,7 +800,7 @@ namespace sat { break; } } - + SASSERT(phase != l_undef); literal next_lit(next, phase == l_false); assign(next_lit, justification()); @@ -808,23 +823,23 @@ namespace sat { return l_undef; if (scope_lvl() == 0) { cleanup(); // cleaner may propagate frozen clauses - if (inconsistent()) + if (inconsistent()) return l_false; gc(); } } - + gc(); if (!decide()) { if (m_ext) { switch (m_ext->check()) { - case CR_DONE: + case CR_DONE: mk_model(); return l_true; - case CR_CONTINUE: + case CR_CONTINUE: break; - case CR_GIVEUP: + case CR_GIVEUP: throw abort_solver(); } } @@ -850,23 +865,23 @@ namespace sat { m_stopwatch.reset(); m_stopwatch.start(); } - + /** \brief Apply all simplifications. */ void solver::simplify_problem() { SASSERT(scope_lvl() == 0); - + m_cleaner(); CASSERT("sat_simplify_bug", check_invariant()); - + m_scc(); CASSERT("sat_simplify_bug", check_invariant()); - - m_simplifier(false); + + m_simplifier(false); CASSERT("sat_simplify_bug", check_invariant()); CASSERT("sat_missed_prop", check_missed_propagation()); - + if (!m_learned.empty()) { m_simplifier(true); CASSERT("sat_missed_prop", check_missed_propagation()); @@ -879,11 +894,11 @@ namespace sat { m_probing(); CASSERT("sat_missed_prop", check_missed_propagation()); CASSERT("sat_simplify_bug", check_invariant()); - + m_asymm_branch(); CASSERT("sat_missed_prop", check_missed_propagation()); CASSERT("sat_simplify_bug", check_invariant()); - + if (m_ext) { m_ext->clauses_modifed(); m_ext->simplify(); @@ -957,7 +972,7 @@ namespace sat { } } } - if (!m_mc.check_model(m)) + if (!m_mc.check_model(m)) ok = false; CTRACE("sat_model_bug", !ok, tout << m << "\n";); return ok; @@ -965,9 +980,9 @@ namespace sat { void solver::restart() { m_stats.m_restart++; - IF_VERBOSE(1, + IF_VERBOSE(1, verbose_stream() << "(sat-restart :conflicts " << m_stats.m_conflict << " :decisions " << m_stats.m_decision - << " :restarts " << m_stats.m_restart << mk_stat(*this) + << " :restarts " << m_stats.m_restart << mk_stat(*this) << " :time " << std::fixed << std::setprecision(2) << m_stopwatch.get_current_seconds() << ")\n";); IF_VERBOSE(30, display_status(verbose_stream());); pop(scope_lvl()); @@ -992,9 +1007,9 @@ namespace sat { // GC // // ----------------------- - + void solver::gc() { - if (m_conflicts_since_gc <= m_gc_threshold) + if (m_conflicts_since_gc <= m_gc_threshold) return; CASSERT("sat_gc_bug", check_invariant()); switch (m_config.m_gc_strategy) { @@ -1074,7 +1089,7 @@ namespace sat { std::stable_sort(m_learned.begin(), m_learned.end(), glue_lt()); gc_half("glue"); } - + void solver::gc_psm() { save_psm(); std::stable_sort(m_learned.begin(), m_learned.end(), psm_lt()); @@ -1135,8 +1150,8 @@ namespace sat { void solver::gc_dyn_psm() { // To do gc at scope_lvl() > 0, I will need to use the reinitialization stack, or live with the fact // that I may miss some propagations for reactivated clauses. - SASSERT(scope_lvl() == 0); - // compute + SASSERT(scope_lvl() == 0); + // compute // d_tk unsigned h = 0; unsigned V_tk = 0; @@ -1153,7 +1168,7 @@ namespace sat { double d_tk = V_tk == 0 ? static_cast(num_vars() + 1) : static_cast(h)/static_cast(V_tk); if (d_tk < m_min_d_tk) m_min_d_tk = d_tk; - TRACE("sat_frozen", tout << "m_min_d_tk: " << m_min_d_tk << "\n";); + TRACE("sat_frozen", tout << "m_min_d_tk: " << m_min_d_tk << "\n";); unsigned frozen = 0; unsigned deleted = 0; unsigned activated = 0; @@ -1219,15 +1234,15 @@ namespace sat { ++it2; } m_learned.set_end(it2); - IF_VERBOSE(SAT_VB_LVL, verbose_stream() << "(sat-gc :d_tk " << d_tk << " :min-d_tk " << m_min_d_tk << + IF_VERBOSE(SAT_VB_LVL, verbose_stream() << "(sat-gc :d_tk " << d_tk << " :min-d_tk " << m_min_d_tk << " :frozen " << frozen << " :activated " << activated << " :deleted " << deleted << ")\n";); } - + // return true if should keep the clause, and false if we should delete it. - bool solver::activate_frozen_clause(clause & c) { + bool solver::activate_frozen_clause(clause & c) { TRACE("sat_gc", tout << "reactivating:\n" << c << "\n";); SASSERT(scope_lvl() == 0); - // do some cleanup + // do some cleanup unsigned sz = c.size(); unsigned j = 0; for (unsigned i = 0; i < sz; i++) { @@ -1291,13 +1306,13 @@ namespace sat { bool solver::resolve_conflict() { while (true) { bool r = resolve_conflict_core(); + CASSERT("sat_check_marks", check_marks()); // after pop, clauses are reinitialized, this may trigger another conflict. - if (!r) + if (!r) return false; if (!inconsistent()) return true; } - CASSERT("sat_check_marks", check_marks()); } bool solver::resolve_conflict_core() { @@ -1312,7 +1327,7 @@ namespace sat { if (m_conflict_lvl == 0) return false; m_lemma.reset(); - + forget_phase_of_vars(m_conflict_lvl); unsigned idx = skip_literals_above_conflict_level(); @@ -1324,10 +1339,10 @@ namespace sat { TRACE("sat_conflict", tout << "not_l: " << m_not_l << "\n";); process_antecedent(m_not_l, num_marks); } - + literal consequent = m_not_l; justification js = m_conflict; - + do { TRACE("sat_conflict_detail", tout << "processing consequent: " << consequent << "\n"; tout << "num_marks: " << num_marks << ", js kind: " << js.get_kind() << "\n";); @@ -1363,7 +1378,7 @@ namespace sat { fill_ext_antecedents(consequent, js); literal_vector::iterator it = m_ext_antecedents.begin(); literal_vector::iterator end = m_ext_antecedents.end(); - for (; it != end; ++it) + for (; it != end; ++it) process_antecedent(*it, num_marks); break; } @@ -1371,10 +1386,10 @@ namespace sat { UNREACHABLE(); break; } - + while (true) { literal l = m_trail[idx]; - if (is_marked(l.var())) + if (is_marked(l.var())) break; SASSERT(idx > 0); idx--; @@ -1387,9 +1402,9 @@ namespace sat { idx--; num_marks--; reset_mark(c_var); - } + } while (num_marks > 0); - + m_lemma[0] = ~consequent; TRACE("sat_lemma", tout << "new lemma size: " << m_lemma.size() << "\n" << m_lemma << "\n";); @@ -1400,7 +1415,9 @@ namespace sat { dyn_sub_res(); TRACE("sat_lemma", tout << "new lemma (after minimization) size: " << m_lemma.size() << "\n" << m_lemma << "\n";); } - + else + reset_lemma_var_marks(); + literal_vector::iterator it = m_lemma.begin(); literal_vector::iterator end = m_lemma.end(); unsigned new_scope_lvl = 0; @@ -1431,10 +1448,10 @@ namespace sat { return 0; unsigned r = 0; - + if (consequent != null_literal) r = lvl(consequent); - + switch (js.get_kind()) { case justification::NONE: break; @@ -1467,7 +1484,7 @@ namespace sat { fill_ext_antecedents(consequent, js); literal_vector::iterator it = m_ext_antecedents.begin(); literal_vector::iterator end = m_ext_antecedents.end(); - for (; it != end; ++it) + for (; it != end; ++it) r = std::max(r, lvl(*it)); break; } @@ -1496,7 +1513,7 @@ namespace sat { } return idx; } - + void solver::process_antecedent(literal antecedent, unsigned & num_marks) { bool_var var = antecedent.var(); unsigned var_lvl = lvl(var); @@ -1510,7 +1527,7 @@ namespace sat { m_lemma.push_back(~antecedent); } } - + /** \brief js is an external justification. Collect its antecedents and store at m_ext_antecedents. */ @@ -1577,7 +1594,7 @@ namespace sat { unsigned var_lvl = lvl(var); if (!is_marked(var) && var_lvl > 0) { if (m_lvl_set.may_contain(var_lvl)) { - mark(var); + mark(var); m_unmark.push_back(var); m_lemma_min_stack.push_back(var); } @@ -1589,17 +1606,17 @@ namespace sat { } /** - \brief Return true if lit is implied by other marked literals - and/or literals assigned at the base level. - The set lvl_set is used as an optimization. + \brief Return true if lit is implied by other marked literals + and/or literals assigned at the base level. + The set lvl_set is used as an optimization. The idea is to stop the recursive search with a failure - as soon as we find a literal assigned in a level that is not in lvl_set. + as soon as we find a literal assigned in a level that is not in lvl_set. */ bool solver::implied_by_marked(literal lit) { m_lemma_min_stack.reset(); // avoid recursive function m_lemma_min_stack.push_back(lit.var()); unsigned old_size = m_unmark.size(); - + while (!m_lemma_min_stack.empty()) { bool_var var = m_lemma_min_stack.back(); m_lemma_min_stack.pop_back(); @@ -1700,7 +1717,7 @@ namespace sat { void solver::minimize_lemma() { m_unmark.reset(); updt_lemma_lvl_set(); - + unsigned sz = m_lemma.size(); unsigned i = 1; // the first literal is the FUIP unsigned j = 1; @@ -1716,12 +1733,12 @@ namespace sat { j++; } } - + reset_unmark(0); m_lemma.shrink(j); m_stats.m_minimized_lits += sz - j; } - + /** \brief Reset the mark of the variables in the current lemma. */ @@ -1741,17 +1758,17 @@ namespace sat { Only binary and ternary clauses are used. */ void solver::dyn_sub_res() { - unsigned sz = m_lemma.size(); + unsigned sz = m_lemma.size(); for (unsigned i = 0; i < sz; i++) { mark_lit(m_lemma[i]); } - + literal l0 = m_lemma[0]; // l0 is the FUIP, and we never remove the FUIP. - // + // // In the following loop, we use unmark_lit(l) to remove a // literal from m_lemma. - + for (unsigned i = 0; i < sz; i++) { literal l = m_lemma[i]; if (!is_marked_lit(l)) @@ -1763,8 +1780,8 @@ namespace sat { for (; it != end; ++it) { // In this for-loop, the conditions l0 != ~l2 and l0 != ~l3 // are not really needed if the solver does not miss unit propagations. - // However, we add them anyway because we don't want to rely on this - // property of the propagator. + // However, we add them anyway because we don't want to rely on this + // property of the propagator. // For example, if this property is relaxed in the future, then the code // without the conditions l0 != ~l2 and l0 != ~l3 may remove the FUIP if (it->is_binary_clause()) { @@ -1810,10 +1827,10 @@ namespace sat { // p1 \/ ~p2 // p2 \/ ~p3 // p3 \/ ~p4 - // q1 \/ q2 \/ p1 \/ p2 \/ p3 \/ p4 \/ l2 - // q1 \/ ~q2 \/ p1 \/ p2 \/ p3 \/ p4 \/ l2 - // ~q1 \/ q2 \/ p1 \/ p2 \/ p3 \/ p4 \/ l2 - // ~q1 \/ ~q2 \/ p1 \/ p2 \/ p3 \/ p4 \/ l2 + // q1 \/ q2 \/ p1 \/ p2 \/ p3 \/ p4 \/ l2 + // q1 \/ ~q2 \/ p1 \/ p2 \/ p3 \/ p4 \/ l2 + // ~q1 \/ q2 \/ p1 \/ p2 \/ p3 \/ p4 \/ l2 + // ~q1 \/ ~q2 \/ p1 \/ p2 \/ p3 \/ p4 \/ l2 // ... // // 2. Now suppose we learned the lemma @@ -1824,15 +1841,15 @@ namespace sat { // That is, l \/ l2 is an implied clause. Note that probing does not add // this clause to the clause database (there are too many). // - // 4. Lemma (*) is deleted (garbage collected). + // 4. Lemma (*) is deleted (garbage collected). // // 5. l is decided to be false, p1, p2, p3 and p4 are propagated using BCP, // but l2 is not since the lemma (*) was deleted. - // + // // Probing module still "knows" that l \/ l2 is valid binary clause - // + // // 6. A new lemma is created where ~l2 is the FUIP and the lemma also contains l. - // If we remove l0 != ~l2 may try to delete the FUIP. + // If we remove l0 != ~l2 may try to delete the FUIP. if (is_marked_lit(~l2) && l0 != ~l2) { // eliminate ~l2 from lemma because we have the clause l \/ l2 unmark_lit(~l2); @@ -1843,7 +1860,7 @@ namespace sat { // can't eliminat FUIP SASSERT(is_marked_lit(m_lemma[0])); - + unsigned j = 0; for (unsigned i = 0; i < sz; i++) { literal l = m_lemma[i]; @@ -1855,7 +1872,7 @@ namespace sat { } m_stats.m_dyn_sub_res += sz - j; - + SASSERT(j >= 1); m_lemma.shrink(j); } @@ -1948,7 +1965,7 @@ namespace sat { // Misc // // ----------------------- - + void solver::updt_params(params_ref const & p) { m_params = p; m_config.updt_params(p); @@ -1970,8 +1987,8 @@ namespace sat { void solver::set_cancel(bool f) { m_cancel = f; } - - void solver::collect_statistics(statistics & st) { + + void solver::collect_statistics(statistics & st) { m_stats.collect_statistics(st); m_cleaner.collect_statistics(st); m_simplifier.collect_statistics(st); @@ -2066,7 +2083,7 @@ namespace sat { } } } - + void solver::display_units(std::ostream & out) const { unsigned end = scope_lvl() == 0 ? m_trail.size() : m_scopes[0].m_trail_lim; for (unsigned i = 0; i < end; i++) { @@ -2220,26 +2237,26 @@ namespace sat { // Simplification // // ----------------------- - void solver::cleanup() { - if (scope_lvl() > 0 || inconsistent()) - return; + void solver::cleanup() { + if (scope_lvl() > 0 || inconsistent()) + return; if (m_cleaner() && m_ext) m_ext->clauses_modifed(); } - void solver::simplify(bool learned) { - if (scope_lvl() > 0 || inconsistent()) - return; - m_simplifier(learned); - m_simplifier.free_memory(); + void solver::simplify(bool learned) { + if (scope_lvl() > 0 || inconsistent()) + return; + m_simplifier(learned); + m_simplifier.free_memory(); if (m_ext) m_ext->clauses_modifed(); } - unsigned solver::scc_bin() { - if (scope_lvl() > 0 || inconsistent()) - return 0; - unsigned r = m_scc(); + unsigned solver::scc_bin() { + if (scope_lvl() > 0 || inconsistent()) + return 0; + unsigned r = m_scc(); if (r > 0 && m_ext) m_ext->clauses_modifed(); return r; @@ -2313,10 +2330,10 @@ namespace sat { out << " :inconsistent " << (m_inconsistent ? "true" : "false") << "\n"; out << " :vars " << num_vars() << "\n"; out << " :elim-vars " << num_elim << "\n"; - out << " :lits " << num_lits << "\n"; + out << " :lits " << num_lits << "\n"; out << " :assigned " << m_trail.size() << "\n"; - out << " :binary-clauses " << num_bin << "\n"; - out << " :ternary-clauses " << num_ter << "\n"; + out << " :binary-clauses " << num_bin << "\n"; + out << " :ternary-clauses " << num_ter << "\n"; out << " :clauses " << num_cls << "\n"; out << " :del-clause " << m_stats.m_del_clause << "\n"; out << " :avg-clause-size " << (total_cls == 0 ? 0.0 : static_cast(num_lits) / static_cast(total_cls)) << "\n"; diff --git a/src/shell/smtlib_frontend.cpp b/src/shell/smtlib_frontend.cpp index 2ec1e200a..eb3b3587d 100644 --- a/src/shell/smtlib_frontend.cpp +++ b/src/shell/smtlib_frontend.cpp @@ -30,6 +30,7 @@ Revision History: #include"polynomial_cmds.h" #include"subpaving_cmds.h" #include"smt_strategic_solver.h" +#include"smt_solver.h" extern bool g_display_statistics; extern void display_config(); @@ -106,6 +107,7 @@ unsigned read_smtlib2_commands(char const * file_name) { cmd_context ctx; ctx.set_solver_factory(mk_smt_strategic_solver_factory()); + ctx.set_interpolating_solver_factory(mk_smt_solver_factory()); install_dl_cmds(ctx); install_dbg_cmds(ctx); diff --git a/src/smt/params/qi_params.cpp b/src/smt/params/qi_params.cpp index f5506d35b..60fcd6fc4 100644 --- a/src/smt/params/qi_params.cpp +++ b/src/smt/params/qi_params.cpp @@ -27,6 +27,7 @@ void qi_params::updt_params(params_ref const & _p) { m_mbqi_max_iterations = p.mbqi_max_iterations(); m_mbqi_trace = p.mbqi_trace(); m_mbqi_force_template = p.mbqi_force_template(); + m_mbqi_id = p.mbqi_id(); m_qi_profile = p.qi_profile(); m_qi_profile_freq = p.qi_profile_freq(); m_qi_max_instances = p.qi_max_instances(); diff --git a/src/smt/params/qi_params.h b/src/smt/params/qi_params.h index 0cd817f22..bca3ad3fc 100644 --- a/src/smt/params/qi_params.h +++ b/src/smt/params/qi_params.h @@ -51,6 +51,7 @@ struct qi_params { unsigned m_mbqi_max_iterations; bool m_mbqi_trace; unsigned m_mbqi_force_template; + const char * m_mbqi_id; qi_params(params_ref const & p = params_ref()): /* @@ -97,7 +98,9 @@ struct qi_params { m_mbqi_max_cexs_incr(1), m_mbqi_max_iterations(1000), m_mbqi_trace(false), - m_mbqi_force_template(10) { + m_mbqi_force_template(10), + m_mbqi_id(0) + { updt_params(p); } diff --git a/src/smt/params/smt_params.cpp b/src/smt/params/smt_params.cpp index f97d49a91..ca471f716 100644 --- a/src/smt/params/smt_params.cpp +++ b/src/smt/params/smt_params.cpp @@ -40,6 +40,7 @@ void smt_params::updt_local_params(params_ref const & _p) { m_arith_pivot_strategy = ARITH_PIVOT_GREATEST_ERROR; else if (_p.get_bool("arith.least_error_pivot", false)) m_arith_pivot_strategy = ARITH_PIVOT_LEAST_ERROR; + theory_array_params::updt_params(_p); } void smt_params::updt_params(params_ref const & p) { @@ -48,6 +49,7 @@ void smt_params::updt_params(params_ref const & p) { theory_arith_params::updt_params(p); theory_bv_params::updt_params(p); theory_pb_params::updt_params(p); + // theory_array_params::updt_params(p); updt_local_params(p); } diff --git a/src/smt/params/smt_params_helper.pyg b/src/smt/params/smt_params_helper.pyg index 0c2a4f93a..cc1527b9c 100644 --- a/src/smt/params/smt_params_helper.pyg +++ b/src/smt/params/smt_params_helper.pyg @@ -21,6 +21,7 @@ def_module_params(module_name='smt', ('mbqi.max_iterations', UINT, 1000, 'maximum number of rounds of MBQI'), ('mbqi.trace', BOOL, False, 'generate tracing messages for Model Based Quantifier Instantiation (MBQI). It will display a message before every round of MBQI, and the quantifiers that were not satisfied'), ('mbqi.force_template', UINT, 10, 'some quantifiers can be used as templates for building interpretations for functions. Z3 uses heuristics to decide whether a quantifier will be used as a template or not. Quantifiers with weight >= mbqi.force_template are forced to be used as a template'), + ('mbqi.id', STRING, '', 'Only use model-based instantiation for quantifiers with id\'s beginning with string'), ('qi.profile', BOOL, False, 'profile quantifier instantiation'), ('qi.profile_freq', UINT, UINT_MAX, 'how frequent results are reported by qi.profile'), ('qi.max_instances', UINT, UINT_MAX, 'maximum number of quantifier instantiations'), @@ -44,5 +45,7 @@ def_module_params(module_name='smt', ('arith.ignore_int', BOOL, False, 'treat integer variables as real'), ('pb.conflict_frequency', UINT, 0, 'conflict frequency for Pseudo-Boolean theory'), ('pb.learn_complements', BOOL, True, 'learn complement literals for Pseudo-Boolean theory'), - ('pb.enable_compilation', BOOL, True, 'enable compilation into sorting circuits for Pseudo-Boolean'))) - + ('pb.enable_compilation', BOOL, True, 'enable compilation into sorting circuits for Pseudo-Boolean'), + ('array.weak', BOOL, False, 'weak array theory'), + ('array.extensional', BOOL, True, 'extensional array theory') + )) diff --git a/src/smt/params/theory_array_params.cpp b/src/smt/params/theory_array_params.cpp new file mode 100644 index 000000000..e3c8b2448 --- /dev/null +++ b/src/smt/params/theory_array_params.cpp @@ -0,0 +1,28 @@ +/*++ +Copyright (c) 2006 Microsoft Corporation + +Module Name: + + theory_array_params.cpp + +Abstract: + + + +Author: + + Leonardo de Moura (leonardo) 2008-05-06. + +Revision History: + +--*/ +#include"theory_array_params.h" +#include"smt_params_helper.hpp" + +void theory_array_params::updt_params(params_ref const & _p) { + smt_params_helper p(_p); + m_array_weak = p.array_weak(); + m_array_extensional = p.array_extensional(); +} + + diff --git a/src/smt/params/theory_array_params.h b/src/smt/params/theory_array_params.h index c1ae8ce32..a022bcadf 100644 --- a/src/smt/params/theory_array_params.h +++ b/src/smt/params/theory_array_params.h @@ -51,6 +51,9 @@ struct theory_array_params : public array_simplifier_params { m_array_lazy_ieq_delay(10) { } + + void updt_params(params_ref const & _p); + #if 0 void register_params(ini_params & p) { p.register_int_param("array_solver", 0, 3, reinterpret_cast(m_array_mode), "0 - no array, 1 - simple, 2 - model based, 3 - full"); diff --git a/src/smt/smt_model_checker.cpp b/src/smt/smt_model_checker.cpp index 53f3af961..526447f9f 100644 --- a/src/smt/smt_model_checker.cpp +++ b/src/smt/smt_model_checker.cpp @@ -322,6 +322,7 @@ namespace smt { for (; it != end; ++it) { quantifier * q = *it; + if(!m_qm->mbqi_enabled(q)) continue; if (m_context->is_relevant(q) && m_context->get_assignment(q) == l_true) { if (m_params.m_mbqi_trace && q->get_qid() != symbol::null) { verbose_stream() << "(smt.mbqi :checking " << q->get_qid() << ")\n"; diff --git a/src/smt/smt_quantifier.cpp b/src/smt/smt_quantifier.cpp index d56fe0cff..b30e0f8b8 100644 --- a/src/smt/smt_quantifier.cpp +++ b/src/smt/smt_quantifier.cpp @@ -335,6 +335,10 @@ namespace smt { return m_imp->m_plugin->model_based(); } + bool quantifier_manager::mbqi_enabled(quantifier *q) const { + return m_imp->m_plugin->mbqi_enabled(q); + } + void quantifier_manager::adjust_model(proto_model * m) { m_imp->m_plugin->adjust_model(m); } @@ -434,10 +438,24 @@ namespace smt { virtual bool model_based() const { return m_fparams->m_mbqi; } + virtual bool mbqi_enabled(quantifier *q) const { + if(!m_fparams->m_mbqi_id) return true; + const symbol &s = q->get_qid(); + unsigned len = strlen(m_fparams->m_mbqi_id); + if(s == symbol::null || s.is_numerical()) + return len == 0; + return strncmp(s.bare_str(),m_fparams->m_mbqi_id,static_cast(len)) == 0; + } + + /* Quantifier id's must begin with the prefix specified by + parameter mbqi.id to be instantiated with MBQI. The default + value is the empty string, so all quantifiers are + instantiated. + */ virtual void add(quantifier * q) { - if (m_fparams->m_mbqi) { - m_model_finder->register_quantifier(q); - } + if (m_fparams->m_mbqi && mbqi_enabled(q)) { + m_model_finder->register_quantifier(q); + } } virtual void del(quantifier * q) { diff --git a/src/smt/smt_quantifier.h b/src/smt/smt_quantifier.h index 19113229c..e3d626157 100644 --- a/src/smt/smt_quantifier.h +++ b/src/smt/smt_quantifier.h @@ -75,6 +75,7 @@ namespace smt { }; bool model_based() const; + bool mbqi_enabled(quantifier *q) const; // can mbqi instantiate this quantifier? void adjust_model(proto_model * m); check_model_result check_model(proto_model * m, obj_map const & root2value); @@ -144,6 +145,11 @@ namespace smt { */ virtual bool model_based() const = 0; + /** + \brief Is "model based" instantiate allowed to instantiate this quantifier? + */ + virtual bool mbqi_enabled(quantifier *q) const {return true;} + /** \brief Give a change to the plugin to adjust the interpretation of unintepreted functions. It can basically change the "else" of each uninterpreted function. diff --git a/src/tactic/arith/factor_tactic.cpp b/src/tactic/arith/factor_tactic.cpp index 752f6681d..4eec83037 100644 --- a/src/tactic/arith/factor_tactic.cpp +++ b/src/tactic/arith/factor_tactic.cpp @@ -50,7 +50,7 @@ class factor_tactic : public tactic { return args[0]; return m_util.mk_mul(sz, args); } - + expr * mk_zero_for(expr * arg) { return m_util.mk_numeral(rational(0), m_util.is_int(arg)); } @@ -92,10 +92,10 @@ class factor_tactic : public tactic { return k; } } - + // p1^{2*k1} * p2^{2*k2 + 1} >=< 0 // --> - // (p1^2)*p2 >=<0 + // (p1^2)*p2 >=<0 void mk_comp(decl_kind k, polynomial::factors const & fs, expr_ref & result) { SASSERT(k == OP_LT || k == OP_GT || k == OP_LE || k == OP_GE); expr_ref_buffer args(m); @@ -127,7 +127,7 @@ class factor_tactic : public tactic { } } } - + // Strict case // p1^{2*k1} * p2^{2*k2 + 1} >< 0 // --> @@ -158,11 +158,11 @@ class factor_tactic : public tactic { args.push_back(m.mk_app(m_util.get_family_id(), k, mk_mul(odd_factors.size(), odd_factors.c_ptr()), mk_zero_for(odd_factors[0]))); } SASSERT(!args.empty()); - if (args.size() == 1) + if (args.size() == 1) result = args[0]; else if (strict) result = m.mk_and(args.size(), args.c_ptr()); - else + else result = m.mk_or(args.size(), args.c_ptr()); } @@ -173,7 +173,7 @@ class factor_tactic : public tactic { scoped_mpz d2(m_qm); m_expr2poly.to_polynomial(lhs, p1, d1); m_expr2poly.to_polynomial(rhs, p2, d2); - TRACE("factor_tactic_bug", + TRACE("factor_tactic_bug", tout << "lhs: " << mk_ismt2_pp(lhs, m) << "\n"; tout << "p1: " << p1 << "\n"; tout << "d1: " << d1 << "\n"; @@ -195,18 +195,18 @@ class factor_tactic : public tactic { SASSERT(fs.distinct_factors() > 0); TRACE("factor_tactic_bug", tout << "factors:\n"; fs.display(tout); tout << "\n";); if (fs.distinct_factors() == 1 && fs.get_degree(0) == 1) - return BR_FAILED; + return BR_FAILED; if (m.is_eq(f)) { if (m_split_factors) mk_split_eq(fs, result); - else + else mk_eq(fs, result); } else { decl_kind k = f->get_decl_kind(); if (m_qm.is_neg(fs.get_constant())) k = flip(k); - + if (m_split_factors) mk_split_comp(k, fs, result); else @@ -215,10 +215,10 @@ class factor_tactic : public tactic { return BR_DONE; } - br_status reduce_app(func_decl * f, unsigned num, expr * const * args, expr_ref & result, proof_ref & result_pr) { + br_status reduce_app(func_decl * f, unsigned num, expr * const * args, expr_ref & result, proof_ref & result_pr) { if (num != 2) return BR_FAILED; - if (m.is_eq(f) && (m_util.is_arith_expr(args[0]) || m_util.is_arith_expr(args[1]))) + if (m.is_eq(f) && (m_util.is_arith_expr(args[0]) || m_util.is_arith_expr(args[1])) && (!m.is_bool(args[0]))) return factor(f, args[0], args[1], result); if (f->get_family_id() != m_util.get_family_id()) return BR_FAILED; @@ -232,10 +232,10 @@ class factor_tactic : public tactic { return BR_FAILED; } }; - + struct rw : public rewriter_tpl { rw_cfg m_cfg; - + rw(ast_manager & m, params_ref const & p): rewriter_tpl(m, m.proofs_enabled(), m_cfg), m_cfg(m, p) { @@ -245,24 +245,24 @@ class factor_tactic : public tactic { struct imp { ast_manager & m; rw m_rw; - + imp(ast_manager & _m, params_ref const & p): m(_m), m_rw(m, p) { } - + void set_cancel(bool f) { m_rw.set_cancel(f); m_rw.cfg().m_pm.set_cancel(f); } - + void updt_params(params_ref const & p) { m_rw.cfg().updt_params(p); } - - void operator()(goal_ref const & g, - goal_ref_buffer & result, - model_converter_ref & mc, + + void operator()(goal_ref const & g, + goal_ref_buffer & result, + model_converter_ref & mc, proof_converter_ref & pc, expr_dependency_ref & core) { SASSERT(g->is_well_sorted()); @@ -288,7 +288,7 @@ class factor_tactic : public tactic { SASSERT(g->is_well_sorted()); } }; - + imp * m_imp; params_ref m_params; public: @@ -300,7 +300,7 @@ public: virtual tactic * translate(ast_manager & m) { return alloc(factor_tactic, m, m_params); } - + virtual ~factor_tactic() { dealloc(m_imp); } @@ -311,14 +311,14 @@ public: } virtual void collect_param_descrs(param_descrs & r) { - r.insert("split_factors", CPK_BOOL, + r.insert("split_factors", CPK_BOOL, "(default: true) apply simplifications such as (= (* p1 p2) 0) --> (or (= p1 0) (= p2 0))."); polynomial::factor_params::get_param_descrs(r); } - - virtual void operator()(goal_ref const & in, - goal_ref_buffer & result, - model_converter_ref & mc, + + virtual void operator()(goal_ref const & in, + goal_ref_buffer & result, + model_converter_ref & mc, proof_converter_ref & pc, expr_dependency_ref & core) { try { @@ -331,7 +331,7 @@ public: throw tactic_exception(ex.msg()); } } - + virtual void cleanup() { ast_manager & m = m_imp->m; imp * d = m_imp; diff --git a/src/tactic/arith/pb2bv_tactic.cpp b/src/tactic/arith/pb2bv_tactic.cpp index 83a949579..89195a9d5 100644 --- a/src/tactic/arith/pb2bv_tactic.cpp +++ b/src/tactic/arith/pb2bv_tactic.cpp @@ -480,7 +480,10 @@ private: break; } - SASSERT (i < sz); + if (i >= sz) { + // [Christoph]: In this case, all the m_a are equal to m_c. + return; + } // copy lits [0, i) to m_clause for (unsigned j = 0; j < i; j++) @@ -500,6 +503,7 @@ private: } void mk_pbc(polynomial & m_p, numeral & m_c, expr_ref & r, bool enable_split) { + TRACE("mk_pbc", display(tout, m_p, m_c); ); if (m_c.is_nonpos()) { // constraint is equivalent to true. r = m.mk_true(); @@ -507,7 +511,7 @@ private: } polynomial::iterator it = m_p.begin(); polynomial::iterator end = m_p.end(); - numeral a_gcd = it->m_a; + numeral a_gcd = (it->m_a > m_c) ? m_c : it->m_a; for (; it != end; ++it) { if (it->m_a > m_c) it->m_a = m_c; // trimming coefficients @@ -520,6 +524,7 @@ private: it->m_a /= a_gcd; m_c = ceil(m_c/a_gcd); } + TRACE("mk_pbc", tout << "GCD = " << a_gcd << "; Normalized: "; display(tout, m_p, m_c); tout << "\n"; ); it = m_p.begin(); numeral a_sum; for (; it != end; ++it) { diff --git a/src/tactic/fpa/fpa2bv_converter.cpp b/src/tactic/fpa/fpa2bv_converter.cpp index ff24d1ceb..ebccf4d52 100644 --- a/src/tactic/fpa/fpa2bv_converter.cpp +++ b/src/tactic/fpa/fpa2bv_converter.cpp @@ -140,7 +140,7 @@ void fpa2bv_converter::mk_const(func_decl * f, expr_ref & result) { s_sig = m_bv_util.mk_sort(sbits-1); s_exp = m_bv_util.mk_sort(ebits); -#ifdef _DEBUG +#ifdef Z3DEBUG std::string p("fpa2bv"); std::string name = f->get_name().str(); @@ -271,7 +271,7 @@ void fpa2bv_converter::mk_rm_const(func_decl * f, expr_ref & result) { SASSERT(is_rm_sort(f->get_range())); result = m.mk_fresh_const( - #ifdef _DEBUG + #ifdef Z3DEBUG "fpa2bv_rm" #else 0 @@ -1588,8 +1588,7 @@ void fpa2bv_converter::mk_round_to_integral(func_decl * f, unsigned num, expr * v2 = x; unsigned ebits = m_util.get_ebits(f->get_range()); - unsigned sbits = m_util.get_sbits(f->get_range()); - SASSERT(ebits < sbits); + unsigned sbits = m_util.get_sbits(f->get_range()); expr_ref a_sgn(m), a_sig(m), a_exp(m), a_lz(m); unpack(x, a_sgn, a_sig, a_exp, a_lz, true); @@ -1619,9 +1618,22 @@ void fpa2bv_converter::mk_round_to_integral(func_decl * f, unsigned num, expr * expr_ref shift(m), r_shifted(m), l_shifted(m); shift = m_bv_util.mk_bv_sub(m_bv_util.mk_numeral(sbits-1, ebits+1), m_bv_util.mk_sign_extend(1, a_exp)); - r_shifted = m_bv_util.mk_bv_lshr(a_sig, m_bv_util.mk_zero_extend(sbits-ebits-1, shift)); + if (sbits > (ebits+1)) + r_shifted = m_bv_util.mk_bv_lshr(a_sig, m_bv_util.mk_zero_extend(sbits-(ebits+1), shift)); + else if (sbits < (ebits+1)) + r_shifted = m_bv_util.mk_extract(ebits, ebits-sbits+1, m_bv_util.mk_bv_lshr(m_bv_util.mk_zero_extend(ebits+1-sbits, a_sig), shift)); + else // sbits == ebits+1 + r_shifted = m_bv_util.mk_bv_lshr(a_sig, shift); + SASSERT(is_well_sorted(m, r_shifted)); SASSERT(m_bv_util.get_bv_size(r_shifted) == sbits); - l_shifted = m_bv_util.mk_bv_shl(r_shifted, m_bv_util.mk_zero_extend(sbits-ebits-1, shift)); + + if (sbits > (ebits+1)) + l_shifted = m_bv_util.mk_bv_shl(r_shifted, m_bv_util.mk_zero_extend(sbits-(ebits+1), shift)); + else if (sbits < (ebits+1)) + l_shifted = m_bv_util.mk_extract(ebits, ebits-sbits+1, m_bv_util.mk_bv_shl(m_bv_util.mk_zero_extend(ebits+1-sbits, r_shifted), shift)); + else // sbits == ebits+1 + l_shifted = m_bv_util.mk_bv_shl(r_shifted, shift); + SASSERT(is_well_sorted(m, l_shifted)); SASSERT(m_bv_util.get_bv_size(l_shifted) == sbits); res_sig = m_bv_util.mk_concat(m_bv_util.mk_numeral(0, 1), @@ -1825,146 +1837,158 @@ void fpa2bv_converter::mk_to_float(func_decl * f, unsigned num, expr * const * a mk_triple(args[0], args[1], args[2], result); } else if (num == 2 && is_app(args[1]) && m_util.is_float(m.get_sort(args[1]))) { - // We also support float to float conversion + // We also support float to float conversion + sort * s = f->get_range(); expr_ref rm(m), x(m); rm = args[0]; x = args[1]; - - expr_ref c1(m), c2(m), c3(m), c4(m), c5(m); - expr_ref v1(m), v2(m), v3(m), v4(m), v5(m), v6(m); - expr_ref one1(m); - one1 = m_bv_util.mk_numeral(1, 1); - expr_ref ninf(m), pinf(m); - mk_plus_inf(f, pinf); - mk_minus_inf(f, ninf); - - // NaN -> NaN - mk_is_nan(x, c1); - mk_nan(f, v1); - - // +0 -> +0 - mk_is_pzero(x, c2); - mk_pzero(f, v2); - - // -0 -> -0 - mk_is_nzero(x, c3); - mk_nzero(f, v3); - - // +oo -> +oo - mk_is_pinf(x, c4); - v4 = pinf; - - // -oo -> -oo - mk_is_ninf(x, c5); - v5 = ninf; - - // otherwise: the actual conversion with rounding. - sort * s = f->get_range(); - expr_ref sgn(m), sig(m), exp(m), lz(m); - unpack(x, sgn, sig, exp, lz, true); - - dbg_decouple("fpa2bv_to_float_x_sig", sig); - dbg_decouple("fpa2bv_to_float_x_exp", exp); - dbg_decouple("fpa2bv_to_float_lz", lz); - - expr_ref res_sgn(m), res_sig(m), res_exp(m); - - res_sgn = sgn; - - unsigned from_sbits = m_util.get_sbits(m.get_sort(args[1])); - unsigned from_ebits = m_util.get_ebits(m.get_sort(args[1])); + unsigned from_sbits = m_util.get_sbits(m.get_sort(x)); + unsigned from_ebits = m_util.get_ebits(m.get_sort(x)); unsigned to_sbits = m_util.get_sbits(s); unsigned to_ebits = m_util.get_ebits(s); - SASSERT(m_bv_util.get_bv_size(sgn) == 1); - SASSERT(m_bv_util.get_bv_size(sig) == from_sbits); - SASSERT(m_bv_util.get_bv_size(exp) == from_ebits); - SASSERT(m_bv_util.get_bv_size(lz) == from_ebits); + if (from_sbits == to_sbits && from_ebits == to_ebits) + result = x; + else { + expr_ref c1(m), c2(m), c3(m), c4(m), c5(m); + expr_ref v1(m), v2(m), v3(m), v4(m), v5(m), v6(m); + expr_ref one1(m); - if (from_sbits < (to_sbits + 3)) - { - // make sure that sig has at least to_sbits + 3 - res_sig = m_bv_util.mk_concat(sig, m_bv_util.mk_numeral(0, to_sbits+3-from_sbits)); - } - else if (from_sbits > (to_sbits + 3)) - { - // collapse the extra bits into a sticky bit. - expr_ref sticky(m), low(m), high(m); - low = m_bv_util.mk_extract(from_sbits - to_sbits - 3, 0, sig); - high = m_bv_util.mk_extract(from_sbits - 1, from_sbits - to_sbits - 2, sig); - sticky = m.mk_app(m_bv_util.get_fid(), OP_BREDOR, low.get()); - res_sig = m_bv_util.mk_concat(high, sticky); - } - else - res_sig = sig; + one1 = m_bv_util.mk_numeral(1, 1); + expr_ref ninf(m), pinf(m); + mk_plus_inf(f, pinf); + mk_minus_inf(f, ninf); - res_sig = m_bv_util.mk_zero_extend(1, res_sig); // extra zero in the front for the rounder. - unsigned sig_sz = m_bv_util.get_bv_size(res_sig); - SASSERT(sig_sz == to_sbits+4); + // NaN -> NaN + mk_is_nan(x, c1); + mk_nan(f, v1); - expr_ref exponent_overflow(m); - exponent_overflow = m.mk_false(); + // +0 -> +0 + mk_is_pzero(x, c2); + mk_pzero(f, v2); - if (from_ebits < (to_ebits + 2)) - { - res_exp = m_bv_util.mk_sign_extend(to_ebits-from_ebits+2, exp); - } - else if (from_ebits > (to_ebits + 2)) - { - expr_ref high(m), low(m), lows(m), high_red_or(m), high_red_and(m), h_or_eq(m), h_and_eq(m); - expr_ref no_ovf(m), zero1(m), s_is_one(m), s_is_zero(m); - high = m_bv_util.mk_extract(from_ebits - 1, to_ebits + 2, exp); - low = m_bv_util.mk_extract(to_ebits+1, 0, exp); - lows = m_bv_util.mk_extract(to_ebits+1, to_ebits+1, low); + // -0 -> -0 + mk_is_nzero(x, c3); + mk_nzero(f, v3); + + // +oo -> +oo + mk_is_pinf(x, c4); + v4 = pinf; + + // -oo -> -oo + mk_is_ninf(x, c5); + v5 = ninf; - high_red_or = m.mk_app(m_bv_util.get_fid(), OP_BREDOR, high.get()); - high_red_and = m.mk_app(m_bv_util.get_fid(), OP_BREDAND, high.get()); + // otherwise: the actual conversion with rounding. + expr_ref sgn(m), sig(m), exp(m), lz(m); + unpack(x, sgn, sig, exp, lz, true); - zero1 = m_bv_util.mk_numeral(0, 1); - m_simp.mk_eq(high_red_and, one1, h_and_eq); - m_simp.mk_eq(high_red_or, zero1, h_or_eq); - m_simp.mk_eq(lows, zero1, s_is_zero); - m_simp.mk_eq(lows, one1, s_is_one); + dbg_decouple("fpa2bv_to_float_x_sig", sig); + dbg_decouple("fpa2bv_to_float_x_exp", exp); + dbg_decouple("fpa2bv_to_float_lz", lz); + + expr_ref res_sgn(m), res_sig(m), res_exp(m); + + res_sgn = sgn; + + SASSERT(m_bv_util.get_bv_size(sgn) == 1); + SASSERT(m_bv_util.get_bv_size(sig) == from_sbits); + SASSERT(m_bv_util.get_bv_size(exp) == from_ebits); + SASSERT(m_bv_util.get_bv_size(lz) == from_ebits); + + if (from_sbits < (to_sbits + 3)) + { + // make sure that sig has at least to_sbits + 3 + res_sig = m_bv_util.mk_concat(sig, m_bv_util.mk_numeral(0, to_sbits+3-from_sbits)); + } + else if (from_sbits > (to_sbits + 3)) + { + // collapse the extra bits into a sticky bit. + expr_ref sticky(m), low(m), high(m); + low = m_bv_util.mk_extract(from_sbits - to_sbits - 3, 0, sig); + high = m_bv_util.mk_extract(from_sbits - 1, from_sbits - to_sbits - 2, sig); + sticky = m.mk_app(m_bv_util.get_fid(), OP_BREDOR, low.get()); + res_sig = m_bv_util.mk_concat(high, sticky); + } + else + res_sig = sig; + + res_sig = m_bv_util.mk_zero_extend(1, res_sig); // extra zero in the front for the rounder. + unsigned sig_sz = m_bv_util.get_bv_size(res_sig); + SASSERT(sig_sz == to_sbits+4); + + expr_ref exponent_overflow(m); + exponent_overflow = m.mk_false(); + + if (from_ebits < (to_ebits + 2)) + { + res_exp = m_bv_util.mk_sign_extend(to_ebits-from_ebits+2, exp); + + // subtract lz for subnormal numbers. + expr_ref lz_ext(m); + lz_ext = m_bv_util.mk_zero_extend(to_ebits-from_ebits+2, lz); + res_exp = m_bv_util.mk_bv_sub(res_exp, lz_ext); + } + else if (from_ebits > (to_ebits + 2)) + { + expr_ref high(m), low(m), lows(m), high_red_or(m), high_red_and(m), h_or_eq(m), h_and_eq(m); + expr_ref no_ovf(m), zero1(m), s_is_one(m), s_is_zero(m); + high = m_bv_util.mk_extract(from_ebits - 1, to_ebits + 2, exp); + low = m_bv_util.mk_extract(to_ebits+1, 0, exp); + lows = m_bv_util.mk_extract(to_ebits+1, to_ebits+1, low); - expr_ref c2(m); - m_simp.mk_ite(h_or_eq, s_is_one, m.mk_false(), c2); - m_simp.mk_ite(h_and_eq, s_is_zero, c2, exponent_overflow); + high_red_or = m.mk_app(m_bv_util.get_fid(), OP_BREDOR, high.get()); + high_red_and = m.mk_app(m_bv_util.get_fid(), OP_BREDAND, high.get()); + + zero1 = m_bv_util.mk_numeral(0, 1); + m_simp.mk_eq(high_red_and, one1, h_and_eq); + m_simp.mk_eq(high_red_or, zero1, h_or_eq); + m_simp.mk_eq(lows, zero1, s_is_zero); + m_simp.mk_eq(lows, one1, s_is_one); - // Note: Upon overflow, we _could_ try to shift the significand around... + expr_ref c2(m); + m_simp.mk_ite(h_or_eq, s_is_one, m.mk_false(), c2); + m_simp.mk_ite(h_and_eq, s_is_zero, c2, exponent_overflow); + + // Note: Upon overflow, we _could_ try to shift the significand around... - res_exp = low; - } - else - res_exp = exp; + // subtract lz for subnormal numbers. + expr_ref lz_ext(m), lz_rest(m), lz_redor(m), lz_redor_bool(m); + lz_ext = m_bv_util.mk_extract(to_ebits+1, 0, lz); + lz_rest = m_bv_util.mk_extract(from_ebits-1, to_ebits+2, lz); + lz_redor = m.mk_app(m_bv_util.get_fid(), OP_BREDOR, lz_rest.get()); + m_simp.mk_eq(lz_redor, one1, lz_redor_bool); + m_simp.mk_or(exponent_overflow, lz_redor_bool, exponent_overflow); - // subtract lz for subnormal numbers. - expr_ref lz_ext(m); - lz_ext = m_bv_util.mk_zero_extend(to_ebits-from_ebits+2, lz); - res_exp = m_bv_util.mk_bv_sub(res_exp, lz_ext); - SASSERT(m_bv_util.get_bv_size(res_exp) == to_ebits+2); + res_exp = m_bv_util.mk_bv_sub(low, lz_ext); + } + else // from_ebits == (to_ebits + 2) + res_exp = m_bv_util.mk_bv_sub(exp, lz); - dbg_decouple("fpa2bv_to_float_res_sig", res_sig); - dbg_decouple("fpa2bv_to_float_res_exp", res_exp); + SASSERT(m_bv_util.get_bv_size(res_exp) == to_ebits+2); + SASSERT(is_well_sorted(m, res_exp)); - expr_ref rounded(m); - round(s, rm, res_sgn, res_sig, res_exp, rounded); - + dbg_decouple("fpa2bv_to_float_res_sig", res_sig); + dbg_decouple("fpa2bv_to_float_res_exp", res_exp); - expr_ref is_neg(m), sig_inf(m); - m_simp.mk_eq(sgn, one1, is_neg); - mk_ite(is_neg, ninf, pinf, sig_inf); + expr_ref rounded(m); + round(s, rm, res_sgn, res_sig, res_exp, rounded); + + expr_ref is_neg(m), sig_inf(m); + m_simp.mk_eq(sgn, one1, is_neg); + mk_ite(is_neg, ninf, pinf, sig_inf); - dbg_decouple("fpa2bv_to_float_exp_ovf", exponent_overflow); + dbg_decouple("fpa2bv_to_float_exp_ovf", exponent_overflow); + mk_ite(exponent_overflow, sig_inf, rounded, v6); - mk_ite(exponent_overflow, sig_inf, rounded, v6); - - // And finally, we tie them together. - mk_ite(c5, v5, v6, result); - mk_ite(c4, v4, result, result); - mk_ite(c3, v3, result, result); - mk_ite(c2, v2, result, result); - mk_ite(c1, v1, result, result); + // And finally, we tie them together. + mk_ite(c5, v5, v6, result); + mk_ite(c4, v4, result, result); + mk_ite(c3, v3, result, result); + mk_ite(c2, v2, result, result); + mk_ite(c1, v1, result, result); + } } else { // .. other than that, we only support rationals for asFloat @@ -2167,14 +2191,14 @@ void fpa2bv_converter::mk_bot_exp(unsigned sz, expr_ref & result) { } void fpa2bv_converter::mk_min_exp(unsigned ebits, expr_ref & result) { - SASSERT(ebits > 0); + SASSERT(ebits >= 2); const mpz & z = m_mpf_manager.m_powers2.m1(ebits-1, true); result = m_bv_util.mk_numeral(z + mpz(1), ebits); } void fpa2bv_converter::mk_max_exp(unsigned ebits, expr_ref & result) { - SASSERT(ebits > 0); - result = m_bv_util.mk_numeral(m_mpf_manager.m_powers2.m1(ebits-1, false), ebits); + SASSERT(ebits >= 2); + result = m_bv_util.mk_numeral(m_mpf_manager.m_powers2.m1(ebits-1, false), ebits); } void fpa2bv_converter::mk_leading_zeros(expr * e, unsigned max_bits, expr_ref & result) { @@ -2220,14 +2244,14 @@ void fpa2bv_converter::mk_bias(expr * e, expr_ref & result) { unsigned ebits = m_bv_util.get_bv_size(e); SASSERT(ebits >= 2); - expr_ref mask(m); - mask = m_bv_util.mk_numeral(fu().fm().m_powers2.m1(ebits-1), ebits); - result = m_bv_util.mk_bv_add(e, mask); + expr_ref bias(m); + bias = m_bv_util.mk_numeral(fu().fm().m_powers2.m1(ebits-1), ebits); + result = m_bv_util.mk_bv_add(e, bias); } void fpa2bv_converter::mk_unbias(expr * e, expr_ref & result) { unsigned ebits = m_bv_util.get_bv_size(e); - SASSERT(ebits >= 2); + SASSERT(ebits >= 3); expr_ref e_plus_one(m); e_plus_one = m_bv_util.mk_bv_add(e, m_bv_util.mk_numeral(1, ebits)); @@ -2263,6 +2287,7 @@ void fpa2bv_converter::unpack(expr * e, expr_ref & sgn, expr_ref & sig, expr_ref expr_ref denormal_sig(m), denormal_exp(m); denormal_sig = m_bv_util.mk_zero_extend(1, sig); + SASSERT(ebits >= 3); // Note: when ebits=2 there is no 1-exponent, so mk_unbias will produce a 0. denormal_exp = m_bv_util.mk_numeral(1, ebits); mk_unbias(denormal_exp, denormal_exp); dbg_decouple("fpa2bv_unpack_denormal_exp", denormal_exp); @@ -2345,7 +2370,7 @@ void fpa2bv_converter::mk_rounding_mode(func_decl * f, expr_ref & result) } void fpa2bv_converter::dbg_decouple(const char * prefix, expr_ref & e) { - #ifdef _DEBUG + #ifdef Z3DEBUG return; // CMW: This works only for quantifier-free formulas. expr_ref new_e(m); @@ -2436,13 +2461,12 @@ void fpa2bv_converter::round(sort * s, expr_ref & rm, expr_ref & sgn, expr_ref & t = m_bv_util.mk_bv_add(exp, m_bv_util.mk_numeral(1, ebits+2)); t = m_bv_util.mk_bv_sub(t, lz); t = m_bv_util.mk_bv_sub(t, m_bv_util.mk_sign_extend(2, e_min)); - expr_ref TINY(m); + dbg_decouple("fpa2bv_rnd_t", t); + expr_ref TINY(m); TINY = m_bv_util.mk_sle(t, m_bv_util.mk_numeral((unsigned)-1, ebits+2)); - - TRACE("fpa2bv_dbg", tout << "TINY = " << mk_ismt2_pp(TINY, m) << std::endl;); - SASSERT(is_well_sorted(m, TINY)); - dbg_decouple("fpa2bv_rnd_TINY", TINY); + TRACE("fpa2bv_dbg", tout << "TINY = " << mk_ismt2_pp(TINY, m) << std::endl;); + SASSERT(is_well_sorted(m, TINY)); expr_ref beta(m); beta = m_bv_util.mk_bv_add(m_bv_util.mk_bv_sub(exp, lz), m_bv_util.mk_numeral(1, ebits+2)); @@ -2455,7 +2479,7 @@ void fpa2bv_converter::round(sort * s, expr_ref & rm, expr_ref & sgn, expr_ref & dbg_decouple("fpa2bv_rnd_e_min", e_min); dbg_decouple("fpa2bv_rnd_e_max", e_max); - expr_ref sigma(m), sigma_add(m), e_min_p2(m); + expr_ref sigma(m), sigma_add(m); sigma_add = m_bv_util.mk_bv_sub(exp, m_bv_util.mk_sign_extend(2, e_min)); sigma_add = m_bv_util.mk_bv_add(sigma_add, m_bv_util.mk_numeral(1, ebits+2)); m_simp.mk_ite(TINY, sigma_add, lz, sigma); @@ -2477,9 +2501,10 @@ void fpa2bv_converter::round(sort * s, expr_ref & rm, expr_ref & sgn, expr_ref & rs_sig(m), ls_sig(m), big_sh_sig(m), sigma_le_cap(m); sigma_neg = m_bv_util.mk_bv_neg(sigma); sigma_cap = m_bv_util.mk_numeral(sbits+2, sigma_size); - sigma_le_cap = m_bv_util.mk_sle(sigma_neg, sigma_cap); + sigma_le_cap = m_bv_util.mk_ule(sigma_neg, sigma_cap); m_simp.mk_ite(sigma_le_cap, sigma_neg, sigma_cap, sigma_neg_capped); dbg_decouple("fpa2bv_rnd_sigma_neg", sigma_neg); + dbg_decouple("fpa2bv_rnd_sigma_cap", sigma_cap); dbg_decouple("fpa2bv_rnd_sigma_neg_capped", sigma_neg_capped); sigma_lt_zero = m_bv_util.mk_sle(sigma, m_bv_util.mk_numeral((unsigned)-1, sigma_size)); dbg_decouple("fpa2bv_rnd_sigma_lt_zero", sigma_lt_zero); diff --git a/src/tactic/goal.cpp b/src/tactic/goal.cpp index 03a13aba5..51d7fedb0 100644 --- a/src/tactic/goal.cpp +++ b/src/tactic/goal.cpp @@ -255,7 +255,8 @@ void goal::get_formulas(ptr_vector & result) { } void goal::update(unsigned i, expr * f, proof * pr, expr_dependency * d) { - SASSERT(proofs_enabled() == (pr != 0 && !m().is_undef_proof(pr))); + // KLM: don't know why this assertion is no longer true + // SASSERT(proofs_enabled() == (pr != 0 && !m().is_undef_proof(pr))); if (m_inconsistent) return; if (proofs_enabled()) { diff --git a/src/tactic/sls/sls_tactic.cpp b/src/tactic/sls/sls_tactic.cpp index fff082bd8..ea5d60668 100644 --- a/src/tactic/sls/sls_tactic.cpp +++ b/src/tactic/sls/sls_tactic.cpp @@ -189,14 +189,14 @@ class sls_tactic : public tactic { bool what_if(goal_ref const & g, func_decl * fd, const unsigned & fd_inx, const mpz & temp, double & best_score, unsigned & best_const, mpz & best_value) { - #ifdef _DEBUG + #ifdef Z3DEBUG mpz old_value; m_mpz_manager.set(old_value, m_tracker.get_value(fd)); #endif double r = incremental_score(g, fd, temp); - #ifdef _DEBUG + #ifdef Z3DEBUG TRACE("sls_whatif", tout << "WHAT IF " << fd->get_name() << " WERE " << m_mpz_manager.to_string(temp) << " --> " << r << std::endl; ); diff --git a/src/tactic/smtlogics/qfbv_tactic.cpp b/src/tactic/smtlogics/qfbv_tactic.cpp index 51b040332..ac53ca0c8 100644 --- a/src/tactic/smtlogics/qfbv_tactic.cpp +++ b/src/tactic/smtlogics/qfbv_tactic.cpp @@ -43,6 +43,8 @@ tactic * mk_qfbv_tactic(ast_manager & m, params_ref const & p) { simp2_p.set_bool("push_ite_bv", false); simp2_p.set_bool("local_ctx", true); simp2_p.set_uint("local_ctx_limit", 10000000); + simp2_p.set_bool("flat", true); // required by som + simp2_p.set_bool("hoist_mul", false); // required by som params_ref local_ctx_p = p; local_ctx_p.set_bool("local_ctx", true); diff --git a/src/util/id_gen.h b/src/util/id_gen.h index b1713b524..c6d22246d 100644 --- a/src/util/id_gen.h +++ b/src/util/id_gen.h @@ -57,6 +57,11 @@ public: m_free_ids.finalize(); } + unsigned show_hash(){ + unsigned h = string_hash((char *)&m_free_ids[0],m_free_ids.size()*sizeof(unsigned),17); + return hash_u_u(h,m_next_id); + } + /** \brief Return N if the range of ids generated by this module is in the set [0..N) */ diff --git a/src/util/mpf.cpp b/src/util/mpf.cpp index 735d21c99..5910f55c9 100644 --- a/src/util/mpf.cpp +++ b/src/util/mpf.cpp @@ -360,6 +360,8 @@ void mpf_manager::set(mpf & o, unsigned ebits, unsigned sbits, mpf_rounding_mode mk_inf(ebits, sbits, x.sign, o); else if (is_zero(x)) mk_zero(ebits, sbits, x.sign, o); + else if (x.ebits == ebits && x.sbits == sbits) + set(o, x); else { set(o, x); unpack(o, true); @@ -1378,12 +1380,12 @@ bool mpf_manager::has_top_exp(mpf const & x) { } mpf_exp_t mpf_manager::mk_bot_exp(unsigned ebits) { - SASSERT(ebits > 0); + SASSERT(ebits >= 2); return m_mpz_manager.get_int64(m_powers2.m1(ebits-1, true)); } mpf_exp_t mpf_manager::mk_top_exp(unsigned ebits) { - SASSERT(ebits > 0); + SASSERT(ebits >= 2); return m_mpz_manager.get_int64(m_powers2(ebits-1)); } diff --git a/src/util/mpn.cpp b/src/util/mpn.cpp index 2cc35776f..6a544f583 100644 --- a/src/util/mpn.cpp +++ b/src/util/mpn.cpp @@ -31,7 +31,7 @@ mpn_manager static_mpn_manager; const mpn_digit mpn_manager::zero = 0; mpn_manager::mpn_manager() { -#ifdef _DEBUG +#ifdef Z3DEBUG trace_enabled=true; #endif } @@ -43,7 +43,7 @@ int mpn_manager::compare(mpn_digit const * a, size_t const lnga, mpn_digit const * b, size_t const lngb) const { int res = 0; - #ifdef _DEBUG + #ifdef Z3DEBUG if (trace_enabled) STRACE("mpn", tout << "[mpn] "; ); #endif @@ -60,7 +60,7 @@ int mpn_manager::compare(mpn_digit const * a, size_t const lnga, res = -1; } - #ifdef _DEBUG + #ifdef Z3DEBUG if (trace_enabled) STRACE("mpn", tout << ((res == 1) ? " > " : (res == -1) ? " < " : " == "); ); #endif @@ -212,7 +212,7 @@ bool mpn_manager::div(mpn_digit const * numer, size_t const lnum, trace(numer, lnum, denom, lden, "%"); trace_nl(rem, lden); -#ifdef _DEBUG +#ifdef Z3DEBUG mpn_sbuffer temp(lnum+1, 0); mul(quot, lnum-lden+1, denom, lden, temp.c_ptr()); size_t real_size; @@ -340,7 +340,7 @@ bool mpn_manager::div_n(mpn_sbuffer & numer, mpn_sbuffer const & denom, // Replace numer[j+n]...numer[j] with // numer[j+n]...numer[j] - q * (denom[n-1]...denom[0]) mpn_digit q_hat_small = (mpn_digit)q_hat; - #ifdef _DEBUG + #ifdef Z3DEBUG trace_enabled = false; #endif mul(&q_hat_small, 1, denom.c_ptr(), n, t_ms.c_ptr()); @@ -354,7 +354,7 @@ bool mpn_manager::div_n(mpn_sbuffer & numer, mpn_sbuffer const & denom, for (size_t i = 0; i < n+1; i++) numer[j+i] = t_ab[i]; } - #ifdef _DEBUG + #ifdef Z3DEBUG trace_enabled = true; #endif STRACE("mpn_div", tout << "q_hat=" << q_hat << " r_hat=" << r_hat; @@ -416,7 +416,7 @@ void mpn_manager::display_raw(std::ostream & out, mpn_digit const * a, size_t co void mpn_manager::trace(mpn_digit const * a, size_t const lnga, mpn_digit const * b, size_t const lngb, const char * op) const { -#ifdef _DEBUG +#ifdef Z3DEBUG if (trace_enabled) STRACE("mpn", tout << "[mpn] " << to_string(a, lnga, char_buf, sizeof(char_buf)); tout << " " << op << " " << to_string(b, lngb, char_buf, sizeof(char_buf)); @@ -425,14 +425,14 @@ void mpn_manager::trace(mpn_digit const * a, size_t const lnga, } void mpn_manager::trace(mpn_digit const * a, size_t const lnga) const { -#ifdef _DEBUG +#ifdef Z3DEBUG if (trace_enabled) STRACE("mpn", tout << to_string(a, lnga, char_buf, sizeof(char_buf)); ); #endif } void mpn_manager::trace_nl(mpn_digit const * a, size_t const lnga) const { -#ifdef _DEBUG +#ifdef Z3DEBUG if (trace_enabled) STRACE("mpn", tout << to_string(a, lnga, char_buf, sizeof(char_buf)) << std::endl; ); #endif diff --git a/src/util/mpn.h b/src/util/mpn.h index f995004c2..219368de3 100644 --- a/src/util/mpn.h +++ b/src/util/mpn.h @@ -101,7 +101,7 @@ private: bool div_n(mpn_sbuffer & numer, mpn_sbuffer const & denom, mpn_digit * quot, mpn_digit * rem); - #ifdef _DEBUG + #ifdef Z3DEBUG mutable char char_buf[4096]; bool trace_enabled; #endif