3
0
Fork 0
mirror of https://github.com/YosysHQ/yosys synced 2025-04-13 04:28:18 +00:00

smtbmc: Add --track-assumes and --minimize-assumes options

The --track-assumes option makes smtbmc keep track of which assumptions
were used by the solver when reaching an unsat case and to output that
set of assumptions. This is particularly useful to debug PREUNSAT
failures.

The --minimize-assumes option can be used in addition to --track-assumes
which will cause smtbmc to spend additional solving effort to produce a
minimal set of assumptions that are sufficient to cause the unsat
result.
This commit is contained in:
Jannis Harder 2024-03-07 13:27:03 +01:00
parent e4f11eb0a0
commit 42122e240e
3 changed files with 219 additions and 24 deletions

View file

@ -57,6 +57,8 @@ keep_going = False
check_witness = False check_witness = False
detect_loops = False detect_loops = False
incremental = None incremental = None
track_assumes = False
minimize_assumes = False
so = SmtOpts() so = SmtOpts()
@ -189,6 +191,15 @@ def help():
--incremental --incremental
run in incremental mode (experimental) run in incremental mode (experimental)
--track-assumes
track individual assumptions and report a subset of used
assumptions that are sufficient for the reported outcome. This
can be used to debug PREUNSAT failures as well as to find a
smaller set of sufficient assumptions.
--minimize-assumes
when using --track-assumes, solve for a minimal set of sufficient assumptions.
""" + so.helpmsg()) """ + so.helpmsg())
def usage(): def usage():
@ -200,7 +211,8 @@ try:
opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts + opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts +
["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat", ["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat",
"dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=", "dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=",
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental"]) "smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental",
"track-assumes", "minimize-assumes"])
except: except:
usage() usage()
@ -289,6 +301,10 @@ for o, a in opts:
elif o == "--incremental": elif o == "--incremental":
from smtbmc_incremental import Incremental from smtbmc_incremental import Incremental
incremental = Incremental() incremental = Incremental()
elif o == "--track-assumes":
track_assumes = True
elif o == "--minimize-assumes":
minimize_assumes = True
elif so.handle(o, a): elif so.handle(o, a):
pass pass
else: else:
@ -447,6 +463,9 @@ def get_constr_expr(db, state, final=False, getvalues=False, individual=False):
smt = SmtIo(opts=so) smt = SmtIo(opts=so)
if track_assumes:
smt.smt2_options[':produce-unsat-assumptions'] = 'true'
if noinfo and vcdfile is None and vlogtbfile is None and outconstr is None: if noinfo and vcdfile is None and vlogtbfile is None and outconstr is None:
smt.produce_models = False smt.produce_models = False
@ -1497,6 +1516,44 @@ def get_active_assert_map(step, active):
return assert_map return assert_map
assume_enables = {}
def declare_assume_enables():
def recurse(mod, path, key_base=()):
for expr, desc in smt.modinfo[mod].assumes.items():
enable = f"|assume_enable {len(assume_enables)}|"
smt.smt2_assumptions[(expr, key_base)] = enable
smt.write(f"(declare-const {enable} Bool)")
assume_enables[(expr, key_base)] = (enable, path, desc)
for cell, submod in smt.modinfo[mod].cells.items():
recurse(submod, f"{path}.{cell}", (mod, cell, key_base))
recurse(topmod, topmod)
if track_assumes:
declare_assume_enables()
def smt_assert_design_assumes(step):
if not track_assumes:
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
return
if not assume_enables:
return
def expr_for_assume(assume_key, base=None):
expr, key_base = assume_key
expr_prefix = f"(|{expr}| "
expr_suffix = ")"
while key_base:
mod, cell, key_base = key_base
expr_prefix += f"(|{mod}_h {cell}| "
expr_suffix += ")"
return f"{expr_prefix} s{step}{expr_suffix}"
for assume_key, (enable, path, desc) in assume_enables.items():
smt_assert_consequent(f"(=> {enable} {expr_for_assume(assume_key)})")
states = list() states = list()
asserts_antecedent_cache = [list()] asserts_antecedent_cache = [list()]
@ -1651,6 +1708,13 @@ def smt_check_sat(expected=["sat", "unsat"]):
smt_forall_assert() smt_forall_assert()
return smt.check_sat(expected=expected) return smt.check_sat(expected=expected)
def report_tracked_assumptions(msg):
if track_assumes:
print_msg(msg)
for key in smt.get_unsat_assumptions(minimize=minimize_assumes):
enable, path, descr = assume_enables[key]
print_msg(f" In {path}: {descr}")
if incremental: if incremental:
incremental.mainloop() incremental.mainloop()
@ -1664,7 +1728,7 @@ elif tempind:
break break
smt_state(step) smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step)) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step)) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step)) smt_assert_consequent(get_constr_expr(constr_assumes, step))
@ -1707,6 +1771,7 @@ elif tempind:
else: else:
print_msg("Temporal induction successful.") print_msg("Temporal induction successful.")
report_tracked_assumptions("Used assumptions:")
retstatus = "PASSED" retstatus = "PASSED"
break break
@ -1732,7 +1797,7 @@ elif covermode:
while step < num_steps: while step < num_steps:
smt_state(step) smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step)) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step)) smt_assert_consequent(get_constr_expr(constr_assumes, step))
@ -1753,6 +1818,7 @@ elif covermode:
smt_assert("(distinct (covers_%d s%d) #b%s)" % (coveridx, step, "0" * len(cover_desc))) smt_assert("(distinct (covers_%d s%d) #b%s)" % (coveridx, step, "0" * len(cover_desc)))
if smt_check_sat() == "unsat": if smt_check_sat() == "unsat":
report_tracked_assumptions("Used assumptions:")
smt_pop() smt_pop()
break break
@ -1761,13 +1827,14 @@ elif covermode:
print_msg("Appending additional step %d." % i) print_msg("Appending additional step %d." % i)
smt_state(i) smt_state(i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i)) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, i)) smt_assert_design_assumes(i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i)) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i)) smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
smt_assert_consequent(get_constr_expr(constr_assumes, i)) smt_assert_consequent(get_constr_expr(constr_assumes, i))
print_msg("Re-solving with appended steps..") print_msg("Re-solving with appended steps..")
if smt_check_sat() == "unsat": if smt_check_sat() == "unsat":
print("%s Cannot appended steps without violating assumptions!" % smt.timestamp()) print("%s Cannot appended steps without violating assumptions!" % smt.timestamp())
report_tracked_assumptions("Conflicting assumptions:")
found_failed_assert = True found_failed_assert = True
retstatus = "FAILED" retstatus = "FAILED"
break break
@ -1823,7 +1890,7 @@ else: # not tempind, covermode
retstatus = "PASSED" retstatus = "PASSED"
while step < num_steps: while step < num_steps:
smt_state(step) smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step)) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step)) smt_assert_consequent(get_constr_expr(constr_assumes, step))
@ -1853,7 +1920,7 @@ else: # not tempind, covermode
if step+i < num_steps: if step+i < num_steps:
smt_state(step+i) smt_state(step+i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step+i)) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step+i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step+i)) smt_assert_design_assumes(step + i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step+i)) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step+i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, step+i-1, step+i)) smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, step+i-1, step+i))
smt_assert_consequent(get_constr_expr(constr_assumes, step+i)) smt_assert_consequent(get_constr_expr(constr_assumes, step+i))
@ -1867,7 +1934,8 @@ else: # not tempind, covermode
print_msg("Checking assumptions in steps %d to %d.." % (step, last_check_step)) print_msg("Checking assumptions in steps %d to %d.." % (step, last_check_step))
if smt_check_sat() == "unsat": if smt_check_sat() == "unsat":
print("%s Assumptions are unsatisfiable!" % smt.timestamp()) print_msg("Assumptions are unsatisfiable!")
report_tracked_assumptions("Conficting assumptions:")
retstatus = "PREUNSAT" retstatus = "PREUNSAT"
break break
@ -1920,13 +1988,14 @@ else: # not tempind, covermode
print_msg("Appending additional step %d." % i) print_msg("Appending additional step %d." % i)
smt_state(i) smt_state(i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i)) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, i)) smt_assert_design_assumes(i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i)) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i)) smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
smt_assert_consequent(get_constr_expr(constr_assumes, i)) smt_assert_consequent(get_constr_expr(constr_assumes, i))
print_msg("Re-solving with appended steps..") print_msg("Re-solving with appended steps..")
if smt_check_sat() == "unsat": if smt_check_sat() == "unsat":
print("%s Cannot append steps without violating assumptions!" % smt.timestamp()) print_msg("Cannot append steps without violating assumptions!")
report_tracked_assumptions("Conflicting assumptions:")
retstatus = "FAILED" retstatus = "FAILED"
break break
print_anyconsts(step) print_anyconsts(step)

