diff options
-rw-r--r-- | src/cryptography/hazmat/backends/openssl/rsa.py | 4 | ||||
-rw-r--r-- | src/cryptography/hazmat/primitives/asymmetric/padding.py | 3 | ||||
-rw-r--r-- | tests/hazmat/primitives/test_rsa.py | 11 | ||||
-rw-r--r-- | tests/utils.py | 7 |
4 files changed, 10 insertions, 15 deletions
diff --git a/src/cryptography/hazmat/backends/openssl/rsa.py b/src/cryptography/hazmat/backends/openssl/rsa.py index 3e4c2fd2..458071ca 100644 --- a/src/cryptography/hazmat/backends/openssl/rsa.py +++ b/src/cryptography/hazmat/backends/openssl/rsa.py @@ -4,8 +4,6 @@ from __future__ import absolute_import, division, print_function -import math - from cryptography import utils from cryptography.exceptions import ( InvalidSignature, UnsupportedAlgorithm, _Reasons @@ -352,7 +350,7 @@ class _RSAPrivateKey(object): return _RSASignatureContext(self._backend, self, padding, algorithm) def decrypt(self, ciphertext, padding): - key_size_bytes = int(math.ceil(self.key_size / 8.0)) + key_size_bytes = (self.key_size + 7) // 8 if key_size_bytes != len(ciphertext): raise ValueError("Ciphertext length must be equal to key size.") diff --git a/src/cryptography/hazmat/primitives/asymmetric/padding.py b/src/cryptography/hazmat/primitives/asymmetric/padding.py index a37c3f90..828e03bc 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/padding.py +++ b/src/cryptography/hazmat/primitives/asymmetric/padding.py @@ -5,7 +5,6 @@ from __future__ import absolute_import, division, print_function import abc -import math import six @@ -73,7 +72,7 @@ def calculate_max_pss_salt_length(key, hash_algorithm): if not isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)): raise TypeError("key must be an RSA public or private key") # bit length - 1 per RFC 3447 - emlen = int(math.ceil((key.key_size - 1) / 8.0)) + emlen = (key.key_size + 6) // 8 salt_length = emlen - hash_algorithm.digest_size - 2 assert salt_length >= 0 return salt_length diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py index 9a0aaf1a..e6482651 100644 --- a/tests/hazmat/primitives/test_rsa.py +++ b/tests/hazmat/primitives/test_rsa.py @@ -6,7 +6,6 @@ from __future__ import absolute_import, division, print_function import binascii import itertools -import math import os import pytest @@ -434,7 +433,7 @@ class TestRSASignature(object): ), hashes.SHA1() ) - assert len(signature) == math.ceil(private_key.key_size / 8.0) + assert len(signature) == (private_key.key_size + 7) // 8 # PSS signatures contain randomness so we can't do an exact # signature check. Instead we'll verify that the signature created # successfully verifies. @@ -1428,7 +1427,7 @@ class TestRSADecryption(object): ) ).private_key(backend) ciphertext = binascii.unhexlify(example["encryption"]) - assert len(ciphertext) == math.ceil(skey.key_size / 8.0) + assert len(ciphertext) == (skey.key_size + 7) // 8 message = skey.decrypt(ciphertext, padding.PKCS1v15()) assert message == binascii.unhexlify(example["message"]) @@ -1684,7 +1683,7 @@ class TestRSAEncryption(object): public_key = private_key.public_key() ct = public_key.encrypt(pt, pad) assert ct != pt - assert len(ct) == math.ceil(public_key.key_size / 8.0) + assert len(ct) == (public_key.key_size + 7) // 8 recovered_pt = private_key.decrypt(ct, pad) assert recovered_pt == pt @@ -1725,7 +1724,7 @@ class TestRSAEncryption(object): public_key = private_key.public_key() ct = public_key.encrypt(pt, pad) assert ct != pt - assert len(ct) == math.ceil(public_key.key_size / 8.0) + assert len(ct) == (public_key.key_size + 7) // 8 recovered_pt = private_key.decrypt(ct, pad) assert recovered_pt == pt @@ -1750,7 +1749,7 @@ class TestRSAEncryption(object): public_key = private_key.public_key() ct = public_key.encrypt(pt, pad) assert ct != pt - assert len(ct) == math.ceil(public_key.key_size / 8.0) + assert len(ct) == (public_key.key_size + 7) // 8 recovered_pt = private_key.decrypt(ct, pad) assert recovered_pt == pt diff --git a/tests/utils.py b/tests/utils.py index ca3245b0..7e79830b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,6 @@ from __future__ import absolute_import, division, print_function import binascii import collections import json -import math import os import re from contextlib import contextmanager @@ -744,15 +743,15 @@ def load_x963_vectors(vector_data): vector["key_data_length"] = key_data_len elif line.startswith("Z"): vector["Z"] = line.split("=")[1].strip() - assert math.ceil(shared_secret_len / 8) * 2 == len(vector["Z"]) + assert ((shared_secret_len + 7) // 8) * 2 == len(vector["Z"]) elif line.startswith("SharedInfo"): if shared_info_len != 0: vector["sharedinfo"] = line.split("=")[1].strip() silen = len(vector["sharedinfo"]) - assert math.ceil(shared_info_len / 8) * 2 == silen + assert ((shared_info_len + 7) // 8) * 2 == silen elif line.startswith("key_data"): vector["key_data"] = line.split("=")[1].strip() - assert math.ceil(key_data_len / 8) * 2 == len(vector["key_data"]) + assert ((key_data_len + 7) // 8) * 2 == len(vector["key_data"]) vectors.append(vector) vector = {} |