diff --git a/litestar/connection/request.py b/litestar/connection/request.py index 23c60f0b3c..154af41a29 100644 --- a/litestar/connection/request.py +++ b/litestar/connection/request.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic +from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, cast from litestar._multipart import parse_content_header, parse_multipart_form from litestar._parsers import parse_url_encoded_form_data @@ -222,17 +222,7 @@ async def form(self) -> FormMultiDict: self._connection_state.form = form_data - # form_data is a dict[str, list[str] | str | UploadFile]. Convert it to a - # list[tuple[str, str | UploadFile]] before passing it to FormMultiDict so - # multi-keys can be accessed properly - items = [] - for k, v in form_data.items(): - if isinstance(v, list): - for sv in v: - items.append((k, sv)) - else: - items.append((k, v)) - self._form = FormMultiDict(items) + self._form = FormMultiDict.from_form_data(cast("dict[str, Any]", form_data)) return self._form diff --git a/litestar/datastructures/multi_dicts.py b/litestar/datastructures/multi_dicts.py index 7702e1a8d5..733be6be0b 100644 --- a/litestar/datastructures/multi_dicts.py +++ b/litestar/datastructures/multi_dicts.py @@ -95,6 +95,27 @@ def copy(self) -> Self: # type: ignore[override] class FormMultiDict(ImmutableMultiDict[Any]): """MultiDict for form data.""" + @classmethod + def from_form_data(cls, form_data: dict[str, list[str] | str | UploadFile]) -> FormMultiDict: + """Create a FormMultiDict from form data. + + Args: + form_data: Form data to create the FormMultiDict from. + + Returns: + A FormMultiDict instance + """ + # Convert form_data to a list[tuple[str, str | UploadFile]] before passing it + # to FormMultiDict so multi-keys can be accessed properly + items = [] + for k, v in form_data.items(): + if not isinstance(v, list): + items.append((k, v)) + else: + for sv in v: + items.append((k, sv)) + return cls(items) + async def close(self) -> None: """Close all files in the multi-dict. diff --git a/litestar/datastructures/upload_file.py b/litestar/datastructures/upload_file.py index 93d76476a2..78f27c92bc 100644 --- a/litestar/datastructures/upload_file.py +++ b/litestar/datastructures/upload_file.py @@ -94,6 +94,8 @@ async def close(self) -> None: Returns: None. """ + if self.file.closed: + return None if self.rolled_to_disk: return await sync_to_thread(self.file.close) return self.file.close() diff --git a/litestar/routes/http.py b/litestar/routes/http.py index a8e47f1b99..2547fdce9b 100644 --- a/litestar/routes/http.py +++ b/litestar/routes/http.py @@ -5,7 +5,7 @@ from msgspec.msgpack import decode as _decode_msgpack_plain -from litestar.datastructures.upload_file import UploadFile +from litestar.datastructures.multi_dicts import FormMultiDict from litestar.enums import HttpMethod, MediaType, ScopeType from litestar.exceptions import ClientException, ImproperlyConfiguredException, SerializationException from litestar.handlers.http_handlers import HTTPRouteHandler @@ -86,8 +86,10 @@ async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None: if after_response_handler := route_handler.resolve_after_response(): await after_response_handler(request) + if request._form is not Empty: + await request._form.close() if form_data := scope.get("_form", {}): - await self._cleanup_temporary_files(form_data=cast("dict[str, Any]", form_data)) + await FormMultiDict.from_form_data(cast("dict[str, Any]", form_data)).close() def create_handler_map(self) -> None: """Parse the ``router_handlers`` of this route and return a mapping of @@ -258,9 +260,3 @@ def options_handler(scope: Scope) -> Response: include_in_schema=False, sync_to_thread=False, )(options_handler) - - @staticmethod - async def _cleanup_temporary_files(form_data: dict[str, Any]) -> None: - for v in form_data.values(): - if isinstance(v, UploadFile) and not v.file.closed: - await v.close() diff --git a/pyproject.toml b/pyproject.toml index e0cc63efa2..f5390bf0a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -221,6 +221,9 @@ fail_under = 50 addopts = "--strict-markers --strict-config --dist=loadgroup -m 'not server_integration'" asyncio_mode = "auto" filterwarnings = [ + "error", + # https://github.com/pytest-dev/pytest-asyncio/issues/724 + "default:.*socket.socket:pytest.PytestUnraisableExceptionWarning", "ignore::trio.TrioDeprecationWarning:anyio._backends._trio*:", "ignore::DeprecationWarning:pkg_resources.*", "ignore::DeprecationWarning:google.rpc", diff --git a/tests/e2e/test_routing/conftest.py b/tests/e2e/test_routing/conftest.py index eaa178e1ad..1295f57647 100644 --- a/tests/e2e/test_routing/conftest.py +++ b/tests/e2e/test_routing/conftest.py @@ -1,4 +1,3 @@ -import subprocess import time from pathlib import Path from typing import Callable, List @@ -16,16 +15,13 @@ def runner(app: str, server_command: List[str]) -> None: tmp_path.joinpath("app.py").write_text(app) monkeypatch.chdir(tmp_path) - proc = psutil.Popen( - server_command, - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, - ) + proc = psutil.Popen(server_command) def kill() -> None: for child in proc.children(recursive=True): child.kill() proc.kill() + proc.wait() request.addfinalizer(kill) diff --git a/tests/unit/test_datastructures/test_multi_dicts.py b/tests/unit/test_datastructures/test_multi_dicts.py index 78fec65f69..7ec7f386ca 100644 --- a/tests/unit/test_datastructures/test_multi_dicts.py +++ b/tests/unit/test_datastructures/test_multi_dicts.py @@ -1,7 +1,8 @@ from __future__ import annotations +from unittest.mock import patch + import pytest -from pytest_mock import MockerFixture from litestar.datastructures import UploadFile from litestar.datastructures.multi_dicts import FormMultiDict, ImmutableMultiDict, MultiDict @@ -34,20 +35,19 @@ def test_immutable_multi_dict_as_mutable() -> None: assert multi.mutable_copy().dict() == MultiDict(data).dict() -async def test_form_multi_dict_close(mocker: MockerFixture) -> None: - close = mocker.patch("litestar.datastructures.multi_dicts.UploadFile.close") - +async def test_form_multi_dict_close() -> None: multi = FormMultiDict( [ ("foo", UploadFile(filename="foo", content_type="text/plain")), ("bar", UploadFile(filename="foo", content_type="text/plain")), ] ) - + with patch("litestar.datastructures.multi_dicts.UploadFile.close") as mock_close: + await multi.close() + assert mock_close.call_count == 2 + # calls the real UploadFile.close method to clean up await multi.close() - assert close.call_count == 2 - @pytest.mark.parametrize("type_", [MultiDict, ImmutableMultiDict]) def test_copy(type_: type[MultiDict | ImmutableMultiDict]) -> None: