From 0d99308ed11b94be52620541b6a8c9402bb4274e Mon Sep 17 00:00:00 2001 From: Rohan Devasthale Date: Wed, 7 Aug 2024 14:29:16 -0400 Subject: [PATCH] Fix: Network hiccups using Retry logic --- src/fromager/context.py | 24 ++++++++++++++++++++++++ src/fromager/sdist.py | 6 +++--- src/fromager/sources.py | 36 +++++++++++++++++++++++++++--------- tests/test_sdist.py | 3 ++- tests/test_sources.py | 8 +++++--- 5 files changed, 61 insertions(+), 16 deletions(-) diff --git a/src/fromager/context.py b/src/fromager/context.py index cc7ee7cc..5adf4c45 100644 --- a/src/fromager/context.py +++ b/src/fromager/context.py @@ -5,9 +5,11 @@ import typing from urllib.parse import urlparse +import requests from packaging.requirements import Requirement from packaging.utils import NormalizedName, canonicalize_name from packaging.version import Version +from requests.adapters import HTTPAdapter, Retry from . import constraints, settings @@ -71,6 +73,8 @@ def __init__( self._seen_requirements: set[tuple[NormalizedName, tuple[str, ...], str]] = ( set() ) + # create a requests session + self.requests = self._make_requests_session() @property def pip_wheel_server_args(self) -> list[str]: @@ -80,6 +84,26 @@ def pip_wheel_server_args(self) -> list[str]: args = args + ["--trusted-host", parsed.hostname] return args + def _make_requests_session( + self, retries=5, backoff_factor=0.1, status_forcelist=None + ) -> requests.Session: + session = requests.Session() + if status_forcelist is None: + status_forcelist = [500, 502, 503, 504] + + # Initialize the retries object + retries = Retry( + total=retries, + backoff_factor=backoff_factor, + status_forcelist=status_forcelist, + ) + + # Initialize the adapter object + adapter = HTTPAdapter(max_retries=retries) + session.mount("http://", adapter) + session.mount("https://", adapter) + return session + def _resolved_key( self, req: Requirement, version: Version ) -> tuple[NormalizedName, tuple[str, ...], str]: diff --git a/src/fromager/sdist.py b/src/fromager/sdist.py index cb1c76d7..8d6a838d 100644 --- a/src/fromager/sdist.py +++ b/src/fromager/sdist.py @@ -311,7 +311,7 @@ def download_wheel( wheel_filename = output_directory / os.path.basename(urlparse(wheel_url).path) if not wheel_filename.exists(): logger.info(f"{req.name}: downloading pre-built wheel {wheel_url}") - wheel_filename = _download_wheel_check(output_directory, wheel_url) + wheel_filename = _download_wheel_check(ctx, output_directory, wheel_url) logger.info(f"{req.name}: saved wheel to {wheel_filename}") else: logger.info(f"{req.name}: have existing wheel {wheel_filename}") @@ -321,8 +321,8 @@ def download_wheel( # Helper method to check whether the .whl file is a zip file and has contents in it. # It will throw BadZipFile exception if any other file is encountered. Eg: index.html -def _download_wheel_check(destination_dir, wheel_url): - wheel_filename = sources.download_url(destination_dir, wheel_url) +def _download_wheel_check(ctx, destination_dir, wheel_url): + wheel_filename = sources.download_url(ctx, destination_dir, wheel_url) wheel_directory_contents = zipfile.ZipFile(wheel_filename).namelist() if not wheel_directory_contents: raise zipfile.BadZipFile(f"Empty zip file encountered: {wheel_filename}") diff --git a/src/fromager/sources.py b/src/fromager/sources.py index 19113f1b..bef612f0 100644 --- a/src/fromager/sources.py +++ b/src/fromager/sources.py @@ -10,7 +10,6 @@ import zipfile from urllib.parse import urlparse -import requests import resolvelib from packaging.requirements import Requirement from packaging.version import InvalidVersion, Version @@ -165,7 +164,7 @@ def default_download_source( ) source_filename = _download_source_check( - ctx.sdists_downloads, url, destination_filename + ctx, ctx.sdists_downloads, url, destination_filename ) logger.debug( @@ -177,9 +176,12 @@ def default_download_source( # Helper method to check whether .zip /.tar / .tgz is able to extract and check its content. # It will throw exception if any other file is encountered. Eg: index.html def _download_source_check( - destination_dir: pathlib.Path, url: str, destination_filename: str | None = None + ctx: context.WorkContext, + destination_dir: pathlib.Path, + url: str, + destination_filename: str | None = None, ) -> str: - source_filename = download_url(destination_dir, url, destination_filename) + source_filename = download_url(ctx, destination_dir, url, destination_filename) if source_filename.suffix == ".zip": source_file_contents = zipfile.ZipFile(source_filename).namelist() if not source_file_contents: @@ -197,7 +199,10 @@ def _download_source_check( def download_url( - destination_dir: pathlib.Path, url: str, destination_filename: str | None = None + ctx: context.WorkContext, + destination_dir: pathlib.Path, + url: str, + destination_filename: str | None = None, ) -> pathlib.Path: basename = ( destination_filename @@ -209,13 +214,26 @@ def download_url( "looking for %s %s", outfile, "(exists)" if outfile.exists() else "(not there)" ) if outfile.exists(): - logger.debug(f"already have {outfile}") - return outfile + outfile_size_on_disk = outfile.stat().st_size + header_response = ctx.requests.head(url) + outfile_size_online = int(header_response.headers.get("Content-Length", 0)) + + # Check if size on disk matches the size of file online + if outfile_size_on_disk == outfile_size_online: + logger.debug(f"already have {outfile}") + return outfile + else: + # Set the range header + headers = {"Range": f"bytes={outfile_size_on_disk}-"} + else: + outfile_size_on_disk = 0 + headers = {} + # Open the URL first in case that fails, so we don't end up with an empty file. logger.debug(f"reading from {url}") - with requests.get(url, stream=True) as r: + with ctx.requests.get(url, stream=True, headers=headers) as r: r.raise_for_status() - with open(outfile, "wb") as f: + with open(outfile, "ab" if outfile_size_on_disk else "wb") as f: logger.debug(f"writing to {outfile}") for chunk in r.iter_content(chunk_size=1024 * 1024): f.write(chunk) diff --git a/tests/test_sdist.py b/tests/test_sdist.py index 699bde25..91b5c2d0 100644 --- a/tests/test_sdist.py +++ b/tests/test_sdist.py @@ -51,6 +51,7 @@ def test_ignore_based_on_marker(tmp_context: WorkContext): @patch("fromager.sources.download_url") def test_invalid_wheel_file_exception(mock_download_url, tmp_path: pathlib.Path): + tmp_context = WorkContext mock_download_url.return_value = pathlib.Path(tmp_path / "test" / "fake_wheel.txt") fake_url = "https://www.thisisafakeurl.com" fake_dir = tmp_path / "test" @@ -58,4 +59,4 @@ def test_invalid_wheel_file_exception(mock_download_url, tmp_path: pathlib.Path) text_file = fake_dir / "fake_wheel.txt" text_file.write_text("This is a test file") with pytest.raises(zipfile.BadZipFile): - sdist._download_wheel_check(fake_dir, fake_url) + sdist._download_wheel_check(tmp_context, fake_dir, fake_url) diff --git a/tests/test_sources.py b/tests/test_sources.py index 06de6bcf..b7bb093e 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -8,7 +8,9 @@ @patch("fromager.sources.download_url") -def test_invalid_tarfile(mock_download_url, tmp_path: pathlib.Path): +def test_invalid_tarfile( + mock_download_url, tmp_path: pathlib.Path, tmp_context: context.WorkContext +): mock_download_url.return_value = pathlib.Path(tmp_path / "test" / "fake_wheel.txt") fake_url = "https://www.thisisafakeurl.com" fake_dir = tmp_path / "test" @@ -16,7 +18,7 @@ def test_invalid_tarfile(mock_download_url, tmp_path: pathlib.Path): text_file = fake_dir / "fake_wheel.txt" text_file.write_text("This is a test file") with pytest.raises(TypeError): - sources._download_source_check(fake_dir, fake_url) + sources._download_source_check(tmp_context, fake_dir, fake_url) @patch("fromager.sources.resolve_dist") @@ -41,7 +43,7 @@ def test_default_download_source_from_settings( resolve_dist.assert_called_with(tmp_context, req, sdist_server_url, True, False) download_source_check.assert_called_with( - tmp_context.sdists_downloads, "predefined_url-1.0", "foo-1.0" + tmp_context, tmp_context.sdists_downloads, "predefined_url-1.0", "foo-1.0" )