From 822d62c348e72387036f601013ef3ab855c5a04e Mon Sep 17 00:00:00 2001 From: Julien Castiaux Date: Thu, 9 Jan 2025 18:44:48 +0100 Subject: [PATCH] Add headers.get and headers.getlist --- h11/_headers.py | 43 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/h11/_headers.py b/h11/_headers.py index b97d020..d7e69ff 100644 --- a/h11/_headers.py +++ b/h11/_headers.py @@ -1,5 +1,5 @@ import re -from typing import AnyStr, cast, List, overload, Sequence, Tuple, TYPE_CHECKING, Union +from typing import List, overload, Sequence, Tuple, TYPE_CHECKING, TypeVar, Union from ._abnf import field_name, field_value from ._util import bytesify, LocalProtocolError, validate @@ -13,6 +13,8 @@ from typing_extensions import Literal # type: ignore +T = TypeVar("T") + # Facts # ----- # @@ -84,19 +86,29 @@ class Headers(Sequence[Tuple[bytes, bytes]]): r = Request( method="GET", target="/", - headers=[("Host", "example.org"), ("Connection", "keep-alive")], + headers=[ + ("Host", "example.org"), + ("Connection", "keep-alive"), + ("Cookie", "session=1234"), + ("Cookie", "lang=en_US"), + ], http_version="1.1", ) assert r.headers == [ (b"host", b"example.org"), - (b"connection", b"keep-alive") + (b"connection", b"keep-alive"), + (b"cookie", b"session=1234"), + (b"cookie", b"lang=en_US"), ] assert r.headers.raw_items() == [ (b"Host", b"example.org"), - (b"Connection", b"keep-alive") + (b"Connection", b"keep-alive"), + (b"Cookie", b"session=1234"), + (b"Cookie", b"lang=en_US"), ] + assert r.headers.get(b"host") == b"example.org" + assert r.headers.getlist(b"cookie") == [b"session=1234", b"lang=en_US"] """ - __slots__ = "_full_items" def __init__(self, full_items: List[Tuple[bytes, bytes, bytes]]) -> None: @@ -118,6 +130,27 @@ def __getitem__(self, idx: int) -> Tuple[bytes, bytes]: # type: ignore[override _, name, value = self._full_items[idx] return (name, value) + def get(self, name: bytes, default: T = None) -> bytes | T: + """Find the first header with lowercased-name :param:`name`, it returns + its value when found, and :param:`default` otherwise. + + Args: + name (bytes): The lowercased header name to find. + + default: The value to return when the header is not found. + """ + return next((value for name_, value in self if name_ == name), default) + + def getlist(self, name: bytes) -> list[bytes]: + """Find the all the headers with lowercased-name :param:`name`, + it returns their values in a list. It returns an empty list when + no header matched. + + Args: + name (bytes): The lowercased header name to find. + """ + return [value for name_, value in self if name_ == name] + def raw_items(self) -> List[Tuple[bytes, bytes]]: return [(raw_name, value) for raw_name, _, value in self._full_items]