aboutsummaryrefslogtreecommitdiffstats
path: root/libs/ezsat/ezsat.cc
diff options
context:
space:
mode:
Diffstat (limited to 'libs/ezsat/ezsat.cc')
-rw-r--r--libs/ezsat/ezsat.cc322
1 files changed, 205 insertions, 117 deletions
diff --git a/libs/ezsat/ezsat.cc b/libs/ezsat/ezsat.cc
index dccc00555..13ed112ed 100644
--- a/libs/ezsat/ezsat.cc
+++ b/libs/ezsat/ezsat.cc
@@ -19,21 +19,21 @@
#include "ezsat.h"
+#include <cmath>
#include <algorithm>
+#include <cassert>
#include <stdlib.h>
-#include <assert.h>
const int ezSAT::TRUE = 1;
const int ezSAT::FALSE = 2;
ezSAT::ezSAT()
{
- literal("TRUE");
- literal("FALSE");
+ flag_keep_cnf = false;
+ flag_non_incremental = false;
- assert(literal("TRUE") == TRUE);
- assert(literal("FALSE") == FALSE);
+ non_incremental_solve_used_up = false;
cnfConsumed = false;
cnfVariableCount = 0;
@@ -41,6 +41,12 @@ ezSAT::ezSAT()
solverTimeout = 0;
solverTimoutStatus = false;
+
+ literal("TRUE");
+ literal("FALSE");
+
+ assert(literal("TRUE") == TRUE);
+ assert(literal("FALSE") == FALSE);
}
ezSAT::~ezSAT()
@@ -67,6 +73,20 @@ int ezSAT::literal(const std::string &name)
return literalsCache.at(name);
}
+int ezSAT::frozen_literal()
+{
+ int id = literal();
+ freeze(id);
+ return id;
+}
+
+int ezSAT::frozen_literal(const std::string &name)
+{
+ int id = literal(name);
+ freeze(id);
+ return id;
+}
+
int ezSAT::expression(OpId op, int a, int b, int c, int d, int e, int f)
{
std::vector<int> args(6);
@@ -342,13 +362,19 @@ void ezSAT::clear()
cnfLiteralVariables.clear();
cnfExpressionVariables.clear();
cnfClauses.clear();
- cnfAssumptions.clear();
}
-void ezSAT::assume(int id)
+void ezSAT::freeze(int)
+{
+}
+
+bool ezSAT::eliminated(int)
{
- cnfAssumptions.insert(id);
+ return false;
+}
+void ezSAT::assume(int id)
+{
if (id < 0)
{
assert(0 < -id && -id <= int(expressions.size()));
@@ -462,11 +488,32 @@ int ezSAT::bound(int id) const
return 0;
}
-int ezSAT::bind(int id)
+std::string ezSAT::cnfLiteralInfo(int idx) const
+{
+ for (size_t i = 0; i < cnfLiteralVariables.size(); i++) {
+ if (cnfLiteralVariables[i] == idx)
+ return to_string(i+1);
+ if (cnfLiteralVariables[i] == -idx)
+ return "NOT " + to_string(i+1);
+ }
+ for (size_t i = 0; i < cnfExpressionVariables.size(); i++) {
+ if (cnfExpressionVariables[i] == idx)
+ return to_string(-i-1);
+ if (cnfExpressionVariables[i] == -idx)
+ return "NOT " + to_string(-i-1);
+ }
+ return "<unnamed>";
+}
+
+int ezSAT::bind(int id, bool auto_freeze)
{
if (id >= 0) {
assert(0 < id && id <= int(literals.size()));
cnfLiteralVariables.resize(literals.size());
+ if (eliminated(cnfLiteralVariables[id-1])) {
+ fprintf(stderr, "ezSAT: Missing freeze on literal `%s'.\n", to_string(id).c_str());
+ abort();
+ }
if (cnfLiteralVariables[id-1] == 0) {
cnfLiteralVariables[id-1] = ++cnfVariableCount;
if (id == TRUE)
@@ -480,6 +527,17 @@ int ezSAT::bind(int id)
assert(0 < -id && -id <= int(expressions.size()));
cnfExpressionVariables.resize(expressions.size());
+ if (eliminated(cnfExpressionVariables[-id-1]))
+ {
+ cnfExpressionVariables[-id-1] = 0;
+
+ // this will recursively call bind(id). within the recursion
+ // the cnf is pre-set to 0. an idx is allocated there, then it
+ // is frozen, then it returns here with the new idx already set.
+ if (auto_freeze)
+ freeze(id);
+ }
+
if (cnfExpressionVariables[-id-1] == 0)
{
OpId op;
@@ -497,7 +555,7 @@ int ezSAT::bind(int id)
newArgs.push_back(OR(AND(args[i], NOT(args[i+1])), AND(NOT(args[i]), args[i+1])));
args.swap(newArgs);
}
- idx = bind(args.at(0));
+ idx = bind(args.at(0), false);
goto assign_idx;
}
@@ -505,17 +563,17 @@ int ezSAT::bind(int id)
std::vector<int> invArgs;
for (auto arg : args)
invArgs.push_back(NOT(arg));
- idx = bind(OR(expression(OpAnd, args), expression(OpAnd, invArgs)));
+ idx = bind(OR(expression(OpAnd, args), expression(OpAnd, invArgs)), false);
goto assign_idx;
}
if (op == OpITE) {
- idx = bind(OR(AND(args[0], args[1]), AND(NOT(args[0]), args[2])));
+ idx = bind(OR(AND(args[0], args[1]), AND(NOT(args[0]), args[2])), false);
goto assign_idx;
}
for (int i = 0; i < int(args.size()); i++)
- args[i] = bind(args[i]);
+ args[i] = bind(args[i], false);
switch (op)
{
@@ -535,120 +593,45 @@ int ezSAT::bind(int id)
void ezSAT::consumeCnf()
{
- cnfConsumed = true;
+ if (mode_keep_cnf())
+ cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end());
+ else
+ cnfConsumed = true;
cnfClauses.clear();
}
void ezSAT::consumeCnf(std::vector<std::vector<int>> &cnf)
{
- cnfConsumed = true;
+ if (mode_keep_cnf())
+ cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end());
+ else
+ cnfConsumed = true;
cnf.swap(cnfClauses);
cnfClauses.clear();
}
-static bool test_bit(uint32_t bitmask, int idx)
+void ezSAT::getFullCnf(std::vector<std::vector<int>> &full_cnf) const
{
- if (idx > 0)
- return (bitmask & (1 << (+idx-1))) != 0;
- else
- return (bitmask & (1 << (-idx-1))) == 0;
+ assert(full_cnf.empty());
+ full_cnf.insert(full_cnf.end(), cnfClausesBackup.begin(), cnfClausesBackup.end());
+ full_cnf.insert(full_cnf.end(), cnfClauses.begin(), cnfClauses.end());
}
-bool ezSAT::solver(const std::vector<int> &modelExpressions, std::vector<bool> &modelValues, const std::vector<int> &assumptions)
+void ezSAT::preSolverCallback()
{
- std::vector<int> extraClauses, modelIdx;
- std::vector<int> values(numLiterals());
-
- for (auto id : assumptions)
- extraClauses.push_back(bind(id));
- for (auto id : modelExpressions)
- modelIdx.push_back(bind(id));
-
- if (cnfVariableCount > 20) {
- fprintf(stderr, "*************************************************************************************\n");
- fprintf(stderr, "ERROR: You are trying to use the builtin solver of ezSAT with more than 20 variables!\n");
- fprintf(stderr, "The builtin solver is a dumb brute force solver and only ment for testing and demo\n");
- fprintf(stderr, "purposes. Use a real SAT solve like MiniSAT (e.g. using the ezMiniSAT class) instead.\n");
- fprintf(stderr, "*************************************************************************************\n");
- abort();
- }
-
- for (uint32_t bitmask = 0; bitmask < (1 << numCnfVariables()); bitmask++)
- {
- // printf("%07o:", int(bitmask));
- // for (int i = 2; i < numLiterals(); i++)
- // if (bound(i+1))
- // printf(" %s=%d", to_string(i+1).c_str(), test_bit(bitmask, bound(i+1)));
- // printf(" |");
- // for (int idx = 1; idx <= numCnfVariables(); idx++)
- // printf(" %3d", test_bit(bitmask, idx) ? idx : -idx);
- // printf("\n");
-
- for (auto idx : extraClauses)
- if (!test_bit(bitmask, idx))
- goto next;
-
- for (auto &clause : cnfClauses) {
- for (auto idx : clause)
- if (test_bit(bitmask, idx))
- goto next_clause;
- // printf("failed clause:");
- // for (auto idx2 : clause)
- // printf(" %3d", idx2);
- // printf("\n");
- goto next;
- next_clause:;
- // printf("passed clause:");
- // for (auto idx2 : clause)
- // printf(" %3d", idx2);
- // printf("\n");
- }
-
- modelValues.resize(modelIdx.size());
- for (int i = 0; i < int(modelIdx.size()); i++)
- modelValues[i] = test_bit(bitmask, modelIdx[i]);
-
- // validate result using eval()
-
- values[0] = TRUE, values[1] = FALSE;
- for (int i = 2; i < numLiterals(); i++) {
- int idx = bound(i+1);
- values[i] = idx != 0 ? (test_bit(bitmask, idx) ? TRUE : FALSE) : 0;
- }
-
- for (auto id : cnfAssumptions) {
- int result = eval(id, values);
- if (result != TRUE) {
- printInternalState(stderr);
- fprintf(stderr, "Variables:");
- for (int i = 0; i < numLiterals(); i++)
- fprintf(stderr, " %s=%s", lookup_literal(i+1).c_str(), values[i] == TRUE ? "TRUE" : values[i] == FALSE ? "FALSE" : "UNDEF");
- fprintf(stderr, "\nValidation of solver results failed: got `%d' (%s) for assumption '%d': %s\n",
- result, result == FALSE ? "FALSE" : "UNDEF", id, to_string(id).c_str());
- abort();
- }
- // printf("OK: %d -> %d\n", id, result);
- }
-
- for (auto id : assumptions) {
- int result = eval(id, values);
- if (result != TRUE) {
- printInternalState(stderr);
- fprintf(stderr, "Variables:");
- for (int i = 0; i < numLiterals(); i++)
- fprintf(stderr, " %s=%s", lookup_literal(i+1).c_str(), values[i] == TRUE ? "TRUE" : values[i] == FALSE ? "FALSE" : "UNDEF");
- fprintf(stderr, "\nValidation of solver results failed: got `%d' (%s) for assumption '%d': %s\n",
- result, result == FALSE ? "FALSE" : "UNDEF", id, to_string(id).c_str());
- abort();
- }
- // printf("OK: %d -> %d\n", id, result);
- }
-
- return true;
- next:;
- }
+ assert(!non_incremental_solve_used_up);
+ if (mode_non_incremental())
+ non_incremental_solve_used_up = true;
+}
- return false;
+bool ezSAT::solver(const std::vector<int>&, std::vector<bool>&, const std::vector<int>&)
+{
+ preSolverCallback();
+ fprintf(stderr, "************************************************************************\n");
+ fprintf(stderr, "ERROR: You are trying to use the solve() method of the ezSAT base class!\n");
+ fprintf(stderr, "Use a dervied class like ezMiniSAT instead.\n");
+ fprintf(stderr, "************************************************************************\n");
+ abort();
}
std::vector<int> ezSAT::vec_const(const std::vector<bool> &bits)
@@ -994,6 +977,96 @@ std::vector<int> ezSAT::vec_srl(const std::vector<int> &vec1, int shift)
return vec;
}
+std::vector<int> ezSAT::vec_shift(const std::vector<int> &vec1, int shift, int extend_left, int extend_right)
+{
+ std::vector<int> vec;
+ for (int i = 0; i < int(vec1.size()); i++) {
+ int j = i+shift;
+ if (j < 0)
+ vec.push_back(extend_right);
+ else if (j >= int(vec1.size()))
+ vec.push_back(extend_left);
+ else
+ vec.push_back(vec1[j]);
+ }
+ return vec;
+}
+
+static int my_clog2(int x)
+{
+ int result = 0;
+ for (x--; x > 0; result++)
+ x >>= 1;
+ return result;
+}
+
+std::vector<int> ezSAT::vec_shift_right(const std::vector<int> &vec1, const std::vector<int> &vec2, bool vec2_signed, int extend_left, int extend_right)
+{
+ int vec2_bits = std::min(my_clog2(vec1.size()) + (vec2_signed ? 1 : 0), int(vec2.size()));
+
+ std::vector<int> overflow_bits(vec2.begin() + vec2_bits, vec2.end());
+ int overflow_left = FALSE, overflow_right = FALSE;
+
+ if (vec2_signed) {
+ int overflow = FALSE;
+ for (auto bit : overflow_bits)
+ overflow = OR(overflow, XOR(bit, vec2[vec2_bits-1]));
+ overflow_left = AND(overflow, NOT(vec2.back()));
+ overflow_right = AND(overflow, vec2.back());
+ } else
+ overflow_left = vec_reduce_or(overflow_bits);
+
+ std::vector<int> buffer = vec1;
+
+ if (vec2_signed)
+ while (buffer.size() < vec1.size() + (1 << vec2_bits))
+ buffer.push_back(extend_left);
+
+ std::vector<int> overflow_pattern_left(buffer.size(), extend_left);
+ std::vector<int> overflow_pattern_right(buffer.size(), extend_right);
+
+ buffer = vec_ite(overflow_left, overflow_pattern_left, buffer);
+
+ if (vec2_signed)
+ buffer = vec_ite(overflow_right, overflow_pattern_left, buffer);
+
+ for (int i = vec2_bits-1; i >= 0; i--) {
+ std::vector<int> shifted_buffer;
+ if (vec2_signed && i == vec2_bits-1)
+ shifted_buffer = vec_shift(buffer, -(1 << i), extend_left, extend_right);
+ else
+ shifted_buffer = vec_shift(buffer, 1 << i, extend_left, extend_right);
+ buffer = vec_ite(vec2[i], shifted_buffer, buffer);
+ }
+
+ buffer.resize(vec1.size());
+ return buffer;
+}
+
+std::vector<int> ezSAT::vec_shift_left(const std::vector<int> &vec1, const std::vector<int> &vec2, bool vec2_signed, int extend_left, int extend_right)
+{
+ // vec2_signed is not implemented in vec_shift_left() yet
+ assert(vec2_signed == false);
+
+ int vec2_bits = std::min(my_clog2(vec1.size()), int(vec2.size()));
+
+ std::vector<int> overflow_bits(vec2.begin() + vec2_bits, vec2.end());
+ int overflow = vec_reduce_or(overflow_bits);
+
+ std::vector<int> buffer = vec1;
+ std::vector<int> overflow_pattern_right(buffer.size(), extend_right);
+ buffer = vec_ite(overflow, overflow_pattern_right, buffer);
+
+ for (int i = 0; i < vec2_bits; i++) {
+ std::vector<int> shifted_buffer;
+ shifted_buffer = vec_shift(buffer, -(1 << i), extend_left, extend_right);
+ buffer = vec_ite(vec2[i], shifted_buffer, buffer);
+ }
+
+ buffer.resize(vec1.size());
+ return buffer;
+}
+
void ezSAT::vec_append(std::vector<int> &vec, const std::vector<int> &vec1) const
{
for (auto bit : vec1)
@@ -1124,17 +1197,32 @@ void ezSAT::printDIMACS(FILE *f, bool verbose) const
if (cnfExpressionVariables[i] != 0)
fprintf(f, "c %*d: %s\n", digits, cnfExpressionVariables[i], to_string(-i-1).c_str());
+ if (mode_keep_cnf()) {
+ fprintf(f, "c\n");
+ fprintf(f, "c %d clauses from backup, %d from current buffer\n",
+ int(cnfClausesBackup.size()), int(cnfClauses.size()));
+ }
+
fprintf(f, "c\n");
}
- fprintf(f, "p cnf %d %d\n", cnfVariableCount, int(cnfClauses.size()));
+ std::vector<std::vector<int>> all_clauses;
+ getFullCnf(all_clauses);
+ assert(cnfClausesCount == int(all_clauses.size()));
+
+ fprintf(f, "p cnf %d %d\n", cnfVariableCount, cnfClausesCount);
int maxClauseLen = 0;
- for (auto &clause : cnfClauses)
+ for (auto &clause : all_clauses)
maxClauseLen = std::max(int(clause.size()), maxClauseLen);
- for (auto &clause : cnfClauses) {
+ if (!verbose)
+ maxClauseLen = std::min(maxClauseLen, 3);
+ for (auto &clause : all_clauses) {
for (auto idx : clause)
fprintf(f, " %*d", digits, idx);
- fprintf(f, " %*d\n", (digits + 1)*int(maxClauseLen - clause.size()) + digits, 0);
+ if (maxClauseLen >= int(clause.size()))
+ fprintf(f, " %*d\n", (digits + 1)*int(maxClauseLen - clause.size()) + digits, 0);
+ else
+ fprintf(f, " %*d\n", digits, 0);
}
}