diff options
Diffstat (limited to 'backends/smt2/smtbmc.py')
-rw-r--r-- | backends/smt2/smtbmc.py | 476 |
1 files changed, 363 insertions, 113 deletions
diff --git a/backends/smt2/smtbmc.py b/backends/smt2/smtbmc.py index 137182f33..cb21eb3aa 100644 --- a/backends/smt2/smtbmc.py +++ b/backends/smt2/smtbmc.py @@ -17,9 +17,10 @@ # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # -import os, sys, getopt, re +import os, sys, getopt, re, bisect ##yosys-sys-path## from smtio import SmtIo, SmtOpts, MkVcd +from ywio import ReadWitness, WriteWitness, WitnessValues from collections import defaultdict got_topt = False @@ -28,6 +29,8 @@ step_size = 1 num_steps = 20 append_steps = 0 vcdfile = None +inywfile = None +outywfile = None cexfile = None aimfile = None aiwfile = None @@ -51,6 +54,8 @@ smtctop = None noinit = False binarymode = False keep_going = False +check_witness = False +detect_loops = False so = SmtOpts() @@ -94,6 +99,9 @@ def usage(): the AIGER witness file does not include the status and properties lines. + --yw <yosys_witness_filename> + read a Yosys witness. + --btorwit <btor_witness_filename> read a BTOR witness. @@ -121,6 +129,9 @@ def usage(): (hint: use 'write_smt2 -wires' for maximum coverage of signals in generated VCD file) + --dump-yw <yw_filename> + write trace as a Yosys witness trace + --dump-vlogtb <verilog_filename> write trace as Verilog test bench @@ -161,15 +172,23 @@ def usage(): covering all found failed assertions, the character '%' is replaced in all dump filenames with an increasing number. + --check-witness + check that the used witness file contains sufficient + constraints to force an assertion failure. + + --detect-loops + check if states are unique in temporal induction counter examples + (this feature is experimental and incomplete) + """ + so.helpmsg()) sys.exit(1) try: opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:igcm:", so.longopts + - ["final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "btorwit=", "presat", - "dump-vcd=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=", - "smtc-init", "smtc-top=", "noinit", "binary", "keep-going"]) + ["final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat", + "dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=", + "smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops"]) except: usage() @@ -204,10 +223,14 @@ for o, a in opts: aiwfile = a + ".aiw" elif o == "--aig-noheader": aigheader = False + elif o == "--yw": + inywfile = a elif o == "--btorwit": btorwitfile = a elif o == "--dump-vcd": vcdfile = a + elif o == "--dump-yw": + outywfile = a elif o == "--dump-vlogtb": vlogtbfile = a elif o == "--vlogtb-top": @@ -244,6 +267,10 @@ for o, a in opts: binarymode = True elif o == "--keep-going": keep_going = True + elif o == "--check-witness": + check_witness = True + elif o == "--detect-loops": + detect_loops = True elif so.handle(o, a): pass else: @@ -426,7 +453,6 @@ assert topmod in smt.modinfo if cexfile is not None: if not got_topt: - assume_skipped = 0 skip_steps = 0 num_steps = 0 @@ -462,7 +488,8 @@ if cexfile is not None: constr_assumes[step].append((cexfile, smtexpr)) if not got_topt: - skip_steps = max(skip_steps, step) + if not check_witness: + skip_steps = max(skip_steps, step) num_steps = max(num_steps, step+1) if aimfile is not None: @@ -471,7 +498,6 @@ if aimfile is not None: latch_map = dict() if not got_topt: - assume_skipped = 0 skip_steps = 0 num_steps = 0 @@ -595,13 +621,118 @@ if aimfile is not None: constr_assumes[step].append((cexfile, smtexpr)) if not got_topt: - skip_steps = max(skip_steps, step) + if not check_witness: + skip_steps = max(skip_steps, step) # some solvers optimize the properties so that they fail one cycle early, # thus we check the properties in the cycle the aiger witness ends, and # if that doesn't work, we check the cycle after that as well. num_steps = max(num_steps, step+2) step += 1 +if inywfile is not None: + if not got_topt: + skip_steps = 0 + num_steps = 0 + + with open(inywfile, "r") as f: + inyw = ReadWitness(f) + + inits, seqs, clocks, mems = smt.hierwitness(topmod, allregs=True, blackbox=True) + + smt_wires = defaultdict(list) + smt_mems = defaultdict(list) + + for wire in inits + seqs: + smt_wires[wire["path"]].append(wire) + + for mem in mems: + smt_mems[mem["path"]].append(mem) + + addr_re = re.compile(r'\\\[[0-9]+\]$') + bits_re = re.compile(r'[01?]*$') + + for t, step in inyw.steps(): + present_signals, missing = step.present_signals(inyw.sigmap) + for sig in present_signals: + bits = step[sig] + if not bits_re.match(bits): + raise ValueError("unsupported bit value in Yosys witness file") + + sig_end = sig.offset + len(bits) + if sig.path in smt_wires: + for wire in smt_wires[sig.path]: + width, offset = wire["width"], wire["offset"] + + smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1 + + offset = max(offset, 0) + + end = width + offset + common_offset = max(sig.offset, offset) + common_end = min(sig_end, end) + if common_end <= common_offset: + continue + + smt_expr = smt.witness_net_expr(topmod, f"s{t}", wire) + + if not smt_bool: + slice_high = common_end - offset - 1 + slice_low = common_offset - offset + smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr) + + bit_slice = bits[len(bits) - (common_end - sig.offset):len(bits) - (common_offset - sig.offset)] + + if bit_slice.count("?") == len(bit_slice): + continue + + if smt_bool: + assert width == 1 + smt_constr = "(= %s %s)" % (smt_expr, "true" if bit_slice == "1" else "false") + else: + if "?" in bit_slice: + mask = bit_slice.replace("0", "1").replace("?", "0") + bit_slice = bit_slice.replace("?", "0") + smt_expr = "(bvand %s #b%s)" % (smt_expr, mask) + + smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice) + + constr_assumes[t].append((inywfile, smt_constr)) + + if sig.memory_path: + if sig.memory_path in smt_mems: + for mem in smt_mems[sig.memory_path]: + width, size, bv = mem["width"], mem["size"], mem["statebv"] + + smt_expr = smt.net_expr(topmod, f"s{t}", mem["smtpath"]) + + if bv: + word_low = sig.memory_addr * width + word_high = word_low + width - 1 + smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr) + else: + addr_width = (size - 1).bit_length() + addr_bits = f"{sig.memory_addr:0{addr_width}b}" + smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits) + + if len(bits) < width: + slice_high = sig.offset + len(bits) - 1 + smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr) + + bit_slice = bits + + if "?" in bit_slice: + mask = bit_slice.replace("0", "1").replace("?", "0") + bit_slice = bit_slice.replace("?", "0") + smt_expr = "(bvand %s #b%s)" % (smt_expr, mask) + + smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice) + constr_assumes[t].append((inywfile, smt_constr)) + + if not got_topt: + if not check_witness: + skip_steps = max(skip_steps, t) + num_steps = max(num_steps, t+1) + if btorwitfile is not None: with open(btorwitfile, "r") as f: step = None @@ -699,128 +830,137 @@ if btorwitfile is not None: skip_steps = step num_steps = step+1 -def write_vcd_trace(steps_start, steps_stop, index): - filename = vcdfile.replace("%", index) - print_msg("Writing trace to VCD file: %s" % (filename)) +def collect_mem_trace_data(steps_start, steps_stop, vcd=None): + mem_trace_data = dict() - with open(filename, "w") as vcd_file: - vcd = MkVcd(vcd_file) - path_list = list() + for mempath in sorted(smt.hiermems(topmod)): + abits, width, rports, wports, asyncwr = smt.mem_info(topmod, mempath) - for netpath in sorted(smt.hiernets(topmod)): - hidden_net = False - for n in netpath: - if n.startswith("$"): - hidden_net = True - if not hidden_net: - edge = smt.net_clock(topmod, netpath) - if edge is None: - vcd.add_net([topmod] + netpath, smt.net_width(topmod, netpath)) - else: - vcd.add_clock([topmod] + netpath, edge) - path_list.append(netpath) + expr_id = list() + expr_list = list() + for i in range(steps_start, steps_stop): + for j in range(rports): + expr_id.append(('R', i-steps_start, j, 'A')) + expr_id.append(('R', i-steps_start, j, 'D')) + expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dA" % j)) + expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dD" % j)) + for j in range(wports): + expr_id.append(('W', i-steps_start, j, 'A')) + expr_id.append(('W', i-steps_start, j, 'D')) + expr_id.append(('W', i-steps_start, j, 'M')) + expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dA" % j)) + expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dD" % j)) + expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dM" % j)) + + rdata = list() + wdata = list() + addrs = set() + + for eid, edat in zip(expr_id, smt.get_list(expr_list)): + t, i, j, f = eid + + if t == 'R': + c = rdata + elif t == 'W': + c = wdata + else: + assert False - mem_trace_data = dict() - for mempath in sorted(smt.hiermems(topmod)): - abits, width, rports, wports, asyncwr = smt.mem_info(topmod, mempath) + while len(c) <= i: + c.append(list()) + c = c[i] - expr_id = list() - expr_list = list() - for i in range(steps_start, steps_stop): - for j in range(rports): - expr_id.append(('R', i-steps_start, j, 'A')) - expr_id.append(('R', i-steps_start, j, 'D')) - expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dA" % j)) - expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dD" % j)) - for j in range(wports): - expr_id.append(('W', i-steps_start, j, 'A')) - expr_id.append(('W', i-steps_start, j, 'D')) - expr_id.append(('W', i-steps_start, j, 'M')) - expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dA" % j)) - expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dD" % j)) - expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dM" % j)) - - rdata = list() - wdata = list() - addrs = set() - - for eid, edat in zip(expr_id, smt.get_list(expr_list)): - t, i, j, f = eid - - if t == 'R': - c = rdata - elif t == 'W': - c = wdata - else: - assert False + while len(c) <= j: + c.append(dict()) + c = c[j] - while len(c) <= i: - c.append(list()) - c = c[i] + c[f] = smt.bv2bin(edat) - while len(c) <= j: - c.append(dict()) - c = c[j] + if f == 'A': + addrs.add(c[f]) - c[f] = smt.bv2bin(edat) + for addr in addrs: + tdata = list() + data = ["x"] * width + gotread = False - if f == 'A': - addrs.add(c[f]) + if len(wdata) == 0 and len(rdata) != 0: + wdata = [[]] * len(rdata) - for addr in addrs: - tdata = list() - data = ["x"] * width - gotread = False + assert len(rdata) == len(wdata) - if len(wdata) == 0 and len(rdata) != 0: - wdata = [[]] * len(rdata) + for i in range(len(wdata)): + if not gotread: + for j_data in rdata[i]: + if j_data["A"] == addr: + data = list(j_data["D"]) + gotread = True + break - assert len(rdata) == len(wdata) + if gotread: + buf = data[:] + for ii in reversed(range(len(tdata))): + for k in range(width): + if tdata[ii][k] == "x": + tdata[ii][k] = buf[k] + else: + buf[k] = tdata[ii][k] - for i in range(len(wdata)): - if not gotread: - for j_data in rdata[i]: - if j_data["A"] == addr: - data = list(j_data["D"]) - gotread = True - break - - if gotread: - buf = data[:] - for ii in reversed(range(len(tdata))): - for k in range(width): - if tdata[ii][k] == "x": - tdata[ii][k] = buf[k] - else: - buf[k] = tdata[ii][k] + if not asyncwr: + tdata.append(data[:]) - if not asyncwr: - tdata.append(data[:]) + for j_data in wdata[i]: + if j_data["A"] != addr: + continue - for j_data in wdata[i]: - if j_data["A"] != addr: - continue + D = j_data["D"] + M = j_data["M"] - D = j_data["D"] - M = j_data["M"] + for k in range(width): + if M[k] == "1": + data[k] = D[k] - for k in range(width): - if M[k] == "1": - data[k] = D[k] + if asyncwr: + tdata.append(data[:]) - if asyncwr: - tdata.append(data[:]) + assert len(tdata) == len(rdata) - assert len(tdata) == len(rdata) + int_addr = int(addr, 2) - netpath = mempath[:] - netpath[-1] += "<%0*x>" % ((len(addr)+3) // 4, int(addr, 2)) + netpath = mempath[:] + if vcd: + netpath[-1] += "<%0*x>" % ((len(addr)+3) // 4, int_addr) vcd.add_net([topmod] + netpath, width) - for i in range(steps_start, steps_stop): - if i not in mem_trace_data: - mem_trace_data[i] = list() - mem_trace_data[i].append((netpath, "".join(tdata[i-steps_start]))) + for i in range(steps_start, steps_stop): + if i not in mem_trace_data: + mem_trace_data[i] = list() + mem_trace_data[i].append((netpath, int_addr, "".join(tdata[i-steps_start]))) + + return mem_trace_data + +def write_vcd_trace(steps_start, steps_stop, index): + filename = vcdfile.replace("%", index) + print_msg("Writing trace to VCD file: %s" % (filename)) + + with open(filename, "w") as vcd_file: + vcd = MkVcd(vcd_file) + path_list = list() + + for netpath in sorted(smt.hiernets(topmod)): + hidden_net = False + for n in netpath: + if n.startswith("$"): + hidden_net = True + if not hidden_net: + edge = smt.net_clock(topmod, netpath) + if edge is None: + vcd.add_net([topmod] + netpath, smt.net_width(topmod, netpath)) + else: + vcd.add_clock([topmod] + netpath, edge) + path_list.append(netpath) + + mem_trace_data = collect_mem_trace_data(steps_start, steps_stop, vcd) for i in range(steps_start, steps_stop): vcd.set_time(i) @@ -828,11 +968,35 @@ def write_vcd_trace(steps_start, steps_stop, index): for path, value in zip(path_list, value_list): vcd.set_net([topmod] + path, value) if i in mem_trace_data: - for path, value in mem_trace_data[i]: + for path, addr, value in mem_trace_data[i]: vcd.set_net([topmod] + path, value) vcd.set_time(steps_stop) +def detect_state_loop(steps_start, steps_stop): + print_msg(f"Checking for loops in found induction counter example") + print_msg(f"This feature is experimental and incomplete") + + path_list = sorted(smt.hiernets(topmod, regs_only=True)) + + mem_trace_data = collect_mem_trace_data(steps_start, steps_stop) + + # Map state to index of step when it occurred + states = dict() + + for i in range(steps_start, steps_stop): + value_list = smt.get_net_bin_list(topmod, path_list, "s%d" % i) + mem_state = sorted( + [(tuple(path), addr, data) + for path, addr, data in mem_trace_data.get(i, [])]) + state = tuple(value_list), tuple(mem_state) + if state in states: + return (i, states[state]) + else: + states[state] = i + + return None + def char_ok_in_verilog(c,i): if ('A' <= c <= 'Z'): return True if ('a' <= c <= 'z'): return True @@ -1072,8 +1236,73 @@ def write_constr_trace(steps_start, steps_stop, index): for name, val in zip(pi_names, pi_values): print("assume (= [%s%s] %s)" % (constr_prefix, ".".join(name), val), file=f) +def write_yw_trace(steps_start, steps_stop, index, allregs=False): + filename = outywfile.replace("%", index) + print_msg("Writing trace to Yosys witness file: %s" % (filename)) + + mem_trace_data = collect_mem_trace_data(steps_start, steps_stop) + + with open(filename, "w") as f: + inits, seqs, clocks, mems = smt.hierwitness(topmod, allregs) + + yw = WriteWitness(f, "smtbmc") + + for clock in clocks: + yw.add_clock(clock["path"], clock["offset"], clock["type"]) + + for seq in seqs: + seq["sig"] = yw.add_sig(seq["path"], seq["offset"], seq["width"]) + + for init in inits: + init["sig"] = yw.add_sig(init["path"], init["offset"], init["width"], True) + + inits = seqs + inits + + mem_dict = {tuple(mem["smtpath"]): mem for mem in mems} + mem_init_values = [] + + for path, addr, value in mem_trace_data.get(0, ()): + json_mem = mem_dict.get(tuple(path)) + if not json_mem: + continue + + bit_addr = addr * json_mem["width"] + uninit_chunks = [(chunk["width"] + chunk["offset"], chunk["offset"]) for chunk in json_mem["uninitialized"]] + first_chunk_nr = bisect.bisect_left(uninit_chunks, (bit_addr + 1,)) -def write_trace(steps_start, steps_stop, index): + for uninit_end, uninit_offset in uninit_chunks[first_chunk_nr:]: + assert uninit_end > bit_addr + if uninit_offset > bit_addr + json_mem["width"]: + break + + word_path = (*json_mem["path"], f"\\[{addr}]") + + overlap_start = max(uninit_offset - bit_addr, 0) + overlap_end = min(uninit_end - bit_addr, json_mem["width"]) + overlap_bits = value[len(value)-overlap_end:len(value)-overlap_start] + + sig = yw.add_sig(word_path, overlap_start, overlap_end - overlap_start, True) + mem_init_values.append((sig, overlap_bits.replace("x", "?"))) + + for k in range(steps_start, steps_stop): + step_values = WitnessValues() + + if k == steps_start: + for sig, value in mem_init_values: + step_values[sig] = value + sigs = inits + seqs + else: + sigs = seqs + + for sig in sigs: + value = smt.bv2bin(smt.get(smt.witness_net_expr(topmod, f"s{k}", sig))) + step_values[sig["sig"]] = value + yw.step(step_values) + + yw.end_trace() + + +def write_trace(steps_start, steps_stop, index, allregs=False): if vcdfile is not None: write_vcd_trace(steps_start, steps_stop, index) @@ -1083,6 +1312,9 @@ def write_trace(steps_start, steps_stop, index): if outconstr is not None: write_constr_trace(steps_start, steps_stop, index) + if outywfile is not None: + write_yw_trace(steps_start, steps_stop, index, allregs) + def print_failed_asserts_worker(mod, state, path, extrainfo, infomap, infokey=()): assert mod in smt.modinfo @@ -1392,12 +1624,16 @@ if tempind: print_msg("Temporal induction failed!") print_anyconsts(num_steps) print_failed_asserts(num_steps) - write_trace(step, num_steps+1, '%') + write_trace(step, num_steps+1, '%', allregs=True) + if detect_loops: + loop = detect_state_loop(step, num_steps+1) + if loop: + print_msg(f"Loop detected, increasing induction depth will not help. Step {loop[0]} = step {loop[1]}") elif dumpall: print_anyconsts(num_steps) print_failed_asserts(num_steps) - write_trace(step, num_steps+1, "%d" % step) + write_trace(step, num_steps+1, "%d" % step, allregs=True) else: print_msg("Temporal induction successful.") @@ -1590,6 +1826,7 @@ else: # not tempind, covermode smt_assert("(not %s)" % active_assert_expr) else: + active_assert_expr = "true" smt_assert("false") @@ -1597,6 +1834,17 @@ else: # not tempind, covermode if retstatus != "FAILED": print("%s BMC failed!" % smt.timestamp()) + if check_witness: + print_msg("Checking witness constraints...") + smt_pop() + smt_push() + smt_assert(active_assert_expr) + if smt_check_sat() != "sat": + retstatus = "PASSED" + check_witness = False + num_steps = -1 + break + if append_steps > 0: for i in range(last_check_step+1, last_check_step+1+append_steps): print_msg("Appending additional step %d." % i) @@ -1689,6 +1937,8 @@ else: # not tempind, covermode print_anyconsts(0) write_trace(0, num_steps, '%') + if check_witness: + retstatus = "FAILED" smt.write("(exit)") smt.wait() |