diff options
-rw-r--r-- | cryptography/fernet.py | 18 | ||||
-rw-r--r-- | tests/test_fernet.py | 33 |
2 files changed, 50 insertions, 1 deletions
diff --git a/cryptography/fernet.py b/cryptography/fernet.py index a8e0330e..9fee3eba 100644 --- a/cryptography/fernet.py +++ b/cryptography/fernet.py @@ -127,3 +127,21 @@ class Fernet(object): except ValueError: raise InvalidToken return unpadded + + +class MultiFernet(object): + def __init__(self, fernets): + if not fernets: + raise ValueError("MultiFernet requires at least one fernet") + self._fernets = fernets + + def encrypt(self, msg): + return self._fernets[0].encrypt(msg) + + def decrypt(self, msg, ttl=None): + for f in self._fernets: + try: + return f.decrypt(msg, ttl) + except InvalidToken: + pass + raise InvalidToken diff --git a/tests/test_fernet.py b/tests/test_fernet.py index 0b4e3e87..91af32ad 100644 --- a/tests/test_fernet.py +++ b/tests/test_fernet.py @@ -24,7 +24,7 @@ import pytest import six -from cryptography.fernet import Fernet, InvalidToken +from cryptography.fernet import Fernet, InvalidToken, MultiFernet from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import algorithms, modes @@ -115,3 +115,34 @@ class TestFernet(object): def test_bad_key(self, backend): with pytest.raises(ValueError): Fernet(base64.urlsafe_b64encode(b"abc"), backend=backend) + + +@pytest.mark.supported( + only_if=lambda backend: backend.cipher_supported( + algorithms.AES("\x00" * 32), modes.CBC("\x00" * 16) + ), + skip_message="Does not support AES CBC", +) +class TestMultiFernet(object): + def test_encrypt(self, backend): + single_f = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend) + f = MultiFernet([ + single_f, + Fernet(base64.urlsafe_b64encode(b"\x01" * 32), backend=backend) + ]) + assert single_f.decrypt(f.encrypt(b"abc")) == b"abc" + + def test_decrypt(self, backend): + f1 = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend) + f2 = Fernet(base64.urlsafe_b64encode(b"\x00" * 32), backend=backend) + f = MultiFernet([f1, f2]) + + assert f.decrypt(f1.encrypt(b"abc")) == b"abc" + assert f.decrypt(f2.encrypt(b"abc")) == b"abc" + + with pytest.raises(InvalidToken): + f.decrypt(b"\x00" * 16) + + def test_no_fernets(self, backend): + with pytest.raises(ValueError): + MultiFernet([]) |