diff options
Diffstat (limited to 'libs/ezsat/ezsat.cc')
-rw-r--r-- | libs/ezsat/ezsat.cc | 322 |
1 files changed, 205 insertions, 117 deletions
diff --git a/libs/ezsat/ezsat.cc b/libs/ezsat/ezsat.cc index dccc00555..13ed112ed 100644 --- a/libs/ezsat/ezsat.cc +++ b/libs/ezsat/ezsat.cc @@ -19,21 +19,21 @@ #include "ezsat.h" +#include <cmath> #include <algorithm> +#include <cassert> #include <stdlib.h> -#include <assert.h> const int ezSAT::TRUE = 1; const int ezSAT::FALSE = 2; ezSAT::ezSAT() { - literal("TRUE"); - literal("FALSE"); + flag_keep_cnf = false; + flag_non_incremental = false; - assert(literal("TRUE") == TRUE); - assert(literal("FALSE") == FALSE); + non_incremental_solve_used_up = false; cnfConsumed = false; cnfVariableCount = 0; @@ -41,6 +41,12 @@ ezSAT::ezSAT() solverTimeout = 0; solverTimoutStatus = false; + + literal("TRUE"); + literal("FALSE"); + + assert(literal("TRUE") == TRUE); + assert(literal("FALSE") == FALSE); } ezSAT::~ezSAT() @@ -67,6 +73,20 @@ int ezSAT::literal(const std::string &name) return literalsCache.at(name); } +int ezSAT::frozen_literal() +{ + int id = literal(); + freeze(id); + return id; +} + +int ezSAT::frozen_literal(const std::string &name) +{ + int id = literal(name); + freeze(id); + return id; +} + int ezSAT::expression(OpId op, int a, int b, int c, int d, int e, int f) { std::vector<int> args(6); @@ -342,13 +362,19 @@ void ezSAT::clear() cnfLiteralVariables.clear(); cnfExpressionVariables.clear(); cnfClauses.clear(); - cnfAssumptions.clear(); } -void ezSAT::assume(int id) +void ezSAT::freeze(int) +{ +} + +bool ezSAT::eliminated(int) { - cnfAssumptions.insert(id); + return false; +} +void ezSAT::assume(int id) +{ if (id < 0) { assert(0 < -id && -id <= int(expressions.size())); @@ -462,11 +488,32 @@ int ezSAT::bound(int id) const return 0; } -int ezSAT::bind(int id) +std::string ezSAT::cnfLiteralInfo(int idx) const +{ + for (size_t i = 0; i < cnfLiteralVariables.size(); i++) { + if (cnfLiteralVariables[i] == idx) + return to_string(i+1); + if (cnfLiteralVariables[i] == -idx) + return "NOT " + to_string(i+1); + } + for (size_t i = 0; i < cnfExpressionVariables.size(); i++) { + if (cnfExpressionVariables[i] == idx) + return to_string(-i-1); + if (cnfExpressionVariables[i] == -idx) + return "NOT " + to_string(-i-1); + } + return "<unnamed>"; +} + +int ezSAT::bind(int id, bool auto_freeze) { if (id >= 0) { assert(0 < id && id <= int(literals.size())); cnfLiteralVariables.resize(literals.size()); + if (eliminated(cnfLiteralVariables[id-1])) { + fprintf(stderr, "ezSAT: Missing freeze on literal `%s'.\n", to_string(id).c_str()); + abort(); + } if (cnfLiteralVariables[id-1] == 0) { cnfLiteralVariables[id-1] = ++cnfVariableCount; if (id == TRUE) @@ -480,6 +527,17 @@ int ezSAT::bind(int id) assert(0 < -id && -id <= int(expressions.size())); cnfExpressionVariables.resize(expressions.size()); + if (eliminated(cnfExpressionVariables[-id-1])) + { + cnfExpressionVariables[-id-1] = 0; + + // this will recursively call bind(id). within the recursion + // the cnf is pre-set to 0. an idx is allocated there, then it + // is frozen, then it returns here with the new idx already set. + if (auto_freeze) + freeze(id); + } + if (cnfExpressionVariables[-id-1] == 0) { OpId op; @@ -497,7 +555,7 @@ int ezSAT::bind(int id) newArgs.push_back(OR(AND(args[i], NOT(args[i+1])), AND(NOT(args[i]), args[i+1]))); args.swap(newArgs); } - idx = bind(args.at(0)); + idx = bind(args.at(0), false); goto assign_idx; } @@ -505,17 +563,17 @@ int ezSAT::bind(int id) std::vector<int> invArgs; for (auto arg : args) invArgs.push_back(NOT(arg)); - idx = bind(OR(expression(OpAnd, args), expression(OpAnd, invArgs))); + idx = bind(OR(expression(OpAnd, args), expression(OpAnd, invArgs)), false); goto assign_idx; } if (op == OpITE) { - idx = bind(OR(AND(args[0], args[1]), AND(NOT(args[0]), args[2]))); + idx = bind(OR(AND(args[0], args[1]), AND(NOT(args[0]), args[2])), false); goto assign_idx; } for (int i = 0; i < int(args.size()); i++) - args[i] = bind(args[i]); + args[i] = bind(args[i], false); switch (op) { @@ -535,120 +593,45 @@ int ezSAT::bind(int id) void ezSAT::consumeCnf() { - cnfConsumed = true; + if (mode_keep_cnf()) + cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end()); + else + cnfConsumed = true; cnfClauses.clear(); } void ezSAT::consumeCnf(std::vector<std::vector<int>> &cnf) { - cnfConsumed = true; + if (mode_keep_cnf()) + cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end()); + else + cnfConsumed = true; cnf.swap(cnfClauses); cnfClauses.clear(); } -static bool test_bit(uint32_t bitmask, int idx) +void ezSAT::getFullCnf(std::vector<std::vector<int>> &full_cnf) const { - if (idx > 0) - return (bitmask & (1 << (+idx-1))) != 0; - else - return (bitmask & (1 << (-idx-1))) == 0; + assert(full_cnf.empty()); + full_cnf.insert(full_cnf.end(), cnfClausesBackup.begin(), cnfClausesBackup.end()); + full_cnf.insert(full_cnf.end(), cnfClauses.begin(), cnfClauses.end()); } -bool ezSAT::solver(const std::vector<int> &modelExpressions, std::vector<bool> &modelValues, const std::vector<int> &assumptions) +void ezSAT::preSolverCallback() { - std::vector<int> extraClauses, modelIdx; - std::vector<int> values(numLiterals()); - - for (auto id : assumptions) - extraClauses.push_back(bind(id)); - for (auto id : modelExpressions) - modelIdx.push_back(bind(id)); - - if (cnfVariableCount > 20) { - fprintf(stderr, "*************************************************************************************\n"); - fprintf(stderr, "ERROR: You are trying to use the builtin solver of ezSAT with more than 20 variables!\n"); - fprintf(stderr, "The builtin solver is a dumb brute force solver and only ment for testing and demo\n"); - fprintf(stderr, "purposes. Use a real SAT solve like MiniSAT (e.g. using the ezMiniSAT class) instead.\n"); - fprintf(stderr, "*************************************************************************************\n"); - abort(); - } - - for (uint32_t bitmask = 0; bitmask < (1 << numCnfVariables()); bitmask++) - { - // printf("%07o:", int(bitmask)); - // for (int i = 2; i < numLiterals(); i++) - // if (bound(i+1)) - // printf(" %s=%d", to_string(i+1).c_str(), test_bit(bitmask, bound(i+1))); - // printf(" |"); - // for (int idx = 1; idx <= numCnfVariables(); idx++) - // printf(" %3d", test_bit(bitmask, idx) ? idx : -idx); - // printf("\n"); - - for (auto idx : extraClauses) - if (!test_bit(bitmask, idx)) - goto next; - - for (auto &clause : cnfClauses) { - for (auto idx : clause) - if (test_bit(bitmask, idx)) - goto next_clause; - // printf("failed clause:"); - // for (auto idx2 : clause) - // printf(" %3d", idx2); - // printf("\n"); - goto next; - next_clause:; - // printf("passed clause:"); - // for (auto idx2 : clause) - // printf(" %3d", idx2); - // printf("\n"); - } - - modelValues.resize(modelIdx.size()); - for (int i = 0; i < int(modelIdx.size()); i++) - modelValues[i] = test_bit(bitmask, modelIdx[i]); - - // validate result using eval() - - values[0] = TRUE, values[1] = FALSE; - for (int i = 2; i < numLiterals(); i++) { - int idx = bound(i+1); - values[i] = idx != 0 ? (test_bit(bitmask, idx) ? TRUE : FALSE) : 0; - } - - for (auto id : cnfAssumptions) { - int result = eval(id, values); - if (result != TRUE) { - printInternalState(stderr); - fprintf(stderr, "Variables:"); - for (int i = 0; i < numLiterals(); i++) - fprintf(stderr, " %s=%s", lookup_literal(i+1).c_str(), values[i] == TRUE ? "TRUE" : values[i] == FALSE ? "FALSE" : "UNDEF"); - fprintf(stderr, "\nValidation of solver results failed: got `%d' (%s) for assumption '%d': %s\n", - result, result == FALSE ? "FALSE" : "UNDEF", id, to_string(id).c_str()); - abort(); - } - // printf("OK: %d -> %d\n", id, result); - } - - for (auto id : assumptions) { - int result = eval(id, values); - if (result != TRUE) { - printInternalState(stderr); - fprintf(stderr, "Variables:"); - for (int i = 0; i < numLiterals(); i++) - fprintf(stderr, " %s=%s", lookup_literal(i+1).c_str(), values[i] == TRUE ? "TRUE" : values[i] == FALSE ? "FALSE" : "UNDEF"); - fprintf(stderr, "\nValidation of solver results failed: got `%d' (%s) for assumption '%d': %s\n", - result, result == FALSE ? "FALSE" : "UNDEF", id, to_string(id).c_str()); - abort(); - } - // printf("OK: %d -> %d\n", id, result); - } - - return true; - next:; - } + assert(!non_incremental_solve_used_up); + if (mode_non_incremental()) + non_incremental_solve_used_up = true; +} - return false; +bool ezSAT::solver(const std::vector<int>&, std::vector<bool>&, const std::vector<int>&) +{ + preSolverCallback(); + fprintf(stderr, "************************************************************************\n"); + fprintf(stderr, "ERROR: You are trying to use the solve() method of the ezSAT base class!\n"); + fprintf(stderr, "Use a dervied class like ezMiniSAT instead.\n"); + fprintf(stderr, "************************************************************************\n"); + abort(); } std::vector<int> ezSAT::vec_const(const std::vector<bool> &bits) @@ -994,6 +977,96 @@ std::vector<int> ezSAT::vec_srl(const std::vector<int> &vec1, int shift) return vec; } +std::vector<int> ezSAT::vec_shift(const std::vector<int> &vec1, int shift, int extend_left, int extend_right) +{ + std::vector<int> vec; + for (int i = 0; i < int(vec1.size()); i++) { + int j = i+shift; + if (j < 0) + vec.push_back(extend_right); + else if (j >= int(vec1.size())) + vec.push_back(extend_left); + else + vec.push_back(vec1[j]); + } + return vec; +} + +static int my_clog2(int x) +{ + int result = 0; + for (x--; x > 0; result++) + x >>= 1; + return result; +} + +std::vector<int> ezSAT::vec_shift_right(const std::vector<int> &vec1, const std::vector<int> &vec2, bool vec2_signed, int extend_left, int extend_right) +{ + int vec2_bits = std::min(my_clog2(vec1.size()) + (vec2_signed ? 1 : 0), int(vec2.size())); + + std::vector<int> overflow_bits(vec2.begin() + vec2_bits, vec2.end()); + int overflow_left = FALSE, overflow_right = FALSE; + + if (vec2_signed) { + int overflow = FALSE; + for (auto bit : overflow_bits) + overflow = OR(overflow, XOR(bit, vec2[vec2_bits-1])); + overflow_left = AND(overflow, NOT(vec2.back())); + overflow_right = AND(overflow, vec2.back()); + } else + overflow_left = vec_reduce_or(overflow_bits); + + std::vector<int> buffer = vec1; + + if (vec2_signed) + while (buffer.size() < vec1.size() + (1 << vec2_bits)) + buffer.push_back(extend_left); + + std::vector<int> overflow_pattern_left(buffer.size(), extend_left); + std::vector<int> overflow_pattern_right(buffer.size(), extend_right); + + buffer = vec_ite(overflow_left, overflow_pattern_left, buffer); + + if (vec2_signed) + buffer = vec_ite(overflow_right, overflow_pattern_left, buffer); + + for (int i = vec2_bits-1; i >= 0; i--) { + std::vector<int> shifted_buffer; + if (vec2_signed && i == vec2_bits-1) + shifted_buffer = vec_shift(buffer, -(1 << i), extend_left, extend_right); + else + shifted_buffer = vec_shift(buffer, 1 << i, extend_left, extend_right); + buffer = vec_ite(vec2[i], shifted_buffer, buffer); + } + + buffer.resize(vec1.size()); + return buffer; +} + +std::vector<int> ezSAT::vec_shift_left(const std::vector<int> &vec1, const std::vector<int> &vec2, bool vec2_signed, int extend_left, int extend_right) +{ + // vec2_signed is not implemented in vec_shift_left() yet + assert(vec2_signed == false); + + int vec2_bits = std::min(my_clog2(vec1.size()), int(vec2.size())); + + std::vector<int> overflow_bits(vec2.begin() + vec2_bits, vec2.end()); + int overflow = vec_reduce_or(overflow_bits); + + std::vector<int> buffer = vec1; + std::vector<int> overflow_pattern_right(buffer.size(), extend_right); + buffer = vec_ite(overflow, overflow_pattern_right, buffer); + + for (int i = 0; i < vec2_bits; i++) { + std::vector<int> shifted_buffer; + shifted_buffer = vec_shift(buffer, -(1 << i), extend_left, extend_right); + buffer = vec_ite(vec2[i], shifted_buffer, buffer); + } + + buffer.resize(vec1.size()); + return buffer; +} + void ezSAT::vec_append(std::vector<int> &vec, const std::vector<int> &vec1) const { for (auto bit : vec1) @@ -1124,17 +1197,32 @@ void ezSAT::printDIMACS(FILE *f, bool verbose) const if (cnfExpressionVariables[i] != 0) fprintf(f, "c %*d: %s\n", digits, cnfExpressionVariables[i], to_string(-i-1).c_str()); + if (mode_keep_cnf()) { + fprintf(f, "c\n"); + fprintf(f, "c %d clauses from backup, %d from current buffer\n", + int(cnfClausesBackup.size()), int(cnfClauses.size())); + } + fprintf(f, "c\n"); } - fprintf(f, "p cnf %d %d\n", cnfVariableCount, int(cnfClauses.size())); + std::vector<std::vector<int>> all_clauses; + getFullCnf(all_clauses); + assert(cnfClausesCount == int(all_clauses.size())); + + fprintf(f, "p cnf %d %d\n", cnfVariableCount, cnfClausesCount); int maxClauseLen = 0; - for (auto &clause : cnfClauses) + for (auto &clause : all_clauses) maxClauseLen = std::max(int(clause.size()), maxClauseLen); - for (auto &clause : cnfClauses) { + if (!verbose) + maxClauseLen = std::min(maxClauseLen, 3); + for (auto &clause : all_clauses) { for (auto idx : clause) fprintf(f, " %*d", digits, idx); - fprintf(f, " %*d\n", (digits + 1)*int(maxClauseLen - clause.size()) + digits, 0); + if (maxClauseLen >= int(clause.size())) + fprintf(f, " %*d\n", (digits + 1)*int(maxClauseLen - clause.size()) + digits, 0); + else + fprintf(f, " %*d\n", digits, 0); } } |