diff --git a/packages/atproto_firehose/models.py b/packages/atproto_firehose/models.py index 375dd43b..0cd3f2a7 100644 --- a/packages/atproto_firehose/models.py +++ b/packages/atproto_firehose/models.py @@ -15,10 +15,6 @@ class FrameType(Enum): MESSAGE = 1 ERROR = -1 - @classmethod - def has_value(cls, value: int) -> bool: - return value in cls._value2member_map_ - @dataclass class MessageFrameHeader: @@ -50,13 +46,12 @@ class ErrorFrameBody: def parse_frame_header(raw_header: dict) -> FrameHeader: try: header_op = int(raw_header.get('op', 0)) - if not FrameType.has_value(header_op): - raise FirehoseDecodingError('Invalid frame type') - - frame_type = FrameType(header_op) - if frame_type is FrameType.MESSAGE: + if header_op == FrameType.MESSAGE.value: return get_or_create(raw_header, MessageFrameHeader) - return get_or_create(raw_header, ErrorFrameHeader) + elif header_op == FrameType.ERROR.value: + return get_or_create(raw_header, ErrorFrameHeader) + else: + raise FirehoseDecodingError('Invalid frame type') except (ValueError, AtProtocolError) as e: raise FirehoseDecodingError('Invalid frame header') from e @@ -106,23 +101,16 @@ def from_bytes(data: Union[bytes, bytearray]) -> Union['MessageFrame', 'ErrorFra :obj:`atproto.firehose_models.MessageFrame` or :obj:`atproto.firehose_models.ErrorFrame` Raises: - :class:`atproto.exceptions.FirehoseError`: Invalid data frame. + :class:`atproto.exceptions.FirehoseDecodingError`: Invalid data frame. """ decoded_parts = decode_dag_multi(data) + if not decoded_parts: + raise FirehoseDecodingError('Invalid frame without CBOR data') if len(decoded_parts) > 2: raise FirehoseDecodingError('Too many CBOR data parts in the frame') - if not len(decoded_parts): - raise FirehoseDecodingError('Invalid frame without CBOR data') - - raw_header = decoded_parts[0] - - raw_body = None - if len(decoded_parts) > 1: - raw_body = decoded_parts[1] - - if raw_body is None: + if len(decoded_parts) < 2: raise FirehoseDecodingError('Frame body not found') - + raw_header, raw_body = decoded_parts header = parse_frame_header(raw_header) return parse_frame(header, raw_body)