View file

@ -15,6 +15,14 @@ class InteractiveError(Exception):
pass pass
def mkkey(data):
if isinstance(data, list):
return tuple(map(mkkey, data))
elif isinstance(data, dict):
raise InteractiveError(f"JSON objects found in assumption key: {data!r}")
return data
class Incremental: class Incremental:
def __init__(self): def __init__(self):
self.traceidx = 0 self.traceidx = 0
@ -73,17 +81,17 @@ class Incremental:
if min_len is not None and arg_len < min_len: if min_len is not None and arg_len < min_len:
if min_len == max_len: if min_len == max_len:
raise ( raise InteractiveError(
f"{json.dumps(expr[0])} expression must have " f"{json.dumps(expr[0])} expression must have "
f"{min_len} argument{'s' if min_len != 1 else ''}" f"{min_len} argument{'s' if min_len != 1 else ''}"
) )
else: else:
raise ( raise InteractiveError(
f"{json.dumps(expr[0])} expression must have at least " f"{json.dumps(expr[0])} expression must have at least "
f"{min_len} argument{'s' if min_len != 1 else ''}" f"{min_len} argument{'s' if min_len != 1 else ''}"
) )
if max_len is not None and arg_len > max_len: if max_len is not None and arg_len > max_len:
raise ( raise InteractiveError(
f"{json.dumps(expr[0])} expression can have at most " f"{json.dumps(expr[0])} expression can have at most "
f"{min_len} argument{'s' if max_len != 1 else ''}" f"{min_len} argument{'s' if max_len != 1 else ''}"
) )
@ -96,14 +104,31 @@ class Incremental:
smt_out.append(f"s{step}") smt_out.append(f"s{step}")
return "module", smtbmc.topmod return "module", smtbmc.topmod
def expr_mod_constraint(self, expr, smt_out): def expr_cell(self, expr, smt_out):
self.expr_arg_len(expr, 1) self.expr_arg_len(expr, 2)
position = len(smt_out) position = len(smt_out)
smt_out.append(None) smt_out.append(None)
arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None]) arg_sort = self.expr(expr[2], smt_out, required_sort=["module", None])
smt_out.append(")")
module = arg_sort[1] module = arg_sort[1]
cell = expr[1]
submod = smtbmc.smt.modinfo[module].cells.get(cell)
if submod is None:
raise InteractiveError(f"module {module!r} has no cell {cell!r}")
smt_out[position] = f"(|{module}_h {cell}| "
return ("module", submod)
def expr_mod_constraint(self, expr, smt_out):
suffix = expr[0][3:] suffix = expr[0][3:]
smt_out[position] = f"(|{module}{suffix}| " self.expr_arg_len(expr, 1, 2 if suffix in ["_a", "_u", "_c"] else 1)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[-1], smt_out, required_sort=["module", None])
module = arg_sort[1]
if len(expr) == 3:
smt_out[position] = f"(|{module}{suffix} {expr[1]}| "
else:
smt_out[position] = f"(|{module}{suffix}| "
smt_out.append(")") smt_out.append(")")
return "Bool" return "Bool"
@ -223,20 +248,19 @@ class Incremental:
subexpr = expr[2] subexpr = expr[2]
if not isinstance(label, str): if not isinstance(label, str):
raise InteractiveError(f"expression label has to be a string") raise InteractiveError("expression label has to be a string")
smt_out.append("(! ") smt_out.append("(! ")
smt_out.appedd(label)
smt_out.append(" ")
sort = self.expr(subexpr, smt_out) sort = self.expr(subexpr, smt_out)
smt_out.append(" :named ")
smt_out.append(label)
smt_out.append(")") smt_out.append(")")
return sort return sort
expr_handlers = { expr_handlers = {
"step": expr_step, "step": expr_step,
"cell": expr_cell,
"mod_h": expr_mod_constraint, "mod_h": expr_mod_constraint,
"mod_is": expr_mod_constraint, "mod_is": expr_mod_constraint,
"mod_i": expr_mod_constraint, "mod_i": expr_mod_constraint,
@ -302,6 +326,30 @@ class Incremental:
assert_fn(self.expr_smt(cmd.get("expr"), "Bool")) assert_fn(self.expr_smt(cmd.get("expr"), "Bool"))
def cmd_assert_design_assumes(self, cmd):
step = self.arg_step(cmd)
smtbmc.smt_assert_design_assumes(step)
def cmd_get_design_assume(self, cmd):
key = mkkey(cmd.get("key"))
return smtbmc.assume_enables.get(key)
def cmd_update_assumptions(self, cmd):
expr = cmd.get("expr")
key = cmd.get("key")
key = mkkey(key)
result = smtbmc.smt.smt2_assumptions.pop(key, None)
if expr is not None:
expr = self.expr_smt(expr, "Bool")
smtbmc.smt.smt2_assumptions[key] = expr
return result
def cmd_get_unsat_assumptions(self, cmd):
return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get('minimize')))
def cmd_push(self, cmd): def cmd_push(self, cmd):
smtbmc.smt_push() smtbmc.smt_push()
@ -313,11 +361,14 @@ class Incremental:
def cmd_smtlib(self, cmd): def cmd_smtlib(self, cmd):
command = cmd.get("command") command = cmd.get("command")
response = cmd.get("response", False)
if not isinstance(command, str): if not isinstance(command, str):
raise InteractiveError( raise InteractiveError(
f"raw SMT-LIB command must be a string, found {json.dumps(command)}" f"raw SMT-LIB command must be a string, found {json.dumps(command)}"
) )
smtbmc.smt.write(command) smtbmc.smt.write(command)
if response:
return smtbmc.smt.read()
def cmd_design_hierwitness(self, cmd=None): def cmd_design_hierwitness(self, cmd=None):
allregs = (cmd is None) or bool(cmd.get("allreges", False)) allregs = (cmd is None) or bool(cmd.get("allreges", False))
@ -369,6 +420,21 @@ class Incremental:
return dict(last_step=last_step) return dict(last_step=last_step)
def cmd_modinfo(self, cmd):
fields = cmd.get("fields", [])
mod = cmd.get("mod")
if mod is None:
mod = smtbmc.topmod
modinfo = smtbmc.smt.modinfo.get(mod)
if modinfo is None:
return None
result = dict(name=mod)
for field in fields:
result[field] = getattr(modinfo, field, None)
return result
def cmd_ping(self, cmd): def cmd_ping(self, cmd):
return cmd return cmd
@ -377,6 +443,10 @@ class Incremental:
"assert": cmd_assert, "assert": cmd_assert,
"assert_antecedent": cmd_assert, "assert_antecedent": cmd_assert,
"assert_consequent": cmd_assert, "assert_consequent": cmd_assert,
"assert_design_assumes": cmd_assert_design_assumes,
"get_design_assume": cmd_get_design_assume,
"update_assumptions": cmd_update_assumptions,
"get_unsat_assumptions": cmd_get_unsat_assumptions,
"push": cmd_push, "push": cmd_push,
"pop": cmd_pop, "pop": cmd_pop,
"check": cmd_check, "check": cmd_check,
@ -384,6 +454,7 @@ class Incremental:
"design_hierwitness": cmd_design_hierwitness, "design_hierwitness": cmd_design_hierwitness,
"write_yw_trace": cmd_write_yw_trace, "write_yw_trace": cmd_write_yw_trace,
"read_yw_trace": cmd_read_yw_trace, "read_yw_trace": cmd_read_yw_trace,
"modinfo": cmd_modinfo,
"ping": cmd_ping, "ping": cmd_ping,
} }

