diff options
-rw-r--r-- | cryptography/exceptions.py | 4 | ||||
-rw-r--r-- | cryptography/hazmat/primitives/hashes.py | 9 | ||||
-rw-r--r-- | tests/hazmat/primitives/test_hashes.py | 11 |
3 files changed, 23 insertions, 1 deletions
diff --git a/cryptography/exceptions.py b/cryptography/exceptions.py index 391bed82..c2e71493 100644 --- a/cryptography/exceptions.py +++ b/cryptography/exceptions.py @@ -14,3 +14,7 @@ class UnsupportedAlgorithm(Exception): pass + + +class AlreadyFinalized(Exception): + pass diff --git a/cryptography/hazmat/primitives/hashes.py b/cryptography/hazmat/primitives/hashes.py index 6ae622cd..f85d36a0 100644 --- a/cryptography/hazmat/primitives/hashes.py +++ b/cryptography/hazmat/primitives/hashes.py @@ -15,6 +15,7 @@ from __future__ import absolute_import, division, print_function import six +from cryptography import exceptions from cryptography.hazmat.primitives import interfaces @@ -37,17 +38,23 @@ class Hash(object): self._ctx = ctx def update(self, data): + if self._ctx is None: + raise exceptions.AlreadyFinalized() if isinstance(data, six.text_type): raise TypeError("Unicode-objects must be encoded before hashing") self._ctx.update(data) def copy(self): + if self._ctx is None: + raise exceptions.AlreadyFinalized() return Hash( self.algorithm, backend=self._backend, ctx=self._ctx.copy() ) def finalize(self): - return self._ctx.finalize() + digest = self._ctx.finalize() + self._ctx = None + return digest @interfaces.register(interfaces.HashAlgorithm) diff --git a/tests/hazmat/primitives/test_hashes.py b/tests/hazmat/primitives/test_hashes.py index 6cdb0a07..a5c440b8 100644 --- a/tests/hazmat/primitives/test_hashes.py +++ b/tests/hazmat/primitives/test_hashes.py @@ -19,6 +19,7 @@ import pytest import six +from cryptography import exceptions from cryptography.hazmat.bindings import _default_backend from cryptography.hazmat.primitives import hashes @@ -51,6 +52,16 @@ class TestHashContext(object): with pytest.raises(TypeError): hashes.Hash(hashes.SHA1) + def test_raises_after_finalize(self): + h = hashes.Hash(hashes.SHA1()) + h.finalize() + + with pytest.raises(exceptions.AlreadyFinalized): + h.update(b"foo") + + with pytest.raises(exceptions.AlreadyFinalized): + h.copy() + class TestSHA1(object): test_SHA1 = generate_base_hash_test( |