Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

encryptor: add support for symmetric encryption #165

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 200 additions & 73 deletions rohmu/encryptor.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
"""
rohmu - content encryption

Copyright (c) 2016 Ohmu Ltd
Copyright (c) 2023 Aiven, Helsinki, Finland. https://aiven.io/
See LICENSE for details
"""
from .common.constants import IO_BLOCK_SIZE
from .errors import UninitializedError
from .filewrap import FileWrap, Sink, Stream
from .typing import BinaryData, FileLike, HasRead, HasSeek, HasWrite
from abc import ABC, abstractmethod
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, CipherContext, modes
from cryptography.hazmat.primitives.hashes import SHA1, SHA256
from cryptography.hazmat.primitives.hmac import HMAC
from typing import Optional, Union
from typing import Callable, Optional, Union

import io
import logging
Expand All @@ -30,16 +31,15 @@ class EncryptorError(Exception):
"""EncryptorError"""


class Encryptor:
def __init__(self, rsa_public_key_pem: Union[str, bytes]):
if not isinstance(rsa_public_key_pem, bytes):
rsa_public_key_pem = rsa_public_key_pem.encode("ascii")
public_key = serialization.load_pem_public_key(rsa_public_key_pem, backend=default_backend())
assert isinstance(public_key, RSAPublicKey)
self.rsa_public_key = public_key
class BaseEncryptor(ABC):
def __init__(self) -> None:
self._cipher: Optional[CipherContext] = None
self._authenticator: Optional[HMAC] = None

@abstractmethod
def init_cipher(self) -> bytes:
pass

@property
def cipher(self) -> CipherContext:
if self._cipher is None:
Expand All @@ -55,14 +55,8 @@ def authenticator(self) -> HMAC:
def update(self, data: bytes) -> bytes:
ret = b""
if self._cipher is None:
key = os.urandom(16)
nonce = os.urandom(16)
auth_key = os.urandom(32)
self._cipher = Cipher(algorithms.AES(key), modes.CTR(nonce), backend=default_backend()).encryptor()
self._authenticator = HMAC(auth_key, SHA256(), backend=default_backend())
pad = padding.OAEP(mgf=padding.MGF1(algorithm=SHA1()), algorithm=SHA1(), label=None)
cipherkey = self.rsa_public_key.encrypt(key + nonce + auth_key, pad)
ret = FILEMAGIC + struct.pack(">H", len(cipherkey)) + cipherkey
ret = self.init_cipher()

cur = self.cipher.update(data)
self.authenticator.update(cur)
if ret:
Expand All @@ -82,16 +76,15 @@ def finalize(self) -> bytes:
return ret


class EncryptorFile(FileWrap):
def __init__(self, next_fp: FileLike, rsa_public_key_pem: Union[str, bytes]) -> None:
class BaseEncryptorFile(FileWrap):
def __init__(self, next_fp: FileLike, encryptor: BaseEncryptor) -> None:
super().__init__(next_fp)
self.key = rsa_public_key_pem
self._encryptor: Optional[Encryptor] = Encryptor(self.key)
self._encryptor: Optional[BaseEncryptor] = encryptor
self.offset = 0
self.state = "OPEN"

@property
def encryptor(self) -> Encryptor:
def encryptor(self) -> BaseEncryptor:
if self._encryptor is None:
raise UninitializedError("encryptor was not initialized")
return self._encryptor
Expand Down Expand Up @@ -125,12 +118,10 @@ def write(self, data: BinaryData) -> int: # type: ignore[override]
return len(data_as_bytes)


class EncryptorStream(Stream):
"""Non-seekable stream of data that adds encryption on top of given source stream"""

def __init__(self, src_fp: HasRead, rsa_public_key_pem: Union[str, bytes]) -> None:
class BaseEncryptorStream(Stream):
def __init__(self, src_fp: HasRead, encryptor: BaseEncryptor) -> None:
super().__init__(src_fp)
self._encryptor = Encryptor(rsa_public_key_pem)
self._encryptor = encryptor

def _process_chunk(self, data: bytes) -> bytes:
return self._encryptor.update(data)
Expand All @@ -139,18 +130,10 @@ def _finalize(self) -> bytes:
return self._encryptor.finalize()


class Decryptor:
def __init__(self, rsa_private_key_pem: Union[str, bytes]) -> None:
if not isinstance(rsa_private_key_pem, bytes):
rsa_private_key_pem = rsa_private_key_pem.encode("ascii")
self.rsa_private_key = serialization.load_pem_private_key(
data=rsa_private_key_pem, password=None, backend=default_backend()
)
class BaseDecryptor(ABC):
def __init__(self) -> None:
self._cipher: Optional[CipherContext] = None
self._authenticator: Optional[HMAC] = None
self._cipher_key_len = None
self._header_size = None
self._footer_size = 32

@property
def authenticator(self) -> HMAC:
Expand All @@ -164,39 +147,21 @@ def cipher(self) -> CipherContext:
raise UninitializedError("cipher not initialized")
return self._cipher

