diff --git a/backends/smt2/smtbmc.py b/backends/smt2/smtbmc.py index fa887dd15..1de0b2a30 100644 --- a/backends/smt2/smtbmc.py +++ b/backends/smt2/smtbmc.py @@ -19,7 +19,7 @@ import os, sys, getopt, re ##yosys-sys-path## -from smtio import smtio, smtopts, mkvcd +from smtio import SmtIo, SmtOpts, MkVcd from collections import defaultdict skip_steps = 0 @@ -35,7 +35,7 @@ dumpall = False assume_skipped = None final_only = False topmod = None -so = smtopts() +so = SmtOpts() def usage(): @@ -137,7 +137,7 @@ if len(args) != 1: if tempind and len(inconstr) != 0: - print("Error: options -i and --smtc are exclusive."); + print("Error: options -i and --smtc are exclusive.") sys.exit(1) @@ -179,7 +179,6 @@ for fn in inconstr: else: assert 0 continue - continue if tokens[0] == "state": current_states = set() @@ -275,7 +274,7 @@ def get_constr_expr(db, state, final=False, getvalues=False): return "(and %s)" % " ".join(expr_list) -smt = smtio(opts=so) +smt = SmtIo(opts=so) print("%s Solver: %s" % (smt.timestamp(), so.solver)) smt.setup("QF_AUFBV") @@ -297,7 +296,7 @@ def write_vcd_trace(steps_start, steps_stop, index): print("%s Writing trace to VCD file: %s" % (smt.timestamp(), filename)) with open(filename, "w") as vcd_file: - vcd = mkvcd(vcd_file) + vcd = MkVcd(vcd_file) path_list = list() for netpath in sorted(smt.hiernets(topmod)): @@ -343,10 +342,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index): print(" reg [%d:0] PI_%s;" % (width-1, name), file=f) print(" %s UUT (" % topmod, file=f) - for i in range(len(primary_inputs)): - name, width = primary_inputs[i] - last_pi = i+1 == len(primary_inputs) - print(" .%s(PI_%s)%s" % (name, name, "" if last_pi else ","), file=f) + print(",\n".join(" .{name}(PI_{name})".format(name=name) for name, _ in primary_inputs), file=f) print(" );", file=f) print(" initial begin", file=f) @@ -365,7 +361,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index): regs = sorted(smt.hiernets(topmod, regs_only=True)) regvals = smt.get_net_bin_list(topmod, regs, "s%d" % steps_start) - print(" #1;", file=f); + print(" #1;", file=f) for reg, val in zip(regs, regvals): hidden_net = False for n in reg: @@ -399,14 +395,14 @@ def write_vlogtb_trace(steps_start, steps_stop, index): pi_names = [[name] for name, _ in primary_inputs if name not in clock_inputs] pi_values = smt.get_net_bin_list(topmod, pi_names, "s%d" % i) - print(" #1;", file=f); - print(" // state %d" % i, file=f); + print(" #1;", file=f) + print(" // state %d" % i, file=f) if i > 0: - print(" @(posedge clock);", file=f); + print(" @(posedge clock);", file=f) for name, val in zip(pi_names, pi_values): print(" PI_%s <= %d'b%s;" % (".".join(name), len(val), val), file=f) - print(" genclock = 0;", file=f); + print(" genclock = 0;", file=f) print(" end", file=f) print("endmodule", file=f) @@ -423,7 +419,6 @@ def write_constr_trace(steps_start, steps_stop, index): width = smt.modinfo[topmod].wsize[name] primary_inputs.append((name, width)) - if steps_start == 0: print("initial", file=f) else: @@ -445,9 +440,7 @@ def write_constr_trace(steps_start, steps_stop, index): for j in range(ports): addr_expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, j)) - addr_list = set() - for val in smt.get_list(addr_expr_list): - addr_list.add(smt.bv2int(val)) + addr_list = set((smt.bv2int(val) for val in smt.get_list(addr_expr_list))) expr_list = list() for i in addr_list: @@ -456,7 +449,6 @@ def write_constr_trace(steps_start, steps_stop, index): for i, val in zip(addr_list, smt.get_list(expr_list)): print("assume (= (select [%s] #b%s) %s)" % (".".join(mempath), format(i, "0%db" % abits), val), file=f) - for k in range(steps_start, steps_stop): print("", file=f) print("state %d" % k, file=f) @@ -564,7 +556,7 @@ if tempind: break -else: # not tempind +else: # not tempind step = 0 retstatus = True while step < num_steps: @@ -650,7 +642,7 @@ else: # not tempind if not retstatus: break - else: # gentrace + else: # gentrace for i in range(step, last_check_step+1): smt.write("(assert (%s_a s%d))" % (topmod, i)) smt.write("(assert %s)" % get_constr_expr(constr_asserts, i)) @@ -677,4 +669,3 @@ smt.wait() print("%s Status: %s" % (smt.timestamp(), "PASSED" if retstatus else "FAILED (!)")) sys.exit(0 if retstatus else 1) - diff --git a/backends/smt2/smtio.py b/backends/smt2/smtio.py index fc7d1e13d..dad63e567 100644 --- a/backends/smt2/smtio.py +++ b/backends/smt2/smtio.py @@ -22,7 +22,18 @@ import subprocess from select import select from time import time -class smtmodinfo: + +hex_dict = { + "0": "0000", "1": "0001", "2": "0010", "3": "0011", + "4": "0100", "5": "0101", "6": "0110", "7": "0111", + "8": "1000", "9": "1001", "A": "1010", "B": "1011", + "C": "1100", "D": "1101", "E": "1110", "F": "1111", + "a": "1010", "b": "1011", "c": "1100", "d": "1101", + "e": "1110", "f": "1111" +} + + +class SmtModInfo: def __init__(self): self.inputs = set() self.outputs = set() @@ -34,7 +45,8 @@ class smtmodinfo: self.asserts = dict() self.anyconsts = dict() -class smtio: + +class SmtIo: def __init__(self, solver=None, debug_print=None, debug_file=None, timeinfo=None, opts=None): if opts is not None: self.solver = opts.solver @@ -108,7 +120,7 @@ class smtio: if fields[1] == "yosys-smt2-module": self.curmod = fields[2] - self.modinfo[self.curmod] = smtmodinfo() + self.modinfo[self.curmod] = SmtModInfo() if fields[1] == "yosys-smt2-cell": self.modinfo[self.curmod].cells[fields[3]] = fields[2] @@ -274,7 +286,7 @@ class smtio: def bv2hex(self, v): h = "" - v = bv2bin(v) + v = self.bv2bin(v) while len(v) > 0: d = 0 if len(v) > 0 and v[-1] == "1": d += 1 @@ -292,25 +304,7 @@ class smtio: if v.startswith("#b"): return v[2:] if v.startswith("#x"): - digits = [] - for d in v[2:]: - if d == "0": digits.append("0000") - if d == "1": digits.append("0001") - if d == "2": digits.append("0010") - if d == "3": digits.append("0011") - if d == "4": digits.append("0100") - if d == "5": digits.append("0101") - if d == "6": digits.append("0110") - if d == "7": digits.append("0111") - if d == "8": digits.append("1000") - if d == "9": digits.append("1001") - if d in ("a", "A"): digits.append("1010") - if d in ("b", "B"): digits.append("1011") - if d in ("c", "C"): digits.append("1100") - if d in ("d", "D"): digits.append("1101") - if d in ("e", "E"): digits.append("1110") - if d in ("f", "F"): digits.append("1111") - return "".join(digits) + return "".join(hex_dict.get(x) for x in v[2:]) assert False def bv2int(self, v): @@ -406,7 +400,7 @@ class smtio: self.p.wait() -class smtopts: +class SmtOpts: def __init__(self): self.shortopts = "s:v" self.longopts = ["no-progress", "dump-smt2="] @@ -445,7 +439,7 @@ class smtopts: """ -class mkvcd: +class MkVcd: def __init__(self, f): self.f = f self.t = -1 diff --git a/examples/smtbmc/demo1.v b/examples/smtbmc/demo1.v index d9be41513..567dde148 100644 --- a/examples/smtbmc/demo1.v +++ b/examples/smtbmc/demo1.v @@ -9,7 +9,7 @@ module demo1(input clk, input addtwo, output iseven); `ifdef FORMAL assert property (cnt != 15); - initial assume (!cnt[3] && !cnt[0]); + initial assume (!cnt[2]); `endif endmodule