diff options
Diffstat (limited to 'backends/smt2/smtio.py')
-rw-r--r-- | backends/smt2/smtio.py | 161 |
1 files changed, 103 insertions, 58 deletions
diff --git a/backends/smt2/smtio.py b/backends/smt2/smtio.py index 5d0a78485..58554572d 100644 --- a/backends/smt2/smtio.py +++ b/backends/smt2/smtio.py @@ -47,11 +47,19 @@ class SmtModInfo: class SmtIo: - def __init__(self, solver=None, debug_print=None, debug_file=None, timeinfo=None, unroll=None, opts=None): + def __init__(self, opts=None): + self.logic = None + self.logic_qf = True + self.logic_ax = True + self.logic_uf = True + self.logic_bv = True + if opts is not None: + self.logic = opts.logic self.solver = opts.solver self.debug_print = opts.debug_print self.debug_file = opts.debug_file + self.dummy_file = opts.dummy_file self.timeinfo = opts.timeinfo self.unroll = opts.unroll @@ -59,24 +67,10 @@ class SmtIo: self.solver = "z3" self.debug_print = False self.debug_file = None + self.dummy_file = None self.timeinfo = True self.unroll = False - if solver is not None: - self.solver = solver - - if debug_print is not None: - self.debug_print = debug_print - - if debug_file is not None: - self.debug_file = debug_file - - if timeinfo is not None: - self.timeinfo = timeinfo - - if unroll is not None: - self.unroll = unroll - if self.solver == "yices": popen_vargs = ['yices-smt2', '--incremental'] @@ -93,7 +87,16 @@ class SmtIo: popen_vargs = ['boolector', '--smt2', '-i'] self.unroll = True + if self.solver == "dummy": + assert self.dummy_file is not None + self.dummy_fd = open(self.dummy_file, "r") + else: + if self.dummy_file is not None: + self.dummy_fd = open(self.dummy_file, "w") + self.p = subprocess.Popen(popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + if self.unroll: + self.logic_uf = False self.unroll_idcnt = 0 self.unroll_buffer = "" self.unroll_sorts = set() @@ -102,20 +105,26 @@ class SmtIo: self.unroll_cache = dict() self.unroll_stack = list() - self.p = subprocess.Popen(popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) self.start_time = time() self.modinfo = dict() self.curmod = None self.topmod = None + self.setup_done = False + + def setup(self): + assert not self.setup_done + + if self.logic is None: + self.logic = "" + if self.logic_qf: self.logic += "QF_" + if self.logic_ax: self.logic += "A" + if self.logic_uf: self.logic += "UF" + if self.logic_bv: self.logic += "BV" - def setup(self, logic="ALL", info=None): + self.setup_done = True self.write("(set-option :produce-models true)") - self.write("(set-logic %s)" % logic) - if info is not None: - self.write("(set-info :source |%s|)" % info) - self.write("(set-info :smt-lib-version 2.5)") - self.write("(set-info :category \"industrial\")") + self.write("(set-logic %s)" % self.logic) def timestamp(self): secs = int(time() - self.start_time) @@ -171,6 +180,11 @@ class SmtIo: return stmt def write(self, stmt, unroll=True): + if stmt.startswith(";"): + self.info(stmt) + elif not self.setup_done: + self.setup() + stmt = stmt.strip() if unroll and self.unroll: @@ -231,8 +245,9 @@ class SmtIo: print(stmt, file=self.debug_file) self.debug_file.flush() - self.p.stdin.write(bytes(stmt + "\n", "ascii")) - self.p.stdin.flush() + if self.solver != "dummy": + self.p.stdin.write(bytes(stmt + "\n", "ascii")) + self.p.stdin.flush() def info(self, stmt): if not stmt.startswith("; yosys-smt2-"): @@ -240,6 +255,14 @@ class SmtIo: fields = stmt.split() + if fields[1] == "yosys-smt2-nomem": + if self.logic is None: + self.logic_ax = False + + if fields[1] == "yosys-smt2-nobv": + if self.logic is None: + self.logic_bv = False + if fields[1] == "yosys-smt2-module": self.curmod = fields[2] self.modinfo[self.curmod] = SmtModInfo() @@ -303,15 +326,22 @@ class SmtIo: count_brackets = 0 while True: - line = self.p.stdout.readline().decode("ascii").strip() + if self.solver == "dummy": + line = self.dummy_fd.readline().strip() + else: + line = self.p.stdout.readline().decode("ascii").strip() + if self.dummy_file is not None: + self.dummy_fd.write(line + "\n") + count_brackets += line.count("(") count_brackets -= line.count(")") stmt.append(line) + if self.debug_print: print("< %s" % line) if count_brackets == 0: break - if self.p.poll(): + if self.solver != "dummy" and self.p.poll(): print("SMT Solver terminated unexpectedly: %s" % "".join(stmt)) sys.exit(1) @@ -329,43 +359,44 @@ class SmtIo: print("; running check-sat..", file=self.debug_file) self.debug_file.flush() - self.p.stdin.write(bytes("(check-sat)\n", "ascii")) - self.p.stdin.flush() + if self.solver != "dummy": + self.p.stdin.write(bytes("(check-sat)\n", "ascii")) + self.p.stdin.flush() - if self.timeinfo: - i = 0 - s = "/-\|" + if self.timeinfo: + i = 0 + s = "/-\|" - count = 0 - num_bs = 0 - while select([self.p.stdout], [], [], 0.1) == ([], [], []): - count += 1 + count = 0 + num_bs = 0 + while select([self.p.stdout], [], [], 0.1) == ([], [], []): + count += 1 - if count < 25: - continue + if count < 25: + continue - if count % 10 == 0 or count == 25: - secs = count // 10 + if count % 10 == 0 or count == 25: + secs = count // 10 - if secs < 60: - m = "(%d seconds)" % secs - elif secs < 60*60: - m = "(%d seconds -- %d:%02d)" % (secs, secs // 60, secs % 60) - else: - m = "(%d seconds -- %d:%02d:%02d)" % (secs, secs // (60*60), (secs // 60) % 60, secs % 60) + if secs < 60: + m = "(%d seconds)" % secs + elif secs < 60*60: + m = "(%d seconds -- %d:%02d)" % (secs, secs // 60, secs % 60) + else: + m = "(%d seconds -- %d:%02d:%02d)" % (secs, secs // (60*60), (secs // 60) % 60, secs % 60) - print("%s %s %c" % ("\b \b" * num_bs, m, s[i]), end="", file=sys.stderr) - num_bs = len(m) + 3 + print("%s %s %c" % ("\b \b" * num_bs, m, s[i]), end="", file=sys.stderr) + num_bs = len(m) + 3 - else: - print("\b" + s[i], end="", file=sys.stderr) + else: + print("\b" + s[i], end="", file=sys.stderr) - sys.stderr.flush() - i = (i + 1) % len(s) + sys.stderr.flush() + i = (i + 1) % len(s) - if num_bs != 0: - print("\b \b" * num_bs, end="", file=sys.stderr) - sys.stderr.flush() + if num_bs != 0: + print("\b \b" * num_bs, end="", file=sys.stderr) + sys.stderr.flush() result = self.read() if self.debug_file: @@ -524,18 +555,21 @@ class SmtIo: return [self.bv2bin(v) for v in self.get_net_list(mod_name, net_path_list, state_name)] def wait(self): - self.p.wait() + if self.solver != "dummy": + self.p.wait() class SmtOpts: def __init__(self): self.shortopts = "s:v" - self.longopts = ["unroll", "no-progress", "dump-smt2="] + self.longopts = ["unroll", "no-progress", "dump-smt2=", "logic=", "dummy="] self.solver = "z3" self.debug_print = False self.debug_file = None + self.dummy_file = None self.unroll = False self.timeinfo = True + self.logic = None def handle(self, o, a): if o == "-s": @@ -548,6 +582,10 @@ class SmtOpts: self.timeinfo = True elif o == "--dump-smt2": self.debug_file = open(a, "w") + elif o == "--logic": + self.logic = a + elif o == "--dummy": + self.dummy_file = a else: return False return True @@ -555,9 +593,16 @@ class SmtOpts: def helpmsg(self): return """ -s <solver> - set SMT solver: z3, cvc4, yices, mathsat + set SMT solver: z3, cvc4, yices, mathsat, boolector, dummy default: z3 + --logic <smt2_logic> + use the specified SMT2 logic (e.g. QF_AUFBV) + + --dummy <filename> + if solver is "dummy", read solver output from that file + otherwise: write solver output to that file + -v enable debug output |