diff options
Diffstat (limited to 'backends/smt2/smtio.py')
-rw-r--r-- | backends/smt2/smtio.py | 365 |
1 files changed, 327 insertions, 38 deletions
diff --git a/backends/smt2/smtio.py b/backends/smt2/smtio.py index abf8e812d..34bf7ef38 100644 --- a/backends/smt2/smtio.py +++ b/backends/smt2/smtio.py @@ -16,10 +16,67 @@ # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # -import sys, subprocess, re, os +import sys, re, os, signal +import subprocess +if os.name == "posix": + import resource from copy import deepcopy from select import select from time import time +from queue import Queue, Empty +from threading import Thread + + +# This is needed so that the recursive SMT2 S-expression parser +# does not run out of stack frames when parsing large expressions +if os.name == "posix": + smtio_reclimit = 64 * 1024 + if sys.getrecursionlimit() < smtio_reclimit: + sys.setrecursionlimit(smtio_reclimit) + + current_rlimit_stack = resource.getrlimit(resource.RLIMIT_STACK) + if current_rlimit_stack[0] != resource.RLIM_INFINITY: + smtio_stacksize = 128 * 1024 * 1024 + if os.uname().sysname == "Darwin": + # MacOS has rather conservative stack limits + smtio_stacksize = 16 * 1024 * 1024 + if current_rlimit_stack[1] != resource.RLIM_INFINITY: + smtio_stacksize = min(smtio_stacksize, current_rlimit_stack[1]) + if current_rlimit_stack[0] < smtio_stacksize: + try: + resource.setrlimit(resource.RLIMIT_STACK, (smtio_stacksize, current_rlimit_stack[1])) + except ValueError: + # couldn't get more stack, just run with what we have + pass + + +# currently running solvers (so we can kill them) +running_solvers = dict() +forced_shutdown = False +solvers_index = 0 + +def force_shutdown(signum, frame): + global forced_shutdown + if not forced_shutdown: + forced_shutdown = True + if signum is not None: + print("<%s>" % signal.Signals(signum).name) + for p in running_solvers.values(): + # os.killpg(os.getpgid(p.pid), signal.SIGTERM) + os.kill(p.pid, signal.SIGTERM) + sys.exit(1) + +if os.name == "posix": + signal.signal(signal.SIGHUP, force_shutdown) +signal.signal(signal.SIGINT, force_shutdown) +signal.signal(signal.SIGTERM, force_shutdown) + +def except_hook(exctype, value, traceback): + if not forced_shutdown: + sys.__excepthook__(exctype, value, traceback) + force_shutdown(None, None) + +sys.excepthook = except_hook hex_dict = { @@ -40,24 +97,33 @@ class SmtModInfo: self.memories = dict() self.wires = set() self.wsize = dict() + self.clocks = dict() self.cells = dict() self.asserts = dict() self.covers = dict() self.anyconsts = dict() self.anyseqs = dict() + self.allconsts = dict() + self.allseqs = dict() + self.asize = dict() class SmtIo: def __init__(self, opts=None): + global solvers_index + self.logic = None self.logic_qf = True self.logic_ax = True self.logic_uf = True self.logic_bv = True self.logic_dt = False + self.forall = False self.produce_models = True self.smt2cache = [list()] self.p = None + self.p_index = solvers_index + solvers_index += 1 if opts is not None: self.logic = opts.logic @@ -91,23 +157,41 @@ class SmtIo: self.topmod = None self.setup_done = False + def __del__(self): + if self.p is not None and not forced_shutdown: + os.killpg(os.getpgid(self.p.pid), signal.SIGTERM) + if running_solvers is not None: + del running_solvers[self.p_index] + def setup(self): assert not self.setup_done + if self.forall: + self.unroll = False + if self.solver == "yices": - self.popen_vargs = ['yices-smt2', '--incremental'] + self.solver_opts + if self.noincr: + self.popen_vargs = ['yices-smt2'] + self.solver_opts + else: + self.popen_vargs = ['yices-smt2', '--incremental'] + self.solver_opts if self.solver == "z3": self.popen_vargs = ['z3', '-smt2', '-in'] + self.solver_opts if self.solver == "cvc4": - self.popen_vargs = ['cvc4', '--incremental', '--lang', 'smt2.6' if self.logic_dt else 'smt2'] + self.solver_opts + if self.noincr: + self.popen_vargs = ['cvc4', '--lang', 'smt2.6' if self.logic_dt else 'smt2'] + self.solver_opts + else: + self.popen_vargs = ['cvc4', '--incremental', '--lang', 'smt2.6' if self.logic_dt else 'smt2'] + self.solver_opts if self.solver == "mathsat": self.popen_vargs = ['mathsat'] + self.solver_opts if self.solver == "boolector": - self.popen_vargs = ['boolector', '--smt2', '-i'] + self.solver_opts + if self.noincr: + self.popen_vargs = ['boolector', '--smt2'] + self.solver_opts + else: + self.popen_vargs = ['boolector', '--smt2', '-i'] + self.solver_opts self.unroll = True if self.solver == "abc": @@ -126,9 +210,10 @@ class SmtIo: if self.dummy_file is not None: self.dummy_fd = open(self.dummy_file, "w") if not self.noincr: - self.p = subprocess.Popen(self.popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + self.p_open() if self.unroll: + assert not self.forall self.logic_uf = False self.unroll_idcnt = 0 self.unroll_buffer = "" @@ -148,17 +233,17 @@ class SmtIo: self.setup_done = True + for stmt in self.info_stmts: + self.write(stmt) + if self.produce_models: self.write("(set-option :produce-models true)") self.write("(set-logic %s)" % self.logic) - for stmt in self.info_stmts: - self.write(stmt) - def timestamp(self): secs = int(time() - self.start_time) - return "## %6d %3d:%02d:%02d " % (secs, secs // (60*60), (secs // 60) % 60, secs % 60) + return "## %3d:%02d:%02d " % (secs // (60*60), (secs // 60) % 60, secs % 60) def replace_in_stmt(self, stmt, pat, repl): if stmt == pat: @@ -209,6 +294,65 @@ class SmtIo: return stmt + def p_thread_main(self): + while True: + data = self.p.stdout.readline().decode("ascii") + if data == "": break + self.p_queue.put(data) + self.p_queue.put("") + self.p_running = False + + def p_open(self): + assert self.p is None + try: + self.p = subprocess.Popen(self.popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + except FileNotFoundError: + print("%s SMT Solver '%s' not found in path." % (self.timestamp(), self.popen_vargs[0]), flush=True) + sys.exit(1) + running_solvers[self.p_index] = self.p + self.p_running = True + self.p_next = None + self.p_queue = Queue() + self.p_thread = Thread(target=self.p_thread_main) + self.p_thread.start() + + def p_write(self, data, flush): + assert self.p is not None + self.p.stdin.write(bytes(data, "ascii")) + if flush: self.p.stdin.flush() + + def p_read(self): + assert self.p is not None + if self.p_next is not None: + data = self.p_next + self.p_next = None + return data + if not self.p_running: + return "" + return self.p_queue.get() + + def p_poll(self, timeout=0.1): + assert self.p is not None + assert self.p_running + if self.p_next is not None: + return False + try: + self.p_next = self.p_queue.get(True, timeout) + return False + except Empty: + return True + + def p_close(self): + assert self.p is not None + self.p.stdin.close() + self.p_thread.join() + assert not self.p_running + del running_solvers[self.p_index] + self.p = None + self.p_next = None + self.p_queue = None + self.p_thread = None + def write(self, stmt, unroll=True): if stmt.startswith(";"): self.info(stmt) @@ -281,20 +425,17 @@ class SmtIo: if self.solver != "dummy": if self.noincr: if self.p is not None and not stmt.startswith("(get-"): - self.p.stdin.close() - self.p = None + self.p_close() if stmt == "(push 1)": self.smt2cache.append(list()) elif stmt == "(pop 1)": self.smt2cache.pop() else: if self.p is not None: - self.p.stdin.write(bytes(stmt + "\n", "ascii")) - self.p.stdin.flush() + self.p_write(stmt + "\n", True) self.smt2cache[-1].append(stmt) else: - self.p.stdin.write(bytes(stmt + "\n", "ascii")) - self.p.stdin.flush() + self.p_write(stmt + "\n", True) def info(self, stmt): if not stmt.startswith("; yosys-smt2-"): @@ -314,6 +455,11 @@ class SmtIo: if self.logic is None: self.logic_dt = True + if fields[1] == "yosys-smt2-forall": + if self.logic is None: + self.logic_qf = False + self.forall = True + if fields[1] == "yosys-smt2-module": self.curmod = fields[2] self.modinfo[self.curmod] = SmtModInfo() @@ -337,12 +483,19 @@ class SmtIo: self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3]) if fields[1] == "yosys-smt2-memory": - self.modinfo[self.curmod].memories[fields[2]] = (int(fields[3]), int(fields[4]), int(fields[5]), int(fields[6])) + self.modinfo[self.curmod].memories[fields[2]] = (int(fields[3]), int(fields[4]), int(fields[5]), int(fields[6]), fields[7] == "async") if fields[1] == "yosys-smt2-wire": self.modinfo[self.curmod].wires.add(fields[2]) self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3]) + if fields[1] == "yosys-smt2-clock": + for edge in fields[3:]: + if fields[2] not in self.modinfo[self.curmod].clocks: + self.modinfo[self.curmod].clocks[fields[2]] = edge + elif self.modinfo[self.curmod].clocks[fields[2]] != edge: + self.modinfo[self.curmod].clocks[fields[2]] = "event" + if fields[1] == "yosys-smt2-assert": self.modinfo[self.curmod].asserts["%s_a %s" % (self.curmod, fields[2])] = fields[3] @@ -350,10 +503,20 @@ class SmtIo: self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = fields[3] if fields[1] == "yosys-smt2-anyconst": - self.modinfo[self.curmod].anyconsts[fields[2]] = (fields[3], None if len(fields) <= 4 else fields[4]) + self.modinfo[self.curmod].anyconsts[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5]) + self.modinfo[self.curmod].asize[fields[2]] = int(fields[3]) if fields[1] == "yosys-smt2-anyseq": - self.modinfo[self.curmod].anyseqs[fields[2]] = (fields[3], None if len(fields) <= 4 else fields[4]) + self.modinfo[self.curmod].anyseqs[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5]) + self.modinfo[self.curmod].asize[fields[2]] = int(fields[3]) + + if fields[1] == "yosys-smt2-allconst": + self.modinfo[self.curmod].allconsts[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5]) + self.modinfo[self.curmod].asize[fields[2]] = int(fields[3]) + + if fields[1] == "yosys-smt2-allseq": + self.modinfo[self.curmod].allseqs[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5]) + self.modinfo[self.curmod].asize[fields[2]] = int(fields[3]) def hiernets(self, top, regs_only=False): def hiernets_worker(nets, mod, cursor): @@ -370,7 +533,8 @@ class SmtIo: def hieranyconsts(self, top): def worker(results, mod, cursor): for name, value in sorted(self.modinfo[mod].anyconsts.items()): - results.append((cursor, name, value[0], value[1])) + width = self.modinfo[mod].asize[name] + results.append((cursor, name, value[0], value[1], width)) for cellname, celltype in sorted(self.modinfo[mod].cells.items()): worker(results, celltype, cursor + [cellname]) @@ -381,7 +545,32 @@ class SmtIo: def hieranyseqs(self, top): def worker(results, mod, cursor): for name, value in sorted(self.modinfo[mod].anyseqs.items()): - results.append((cursor, name, value[0], value[1])) + width = self.modinfo[mod].asize[name] + results.append((cursor, name, value[0], value[1], width)) + for cellname, celltype in sorted(self.modinfo[mod].cells.items()): + worker(results, celltype, cursor + [cellname]) + + results = list() + worker(results, top, []) + return results + + def hierallconsts(self, top): + def worker(results, mod, cursor): + for name, value in sorted(self.modinfo[mod].allconsts.items()): + width = self.modinfo[mod].asize[name] + results.append((cursor, name, value[0], value[1], width)) + for cellname, celltype in sorted(self.modinfo[mod].cells.items()): + worker(results, celltype, cursor + [cellname]) + + results = list() + worker(results, top, []) + return results + + def hierallseqs(self, top): + def worker(results, mod, cursor): + for name, value in sorted(self.modinfo[mod].allseqs.items()): + width = self.modinfo[mod].asize[name] + results.append((cursor, name, value[0], value[1], width)) for cellname, celltype in sorted(self.modinfo[mod].cells.items()): worker(results, celltype, cursor + [cellname]) @@ -408,7 +597,7 @@ class SmtIo: if self.solver == "dummy": line = self.dummy_fd.readline().strip() else: - line = self.p.stdout.readline().decode("ascii").strip() + line = self.p_read().strip() if self.dummy_file is not None: self.dummy_fd.write(line + "\n") @@ -421,12 +610,14 @@ class SmtIo: if count_brackets == 0: break if self.solver != "dummy" and self.p.poll(): - print("SMT Solver terminated unexpectedly: %s" % "".join(stmt)) + print("%s Solver terminated unexpectedly: %s" % (self.timestamp(), "".join(stmt)), flush=True) sys.exit(1) stmt = "".join(stmt) if stmt.startswith("(error"): - print("SMT Solver Error: %s" % stmt, file=sys.stderr) + print("%s Solver Error: %s" % (self.timestamp(), stmt), flush=True) + if self.solver != "dummy": + self.p_close() sys.exit(1) return stmt @@ -441,15 +632,13 @@ class SmtIo: if self.solver != "dummy": if self.noincr: if self.p is not None: - self.p.stdin.close() - self.p = None - self.p = subprocess.Popen(self.popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + self.p_close() + self.p_open() for cache_ctx in self.smt2cache: for cache_stmt in cache_ctx: - self.p.stdin.write(bytes(cache_stmt + "\n", "ascii")) + self.p_write(cache_stmt + "\n", False) - self.p.stdin.write(bytes("(check-sat)\n", "ascii")) - self.p.stdin.flush() + self.p_write("(check-sat)\n", True) if self.timeinfo: i = 0 @@ -457,7 +646,7 @@ class SmtIo: count = 0 num_bs = 0 - while select([self.p.stdout], [], [], 0.1) == ([], [], []): + while self.p_poll(): count += 1 if count < 25: @@ -486,11 +675,43 @@ class SmtIo: print("\b \b" * num_bs, end="", file=sys.stderr) sys.stderr.flush() + else: + count = 0 + while self.p_poll(60): + count += 1 + msg = None + + if count == 1: + msg = "1 minute" + + elif count in [5, 10, 15, 30]: + msg = "%d minutes" % count + + elif count == 60: + msg = "1 hour" + + elif count % 60 == 0: + msg = "%d hours" % (count // 60) + + if msg is not None: + print("%s waiting for solver (%s)" % (self.timestamp(), msg), flush=True) + result = self.read() + if self.debug_file: print("(set-info :status %s)" % result, file=self.debug_file) print("(check-sat)", file=self.debug_file) self.debug_file.flush() + + if result not in ["sat", "unsat"]: + if result == "": + print("%s Unexpected EOF response from solver." % (self.timestamp()), flush=True) + else: + print("%s Unexpected response from solver: %s" % (self.timestamp(), result), flush=True) + if self.solver != "dummy": + self.p_close() + sys.exit(1) + return result def parse(self, stmt): @@ -545,6 +766,9 @@ class SmtIo: return h def bv2bin(self, v): + if type(v) is list and len(v) == 3 and v[0] == "_" and v[1].startswith("bv"): + x, n = int(v[1][2:]), int(v[2]) + return "".join("1" if (x & (1 << i)) else "0" for i in range(n-1, -1, -1)) if v == "true": return "1" if v == "false": return "0" if v.startswith("#b"): @@ -568,7 +792,7 @@ class SmtIo: def get_path(self, mod, path): assert mod in self.modinfo - path = path.split(".") + path = path.replace("\\", "/").split(".") for i in range(len(path)-1): first = ".".join(path[0:i+1]) @@ -613,6 +837,17 @@ class SmtIo: assert net_path[-1] in self.modinfo[mod].wsize return self.modinfo[mod].wsize[net_path[-1]] + def net_clock(self, mod, net_path): + for i in range(len(net_path)-1): + assert mod in self.modinfo + assert net_path[i] in self.modinfo[mod].cells + mod = self.modinfo[mod].cells[net_path[i]] + + assert mod in self.modinfo + if net_path[-1] not in self.modinfo[mod].clocks: + return None + return self.modinfo[mod].clocks[net_path[-1]] + def net_exists(self, mod, net_path): for i in range(len(net_path)-1): if mod not in self.modinfo: return False @@ -672,6 +907,7 @@ class SmtIo: def wait(self): if self.p is not None: self.p.wait() + self.p_close() class SmtOpts: @@ -763,6 +999,7 @@ class MkVcd: self.f = f self.t = -1 self.nets = dict() + self.clocks = dict() def add_net(self, path, width): path = tuple(path) @@ -770,34 +1007,86 @@ class MkVcd: key = "n%d" % len(self.nets) self.nets[path] = (key, width) + def add_clock(self, path, edge): + path = tuple(path) + assert self.t == -1 + key = "n%d" % len(self.nets) + self.nets[path] = (key, 1) + self.clocks[path] = (key, edge) + def set_net(self, path, bits): path = tuple(path) assert self.t >= 0 assert path in self.nets - print("b%s %s" % (bits, self.nets[path][0]), file=self.f) + if path not in self.clocks: + print("b%s %s" % (bits, self.nets[path][0]), file=self.f) + + def escape_name(self, name): + name = re.sub(r"\[([0-9a-zA-Z_]*[a-zA-Z_][0-9a-zA-Z_]*)\]", r"<\1>", name) + if re.match("[\[\]]", name) and name[0] != "\\": + name = "\\" + name + return name def set_time(self, t): assert t >= self.t if t != self.t: if self.t == -1: + print("$version Generated by Yosys-SMTBMC $end", file=self.f) + print("$timescale 1ns $end", file=self.f) print("$var integer 32 t smt_step $end", file=self.f) print("$var event 1 ! smt_clock $end", file=self.f) + + def vcdescape(n): + if n.startswith("$") or ":" in n: + return "\\" + n + return n + scope = [] for path in sorted(self.nets): - while len(scope)+1 > len(path) or (len(scope) > 0 and scope[-1] != path[len(scope)-1]): + key, width = self.nets[path] + + uipath = list(path) + if "." in uipath[-1] and not uipath[-1].startswith("$"): + uipath = uipath[0:-1] + uipath[-1].split(".") + for i in range(len(uipath)): + uipath[i] = re.sub(r"\[([^\]]*)\]", r"<\1>", uipath[i]) + + while uipath[:len(scope)] != scope: print("$upscope $end", file=self.f) scope = scope[:-1] - while len(scope)+1 < len(path): - print("$scope module %s $end" % path[len(scope)], file=self.f) - scope.append(path[len(scope)-1]) - key, width = self.nets[path] - print("$var wire %d %s %s $end" % (width, key, path[-1]), file=self.f) + + while uipath[:-1] != scope: + scopename = uipath[len(scope)] + print("$scope module %s $end" % vcdescape(scopename), file=self.f) + scope.append(uipath[len(scope)]) + + if path in self.clocks and self.clocks[path][1] == "event": + print("$var event 1 %s %s $end" % (key, vcdescape(uipath[-1])), file=self.f) + else: + print("$var wire %d %s %s $end" % (width, key, vcdescape(uipath[-1])), file=self.f) + for i in range(len(scope)): print("$upscope $end", file=self.f) + print("$enddefinitions $end", file=self.f) + self.t = t assert self.t >= 0 + + if self.t > 0: + print("#%d" % (10 * self.t - 5), file=self.f) + for path in sorted(self.clocks.keys()): + if self.clocks[path][1] == "posedge": + print("b0 %s" % self.nets[path][0], file=self.f) + elif self.clocks[path][1] == "negedge": + print("b1 %s" % self.nets[path][0], file=self.f) + print("#%d" % (10 * self.t), file=self.f) print("1!", file=self.f) print("b%s t" % format(self.t, "032b"), file=self.f) + for path in sorted(self.clocks.keys()): + if self.clocks[path][1] == "negedge": + print("b0 %s" % self.nets[path][0], file=self.f) + else: + print("b1 %s" % self.nets[path][0], file=self.f) |