aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--cryptography/bindings/openssl/api.py18
-rw-r--r--cryptography/primitives/block/modes.py6
-rw-r--r--tests/bindings/test_openssl.py7
3 files changed, 13 insertions, 18 deletions
diff --git a/cryptography/bindings/openssl/api.py b/cryptography/bindings/openssl/api.py
index 17823786..02957d74 100644
--- a/cryptography/bindings/openssl/api.py
+++ b/cryptography/bindings/openssl/api.py
@@ -72,13 +72,12 @@ class API(object):
)
evp_cipher = self._lib.EVP_get_cipherbyname(ciphername.encode("ascii"))
assert evp_cipher != self._ffi.NULL
- # TODO: only use the key and initialization_vector as needed. Sometimes
- # this needs to be a DecryptInit, when?
- iv = self._get_iv(mode)
+ iv_nonce = mode.get_iv_or_nonce(self)
+ # TODO: Sometimes this needs to be a DecryptInit, when?
res = self._lib.EVP_EncryptInit_ex(
ctx, evp_cipher, self._ffi.NULL, cipher.key,
- iv
+ iv_nonce
)
assert res != 0
@@ -87,15 +86,8 @@ class API(object):
self._lib.EVP_CIPHER_CTX_set_padding(ctx, 0)
return ctx
- def _get_iv(self, mode):
- # TODO: refactor this to visitor pattern
- klass_name = mode.__class__.__name__
- if klass_name == 'CBC':
- return mode.initialization_vector
- elif klass_name == 'ECB':
- return self._ffi.NULL
- else:
- raise NotImplementedError
+ def get_null_for_ecb(self):
+ return self._ffi.NULL
def update_encrypt_context(self, ctx, plaintext):
buf = self._ffi.new("unsigned char[]", len(plaintext))
diff --git a/cryptography/primitives/block/modes.py b/cryptography/primitives/block/modes.py
index ac3392c5..82141437 100644
--- a/cryptography/primitives/block/modes.py
+++ b/cryptography/primitives/block/modes.py
@@ -21,6 +21,12 @@ class CBC(object):
super(CBC, self).__init__()
self.initialization_vector = initialization_vector
+ def get_iv_or_nonce(self, api):
+ return self.initialization_vector
+
class ECB(object):
name = "ECB"
+
+ def get_iv_or_nonce(self, api):
+ return api.get_null_for_ecb()
diff --git a/tests/bindings/test_openssl.py b/tests/bindings/test_openssl.py
index e4b73460..c5927b76 100644
--- a/tests/bindings/test_openssl.py
+++ b/tests/bindings/test_openssl.py
@@ -11,8 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import pytest
-
from cryptography.bindings.openssl import api
@@ -31,6 +29,5 @@ class TestOpenSSL(object):
"""
assert api.openssl_version_text().startswith("OpenSSL")
- def test_get_iv_invalid_mode(self):
- with pytest.raises(NotImplementedError):
- api._get_iv(None)
+ def test_get_null_for_ecb(self):
+ assert api.get_null_for_ecb() == api._ffi.NULL