aboutsummaryrefslogtreecommitdiffstats
path: root/backends/smt2/smtbmc.py
diff options
context:
space:
mode:
Diffstat (limited to 'backends/smt2/smtbmc.py')
-rw-r--r--backends/smt2/smtbmc.py476
1 files changed, 363 insertions, 113 deletions
diff --git a/backends/smt2/smtbmc.py b/backends/smt2/smtbmc.py
index 137182f33..cb21eb3aa 100644
--- a/backends/smt2/smtbmc.py
+++ b/backends/smt2/smtbmc.py
@@ -17,9 +17,10 @@
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
#
-import os, sys, getopt, re
+import os, sys, getopt, re, bisect
##yosys-sys-path##
from smtio import SmtIo, SmtOpts, MkVcd
+from ywio import ReadWitness, WriteWitness, WitnessValues
from collections import defaultdict
got_topt = False
@@ -28,6 +29,8 @@ step_size = 1
num_steps = 20
append_steps = 0
vcdfile = None
+inywfile = None
+outywfile = None
cexfile = None
aimfile = None
aiwfile = None
@@ -51,6 +54,8 @@ smtctop = None
noinit = False
binarymode = False
keep_going = False
+check_witness = False
+detect_loops = False
so = SmtOpts()
@@ -94,6 +99,9 @@ def usage():
the AIGER witness file does not include the status and
properties lines.
+ --yw <yosys_witness_filename>
+ read a Yosys witness.
+
--btorwit <btor_witness_filename>
read a BTOR witness.
@@ -121,6 +129,9 @@ def usage():
(hint: use 'write_smt2 -wires' for maximum
coverage of signals in generated VCD file)
+ --dump-yw <yw_filename>
+ write trace as a Yosys witness trace
+
--dump-vlogtb <verilog_filename>
write trace as Verilog test bench
@@ -161,15 +172,23 @@ def usage():
covering all found failed assertions, the character '%' is
replaced in all dump filenames with an increasing number.
+ --check-witness
+ check that the used witness file contains sufficient
+ constraints to force an assertion failure.
+
+ --detect-loops
+ check if states are unique in temporal induction counter examples
+ (this feature is experimental and incomplete)
+
""" + so.helpmsg())
sys.exit(1)
try:
opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:igcm:", so.longopts +
- ["final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "btorwit=", "presat",
- "dump-vcd=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=",
- "smtc-init", "smtc-top=", "noinit", "binary", "keep-going"])
+ ["final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat",
+ "dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=",
+ "smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops"])
except:
usage()
@@ -204,10 +223,14 @@ for o, a in opts:
aiwfile = a + ".aiw"
elif o == "--aig-noheader":
aigheader = False
+ elif o == "--yw":
+ inywfile = a
elif o == "--btorwit":
btorwitfile = a
elif o == "--dump-vcd":
vcdfile = a
+ elif o == "--dump-yw":
+ outywfile = a
elif o == "--dump-vlogtb":
vlogtbfile = a
elif o == "--vlogtb-top":
@@ -244,6 +267,10 @@ for o, a in opts:
binarymode = True
elif o == "--keep-going":
keep_going = True
+ elif o == "--check-witness":
+ check_witness = True
+ elif o == "--detect-loops":
+ detect_loops = True
elif so.handle(o, a):
pass
else:
@@ -426,7 +453,6 @@ assert topmod in smt.modinfo
if cexfile is not None:
if not got_topt:
- assume_skipped = 0
skip_steps = 0
num_steps = 0
@@ -462,7 +488,8 @@ if cexfile is not None:
constr_assumes[step].append((cexfile, smtexpr))
if not got_topt:
- skip_steps = max(skip_steps, step)
+ if not check_witness:
+ skip_steps = max(skip_steps, step)
num_steps = max(num_steps, step+1)
if aimfile is not None:
@@ -471,7 +498,6 @@ if aimfile is not None:
latch_map = dict()
if not got_topt:
- assume_skipped = 0
skip_steps = 0
num_steps = 0
@@ -595,13 +621,118 @@ if aimfile is not None:
constr_assumes[step].append((cexfile, smtexpr))
if not got_topt:
- skip_steps = max(skip_steps, step)
+ if not check_witness:
+ skip_steps = max(skip_steps, step)
# some solvers optimize the properties so that they fail one cycle early,
# thus we check the properties in the cycle the aiger witness ends, and
# if that doesn't work, we check the cycle after that as well.
num_steps = max(num_steps, step+2)
step += 1
+if inywfile is not None:
+ if not got_topt:
+ skip_steps = 0
+ num_steps = 0
+
+ with open(inywfile, "r") as f:
+ inyw = ReadWitness(f)
+
+ inits, seqs, clocks, mems = smt.hierwitness(topmod, allregs=True, blackbox=True)
+
+ smt_wires = defaultdict(list)
+ smt_mems = defaultdict(list)
+
+ for wire in inits + seqs:
+ smt_wires[wire["path"]].append(wire)
+
+ for mem in mems:
+ smt_mems[mem["path"]].append(mem)
+
+ addr_re = re.compile(r'\\\[[0-9]+\]$')
+ bits_re = re.compile(r'[01?]*$')
+
+ for t, step in inyw.steps():
+ present_signals, missing = step.present_signals(inyw.sigmap)
+ for sig in present_signals:
+ bits = step[sig]
+ if not bits_re.match(bits):
+ raise ValueError("unsupported bit value in Yosys witness file")
+
+ sig_end = sig.offset + len(bits)
+ if sig.path in smt_wires:
+ for wire in smt_wires[sig.path]:
+ width, offset = wire["width"], wire["offset"]
+
+ smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1
+
+ offset = max(offset, 0)
+
+ end = width + offset
+ common_offset = max(sig.offset, offset)
+ common_end = min(sig_end, end)
+ if common_end <= common_offset:
+ continue
+
+ smt_expr = smt.witness_net_expr(topmod, f"s{t}", wire)
+
+ if not smt_bool:
+ slice_high = common_end - offset - 1
+ slice_low = common_offset - offset
+ smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr)
+
+ bit_slice = bits[len(bits) - (common_end - sig.offset):len(bits) - (common_offset - sig.offset)]
+
+ if bit_slice.count("?") == len(bit_slice):
+ continue
+
+ if smt_bool:
+ assert width == 1
+ smt_constr = "(= %s %s)" % (smt_expr, "true" if bit_slice == "1" else "false")
+ else:
+ if "?" in bit_slice:
+ mask = bit_slice.replace("0", "1").replace("?", "0")
+ bit_slice = bit_slice.replace("?", "0")
+ smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
+
+ smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
+
+ constr_assumes[t].append((inywfile, smt_constr))
+
+ if sig.memory_path:
+ if sig.memory_path in smt_mems:
+ for mem in smt_mems[sig.memory_path]:
+ width, size, bv = mem["width"], mem["size"], mem["statebv"]
+
+ smt_expr = smt.net_expr(topmod, f"s{t}", mem["smtpath"])
+
+ if bv:
+ word_low = sig.memory_addr * width
+ word_high = word_low + width - 1
+ smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr)
+ else:
+ addr_width = (size - 1).bit_length()
+ addr_bits = f"{sig.memory_addr:0{addr_width}b}"
+ smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits)
+
+ if len(bits) < width:
+ slice_high = sig.offset + len(bits) - 1
+ smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr)
+
+ bit_slice = bits
+
+ if "?" in bit_slice:
+ mask = bit_slice.replace("0", "1").replace("?", "0")
+ bit_slice = bit_slice.replace("?", "0")
+ smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
+
+ smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
+ constr_assumes[t].append((inywfile, smt_constr))
+
+ if not got_topt:
+ if not check_witness:
+ skip_steps = max(skip_steps, t)
+ num_steps = max(num_steps, t+1)
+
if btorwitfile is not None:
with open(btorwitfile, "r") as f:
step = None
@@ -699,128 +830,137 @@ if btorwitfile is not None:
skip_steps = step
num_steps = step+1
-def write_vcd_trace(steps_start, steps_stop, index):
- filename = vcdfile.replace("%", index)
- print_msg("Writing trace to VCD file: %s" % (filename))
+def collect_mem_trace_data(steps_start, steps_stop, vcd=None):
+ mem_trace_data = dict()
- with open(filename, "w") as vcd_file:
- vcd = MkVcd(vcd_file)
- path_list = list()
+ for mempath in sorted(smt.hiermems(topmod)):
+ abits, width, rports, wports, asyncwr = smt.mem_info(topmod, mempath)
- for netpath in sorted(smt.hiernets(topmod)):
- hidden_net = False
- for n in netpath:
- if n.startswith("$"):
- hidden_net = True
- if not hidden_net:
- edge = smt.net_clock(topmod, netpath)
- if edge is None:
- vcd.add_net([topmod] + netpath, smt.net_width(topmod, netpath))
- else:
- vcd.add_clock([topmod] + netpath, edge)
- path_list.append(netpath)
+ expr_id = list()
+ expr_list = list()
+ for i in range(steps_start, steps_stop):
+ for j in range(rports):
+ expr_id.append(('R', i-steps_start, j, 'A'))
+ expr_id.append(('R', i-steps_start, j, 'D'))
+ expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dA" % j))
+ expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dD" % j))
+ for j in range(wports):
+ expr_id.append(('W', i-steps_start, j, 'A'))
+ expr_id.append(('W', i-steps_start, j, 'D'))
+ expr_id.append(('W', i-steps_start, j, 'M'))
+ expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dA" % j))
+ expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dD" % j))
+ expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dM" % j))
+
+ rdata = list()
+ wdata = list()
+ addrs = set()
+
+ for eid, edat in zip(expr_id, smt.get_list(expr_list)):
+ t, i, j, f = eid
+
+ if t == 'R':
+ c = rdata
+ elif t == 'W':
+ c = wdata
+ else:
+ assert False
- mem_trace_data = dict()
- for mempath in sorted(smt.hiermems(topmod)):
- abits, width, rports, wports, asyncwr = smt.mem_info(topmod, mempath)
+ while len(c) <= i:
+ c.append(list())
+ c = c[i]
- expr_id = list()
- expr_list = list()
- for i in range(steps_start, steps_stop):
- for j in range(rports):
- expr_id.append(('R', i-steps_start, j, 'A'))
- expr_id.append(('R', i-steps_start, j, 'D'))
- expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dA" % j))
- expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dD" % j))
- for j in range(wports):
- expr_id.append(('W', i-steps_start, j, 'A'))
- expr_id.append(('W', i-steps_start, j, 'D'))
- expr_id.append(('W', i-steps_start, j, 'M'))
- expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dA" % j))
- expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dD" % j))
- expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dM" % j))
-
- rdata = list()
- wdata = list()
- addrs = set()
-
- for eid, edat in zip(expr_id, smt.get_list(expr_list)):
- t, i, j, f = eid
-
- if t == 'R':
- c = rdata
- elif t == 'W':
- c = wdata
- else:
- assert False
+ while len(c) <= j:
+ c.append(dict())
+ c = c[j]
- while len(c) <= i:
- c.append(list())
- c = c[i]
+ c[f] = smt.bv2bin(edat)
- while len(c) <= j:
- c.append(dict())
- c = c[j]
+ if f == 'A':
+ addrs.add(c[f])
- c[f] = smt.bv2bin(edat)
+ for addr in addrs:
+ tdata = list()
+ data = ["x"] * width
+ gotread = False
- if f == 'A':
- addrs.add(c[f])
+ if len(wdata) == 0 and len(rdata) != 0:
+ wdata = [[]] * len(rdata)
- for addr in addrs:
- tdata = list()
- data = ["x"] * width
- gotread = False
+ assert len(rdata) == len(wdata)
- if len(wdata) == 0 and len(rdata) != 0:
- wdata = [[]] * len(rdata)
+ for i in range(len(wdata)):
+ if not gotread:
+ for j_data in rdata[i]:
+ if j_data["A"] == addr:
+ data = list(j_data["D"])
+ gotread = True
+ break
- assert len(rdata) == len(wdata)
+ if gotread:
+ buf = data[:]
+ for ii in reversed(range(len(tdata))):
+ for k in range(width):
+ if tdata[ii][k] == "x":
+ tdata[ii][k] = buf[k]
+ else:
+ buf[k] = tdata[ii][k]
- for i in range(len(wdata)):
- if not gotread:
- for j_data in rdata[i]:
- if j_data["A"] == addr:
- data = list(j_data["D"])
- gotread = True
- break
-
- if gotread:
- buf = data[:]
- for ii in reversed(range(len(tdata))):
- for k in range(width):
- if tdata[ii][k] == "x":
- tdata[ii][k] = buf[k]
- else:
- buf[k] = tdata[ii][k]
+ if not asyncwr:
+ tdata.append(data[:])
- if not asyncwr:
- tdata.append(data[:])
+ for j_data in wdata[i]:
+ if j_data["A"] != addr:
+ continue
- for j_data in wdata[i]:
- if j_data["A"] != addr:
- continue
+ D = j_data["D"]
+ M = j_data["M"]
- D = j_data["D"]
- M = j_data["M"]
+ for k in range(width):
+ if M[k] == "1":
+ data[k] = D[k]
- for k in range(width):
- if M[k] == "1":
- data[k] = D[k]
+ if asyncwr:
+ tdata.append(data[:])
- if asyncwr:
- tdata.append(data[:])
+ assert len(tdata) == len(rdata)
- assert len(tdata) == len(rdata)
+ int_addr = int(addr, 2)
- netpath = mempath[:]
- netpath[-1] += "<%0*x>" % ((len(addr)+3) // 4, int(addr, 2))
+ netpath = mempath[:]
+ if vcd:
+ netpath[-1] += "<%0*x>" % ((len(addr)+3) // 4, int_addr)
vcd.add_net([topmod] + netpath, width)
- for i in range(steps_start, steps_stop):
- if i not in mem_trace_data:
- mem_trace_data[i] = list()
- mem_trace_data[i].append((netpath, "".join(tdata[i-steps_start])))
+ for i in range(steps_start, steps_stop):
+ if i not in mem_trace_data:
+ mem_trace_data[i] = list()
+ mem_trace_data[i].append((netpath, int_addr, "".join(tdata[i-steps_start])))
+
+ return mem_trace_data
+
+def write_vcd_trace(steps_start, steps_stop, index):
+ filename = vcdfile.replace("%", index)
+ print_msg("Writing trace to VCD file: %s" % (filename))
+
+ with open(filename, "w") as vcd_file:
+ vcd = MkVcd(vcd_file)
+ path_list = list()
+
+ for netpath in sorted(smt.hiernets(topmod)):
+ hidden_net = False
+ for n in netpath:
+ if n.startswith("$"):
+ hidden_net = True
+ if not hidden_net:
+ edge = smt.net_clock(topmod, netpath)
+ if edge is None:
+ vcd.add_net([topmod] + netpath, smt.net_width(topmod, netpath))
+ else:
+ vcd.add_clock([topmod] + netpath, edge)
+ path_list.append(netpath)
+
+ mem_trace_data = collect_mem_trace_data(steps_start, steps_stop, vcd)
for i in range(steps_start, steps_stop):
vcd.set_time(i)
@@ -828,11 +968,35 @@ def write_vcd_trace(steps_start, steps_stop, index):
for path, value in zip(path_list, value_list):
vcd.set_net([topmod] + path, value)
if i in mem_trace_data:
- for path, value in mem_trace_data[i]:
+ for path, addr, value in mem_trace_data[i]:
vcd.set_net([topmod] + path, value)
vcd.set_time(steps_stop)
+def detect_state_loop(steps_start, steps_stop):
+ print_msg(f"Checking for loops in found induction counter example")
+ print_msg(f"This feature is experimental and incomplete")
+
+ path_list = sorted(smt.hiernets(topmod, regs_only=True))
+
+ mem_trace_data = collect_mem_trace_data(steps_start, steps_stop)
+
+ # Map state to index of step when it occurred
+ states = dict()
+
+ for i in range(steps_start, steps_stop):
+ value_list = smt.get_net_bin_list(topmod, path_list, "s%d" % i)
+ mem_state = sorted(
+ [(tuple(path), addr, data)
+ for path, addr, data in mem_trace_data.get(i, [])])
+ state = tuple(value_list), tuple(mem_state)
+ if state in states:
+ return (i, states[state])
+ else:
+ states[state] = i
+
+ return None
+
def char_ok_in_verilog(c,i):
if ('A' <= c <= 'Z'): return True
if ('a' <= c <= 'z'): return True
@@ -1072,8 +1236,73 @@ def write_constr_trace(steps_start, steps_stop, index):
for name, val in zip(pi_names, pi_values):
print("assume (= [%s%s] %s)" % (constr_prefix, ".".join(name), val), file=f)
+def write_yw_trace(steps_start, steps_stop, index, allregs=False):
+ filename = outywfile.replace("%", index)
+ print_msg("Writing trace to Yosys witness file: %s" % (filename))
+
+ mem_trace_data = collect_mem_trace_data(steps_start, steps_stop)
+
+ with open(filename, "w") as f:
+ inits, seqs, clocks, mems = smt.hierwitness(topmod, allregs)
+
+ yw = WriteWitness(f, "smtbmc")
+
+ for clock in clocks:
+ yw.add_clock(clock["path"], clock["offset"], clock["type"])
+
+ for seq in seqs:
+ seq["sig"] = yw.add_sig(seq["path"], seq["offset"], seq["width"])
+
+ for init in inits:
+ init["sig"] = yw.add_sig(init["path"], init["offset"], init["width"], True)
+
+ inits = seqs + inits
+
+ mem_dict = {tuple(mem["smtpath"]): mem for mem in mems}
+ mem_init_values = []
+
+ for path, addr, value in mem_trace_data.get(0, ()):
+ json_mem = mem_dict.get(tuple(path))
+ if not json_mem:
+ continue
+
+ bit_addr = addr * json_mem["width"]
+ uninit_chunks = [(chunk["width"] + chunk["offset"], chunk["offset"]) for chunk in json_mem["uninitialized"]]
+ first_chunk_nr = bisect.bisect_left(uninit_chunks, (bit_addr + 1,))
-def write_trace(steps_start, steps_stop, index):
+ for uninit_end, uninit_offset in uninit_chunks[first_chunk_nr:]:
+ assert uninit_end > bit_addr
+ if uninit_offset > bit_addr + json_mem["width"]:
+ break
+
+ word_path = (*json_mem["path"], f"\\[{addr}]")
+
+ overlap_start = max(uninit_offset - bit_addr, 0)
+ overlap_end = min(uninit_end - bit_addr, json_mem["width"])
+ overlap_bits = value[len(value)-overlap_end:len(value)-overlap_start]
+
+ sig = yw.add_sig(word_path, overlap_start, overlap_end - overlap_start, True)
+ mem_init_values.append((sig, overlap_bits.replace("x", "?")))
+
+ for k in range(steps_start, steps_stop):
+ step_values = WitnessValues()
+
+ if k == steps_start:
+ for sig, value in mem_init_values:
+ step_values[sig] = value
+ sigs = inits + seqs
+ else:
+ sigs = seqs
+
+ for sig in sigs:
+ value = smt.bv2bin(smt.get(smt.witness_net_expr(topmod, f"s{k}", sig)))
+ step_values[sig["sig"]] = value
+ yw.step(step_values)
+
+ yw.end_trace()
+
+
+def write_trace(steps_start, steps_stop, index, allregs=False):
if vcdfile is not None:
write_vcd_trace(steps_start, steps_stop, index)
@@ -1083,6 +1312,9 @@ def write_trace(steps_start, steps_stop, index):
if outconstr is not None:
write_constr_trace(steps_start, steps_stop, index)
+ if outywfile is not None:
+ write_yw_trace(steps_start, steps_stop, index, allregs)
+
def print_failed_asserts_worker(mod, state, path, extrainfo, infomap, infokey=()):
assert mod in smt.modinfo
@@ -1392,12 +1624,16 @@ if tempind:
print_msg("Temporal induction failed!")
print_anyconsts(num_steps)
print_failed_asserts(num_steps)
- write_trace(step, num_steps+1, '%')
+ write_trace(step, num_steps+1, '%', allregs=True)
+ if detect_loops:
+ loop = detect_state_loop(step, num_steps+1)
+ if loop:
+ print_msg(f"Loop detected, increasing induction depth will not help. Step {loop[0]} = step {loop[1]}")
elif dumpall:
print_anyconsts(num_steps)
print_failed_asserts(num_steps)
- write_trace(step, num_steps+1, "%d" % step)
+ write_trace(step, num_steps+1, "%d" % step, allregs=True)
else:
print_msg("Temporal induction successful.")
@@ -1590,6 +1826,7 @@ else: # not tempind, covermode
smt_assert("(not %s)" % active_assert_expr)
else:
+ active_assert_expr = "true"
smt_assert("false")
@@ -1597,6 +1834,17 @@ else: # not tempind, covermode
if retstatus != "FAILED":
print("%s BMC failed!" % smt.timestamp())
+ if check_witness:
+ print_msg("Checking witness constraints...")
+ smt_pop()
+ smt_push()
+ smt_assert(active_assert_expr)
+ if smt_check_sat() != "sat":
+ retstatus = "PASSED"
+ check_witness = False
+ num_steps = -1
+ break
+
if append_steps > 0:
for i in range(last_check_step+1, last_check_step+1+append_steps):
print_msg("Appending additional step %d." % i)
@@ -1689,6 +1937,8 @@ else: # not tempind, covermode
print_anyconsts(0)
write_trace(0, num_steps, '%')
+ if check_witness:
+ retstatus = "FAILED"
smt.write("(exit)")
smt.wait()