@abstractmethod
def expected_header_bytes(self) -> int:
if self._header_size is not None:
return 0
return self._cipher_key_len or 8
pass

@abstractmethod
def header_size(self) -> int:
if self._header_size is None:
raise UninitializedError("header_size not initialized")
return self._header_size
pass

@abstractmethod
def footer_size(self) -> int:
return self._footer_size
pass

@abstractmethod
def process_header(self, data: bytes) -> None:
if self._cipher_key_len is None:
if data[0:6] != FILEMAGIC:
raise EncryptorError("Invalid magic bytes")
self._cipher_key_len = struct.unpack(">H", data[6:8])[0]
else:
pad = padding.OAEP(mgf=padding.MGF1(algorithm=SHA1()), algorithm=SHA1(), label=None)
try:
plainkey = self.rsa_private_key.decrypt(data, pad)
except AssertionError:
raise EncryptorError("Decrypting key data failed")
if len(plainkey) != 64:
raise EncryptorError("Integrity check failed")
key = plainkey[0:16]
nonce = plainkey[16:32]
auth_key = plainkey[32:64]
self._header_size = 8 + len(data)

self._cipher = Cipher(algorithms.AES(key), modes.CTR(nonce), backend=default_backend()).decryptor()
self._authenticator = HMAC(auth_key, SHA256(), backend=default_backend())
pass

def process_data(self, data: bytes) -> bytes:
if not data:
Expand All @@ -213,12 +178,12 @@ def finalize(self, footer: bytes) -> bytes:
return result


class DecryptorFile(FileWrap):
def __init__(self, next_fp: FileLike, rsa_private_key_pem: Union[bytes, str]):
class BaseDecryptorFile(FileWrap):
def __init__(self, next_fp: FileLike, decryptor_factory: Callable[[], BaseDecryptor]):
super().__init__(next_fp)
self._key = rsa_private_key_pem
self._decryptor_factory = decryptor_factory
self.log = logging.getLogger(self.__class__.__name__)
self._maybe_decryptor: Optional[Decryptor] = None
self._maybe_decryptor: Optional[BaseDecryptor] = None
self._maybe_crypted_size: Optional[int] = None
self._maybe_plaintext_size: Optional[int] = None
self._maybe_boundary_block: Optional[bytes] = None
Expand All @@ -230,7 +195,7 @@ def __init__(self, next_fp: FileLike, rsa_private_key_pem: Union[bytes, str]):
self._reset()

@property
def _decryptor(self) -> Decryptor:
def _decryptor(self) -> BaseDecryptor:
if self._maybe_decryptor is None:
raise UninitializedError("decryptor not initialized")
return self._maybe_decryptor
Expand All @@ -248,7 +213,7 @@ def _plaintext_size(self) -> int:
return self._maybe_plaintext_size

def _reset(self) -> None:
self._maybe_decryptor = Decryptor(self._key)
self._maybe_decryptor = self._decryptor_factory()
self._maybe_crypted_size = self._file_size(self.next_fp)
self._maybe_boundary_block = None
self._maybe_plaintext_size = None
Expand Down Expand Up @@ -403,14 +368,14 @@ def seekable(self) -> bool:
return True


class DecryptSink(Sink):
def __init__(self, next_sink: HasWrite, file_size: int, encryption_key_data: Union[str, bytes]) -> None:
class BaseDecryptSink(Sink):
def __init__(self, next_sink: HasWrite, file_size: int, decryptor: BaseDecryptor) -> None:
super().__init__(next_sink)
if file_size < 0:
raise ValueError("Invalid file_size: " + str(file_size))
self.data_bytes_received = 0
self.data_size = file_size
self.decryptor = Decryptor(encryption_key_data)
self.decryptor = decryptor
self.file_size = file_size
self.footer = b""
self.header = b""
Expand Down Expand Up @@ -458,3 +423,165 @@ def write(self, data: BinaryData) -> int:
return written
self._write_to_next_sink(data)
return written


class Encryptor(BaseEncryptor):
def __init__(self, public_key_pem: Union[str, bytes]):
if not isinstance(public_key_pem, bytes):
public_key_pem = public_key_pem.encode("ascii")
public_key = serialization.load_pem_public_key(public_key_pem, backend=default_backend())
if not isinstance(public_key, RSAPublicKey):
raise ValueError("Key must be RSA")

super().__init__()
self.rsa_public_key = public_key

def init_cipher(self) -> bytes:
cipher_key = os.urandom(16)
nonce = os.urandom(16)
hmac_key = os.urandom(32)
self._cipher = Cipher(algorithms.AES(cipher_key), modes.CTR(nonce), backend=default_backend()).encryptor()
self._authenticator = HMAC(hmac_key, SHA256(), backend=default_backend())
pad = padding.OAEP(mgf=padding.MGF1(algorithm=SHA1()), algorithm=SHA1(), label=None)
key = self.rsa_public_key.encrypt(cipher_key + nonce + hmac_key, pad)
return FILEMAGIC + struct.pack(">H", len(key)) + key