View file

@ -114,6 +114,7 @@ class SmtModInfo:
self.clocks = dict() self.clocks = dict()
self.cells = dict() self.cells = dict()
self.asserts = dict() self.asserts = dict()
self.assumes = dict()
self.covers = dict() self.covers = dict()
self.maximize = set() self.maximize = set()
self.minimize = set() self.minimize = set()
@ -141,6 +142,7 @@ class SmtIo:
self.recheck = False self.recheck = False
self.smt2cache = [list()] self.smt2cache = [list()]
self.smt2_options = dict() self.smt2_options = dict()
self.smt2_assumptions = dict()
self.p = None self.p = None
self.p_index = solvers_index self.p_index = solvers_index
solvers_index += 1 solvers_index += 1
@ -602,6 +604,12 @@ class SmtIo:
else: else:
self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = fields[3] self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = fields[3]
if fields[1] == "yosys-smt2-assume":
if len(fields) > 4:
self.modinfo[self.curmod].assumes["%s_u %s" % (self.curmod, fields[2])] = f'{fields[4]} ({fields[3]})'
else:
self.modinfo[self.curmod].assumes["%s_u %s" % (self.curmod, fields[2])] = fields[3]
if fields[1] == "yosys-smt2-maximize": if fields[1] == "yosys-smt2-maximize":
self.modinfo[self.curmod].maximize.add(fields[2]) self.modinfo[self.curmod].maximize.add(fields[2])
@ -785,8 +793,13 @@ class SmtIo:
return stmt return stmt
def check_sat(self, expected=["sat", "unsat", "unknown", "timeout", "interrupted"]): def check_sat(self, expected=["sat", "unsat", "unknown", "timeout", "interrupted"]):
if self.smt2_assumptions:
assume_exprs = " ".join(self.smt2_assumptions.values())
check_stmt = f"(check-sat-assuming ({assume_exprs}))"
else:
check_stmt = "(check-sat)"
if self.debug_print: if self.debug_print:
print("> (check-sat)") print(f"> {check_stmt}")
if self.debug_file and not self.nocomments: if self.debug_file and not self.nocomments:
print("; running check-sat..", file=self.debug_file) print("; running check-sat..", file=self.debug_file)
self.debug_file.flush() self.debug_file.flush()
@ -800,7 +813,7 @@ class SmtIo:
for cache_stmt in cache_ctx: for cache_stmt in cache_ctx:
self.p_write(cache_stmt + "\n", False) self.p_write(cache_stmt + "\n", False)
self.p_write("(check-sat)\n", True) self.p_write(f"{check_stmt}\n", True)
if self.timeinfo: if self.timeinfo:
i = 0 i = 0
@ -868,7 +881,7 @@ class SmtIo:
if self.debug_file: if self.debug_file:
print("(set-info :status %s)" % result, file=self.debug_file) print("(set-info :status %s)" % result, file=self.debug_file)
print("(check-sat)", file=self.debug_file) print(check_stmt, file=self.debug_file)
self.debug_file.flush() self.debug_file.flush()
if result not in expected: if result not in expected:
@ -945,6 +958,48 @@ class SmtIo:
def bv2int(self, v): def bv2int(self, v):
return int(self.bv2bin(v), 2) return int(self.bv2bin(v), 2)
def get_raw_unsat_assumptions(self):
self.write("(get-unsat-assumptions)")
exprs = set(self.unparse(part) for part in self.parse(self.read()))
unsat_assumptions = []
for key, value in self.smt2_assumptions.items():
# normalize expression
value = self.unparse(self.parse(value))
if value in exprs:
exprs.remove(value)
unsat_assumptions.append(key)
return unsat_assumptions
def get_unsat_assumptions(self, minimize=False):
if not minimize:
return self.get_raw_unsat_assumptions()
required_assumptions = {}
while True:
candidate_assumptions = {}
for key in self.get_raw_unsat_assumptions():
if key not in required_assumptions:
candidate_assumptions[key] = self.smt2_assumptions[key]
while candidate_assumptions:
candidate_key, candidate_assume = candidate_assumptions.popitem()
self.smt2_assumptions = {}
for key, assume in candidate_assumptions.items():
self.smt2_assumptions[key] = assume
for key, assume in required_assumptions.items():
self.smt2_assumptions[key] = assume
result = self.check_sat()
if result == 'unsat':
candidate_assumptions = None
else:
required_assumptions[candidate_key] = candidate_assume
if candidate_assumptions is not None:
return list(required_assumptions)
def get(self, expr): def get(self, expr):
self.write("(get-value (%s))" % (expr)) self.write("(get-value (%s))" % (expr))
return self.parse(self.read())[0][1] return self.parse(self.read())[0][1]