diff options
Diffstat (limited to 'backends/smt2')
-rw-r--r-- | backends/smt2/smt2.cc | 115 | ||||
-rw-r--r-- | backends/smt2/smtbmc.py | 202 | ||||
-rw-r--r-- | backends/smt2/smtio.py | 52 |
3 files changed, 288 insertions, 81 deletions
diff --git a/backends/smt2/smt2.cc b/backends/smt2/smt2.cc index f2fa003bc..7481e0510 100644 --- a/backends/smt2/smt2.cc +++ b/backends/smt2/smt2.cc @@ -53,6 +53,8 @@ struct Smt2Worker std::map<int, int> bvsizes; dict<IdString, char*> ids; + bool is_smtlib2_module; + const char *get_id(IdString n) { if (ids.count(n) == 0) { @@ -112,9 +114,10 @@ struct Smt2Worker } Smt2Worker(RTLIL::Module *module, bool bvmode, bool memmode, bool wiresmode, bool verbose, bool statebv, bool statedt, bool forallmode, - dict<IdString, int> &mod_stbv_width, dict<IdString, dict<IdString, pair<bool, bool>>> &mod_clk_cache) : - ct(module->design), sigmap(module), module(module), bvmode(bvmode), memmode(memmode), wiresmode(wiresmode), - verbose(verbose), statebv(statebv), statedt(statedt), forallmode(forallmode), mod_stbv_width(mod_stbv_width) + dict<IdString, int> &mod_stbv_width, dict<IdString, dict<IdString, pair<bool, bool>>> &mod_clk_cache) + : ct(module->design), sigmap(module), module(module), bvmode(bvmode), memmode(memmode), wiresmode(wiresmode), verbose(verbose), + statebv(statebv), statedt(statedt), forallmode(forallmode), mod_stbv_width(mod_stbv_width), + is_smtlib2_module(module->has_attribute(ID::smtlib2_module)) { pool<SigBit> noclock; @@ -124,6 +127,9 @@ struct Smt2Worker memories = Mem::get_all_memories(module); for (auto &mem : memories) { + if (is_smtlib2_module) + log_error("Memory %s.%s not allowed in module with smtlib2_module attribute", get_id(module), mem.memid.c_str()); + mem.narrow(); mem_dict[mem.memid] = &mem; for (auto &port : mem.wr_ports) @@ -649,6 +655,27 @@ struct Smt2Worker return export_bvop(cell, "(bvurem A B)", 'd'); } } + // "div" = flooring division + if (cell->type == ID($divfloor)) { + if (cell->getParam(ID::A_SIGNED).as_bool()) { + // bvsdiv is truncating division, so we can't use it here. + int width = max(GetSize(cell->getPort(ID::A)), GetSize(cell->getPort(ID::B))); + width = max(width, GetSize(cell->getPort(ID::Y))); + auto expr = stringf("(let (" + "(a_neg (bvslt A #b%0*d)) " + "(b_neg (bvslt B #b%0*d))) " + "(let ((abs_a (ite a_neg (bvneg A) A)) " + "(abs_b (ite b_neg (bvneg B) B))) " + "(let ((u (bvudiv abs_a abs_b)) " + "(adj (ite (= #b%0*d (bvurem abs_a abs_b)) #b%0*d #b%0*d))) " + "(ite (= a_neg b_neg) u " + "(bvneg (bvadd u adj))))))", + width, 0, width, 0, width, 0, width, 0, width, 1); + return export_bvop(cell, expr, 'd'); + } else { + return export_bvop(cell, "(bvudiv A B)", 'd'); + } + } if (cell->type.in(ID($reduce_and), ID($reduce_or), ID($reduce_bool)) && 2*GetSize(cell->getPort(ID::A).chunks()) < GetSize(cell->getPort(ID::A))) { @@ -872,10 +899,25 @@ struct Smt2Worker log_id(cell->type), log_id(module), log_id(cell)); } + void verify_smtlib2_module() + { + if (!module->get_blackbox_attribute()) + log_error("Module %s with smtlib2_module attribute must also have blackbox attribute.\n", log_id(module)); + if (module->cells().size() > 0) + log_error("Module %s with smtlib2_module attribute must not have any cells inside it.\n", log_id(module)); + for (auto wire : module->wires()) + if (!wire->port_id) + log_error("Wire %s.%s must be input or output since module has smtlib2_module attribute.\n", log_id(module), + log_id(wire)); + } + void run() { if (verbose) log("=> export logic driving outputs\n"); + if (is_smtlib2_module) + verify_smtlib2_module(); + pool<SigBit> reg_bits; for (auto cell : module->cells()) if (cell->type.in(ID($ff), ID($dff), ID($_FF_), ID($_DFF_P_), ID($_DFF_N_))) { @@ -884,11 +926,25 @@ struct Smt2Worker reg_bits.insert(bit); } + std::string smtlib2_inputs; + std::vector<std::string> smtlib2_decls; + if (is_smtlib2_module) { + for (auto wire : module->wires()) { + if (!wire->port_input) + continue; + smtlib2_inputs += stringf("(|%s| (|%s_n %s| state))\n", get_id(wire), get_id(module), get_id(wire)); + } + } + for (auto wire : module->wires()) { bool is_register = false; for (auto bit : SigSpec(wire)) if (reg_bits.count(bit)) is_register = true; + bool is_smtlib2_comb_expr = wire->has_attribute(ID::smtlib2_comb_expr); + if (is_smtlib2_comb_expr && !is_smtlib2_module) + log_error("smtlib2_comb_expr is only valid in a module with the smtlib2_module attribute: wire %s.%s", log_id(module), + log_id(wire)); if (wire->port_id || is_register || wire->get_bool_attribute(ID::keep) || (wiresmode && wire->name.isPublic())) { RTLIL::SigSpec sig = sigmap(wire); std::vector<std::string> comments; @@ -903,11 +959,22 @@ struct Smt2Worker if (GetSize(wire) == 1 && (clock_posedge.count(sig) || clock_negedge.count(sig))) comments.push_back(stringf("; yosys-smt2-clock %s%s%s\n", get_id(wire), clock_posedge.count(sig) ? " posedge" : "", clock_negedge.count(sig) ? " negedge" : "")); + std::string smtlib2_comb_expr; + if (is_smtlib2_comb_expr) { + smtlib2_comb_expr = + "(let (\n" + smtlib2_inputs + ")\n" + wire->get_string_attribute(ID::smtlib2_comb_expr) + "\n)"; + if (wire->port_input || !wire->port_output) + log_error("smtlib2_comb_expr is only valid on output: wire %s.%s", log_id(module), log_id(wire)); + if (!bvmode && GetSize(sig) > 1) + log_error("smtlib2_comb_expr is unsupported on multi-bit wires when -nobv is specified: wire %s.%s", + log_id(module), log_id(wire)); + } + auto &out_decls = is_smtlib2_comb_expr ? smtlib2_decls : decls; if (bvmode && GetSize(sig) > 1) { - std::string sig_bv = get_bv(sig); + std::string sig_bv = is_smtlib2_comb_expr ? smtlib2_comb_expr : get_bv(sig); if (!comments.empty()) - decls.insert(decls.end(), comments.begin(), comments.end()); - decls.push_back(stringf("(define-fun |%s_n %s| ((state |%s_s|)) (_ BitVec %d) %s)\n", + out_decls.insert(out_decls.end(), comments.begin(), comments.end()); + out_decls.push_back(stringf("(define-fun |%s_n %s| ((state |%s_s|)) (_ BitVec %d) %s)\n", get_id(module), get_id(wire), get_id(module), GetSize(sig), sig_bv.c_str())); if (wire->port_input) ex_input_eq.push_back(stringf(" (= (|%s_n %s| state) (|%s_n %s| other_state))", @@ -915,19 +982,19 @@ struct Smt2Worker } else { std::vector<std::string> sig_bool; for (int i = 0; i < GetSize(sig); i++) { - sig_bool.push_back(get_bool(sig[i])); + sig_bool.push_back(is_smtlib2_comb_expr ? smtlib2_comb_expr : get_bool(sig[i])); } if (!comments.empty()) - decls.insert(decls.end(), comments.begin(), comments.end()); + out_decls.insert(out_decls.end(), comments.begin(), comments.end()); for (int i = 0; i < GetSize(sig); i++) { if (GetSize(sig) > 1) { - decls.push_back(stringf("(define-fun |%s_n %s %d| ((state |%s_s|)) Bool %s)\n", + out_decls.push_back(stringf("(define-fun |%s_n %s %d| ((state |%s_s|)) Bool %s)\n", get_id(module), get_id(wire), i, get_id(module), sig_bool[i].c_str())); if (wire->port_input) ex_input_eq.push_back(stringf(" (= (|%s_n %s %d| state) (|%s_n %s %d| other_state))", get_id(module), get_id(wire), i, get_id(module), get_id(wire), i)); } else { - decls.push_back(stringf("(define-fun |%s_n %s| ((state |%s_s|)) Bool %s)\n", + out_decls.push_back(stringf("(define-fun |%s_n %s| ((state |%s_s|)) Bool %s)\n", get_id(module), get_id(wire), get_id(module), sig_bool[i].c_str())); if (wire->port_input) ex_input_eq.push_back(stringf(" (= (|%s_n %s| state) (|%s_n %s| other_state))", @@ -938,11 +1005,17 @@ struct Smt2Worker } } + decls.insert(decls.end(), smtlib2_decls.begin(), smtlib2_decls.end()); + if (verbose) log("=> export logic associated with the initial state\n"); vector<string> init_list; for (auto wire : module->wires()) if (wire->attributes.count(ID::init)) { + if (is_smtlib2_module) + log_error("init attribute not allowed on wires in module with smtlib2_module attribute: wire %s.%s", + log_id(module), log_id(wire)); + RTLIL::SigSpec sig = sigmap(wire); Const val = wire->attributes.at(ID::init); val.bits.resize(GetSize(sig), State::Sx); @@ -985,8 +1058,10 @@ struct Smt2Worker string name_a = get_bool(cell->getPort(ID::A)); string name_en = get_bool(cell->getPort(ID::EN)); - string infostr = (cell->name[0] == '$' && cell->attributes.count(ID::src)) ? cell->attributes.at(ID::src).decode_string() : get_id(cell); - decls.push_back(stringf("; yosys-smt2-%s %d %s\n", cell->type.c_str() + 1, id, infostr.c_str())); + if (cell->name[0] == '$' && cell->attributes.count(ID::src)) + decls.push_back(stringf("; yosys-smt2-%s %d %s %s\n", cell->type.c_str() + 1, id, get_id(cell), cell->attributes.at(ID::src).decode_string().c_str())); + else + decls.push_back(stringf("; yosys-smt2-%s %d %s\n", cell->type.c_str() + 1, id, get_id(cell))); if (cell->type == ID($cover)) decls.push_back(stringf("(define-fun |%s_%c %d| ((state |%s_s|)) Bool (and %s %s)) ; %s\n", @@ -1183,10 +1258,12 @@ struct Smt2Worker data = stringf("(bvor (bvand %s %s) (bvand (select (|%s#%d#%d| state) %s) (bvnot %s)))", data.c_str(), mask.c_str(), get_id(module), arrayid, i, addr.c_str(), mask.c_str()); + string empty_mask(mem->width, '0'); + decls.push_back(stringf("(define-fun |%s#%d#%d| ((state |%s_s|)) (Array (_ BitVec %d) (_ BitVec %d)) " - "(store (|%s#%d#%d| state) %s %s)) ; %s\n", + "(ite (= %s #b%s) (|%s#%d#%d| state) (store (|%s#%d#%d| state) %s %s))) ; %s\n", get_id(module), arrayid, i+1, get_id(module), abits, mem->width, - get_id(module), arrayid, i, addr.c_str(), data.c_str(), get_id(mem->memid))); + mask.c_str(), empty_mask.c_str(), get_id(module), arrayid, i, get_id(module), arrayid, i, addr.c_str(), data.c_str(), get_id(mem->memid))); } } @@ -1531,6 +1608,11 @@ struct Smt2Backend : public Backend { log_header(design, "Executing SMT2 backend.\n"); + log_push(); + Pass::call(design, "bmuxmap"); + Pass::call(design, "demuxmap"); + log_pop(); + size_t argidx; for (argidx = 1; argidx < args.size(); argidx++) { @@ -1657,7 +1739,10 @@ struct Smt2Backend : public Backend { for (auto module : sorted_modules) { - if (module->get_blackbox_attribute() || module->has_processes_warn()) + if (module->get_blackbox_attribute() && !module->has_attribute(ID::smtlib2_module)) + continue; + + if (module->has_processes_warn()) continue; log("Creating SMT-LIBv2 representation of module %s.\n", log_id(module)); diff --git a/backends/smt2/smtbmc.py b/backends/smt2/smtbmc.py index e5cfcdc08..137182f33 100644 --- a/backends/smt2/smtbmc.py +++ b/backends/smt2/smtbmc.py @@ -50,6 +50,7 @@ smtcinit = False smtctop = None noinit = False binarymode = False +keep_going = False so = SmtOpts() @@ -143,7 +144,7 @@ def usage(): --dump-all when using -g or -i, create a dump file for each - step. The character '%' is replaces in all dump + step. The character '%' is replaced in all dump filenames with the step number. --append <num_steps> @@ -153,6 +154,13 @@ def usage(): --binary dump anyconst values as raw bit strings + + --keep-going + continue BMC after the first failed assertion and report + further failed assertions. To output multiple traces + covering all found failed assertions, the character '%' is + replaced in all dump filenames with an increasing number. + """ + so.helpmsg()) sys.exit(1) @@ -161,7 +169,7 @@ try: opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:igcm:", so.longopts + ["final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "btorwit=", "presat", "dump-vcd=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=", - "smtc-init", "smtc-top=", "noinit", "binary"]) + "smtc-init", "smtc-top=", "noinit", "binary", "keep-going"]) except: usage() @@ -234,6 +242,8 @@ for o, a in opts: topmod = a elif o == "--binary": binarymode = True + elif o == "--keep-going": + keep_going = True elif so.handle(o, a): pass else: @@ -341,13 +351,13 @@ for fn in inconstr: assert False -def get_constr_expr(db, state, final=False, getvalues=False): +def get_constr_expr(db, state, final=False, getvalues=False, individual=False): if final: if ("final-%d" % state) not in db: - return ([], [], []) if getvalues else "true" + return ([], [], []) if getvalues or individual else "true" else: if state not in db: - return ([], [], []) if getvalues else "true" + return ([], [], []) if getvalues or individual else "true" netref_regex = re.compile(r'(^|[( ])\[(-?[0-9]+:|)([^\]]*|\S*)\](?=[ )]|$)') @@ -368,15 +378,18 @@ def get_constr_expr(db, state, final=False, getvalues=False): expr_list = list() for loc, expr in db[("final-%d" % state) if final else state]: actual_expr = netref_regex.sub(replace_netref, expr) - if getvalues: + if getvalues or individual: expr_list.append((loc, expr, actual_expr)) else: expr_list.append(actual_expr) - if getvalues: - loc_list, expr_list, acual_expr_list = zip(*expr_list) - value_list = smt.get_list(acual_expr_list) - return loc_list, expr_list, value_list + if getvalues or individual: + loc_list, expr_list, actual_expr_list = zip(*expr_list) + if individual: + return loc_list, expr_list, actual_expr_list + else: + value_list = smt.get_list(actual_expr_list) + return loc_list, expr_list, value_list if len(expr_list) == 0: return "true" @@ -492,7 +505,7 @@ if aimfile is not None: got_state = True for entry in f.read().splitlines(): - if len(entry) == 0 or entry[0] in "bcjfu.": + if len(entry) == 0 or entry[0] in "bcjfu.#": continue if not got_state: @@ -583,7 +596,10 @@ if aimfile is not None: if not got_topt: skip_steps = max(skip_steps, step) - num_steps = max(num_steps, step+1) + # some solvers optimize the properties so that they fail one cycle early, + # thus we check the properties in the cycle the aiger witness ends, and + # if that doesn't work, we check the cycle after that as well. + num_steps = max(num_steps, step+2) step += 1 if btorwitfile is not None: @@ -826,7 +842,7 @@ def char_ok_in_verilog(c,i): return False def escape_identifier(identifier): - if type(identifier) is list: + if type(identifier) is list: return map(escape_identifier, identifier) if "." in identifier: return ".".join(escape_identifier(identifier.split("."))) @@ -1068,7 +1084,7 @@ def write_trace(steps_start, steps_stop, index): write_constr_trace(steps_start, steps_stop, index) -def print_failed_asserts_worker(mod, state, path, extrainfo): +def print_failed_asserts_worker(mod, state, path, extrainfo, infomap, infokey=()): assert mod in smt.modinfo found_failed_assert = False @@ -1076,29 +1092,31 @@ def print_failed_asserts_worker(mod, state, path, extrainfo): return for cellname, celltype in smt.modinfo[mod].cells.items(): - if print_failed_asserts_worker(celltype, "(|%s_h %s| %s)" % (mod, cellname, state), path + "." + cellname, extrainfo): + cell_infokey = (mod, cellname, infokey) + if print_failed_asserts_worker(celltype, "(|%s_h %s| %s)" % (mod, cellname, state), path + "." + cellname, extrainfo, infomap, cell_infokey): found_failed_assert = True for assertfun, assertinfo in smt.modinfo[mod].asserts.items(): if smt.get("(|%s| %s)" % (assertfun, state)) in ["false", "#b0"]: - print_msg("Assert failed in %s: %s%s" % (path, assertinfo, extrainfo)) + assert_key = (assertfun, infokey) + print_msg("Assert failed in %s: %s%s%s" % (path, assertinfo, extrainfo, infomap.get(assert_key, ''))) found_failed_assert = True return found_failed_assert -def print_failed_asserts(state, final=False, extrainfo=""): +def print_failed_asserts(state, final=False, extrainfo="", infomap={}): if noinfo: return loc_list, expr_list, value_list = get_constr_expr(constr_asserts, state, final=final, getvalues=True) found_failed_assert = False for loc, expr, value in zip(loc_list, expr_list, value_list): if smt.bv2int(value) == 0: - print_msg("Assert %s failed: %s%s" % (loc, expr, extrainfo)) + print_msg("Assert %s failed: %s%s%s" % (loc, expr, extrainfo, infomap.get(loc, ''))) found_failed_assert = True if not final: - if print_failed_asserts_worker(topmod, "s%d" % state, topmod, extrainfo): + if print_failed_asserts_worker(topmod, "s%d" % state, topmod, extrainfo, infomap): found_failed_assert = True return found_failed_assert @@ -1145,6 +1163,43 @@ def get_cover_list(mod, base): return cover_expr, cover_desc + +def get_assert_map(mod, base, path, key_base=()): + assert mod in smt.modinfo + + assert_map = dict() + + for expr, desc in smt.modinfo[mod].asserts.items(): + assert_map[(expr, key_base)] = ("(|%s| %s)" % (expr, base), path, desc) + + for cell, submod in smt.modinfo[mod].cells.items(): + assert_map.update(get_assert_map(submod, "(|%s_h %s| %s)" % (mod, cell, base), path + "." + cell, (mod, cell, key_base))) + + return assert_map + + +def get_assert_keys(): + keys = set() + keys.update(get_assert_map(topmod, 'state', topmod).keys()) + for step_constr_asserts in constr_asserts.values(): + keys.update(loc for loc, expr in step_constr_asserts) + + return keys + + +def get_active_assert_map(step, active): + assert_map = dict() + for key, assert_data in get_assert_map(topmod, "s%s" % step, topmod).items(): + if key in active: + assert_map[key] = assert_data + + for loc, expr, actual_expr in zip(*get_constr_expr(constr_asserts, step, individual=True)): + if loc in active: + assert_map[loc] = (actual_expr, None, (expr, loc)) + + return assert_map + + states = list() asserts_antecedent_cache = [list()] asserts_consequent_cache = [list()] @@ -1454,6 +1509,10 @@ elif covermode: print_msg("Unreached cover statement at %s." % cover_desc[i]) else: # not tempind, covermode + active_assert_keys = get_assert_keys() + failed_assert_infomap = dict() + traceidx = 0 + step = 0 retstatus = "PASSED" while step < num_steps: @@ -1507,44 +1566,83 @@ else: # not tempind, covermode break if not final_only: - if last_check_step == step: - print_msg("Checking assertions in step %d.." % (step)) - else: - print_msg("Checking assertions in steps %d to %d.." % (step, last_check_step)) - smt_push() - - smt_assert("(not (and %s))" % " ".join(["(|%s_a| s%d)" % (topmod, i) for i in range(step, last_check_step+1)] + - [get_constr_expr(constr_asserts, i) for i in range(step, last_check_step+1)])) - - if smt_check_sat() == "sat": - print("%s BMC failed!" % smt.timestamp()) - if append_steps > 0: - for i in range(last_check_step+1, last_check_step+1+append_steps): - print_msg("Appending additional step %d." % i) - smt_state(i) - smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i)) - smt_assert_consequent("(|%s_u| s%d)" % (topmod, i)) - smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i)) - smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i)) - smt_assert_consequent(get_constr_expr(constr_assumes, i)) - print_msg("Re-solving with appended steps..") - if smt_check_sat() == "unsat": - print("%s Cannot append steps without violating assumptions!" % smt.timestamp()) - retstatus = "FAILED" - break - print_anyconsts(step) + recheck_current_step = True + while recheck_current_step: + recheck_current_step = False + if last_check_step == step: + print_msg("Checking assertions in step %d.." % (step)) + else: + print_msg("Checking assertions in steps %d to %d.." % (step, last_check_step)) + smt_push() + + active_assert_maps = dict() + active_assert_exprs = list() for i in range(step, last_check_step+1): - print_failed_asserts(i) - write_trace(0, last_check_step+1+append_steps, '%') - retstatus = "FAILED" - break + assert_expr_map = get_active_assert_map(i, active_assert_keys) + active_assert_maps[i] = assert_expr_map + active_assert_exprs.extend(assert_data[0] for assert_data in assert_expr_map.values()) - smt_pop() + if active_assert_exprs: + if len(active_assert_exprs) == 1: + active_assert_expr = active_assert_exprs[0] + else: + active_assert_expr = "(and %s)" % " ".join(active_assert_exprs) + + smt_assert("(not %s)" % active_assert_expr) + else: + smt_assert("false") + + + if smt_check_sat() == "sat": + if retstatus != "FAILED": + print("%s BMC failed!" % smt.timestamp()) + + if append_steps > 0: + for i in range(last_check_step+1, last_check_step+1+append_steps): + print_msg("Appending additional step %d." % i) + smt_state(i) + smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i)) + smt_assert_consequent("(|%s_u| s%d)" % (topmod, i)) + smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i)) + smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i)) + smt_assert_consequent(get_constr_expr(constr_assumes, i)) + print_msg("Re-solving with appended steps..") + if smt_check_sat() == "unsat": + print("%s Cannot append steps without violating assumptions!" % smt.timestamp()) + retstatus = "FAILED" + break + print_anyconsts(step) + + for i in range(step, last_check_step+1): + print_failed_asserts(i, infomap=failed_assert_infomap) + + if keep_going: + for i in range(step, last_check_step+1): + for key, (expr, path, desc) in active_assert_maps[i].items(): + if key in active_assert_keys and not smt.bv2int(smt.get(expr)): + failed_assert_infomap[key] = " [failed before]" + + active_assert_keys.remove(key) + + if active_assert_keys: + recheck_current_step = True + + write_trace(0, last_check_step+1+append_steps, "%d" % traceidx if keep_going else '%') + traceidx += 1 + retstatus = "FAILED" + + smt_pop() + if recheck_current_step: + print_msg("Checking remaining assertions..") + + if retstatus == "FAILED" and not (keep_going and active_assert_keys): + break if (constr_final_start is not None) or (last_check_step+1 != num_steps): for i in range(step, last_check_step+1): - smt_assert("(|%s_a| s%d)" % (topmod, i)) - smt_assert(get_constr_expr(constr_asserts, i)) + assert_expr_map = get_active_assert_map(i, active_assert_keys) + for assert_data in assert_expr_map.values(): + smt_assert(assert_data[0]) if constr_final_start is not None: for i in range(step, last_check_step+1): diff --git a/backends/smt2/smtio.py b/backends/smt2/smtio.py index d73a875ba..91efc13a3 100644 --- a/backends/smt2/smtio.py +++ b/backends/smt2/smtio.py @@ -20,7 +20,7 @@ import sys, re, os, signal import subprocess if os.name == "posix": import resource -from copy import deepcopy +from copy import copy from select import select from time import time from queue import Queue, Empty @@ -123,6 +123,7 @@ class SmtIo: self.forall = False self.timeout = 0 self.produce_models = True + self.recheck = False self.smt2cache = [list()] self.smt2_options = dict() self.p = None @@ -176,7 +177,10 @@ class SmtIo: self.unroll = False if self.solver == "yices": - if self.noincr or self.forall: + if self.forall: + self.noincr = True + + if self.noincr: self.popen_vargs = ['yices-smt2'] + self.solver_opts else: self.popen_vargs = ['yices-smt2', '--incremental'] + self.solver_opts @@ -189,11 +193,12 @@ class SmtIo: if self.timeout != 0: self.popen_vargs.append('-T:%d' % self.timeout); - if self.solver == "cvc4": + if self.solver in ["cvc4", "cvc5"]: + self.recheck = True if self.noincr: - self.popen_vargs = ['cvc4', '--lang', 'smt2.6' if self.logic_dt else 'smt2'] + self.solver_opts + self.popen_vargs = [self.solver, '--lang', 'smt2.6' if self.logic_dt else 'smt2'] + self.solver_opts else: - self.popen_vargs = ['cvc4', '--incremental', '--lang', 'smt2.6' if self.logic_dt else 'smt2'] + self.solver_opts + self.popen_vargs = [self.solver, '--incremental', '--lang', 'smt2.6' if self.logic_dt else 'smt2'] + self.solver_opts if self.timeout != 0: self.popen_vargs.append('--tlimit=%d000' % self.timeout); @@ -301,7 +306,7 @@ class SmtIo: key = tuple(stmt) if key not in self.unroll_cache: - decl = deepcopy(self.unroll_decls[key[0]]) + decl = copy(self.unroll_decls[key[0]]) self.unroll_cache[key] = "|UNROLL#%d|" % self.unroll_idcnt decl[1] = self.unroll_cache[key] @@ -404,6 +409,15 @@ class SmtIo: stmt = re.sub(r" *;.*", "", stmt) if stmt == "": return + recheck = None + + if self.solver != "dummy": + if self.noincr: + # Don't close the solver yet, if we're just unrolling definitions + # required for a (get-...) statement + if self.p is not None and not stmt.startswith("(get-") and unroll: + self.p_close() + if unroll and self.unroll: stmt = self.unroll_buffer + stmt self.unroll_buffer = "" @@ -415,6 +429,9 @@ class SmtIo: s = self.parse(stmt) + if self.recheck and s and s[0].startswith("get-"): + recheck = self.unroll_idcnt + if self.debug_print: print("-> %s" % s) @@ -440,12 +457,15 @@ class SmtIo: stmt = self.unparse(self.unroll_stmt(s)) + if recheck is not None and recheck != self.unroll_idcnt: + self.check_sat(["sat"]) + if stmt == "(push 1)": self.unroll_stack.append(( - deepcopy(self.unroll_sorts), - deepcopy(self.unroll_objs), - deepcopy(self.unroll_decls), - deepcopy(self.unroll_cache), + copy(self.unroll_sorts), + copy(self.unroll_objs), + copy(self.unroll_decls), + copy(self.unroll_cache), )) if stmt == "(pop 1)": @@ -460,8 +480,6 @@ class SmtIo: if self.solver != "dummy": if self.noincr: - if self.p is not None and not stmt.startswith("(get-"): - self.p_close() if stmt == "(push 1)": self.smt2cache.append(list()) elif stmt == "(pop 1)": @@ -536,10 +554,16 @@ class SmtIo: self.modinfo[self.curmod].clocks[fields[2]] = "event" if fields[1] == "yosys-smt2-assert": - self.modinfo[self.curmod].asserts["%s_a %s" % (self.curmod, fields[2])] = fields[3] + if len(fields) > 4: + self.modinfo[self.curmod].asserts["%s_a %s" % (self.curmod, fields[2])] = f'{fields[4]} ({fields[3]})' + else: + self.modinfo[self.curmod].asserts["%s_a %s" % (self.curmod, fields[2])] = fields[3] if fields[1] == "yosys-smt2-cover": - self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = fields[3] + if len(fields) > 4: + self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = f'{fields[4]} ({fields[3]})' + else: + self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = fields[3] if fields[1] == "yosys-smt2-maximize": self.modinfo[self.curmod].maximize.add(fields[2]) |