class EncryptorFile(BaseEncryptorFile):
def __init__(self, next_fp: FileLike, public_key_pem: Union[str, bytes]) -> None:
super().__init__(next_fp, Encryptor(public_key_pem))


class EncryptorStream(BaseEncryptorStream):
"""Non-seekable stream of data that adds encryption on top of given source stream"""

def __init__(self, src_fp: HasRead, public_key_pem: Union[str, bytes]) -> None:
super().__init__(src_fp, Encryptor(public_key_pem))


class Decryptor(BaseDecryptor):
def __init__(self, private_key_pem: Union[str, bytes]) -> None:
if not isinstance(private_key_pem, bytes):
private_key_pem = private_key_pem.encode("ascii")
private_key = serialization.load_pem_private_key(data=private_key_pem, password=None, backend=default_backend())
if not isinstance(private_key, RSAPrivateKey):
raise ValueError("Key must be RSA")

super().__init__()
self.rsa_private_key = private_key
self._key_size = None
self._header_size = None

def expected_header_bytes(self) -> int:
if self._cipher is not None:
return 0
return self._key_size or 8

def header_size(self) -> int:
if self._cipher is None:
raise UninitializedError("header_size not initialized")
assert self._header_size is not None
return self._header_size

def footer_size(self) -> int:
return 32

def process_header(self, data: bytes) -> None:
if self._key_size is None:
n = len(FILEMAGIC)
if data[0:n] != FILEMAGIC:
raise EncryptorError("Invalid magic bytes")
self._key_size = struct.unpack(">H", data[n : n + 2])[0]
else:
pad = padding.OAEP(mgf=padding.MGF1(algorithm=SHA1()), algorithm=SHA1(), label=None)
try:
key = self.rsa_private_key.decrypt(data, pad)
except AssertionError:
raise EncryptorError("Decrypting key data failed")
if len(key) != 64:
raise EncryptorError("Integrity check failed")

cipher_key = key[0:16]
nonce = key[16:32]
hmac_key = key[32:64]
self._cipher = Cipher(algorithms.AES(cipher_key), modes.CTR(nonce), backend=default_backend()).decryptor()
self._authenticator = HMAC(hmac_key, SHA256(), backend=default_backend())
self._header_size = len(FILEMAGIC) + 2 + self._key_size
self._key_size = None


class DecryptorFile(BaseDecryptorFile):
def __init__(self, next_fp: FileLike, private_key_pem: Union[bytes, str]):
super().__init__(next_fp, lambda: Decryptor(private_key_pem))


class DecryptSink(BaseDecryptSink):
def __init__(self, next_sink: HasWrite, file_size: int, private_key_pem: Union[bytes, str]):
super().__init__(next_sink, file_size, Decryptor(private_key_pem))


class SymmetricEncryptor(BaseEncryptor):
def __init__(self, key: bytes):
if len(key) != 32:
raise ValueError("Key must be 32 bytes")

super().__init__()
self._cipher_key = key[:16]
self._hmac_key = key[16:]

def init_cipher(self) -> bytes:
nonce = os.urandom(16)
self._cipher = Cipher(algorithms.AES(self._cipher_key), modes.CTR(nonce), backend=default_backend()).encryptor()
self._authenticator = HMAC(self._hmac_key, SHA256(), backend=default_backend())
return FILEMAGIC + nonce


class SymmetricEncryptorFile(BaseEncryptorFile):
def __init__(self, next_fp: FileLike, key: bytes) -> None:
super().__init__(next_fp, SymmetricEncryptor(key))


class SymmetricEncryptorStream(BaseEncryptorStream):
"""Non-seekable stream of data that adds encryption on top of given source stream"""

def __init__(self, src_fp: HasRead, key: bytes) -> None:
super().__init__(src_fp, SymmetricEncryptor(key))


class SymmetricDecryptor(BaseDecryptor):
def __init__(self, key: bytes) -> None:
if len(key) != 32:
raise ValueError("Key must be 32 bytes")

super().__init__()
self._cipher_key = key[:16]
self._hmac_key = key[16:]

def expected_header_bytes(self) -> int:
if self._cipher is not None:
return 0
return self.header_size()

def header_size(self) -> int:
return len(FILEMAGIC) + 16

def footer_size(self) -> int:
return 32

def process_header(self, data: bytes) -> None:
n = len(FILEMAGIC)
if data[0:n] != FILEMAGIC:
raise EncryptorError("Invalid magic bytes")
nonce = data[n : n + 16]
self._cipher = Cipher(algorithms.AES(self._cipher_key), modes.CTR(nonce), backend=default_backend()).decryptor()
self._authenticator = HMAC(self._hmac_key, SHA256(), backend=default_backend())


class SymmetricDecryptorFile(BaseDecryptorFile):
def __init__(self, next_fp: FileLike, key: bytes):
super().__init__(next_fp, lambda: SymmetricDecryptor(key))


class SymmetricDecryptSink(BaseDecryptSink):
def __init__(self, next_sink: HasWrite, file_size: int, key: bytes):
super().__init__(next_sink, file_size, SymmetricDecryptor(key))
Loading
Loading