diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 5bb7d2cad..d93e7eeef 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -41,7 +41,7 @@ jobs: type=edge type=sha,prefix=ubuntu-20.04-bare-z3-sha- - name: Build and push Bare Z3 Docker Image - uses: docker/build-push-action@v5.1.0 + uses: docker/build-push-action@v5.3.0 with: context: . push: true diff --git a/CMakeLists.txt b/CMakeLists.txt index 39f89ae74..cdccb8e4b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.16) set(CMAKE_USER_MAKE_RULES_OVERRIDE_CXX "${CMAKE_CURRENT_SOURCE_DIR}/cmake/cxx_compiler_flags_overrides.cmake") -project(Z3 VERSION 4.12.6.0 LANGUAGES CXX) +project(Z3 VERSION 4.13.1.0 LANGUAGES CXX) ################################################################################ # Project version diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 9b041c07e..3830f566f 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -10,6 +10,10 @@ Version 4.next - native word level bit-vector solving. - introduction of simple induction lemmas to handle a limited repertoire of induction proofs. +Version 4.13.0 +============== +- add ARM64 wheels for Python, thanks to Steven Moy, smoy + Version 4.12.6 ============== - remove expensive rewrite that coalesces adjacent stores diff --git a/azure-pipelines.yml b/azure-pipelines.yml index d9d2ab2b2..b3bc0f226 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -43,6 +43,39 @@ jobs: - ${{if eq(variables['runRegressions'], 'True')}}: - template: scripts/test-regressions.yml +- job: LinuxBuildsArm64 + displayName: "ManyLinux ARM64 build" + variables: + name: ManyLinux + python: "/opt/python/cp37-cp37m/bin/python" + pool: + vmImage: "ubuntu-latest" + container: "quay.io/pypa/manylinux2014_x86_64:latest" + steps: + - script: curl -L -o /tmp/arm-toolchain.tar.xz 'https://developer.arm.com/-/media/Files/downloads/gnu/11.2-2022.02/binrel/gcc-arm-11.2-2022.02-x86_64-aarch64-none-linux-gnu.tar.xz?rev=33c6e30e5ac64e6dba8f0431f2c35f1b&hash=9918A05BF47621B632C7A5C8D2BB438FB80A4480' + - script: mkdir -p /tmp/arm-toolchain/ + - script: tar xf /tmp/arm-toolchain.tar.xz -C /tmp/arm-toolchain/ --strip-components=1 + - script: echo '##vso[task.prependpath]/tmp/arm-toolchain/bin' + - script: echo '##vso[task.prependpath]/tmp/arm-toolchain/aarch64-none-linux-gnu/libc/usr/bin' + - script: echo $PATH + - script: stat /tmp/arm-toolchain/bin/aarch64-none-linux-gnu-gcc + - task: PythonScript@0 + displayName: Build + inputs: + scriptSource: 'filepath' + scriptPath: scripts/mk_unix_dist.py + arguments: --nodotnet --nojava --arch=arm64 + pythonInterpreter: $(python) + - task: CopyFiles@2 + inputs: + sourceFolder: dist + contents: '*.zip' + targetFolder: $(Build.ArtifactStagingDirectory) + - task: PublishPipelineArtifact@0 + inputs: + artifactName: 'ManyLinuxBuildArm64' + targetPath: $(Build.ArtifactStagingDirectory) + - job: "Ubuntu20OCaml" displayName: "Ubuntu 20 with OCaml" pool: diff --git a/scripts/mk_project.py b/scripts/mk_project.py index 8a6b1b943..95a5659d7 100644 --- a/scripts/mk_project.py +++ b/scripts/mk_project.py @@ -8,7 +8,7 @@ from mk_util import * def init_version(): - set_version(4, 12, 6, 0) # express a default build version or pick up ci build version + set_version(4, 13, 0, 1) # express a default build version or pick up ci build version # Z3 Project definition def init_project_def(): @@ -60,7 +60,7 @@ def init_project_def(): add_lib('smt', ['bit_blaster', 'macros', 'normal_forms', 'cmd_context', 'proto_model', 'solver_assertions', 'substitution', 'grobner', 'simplex', 'proofs', 'pattern', 'parser_util', 'fpa', 'lp']) add_lib('polysat', ['util', 'dd', 'sat'], 'sat/smt/polysat'), - add_lib('sat_smt', ['sat', 'euf', 'smt', 'tactic', 'solver', 'smt_params', 'bit_blaster', 'fpa', 'mbp', 'polysat', 'normal_forms', 'lp', 'pattern', 'qe_lite'], 'sat/smt') + add_lib('sat_smt', ['sat', 'ast_sls', 'euf', 'smt', 'tactic', 'solver', 'smt_params', 'bit_blaster', 'fpa', 'mbp', 'polysat', 'normal_forms', 'lp', 'pattern', 'qe_lite'], 'sat/smt') add_lib('sat_tactic', ['tactic', 'sat', 'solver', 'sat_smt'], 'sat/tactic') add_lib('nlsat_tactic', ['nlsat', 'sat_tactic', 'arith_tactics'], 'nlsat/tactic') add_lib('bv_tactics', ['tactic', 'bit_blaster', 'core_tactics'], 'tactic/bv') diff --git a/scripts/mk_unix_dist.py b/scripts/mk_unix_dist.py index 3b1e71391..d967e9109 100644 --- a/scripts/mk_unix_dist.py +++ b/scripts/mk_unix_dist.py @@ -118,7 +118,9 @@ def check_build_dir(path): # Create a build directory using mk_make.py def mk_build_dir(path): + global LINUX_X64 if not check_build_dir(path) or FORCE_MK: + env = os.environ opts = [sys.executable, os.path.join('scripts', 'mk_make.py'), "-b", path, "--staticlib"] if DOTNET_CORE_ENABLED: opts.append('--dotnet') @@ -133,7 +135,17 @@ def mk_build_dir(path): opts.append('--python') if mk_util.IS_ARCH_ARM64: opts.append('--arm64=true') - if subprocess.call(opts) != 0: + if mk_util.IS_ARCH_ARM64 and LINUX_X64: + # we are machine x64 but build against arm64 + # so we have to do cross compiling + # the cross compiler is download from ARM GNU + # toolchain + myvar = { + "CC": "aarch64-none-linux-gnu-gcc", + "CXX": "aarch64-none-linux-gnu-g++" + } + env.update(myvar) + if subprocess.call(opts, env=env) != 0: raise MKException("Failed to generate build directory at '%s'" % path) # Create build directories @@ -159,12 +171,22 @@ def mk_z3(): return 1 def get_os_name(): + global LINUX_X64 if OS_NAME is not None: return OS_NAME import platform basic = os.uname()[0].lower() if basic == 'linux': - dist = platform.libc_ver() + if mk_util.IS_ARCH_ARM64 and LINUX_X64: + # handle cross compiling + # example: 'ldd (GNU) 2.34' + lines = subprocess.check_output(["ldd", "--version"]).decode('ascii') + first_line = lines.split("\n")[0] + ldd_version = first_line.split()[-1] + # coerce the format to platform.libc_ver() return type + dist = ('glibc', ldd_version) + else: + dist = platform.libc_ver() if len(dist) == 2 and len(dist[0]) > 0 and len(dist[1]) > 0: return '%s-%s' % (dist[0].lower(), dist[1].lower()) else: @@ -187,8 +209,14 @@ def get_os_name(): return basic def get_z3_name(): + import platform as platform_module + # Note that the platform name this function return + # has to work together with setup.py + # It's not the typical output from platform.machine() major, minor, build, revision = get_version() - if mk_util.IS_ARCH_ARM64: + if mk_util.IS_ARCH_ARM64 or platform_module.machine() == "aarch64": + # the second case handle native build on aarch64 + # TODO: we don't handle cross compile on host aarch64 to target x64 platform = "arm64" elif sys.maxsize >= 2**32: platform = "x64" diff --git a/scripts/nightly.yaml b/scripts/nightly.yaml index f6c3c66c7..00e27507e 100644 --- a/scripts/nightly.yaml +++ b/scripts/nightly.yaml @@ -1,7 +1,7 @@ variables: Major: '4' - Minor: '12' - Patch: '6' + Minor: '13' + Patch: '1' ReleaseVersion: $(Major).$(Minor).$(Patch) AssemblyVersion: $(Major).$(Minor).$(Patch).$(Build.BuildId) NightlyVersion: $(AssemblyVersion)-$(Build.buildId) @@ -194,6 +194,39 @@ stages: artifactName: 'ManyLinuxBuild' targetPath: $(Build.ArtifactStagingDirectory) + - job: LinuxBuildsArm64 + displayName: "ManyLinux ARM64 build" + variables: + name: ManyLinux + python: "/opt/python/cp37-cp37m/bin/python" + pool: + vmImage: "ubuntu-latest" + container: "quay.io/pypa/manylinux2014_x86_64:latest" + steps: + - script: curl -L -o /tmp/arm-toolchain.tar.xz 'https://developer.arm.com/-/media/Files/downloads/gnu/11.2-2022.02/binrel/gcc-arm-11.2-2022.02-x86_64-aarch64-none-linux-gnu.tar.xz?rev=33c6e30e5ac64e6dba8f0431f2c35f1b&hash=9918A05BF47621B632C7A5C8D2BB438FB80A4480' + - script: mkdir -p /tmp/arm-toolchain/ + - script: tar xf /tmp/arm-toolchain.tar.xz -C /tmp/arm-toolchain/ --strip-components=1 + - script: echo '##vso[task.prependpath]/tmp/arm-toolchain/bin' + - script: echo '##vso[task.prependpath]/tmp/arm-toolchain/aarch64-none-linux-gnu/libc/usr/bin' + - script: echo $PATH + - script: stat /tmp/arm-toolchain/bin/aarch64-none-linux-gnu-gcc + - task: PythonScript@0 + displayName: Build + inputs: + scriptSource: 'filepath' + scriptPath: scripts/mk_unix_dist.py + arguments: --nodotnet --nojava --arch=arm64 + pythonInterpreter: $(python) + - task: CopyFiles@2 + inputs: + sourceFolder: dist + contents: '*.zip' + targetFolder: $(Build.ArtifactStagingDirectory) + - task: PublishPipelineArtifact@0 + inputs: + artifactName: 'ManyLinuxBuildArm64' + targetPath: $(Build.ArtifactStagingDirectory) + - template: build-win-signed.yml parameters: ReleaseVersion: $(ReleaseVersion) @@ -247,7 +280,7 @@ stages: displayName: 'Download macOS Arm64 Build' inputs: artifact: 'MacArm64' - path: $(Agent.TempDirectory)\package + path: $(Agent.TempDirectory)\package - task: NuGetToolInstaller@0 inputs: versionSpec: 5.x @@ -459,6 +492,10 @@ stages: inputs: artifactName: 'ManyLinuxBuild' targetPath: $(Agent.TempDirectory) + - task: DownloadPipelineArtifact@2 + inputs: + artifactName: 'ManyLinuxBuildArm64' + targetPath: $(Agent.TempDirectory) - task: DownloadPipelineArtifact@2 inputs: artifactName: 'macOsBuild' @@ -469,14 +506,16 @@ stages: targetPath: $(Agent.TempDirectory) - script: cd $(Agent.TempDirectory); mkdir osx-x64-bin; cd osx-x64-bin; unzip ../*x64-osx*.zip - script: cd $(Agent.TempDirectory); mkdir osx-arm64-bin; cd osx-arm64-bin; unzip ../*arm64-osx*.zip - - script: cd $(Agent.TempDirectory); mkdir libc-bin; cd libc-bin; unzip ../*glibc*.zip + - script: cd $(Agent.TempDirectory); mkdir libc-x64-bin; cd libc-x64-bin; unzip ../*x64-glibc*.zip + - script: cd $(Agent.TempDirectory); mkdir libc-arm64-bin; cd libc-arm64-bin; unzip ../*arm64-glibc*.zip # - script: cd $(Agent.TempDirectory); mkdir musl-bin; cd musl-bin; unzip ../*-linux.zip - script: cd $(Agent.TempDirectory); mkdir win32-bin; cd win32-bin; unzip ../*x86-win*.zip - script: cd $(Agent.TempDirectory); mkdir win64-bin; cd win64-bin; unzip ../*x64-win*.zip - script: python3 -m pip install --user -U setuptools wheel - script: cd src/api/python; python3 setup.py sdist # take a look at this PREMIUM HACK I came up with to get around the fact that the azure variable syntax overloads the bash syntax for subshells - - script: cd src/api/python; echo $(Agent.TempDirectory)/libc-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel + - script: cd src/api/python; echo $(Agent.TempDirectory)/libc-x64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel + - script: cd src/api/python; echo $(Agent.TempDirectory)/libc-arm64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel # - script: cd src/api/python; echo $(Agent.TempDirectory)/musl-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - script: cd src/api/python; echo $(Agent.TempDirectory)/win32-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - script: cd src/api/python; echo $(Agent.TempDirectory)/win64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel diff --git a/scripts/release.yml b/scripts/release.yml index 7a2c4ef41..8f1242f82 100644 --- a/scripts/release.yml +++ b/scripts/release.yml @@ -6,7 +6,7 @@ trigger: none variables: - ReleaseVersion: '4.12.6' + ReleaseVersion: '4.13.1' stages: @@ -199,6 +199,39 @@ stages: artifactName: 'ManyLinuxBuild' targetPath: $(Build.ArtifactStagingDirectory) + - job: LinuxBuildsArm64 + displayName: "ManyLinux ARM64 build" + variables: + name: ManyLinux + python: "/opt/python/cp37-cp37m/bin/python" + pool: + vmImage: "ubuntu-latest" + container: "quay.io/pypa/manylinux2014_x86_64:latest" + steps: + - script: curl -L -o /tmp/arm-toolchain.tar.xz 'https://developer.arm.com/-/media/Files/downloads/gnu/11.2-2022.02/binrel/gcc-arm-11.2-2022.02-x86_64-aarch64-none-linux-gnu.tar.xz?rev=33c6e30e5ac64e6dba8f0431f2c35f1b&hash=9918A05BF47621B632C7A5C8D2BB438FB80A4480' + - script: mkdir -p /tmp/arm-toolchain/ + - script: tar xf /tmp/arm-toolchain.tar.xz -C /tmp/arm-toolchain/ --strip-components=1 + - script: echo '##vso[task.prependpath]/tmp/arm-toolchain/bin' + - script: echo '##vso[task.prependpath]/tmp/arm-toolchain/aarch64-none-linux-gnu/libc/usr/bin' + - script: echo $PATH + - script: stat /tmp/arm-toolchain/bin/aarch64-none-linux-gnu-gcc + - task: PythonScript@0 + displayName: Build + inputs: + scriptSource: 'filepath' + scriptPath: scripts/mk_unix_dist.py + arguments: --nodotnet --nojava --arch=arm64 + pythonInterpreter: $(python) + - task: CopyFiles@2 + inputs: + sourceFolder: dist + contents: '*.zip' + targetFolder: $(Build.ArtifactStagingDirectory) + - task: PublishPipelineArtifact@0 + inputs: + artifactName: 'ManyLinuxBuildArm64' + targetPath: $(Build.ArtifactStagingDirectory) + - template: build-win-signed.yml parameters: ReleaseVersion: $(ReleaseVersion) @@ -458,6 +491,11 @@ stages: inputs: artifact: 'ManyLinuxBuild' path: $(Agent.TempDirectory) + - task: DownloadPipelineArtifact@2 + displayName: 'Download ManyLinux Arm64 Build' + inputs: + artifact: 'ManyLinuxBuildArm64' + path: $(Agent.TempDirectory) - task: DownloadPipelineArtifact@2 displayName: 'Download Win32 Build' inputs: @@ -470,7 +508,8 @@ stages: path: $(Agent.TempDirectory) - script: cd $(Agent.TempDirectory); mkdir osx-x64-bin; cd osx-x64-bin; unzip ../*x64-osx*.zip - script: cd $(Agent.TempDirectory); mkdir osx-arm64-bin; cd osx-arm64-bin; unzip ../*arm64-osx*.zip - - script: cd $(Agent.TempDirectory); mkdir libc-bin; cd libc-bin; unzip ../*glibc*.zip + - script: cd $(Agent.TempDirectory); mkdir libc-x64-bin; cd libc-x64-bin; unzip ../*x64-glibc*.zip + - script: cd $(Agent.TempDirectory); mkdir libc-arm64-bin; cd libc-arm64-bin; unzip ../*arm64-glibc*.zip - script: cd $(Agent.TempDirectory); mkdir win32-bin; cd win32-bin; unzip ../*x86-win*.zip - script: cd $(Agent.TempDirectory); mkdir win64-bin; cd win64-bin; unzip ../*x64-win*.zip - script: python3 -m pip install --user -U setuptools wheel @@ -478,7 +517,8 @@ stages: # take a look at this PREMIUM HACK I came up with to get around the fact that the azure variable syntax overloads the bash syntax for subshells - script: cd src/api/python; echo $(Agent.TempDirectory)/osx-x64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - script: cd src/api/python; echo $(Agent.TempDirectory)/osx-arm64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - - script: cd src/api/python; echo $(Agent.TempDirectory)/libc-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel + - script: cd src/api/python; echo $(Agent.TempDirectory)/libc-x64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel + - script: cd src/api/python; echo $(Agent.TempDirectory)/libc-arm64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - script: cd src/api/python; echo $(Agent.TempDirectory)/win32-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - script: cd src/api/python; echo $(Agent.TempDirectory)/win64-bin/* | xargs printf 'PACKAGE_FROM_RELEASE=%s\n' | xargs -I '{}' env '{}' python3 setup.py bdist_wheel - task: PublishPipelineArtifact@0 diff --git a/src/api/api_util.h b/src/api/api_util.h index 0ff2c8ddd..174d75144 100644 --- a/src/api/api_util.h +++ b/src/api/api_util.h @@ -110,6 +110,7 @@ inline param_descrs * to_param_descrs_ptr(Z3_param_descrs p) { return p == nullp Z3_TRY; \ RESET_ERROR_CODE(); \ EXTRA_CODE; \ + CHECK_IS_EXPR(n, nullptr); \ expr * _n = to_expr(n); \ ast* a = mk_c(c)->m().mk_app(FID, OP, 0, 0, 1, &_n); \ mk_c(c)->save_ast_trail(a); \ @@ -127,6 +128,8 @@ Z3_ast Z3_API NAME(Z3_context c, Z3_ast n) { \ Z3_TRY; \ RESET_ERROR_CODE(); \ EXTRA_CODE; \ + CHECK_IS_EXPR(n1, nullptr); \ + CHECK_IS_EXPR(n2, nullptr); \ expr * args[2] = { to_expr(n1), to_expr(n2) }; \ ast* a = mk_c(c)->m().mk_app(FID, OP, 0, 0, 2, args); \ mk_c(c)->save_ast_trail(a); \ diff --git a/src/api/ml/z3.ml b/src/api/ml/z3.ml index 166e474d9..0be6e57a0 100644 --- a/src/api/ml/z3.ml +++ b/src/api/ml/z3.ml @@ -8,7 +8,7 @@ open Z3enums exception Error of string -let _ = Callback.register_exception "Z3EXCEPTION" (Error "") +let () = Callback.register_exception "Z3EXCEPTION" (Error "") type context = Z3native.context @@ -27,22 +27,9 @@ struct let full_version : string = Z3native.get_full_version() - let to_string = - string_of_int major ^ "." ^ - string_of_int minor ^ "." ^ - string_of_int build ^ "." ^ - string_of_int revision + let to_string = Printf.sprintf "%d.%d.%d.%d" major minor build revision end -let mk_list f n = - let rec mk_list' i accu = - if i >= n then - List.rev accu - else - mk_list' (i + 1) ((f i)::accu) - in - mk_list' 0 [] - let check_int32 v = v = Int32.to_int (Int32.of_int v) let mk_int_expr ctx v ty = @@ -68,7 +55,7 @@ let interrupt (ctx:context) = module Symbol = struct type symbol = Z3native.symbol - let gc = Z3native.context_of_symbol + let gc s = Z3native.context_of_symbol s let kind o = symbol_kind_of_int (Z3native.get_symbol_kind (gc o) o) let is_int_symbol o = kind o = INT_SYMBOL @@ -80,8 +67,8 @@ struct | INT_SYMBOL -> string_of_int (Z3native.get_symbol_int (gc o) o) | STRING_SYMBOL -> Z3native.get_symbol_string (gc o) o - let mk_int = Z3native.mk_int_symbol - let mk_string = Z3native.mk_string_symbol + let mk_int ctx = Z3native.mk_int_symbol ctx + let mk_string ctx s = Z3native.mk_string_symbol ctx s let mk_ints ctx names = List.map (mk_int ctx) names let mk_strings ctx names = List.map (mk_string ctx) names @@ -135,12 +122,12 @@ sig val translate : ast -> context -> ast end = struct type ast = Z3native.ast - let gc = Z3native.context_of_ast + let gc a = Z3native.context_of_ast a module ASTVector = struct type ast_vector = Z3native.ast_vector - let gc = Z3native.context_of_ast_vector + let gc v = Z3native.context_of_ast_vector v let mk_ast_vector = Z3native.mk_ast_vector let get_size (x:ast_vector) = Z3native.ast_vector_size (gc x) x @@ -153,12 +140,12 @@ end = struct let to_list (x:ast_vector) = let xs = get_size x in let f i = get x i in - mk_list f xs + List.init xs f let to_expr_list (x:ast_vector) = let xs = get_size x in let f i = get x i in - mk_list f xs + List.init xs f let to_string x = Z3native.ast_vector_to_string (gc x) x end @@ -166,7 +153,7 @@ end = struct module ASTMap = struct type ast_map = Z3native.ast_map - let gc = Z3native.context_of_ast_map + let gc m = Z3native.context_of_ast_map m let mk_ast_map = Z3native.mk_ast_map let contains (x:ast_map) (key:ast) = Z3native.ast_map_contains (gc x) x key @@ -231,7 +218,7 @@ sig val mk_uninterpreted_s : context -> string -> sort end = struct type sort = Z3native.sort - let gc = Z3native.context_of_ast + let gc a = Z3native.context_of_ast a let equal a b = (a = b) || (gc a = gc b && Z3native.is_eq_sort (gc a) a b) @@ -239,7 +226,7 @@ end = struct let get_sort_kind (x:sort) = sort_kind_of_int (Z3native.get_sort_kind (gc x) x) let get_name (x:sort) = Z3native.get_sort_name (gc x) x let to_string (x:sort) = Z3native.sort_to_string (gc x) x - let mk_uninterpreted = Z3native.mk_uninterpreted_sort + let mk_uninterpreted ctx s = Z3native.mk_uninterpreted_sort ctx s let mk_uninterpreted_s (ctx:context) (s:string) = mk_uninterpreted ctx (Symbol.mk_string ctx s) end @@ -290,7 +277,7 @@ sig val apply : func_decl -> Expr.expr list -> Expr.expr end = struct type func_decl = AST.ast - let gc = Z3native.context_of_ast + let gc a = Z3native.context_of_ast a module Parameter = struct @@ -378,7 +365,7 @@ end = struct let get_domain (x:func_decl) = let n = get_domain_size x in let f i = Z3native.get_domain (gc x) x i in - mk_list f n + List.init n f let get_range (x:func_decl) = Z3native.get_range (gc x) x let get_decl_kind (x:func_decl) = decl_kind_of_int (Z3native.get_decl_kind (gc x) x) @@ -397,7 +384,7 @@ end = struct | PARAMETER_FUNC_DECL -> Parameter.P_Fdl (Z3native.get_decl_func_decl_parameter (gc x) x i) | PARAMETER_RATIONAL -> Parameter.P_Rat (Z3native.get_decl_rational_parameter (gc x) x i) in - mk_list f n + List.init n f let apply (x:func_decl) (args:Expr.expr list) = Expr.expr_of_func_app (gc x) x args end @@ -426,12 +413,12 @@ sig val set_print_mode : context -> Z3enums.ast_print_mode -> unit end = struct type params = Z3native.params - let gc = Z3native.context_of_params + let gc p = Z3native.context_of_params p module ParamDescrs = struct type param_descrs = Z3native.param_descrs - let gc = Z3native.context_of_param_descrs + let gc p = Z3native.context_of_param_descrs p let validate (x:param_descrs) (p:params) = Z3native.params_validate (gc x) p x let get_kind (x:param_descrs) (name:Symbol.symbol) = param_kind_of_int (Z3native.param_descrs_get_kind (gc x) x name) @@ -439,7 +426,7 @@ end = struct let get_names (x:param_descrs) = let n = Z3native.param_descrs_size (gc x) x in let f i = Z3native.param_descrs_get_name (gc x) x i in - mk_list f n + List.init n f let get_size (x:param_descrs) = Z3native.param_descrs_size (gc x) x let to_string (x:param_descrs) = Z3native.param_descrs_to_string (gc x) x @@ -491,7 +478,7 @@ sig val compare : expr -> expr -> int end = struct type expr = AST.ast - let gc = Z3native.context_of_ast + let gc a = Z3native.context_of_ast a let expr_of_ast a = let q = Z3enums.ast_kind_of_int (Z3native.get_ast_kind (gc a) a) in @@ -517,7 +504,7 @@ end = struct let get_args (x:expr) = let n = get_num_args x in let f i = Z3native.get_app_arg (gc x) x i in - mk_list f n + List.init n f let update (x:expr) (args:expr list) = if AST.is_app x && List.length args <> get_num_args x then @@ -567,11 +554,11 @@ open Expr module Boolean = struct - let mk_sort = Z3native.mk_bool_sort + let mk_sort ctx = Z3native.mk_bool_sort ctx let mk_const (ctx:context) (name:Symbol.symbol) = Expr.mk_const ctx name (mk_sort ctx) let mk_const_s (ctx:context) (name:string) = mk_const ctx (Symbol.mk_string ctx name) - let mk_true = Z3native.mk_true - let mk_false = Z3native.mk_false + let mk_true ctx = Z3native.mk_true ctx + let mk_false ctx = Z3native.mk_false ctx let mk_val (ctx:context) (value:bool) = if value then mk_true ctx else mk_false ctx let mk_not = Z3native.mk_not let mk_ite = Z3native.mk_ite @@ -609,7 +596,7 @@ end module Quantifier = struct type quantifier = AST.ast - let gc = Z3native.context_of_ast + let gc a = Z3native.context_of_ast a let expr_of_quantifier q = q @@ -623,14 +610,14 @@ struct module Pattern = struct type pattern = Z3native.pattern - let gc = Z3native.context_of_ast + let gc a = Z3native.context_of_ast a let get_num_terms x = Z3native.get_pattern_num_terms (gc x) x let get_terms x = let n = get_num_terms x in let f i = Z3native.get_pattern (gc x) x i in - mk_list f n + List.init n f let to_string x = Z3native.pattern_to_string (gc x) x end @@ -648,26 +635,26 @@ struct let get_patterns x = let n = get_num_patterns x in let f i = Z3native.get_quantifier_pattern_ast (gc x) x i in - mk_list f n + List.init n f let get_num_no_patterns x = Z3native.get_quantifier_num_no_patterns (gc x) x let get_no_patterns x = let n = get_num_patterns x in let f i = Z3native.get_quantifier_no_pattern_ast (gc x) x i in - mk_list f n + List.init n f let get_num_bound x = Z3native.get_quantifier_num_bound (gc x) x let get_bound_variable_names x = let n = get_num_bound x in let f i = Z3native.get_quantifier_bound_name (gc x) x i in - mk_list f n + List.init n f let get_bound_variable_sorts x = let n = get_num_bound x in let f i = Z3native.get_quantifier_bound_sort (gc x) x i in - mk_list f n + List.init n f let get_body x = Z3native.get_quantifier_body (gc x) x let mk_bound = Z3native.mk_bound @@ -746,7 +733,7 @@ end module Z3Array = struct - let mk_sort = Z3native.mk_array_sort + let mk_sort ctx domain range = Z3native.mk_array_sort ctx domain range let is_store x = AST.is_app x && FuncDecl.get_decl_kind (Expr.get_func_decl x) = OP_STORE let is_select x = AST.is_app x && FuncDecl.get_decl_kind (Expr.get_func_decl x) = OP_SELECT let is_constant_array x = AST.is_app x && FuncDecl.get_decl_kind (Expr.get_func_decl x) = OP_CONST_ARRAY @@ -806,7 +793,7 @@ end module FiniteDomain = struct - let mk_sort = Z3native.mk_finite_domain_sort + let mk_sort ctx s size = Z3native.mk_finite_domain_sort ctx s size let mk_sort_s ctx name size = mk_sort ctx (Symbol.mk_string ctx name) size let is_finite_domain (x:expr) = @@ -849,7 +836,7 @@ struct let get_column_sorts (x:Sort.sort) = let n = get_arity x in let f i = Z3native.get_relation_column (Sort.gc x) x i in - mk_list f n + List.init n f end @@ -932,12 +919,12 @@ struct let get_constructors (x:Sort.sort) = let n = get_num_constructors x in let f i = Z3native.get_datatype_sort_constructor (Sort.gc x) x i in - mk_list f n + List.init n f let get_recognizers (x:Sort.sort) = let n = (get_num_constructors x) in let f i = Z3native.get_datatype_sort_recognizer (Sort.gc x) x i in - mk_list f n + List.init n f let get_accessors (x:Sort.sort) = let n = (get_num_constructors x) in @@ -945,8 +932,8 @@ struct let fd = Z3native.get_datatype_sort_constructor (Sort.gc x) x i in let ds = Z3native.get_domain_size (FuncDecl.gc fd) fd in let g j = Z3native.get_datatype_sort_constructor_accessor (Sort.gc x) x i j in - mk_list g ds) in - mk_list f n + List.init ds g) in + List.init n f end @@ -962,21 +949,21 @@ struct let get_const_decls (x:Sort.sort) = let n = Z3native.get_datatype_sort_num_constructors (Sort.gc x) x in let f i = Z3native.get_datatype_sort_constructor (Sort.gc x) x i in - mk_list f n + List.init n f let get_const_decl (x:Sort.sort) (inx:int) = Z3native.get_datatype_sort_constructor (Sort.gc x) x inx let get_consts (x:Sort.sort) = let n = Z3native.get_datatype_sort_num_constructors (Sort.gc x) x in let f i = Expr.mk_const_f (Sort.gc x) (get_const_decl x i) in - mk_list f n + List.init n f let get_const (x:Sort.sort) (inx:int) = Expr.mk_const_f (Sort.gc x) (get_const_decl x inx) let get_tester_decls (x:Sort.sort) = let n = Z3native.get_datatype_sort_num_constructors (Sort.gc x) x in let f i = Z3native.get_datatype_sort_recognizer (Sort.gc x) x i in - mk_list f n + List.init n f let get_tester_decl (x:Sort.sort) (inx:int) = Z3native.get_datatype_sort_recognizer (Sort.gc x) x inx end @@ -1010,8 +997,8 @@ struct let get_field_decls (x:Sort.sort) = let n = get_num_fields x in - let f i =Z3native.get_tuple_sort_field_decl (Sort.gc x) x i in - mk_list f n + let f i = Z3native.get_tuple_sort_field_decl (Sort.gc x) x i in + List.init n f end @@ -1043,7 +1030,7 @@ struct module Integer = struct - let mk_sort = Z3native.mk_int_sort + let mk_sort ctx = Z3native.mk_int_sort ctx let get_int x = match Z3native.get_numeral_int (Expr.gc x) x with @@ -1070,7 +1057,7 @@ struct module Real = struct - let mk_sort = Z3native.mk_real_sort + let mk_sort ctx = Z3native.mk_real_sort ctx let get_numerator x = Z3native.get_numerator (Expr.gc x) x let get_denominator x = Z3native.get_denominator (Expr.gc x) x @@ -1467,7 +1454,7 @@ end module Goal = struct type goal = Z3native.goal - let gc = Z3native.context_of_goal + let gc g = Z3native.context_of_goal g let get_precision (x:goal) = goal_prec_of_int (Z3native.goal_precision (gc x) x) let is_precise (x:goal) = (get_precision x) = GOAL_PRECISE @@ -1486,7 +1473,7 @@ struct let get_formulas (x:goal) = let n = get_size x in let f i = Z3native.goal_formula (gc x) x i in - mk_list f n + List.init n f let get_num_exprs (x:goal) = Z3native.goal_num_exprs (gc x) x let is_decided_sat (x:goal) = Z3native.goal_is_decided_sat (gc x) x @@ -1527,17 +1514,17 @@ end module Model = struct type model = Z3native.model - let gc = Z3native.context_of_model + let gc m = Z3native.context_of_model m module FuncInterp = struct type func_interp = Z3native.func_interp - let gc = Z3native.context_of_func_interp + let gc f = Z3native.context_of_func_interp f module FuncEntry = struct type func_entry = Z3native.func_entry - let gc = Z3native.context_of_func_entry + let gc f = Z3native.context_of_func_entry f let get_value (x:func_entry) = Z3native.func_entry_get_value (gc x) x let get_num_args (x:func_entry) = Z3native.func_entry_get_num_args (gc x) x @@ -1545,7 +1532,7 @@ struct let get_args (x:func_entry) = let n = get_num_args x in let f i = Z3native.func_entry_get_arg (gc x) x i in - mk_list f n + List.init n f let to_string (x:func_entry) = let a = get_args x in @@ -1558,7 +1545,7 @@ struct let get_entries (x:func_interp) = let n = get_num_entries x in let f i = Z3native.func_interp_get_entry (gc x) x i in - mk_list f n + List.init n f let get_else (x:func_interp) = Z3native.func_interp_get_else (gc x) x @@ -1614,21 +1601,24 @@ struct let get_const_decls (x:model) = let n = (get_num_consts x) in let f i = Z3native.model_get_const_decl (gc x) x i in - mk_list f n + List.init n f let get_num_funcs (x:model) = Z3native.model_get_num_funcs (gc x) x let get_func_decls (x:model) = let n = (get_num_funcs x) in let f i = Z3native.model_get_func_decl (gc x) x i in - mk_list f n + List.init n f let get_decls (x:model) = let n_funcs = get_num_funcs x in let n_consts = get_num_consts x in let f i = Z3native.model_get_func_decl (gc x) x i in let g i = Z3native.model_get_const_decl (gc x) x i in - (mk_list f n_funcs) @ (mk_list g n_consts) + List.init (n_funcs + n_consts) (fun i -> + if i < n_funcs then f i + else g i + ) let eval (x:model) (t:expr) (completion:bool) = match Z3native.model_eval (gc x) x t completion with @@ -1641,7 +1631,7 @@ struct let get_sorts (x:model) = let n = get_num_sorts x in let f i = Z3native.model_get_sort (gc x) x i in - mk_list f n + List.init n f let sort_universe (x:model) (s:Sort.sort) = let av = Z3native.model_get_sort_universe (gc x) x s in @@ -1656,12 +1646,12 @@ struct type probe = Z3native.probe let apply (x:probe) (g:Goal.goal) = Z3native.probe_apply (gc x) x g - let get_num_probes = Z3native.get_num_probes + let get_num_probes ctx = Z3native.get_num_probes ctx let get_probe_names (ctx:context) = let n = get_num_probes ctx in let f i = Z3native.get_probe_name ctx i in - mk_list f n + List.init n f let get_probe_description = Z3native.probe_get_descr let mk_probe = Z3native.mk_probe @@ -1680,19 +1670,19 @@ end module Tactic = struct type tactic = Z3native.tactic - let gc = Z3native.context_of_tactic + let gc t = Z3native.context_of_tactic t module ApplyResult = struct type apply_result = Z3native.apply_result - let gc = Z3native.context_of_apply_result + let gc a = Z3native.context_of_apply_result a let get_num_subgoals (x:apply_result) = Z3native.apply_result_get_num_subgoals (gc x) x let get_subgoals (x:apply_result) = let n = get_num_subgoals x in let f i = Z3native.apply_result_get_subgoal (gc x) x i in - mk_list f n + List.init n f let get_subgoal (x:apply_result) (i:int) = Z3native.apply_result_get_subgoal (gc x) x i let to_string (x:apply_result) = Z3native.apply_result_to_string (gc x) x @@ -1706,23 +1696,26 @@ struct | None -> Z3native.tactic_apply (gc x) x g | Some pn -> Z3native.tactic_apply_ex (gc x) x g pn - let get_num_tactics = Z3native.get_num_tactics + let get_num_tactics ctx = Z3native.get_num_tactics ctx let get_tactic_names (ctx:context) = let n = get_num_tactics ctx in let f i = Z3native.get_tactic_name ctx i in - mk_list f n + List.init n f let get_tactic_description = Z3native.tactic_get_descr let mk_tactic = Z3native.mk_tactic let and_then (ctx:context) (t1:tactic) (t2:tactic) (ts:tactic list) = - let f p c = (match p with - | None -> Some c - | Some(x) -> Some (Z3native.tactic_and_then ctx c x)) in - match (List.fold_left f None ts) with + let f p c = + match p with + | None -> Some c + | Some x -> Some (Z3native.tactic_and_then ctx c x) + in + match List.fold_left f None ts with | None -> Z3native.tactic_and_then ctx t1 t2 - | Some(x) -> let o = Z3native.tactic_and_then ctx t2 x in + | Some x -> + let o = Z3native.tactic_and_then ctx t2 x in Z3native.tactic_and_then ctx t1 o let or_else = Z3native.tactic_or_else @@ -1744,18 +1737,18 @@ end module Simplifier = struct type simplifier = Z3native.simplifier - let gc = Z3native.context_of_simplifier + let gc s = Z3native.context_of_simplifier s let get_help (x:simplifier) = Z3native.simplifier_get_help (gc x) x let get_param_descrs (x:simplifier) = Z3native.simplifier_get_param_descrs (gc x) x - let get_num_simplifiers = Z3native.get_num_simplifiers + let get_num_simplifiers ctx = Z3native.get_num_simplifiers ctx let get_simplifier_names (ctx:context) = let n = get_num_simplifiers ctx in let f i = Z3native.get_simplifier_name ctx i in - mk_list f n + List.init n f let get_simplifier_description = Z3native.simplifier_get_descr @@ -1778,7 +1771,7 @@ end module Statistics = struct type statistics = Z3native.stats - let gc = Z3native.context_of_stats + let gc s = Z3native.context_of_stats s module Entry = struct @@ -1822,12 +1815,12 @@ struct else Entry.create_sd k (Z3native.stats_get_double_value (gc x) x i) in - mk_list f n + List.init n f let get_keys (x:statistics) = let n = get_size x in let f i = Z3native.stats_get_key (gc x) x i in - mk_list f n + List.init n f let get (x:statistics) (key:string) = try Some(List.find (fun c -> Entry.get_key c = key) (get_entries x)) with @@ -1842,7 +1835,7 @@ module Solver = struct type solver = Z3native.solver type status = UNSATISFIABLE | UNKNOWN | SATISFIABLE - let gc = Z3native.context_of_solver + let gc s = Z3native.context_of_solver s let string_of_status (s:status) = match s with | UNSATISFIABLE -> "unsatisfiable" @@ -1923,7 +1916,7 @@ end module Fixedpoint = struct type fixedpoint = Z3native.fixedpoint - let gc = Z3native.context_of_fixedpoint + let gc f = Z3native.context_of_fixedpoint f let get_help x = Z3native.fixedpoint_get_help (gc x) x let set_parameters x = Z3native.fixedpoint_set_params (gc x) x @@ -2055,22 +2048,22 @@ struct formula let parse_smtlib2_string (ctx:context) (str:string) (sort_names:Symbol.symbol list) (sorts:Sort.sort list) (decl_names:Symbol.symbol list) (decls:func_decl list) = - let csn = List.length sort_names in let cs = List.length sorts in - let cdn = List.length decl_names in let cd = List.length decls in - if csn <> cs || cdn <> cd then + if List.compare_length_with sort_names cs <> 0 + || List.compare_length_with decl_names cd <> 0 + then raise (Error "Argument size mismatch") else Z3native.parse_smtlib2_string ctx str cs sort_names sorts cd decl_names decls let parse_smtlib2_file (ctx:context) (file_name:string) (sort_names:Symbol.symbol list) (sorts:Sort.sort list) (decl_names:Symbol.symbol list) (decls:func_decl list) = - let csn = List.length sort_names in let cs = List.length sorts in - let cdn = List.length decl_names in let cd = List.length decls in - if csn <> cs || cdn <> cd then + if List.compare_length_with sort_names cs <> 0 + || List.compare_length_with decl_names cd <> 0 + then raise (Error "Argument size mismatch") else Z3native.parse_smtlib2_file ctx file_name @@ -2082,7 +2075,7 @@ module RCF = struct type rcf_num = Z3native.rcf_num - let del (ctx:context) (a:rcf_num) = Z3native.rcf_del ctx a + let del (ctx:context) (a:rcf_num) : unit = Z3native.rcf_del ctx a let del_list (ctx:context) (ns:rcf_num list) = List.iter (fun a -> Z3native.rcf_del ctx a) ns let mk_rational (ctx:context) (v:string) = Z3native.rcf_mk_rational ctx v let mk_small_int (ctx:context) (v:int) = Z3native.rcf_mk_small_int ctx v @@ -2093,7 +2086,14 @@ struct let mk_roots (ctx:context) (a:rcf_num list) = let n, r = Z3native.rcf_mk_roots ctx (List.length a) a in - List.init n (fun x -> List.nth r x) + let _i, l = + (* keep only the first `n` elements of the list `r` *) + List.fold_left (fun (i, acc) x -> + if i = 0 then i, acc + else (i - 1, x :: acc) + ) (n, []) r + in + List.rev l let add (ctx:context) (a:rcf_num) (b:rcf_num) = Z3native.rcf_add ctx a b let sub (ctx:context) (a:rcf_num) (b:rcf_num) = Z3native.rcf_sub ctx a b diff --git a/src/api/ml/z3native.ml.pre b/src/api/ml/z3native.ml.pre index 1d75d5d1e..fe4e8a194 100644 --- a/src/api/ml/z3native.ml.pre +++ b/src/api/ml/z3native.ml.pre @@ -4,36 +4,36 @@ open Z3enums (**/**) type ptr -and symbol = ptr -and config = ptr -and context = ptr -and ast = ptr -and app = ast -and sort = ast -and func_decl = ast -and pattern = ast -and model = ptr -and literals = ptr -and constructor = ptr -and constructor_list = ptr -and solver = ptr -and solver_callback = ptr -and goal = ptr -and tactic = ptr -and simplifier = ptr -and params = ptr -and parser_context = ptr -and probe = ptr -and stats = ptr -and ast_vector = ptr -and ast_map = ptr -and apply_result = ptr -and func_interp = ptr -and func_entry = ptr -and fixedpoint = ptr -and optimize = ptr -and param_descrs = ptr -and rcf_num = ptr +type symbol = ptr +type config = ptr +type context = ptr +type ast = ptr +type app = ast +type sort = ast +type func_decl = ast +type pattern = ast +type model = ptr +type literals = ptr +type constructor = ptr +type constructor_list = ptr +type solver = ptr +type solver_callback = ptr +type goal = ptr +type tactic = ptr +type simplifier = ptr +type params = ptr +type parser_context = ptr +type probe = ptr +type stats = ptr +type ast_vector = ptr +type ast_map = ptr +type apply_result = ptr +type func_interp = ptr +type func_entry = ptr +type fixedpoint = ptr +type optimize = ptr +type param_descrs = ptr +type rcf_num = ptr external set_internal_error_handler : ptr -> unit = "n_set_internal_error_handler" diff --git a/src/api/python/setup.py b/src/api/python/setup.py index c3f65f848..5faf5aad1 100644 --- a/src/api/python/setup.py +++ b/src/api/python/setup.py @@ -297,6 +297,15 @@ if 'bdist_wheel' in sys.argv and '--plat-name' not in sys.argv: elif distos == 'glibc': if arch == 'x64': plat_name = 'manylinux2014_x86_64' + elif arch == 'arm64' or arch == 'aarch64': + # context on why are we match on arm64 + # but use aarch64 on the plat_name is + # due to a workaround current python + # legacy build doesn't support aarch64 + # so using the currently supported arm64 + # build and simply rename it to aarch64 + # see full context on #7148 + plat_name = 'manylinux2014_aarch64' else: plat_name = 'manylinux2014_i686' elif distos == 'linux' and os_id == 'alpine': diff --git a/src/api/python/z3/z3.py b/src/api/python/z3/z3.py index 16db39afd..9a3dadda2 100644 --- a/src/api/python/z3/z3.py +++ b/src/api/python/z3/z3.py @@ -5445,10 +5445,10 @@ def EnumSort(name, values, ctx=None): num = len(values) _val_names = (Symbol * num)() for i in range(num): - _val_names[i] = to_symbol(values[i]) + _val_names[i] = to_symbol(values[i], ctx) _values = (FuncDecl * num)() _testers = (FuncDecl * num)() - name = to_symbol(name) + name = to_symbol(name, ctx) S = DatatypeSortRef(Z3_mk_enumeration_sort(ctx.ref(), name, num, _val_names, _values, _testers), ctx) V = [] for i in range(num): diff --git a/src/ast/ast_smt_pp.cpp b/src/ast/ast_smt_pp.cpp index bea669438..0da4f1c12 100644 --- a/src/ast/ast_smt_pp.cpp +++ b/src/ast/ast_smt_pp.cpp @@ -986,7 +986,7 @@ void ast_smt_pp::display_smt2(std::ostream& strm, expr* n) { ast_mark sort_mark; for (sort* s : decls.get_sorts()) { if (!(*m_is_declared)(s)) { - smt_printer p(strm, m, ql, rn, m_logic, true, true, m_simplify_implies, 0); + smt_printer p(strm, m, ql, rn, m_logic, true, m_simplify_implies, 0); p.pp_sort_decl(sort_mark, s); } } @@ -994,7 +994,7 @@ void ast_smt_pp::display_smt2(std::ostream& strm, expr* n) { for (unsigned i = 0; i < decls.get_num_decls(); ++i) { func_decl* d = decls.get_func_decls()[i]; if (!(*m_is_declared)(d)) { - smt_printer p(strm, m, ql, rn, m_logic, true, true, m_simplify_implies, 0); + smt_printer p(strm, m, ql, rn, m_logic, true, m_simplify_implies, 0); p(d); strm << "\n"; } @@ -1003,20 +1003,20 @@ void ast_smt_pp::display_smt2(std::ostream& strm, expr* n) { #endif for (expr* a : m_assumptions) { - smt_printer p(strm, m, ql, rn, m_logic, false, true, m_simplify_implies, 1); + smt_printer p(strm, m, ql, rn, m_logic, false, m_simplify_implies, 1); strm << "(assert\n "; p(a); strm << ")\n"; } for (expr* a : m_assumptions_star) { - smt_printer p(strm, m, ql, rn, m_logic, false, true, m_simplify_implies, 1); + smt_printer p(strm, m, ql, rn, m_logic, false, m_simplify_implies, 1); strm << "(assert\n "; p(a); strm << ")\n"; } - smt_printer p(strm, m, ql, rn, m_logic, false, true, m_simplify_implies, 0); + smt_printer p(strm, m, ql, rn, m_logic, false, m_simplify_implies, 0); if (m.is_bool(n)) { if (!m.is_true(n)) { strm << "(assert\n "; diff --git a/src/ast/bv_decl_plugin.cpp b/src/ast/bv_decl_plugin.cpp index 327e280cf..5dd9f6080 100644 --- a/src/ast/bv_decl_plugin.cpp +++ b/src/ast/bv_decl_plugin.cpp @@ -942,3 +942,13 @@ app* bv_util::mk_int2bv(unsigned sz, expr* e) { parameter p(sz); return m_manager.mk_app(get_fid(), OP_INT2BV, 1, &p, 1, &e); } + +app* bv_util::mk_bv_rotate_left(expr* arg, unsigned n) { + parameter p(n); + return m_manager.mk_app(get_fid(), OP_ROTATE_LEFT, 1, &p, 1, &arg); +} + +app* bv_util::mk_bv_rotate_right(expr* arg, unsigned n) { + parameter p(n); + return m_manager.mk_app(get_fid(), OP_ROTATE_RIGHT, 1, &p, 1, &arg); +} \ No newline at end of file diff --git a/src/ast/bv_decl_plugin.h b/src/ast/bv_decl_plugin.h index 89588ee0e..137dc754f 100644 --- a/src/ast/bv_decl_plugin.h +++ b/src/ast/bv_decl_plugin.h @@ -445,6 +445,11 @@ public: MATCH_BINARY(is_bv_sdivi); MATCH_BINARY(is_bv_udivi); MATCH_BINARY(is_bv_smodi); + MATCH_BINARY(is_bv_urem0); + MATCH_BINARY(is_bv_srem0); + MATCH_BINARY(is_bv_sdiv0); + MATCH_BINARY(is_bv_udiv0); + MATCH_BINARY(is_bv_smod0); MATCH_UNARY(is_bit2bool); MATCH_UNARY(is_int2bv); bool is_bit2bool(expr* e, expr*& bv, unsigned& idx) const; @@ -546,16 +551,21 @@ public: app * mk_bv2int(expr* e); app * mk_int2bv(unsigned sz, expr* e); + app* mk_bv_rotate_left(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_EXT_ROTATE_LEFT, arg1, arg2); } + app* mk_bv_rotate_right(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_EXT_ROTATE_RIGHT, arg1, arg2); } + app* mk_bv_rotate_left(expr* arg, unsigned n); + app* mk_bv_rotate_right(expr* arg, unsigned n); + // TODO: all these binary ops commute (right?) but it'd be more logical to swap `n` & `m` in the `return` - app * mk_bvsmul_no_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_NO_OVFL, n, m); } - app * mk_bvsmul_no_udfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_NO_UDFL, n, m); } - app * mk_bvumul_no_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUMUL_NO_OVFL, n, m); } - app * mk_bvsmul_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_OVFL, n, m); } - app * mk_bvumul_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUMUL_OVFL, n, m); } + app * mk_bvsmul_no_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_NO_OVFL, m, n); } + app * mk_bvsmul_no_udfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_NO_UDFL, m, n); } + app * mk_bvumul_no_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUMUL_NO_OVFL, m, n); } + app * mk_bvsmul_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSMUL_OVFL, m, n); } + app * mk_bvumul_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUMUL_OVFL, m, n); } app * mk_bvsdiv_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSDIV_OVFL, m, n); } app * mk_bvneg_ovfl(expr* m) { return m_manager.mk_app(get_fid(), OP_BNEG_OVFL, m); } - app * mk_bvuadd_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUADD_OVFL, n, m); } - app * mk_bvsadd_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSADD_OVFL, n, m); } + app * mk_bvuadd_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUADD_OVFL, m, n); } + app * mk_bvsadd_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSADD_OVFL, m, n); } app * mk_bvusub_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BUSUB_OVFL, m, n); } app * mk_bvssub_ovfl(expr* m, expr* n) { return m_manager.mk_app(get_fid(), OP_BSSUB_OVFL, m, n); } diff --git a/src/ast/macros/macro_manager.cpp b/src/ast/macros/macro_manager.cpp index bbe7f245c..b7c94b1b5 100644 --- a/src/ast/macros/macro_manager.cpp +++ b/src/ast/macros/macro_manager.cpp @@ -319,14 +319,21 @@ struct macro_manager::macro_expander_cfg : public default_rewriter_cfg { if (m.proofs_enabled()) { expr_ref instance = s(q->get_expr(), num, subst_args.data()); expr* eq, * lhs, * rhs; + + expr* q_inst = m.mk_or(m.mk_not(q), instance); + proof * qi_pr = m.mk_quant_inst(q_inst, num, subst_args.data()); if (m.is_not(instance, eq) && m.is_eq(eq, lhs, rhs)) { + expr_ref instance2(m); if (revert) - instance = m.mk_eq(m.mk_not(lhs), rhs); + instance2 = m.mk_eq(m.mk_not(lhs), rhs); else - instance = m.mk_eq(lhs, m.mk_not(rhs)); + instance2 = m.mk_eq(lhs, m.mk_not(rhs)); + expr* q_inst2 = m.mk_or(m.mk_not(q), instance2); + proof* eq_pr = m.mk_rewrite(q_inst, q_inst2); + qi_pr = m.mk_modus_ponens(qi_pr, eq_pr); + instance = instance2; } SASSERT(m.is_eq(instance)); - proof * qi_pr = m.mk_quant_inst(m.mk_or(m.mk_not(q), instance), num, subst_args.data()); proof * q_pr = mm.m_decl2macro_pr.find(d); proof * prs[2] = { qi_pr, q_pr }; p = m.mk_unit_resolution(2, prs); diff --git a/src/ast/sls/CMakeLists.txt b/src/ast/sls/CMakeLists.txt index 17b803cf3..24eaec4dc 100644 --- a/src/ast/sls/CMakeLists.txt +++ b/src/ast/sls/CMakeLists.txt @@ -1,7 +1,12 @@ z3_add_component(ast_sls SOURCES bvsls_opt_engine.cpp - sls_engine.cpp + bv_sls.cpp + bv_sls_eval.cpp + bv_sls_fixed.cpp + bv_sls_terms.cpp + sls_engine.cpp + sls_valuation.cpp COMPONENT_DEPENDENCIES ast converters diff --git a/src/ast/sls/bv_sls.cpp b/src/ast/sls/bv_sls.cpp new file mode 100644 index 000000000..f80a362ba --- /dev/null +++ b/src/ast/sls/bv_sls.cpp @@ -0,0 +1,295 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls.cpp + +Abstract: + + A Stochastic Local Search (SLS) engine + Uses invertibility conditions, + interval annotations + don't care annotations + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ + +#include "ast/sls/bv_sls.h" +#include "ast/ast_pp.h" +#include "ast/ast_ll_pp.h" +#include "params/sls_params.hpp" + +namespace bv { + + sls::sls(ast_manager& m): + m(m), + bv(m), + m_terms(m), + m_eval(m) + {} + + void sls::init() { + m_terms.init(); + } + + void sls::init_eval(std::function& eval) { + m_eval.init_eval(m_terms.assertions(), eval); + m_eval.tighten_range(m_terms.assertions()); + init_repair(); + } + + void sls::init_repair() { + m_repair_down = UINT_MAX; + m_repair_up.reset(); + m_repair_roots.reset(); + for (auto* e : m_terms.assertions()) { + if (!m_eval.bval0(e)) { + m_eval.set(e, true); + m_repair_roots.insert(e->get_id()); + } + } + for (auto* t : m_terms.terms()) { + if (t && !re_eval_is_correct(t)) + m_repair_roots.insert(t->get_id()); + } + } + + void sls::init_repair_goal(app* t) { + if (m.is_bool(t)) + m_eval.set(t, m_eval.bval1(t)); + else if (bv.is_bv(t)) { + auto& v = m_eval.wval(t); + v.bits().copy_to(v.nw, v.eval); + } + } + + void sls::reinit_eval() { + std::function eval = [&](expr* e, unsigned i) { + auto should_keep = [&]() { + return m_rand() % 100 <= 92; + }; + if (m.is_bool(e)) { + if (m_eval.is_fixed0(e) || should_keep()) + return m_eval.bval0(e); + } + else if (bv.is_bv(e)) { + auto& w = m_eval.wval(e); + if (w.fixed.get(i) || should_keep()) + return w.get_bit(i); + } + return m_rand() % 2 == 0; + }; + m_eval.init_eval(m_terms.assertions(), eval); + init_repair(); + } + + std::pair sls::next_to_repair() { + app* e = nullptr; + if (m_repair_down != UINT_MAX) { + e = m_terms.term(m_repair_down); + m_repair_down = UINT_MAX; + return { true, e }; + } + + if (!m_repair_up.empty()) { + unsigned index = m_repair_up.elem_at(m_rand(m_repair_up.size())); + m_repair_up.remove(index); + e = m_terms.term(index); + return { false, e }; + } + + while (!m_repair_roots.empty()) { + unsigned index = m_repair_roots.elem_at(m_rand(m_repair_roots.size())); + e = m_terms.term(index); + if (m_terms.is_assertion(e) && !m_eval.bval1(e)) { + SASSERT(m_eval.bval0(e)); + return { true, e }; + } + if (!re_eval_is_correct(e)) { + init_repair_goal(e); + return { true, e }; + } + m_repair_roots.remove(index); + } + + return { false, nullptr }; + } + + lbool sls::search() { + // init and init_eval were invoked + unsigned n = 0; + for (; n++ < m_config.m_max_repairs && m.inc(); ) { + auto [down, e] = next_to_repair(); + if (!e) + return l_true; + + + trace_repair(down, e); + + ++m_stats.m_moves; + + if (down) + try_repair_down(e); + else + try_repair_up(e); + } + return l_undef; + } + + + lbool sls::operator()() { + lbool res = l_undef; + m_stats.reset(); + m_stats.m_restarts = 0; + do { + res = search(); + if (res != l_undef) + break; + trace(); + reinit_eval(); + } + while (m.inc() && m_stats.m_restarts++ < m_config.m_max_restarts); + + return res; + } + + void sls::try_repair_down(app* e) { + + unsigned n = e->get_num_args(); + if (n == 0) { + if (m.is_bool(e)) + m_eval.set(e, m_eval.bval1(e)); + else + VERIFY(m_eval.wval(e).commit_eval()); + + for (auto p : m_terms.parents(e)) + m_repair_up.insert(p->get_id()); + return; + } + + unsigned s = m_rand(n); + for (unsigned i = 0; i < n; ++i) { + auto j = (i + s) % n; + if (m_eval.try_repair(e, j)) { + set_repair_down(e->get_arg(j)); + return; + } + } + // search a new root / random walk to repair + } + + void sls::try_repair_up(app* e) { + + if (m_terms.is_assertion(e) || !m_eval.repair_up(e)) + m_repair_roots.insert(e->get_id()); + else { + if (!eval_is_correct(e)) { + verbose_stream() << "incorrect eval #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n"; + } + SASSERT(eval_is_correct(e)); + for (auto p : m_terms.parents(e)) + m_repair_up.insert(p->get_id()); + } + } + + bool sls::eval_is_correct(app* e) { + if (!m_eval.can_eval1(e)) + return false; + if (m.is_bool(e)) + return m_eval.bval0(e) == m_eval.bval1(e); + if (bv.is_bv(e)) { + auto const& v = m_eval.wval(e); + return v.eval == v.bits(); + } + UNREACHABLE(); + return false; + } + + + bool sls::re_eval_is_correct(app* e) { + if (!m_eval.can_eval1(e)) + return false; + if (m.is_bool(e)) + return m_eval.bval0(e) == m_eval.bval1(e); + if (bv.is_bv(e)) { + auto const& v = m_eval.eval(e); + return v.eval == v.bits(); + } + UNREACHABLE(); + return false; + } + + model_ref sls::get_model() { + model_ref mdl = alloc(model, m); + auto& terms = m_eval.sort_assertions(m_terms.assertions()); + for (expr* e : terms) { + if (!re_eval_is_correct(to_app(e))) { + verbose_stream() << "missed evaluation #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n"; + if (bv.is_bv(e)) { + auto const& v = m_eval.wval(e); + verbose_stream() << v << "\n" << v.eval << "\n"; + } + } + if (!is_uninterp_const(e)) + continue; + + auto f = to_app(e)->get_decl(); + if (m.is_bool(e)) + mdl->register_decl(f, m.mk_bool_val(m_eval.bval0(e))); + else if (bv.is_bv(e)) { + auto const& v = m_eval.wval(e); + rational n = v.get_value(); + mdl->register_decl(f, bv.mk_numeral(n, v.bw)); + } + } + terms.reset(); + return mdl; + } + + std::ostream& sls::display(std::ostream& out) { + auto& terms = m_eval.sort_assertions(m_terms.assertions()); + for (expr* e : terms) { + out << e->get_id() << ": " << mk_bounded_pp(e, m, 1) << " "; + if (m_eval.is_fixed0(e)) + out << "f "; + if (m_repair_up.contains(e->get_id())) + out << "u "; + if (m_repair_roots.contains(e->get_id())) + out << "r "; + if (bv.is_bv(e)) + out << m_eval.wval(e); + else if (m.is_bool(e)) + out << (m_eval.bval0(e)?"T":"F"); + out << "\n"; + } + terms.reset(); + return out; + } + + void sls::updt_params(params_ref const& _p) { + sls_params p(_p); + m_config.m_max_restarts = p.max_restarts(); + m_rand.set_seed(p.random_seed()); + } + + void sls::trace_repair(bool down, expr* e) { + IF_VERBOSE(20, + verbose_stream() << (down ? "d #" : "u #") + << e->get_id() << ": " + << mk_bounded_pp(e, m, 1) << " "; + if (bv.is_bv(e)) verbose_stream() << m_eval.wval(e) << " " << (m_eval.is_fixed0(e) ? "fixed " : " "); + if (m.is_bool(e)) verbose_stream() << m_eval.bval0(e) << " "; + verbose_stream() << "\n"); + } + + void sls::trace() { + IF_VERBOSE(2, verbose_stream() + << "(bvsls :restarts " << m_stats.m_restarts + << " :repair-up " << m_repair_up.size() + << " :repair-roots " << m_repair_roots.size() << ")\n"); + } +} diff --git a/src/ast/sls/bv_sls.h b/src/ast/sls/bv_sls.h new file mode 100644 index 000000000..bbcd59aea --- /dev/null +++ b/src/ast/sls/bv_sls.h @@ -0,0 +1,110 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls.h + +Abstract: + + A Stochastic Local Search (SLS) engine + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ +#pragma once + +#include "util/lbool.h" +#include "util/params.h" +#include "util/scoped_ptr_vector.h" +#include "util/uint_set.h" +#include "ast/ast.h" +#include "ast/sls/sls_stats.h" +#include "ast/sls/sls_powers.h" +#include "ast/sls/sls_valuation.h" +#include "ast/sls/bv_sls_terms.h" +#include "ast/sls/bv_sls_eval.h" +#include "ast/bv_decl_plugin.h" +#include "model/model.h" + +namespace bv { + + + class sls { + + struct config { + unsigned m_max_restarts = 1000; + unsigned m_max_repairs = 1000; + }; + + ast_manager& m; + bv_util bv; + sls_terms m_terms; + sls_eval m_eval; + sls_stats m_stats; + indexed_uint_set m_repair_up, m_repair_roots; + unsigned m_repair_down = UINT_MAX; + ptr_vector m_todo; + random_gen m_rand; + config m_config; + + std::pair next_to_repair(); + + bool eval_is_correct(app* e); + bool re_eval_is_correct(app* e); + void init_repair_goal(app* e); + void try_repair_down(app* e); + void try_repair_up(app* e); + void set_repair_down(expr* e) { m_repair_down = e->get_id(); } + + lbool search(); + void reinit_eval(); + void init_repair(); + void trace(); + void trace_repair(bool down, expr* e); + + public: + sls(ast_manager& m); + + /** + * Add constraints + */ + void assert_expr(expr* e) { m_terms.assert_expr(e); } + + /* + * Invoke init after all expressions are asserted. + * No other expressions can be asserted after init. + */ + void init(); + + /** + * Invoke init_eval to initialize, or re-initialize, values of + * uninterpreted constants. + */ + void init_eval(std::function& eval); + + /** + * Run (bounded) local search to find feasible assignments. + */ + lbool operator()(); + + void updt_params(params_ref const& p); + void collect_statistics(statistics & st) const { m_stats.collect_statistics(st); } + void reset_statistics() { m_stats.reset(); } + + sls_stats const& get_stats() const { return m_stats; } + + std::ostream& display(std::ostream& out); + + /** + * Retrieve valuation + */ + sls_valuation const& wval(expr* e) const { return m_eval.wval(e); } + + model_ref get_model(); + + void cancel() { m.limit().cancel(); } + }; +} diff --git a/src/ast/sls/bv_sls_eval.cpp b/src/ast/sls/bv_sls_eval.cpp new file mode 100644 index 000000000..4b7bf9546 --- /dev/null +++ b/src/ast/sls/bv_sls_eval.cpp @@ -0,0 +1,1735 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls_eval.cpp + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ + +#include "ast/ast_pp.h" +#include "ast/ast_ll_pp.h" +#include "ast/sls/bv_sls.h" + +namespace bv { + + sls_eval::sls_eval(ast_manager& m): + m(m), + bv(m), + m_fix(*this) + {} + + void sls_eval::init_eval(expr_ref_vector const& es, std::function const& eval) { + sort_assertions(es); + for (expr* e : m_todo) { + if (!is_app(e)) + continue; + app* a = to_app(e); + if (bv.is_bv(e)) + add_bit_vector(a); + if (a->get_family_id() == basic_family_id) + init_eval_basic(a); + else if (a->get_family_id() == bv.get_family_id()) + init_eval_bv(a); + else if (is_uninterp(e)) { + if (bv.is_bv(e)) { + auto& v = wval(e); + for (unsigned i = 0; i < v.bw; ++i) + m_tmp.set(i, eval(e, i)); + v.set_repair(random_bool(), m_tmp); + } + else if (m.is_bool(e)) + m_eval.setx(e->get_id(), eval(e, 0), false); + } + else { + TRACE("sls", tout << "Unhandled expression " << mk_pp(e, m) << "\n"); + } + } + m_todo.reset(); + } + + /** + * Sort all sub-expressions by depth, smallest first. + */ + ptr_vector& sls_eval::sort_assertions(expr_ref_vector const& es) { + expr_fast_mark1 mark; + for (expr* e : es) { + if (!mark.is_marked(e)) { + mark.mark(e); + m_todo.push_back(e); + } + } + for (unsigned i = 0; i < m_todo.size(); ++i) { + auto e = m_todo[i]; + if (!is_app(e)) + continue; + for (expr* arg : *to_app(e)) { + if (!mark.is_marked(arg)) { + mark.mark(arg); + m_todo.push_back(arg); + } + } + } + std::stable_sort(m_todo.begin(), m_todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); + return m_todo; + } + + bool sls_eval::add_bit_vector(app* e) { + m_values.reserve(e->get_id() + 1); + if (m_values.get(e->get_id())) + return false; + auto v = alloc_valuation(e); + m_values.set(e->get_id(), v); + if (bv.is_sign_ext(e)) { + unsigned p = e->get_parameter(0).get_int(); + v->set_signed(p); + } + return true; + } + + sls_valuation* sls_eval::alloc_valuation(app* e) { + auto bit_width = bv.get_bv_size(e); + auto* r = alloc(sls_valuation, bit_width); + while (m_tmp.size() < 2 * r->nw) { + m_tmp.push_back(0); + m_tmp2.push_back(0); + m_tmp3.push_back(0); + m_tmp4.push_back(0); + m_zero.push_back(0); + m_one.push_back(0); + m_a.push_back(0); + m_b.push_back(0); + m_nexta.push_back(0); + m_nextb.push_back(0); + m_aux.push_back(0); + m_minus_one.push_back(~0); + m_one[0] = 1; + } + return r; + } + + void sls_eval::init_eval_basic(app* e) { + auto id = e->get_id(); + if (m.is_bool(e)) + m_eval.setx(id, bval1(e), false); + else if (m.is_ite(e)) { + SASSERT(bv.is_bv(e->get_arg(1))); + auto& val = wval(e); + auto& val_th = wval(e->get_arg(1)); + auto& val_el = wval(e->get_arg(2)); + if (bval0(e->get_arg(0))) + val.set(val_th.bits()); + else + val.set(val_el.bits()); + } + else { + UNREACHABLE(); + } + } + + void sls_eval::init_eval_bv(app* e) { + if (bv.is_bv(e)) + eval(e).commit_eval(); + else if (m.is_bool(e)) + m_eval.setx(e->get_id(), bval1_bv(e), false); + } + + bool sls_eval::bval1_basic(app* e) const { + SASSERT(m.is_bool(e)); + SASSERT(e->get_family_id() == basic_family_id); + + auto id = e->get_id(); + switch (e->get_decl_kind()) { + case OP_TRUE: + return true; + case OP_FALSE: + return false; + case OP_AND: + return all_of(*to_app(e), [&](expr* arg) { return bval0(arg); }); + case OP_OR: + return any_of(*to_app(e), [&](expr* arg) { return bval0(arg); }); + case OP_NOT: + return !bval0(e->get_arg(0)); + case OP_XOR: { + bool r = false; + for (auto* arg : *to_app(e)) + r ^= bval0(arg); + return r; + } + case OP_IMPLIES: { + auto a = e->get_arg(0); + auto b = e->get_arg(1); + return !bval0(a) || bval0(b); + } + case OP_ITE: { + auto c = bval0(e->get_arg(0)); + return bval0(c ? e->get_arg(1) : e->get_arg(2)); + } + case OP_EQ: { + auto a = e->get_arg(0); + auto b = e->get_arg(1); + if (m.is_bool(a)) + return bval0(a) == bval0(b); + else if (bv.is_bv(a)) { + auto const& va = wval(a); + auto const& vb = wval(b); + return va.eq(vb); + } + return m.are_equal(a, b); + } + case OP_DISTINCT: + default: + verbose_stream() << mk_bounded_pp(e, m) << "\n"; + UNREACHABLE(); + break; + } + UNREACHABLE(); + return false; + } + + bool sls_eval::can_eval1(app* e) const { + expr* x, * y, * z; + if (m.is_eq(e, x, y)) + return m.is_bool(x) || bv.is_bv(x); + if (m.is_ite(e, x, y, z)) + return m.is_bool(y) || bv.is_bv(y); + if (e->get_family_id() == bv.get_fid()) { + switch (e->get_decl_kind()) { + case OP_BNEG_OVFL: + case OP_BSADD_OVFL: + case OP_BSDIV_OVFL: + case OP_BSMUL_NO_OVFL: + case OP_BSMUL_NO_UDFL: + case OP_BSMUL_OVFL: + return false; + default: + return true; + } + } + if (e->get_family_id() == basic_family_id) + return true; + if (is_uninterp_const(e)) + return m.is_bool(e) || bv.is_bv(e); + return false; + } + + bool sls_eval::bval1_bv(app* e) const { + SASSERT(m.is_bool(e)); + SASSERT(e->get_family_id() == bv.get_fid()); + + auto ucompare = [&](std::function const& f) { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + return f(mpn.compare(a.bits().data(), a.nw, b.bits().data(), b.nw)); + }; + + // x x + 2^{bw-1} const& f) { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + add_p2_1(a, m_tmp); + add_p2_1(b, m_tmp2); + return f(mpn.compare(m_tmp.data(), a.nw, m_tmp2.data(), b.nw)); + }; + + auto umul_overflow = [&]() { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + return a.set_mul(m_tmp2, a.bits(), b.bits()); + }; + + switch (e->get_decl_kind()) { + case OP_ULEQ: + return ucompare([](int i) { return i <= 0; }); + case OP_ULT: + return ucompare([](int i) { return i < 0; }); + case OP_UGT: + return ucompare([](int i) { return i > 0; }); + case OP_UGEQ: + return ucompare([](int i) { return i >= 0; }); + case OP_SLEQ: + return scompare([](int i) { return i <= 0; }); + case OP_SLT: + return scompare([](int i) { return i < 0; }); + case OP_SGT: + return scompare([](int i) { return i > 0; }); + case OP_SGEQ: + return scompare([](int i) { return i >= 0; }); + case OP_BIT2BOOL: { + expr* child; + unsigned idx; + VERIFY(bv.is_bit2bool(e, child, idx)); + auto& a = wval(child); + return a.get_bit(idx); + } + case OP_BUMUL_NO_OVFL: + return !umul_overflow(); + case OP_BUMUL_OVFL: + return umul_overflow(); + case OP_BUADD_OVFL: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + return a.set_add(m_tmp, a.bits(), b.bits()); + } + case OP_BNEG_OVFL: + case OP_BSADD_OVFL: + case OP_BSDIV_OVFL: + case OP_BSMUL_NO_OVFL: + case OP_BSMUL_NO_UDFL: + case OP_BSMUL_OVFL: + NOT_IMPLEMENTED_YET(); + break; + default: + UNREACHABLE(); + break; + } + return false; + } + + bool sls_eval::bval1(app* e) const { + if (e->get_family_id() == basic_family_id) + return bval1_basic(e); + if (e->get_family_id() == bv.get_fid()) + return bval1_bv(e); + SASSERT(is_uninterp_const(e)); + return bval0(e); + } + + sls_valuation& sls_eval::eval(app* e) const { + auto& val = *m_values[e->get_id()]; + eval(e, val); + return val; + } + + void sls_eval::eval(app* e, sls_valuation& val) const { + SASSERT(bv.is_bv(e)); + if (m.is_ite(e)) { + SASSERT(bv.is_bv(e->get_arg(1))); + auto& val_th = wval(e->get_arg(1)); + auto& val_el = wval(e->get_arg(2)); + if (bval0(e->get_arg(0))) + val.set(val_th.bits()); + else + val.set(val_el.bits()); + return; + } + if (e->get_family_id() == null_family_id) { + val.set(wval(e).bits()); + return; + } + auto set_sdiv = [&]() { + // d = udiv(abs(x), abs(y)) + // y = 0, x >= 0 -> -1 + // y = 0, x < 0 -> 1 + // x = 0, y != 0 -> 0 + // x > 0, y < 0 -> -d + // x < 0, y > 0 -> -d + // x > 0, y > 0 -> d + // x < 0, y < 0 -> d + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + bool sign_a = a.sign(); + bool sign_b = b.sign(); + if (b.is_zero()) { + if (sign_a) + val.set(m_one); + else + val.set(m_minus_one); + } + else if (a.is_zero()) + val.set(m_zero); + else { + if (sign_a) + a.set_sub(m_tmp, m_zero, a.bits()); + else + a.get(m_tmp); + + if (sign_b) + b.set_sub(m_tmp2, m_zero, b.bits()); + else + b.get(m_tmp2); + + set_div(m_tmp, m_tmp2, a.bw, m_tmp3, m_tmp4); + if (sign_a == sign_b) + val.set(m_tmp3); + else + val.set_sub(val.eval, m_zero, m_tmp3); + } + }; + + auto mk_rotate_left = [&](unsigned n) { + auto& a = wval(e->get_arg(0)); + VERIFY(try_repair_rotate_left(a.bits(), val, a.bw - n)); + }; + + SASSERT(e->get_family_id() == bv.get_fid()); + switch (e->get_decl_kind()) { + case OP_BV_NUM: { + rational n; + VERIFY(bv.is_numeral(e, n)); + val.set_value(m_tmp, n); + val.set(m_tmp); + break; + } + case OP_BAND: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] = a.bits()[i] & b.bits()[i]; + break; + } + case OP_BOR: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] = a.bits()[i] | b.bits()[i]; + break; + } + case OP_BXOR: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] = a.bits()[i] ^ b.bits()[i]; + break; + } + case OP_BNAND: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] = ~(a.bits()[i] & b.bits()[i]); + break; + } + case OP_BADD: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + val.set_add(val.eval, a.bits(), b.bits()); + break; + } + case OP_BSUB: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + val.set_sub(val.eval, a.bits(), b.bits()); + break; + } + case OP_BMUL: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + val.set_mul(m_tmp2, a.bits(), b.bits()); + val.set(m_tmp2); + break; + } + case OP_CONCAT: { + SASSERT(e->get_num_args() == 2); + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < b.bw; ++i) + val.eval.set(i, b.get_bit(i)); + for (unsigned i = 0; i < a.bw; ++i) + val.eval.set(i + b.bw, a.get_bit(i)); + break; + } + case OP_EXTRACT: { + expr* child; + unsigned lo, hi; + VERIFY(bv.is_extract(e, lo, hi, child)); + auto const& a = wval(child); + SASSERT(lo <= hi && hi + 1 <= a.bw && hi - lo + 1 == val.bw); + for (unsigned i = lo; i <= hi; ++i) + val.eval.set(i - lo, a.get_bit(i)); + break; + } + case OP_BNOT: { + auto& a = wval(e->get_arg(0)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] = ~a.bits()[i]; + break; + } + case OP_BNEG: { + auto& a = wval(e->get_arg(0)); + val.set_sub(val.eval, m_zero, a.bits()); + break; + } + case OP_BIT0: + val.eval.set(0, false); + break; + case OP_BIT1: + val.eval.set(0, true); + break; + case OP_BSHL: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + auto sh = b.to_nat(b.bw); + if (sh == 0) + val.set(a.bits()); + else if (sh >= b.bw) + val.set_zero(); + else { + for (unsigned i = 0; i < a.bw; ++i) + val.eval.set(i, i >= sh && a.get_bit(i - sh)); + } + break; + } + case OP_BLSHR: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + auto sh = b.to_nat(b.bw); + if (sh == 0) + val.set(a.bits()); + else if (sh >= b.bw) + val.set_zero(); + else { + for (unsigned i = 0; i < a.bw; ++i) + val.eval.set(i, i + sh < a.bw && a.get_bit(i + sh)); + } + break; + } + case OP_BASHR: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + auto sh = b.to_nat(b.bw); + auto sign = a.sign(); + if (sh == 0) + val.set(a.bits()); + else if (sh >= b.bw) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = sign ? ~0 : 0; + val.set(m_tmp); + } + else { + a.set_zero(m_tmp); + for (unsigned i = 0; i < a.bw; ++i) + m_tmp.set(i, i + sh < a.bw && a.get_bit(i + sh)); + if (sign) + val.set_range(m_tmp, a.bw - sh, a.bw, true); + val.set(m_tmp); + } + break; + } + case OP_SIGN_EXT: { + auto& a = wval(e->get_arg(0)); + a.get(m_tmp); + bool sign = a.sign(); + val.set_range(m_tmp, a.bw, val.bw, sign); + val.set(m_tmp); + break; + } + case OP_ZERO_EXT: { + auto& a = wval(e->get_arg(0)); + a.get(m_tmp); + val.set_range(m_tmp, a.bw, val.bw, false); + val.set(m_tmp); + break; + } + case OP_BUREM: + case OP_BUREM_I: + case OP_BUREM0: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + + if (b.is_zero()) + val.set(a.bits()); + else { + set_div(a.bits(), b.bits(), b.bw, m_tmp, m_tmp2); + val.set(m_tmp2); + } + break; + } + case OP_BSMOD: + case OP_BSMOD_I: + case OP_BSMOD0: { + // u = mod(abs(x),abs(y)) + // u = 0 -> 0 + // y = 0 -> x + // x < 0, y < 0 -> -u + // x < 0, y >= 0 -> y - u + // x >= 0, y < 0 -> y + u + // x >= 0, y >= 0 -> u + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + if (b.is_zero()) + val.set(a.bits()); + else { + if (a.sign()) + a.set_sub(m_tmp3, m_zero, a.bits()); + else + a.set(m_tmp3, a.bits()); + if (b.sign()) + b.set_sub(m_tmp4, m_zero, b.bits()); + else + a.set(m_tmp4, b.bits()); + set_div(m_tmp3, m_tmp4, a.bw, m_tmp, m_tmp2); + if (val.is_zero(m_tmp2)) + val.set(m_tmp2); + else if (a.sign() && b.sign()) + val.set_sub(val.eval, m_zero, m_tmp2); + else if (a.sign()) + val.set_sub(val.eval, b.bits(), m_tmp2); + else if (b.sign()) + val.set_add(val.eval, b.bits(), m_tmp2); + else + val.set(m_tmp2); + } + break; + } + case OP_BUDIV: + case OP_BUDIV_I: + case OP_BUDIV0: { + // x div 0 = -1 + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + if (b.is_zero()) + val.set(m_minus_one); + else { + set_div(a.bits(), b.bits(), a.bw, m_tmp, m_tmp2); + val.set(m_tmp); + } + break; + } + + case OP_BSDIV: + case OP_BSDIV_I: + case OP_BSDIV0: { + set_sdiv(); + break; + } + case OP_BSREM: + case OP_BSREM0: + case OP_BSREM_I: { + // b = 0 -> a + // else a - sdiv(a, b) * b + // + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + if (b.is_zero()) + val.set(a.bits()); + else { + set_sdiv(); + val.set_mul(m_tmp, val.eval, b.bits()); + val.set_sub(val.eval, a.bits(), m_tmp); + } + break; + } + case OP_ROTATE_LEFT: { + unsigned n = e->get_parameter(0).get_int() % val.bw; + mk_rotate_left(n); + break; + } + case OP_ROTATE_RIGHT: { + unsigned n = e->get_parameter(0).get_int() % val.bw; + mk_rotate_left(val.bw - n); + break; + } + case OP_EXT_ROTATE_LEFT: { + auto& b = wval(e->get_arg(1)); + rational n = b.get_value(); + n = mod(n, rational(val.bw)); + SASSERT(n.is_unsigned()); + mk_rotate_left(n.get_unsigned()); + break; + } + case OP_EXT_ROTATE_RIGHT: { + auto& b = wval(e->get_arg(1)); + rational n = b.get_value(); + n = mod(n, rational(val.bw)); + SASSERT(n.is_unsigned()); + mk_rotate_left(val.bw - n.get_unsigned()); + break; + } + case OP_BCOMP: { + auto const& a = wval(e->get_arg(0)); + auto const& b = wval(e->get_arg(1)); + if (a.bits() == b.bits()) + val.set(val.eval, 1); + else + val.set(val.eval, 0); + break; + } + case OP_BREDAND: + case OP_BREDOR: + case OP_BXNOR: + case OP_INT2BV: + + verbose_stream() << mk_bounded_pp(e, m) << "\n"; + NOT_IMPLEMENTED_YET(); + break; + case OP_BIT2BOOL: + case OP_BV2INT: + case OP_BNEG_OVFL: + case OP_BSADD_OVFL: + case OP_BUADD_OVFL: + case OP_BSDIV_OVFL: + case OP_BSMUL_NO_OVFL: + case OP_BSMUL_NO_UDFL: + case OP_BSMUL_OVFL: + case OP_BUMUL_NO_OVFL: + case OP_BUMUL_OVFL: + case OP_ULEQ: + case OP_UGEQ: + case OP_UGT: + case OP_ULT: + case OP_SLEQ: + case OP_SGEQ: + case OP_SGT: + case OP_SLT: + UNREACHABLE(); + break; + default: + UNREACHABLE(); + break; + } + val.clear_overflow_bits(val.eval); + } + + digit_t sls_eval::random_bits() { + return sls_valuation::random_bits(m_rand); + } + + bool sls_eval::try_repair(app* e, unsigned i) { + if (is_fixed0(e->get_arg(i))) + return false; + else if (e->get_family_id() == basic_family_id) + return try_repair_basic(e, i); + if (e->get_family_id() == bv.get_family_id()) + return try_repair_bv(e, i); + return false; + } + + bool sls_eval::try_repair_basic(app* e, unsigned i) { + switch (e->get_decl_kind()) { + case OP_AND: + return try_repair_and_or(e, i); + case OP_OR: + return try_repair_and_or(e, i); + case OP_NOT: + return try_repair_not(e); + case OP_FALSE: + return false; + case OP_TRUE: + return false; + case OP_EQ: + return try_repair_eq(e, i); + case OP_IMPLIES: + return try_repair_implies(e, i); + case OP_XOR: + return try_repair_xor(e, i); + case OP_ITE: + return try_repair_ite(e, i); + default: + UNREACHABLE(); + return false; + } + } + + bool sls_eval::try_repair_bv(app* e, unsigned i) { + switch (e->get_decl_kind()) { + case OP_BAND: + return try_repair_band(eval_value(e), wval(e, i), wval(e, 1 - i)); + case OP_BOR: + return try_repair_bor(eval_value(e), wval(e, i), wval(e, 1 - i)); + case OP_BXOR: + return try_repair_bxor(eval_value(e), wval(e, i), wval(e, 1 - i)); + case OP_BADD: + return try_repair_add(eval_value(e), wval(e, i), wval(e, 1 - i)); + case OP_BSUB: + return try_repair_sub(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_BMUL: + return try_repair_mul(eval_value(e), wval(e, i), wval(e, 1 - i)); + case OP_BNOT: + return try_repair_bnot(eval_value(e), wval(e, i)); + case OP_BNEG: + return try_repair_bneg(eval_value(e), wval(e, i)); + case OP_BIT0: + return false; + case OP_BIT1: + return false; + case OP_BV2INT: + return false; + case OP_INT2BV: + return false; + case OP_ULEQ: + if (i == 0) + return try_repair_ule(bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_uge(bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_UGEQ: + if (i == 0) + return try_repair_uge(bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_ule(bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_UGT: + if (i == 0) + return try_repair_ule(!bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_uge(!bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_ULT: + if (i == 0) + return try_repair_uge(!bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_ule(!bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_SLEQ: + if (i == 0) + return try_repair_sle(bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_sge(bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_SGEQ: + if (i == 0) + return try_repair_sge(bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_sle(bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_SGT: + if (i == 0) + return try_repair_sle(!bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_sge(!bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_SLT: + if (i == 0) + return try_repair_sge(!bval0(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_sle(!bval0(e), wval(e, i), wval(e, 1 - i)); + case OP_BASHR: + return try_repair_ashr(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_BLSHR: + return try_repair_lshr(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_BSHL: + return try_repair_shl(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_BIT2BOOL: { + unsigned idx; + expr* arg; + VERIFY(bv.is_bit2bool(e, arg, idx)); + return try_repair_bit2bool(wval(e, 0), idx); + } + + case OP_BUDIV: + case OP_BUDIV_I: + case OP_BUDIV0: + return try_repair_udiv(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_BUREM: + case OP_BUREM_I: + case OP_BUREM0: + return try_repair_urem(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_ROTATE_LEFT: + return try_repair_rotate_left(eval_value(e), wval(e, 0), e->get_parameter(0).get_int()); + case OP_ROTATE_RIGHT: + return try_repair_rotate_left(eval_value(e), wval(e, 0), wval(e).bw - e->get_parameter(0).get_int()); + case OP_EXT_ROTATE_LEFT: + return try_repair_rotate_left(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_EXT_ROTATE_RIGHT: + return try_repair_rotate_right(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_ZERO_EXT: + return try_repair_zero_ext(eval_value(e), wval(e, 0)); + case OP_SIGN_EXT: + return try_repair_sign_ext(eval_value(e), wval(e, 0)); + case OP_CONCAT: + return try_repair_concat(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_EXTRACT: { + unsigned hi, lo; + expr* arg; + VERIFY(bv.is_extract(e, lo, hi, arg)); + return try_repair_extract(eval_value(e), wval(arg), lo); + } + case OP_BUMUL_NO_OVFL: + return try_repair_umul_ovfl(!bval0(e), wval(e, 0), wval(e, 1), i); + case OP_BUMUL_OVFL: + return try_repair_umul_ovfl(bval0(e), wval(e, 0), wval(e, 1), i); + case OP_BCOMP: + return try_repair_comp(eval_value(e), wval(e, 0), wval(e, 1), i); + case OP_BUADD_OVFL: + + case OP_BNAND: + case OP_BREDAND: + case OP_BREDOR: + case OP_BXNOR: + case OP_BNEG_OVFL: + case OP_BSADD_OVFL: + case OP_BSDIV_OVFL: + case OP_BSMUL_NO_OVFL: + case OP_BSMUL_NO_UDFL: + case OP_BSMUL_OVFL: + verbose_stream() << mk_pp(e, m) << "\n"; + return false; + case OP_BSREM: + case OP_BSREM_I: + case OP_BSREM0: + case OP_BSMOD: + case OP_BSMOD_I: + case OP_BSMOD0: + case OP_BSDIV: + case OP_BSDIV_I: + case OP_BSDIV0: + // these are currently compiled to udiv and urem. + UNREACHABLE(); + return false; + default: + return false; + } + } + + bool sls_eval::try_repair_and_or(app* e, unsigned i) { + auto b = bval0(e); + auto child = e->get_arg(i); + if (b == bval0(child)) + return false; + m_eval[child->get_id()] = b; + return true; + } + + bool sls_eval::try_repair_not(app* e) { + auto child = e->get_arg(0); + m_eval[child->get_id()] = !bval0(e); + return true; + } + + bool sls_eval::try_repair_eq(app* e, unsigned i) { + auto child = e->get_arg(i); + auto is_true = bval0(e); + if (m.is_bool(child)) { + SASSERT(!is_fixed0(child)); + auto bv = bval0(e->get_arg(1 - i)); + m_eval[child->get_id()] = is_true == bv; + return true; + } + else if (bv.is_bv(child)) { + auto & a = wval(e->get_arg(i)); + auto & b = wval(e->get_arg(1 - i)); + return try_repair_eq(is_true, a, b); + } + return false; + } + + bool sls_eval::try_repair_eq(bool is_true, bvval& a, bvval const& b) { + if (is_true) { + if (m_rand() % 20 != 0) + if (a.try_set(b.bits())) + return true; + + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + else { + bool try_above = m_rand() % 2 == 0; + if (try_above) { + a.set_add(m_tmp, b.bits(), m_one); + if (!a.is_zero(m_tmp) && a.set_random_at_least(m_tmp, m_tmp2, m_rand)) + return true; + } + a.set_sub(m_tmp, b.bits(), m_one); + if (!a.is_zero(m_tmp) && a.set_random_at_most(m_tmp, m_tmp2, m_rand)) + return true; + if (!try_above) { + a.set_add(m_tmp, b.bits(), m_one); + if (!a.is_zero(m_tmp) && a.set_random_at_least(m_tmp, m_tmp2, m_rand)) + return true; + } + return false; + } + } + + bool sls_eval::try_repair_xor(app* e, unsigned i) { + bool ev = bval0(e); + bool bv = bval0(e->get_arg(1 - i)); + auto child = e->get_arg(i); + m_eval[child->get_id()] = ev != bv; + return true; + } + + bool sls_eval::try_repair_ite(app* e, unsigned i) { + auto child = e->get_arg(i); + bool c = bval0(e->get_arg(0)); + if (i == 0) { + m_eval[child->get_id()] = !c; + return true; + } + if (c != (i == 1)) + return false; + if (m.is_bool(e)) { + m_eval[child->get_id()] = bval0(e); + return true; + } + if (bv.is_bv(e)) + return wval(child).try_set(wval(e).bits()); + return false; + } + + bool sls_eval::try_repair_implies(app* e, unsigned i) { + auto child = e->get_arg(i); + bool ev = bval0(e); + bool av = bval0(child); + bool bv = bval0(e->get_arg(1 - i)); + if (i == 0) { + if (ev == (!av || bv)) + return false; + } + else if (ev != (!bv || av)) + return false; + m_eval[child->get_id()] = ev; + return true; + } + + // + // e = a & b + // e[i] = 1 -> a[i] = 1 + // e[i] = 0 & b[i] = 1 -> a[i] = 0 + // e[i] = 0 & b[i] = 0 -> a[i] = random + // a := e[i] | (~b[i] & a[i]) + + bool sls_eval::try_repair_band(bvect const& e, bvval& a, bvval const& b) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = ~a.fixed[i] & (e[i] | (~b.bits()[i] & random_bits())); + return a.set_repair(random_bool(), m_tmp); + } + + // + // e = a | b + // set a[i] to 1 where b[i] = 0, e[i] = 1 + // set a[i] to 0 where e[i] = 0, a[i] = 1 + // + bool sls_eval::try_repair_bor(bvect const& e, bvval& a, bvval const& b) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = e[i] & (~b.bits()[i] | random_bits()); + return a.set_repair(random_bool(), m_tmp); + } + + bool sls_eval::try_repair_bxor(bvect const& e, bvval& a, bvval const& b) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = e[i] ^ b.bits()[i]; + a.clear_overflow_bits(m_tmp); + return a.set_repair(random_bool(), m_tmp); + } + + + // + // first try to set a := e - b + // If this fails, set a to a random value + // + bool sls_eval::try_repair_add(bvect const& e, bvval& a, bvval const& b) { + if (m_rand() % 20 != 0) { + a.set_sub(m_tmp, e, b.bits()); + if (a.try_set(m_tmp)) + return true; + } + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + + bool sls_eval::try_repair_sub(bvect const& e, bvval& a, bvval & b, unsigned i) { + if (m_rand() % 20 != 0) { + if (i == 0) + // e = a - b -> a := e + b + a.set_add(m_tmp, e, b.bits()); + else + // b := a - e + b.set_sub(m_tmp, a.bits(), e); + if (a.try_set(m_tmp)) + return true; + } + // fall back to a random value + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + + /** + * e = a*b, then a = e * b^-1 + * 8*e = a*(2b), then a = 4e*b^-1 + */ + bool sls_eval::try_repair_mul(bvect const& e, bvval& a, bvval const& b) { + unsigned parity_e = b.parity(e); + unsigned parity_b = b.parity(b.bits()); + + if (b.is_zero(e)) { + a.get_variant(m_tmp, m_rand); + for (unsigned i = 0; i < b.bw - parity_b; ++i) + m_tmp.set(i, false); + return a.set_repair(random_bool(), m_tmp); + } + + if (b.is_zero()) { + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + + if (m_rand() % 20 == 0) { + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + +#if 0 + verbose_stream() << "solve for " << e << "\n"; + + rational r = e.get_value(e.nw); + rational root; + verbose_stream() << r.is_int_perfect_square(root) << "\n"; +#endif + + + auto& x = m_tmp; + auto& y = m_tmp2; + auto& quot = m_tmp3; + auto& rem = m_tmp4; + auto& ta = m_a; + auto& tb = m_b; + auto& nexta = m_nexta; + auto& nextb = m_nextb; + auto& aux = m_aux; + auto bw = b.bw; + + + // x*ta + y*tb = x + + b.get(y); + if (parity_b > 0) { + b.shift_right(y, parity_b); +#if 0 + for (unsigned i = parity_b; i < b.bw; ++i) + y.set(i, m_rand() % 2 == 0); +#endif + } + + y[a.nw] = 0; + x[a.nw] = 0; + + + a.set_bw((a.nw + 1)* 8 * sizeof(digit_t)); + y.set_bw(a.bw); // enable comparisons + a.set_zero(x); + x.set(bw, true); // x = 2 ^ b.bw + + a.set_one(ta); + a.set_zero(tb); + a.set_zero(nexta); + a.set_one(nextb); + + rem.reserve(2 * a.nw); + SASSERT(y <= x); + while (y > m_zero) { + SASSERT(y <= x); + set_div(x, y, a.bw, quot, rem); // quot, rem := quot_rem(x, y) + SASSERT(y >= rem); + a.set(x, y); // x := y + a.set(y, rem); // y := rem + a.set(aux, nexta); // aux := nexta + a.set_mul(rem, quot, nexta, false); + a.set_sub(nexta, ta, rem); // nexta := ta - quot*nexta + a.set(ta, aux); // ta := aux + a.set(aux, nextb); // aux := nextb + a.set_mul(rem, quot, nextb, false); + a.set_sub(nextb, tb, rem); // nextb := tb - quot*nextb + a.set(tb, aux); // tb := aux + } + + a.set_bw(bw); + y.set_bw(0); + // x*a + y*b = 1 + + tb.set_bw(0); +#if Z3DEBUG + b.get(y); + if (parity_b > 0) + b.shift_right(y, parity_b); + a.set_mul(m_tmp, tb, y); + SASSERT(a.is_one(m_tmp)); +#endif + e.copy_to(b.nw, m_tmp2); + if (parity_e > 0 && parity_b > 0) + b.shift_right(m_tmp2, std::min(parity_b, parity_e)); + a.set_mul(m_tmp, tb, m_tmp2); + if (a.set_repair(random_bool(), m_tmp)) + return true; + + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + + bool sls_eval::try_repair_bnot(bvect const& e, bvval& a) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = ~e[i]; + a.clear_overflow_bits(m_tmp); + return a.try_set(m_tmp); + } + + bool sls_eval::try_repair_bneg(bvect const& e, bvval& a) { + a.set_sub(m_tmp, m_zero, e); + return a.try_set(m_tmp); + } + + + // a <=s b <-> a + p2 <=u b + p2 + // + // NB: p2 = -p2 + // + // to solve x for x >s b: + // infeasible if b + 1 = p2 + // solve for x >=s b + 1 + // + bool sls_eval::try_repair_sle(bool e, bvval& a, bvval const& b) { + auto& p2 = m_b; + b.set_zero(p2); + p2.set(b.bw - 1, true); + p2.set_bw(b.bw); + bool r = false; + if (e) + r = try_repair_sle(a, b.bits(), p2); + else { + auto& b1 = m_nexta; + a.set_add(b1, b.bits(), m_one); + b1.set_bw(b.bw); + if (p2 == b1) + r = false; + else + r = try_repair_sge(a, b1, p2); + b1.set_bw(0); + } + p2.set_bw(0); + return r; + } + + // to solve x for x = p2 if c >= p2 (b < p2) + // or + // x := random p2 <= x <= b if c < p2 (b >= p2) + // + bool sls_eval::try_repair_sle(bvval& a, bvect const& b, bvect const& p2) { + bool r = false; + if (b < p2) { + bool coin = m_rand() % 2 == 0; + if (coin) + r = a.set_random_at_least(p2, m_tmp3, m_rand); + if (!r) + r = a.set_random_at_most(b, m_tmp3, m_rand); + if (!coin && !r) + r = a.set_random_at_least(p2, m_tmp3, m_rand); + } + else + r = a.set_random_in_range(p2, b, m_tmp3, m_rand); + return r; + } + + // solve for x >=s b + // + // d := b + p2 + // + // x := random b <= x < p2 if d >= p2 (b < p2) + // or + // x := random b <= x or x < p2 if d < p2 + // + + bool sls_eval::try_repair_sge(bvval& a, bvect const& b, bvect const& p2) { + auto& p2_1 = m_tmp4; + a.set_sub(p2_1, p2, m_one); + p2_1.set_bw(a.bw); + bool r = false; + if (p2 < b) + // random b <= x < p2 + r = a.set_random_in_range(b, p2_1, m_tmp3, m_rand); + else { + // random b <= x or x < p2 + bool coin = m_rand() % 2 == 0; + if (coin) + r = a.set_random_at_most(p2_1, m_tmp3, m_rand); + if (!r) + r = a.set_random_at_least(b, m_tmp3, m_rand); + if (!r && !coin) + r = a.set_random_at_most(p2_1, m_tmp3, m_rand); + } + p2_1.set_bw(0); + return r; + } + + void sls_eval::add_p2_1(bvval const& a, bvect& t) const { + m_zero.set(a.bw - 1, true); + a.set_add(t, a.bits(), m_zero); + m_zero.set(a.bw - 1, false); + a.clear_overflow_bits(t); + } + + bool sls_eval::try_repair_ule(bool e, bvval& a, bvval const& b) { + if (e) { + // a <= t + return a.set_random_at_most(b.bits(), m_tmp, m_rand); + } + else { + // a > t + a.set_add(m_tmp, b.bits(), m_one); + if (a.is_zero(m_tmp)) + return false; + return a.set_random_at_least(m_tmp, m_tmp2, m_rand); + } + } + + bool sls_eval::try_repair_uge(bool e, bvval& a, bvval const& b) { + if (e) { + // a >= t + return a.set_random_at_least(b.bits(), m_tmp, m_rand); + } + else { + // a < t + if (b.is_zero()) + return false; + a.set_sub(m_tmp, b.bits(), m_one); + return a.set_random_at_most(m_tmp, m_tmp2, m_rand); + } + } + + bool sls_eval::try_repair_bit2bool(bvval& a, unsigned idx) { + return a.try_set_bit(idx, !a.get_bit(idx)); + } + + bool sls_eval::try_repair_shl(bvect const& e, bvval& a, bvval& b, unsigned i) { + if (i == 0) { + unsigned sh = b.to_nat(b.bw); + if (sh == 0) + return a.try_set(e); + else if (sh >= b.bw) + return false; + else { + // + // e = a << sh + // set bw - sh low order bits to bw - sh high-order of e. + // a[bw - sh - 1: 0] = e[bw - 1: sh] + // a[bw - 1: bw - sh] = unchanged + // + for (unsigned i = 0; i < a.bw - sh; ++i) + m_tmp.set(i, e.get(sh + i)); + for (unsigned i = a.bw - sh; i < a.bw; ++i) + m_tmp.set(i, a.get_bit(i)); + a.clear_overflow_bits(m_tmp); + return a.try_set(m_tmp); + } + } + else { + // NB. blind sub-range of possible values for b + SASSERT(i == 1); + unsigned sh = m_rand(a.bw + 1); + b.set(m_tmp, sh); + return b.try_set(m_tmp); + } + return false; + } + + bool sls_eval::try_repair_ashr(bvect const& e, bvval & a, bvval& b, unsigned i) { + if (i == 0) { + unsigned sh = b.to_nat(b.bw); + if (sh == 0) + return a.try_set(e); + else if (sh >= b.bw) { + if (e.get(a.bw - 1)) + return a.try_set_bit(a.bw - 1, true); + else + return a.try_set_bit(a.bw - 1, false); + } + else { + // e = a >> sh + // a[bw-1:sh] = e[bw-sh-1:0] + // a[sh-1:0] = a[sh-1:0] + // ignore sign + for (unsigned i = sh; i < a.bw; ++i) + m_tmp.set(i, e.get(i - sh)); + for (unsigned i = 0; i < sh; ++i) + m_tmp.set(i, a.get_bit(i)); + a.clear_overflow_bits(m_tmp); + return a.try_set(m_tmp); + } + } + else { + // NB. blind sub-range of possible values for b + SASSERT(i == 1); + unsigned sh = m_rand(a.bw + 1); + b.set(m_tmp, sh); + return b.try_set(m_tmp); + } + } + + bool sls_eval::try_repair_lshr(bvect const& e, bvval& a, bvval& b, unsigned i) { + return try_repair_ashr(e, a, b, i); + } + + bool sls_eval::try_repair_comp(bvect const& e, bvval& a, bvval& b, unsigned i) { + SASSERT(e[0] == 0 || e[0] == 1); + SASSERT(e.bw == 1); + return try_repair_eq(e[0] == 1, i == 0 ? a : b, i == 0 ? b : a); + } + + // e = a udiv b + // e = 0 => a != ones + // b = 0 => e = -1 // nothing to repair on a + // e != -1 => max(a) >=u e + + bool sls_eval::try_repair_udiv(bvect const& e, bvval& a, bvval& b, unsigned i) { + if (i == 0) { + if (a.is_zero(e) && a.is_ones(a.fixed) && a.is_ones()) + return false; + if (b.is_zero()) + return false; + if (!a.is_ones(e)) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = ~a.fixed[i] | a.bits()[i]; + a.clear_overflow_bits(m_tmp); + if (e > m_tmp) + return false; + } + // e = 1 => a := b + if (a.is_one(e)) { + a.set(m_tmp, b.bits()); + return a.set_repair(false, m_tmp); + } + // b * e + r = a + if (mul_overflow_on_fixed(b, e)) { + a.get_variant(m_tmp, m_rand); + return a.set_repair(random_bool(), m_tmp); + } + + b.get_variant(m_tmp2, m_rand); + while (b.bits() < m_tmp2) + m_tmp2.set(b.msb(m_tmp2), false); + while (a.set_add(m_tmp3, m_tmp, m_tmp2)) + m_tmp2.set(b.msb(m_tmp2), false); + a.clear_overflow_bits(m_tmp3); + return a.set_repair(true, m_tmp3); + } + else { + if (a.is_one(e) && a.is_zero()) { + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = random_bits(); + a.clear_overflow_bits(m_tmp); + return b.set_repair(true, m_tmp); + } + if (a.is_one(e)) { + a.set(m_tmp, a.bits()); + return b.set_repair(true, m_tmp); + } + + // e * b + r = a + // b = (a - r) udiv e + // random version of r: + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = random_bits(); + a.clear_overflow_bits(m_tmp); + // ensure r <= m + while (a.bits() < m_tmp) + m_tmp.set(a.msb(m_tmp), false); + a.set_sub(m_tmp2, a.bits(), m_tmp); + set_div(m_tmp2, e, a.bw, m_tmp3, m_tmp4); + return b.set_repair(random_bool(), m_tmp4); + } + } + + // table III in Niemetz et al + // x urem s = t <=> + // ~(-s) >=u t + // ((s = 0 or t = ones) => mcb(x, t)) + // ((s != 0 and t != ones) => exists y . (mcb(x, s*y + t) and ~mulo(s, y) and ~addo(s*y, t)) + // s urem x = t <=> + // (s = t => x can be >u t) + // (s != t => exists y . (mcb(x, y) and y >u t and (s - t) mod y = 0) + + + bool sls_eval::try_repair_urem(bvect const& e, bvval& a, bvval& b, unsigned i) { + + if (i == 0) { + if (b.is_zero()) { + a.set(m_tmp, e); + return a.set_repair(random_bool(), m_tmp); + } + // a urem b = e: b*y + e = a + // ~Ovfl*(b, y) + // ~Ovfl+(b*y, e) + // choose y at random + // lower y as long as y*b overflows with fixed bits in b + + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = random_bits(); + a.clear_overflow_bits(m_tmp); + while (mul_overflow_on_fixed(b, m_tmp)) { + auto i = b.msb(m_tmp); + m_tmp.set(i, false); + } + while (true) { + a.set_mul(m_tmp2, m_tmp, b.bits()); + if (!a.set_add(m_tmp3, m_tmp2, e)) + break; + auto i = b.msb(m_tmp); + m_tmp.set(i, false); + } + return a.set_repair(random_bool(), m_tmp3); + } + else { + // a urem b = e: b*y + e = a + // b*y = a - e + // b = (a - e) div y + // ~Ovfl*(b, y) + // ~Ovfl+(b*y, e) + // choose y at random + // lower y as long as y*b overflows with fixed bits in b + for (unsigned i = 0; i < a.nw; ++i) + m_tmp[i] = random_bits(); + a.set_sub(m_tmp2, a.bits(), e); + set_div(m_tmp2, m_tmp, a.bw, m_tmp3, m_tmp4); + a.clear_overflow_bits(m_tmp3); + return b.set_repair(random_bool(), m_tmp3); + } + } + + bool sls_eval::add_overflow_on_fixed(bvval const& a, bvect const& t) { + a.set(m_tmp3, m_zero); + for (unsigned i = 0; i < a.nw; ++i) + m_tmp3[i] = a.fixed[i] & a.bits()[i]; + return a.set_add(m_tmp4, t, m_tmp3); + } + + bool sls_eval::mul_overflow_on_fixed(bvval const& a, bvect const& t) { + a.set(m_tmp3, m_zero); + for (unsigned i = 0; i < a.nw; ++i) + m_tmp3[i] = a.fixed[i] & a.bits()[i]; + return a.set_mul(m_tmp4, m_tmp3, t); + } + + bool sls_eval::try_repair_rotate_left(bvect const& e, bvval& a, unsigned n) const { + // a := rotate_right(e, n) + n = (a.bw - n) % a.bw; + for (unsigned i = a.bw - n; i < a.bw; ++i) + m_tmp.set(i + n - a.bw, e.get(i)); + for (unsigned i = 0; i < a.bw - n; ++i) + m_tmp.set(i + n, e.get(i)); + return a.set_repair(true, m_tmp); + } + + bool sls_eval::try_repair_rotate_left(bvect const& e, bvval& a, bvval& b, unsigned i) { + if (i == 0) { + rational n = b.get_value(); + n = mod(n, rational(b.bw)); + return try_repair_rotate_left(e, a, n.get_unsigned()); + } + else { + SASSERT(i == 1); + unsigned sh = m_rand(b.bw); + b.set(m_tmp, sh); + return b.set_repair(random_bool(), m_tmp); + } + } + + bool sls_eval::try_repair_rotate_right(bvect const& e, bvval& a, bvval& b, unsigned i) { + if (i == 0) { + rational n = b.get_value(); + n = mod(b.bw - n, rational(b.bw)); + return try_repair_rotate_left(e, a, n.get_unsigned()); + } + else { + SASSERT(i == 1); + unsigned sh = m_rand(b.bw); + b.set(m_tmp, sh); + return b.set_repair(random_bool(), m_tmp); + } + } + + bool sls_eval::try_repair_umul_ovfl(bool e, bvval& a, bvval& b, unsigned i) { + if (e) { + // maximize + if (i == 0) { + a.max_feasible(m_tmp); + return a.set_repair(false, m_tmp); + } + else { + b.max_feasible(m_tmp); + return b.set_repair(false, m_tmp); + } + } + else { + // minimize + if (i == 0) { + a.min_feasible(m_tmp); + return a.set_repair(true, m_tmp); + } + else { + b.min_feasible(m_tmp); + return b.set_repair(true, m_tmp); + } + } + } + + // + // prefix of e must be 1s or 0 and match bit position of last bit in a. + // set a to suffix of e, matching signs. + // + bool sls_eval::try_repair_sign_ext(bvect const& e, bvval& a) { + for (unsigned i = a.bw; i < e.bw; ++i) + if (e.get(i) != e.get(a.bw - 1)) + return false; + + for (unsigned i = 0; i < e.nw; ++i) + m_tmp[i] = e[i]; + a.clear_overflow_bits(m_tmp); + return a.try_set(m_tmp); + } + + // + // prefix of e must be 0s. + // + bool sls_eval::try_repair_zero_ext(bvect const& e, bvval& a) { + for (unsigned i = a.bw; i < e.bw; ++i) + if (e.get(i)) + return false; + + for (unsigned i = 0; i < e.nw; ++i) + m_tmp[i] = e[i]; + a.clear_overflow_bits(m_tmp); + return a.try_set(m_tmp); + } + + bool sls_eval::try_repair_concat(bvect const& e, bvval& a, bvval& b, unsigned idx) { + bool r = false; + if (idx == 0) { + for (unsigned i = 0; i < a.bw; ++i) + m_tmp.set(i, e.get(i + b.bw)); + a.clear_overflow_bits(m_tmp); + r = a.try_set(m_tmp); + } + else { + for (unsigned i = 0; i < b.bw; ++i) + m_tmp.set(i, e.get(i)); + b.clear_overflow_bits(m_tmp); + r = b.try_set(m_tmp); + } + //verbose_stream() << e << " := " << a << " " << b << "\n"; + return r; + } + + // + // e = a[hi:lo], where hi = e.bw + lo - 1 + // for the randomized assignment, + // set a outside of [hi:lo] to random values with preference to 0 or 1 bits + // + bool sls_eval::try_repair_extract(bvect const& e, bvval& a, unsigned lo) { + if (m_rand() % m_config.m_prob_randomize_extract <= 100) { + a.get_variant(m_tmp, m_rand); + if (0 == (m_rand() % 2)) { + auto bit = 0 == (m_rand() % 2); + if (!a.try_set_range(m_tmp, 0, lo, bit)) + a.try_set_range(m_tmp, 0, lo, !bit); + } + if (0 == (m_rand() % 2)) { + auto bit = 0 == (m_rand() % 2); + if (!a.try_set_range(m_tmp, lo + e.bw, a.bw, bit)) + a.try_set_range(m_tmp, lo + e.bw, a.bw, !bit); + } + } + else + a.get(m_tmp); + for (unsigned i = 0; i < e.bw; ++i) + m_tmp.set(i + lo, e.get(i)); + if (a.try_set(m_tmp)) + return true; + a.get_variant(m_tmp, m_rand); + bool res = a.set_repair(random_bool(), m_tmp); + // verbose_stream() << "try set " << res << " " << m_tmp[0] << " " << a << "\n"; + return res; + } + + void sls_eval::set_div(bvect const& a, bvect const& b, unsigned bw, + bvect& quot, bvect& rem) const { + unsigned nw = (bw + 8 * sizeof(digit_t) - 1) / (8 * sizeof(digit_t)); + unsigned bnw = nw; + while (bnw > 1 && b[bnw - 1] == 0) + --bnw; + if (b[bnw-1] == 0) { + for (unsigned i = 0; i < nw; ++i) { + quot[i] = ~0; + rem[i] = 0; + } + quot[nw - 1] = (1 << (bw % (8 * sizeof(digit_t)))) - 1; + } + else { + for (unsigned i = 0; i < nw; ++i) + rem[i] = quot[i] = 0; + mpn.div(a.data(), nw, b.data(), bnw, quot.data(), rem.data()); + } + } + + bool sls_eval::repair_up(expr* e) { + if (!is_app(e)) + return false; + if (m.is_bool(e)) { + auto b = bval1(to_app(e)); + if (is_fixed0(e)) + return b == bval0(e); + m_eval[e->get_id()] = b; + return true; + } + if (bv.is_bv(e)) { + auto& v = eval(to_app(e)); + // verbose_stream() << "committing: " << v << "\n"; + for (unsigned i = 0; i < v.nw; ++i) + if (0 != (v.fixed[i] & (v.bits()[i] ^ v.eval[i]))) { + v.bits().copy_to(v.nw, v.eval); + return false; + } + if (v.commit_eval()) + return true; + v.bits().copy_to(v.nw, v.eval); + return false; + } + return false; + } + + sls_valuation& sls_eval::wval(expr* e) const { + // if (!m_values[e->get_id()]) verbose_stream() << mk_bounded_pp(e, m) << "\n"; + return *m_values[e->get_id()]; + } + + std::ostream& sls_eval::display(std::ostream& out, expr_ref_vector const& es) { + auto& terms = sort_assertions(es); + for (expr* e : terms) { + out << e->get_id() << ": " << mk_bounded_pp(e, m, 1) << " "; + if (is_fixed0(e)) + out << "f "; + if (bv.is_bv(e)) + out << wval(e); + else if (m.is_bool(e)) + out << (bval0(e) ? "T" : "F"); + out << "\n"; + } + terms.reset(); + return out; + } +} diff --git a/src/ast/sls/bv_sls_eval.h b/src/ast/sls/bv_sls_eval.h new file mode 100644 index 000000000..5422d5b7c --- /dev/null +++ b/src/ast/sls/bv_sls_eval.h @@ -0,0 +1,178 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls.h + +Abstract: + + A Stochastic Local Search (SLS) engine + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ +#pragma once + +#include "ast/ast.h" +#include "ast/sls/sls_valuation.h" +#include "ast/sls/bv_sls_fixed.h" +#include "ast/bv_decl_plugin.h" + +namespace bv { + + class sls_fixed; + + class sls_eval { + struct config { + unsigned m_prob_randomize_extract = 50; + }; + + friend class sls_fixed; + friend class sls_test; + ast_manager& m; + bv_util bv; + sls_fixed m_fix; + mutable mpn_manager mpn; + ptr_vector m_todo; + random_gen m_rand; + config m_config; + + + + scoped_ptr_vector m_values; // expr-id -> bv valuation + bool_vector m_eval; // expr-id -> boolean valuation + bool_vector m_fixed; // expr-id -> is Boolean fixed + + mutable bvect m_tmp, m_tmp2, m_tmp3, m_tmp4, m_zero, m_one, m_minus_one; + bvect m_a, m_b, m_nextb, m_nexta, m_aux; + + using bvval = sls_valuation; + + + void init_eval_basic(app* e); + void init_eval_bv(app* e); + + /** + * Register e as a bit-vector. + * Return true if not already registered, false if already registered. + */ + bool add_bit_vector(app* e); + sls_valuation* alloc_valuation(app* e); + + bool bval1_basic(app* e) const; + bool bval1_bv(app* e) const; + + /** + * Repair operations + */ + bool try_repair_basic(app* e, unsigned i); + bool try_repair_bv(app * e, unsigned i); + bool try_repair_and_or(app* e, unsigned i); + bool try_repair_not(app* e); + bool try_repair_eq(app* e, unsigned i); + bool try_repair_xor(app* e, unsigned i); + bool try_repair_ite(app* e, unsigned i); + bool try_repair_implies(app* e, unsigned i); + bool try_repair_band(bvect const& e, bvval& a, bvval const& b); + bool try_repair_bor(bvect const& e, bvval& a, bvval const& b); + bool try_repair_add(bvect const& e, bvval& a, bvval const& b); + bool try_repair_sub(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_mul(bvect const& e, bvval& a, bvval const& b); + bool try_repair_bxor(bvect const& e, bvval& a, bvval const& b); + bool try_repair_bnot(bvect const& e, bvval& a); + bool try_repair_bneg(bvect const& e, bvval& a); + bool try_repair_ule(bool e, bvval& a, bvval const& b); + bool try_repair_uge(bool e, bvval& a, bvval const& b); + bool try_repair_sle(bool e, bvval& a, bvval const& b); + bool try_repair_sge(bool e, bvval& a, bvval const& b); + bool try_repair_sge(bvval& a, bvect const& b, bvect const& p2); + bool try_repair_sle(bvval& a, bvect const& b, bvect const& p2); + bool try_repair_shl(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_ashr(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_lshr(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_bit2bool(bvval& a, unsigned idx); + bool try_repair_udiv(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_urem(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_rotate_left(bvect const& e, bvval& a, unsigned n) const; + bool try_repair_rotate_left(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_rotate_right(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_ule(bool e, bvval& a, bvect const& t); + bool try_repair_uge(bool e, bvval& a, bvect const& t); + bool try_repair_umul_ovfl(bool e, bvval& a, bvval& b, unsigned i); + bool try_repair_zero_ext(bvect const& e, bvval& a); + bool try_repair_sign_ext(bvect const& e, bvval& a); + bool try_repair_concat(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_extract(bvect const& e, bvval& a, unsigned lo); + bool try_repair_comp(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_eq(bool is_true, bvval& a, bvval const& b); + void add_p2_1(bvval const& a, bvect& t) const; + + bool add_overflow_on_fixed(bvval const& a, bvect const& t); + bool mul_overflow_on_fixed(bvval const& a, bvect const& t); + void set_div(bvect const& a, bvect const& b, unsigned nw, + bvect& quot, bvect& rem) const; + + digit_t random_bits(); + bool random_bool() { return m_rand() % 2 == 0; } + + sls_valuation& wval(app* e, unsigned i) { return wval(e->get_arg(i)); } + + void eval(app* e, sls_valuation& val) const; + + bvect const& eval_value(app* e) const { return wval(e).eval; } + + public: + sls_eval(ast_manager& m); + + void init_eval(expr_ref_vector const& es, std::function const& eval); + + void tighten_range(expr_ref_vector const& es) { m_fix.init(es); } + + ptr_vector& sort_assertions(expr_ref_vector const& es); + + /** + * Retrieve evaluation based on cache. + * bval - Boolean values + * wval - Word (bit-vector) values + */ + + bool bval0(expr* e) const { return m_eval[e->get_id()]; } + + sls_valuation& wval(expr* e) const; + + bool is_fixed0(expr* e) const { return m_fixed.get(e->get_id(), false); } + + /** + * Retrieve evaluation based on immediate children. + */ + bool bval1(app* e) const; + bool can_eval1(app* e) const; + + sls_valuation& eval(app* e) const; + + /** + * Override evaluaton. + */ + + void set(expr* e, bool b) { + m_eval[e->get_id()] = b; + } + + /* + * Try to invert value of child to repair value assignment of parent. + */ + + bool try_repair(app* e, unsigned i); + + /* + * Propagate repair up to parent + */ + bool repair_up(expr* e); + + + std::ostream& display(std::ostream& out, expr_ref_vector const& es); + }; +} diff --git a/src/ast/sls/bv_sls_fixed.cpp b/src/ast/sls/bv_sls_fixed.cpp new file mode 100644 index 000000000..91ce8e0e2 --- /dev/null +++ b/src/ast/sls/bv_sls_fixed.cpp @@ -0,0 +1,423 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls_fixed.cpp + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ + +#include "ast/ast_pp.h" +#include "ast/ast_ll_pp.h" +#include "ast/sls/bv_sls_fixed.h" +#include "ast/sls/bv_sls_eval.h" + +namespace bv { + + sls_fixed::sls_fixed(sls_eval& ev): + ev(ev), + m(ev.m), + bv(ev.bv) + {} + + void sls_fixed::init(expr_ref_vector const& es) { + ev.sort_assertions(es); + for (expr* e : ev.m_todo) { + if (!is_app(e)) + continue; + app* a = to_app(e); + ev.m_fixed.setx(a->get_id(), is_fixed1(a), false); + if (a->get_family_id() == basic_family_id) + init_fixed_basic(a); + else if (a->get_family_id() == bv.get_family_id()) + init_fixed_bv(a); + else + ; + } + ev.m_todo.reset(); + init_ranges(es); + } + + + void sls_fixed::init_ranges(expr_ref_vector const& es) { + for (expr* e : es) { + bool sign = m.is_not(e, e); + if (is_app(e)) + init_range(to_app(e), sign); + } + } + + // s <=s t <=> s + K <= t + K, K = 2^{bw-1} + + void sls_fixed::init_range(app* e, bool sign) { + expr* s, * t, * x, * y; + rational a, b; + unsigned idx; + auto N = [&](expr* s) { + auto b = bv.get_bv_size(s); + return b > 0 ? rational::power_of_two(b - 1) : rational(0); + }; + if (bv.is_ule(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(x, a, y, b, sign); + } + else if (bv.is_ult(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(y, b, x, a, !sign); + } + else if (bv.is_uge(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(y, b, x, a, sign); + } + else if (bv.is_ugt(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(x, a, y, b, !sign); + } + else if (bv.is_sle(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(x, a + N(s), y, b + N(s), sign); + } + else if (bv.is_slt(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(y, b + N(s), x, a + N(s), !sign); + } + else if (bv.is_sge(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(y, b + N(s), x, a + N(s), sign); + } + else if (bv.is_sgt(e, s, t)) { + get_offset(s, x, a); + get_offset(t, y, b); + init_range(x, a + N(s), y, b + N(s), !sign); + } + else if (!sign && m.is_eq(e, s, t)) { + if (bv.is_numeral(s, a)) + // t - a <= 0 + init_range(t, -a, nullptr, rational(0), false); + else if (bv.is_numeral(t, a)) + init_range(s, -a, nullptr, rational(0), false); + } + else if (bv.is_bit2bool(e, s, idx)) { + auto& val = wval(s); + val.try_set_bit(idx, !sign); + val.fixed.set(idx, true); + val.tighten_range(); + } + } + + // + // x + a <= b <=> x in [-a, b - a + 1[ b != -1 + // a <= x + b <=> x in [a - b, -b[ a != 0 + // x + a <= x + b <=> x in [-a, -b[ a != b + // + // x + a < b <=> ! (b <= x + a) <=> x not in [-b, a - b + 1[ <=> x in [a - b + 1, -b [ b != 0 + // a < x + b <=> ! (x + b <= a) <=> x not in [-a, b - a [ <=> x in [b - a, -a [ a != -1 + // x + a < x + b <=> ! (x + b <= x + a) <=> x in [-a, -b [ a != b + // + void sls_fixed::init_range(expr* x, rational const& a, expr* y, rational const& b, bool sign) { + if (!x && !y) + return; + if (!x) { + // a <= y + b + if (a == 0) + return; + auto& v = wval(y); + if (!sign) + v.add_range(a - b, -b); + else + v.add_range(-b, a - b); + } + else if (!y) { + + if (mod(b + 1, rational::power_of_two(bv.get_bv_size(x))) == 0) + return; + auto& v = wval(x); + if (!sign) + v.add_range(-a, b - a + 1); + else + v.add_range(b - a + 1, -a); + } + else if (x == y) { + if (a == b) + return; + auto& v = wval(x); + if (!sign) + v.add_range(-a, -b); + else + v.add_range(-b, -a); + } + + } + + void sls_fixed::get_offset(expr* e, expr*& x, rational& offset) { + expr* s, * t; + x = e; + offset = 0; + if (bv.is_bv_add(e, s, t)) { + if (bv.is_numeral(s, offset)) + x = t; + else if (bv.is_numeral(t, offset)) + x = s; + } + else if (bv.is_numeral(e, offset)) + x = nullptr; + } + + sls_valuation& sls_fixed::wval(expr* e) { + return ev.wval(e); + } + + void sls_fixed::init_fixed_basic(app* e) { + if (bv.is_bv(e) && m.is_ite(e)) { + auto& val = wval(e); + auto& val_th = wval(e->get_arg(1)); + auto& val_el = wval(e->get_arg(2)); + for (unsigned i = 0; i < val.nw; ++i) + val.fixed[i] = val_el.fixed[i] & val_th.fixed[i] & ~(val_el.bits(i) ^ val_th.bits(i)); + } + } + + void sls_fixed::init_fixed_bv(app* e) { + if (bv.is_bv(e)) + set_fixed_bw(e); + } + + bool sls_fixed::is_fixed1(app* e) const { + if (is_uninterp(e)) + return false; + if (e->get_family_id() == basic_family_id) + return is_fixed1_basic(e); + return all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); }); + } + + bool sls_fixed::is_fixed1_basic(app* e) const { + switch (e->get_decl_kind()) { + case OP_TRUE: + case OP_FALSE: + return true; + case OP_AND: + return any_of(*e, [&](expr* arg) { return ev.is_fixed0(arg) && !ev.bval0(e); }); + case OP_OR: + return any_of(*e, [&](expr* arg) { return ev.is_fixed0(arg) && ev.bval0(e); }); + default: + return all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); }); + } + } + + void sls_fixed::set_fixed_bw(app* e) { + SASSERT(bv.is_bv(e)); + SASSERT(e->get_family_id() == bv.get_fid()); + auto& v = ev.wval(e); + if (all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); })) { + for (unsigned i = 0; i < v.bw; ++i) + v.fixed.set(i, true); + ev.m_fixed.setx(e->get_id(), true, false); + return; + } + switch (e->get_decl_kind()) { + case OP_BAND: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + // (a.fixed & b.fixed) | (a.fixed & ~a.bits) | (b.fixed & ~b.bits) + for (unsigned i = 0; i < a.nw; ++i) + v.fixed[i] = (a.fixed[i] & b.fixed[i]) | (a.fixed[i] & ~a.bits(i)) | (b.fixed[i] & ~b.bits(i)); + break; + } + case OP_BOR: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + // (a.fixed & b.fixed) | (a.fixed & a.bits) | (b.fixed & b.bits) + for (unsigned i = 0; i < a.nw; ++i) + v.fixed[i] = (a.fixed[i] & b.fixed[i]) | (a.fixed[i] & a.bits(i)) | (b.fixed[i] & b.bits(i)); + break; + } + case OP_BXOR: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < a.nw; ++i) + v.fixed[i] = a.fixed[i] & b.fixed[i]; + break; + } + case OP_BNOT: { + auto& a = wval(e->get_arg(0)); + for (unsigned i = 0; i < a.nw; ++i) + v.fixed[i] = a.fixed[i]; + break; + } + case OP_BADD: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + rational r; + if (bv.is_numeral(e->get_arg(0), r) && b.has_range()) + v.add_range(r + b.lo(), r + b.hi()); + else if (bv.is_numeral(e->get_arg(1), r) && a.has_range()) + v.add_range(r + a.lo(), r + a.hi()); + bool pfixed = true; + for (unsigned i = 0; i < v.bw; ++i) { + if (pfixed && a.fixed.get(i) && b.fixed.get(i)) + v.fixed.set(i, true); + else if (!pfixed && a.fixed.get(i) && b.fixed.get(i) && + !a.get_bit(i) && !b.get_bit(i)) { + pfixed = true; + v.fixed.set(i, false); + } + else { + pfixed = false; + v.fixed.set(i, false); + } + } + + break; + } + case OP_BMUL: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + unsigned j = 0, k = 0, zj = 0, zk = 0, hzj = 0, hzk = 0; + // i'th bit depends on bits j + k = i + // if the first j, resp k bits are 0, the bits j + k are 0 + for (; j < v.bw; ++j) + if (!a.fixed.get(j)) + break; + for (; k < v.bw; ++k) + if (!b.fixed.get(k)) + break; + for (; zj < v.bw; ++zj) + if (!a.fixed.get(zj) || a.get_bit(zj)) + break; + for (; zk < v.bw; ++zk) + if (!b.fixed.get(zk) || b.get_bit(zk)) + break; + for (; hzj < v.bw; ++hzj) + if (!a.fixed.get(v.bw - hzj - 1) || a.get_bit(v.bw - hzj - 1)) + break; + for (; hzk < v.bw; ++hzk) + if (!b.fixed.get(v.bw - hzk - 1) || b.get_bit(v.bw - hzk - 1)) + break; + + + if (j > 0 && k > 0) { + for (unsigned i = 0; i < std::min(k, j); ++i) { + SASSERT(!v.get_bit(i)); + v.fixed.set(i, true); + } + } + // lower zj + jk bits are 0 + if (zk > 0 || zj > 0) { + for (unsigned i = 0; i < zk + zj; ++i) { + SASSERT(!v.get_bit(i)); + v.fixed.set(i, true); + } + } + // upper bits are 0, if enough high order bits of a, b are 0. + // TODO - buggy + if (false && hzj < v.bw && hzk < v.bw && hzj + hzk > v.bw) { + hzj = v.bw - hzj; + hzk = v.bw - hzk; + for (unsigned i = hzj + hzk - 1; i < v.bw; ++i) { + SASSERT(!v.get_bit(i)); + v.fixed.set(i, true); + } + } + break; + } + case OP_CONCAT: { + auto& a = wval(e->get_arg(0)); + auto& b = wval(e->get_arg(1)); + for (unsigned i = 0; i < b.bw; ++i) + v.fixed.set(i, b.fixed.get(i)); + for (unsigned i = 0; i < a.bw; ++i) + v.fixed.set(i + b.bw, a.fixed.get(i)); + break; + } + case OP_EXTRACT: { + expr* child; + unsigned lo, hi; + VERIFY(bv.is_extract(e, lo, hi, child)); + auto& a = wval(child); + for (unsigned i = lo; i <= hi; ++i) + v.fixed.set(i - lo, a.fixed.get(i)); + break; + } + case OP_BNEG: { + auto& a = wval(e->get_arg(0)); + bool pfixed = true; + for (unsigned i = 0; i < v.bw; ++i) { + if (pfixed && a.fixed.get(i)) + v.fixed.set(i, true); + else { + pfixed = false; + v.fixed.set(i, false); + } + } + break; + } + case OP_BSHL: { + // determine range of b. + // if b = 0, then inherit fixed from a + // if b >= v.bw then make e fixed to 0 + // if 0 < b < v.bw is known, then inherit shift of fixed values of a + // if 0 < b < v.bw but not known, then inherit run lengths of equal bits of a + // that are fixed. + break; + } + + case OP_BASHR: + case OP_BLSHR: + case OP_INT2BV: + case OP_BCOMP: + case OP_BNAND: + case OP_BREDAND: + case OP_BREDOR: + case OP_BSDIV: + case OP_BSDIV_I: + case OP_BSDIV0: + case OP_BUDIV: + case OP_BUDIV_I: + case OP_BUDIV0: + case OP_BUREM: + case OP_BUREM_I: + case OP_BUREM0: + case OP_BSMOD: + case OP_BSMOD_I: + case OP_BSMOD0: + case OP_BXNOR: + // NOT_IMPLEMENTED_YET(); + break; + case OP_BV_NUM: + case OP_BIT0: + case OP_BIT1: + case OP_BV2INT: + case OP_BNEG_OVFL: + case OP_BSADD_OVFL: + case OP_BUADD_OVFL: + case OP_BSDIV_OVFL: + case OP_BSMUL_NO_OVFL: + case OP_BSMUL_NO_UDFL: + case OP_BSMUL_OVFL: + case OP_BUMUL_NO_OVFL: + case OP_BUMUL_OVFL: + case OP_BIT2BOOL: + case OP_ULEQ: + case OP_UGEQ: + case OP_UGT: + case OP_ULT: + case OP_SLEQ: + case OP_SGEQ: + case OP_SGT: + case OP_SLT: + UNREACHABLE(); + break; + } + } +} diff --git a/src/ast/sls/bv_sls_fixed.h b/src/ast/sls/bv_sls_fixed.h new file mode 100644 index 000000000..14970c20c --- /dev/null +++ b/src/ast/sls/bv_sls_fixed.h @@ -0,0 +1,52 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls_fixed.h + +Abstract: + + Initialize fixed information. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ +#pragma once + +#include "ast/ast.h" +#include "ast/sls/sls_valuation.h" +#include "ast/bv_decl_plugin.h" + +namespace bv { + + class sls_eval; + + class sls_fixed { + sls_eval& ev; + ast_manager& m; + bv_util& bv; + + void init_ranges(expr_ref_vector const& es); + void init_range(app* e, bool sign); + void init_range(expr* x, rational const& a, expr* y, rational const& b, bool sign); + void get_offset(expr* e, expr*& x, rational& offset); + + void init_fixed_basic(app* e); + void init_fixed_bv(app* e); + + bool is_fixed1(app* e) const; + bool is_fixed1_basic(app* e) const; + void set_fixed_bw(app* e); + + sls_valuation& wval(expr* e); + + public: + sls_fixed(sls_eval& ev); + + void init(expr_ref_vector const& es); + + }; +} diff --git a/src/ast/sls/bv_sls_terms.cpp b/src/ast/sls/bv_sls_terms.cpp new file mode 100644 index 000000000..8702c3c48 --- /dev/null +++ b/src/ast/sls/bv_sls_terms.cpp @@ -0,0 +1,212 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls.cpp + +Abstract: + + A Stochastic Local Search (SLS) engine + Uses invertibility conditions, + interval annotations + don't care annotations + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ + +#include "ast/ast_ll_pp.h" +#include "ast/sls/bv_sls.h" + +namespace bv { + + sls_terms::sls_terms(ast_manager& m): + m(m), + bv(m), + m_assertions(m), + m_pinned(m), + m_translated(m), + m_terms(m){} + + + void sls_terms::assert_expr(expr* e) { + m_assertions.push_back(ensure_binary(e)); + } + + expr* sls_terms::ensure_binary(expr* e) { + expr* top = e; + m_pinned.push_back(e); + m_todo.push_back(e); + expr_fast_mark1 mark; + for (unsigned i = 0; i < m_todo.size(); ++i) { + expr* e = m_todo[i]; + if (!is_app(e)) + continue; + if (m_translated.get(e->get_id(), nullptr)) + continue; + if (mark.is_marked(e)) + continue; + mark.mark(e); + for (auto arg : *to_app(e)) + m_todo.push_back(arg); + } + std::stable_sort(m_todo.begin(), m_todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); + for (expr* e : m_todo) + ensure_binary_core(e); + m_todo.reset(); + return m_translated.get(top->get_id()); + } + + void sls_terms::ensure_binary_core(expr* e) { + if (m_translated.get(e->get_id(), nullptr)) + return; + + app* a = to_app(e); + auto arg = [&](unsigned i) { + return m_translated.get(a->get_arg(i)->get_id()); + }; + unsigned num_args = a->get_num_args(); + expr_ref r(m); +#define FOLD_OP(oper) \ + r = arg(0); \ + for (unsigned i = 1; i < num_args; ++i)\ + r = oper(r, arg(i)); \ + + if (m.is_and(e)) { + FOLD_OP(m.mk_and); + } + else if (m.is_or(e)) { + FOLD_OP(m.mk_or); + } + else if (m.is_xor(e)) { + FOLD_OP(m.mk_xor); + } + else if (bv.is_bv_and(e)) { + FOLD_OP(bv.mk_bv_and); + } + else if (bv.is_bv_or(e)) { + FOLD_OP(bv.mk_bv_or); + } + else if (bv.is_bv_xor(e)) { + FOLD_OP(bv.mk_bv_xor); + } + else if (bv.is_bv_add(e)) { + FOLD_OP(bv.mk_bv_add); + } + else if (bv.is_bv_mul(e)) { + FOLD_OP(bv.mk_bv_mul); + } + else if (bv.is_concat(e)) { + FOLD_OP(bv.mk_concat); + } + else if (m.is_distinct(e)) { + expr_ref_vector es(m); + for (unsigned i = 0; i < num_args; ++i) + for (unsigned j = i + 1; j < num_args; ++j) + es.push_back(m.mk_not(m.mk_eq(arg(i), arg(j)))); + r = m.mk_and(es); + } + else if (bv.is_bv_sdiv(e) || bv.is_bv_sdiv0(e) || bv.is_bv_sdivi(e)) { + r = mk_sdiv(arg(0), arg(1)); + } + else if (bv.is_bv_smod(e) || bv.is_bv_smod0(e) || bv.is_bv_smodi(e)) { + r = mk_smod(arg(0), arg(1)); + } + else if (bv.is_bv_srem(e) || bv.is_bv_srem0(e) || bv.is_bv_sremi(e)) { + r = mk_srem(arg(0), arg(1)); + } + else { + for (unsigned i = 0; i < num_args; ++i) + m_args.push_back(arg(i)); + r = m.mk_app(a->get_decl(), num_args, m_args.data()); + m_args.reset(); + } + m_translated.setx(e->get_id(), r); + } + + expr* sls_terms::mk_sdiv(expr* x, expr* y) { + // d = udiv(abs(x), abs(y)) + // y = 0, x >= 0 -> -1 + // y = 0, x < 0 -> 1 + // x = 0, y != 0 -> 0 + // x > 0, y < 0 -> -d + // x < 0, y > 0 -> -d + // x > 0, y > 0 -> d + // x < 0, y < 0 -> d + unsigned sz = bv.get_bv_size(x); + rational N = rational::power_of_two(sz); + expr_ref z(bv.mk_zero(sz), m); + expr* signx = bv.mk_ule(bv.mk_numeral(N / 2, sz), x); + expr* signy = bv.mk_ule(bv.mk_numeral(N / 2, sz), y); + expr* absx = m.mk_ite(signx, bv.mk_bv_sub(bv.mk_numeral(N - 1, sz), x), x); + expr* absy = m.mk_ite(signy, bv.mk_bv_sub(bv.mk_numeral(N - 1, sz), y), y); + expr* d = bv.mk_bv_udiv(absx, absy); + expr* r = m.mk_ite(m.mk_eq(signx, signy), d, bv.mk_bv_neg(d)); + r = m.mk_ite(m.mk_eq(z, y), + m.mk_ite(signx, bv.mk_one(sz), bv.mk_numeral(N - 1, sz)), + m.mk_ite(m.mk_eq(x, z), z, r)); + return r; + } + + expr* sls_terms::mk_smod(expr* x, expr* y) { + // u := umod(abs(x), abs(y)) + // u = 0 -> 0 + // y = 0 -> x + // x < 0, y < 0 -> -u + // x < 0, y >= 0 -> y - u + // x >= 0, y < 0 -> y + u + // x >= 0, y >= 0 -> u + unsigned sz = bv.get_bv_size(x); + expr_ref z(bv.mk_zero(sz), m); + expr_ref abs_x(m.mk_ite(bv.mk_sle(z, x), x, bv.mk_bv_neg(x)), m); + expr_ref abs_y(m.mk_ite(bv.mk_sle(z, y), y, bv.mk_bv_neg(y)), m); + expr_ref u(bv.mk_bv_urem(abs_x, abs_y), m); + return + m.mk_ite(m.mk_eq(u, z), z, + m.mk_ite(m.mk_eq(y, z), x, + m.mk_ite(m.mk_and(bv.mk_sle(z, x), bv.mk_sle(z, x)), u, + m.mk_ite(bv.mk_sle(z, x), bv.mk_bv_add(y, u), + m.mk_ite(bv.mk_sle(z, y), bv.mk_bv_sub(y, u), bv.mk_bv_neg(u)))))); + + } + + expr* sls_terms::mk_srem(expr* x, expr* y) { + // y = 0 -> x + // else x - sdiv(x, y) * y + return + m.mk_ite(m.mk_eq(y, bv.mk_zero(bv.get_bv_size(x))), + x, bv.mk_bv_sub(x, bv.mk_bv_mul(y, mk_sdiv(x, y)))); + } + + + void sls_terms::init() { + // populate terms + expr_fast_mark1 mark; + for (expr* e : m_assertions) + m_todo.push_back(e); + while (!m_todo.empty()) { + expr* e = m_todo.back(); + m_todo.pop_back(); + if (mark.is_marked(e) || !is_app(e)) + continue; + mark.mark(e); + m_terms.setx(e->get_id(), to_app(e)); + for (expr* arg : *to_app(e)) + m_todo.push_back(arg); + } + // populate parents + m_parents.reserve(m_terms.size()); + for (expr* e : m_terms) { + if (!e || !is_app(e)) + continue; + for (expr* arg : *to_app(e)) + m_parents[arg->get_id()].push_back(e); + } + for (auto a : m_assertions) + m_assertion_set.insert(a->get_id()); + } + +} diff --git a/src/ast/sls/bv_sls_terms.h b/src/ast/sls/bv_sls_terms.h new file mode 100644 index 000000000..3baffc78e --- /dev/null +++ b/src/ast/sls/bv_sls_terms.h @@ -0,0 +1,75 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls_terms.h + +Abstract: + + A Stochastic Local Search (SLS) engine + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ +#pragma once + +#include "util/lbool.h" +#include "util/params.h" +#include "util/scoped_ptr_vector.h" +#include "util/uint_set.h" +#include "ast/ast.h" +#include "ast/sls/sls_stats.h" +#include "ast/sls/sls_powers.h" +#include "ast/sls/sls_valuation.h" +#include "ast/bv_decl_plugin.h" + +namespace bv { + + class sls_terms { + ast_manager& m; + bv_util bv; + ptr_vector m_todo, m_args; + expr_ref_vector m_assertions, m_pinned, m_translated; + app_ref_vector m_terms; + vector> m_parents; + tracked_uint_set m_assertion_set; + + expr* ensure_binary(expr* e); + void ensure_binary_core(expr* e); + + expr* mk_sdiv(expr* x, expr* y); + expr* mk_smod(expr* x, expr* y); + expr* mk_srem(expr* x, expr* y); + + public: + sls_terms(ast_manager& m); + + /** + * Add constraints + */ + void assert_expr(expr* e); + + /** + * Initialize structures: assertions, parents, terms + */ + void init(); + + /** + * Accessors. + */ + + ptr_vector const& parents(expr* e) const { return m_parents[e->get_id()]; } + + expr_ref_vector const& assertions() const { return m_assertions; } + + app* term(unsigned id) const { return m_terms.get(id); } + + app_ref_vector const& terms() const { return m_terms; } + + bool is_assertion(expr* e) const { return m_assertion_set.contains(e->get_id()); } + + }; +} diff --git a/src/ast/sls/sls_engine.cpp b/src/ast/sls/sls_engine.cpp index 8bf70f3dd..249c771ed 100644 --- a/src/ast/sls/sls_engine.cpp +++ b/src/ast/sls/sls_engine.cpp @@ -76,19 +76,6 @@ void sls_engine::updt_params(params_ref const & _p) { NOT_IMPLEMENTED_YET(); } -void sls_engine::collect_statistics(statistics& st) const { - double seconds = m_stats.m_stopwatch.get_current_seconds(); - st.update("sls restarts", m_stats.m_restarts); - st.update("sls full evals", m_stats.m_full_evals); - st.update("sls incr evals", m_stats.m_incr_evals); - st.update("sls incr evals/sec", m_stats.m_incr_evals / seconds); - st.update("sls FLIP moves", m_stats.m_flips); - st.update("sls INC moves", m_stats.m_incs); - st.update("sls DEC moves", m_stats.m_decs); - st.update("sls INV moves", m_stats.m_invs); - st.update("sls moves", m_stats.m_moves); - st.update("sls moves/sec", m_stats.m_moves / seconds); -} bool sls_engine::full_eval(model & mdl) { diff --git a/src/ast/sls/sls_engine.h b/src/ast/sls/sls_engine.h index 32338b8ae..614534f1a 100644 --- a/src/ast/sls/sls_engine.h +++ b/src/ast/sls/sls_engine.h @@ -22,42 +22,15 @@ Notes: #include "util/lbool.h" #include "ast/converters/model_converter.h" +#include "ast/sls/sls_stats.h" #include "ast/sls/sls_tracker.h" #include "ast/sls/sls_evaluator.h" -#include "util/statistics.h" class sls_engine { -public: - class stats { - public: - unsigned m_restarts; - stopwatch m_stopwatch; - unsigned m_full_evals; - unsigned m_incr_evals; - unsigned m_moves, m_flips, m_incs, m_decs, m_invs; - - stats() : - m_restarts(0), - m_full_evals(0), - m_incr_evals(0), - m_moves(0), - m_flips(0), - m_incs(0), - m_decs(0), - m_invs(0) { - m_stopwatch.reset(); - m_stopwatch.start(); - } - void reset() { - m_full_evals = m_flips = m_incr_evals = 0; - m_stopwatch.reset(); - m_stopwatch.start(); - } - }; protected: ast_manager & m_manager; - stats m_stats; + bv::sls_stats m_stats; unsynch_mpz_manager m_mpz_manager; powers m_powers; mpz m_zero, m_one, m_two; @@ -94,8 +67,8 @@ public: void assert_expr(expr * e) { m_assertions.push_back(e); } - stats const & get_stats(void) { return m_stats; } - void collect_statistics(statistics & st) const; + bv::sls_stats const & get_stats(void) { return m_stats; } + void collect_statistics(statistics & st) const { m_stats.collect_statistics(st); } void reset_statistics() { m_stats.reset(); } bool full_eval(model & mdl); diff --git a/src/ast/sls/sls_powers.h b/src/ast/sls/sls_powers.h index 9616c43ab..80ccbe04f 100644 --- a/src/ast/sls/sls_powers.h +++ b/src/ast/sls/sls_powers.h @@ -20,6 +20,7 @@ Notes: #pragma once #include "util/mpz.h" +#include "util/map.h" class powers : public u_map { unsynch_mpz_manager & m; diff --git a/src/ast/sls/sls_stats.h b/src/ast/sls/sls_stats.h new file mode 100644 index 000000000..9468e9c8d --- /dev/null +++ b/src/ast/sls/sls_stats.h @@ -0,0 +1,51 @@ +#pragma once +#include "util/statistics.h" +#include "util/stopwatch.h" + + +namespace bv { + class sls_stats { + public: + unsigned m_restarts; + stopwatch m_stopwatch; + unsigned m_full_evals; + unsigned m_incr_evals; + unsigned m_moves, m_flips, m_incs, m_decs, m_invs; + + sls_stats() : + m_restarts(0), + m_full_evals(0), + m_incr_evals(0), + m_moves(0), + m_flips(0), + m_incs(0), + m_decs(0), + m_invs(0) { + m_stopwatch.reset(); + m_stopwatch.start(); + } + void reset() { + m_full_evals = m_flips = m_incr_evals = 0; + m_stopwatch.reset(); + m_stopwatch.start(); + } + + void collect_statistics(statistics& st) const { + double seconds = m_stopwatch.get_current_seconds(); + st.update("sls restarts", m_restarts); + st.update("sls full evals", m_full_evals); + st.update("sls incr evals", m_incr_evals); + if (seconds > 0 && m_incr_evals > 0) + st.update("sls incr evals/sec", m_incr_evals / seconds); + if (seconds > 0 && m_moves > 0) + st.update("sls moves/sec", m_moves / seconds); + st.update("sls FLIP moves", m_flips); + st.update("sls INC moves", m_incs); + st.update("sls DEC moves", m_decs); + st.update("sls INV moves", m_invs); + st.update("sls moves", m_moves); + + } + + }; +} diff --git a/src/ast/sls/sls_valuation.cpp b/src/ast/sls/sls_valuation.cpp new file mode 100644 index 000000000..3160e5cf5 --- /dev/null +++ b/src/ast/sls/sls_valuation.cpp @@ -0,0 +1,653 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_valuation.cpp + +Abstract: + + A Stochastic Local Search (SLS) engine + Uses invertibility conditions, + interval annotations + don't care annotations + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ + +#include "ast/sls/sls_valuation.h" + +namespace bv { + + void bvect::set_bw(unsigned bw) { + this->bw = bw; + nw = (bw + sizeof(digit_t) * 8 - 1) / (8 * sizeof(digit_t)); + mask = (1 << (bw % (8 * sizeof(digit_t)))) - 1; + if (mask == 0) + mask = ~(digit_t)0; + reserve(nw + 1); + } + + bool operator==(bvect const& a, bvect const& b) { + SASSERT(a.nw > 0); + return 0 == mpn_manager().compare(a.data(), a.nw, b.data(), a.nw); + } + + bool operator<(bvect const& a, bvect const& b) { + SASSERT(a.nw > 0); + return mpn_manager().compare(a.data(), a.nw, b.data(), a.nw) < 0; + } + + bool operator>(bvect const& a, bvect const& b) { + SASSERT(a.nw > 0); + return mpn_manager().compare(a.data(), a.nw, b.data(), a.nw) > 0; + } + + bool operator<=(bvect const& a, bvect const& b) { + SASSERT(a.nw > 0); + return mpn_manager().compare(a.data(), a.nw, b.data(), a.nw) <= 0; + } + + bool operator>=(bvect const& a, bvect const& b) { + SASSERT(a.nw > 0); + return mpn_manager().compare(a.data(), a.nw, b.data(), a.nw) >= 0; + } + + std::ostream& operator<<(std::ostream& out, bvect const& v) { + out << std::hex; + bool nz = false; + for (unsigned i = v.nw; i-- > 0;) { + auto w = v[i]; + if (i + 1 == v.nw) + w &= v.mask; + if (nz) + out << std::setw(8) << std::setfill('0') << w; + else if (w != 0) + out << w, nz = true; + } + if (!nz) + out << "0"; + out << std::dec; + return out; + } + + rational bvect::get_value(unsigned nw) const { + rational p(1), r(0); + for (unsigned i = 0; i < nw; ++i) { + r += p * rational((*this)[i]); + p *= rational::power_of_two(8 * sizeof(digit_t)); + } + return r; + } + + sls_valuation::sls_valuation(unsigned bw) { + set_bw(bw); + m_lo.set_bw(bw); + m_hi.set_bw(bw); + m_bits.set_bw(bw); + fixed.set_bw(bw); + eval.set_bw(bw); + // have lo, hi bits, fixed point to memory allocated within this of size num_bytes each allocated + for (unsigned i = 0; i < nw; ++i) + m_lo[i] = 0, m_hi[i] = 0, m_bits[i] = 0, fixed[i] = 0, eval[i] = 0; + fixed[nw - 1] = ~mask; + } + + void sls_valuation::set_bw(unsigned b) { + bw = b; + nw = (bw + sizeof(digit_t) * 8 - 1) / (8 * sizeof(digit_t)); + mask = (1 << (bw % (8 * sizeof(digit_t)))) - 1; + if (mask == 0) + mask = ~(digit_t)0; + } + + bool sls_valuation::commit_eval() { + for (unsigned i = 0; i < nw; ++i) + if (0 != (fixed[i] & (m_bits[i] ^ eval[i]))) + return false; + if (!in_range(eval)) + return false; + for (unsigned i = 0; i < nw; ++i) + m_bits[i] = eval[i]; + SASSERT(well_formed()); + return true; + } + + bool sls_valuation::in_range(bvect const& bits) const { + mpn_manager m; + auto c = m.compare(m_lo.data(), nw, m_hi.data(), nw); + SASSERT(!has_overflow(bits)); + // full range + + if (c == 0) + return true; + // lo < hi: then lo <= bits & bits < hi + if (c < 0) + return + m.compare(m_lo.data(), nw, bits.data(), nw) <= 0 && + m.compare(bits.data(), nw, m_hi.data(), nw) < 0; + // hi < lo: bits < hi or lo <= bits + return + m.compare(m_lo.data(), nw, bits.data(), nw) <= 0 || + m.compare(bits.data(), nw, m_hi.data(), nw) < 0; + } + + // + // largest dst <= src and dst is feasible + // set dst := src & (~fixed | bits) + // + // increment dst if dst < src by setting bits below msb(src & ~dst) to 1 + // + // if dst < lo < hi: + // return false + // if lo < hi <= dst: + // set dst := hi - 1 + // if hi <= dst < lo + // set dst := hi - 1 + // + + bool sls_valuation::get_at_most(bvect const& src, bvect& dst) const { + SASSERT(!has_overflow(src)); + for (unsigned i = 0; i < nw; ++i) + dst[i] = src[i] & (~fixed[i] | m_bits[i]); + + // + // If dst < src, then find the most significant + // bit where src[idx] = 1, dst[idx] = 0 + // set dst[j] = bits_j | ~fixed_j for j < idx + // + for (unsigned i = nw; i-- > 0; ) { + if (0 != (~dst[i] & src[i])) { + auto idx = log2(~dst[i] & src[i]); + auto mask = (1 << idx) - 1; + dst[i] = (~fixed[i] & mask) | dst[i]; + for (unsigned j = i; j-- > 0; ) + dst[j] = (~fixed[j] | m_bits[j]); + break; + } + } + SASSERT(!has_overflow(dst)); + return round_down(dst); + } + + // + // smallest dst >= src and dst is feasible with respect to this. + // set dst := (src & ~fixed) | (fixed & bits) + // + // decrement dst if dst > src by setting bits below msb to 0 unless fixed + // + // if lo < hi <= dst + // return false + // if dst < lo < hi: + // set dst := lo + // if hi <= dst < lo + // set dst := lo + // + bool sls_valuation::get_at_least(bvect const& src, bvect& dst) const { + SASSERT(!has_overflow(src)); + for (unsigned i = 0; i < nw; ++i) + dst[i] = (~fixed[i] & src[i]) | (fixed[i] & m_bits[i]); + + // + // If dst > src, then find the most significant + // bit where src[idx] = 0, dst[idx] = 1 + // set dst[j] = dst[j] & fixed_j for j < idx + // + for (unsigned i = nw; i-- > 0; ) { + if (0 != (dst[i] & ~src[i])) { + auto idx = log2(dst[i] & ~src[i]); + auto mask = (1 << idx); + dst[i] = dst[i] & (fixed[i] | mask); + for (unsigned j = i; j-- > 0; ) + dst[j] = dst[j] & fixed[j]; + break; + } + } + SASSERT(!has_overflow(dst)); + return round_up(dst); + } + + bool sls_valuation::round_up(bvect& dst) const { + if (m_lo < m_hi) { + if (m_hi <= dst) + return false; + if (m_lo > dst) + set(dst, m_lo); + } + else if (m_hi <= dst && m_lo > dst) + set(dst, m_lo); + SASSERT(!has_overflow(dst)); + return true; + } + + bool sls_valuation::round_down(bvect& dst) const { + if (m_lo < m_hi) { + if (m_lo > dst) + return false; + if (m_hi <= dst) { + set(dst, m_hi); + sub1(dst); + } + } + else if (m_hi <= dst && m_lo > dst) { + set(dst, m_hi); + sub1(dst); + } + SASSERT(well_formed()); + return true; + } + + bool sls_valuation::set_random_at_most(bvect const& src, bvect& tmp, random_gen& r) { + if (!get_at_most(src, tmp)) + return false; + if (is_zero(tmp) || (0 == r() % 2)) + return try_set(tmp); + + set_random_below(tmp, r); + // random value below tmp + + if (m_lo == m_hi || is_zero(m_lo) || m_lo <= tmp) + return try_set(tmp); + + // for simplicity, bail out if we were not lucky + return get_at_most(src, tmp) && try_set(tmp); + } + + bool sls_valuation::set_random_at_least(bvect const& src, bvect& tmp, random_gen& r) { + if (!get_at_least(src, tmp)) + return false; + if (is_ones(tmp) || (0 == r() % 2)) + return try_set(tmp); + + // random value at least tmp + set_random_above(tmp, r); + + if (m_lo == m_hi || is_zero(m_hi) || m_hi > tmp) + return try_set(tmp); + + // for simplicity, bail out if we were not lucky + return get_at_least(src, tmp) && try_set(tmp); + } + + bool sls_valuation::set_random_in_range(bvect const& lo, bvect const& hi, bvect& tmp, random_gen& r) { + if (0 == r() % 2) { + if (!get_at_least(lo, tmp)) + return false; + SASSERT(in_range(tmp)); + if (hi < tmp) + return false; + + if (is_ones(tmp) || (0 == r() % 2)) + return try_set(tmp); + set_random_above(tmp, r); + round_down(tmp, [&](bvect const& t) { return hi >= t && in_range(t); }); + if (in_range(tmp) && lo <= tmp && hi >= tmp) + return try_set(tmp); + return get_at_least(lo, tmp) && hi >= tmp && try_set(tmp); + } + else { + if (!get_at_most(hi, tmp)) + return false; + SASSERT(in_range(tmp)); + if (lo > tmp) + return false; + if (is_zero(tmp) || (0 == r() % 2)) + return try_set(tmp); + set_random_below(tmp, r); + round_up(tmp, [&](bvect const& t) { return lo <= t && in_range(t); }); + if (in_range(tmp) && lo <= tmp && hi >= tmp) + return try_set(tmp); + return get_at_most(hi, tmp) && lo <= tmp && try_set(tmp); + } + } + + void sls_valuation::round_down(bvect& dst, std::function const& is_feasible) { + for (unsigned i = bw; !is_feasible(dst) && i-- > 0; ) + if (!fixed.get(i) && dst.get(i)) + dst.set(i, false); + repair_sign_bits(dst); + } + + void sls_valuation::round_up(bvect& dst, std::function const& is_feasible) { + for (unsigned i = 0; !is_feasible(dst) && i < bw; ++i) + if (!fixed.get(i) && !dst.get(i)) + dst.set(i, true); + repair_sign_bits(dst); + } + + void sls_valuation::set_random_above(bvect& dst, random_gen& r) { + for (unsigned i = 0; i < nw; ++i) + dst[i] = dst[i] | (random_bits(r) & ~fixed[i]); + repair_sign_bits(dst); + } + + void sls_valuation::set_random_below(bvect& dst, random_gen& r) { + if (is_zero(dst)) + return; + unsigned n = 0, idx = UINT_MAX; + for (unsigned i = 0; i < bw; ++i) + if (dst.get(i) && !fixed.get(i) && (r() % ++n) == 0) + idx = i; + + if (idx == UINT_MAX) + return; + dst.set(idx, false); + for (unsigned i = 0; i < idx; ++i) + if (!fixed.get(i)) + dst.set(i, r() % 2 == 0); + repair_sign_bits(dst); + } + + bool sls_valuation::set_repair(bool try_down, bvect& dst) { + for (unsigned i = 0; i < nw; ++i) + dst[i] = (~fixed[i] & dst[i]) | (fixed[i] & m_bits[i]); + + repair_sign_bits(dst); + if (in_range(dst)) { + set(eval, dst); + return true; + } + bool repaired = false; + dst.set_bw(bw); + if (m_lo < m_hi) { + for (unsigned i = bw; m_hi <= dst && !in_range(dst) && i-- > 0; ) + if (!fixed.get(i) && dst.get(i)) + dst.set(i, false); + for (unsigned i = 0; i < bw && dst < m_lo && !in_range(dst); ++i) + if (!fixed.get(i) && !dst.get(i)) + dst.set(i, true); + } + else { + for (unsigned i = 0; !in_range(dst) && i < bw; ++i) + if (!fixed.get(i) && !dst.get(i)) + dst.set(i, true); + for (unsigned i = bw; !in_range(dst) && i-- > 0;) + if (!fixed.get(i) && dst.get(i)) + dst.set(i, false); + } + repair_sign_bits(dst); + if (in_range(dst)) { + set(eval, dst); + repaired = true; + } + dst.set_bw(0); + return repaired; + } + + void sls_valuation::min_feasible(bvect& out) const { + if (m_lo < m_hi) + m_lo.copy_to(nw, out); + else { + for (unsigned i = 0; i < nw; ++i) + out[i] = fixed[i] & m_bits[i]; + } + repair_sign_bits(out); + SASSERT(!has_overflow(out)); + } + + void sls_valuation::max_feasible(bvect& out) const { + if (m_lo < m_hi) { + m_hi.copy_to(nw, out); + sub1(out); + } + else { + for (unsigned i = 0; i < nw; ++i) + out[i] = ~fixed[i] | m_bits[i]; + } + repair_sign_bits(out); + SASSERT(!has_overflow(out)); + } + + unsigned sls_valuation::msb(bvect const& src) const { + SASSERT(!has_overflow(src)); + for (unsigned i = nw; i-- > 0; ) + if (src[i] != 0) + return i * 8 * sizeof(digit_t) + log2(src[i]); + return bw; + } + + void sls_valuation::set_value(bvect& bits, rational const& n) { + for (unsigned i = 0; i < bw; ++i) + bits.set(i, n.get_bit(i)); + clear_overflow_bits(bits); + } + + void sls_valuation::get(bvect& dst) const { + m_bits.copy_to(nw, dst); + } + + digit_t sls_valuation::random_bits(random_gen& rand) { + digit_t r = 0; + for (digit_t i = 0; i < sizeof(digit_t); ++i) + r ^= rand() << (8 * i); + return r; + } + + void sls_valuation::get_variant(bvect& dst, random_gen& r) const { + for (unsigned i = 0; i < nw; ++i) + dst[i] = (random_bits(r) & ~fixed[i]) | (fixed[i] & m_bits[i]); + repair_sign_bits(dst); + clear_overflow_bits(dst); + } + + void sls_valuation::repair_sign_bits(bvect& dst) const { + if (m_signed_prefix == 0) + return; + bool sign = dst.get(bw - 1); + for (unsigned i = bw; i-- >= bw - m_signed_prefix; ) { + if (dst.get(i) != sign) { + if (fixed.get(i)) { + for (unsigned i = bw; i-- >= bw - m_signed_prefix; ) + if (!fixed.get(i)) + dst.set(i, !sign); + return; + } + else + dst.set(i, sign); + } + } + } + + // + // new_bits != bits => ~fixed + // 0 = (new_bits ^ bits) & fixed + // also check that new_bits are in range + // + bool sls_valuation::can_set(bvect const& new_bits) const { + SASSERT(!has_overflow(new_bits)); + for (unsigned i = 0; i < nw; ++i) + if (0 != ((new_bits[i] ^ m_bits[i]) & fixed[i])) + return false; + return in_range(new_bits); + } + + unsigned sls_valuation::to_nat(unsigned max_n) { + bvect const& d = m_bits; + SASSERT(!has_overflow(d)); + SASSERT(max_n < UINT_MAX / 2); + unsigned p = 1; + unsigned value = 0; + for (unsigned i = 0; i < bw; ++i) { + if (p >= max_n) { + for (unsigned j = i; j < bw; ++j) + if (d.get(j)) + return max_n; + return value; + } + if (d.get(i)) + value += p; + p <<= 1; + } + return value; + } + + void sls_valuation::shift_right(bvect& out, unsigned shift) const { + SASSERT(shift < bw); + for (unsigned i = 0; i < bw; ++i) + out.set(i, i + shift < bw ? m_bits.get(i + shift) : false); + SASSERT(well_formed()); + } + + void sls_valuation::add_range(rational l, rational h) { + + l = mod(l, rational::power_of_two(bw)); + h = mod(h, rational::power_of_two(bw)); + if (h == l) + return; + + //verbose_stream() << "[" << l << ", " << h << "[\n"; + //verbose_stream() << *this << "\n"; + + if (m_lo == m_hi) { + set_value(m_lo, l); + set_value(m_hi, h); + } + else { + auto old_lo = lo(); + auto old_hi = hi(); + if (old_lo < old_hi) { + if (old_lo < l && l < old_hi) + set_value(m_lo, l), + old_lo = l; + if (old_hi < h && h < old_hi) + set_value(m_hi, h); + } + else { + SASSERT(old_hi < old_lo); + if (old_lo < l || l < old_hi) + set_value(m_lo, l), + old_lo = l; + if (old_lo < h && h < old_hi) + set_value(m_hi, h); + else if (old_hi < old_lo && (h < old_hi || old_lo < h)) + set_value(m_hi, h); + } + } + + + + SASSERT(!has_overflow(m_lo)); + SASSERT(!has_overflow(m_hi)); + + tighten_range(); + SASSERT(well_formed()); + // verbose_stream() << *this << "\n"; + } + + // + // update bits based on ranges + // tighten lo/hi based on fixed bits. + // lo[bit_i] != fixedbit[bit_i] + // let bit_i be most significant bit position of disagreement. + // if fixedbit = 1, lo = 0, increment lo + // if fixedbit = 0, lo = 1, lo := fixed & bits + // (hi-1)[bit_i] != fixedbit[bit_i] + // if fixedbit = 0, hi-1 = 1, set hi-1 := 0, maximize below bit_i + // if fixedbit = 1, hi-1 = 0, hi := fixed & bits + // tighten fixed bits based on lo/hi + // lo + 1 = hi -> set bits = lo + // lo < hi, set most significant bits based on hi + // + void sls_valuation::tighten_range() { + + // verbose_stream() << "tighten " << *this << "\n"; + if (m_lo == m_hi) + return; + + if (!in_range(m_bits)) { + // verbose_stream() << "not in range\n"; + bool compatible = true; + for (unsigned i = 0; i < nw && compatible; ++i) + compatible = 0 == (fixed[i] & (m_bits[i] ^ m_lo[i])); + //verbose_stream() << (fixed[0] & (m_bits[0] ^ m_lo[0])) << "\n"; + //verbose_stream() << bw << " " << m_lo[0] << " " << m_bits[0] << "\n"; + if (compatible) { + //verbose_stream() << "compatible\n"; + set(m_bits, m_lo); + } + else { + bvect tmp(m_bits.nw); + tmp.set_bw(bw); + set(tmp, m_lo); + unsigned max_diff = bw; + for (unsigned i = 0; i < bw; ++i) { + if (fixed.get(i) && (m_bits.get(i) ^ m_lo.get(i))) + max_diff = i; + } + SASSERT(max_diff != bw); + + for (unsigned i = 0; i <= max_diff; ++i) + tmp.set(i, fixed.get(i) && m_bits.get(i)); + + bool found0 = false; + for (unsigned i = max_diff + 1; i < bw; ++i) { + if (found0 || m_lo.get(i) || fixed.get(i)) + tmp.set(i, m_lo.get(i) && fixed.get(i)); + else { + tmp.set(i, true); + found0 = true; + } + } + set(m_bits, tmp); + } + } + // update lo, hi to be feasible. + + for (unsigned i = bw; i-- > 0; ) { + if (!fixed.get(i)) + continue; + if (m_bits.get(i) == m_lo.get(i)) + continue; + if (m_bits.get(i)) { + m_lo.set(i, true); + for (unsigned j = i; j-- > 0; ) + m_lo.set(j, fixed.get(j) && m_bits.get(j)); + } + else { + for (unsigned j = bw; j-- > 0; ) + m_lo.set(j, fixed.get(j) && m_bits.get(j)); + } + break; + } + + SASSERT(well_formed()); + } + + void sls_valuation::set_sub(bvect& out, bvect const& a, bvect const& b) const { + digit_t c; + mpn_manager().sub(a.data(), nw, b.data(), nw, out.data(), &c); + clear_overflow_bits(out); + } + + bool sls_valuation::set_add(bvect& out, bvect const& a, bvect const& b) const { + digit_t c; + mpn_manager().add(a.data(), nw, b.data(), nw, out.data(), nw + 1, &c); + bool ovfl = out[nw] != 0 || has_overflow(out); + clear_overflow_bits(out); + return ovfl; + } + + bool sls_valuation::set_mul(bvect& out, bvect const& a, bvect const& b, bool check_overflow) const { + mpn_manager().mul(a.data(), nw, b.data(), nw, out.data()); + bool ovfl = false; + if (check_overflow) { + ovfl = has_overflow(out); + for (unsigned i = nw; i < 2 * nw; ++i) + ovfl |= out[i] != 0; + } + clear_overflow_bits(out); + return ovfl; + } + + bool sls_valuation::is_power_of2(bvect const& src) const { + unsigned c = 0; + for (unsigned i = 0; i < nw; ++i) + c += get_num_1bits(src[i]); + return c == 1; + } + + +} diff --git a/src/ast/sls/sls_valuation.h b/src/ast/sls/sls_valuation.h new file mode 100644 index 000000000..dcabf04c0 --- /dev/null +++ b/src/ast/sls/sls_valuation.h @@ -0,0 +1,313 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_valuation.h + +Abstract: + + A Stochastic Local Search (SLS) engine + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ +#pragma once + +#include "util/lbool.h" +#include "util/params.h" +#include "util/scoped_ptr_vector.h" +#include "util/uint_set.h" +#include "ast/ast.h" +#include "ast/sls/sls_stats.h" +#include "ast/sls/sls_powers.h" +#include "ast/bv_decl_plugin.h" + +namespace bv { + + class bvect : public svector { + public: + unsigned bw = 0; + unsigned nw = 0; + unsigned mask = 0; + + bvect() {} + bvect(unsigned sz) : svector(sz, (unsigned)0) {} + void set_bw(unsigned bw); + + void copy_to(unsigned nw, bvect & dst) const { + SASSERT(nw <= this->size()); + for (unsigned i = 0; i < nw; ++i) + dst[i] = (*this)[i]; + } + + void set(unsigned bit_idx, bool val) { + auto _val = static_cast(0 - static_cast(val)); + get_bit_word(bit_idx) ^= (_val ^ get_bit_word(bit_idx)) & get_pos_mask(bit_idx); + } + + bool get(unsigned bit_idx) const { + return (get_bit_word(bit_idx) & get_pos_mask(bit_idx)) != 0; + } + + unsigned parity() const { + SASSERT(bw > 0); + for (unsigned i = 0; i < nw; ++i) + if ((*this)[i] != 0) + return (8 * sizeof(digit_t) * i) + trailing_zeros((*this)[i]); + return bw; + } + + rational get_value(unsigned nw) const; + + friend bool operator==(bvect const& a, bvect const& b); + friend bool operator<(bvect const& a, bvect const& b); + friend bool operator>(bvect const& a, bvect const& b); + friend bool operator<=(bvect const& a, bvect const& b); + friend bool operator>=(bvect const& a, bvect const& b); + friend std::ostream& operator<<(std::ostream& out, bvect const& v); + + private: + + static digit_t get_pos_mask(unsigned bit_idx) { + return (digit_t)1 << (digit_t)(bit_idx % (8 * sizeof(digit_t))); + } + + digit_t get_bit_word(unsigned bit_idx) const { + return (*this)[bit_idx / (8 * sizeof(digit_t))]; + } + + digit_t& get_bit_word(unsigned bit_idx) { + return (*this)[bit_idx / (8 * sizeof(digit_t))]; + } + }; + + bool operator==(bvect const& a, bvect const& b); + bool operator<(bvect const& a, bvect const& b); + bool operator<=(bvect const& a, bvect const& b); + bool operator>=(bvect const& a, bvect const& b); + bool operator>(bvect const& a, bvect const& b); + inline bool operator!=(bvect const& a, bvect const& b) { return !(a == b); } + std::ostream& operator<<(std::ostream& out, bvect const& v); + + class sls_valuation { + protected: + bvect m_bits; + bvect m_lo, m_hi; // range assignment to bit-vector, as wrap-around interval + unsigned m_signed_prefix = 0; + + unsigned mask; + bool round_up(bvect& dst) const; + bool round_down(bvect& dst) const; + + void repair_sign_bits(bvect& dst) const; + + + public: + unsigned bw; // bit-width + unsigned nw; // num words + bvect fixed; // bit assignment and don't care bit + bvect eval; // current evaluation + + sls_valuation(unsigned bw); + + void set_bw(unsigned bw); + void set_signed(unsigned prefix) { m_signed_prefix = prefix; } + + unsigned num_bytes() const { return (bw + 7) / 8; } + + digit_t bits(unsigned i) const { return m_bits[i]; } + bvect const& bits() const { return m_bits; } + bool commit_eval(); + + bool get_bit(unsigned i) const { return m_bits.get(i); } + bool try_set_bit(unsigned i, bool b) { + SASSERT(in_range(m_bits)); + if (fixed.get(i) && get_bit(i) != b) + return false; + eval.set(i, b); + if (in_range(m_bits)) + return true; + eval.set(i, !b); + return false; + } + + void set_value(bvect& bits, rational const& r); + + rational get_value() const { return m_bits.get_value(nw); } + rational get_eval() const { return eval.get_value(nw); } + rational lo() const { return m_lo.get_value(nw); } + rational hi() const { return m_hi.get_value(nw); } + + + void get(bvect& dst) const; + void add_range(rational lo, rational hi); + bool has_range() const { return m_lo != m_hi; } + void tighten_range(); + + void clear_overflow_bits(bvect& bits) const { + SASSERT(nw > 0); + bits[nw - 1] &= mask; + SASSERT(!has_overflow(bits)); + } + + bool in_range(bvect const& bits) const; + bool can_set(bvect const& bits) const; + + bool eq(sls_valuation const& other) const { return eq(other.m_bits); } + bool eq(bvect const& other) const { return other == m_bits; } + + bool is_zero() const { return is_zero(m_bits); } + bool is_zero(bvect const& a) const { + for (unsigned i = 0; i < nw - 1; ++i) + if (a[i] != 0) + return false; + return (a[nw - 1] & mask) == 0; + } + + bool is_ones() const { return is_ones(m_bits); } + + bool is_ones(bvect const& a) const { + SASSERT(!has_overflow(a)); + for (unsigned i = 0; i + 1 < nw; ++i) + if (0 != ~a[i]) + return false; + return 0 == (mask & ~a[nw - 1]); + } + + bool is_one() const { return is_one(m_bits); } + bool is_one(bvect const& a) const { + SASSERT(!has_overflow(a)); + for (unsigned i = 1; i < nw; ++i) + if (0 != a[i]) + return false; + return 1 == a[0]; + } + + bool sign() const { return m_bits.get(bw - 1); } + + bool has_overflow(bvect const& bits) const { return 0 != (bits[nw - 1] & ~mask); } + + unsigned parity(bvect const& bits) const { return bits.parity(); } + + void min_feasible(bvect& out) const; + void max_feasible(bvect& out) const; + + // most significant bit or bw if src = 0 + unsigned msb(bvect const& src) const; + + bool is_power_of2(bvect const& src) const; + + // retrieve largest number at or below (above) src which is feasible + // with respect to fixed, lo, hi. + bool get_at_most(bvect const& src, bvect& dst) const; + bool get_at_least(bvect const& src, bvect& dst) const; + + bool set_random_at_most(bvect const& src, bvect& tmp, random_gen& r); + bool set_random_at_least(bvect const& src, bvect& tmp, random_gen& r); + bool set_random_in_range(bvect const& lo, bvect const& hi, bvect& tmp, random_gen& r); + + bool set_repair(bool try_down, bvect& dst); + void set_random_above(bvect& dst, random_gen& r); + void set_random_below(bvect& dst, random_gen& r); + void round_down(bvect& dst, std::function const& is_feasible); + void round_up(bvect& dst, std::function const& is_feasible); + + + static digit_t random_bits(random_gen& r); + void get_variant(bvect& dst, random_gen& r) const; + + bool try_set(bvect const& src) { + if (!can_set(src)) + return false; + set(src); + return true; + } + + void set(bvect const& src) { + for (unsigned i = nw; i-- > 0; ) + eval[i] = src[i]; + clear_overflow_bits(eval); + } + + void set_zero(bvect& out) const { + for (unsigned i = 0; i < nw; ++i) + out[i] = 0; + } + + void set_one(bvect& out) const { + for (unsigned i = 1; i < nw; ++i) + out[i] = 0; + out[0] = 1; + } + + void set_zero() { + set_zero(eval); + } + + void sub1(bvect& out) const { + for (unsigned i = 0; i < bw; ++i) { + if (out.get(i)) { + out.set(i, false); + return; + } + else + out.set(i, true); + } + } + + void set_sub(bvect& out, bvect const& a, bvect const& b) const; + bool set_add(bvect& out, bvect const& a, bvect const& b) const; + bool set_mul(bvect& out, bvect const& a, bvect const& b, bool check_overflow = true) const; + void shift_right(bvect& out, unsigned shift) const; + + void set_range(bvect& dst, unsigned lo, unsigned hi, bool b) { + for (unsigned i = lo; i < hi; ++i) + dst.set(i, b); + } + + bool try_set_range(bvect& dst, unsigned lo, unsigned hi, bool b) { + for (unsigned i = lo; i < hi; ++i) + if (fixed.get(i) && get_bit(i) != b) + return false; + for (unsigned i = lo; i < hi; ++i) + dst.set(i, b); + return true; + } + + void set(bvect& dst, unsigned v) const { + dst[0] = v; + for (unsigned i = 1; i < nw; ++i) + dst[i] = 0; + } + + void set(bvect& dst, bvect const& src) const { + for (unsigned i = 0; i < nw; ++i) + dst[i] = src[i]; + } + + unsigned to_nat(unsigned max_n); + + std::ostream& display(std::ostream& out) const { + out << m_bits; + out << " ev: " << eval; + if (!is_zero(fixed)) { + out << " fix:"; + out << fixed; + } + if (m_lo != m_hi) + out << " [" << m_lo << ", " << m_hi << "["; + return out; + } + + bool well_formed() const { + return !has_overflow(m_bits) && (!has_range() || in_range(m_bits)); + } + + }; + + inline std::ostream& operator<<(std::ostream& out, sls_valuation const& v) { return v.display(out); } + +} diff --git a/src/cmd_context/basic_cmds.cpp b/src/cmd_context/basic_cmds.cpp index a2757d955..c93e4432f 100644 --- a/src/cmd_context/basic_cmds.cpp +++ b/src/cmd_context/basic_cmds.cpp @@ -202,6 +202,7 @@ ATOMIC_CMD(get_proof_cmd, "get-proof", "retrieve proof", { cmd_is_declared isd(ctx); pp.set_is_declared(&isd); pp.set_logic(ctx.get_logic()); + pp.set_simplify_implies(params.simplify_implies()); pp.display_smt2(ctx.regular_stream(), pr); ctx.regular_stream() << std::endl; } diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 4dc701bc8..7b21a4f34 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -50,6 +50,7 @@ z3_add_component(sat_smt q_solver.cpp recfun_solver.cpp sat_th.cpp + sls_solver.cpp specrel_solver.cpp tseitin_theory_checker.cpp user_solver.cpp diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp new file mode 100644 index 000000000..8feb9f83e --- /dev/null +++ b/src/sat/smt/sls_solver.cpp @@ -0,0 +1,130 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_solver + +Abstract: + + Interface to Concurrent SLS solver + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-21 + +--*/ + +#include "sat/smt/sls_solver.h" +#include "sat/smt/euf_solver.h" + + + +namespace sls { + + solver::solver(euf::solver& ctx): + th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")) {} + + solver::~solver() { + if (m_bvsls) { + m_bvsls->cancel(); + m_thread.join(); + } + } + + void solver::push_core() { + if (s().scope_lvl() == s().search_lvl() + 1) + init_local_search(); + } + + void solver::pop_core(unsigned n) { + if (s().scope_lvl() - n <= s().search_lvl()) + sample_local_search(); + } + + void solver::simplify() { + + } + + + void solver::init_local_search() { + if (m_bvsls) { + m_bvsls->cancel(); + m_thread.join(); + if (m_result == l_true) { + verbose_stream() << "Found model using local search - INIT\n"; + exit(1); + } + } + // set up state for local search solver here + + m_m = alloc(ast_manager, m); + ast_translation tr(m, *m_m); + + m_completed = false; + m_result = l_undef; + m_bvsls = alloc(bv::sls, *m_m); + // walk clauses, add them + // walk trail stack until search level, add units + // encapsulate bvsls within the arguments of run-local-search. + // ensure bvsls does not touch ast-manager. + + unsigned trail_sz = s().trail_size(); + for (unsigned i = 0; i < trail_sz; ++i) { + auto lit = s().trail_literal(i); + if (s().lvl(lit) > s().search_lvl()) + break; + expr_ref fml = literal2expr(lit); + m_bvsls->assert_expr(tr(fml.get())); + } + unsigned num_vars = s().num_vars(); + for (unsigned i = 0; i < 2*num_vars; ++i) { + auto l1 = ~sat::to_literal(i); + auto const& wlist = s().get_wlist(l1); + for (sat::watched const& w : wlist) { + if (!w.is_binary_non_learned_clause()) + continue; + sat::literal l2 = w.get_literal(); + if (l1.index() > l2.index()) + continue; + expr_ref fml(m.mk_or(literal2expr(l1), literal2expr(l2)), m); + m_bvsls->assert_expr(tr(fml.get())); + } + } + for (auto clause : s().clauses()) { + expr_ref_vector cls(m); + for (auto lit : *clause) + cls.push_back(literal2expr(lit)); + expr_ref fml(m.mk_or(cls), m); + m_bvsls->assert_expr(tr(fml.get())); + } + + // use phase assignment from literals? + std::function eval = [&](expr* e, unsigned r) { + return false; + }; + + m_bvsls->init(); + m_bvsls->init_eval(eval); + m_bvsls->updt_params(s().params()); + + m_thread = std::thread([this]() { run_local_search(); }); + } + + void solver::sample_local_search() { + if (m_completed) { + m_thread.join(); + if (m_result == l_true) { + verbose_stream() << "Found model using local search\n"; + exit(1); + } + } + } + + void solver::run_local_search() { + lbool r = (*m_bvsls)(); + m_result = r; + m_completed = true; + } + +} diff --git a/src/sat/smt/sls_solver.h b/src/sat/smt/sls_solver.h new file mode 100644 index 000000000..c473264ac --- /dev/null +++ b/src/sat/smt/sls_solver.h @@ -0,0 +1,63 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_solver + +Abstract: + + Interface to Concurrent SLS solver + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-21 + +--*/ +#pragma once + +#include +#include "util/rlimit.h" +#include "ast/sls/bv_sls.h" +#include "sat/smt/sat_th.h" + + +namespace euf { + class solver; +} + +namespace sls { + + class solver : public euf::th_euf_solver { + std::atomic m_result; + std::atomic m_completed; + std::thread m_thread; + scoped_ptr m_m; + scoped_ptr m_bvsls; + + void run_local_search(); + void init_local_search(); + void sample_local_search(); + public: + solver(euf::solver& ctx); + ~solver(); + + void push_core() override; + void pop_core(unsigned n) override; + void simplify() override; + + sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; } + void internalize(expr* e) override { UNREACHABLE(); } + th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); } + + + bool unit_propagate() override { return false; } + void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override { UNREACHABLE(); } + sat::check_result check() override { return sat::check_result::CR_DONE; } + std::ostream & display(std::ostream & out) const override { return out; } + std::ostream & display_justification(std::ostream & out, sat::ext_justification_idx idx) const override { UNREACHABLE(); return out; } + std::ostream & display_constraint(std::ostream & out, sat::ext_constraint_idx idx) const override { UNREACHABLE(); return out; } + + }; + +} diff --git a/src/tactic/core/simplify_tactic.cpp b/src/tactic/core/simplify_tactic.cpp index 8d9ff759f..f05b4c4fc 100644 --- a/src/tactic/core/simplify_tactic.cpp +++ b/src/tactic/core/simplify_tactic.cpp @@ -42,6 +42,10 @@ struct simplify_tactic::imp { m_num_steps = 0; } + void collect_statistics(statistics& st) { + st.update("rewriter.steps", m_num_steps); + } + void operator()(goal & g) { tactic_report report("simplifier", g); m_num_steps = 0; @@ -108,6 +112,11 @@ void simplify_tactic::cleanup() { new (m_imp) imp(m, p); } +void simplify_tactic::collect_statistics(statistics& st) const { + if (m_imp) + m_imp->collect_statistics(st); +} + unsigned simplify_tactic::get_num_steps() const { return m_imp->get_num_steps(); } diff --git a/src/tactic/core/simplify_tactic.h b/src/tactic/core/simplify_tactic.h index 1594b3d37..7baabb8d6 100644 --- a/src/tactic/core/simplify_tactic.h +++ b/src/tactic/core/simplify_tactic.h @@ -81,6 +81,8 @@ public: static void get_param_descrs(param_descrs & r); void collect_param_descrs(param_descrs & r) override { get_param_descrs(r); } + + void collect_statistics(statistics& st) const override; void operator()(goal_ref const & in, goal_ref_buffer & result) override; diff --git a/src/tactic/dependent_expr_state_tactic.h b/src/tactic/dependent_expr_state_tactic.h index 4d695e300..79c1993b2 100644 --- a/src/tactic/dependent_expr_state_tactic.h +++ b/src/tactic/dependent_expr_state_tactic.h @@ -62,7 +62,7 @@ public: if (m_simp) pop(1); } - + /** * size(), [](), update() and inconsistent() implement the abstract interface of dependent_expr_state */ @@ -140,6 +140,12 @@ public: cleanup(); } + void collect_statistics(statistics& st) const override { + if (m_simp) + m_simp->collect_statistics(st); + st.copy(m_st); + } + void cleanup() override { if (m_simp) { m_simp->collect_statistics(m_st); @@ -151,13 +157,6 @@ public: m_dep = dependent_expr(m, m.mk_true(), nullptr, nullptr); } - void collect_statistics(statistics& st) const override { - if (m_simp) - m_simp->collect_statistics(st); - else - st.copy(m_st); - } - void reset_statistics() override { if (m_simp) m_simp->reset_statistics(); diff --git a/src/tactic/goal.cpp b/src/tactic/goal.cpp index 23e3ff969..43cecf92d 100644 --- a/src/tactic/goal.cpp +++ b/src/tactic/goal.cpp @@ -696,7 +696,7 @@ bool goal::is_cnf() const { if (!is_literal(lit)) return false; } - if (!is_literal(f)) + else if (!is_literal(f)) return false; } return true; diff --git a/src/tactic/sls/sls_tactic.cpp b/src/tactic/sls/sls_tactic.cpp index 6daadc83b..198204d90 100644 --- a/src/tactic/sls/sls_tactic.cpp +++ b/src/tactic/sls/sls_tactic.cpp @@ -29,6 +29,7 @@ Notes: #include "tactic/sls/sls_tactic.h" #include "params/sls_params.hpp" #include "ast/sls/sls_engine.h" +#include "ast/sls/bv_sls.h" class sls_tactic : public tactic { ast_manager & m; @@ -123,11 +124,115 @@ public: }; +class bv_sls_tactic : public tactic { + ast_manager& m; + params_ref m_params; + bv::sls* m_sls; + statistics m_st; + +public: + bv_sls_tactic(ast_manager& _m, params_ref const& p) : + m(_m), + m_params(p) { + m_sls = alloc(bv::sls, m); + } + + tactic* translate(ast_manager& m) override { + return alloc(bv_sls_tactic, m, m_params); + } + + ~bv_sls_tactic() override { + dealloc(m_sls); + } + + char const* name() const override { return "bv-sls"; } + + void updt_params(params_ref const& p) override { + m_params.append(p); + m_sls->updt_params(m_params); + } + + void collect_param_descrs(param_descrs& r) override { + sls_params::collect_param_descrs(r); + } + + void run(goal_ref const& g, model_converter_ref& mc) { + if (g->inconsistent()) { + mc = nullptr; + return; + } + + for (unsigned i = 0; i < g->size(); i++) + m_sls->assert_expr(g->form(i)); + + m_sls->init(); + std::function false_eval = [&](expr* e, unsigned idx) { + return false; + }; + m_sls->init_eval(false_eval); + + lbool res = m_sls->operator()(); + auto const& stats = m_sls->get_stats(); + report_tactic_progress("Number of flips:", stats.m_moves); + IF_VERBOSE(20, verbose_stream() << res << "\n"); + IF_VERBOSE(20, m_sls->display(verbose_stream())); + m_st.reset(); + m_sls->collect_statistics(m_st); + if (res == l_true) { + if (g->models_enabled()) { + model_ref mdl = m_sls->get_model(); + mc = model2model_converter(mdl.get()); + TRACE("sls_model", mc->display(tout);); + } + g->reset(); + } + else + mc = nullptr; + + } + + void operator()(goal_ref const& g, + goal_ref_buffer& result) override { + result.reset(); + + TRACE("sls", g->display(tout);); + tactic_report report("sls", *g); + + model_converter_ref mc; + run(g, mc); + g->add(mc.get()); + g->inc_depth(); + result.push_back(g.get()); + } + + void cleanup() override { + + auto* d = alloc(bv::sls, m); + std::swap(d, m_sls); + dealloc(d); + } + + void collect_statistics(statistics& st) const override { + st.copy(m_st); + } + + void reset_statistics() override { + m_sls->reset_statistics(); + m_st.reset(); + } + +}; + static tactic * mk_sls_tactic(ast_manager & m, params_ref const & p) { return and_then(fail_if_not(mk_is_qfbv_probe()), // Currently only QF_BV is supported. clean(alloc(sls_tactic, m, p))); } +tactic* mk_bv_sls_tactic(ast_manager& m, params_ref const& p) { + return and_then(fail_if_not(mk_is_qfbv_probe()), // Currently only QF_BV is supported. + clean(alloc(bv_sls_tactic, m, p))); +} + static tactic * mk_preamble(ast_manager & m, params_ref const & p) { params_ref main_p; @@ -171,3 +276,9 @@ tactic * mk_qfbv_sls_tactic(ast_manager & m, params_ref const & p) { t->updt_params(p); return t; } + +tactic* mk_qfbv_new_sls_tactic(ast_manager& m, params_ref const& p) { + tactic* t = and_then(mk_preamble(m, p), mk_bv_sls_tactic(m, p)); + t->updt_params(p); + return t; +} diff --git a/src/tactic/sls/sls_tactic.h b/src/tactic/sls/sls_tactic.h index 3c0612e6e..d58d310e3 100644 --- a/src/tactic/sls/sls_tactic.h +++ b/src/tactic/sls/sls_tactic.h @@ -24,7 +24,16 @@ class tactic; tactic * mk_qfbv_sls_tactic(ast_manager & m, params_ref const & p = params_ref()); +tactic* mk_qfbv_new_sls_tactic(ast_manager& m, params_ref const& p = params_ref()); + +tactic* mk_bv_sls_tactic(ast_manager& m, params_ref const& p = params_ref()); + /* ADD_TACTIC("qfbv-sls", "(try to) solve using stochastic local search for QF_BV.", "mk_qfbv_sls_tactic(m, p)") + + ADD_TACTIC("qfbv-new-sls", "(try to) solve using stochastic local search for QF_BV.", "mk_qfbv_new_sls_tactic(m, p)") + + ADD_TACTIC("qfbv-new-sls-core", "(try to) solve using stochastic local search for QF_BV.", "mk_bv_sls_tactic(m, p)") + */ diff --git a/src/tactic/tactic.h b/src/tactic/tactic.h index ddd187337..652bf8130 100644 --- a/src/tactic/tactic.h +++ b/src/tactic/tactic.h @@ -62,7 +62,7 @@ public: */ virtual void operator()(goal_ref const & in, goal_ref_buffer& result) = 0; - virtual void collect_statistics(statistics & st) const { } + virtual void collect_statistics(statistics& st) const { } virtual void reset_statistics() {} virtual void cleanup() = 0; virtual void reset() { cleanup(); } @@ -130,6 +130,7 @@ public: void cleanup() override {} tactic * translate(ast_manager & m) override { return this; } char const* name() const override { return "skip"; } + void collect_statistics(statistics& st) const override {} }; tactic * mk_skip_tactic(); diff --git a/src/tactic/tactical.cpp b/src/tactic/tactical.cpp index 5b1ea9587..0b8189e8d 100644 --- a/src/tactic/tactical.cpp +++ b/src/tactic/tactical.cpp @@ -1190,6 +1190,9 @@ public: tactic * translate(ast_manager & m) override { return this; } + + void collect_statistics(statistics& st) const override { + } }; tactic * fail_if(probe * p) { @@ -1216,6 +1219,7 @@ public: } tactic * translate(ast_manager & m) override { return translate_core(m); } + }; class if_no_unsat_cores_tactical : public unary_tactical { diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 14b51f822..4ddc1b8cb 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -108,6 +108,7 @@ add_executable(test-z3 simple_parser.cpp simplex.cpp simplifier.cpp + sls_test.cpp small_object_allocator.cpp smt2print_parse.cpp smt_context.cpp diff --git a/src/test/hwf.cpp b/src/test/hwf.cpp index 8a019ec02..b81a9cef3 100644 --- a/src/test/hwf.cpp +++ b/src/test/hwf.cpp @@ -103,7 +103,9 @@ static void bug_to_rational() { static void bug_is_int() { unsigned raw_val[2] = { 2147483648u, 1077720461u }; - double val = *(double*)(raw_val); + double val; + static_assert(sizeof(raw_val) == sizeof(val)); + memcpy(&val, raw_val, sizeof(val)); std::cout << val << "\n"; hwf_manager m; hwf a; diff --git a/src/test/main.cpp b/src/test/main.cpp index 3f073abf2..0c3d0e01a 100644 --- a/src/test/main.cpp +++ b/src/test/main.cpp @@ -267,4 +267,5 @@ int main(int argc, char ** argv) { TST(distribution); TST(euf_bv_plugin); TST(euf_arith_plugin); + TST(sls_test); } diff --git a/src/test/sls_test.cpp b/src/test/sls_test.cpp new file mode 100644 index 000000000..d99035398 --- /dev/null +++ b/src/test/sls_test.cpp @@ -0,0 +1,243 @@ + +#include "ast/sls/bv_sls_eval.h" +#include "ast/rewriter/th_rewriter.h" +#include "ast/reg_decl_plugins.h" +#include "ast/ast_pp.h" + +namespace bv { + class sls_test { + ast_manager& m; + bv_util bv; + + public: + sls_test(ast_manager& m): + m(m), + bv(m) + {} + + void check_eval(expr* a, expr* b, unsigned j) { + auto es = create_exprs(a, b, j); + for (expr* e : es) + check_eval(e); + } + + void check_eval(expr* e) { + std::function value = [](expr*, unsigned) { + return false; + }; + expr_ref_vector es(m); + bv_util bv(m); + es.push_back(e); + sls_eval ev(m); + ev.init_eval(es, value); + ev.tighten_range(es); + th_rewriter rw(m); + expr_ref r(e, m); + rw(r); + + if (bv.is_bv(e)) { + auto const& val = ev.wval(e); + rational n1, n2; + + n1 = val.get_value(); + + VERIFY(bv.is_numeral(r, n2)); + if (n1 != n2) { + verbose_stream() << mk_pp(e, m) << " computed value " << val << "\n"; + verbose_stream() << "should be " << n2 << "\n"; + } + SASSERT(n1 == n2); + VERIFY(n1 == n2); + } + else if (m.is_bool(e)) { + auto val1 = ev.bval0(e); + auto val2 = m.is_true(r); + if (val1 != val2) { + verbose_stream() << mk_pp(e, m) << " computed value " << val1 << " at odds with definition\n"; + } + SASSERT(val1 == val2); + VERIFY(val1 == val2); + } + } + + expr_ref_vector create_exprs(expr* a, expr* b, unsigned j) { + expr_ref_vector result(m); + result.push_back(bv.mk_bv_add(a, b)) + .push_back(bv.mk_bv_mul(a, b)) + .push_back(bv.mk_bv_sub(a, b)) + .push_back(bv.mk_bv_udiv(a, b)) + .push_back(bv.mk_bv_sdiv(a, b)) + .push_back(bv.mk_bv_srem(a, b)) + .push_back(bv.mk_bv_urem(a, b)) + .push_back(bv.mk_bv_smod(a, b)) + .push_back(bv.mk_bv_shl(a, b)) + .push_back(bv.mk_bv_ashr(a, b)) + .push_back(bv.mk_bv_lshr(a, b)) + .push_back(bv.mk_bv_and(a, b)) + .push_back(bv.mk_bv_or(a, b)) + .push_back(bv.mk_bv_xor(a, b)) + .push_back(bv.mk_bv_neg(a)) + .push_back(bv.mk_bv_not(a)) + .push_back(bv.mk_bvumul_ovfl(a, b)) + .push_back(bv.mk_bvumul_no_ovfl(a, b)) + .push_back(bv.mk_zero_extend(3, a)) + .push_back(bv.mk_sign_extend(3, a)) + .push_back(bv.mk_ule(a, b)) + .push_back(bv.mk_sle(a, b)) + .push_back(bv.mk_concat(a, b)) + .push_back(bv.mk_extract(4, 2, a)) + .push_back(bv.mk_bvuadd_ovfl(a, b)) + .push_back(bv.mk_bv_rotate_left(a, j)) + .push_back(bv.mk_bv_rotate_right(a, j)) + .push_back(bv.mk_bv_rotate_left(a, b)) + .push_back(bv.mk_bv_rotate_right(a, b)) + // .push_back(bv.mk_bvsadd_ovfl(a, b)) + // .push_back(bv.mk_bvneg_ovfl(a)) + // .push_back(bv.mk_bvsmul_no_ovfl(a, b)) + // .push_back(bv.mk_bvsmul_no_udfl(a, b)) + // .push_back(bv.mk_bvsmul_ovfl(a, b)) + // .push_back(bv.mk_bvsdiv_ovfl(a, b)) + ; + return result; + } + + + // e = op(a, b), + // update value of a to "random" + // repair a based on computed values. + void check_repair(expr* a, expr* b, unsigned j) { + expr_ref x(m.mk_const("x", bv.mk_sort(bv.get_bv_size(a))), m); + expr_ref y(m.mk_const("y", bv.mk_sort(bv.get_bv_size(b))), m); + auto es1 = create_exprs(a, b, j); + auto es2 = create_exprs(x, b, j); + auto es3 = create_exprs(a, y, j); + for (unsigned i = 0; i < es1.size(); ++i) { + auto e1 = es1.get(i); + auto e2 = es2.get(i); + auto e3 = es3.get(i); + if (bv.is_bv_sdiv(e1)) + continue; + if (bv.is_bv_srem(e1)) + continue; + if (bv.is_bv_smod(e1)) + continue; + if (is_app_of(e1, bv.get_fid(), OP_BUADD_OVFL)) + continue; + check_repair_idx(e1, e2, 0, x); + if (is_app(e1) && to_app(e1)->get_num_args() == 2) + check_repair_idx(e1, e3, 1, y); + } + } + + random_gen rand; + + void check_repair_idx(expr* e1, expr* e2, unsigned idx, expr* x) { + std::function value = [&](expr*, unsigned) { + return rand() % 2 == 0; + }; + expr_ref_vector es(m); + bv_util bv(m); + th_rewriter rw(m); + expr_ref r(e1, m); + rw(r); + es.push_back(m.is_false(r) ? m.mk_not(e1) : e1); + es.push_back(m.is_false(r) ? m.mk_not(e2) : e2); + sls_eval ev(m); + ev.init_eval(es, value); + ev.tighten_range(es); + + if (m.is_bool(e1)) { + SASSERT(m.is_true(r) || m.is_false(r)); + auto val = m.is_true(r); + auto val2 = ev.bval0(e2); + if (val != val2) { + ev.set(e2, val); + auto rep1 = ev.try_repair(to_app(e2), idx); + if (!rep1) { + verbose_stream() << "Not repaired " << mk_pp(e1, m) << " " << mk_pp(e2, m) << " r: " << r << "\n"; + } + auto val3 = ev.bval0(e2); + if (val3 != val) { + verbose_stream() << "Repaired but not corrected " << mk_pp(e2, m) << "\n"; + ev.display(std::cout, es); + exit(0); + } + //SASSERT(rep1); + } + } + if (bv.is_bv(e1)) { + auto& val1 = ev.wval(e1); + auto& val2 = ev.wval(e2); + if (!val1.eq(val2)) { + val2.set(val1.bits()); + auto rep2 = ev.try_repair(to_app(e2), idx); + if (!rep2) { + verbose_stream() << "Not repaired " << mk_pp(e2, m) << "\n"; + } + auto val3 = ev.wval(e2); + val3.commit_eval(); + if (!val3.eq(val1)) { + verbose_stream() << "Repaired but not corrected " << mk_pp(e2, m) << "\n"; + } + //SASSERT(rep2); + } + } + } + + // todo: + void test_fixed() { + + } + }; +} + + +static void test_eval1() { + ast_manager m; + reg_decl_plugins(m); + bv_util bv(m); + + expr_ref e(m); + + bv::sls_test validator(m); + + unsigned k = 0; + unsigned bw = 6; + for (unsigned i = 0; i < 1ul << bw; ++i) { + expr_ref a(bv.mk_numeral(rational(i), bw), m); + for (unsigned j = 0; j < 1ul << bw; ++j) { + expr_ref b(bv.mk_numeral(rational(j), bw), m); + ++k; + if (k % 1000 == 0) + verbose_stream() << "tests " << k << "\n"; + validator.check_eval(a, b, j); + } + } +} + +static void test_repair1() { + ast_manager m; + reg_decl_plugins(m); + bv_util bv(m); + expr_ref e(m); + bv::sls_test validator(m); + + unsigned k = 0; + unsigned bw = 6; + for (unsigned i = 0; i < 1ul << bw; ++i) { + expr_ref a(bv.mk_numeral(rational(i), bw), m); + for (unsigned j = 0; j < 1ul << bw; ++j) { + expr_ref b(bv.mk_numeral(rational(j), bw), m); + ++k; + if (k % 1000 == 0) + verbose_stream() << "tests " << k << "\n"; + validator.check_repair(a, b, j); + } + } +} + +void tst_sls_test() { + test_eval1(); + test_repair1(); + +} diff --git a/src/util/util.h b/src/util/util.h index a4bf78073..f05e4f9f4 100644 --- a/src/util/util.h +++ b/src/util/util.h @@ -142,6 +142,7 @@ static inline unsigned get_num_1bits(uint64_t v) { v = (v + (v >> 4)) & 0x0F0F0F0F0F0F0F0F; uint64_t r = (v * 0x0101010101010101) >> 56; SASSERT(c == r); + return r; #endif }