Skip to content

Commit

Permalink
Fix: Network hiccups using Retry logic
Browse files Browse the repository at this point in the history
  • Loading branch information
rd4398 committed Aug 8, 2024
1 parent 43d9ece commit 0d99308
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 16 deletions.
24 changes: 24 additions & 0 deletions src/fromager/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import typing
from urllib.parse import urlparse

import requests
from packaging.requirements import Requirement
from packaging.utils import NormalizedName, canonicalize_name
from packaging.version import Version
from requests.adapters import HTTPAdapter, Retry

from . import constraints, settings

Expand Down Expand Up @@ -71,6 +73,8 @@ def __init__(
self._seen_requirements: set[tuple[NormalizedName, tuple[str, ...], str]] = (
set()
)
# create a requests session
self.requests = self._make_requests_session()

@property
def pip_wheel_server_args(self) -> list[str]:
Expand All @@ -80,6 +84,26 @@ def pip_wheel_server_args(self) -> list[str]:
args = args + ["--trusted-host", parsed.hostname]
return args

def _make_requests_session(
self, retries=5, backoff_factor=0.1, status_forcelist=None
) -> requests.Session:
session = requests.Session()
if status_forcelist is None:
status_forcelist = [500, 502, 503, 504]

# Initialize the retries object
retries = Retry(
total=retries,
backoff_factor=backoff_factor,
status_forcelist=status_forcelist,
)

# Initialize the adapter object
adapter = HTTPAdapter(max_retries=retries)
session.mount("http://", adapter)
session.mount("https://", adapter)
return session

def _resolved_key(
self, req: Requirement, version: Version
) -> tuple[NormalizedName, tuple[str, ...], str]:
Expand Down
6 changes: 3 additions & 3 deletions src/fromager/sdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def download_wheel(
wheel_filename = output_directory / os.path.basename(urlparse(wheel_url).path)
if not wheel_filename.exists():
logger.info(f"{req.name}: downloading pre-built wheel {wheel_url}")
wheel_filename = _download_wheel_check(output_directory, wheel_url)
wheel_filename = _download_wheel_check(ctx, output_directory, wheel_url)
logger.info(f"{req.name}: saved wheel to {wheel_filename}")
else:
logger.info(f"{req.name}: have existing wheel {wheel_filename}")
Expand All @@ -321,8 +321,8 @@ def download_wheel(

# Helper method to check whether the .whl file is a zip file and has contents in it.
# It will throw BadZipFile exception if any other file is encountered. Eg: index.html
def _download_wheel_check(destination_dir, wheel_url):
wheel_filename = sources.download_url(destination_dir, wheel_url)
def _download_wheel_check(ctx, destination_dir, wheel_url):
wheel_filename = sources.download_url(ctx, destination_dir, wheel_url)
wheel_directory_contents = zipfile.ZipFile(wheel_filename).namelist()
if not wheel_directory_contents:
raise zipfile.BadZipFile(f"Empty zip file encountered: {wheel_filename}")
Expand Down
36 changes: 27 additions & 9 deletions src/fromager/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import zipfile
from urllib.parse import urlparse

import requests
import resolvelib
from packaging.requirements import Requirement
from packaging.version import InvalidVersion, Version
Expand Down Expand Up @@ -165,7 +164,7 @@ def default_download_source(
)

source_filename = _download_source_check(
ctx.sdists_downloads, url, destination_filename
ctx, ctx.sdists_downloads, url, destination_filename
)

logger.debug(
Expand All @@ -177,9 +176,12 @@ def default_download_source(
# Helper method to check whether .zip /.tar / .tgz is able to extract and check its content.
# It will throw exception if any other file is encountered. Eg: index.html
def _download_source_check(
destination_dir: pathlib.Path, url: str, destination_filename: str | None = None
ctx: context.WorkContext,
destination_dir: pathlib.Path,
url: str,
destination_filename: str | None = None,
) -> str:
source_filename = download_url(destination_dir, url, destination_filename)
source_filename = download_url(ctx, destination_dir, url, destination_filename)
if source_filename.suffix == ".zip":
source_file_contents = zipfile.ZipFile(source_filename).namelist()
if not source_file_contents:
Expand All @@ -197,7 +199,10 @@ def _download_source_check(


def download_url(
destination_dir: pathlib.Path, url: str, destination_filename: str | None = None
ctx: context.WorkContext,
destination_dir: pathlib.Path,
url: str,
destination_filename: str | None = None,
) -> pathlib.Path:
basename = (
destination_filename
Expand All @@ -209,13 +214,26 @@ def download_url(
"looking for %s %s", outfile, "(exists)" if outfile.exists() else "(not there)"
)
if outfile.exists():
logger.debug(f"already have {outfile}")
return outfile
outfile_size_on_disk = outfile.stat().st_size
header_response = ctx.requests.head(url)
outfile_size_online = int(header_response.headers.get("Content-Length", 0))

# Check if size on disk matches the size of file online
if outfile_size_on_disk == outfile_size_online:
logger.debug(f"already have {outfile}")
return outfile
else:
# Set the range header
headers = {"Range": f"bytes={outfile_size_on_disk}-"}
else:
outfile_size_on_disk = 0
headers = {}

# Open the URL first in case that fails, so we don't end up with an empty file.
logger.debug(f"reading from {url}")
with requests.get(url, stream=True) as r:
with ctx.requests.get(url, stream=True, headers=headers) as r:
r.raise_for_status()
with open(outfile, "wb") as f:
with open(outfile, "ab" if outfile_size_on_disk else "wb") as f:
logger.debug(f"writing to {outfile}")
for chunk in r.iter_content(chunk_size=1024 * 1024):
f.write(chunk)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_sdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def test_ignore_based_on_marker(tmp_context: WorkContext):

@patch("fromager.sources.download_url")
def test_invalid_wheel_file_exception(mock_download_url, tmp_path: pathlib.Path):
tmp_context = WorkContext
mock_download_url.return_value = pathlib.Path(tmp_path / "test" / "fake_wheel.txt")
fake_url = "https://www.thisisafakeurl.com"
fake_dir = tmp_path / "test"
fake_dir.mkdir()
text_file = fake_dir / "fake_wheel.txt"
text_file.write_text("This is a test file")
with pytest.raises(zipfile.BadZipFile):
sdist._download_wheel_check(fake_dir, fake_url)
sdist._download_wheel_check(tmp_context, fake_dir, fake_url)
8 changes: 5 additions & 3 deletions tests/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@


@patch("fromager.sources.download_url")
def test_invalid_tarfile(mock_download_url, tmp_path: pathlib.Path):
def test_invalid_tarfile(
mock_download_url, tmp_path: pathlib.Path, tmp_context: context.WorkContext
):
mock_download_url.return_value = pathlib.Path(tmp_path / "test" / "fake_wheel.txt")
fake_url = "https://www.thisisafakeurl.com"
fake_dir = tmp_path / "test"
fake_dir.mkdir()
text_file = fake_dir / "fake_wheel.txt"
text_file.write_text("This is a test file")
with pytest.raises(TypeError):
sources._download_source_check(fake_dir, fake_url)
sources._download_source_check(tmp_context, fake_dir, fake_url)


@patch("fromager.sources.resolve_dist")
Expand All @@ -41,7 +43,7 @@ def test_default_download_source_from_settings(

resolve_dist.assert_called_with(tmp_context, req, sdist_server_url, True, False)
download_source_check.assert_called_with(
tmp_context.sdists_downloads, "predefined_url-1.0", "foo-1.0"
tmp_context, tmp_context.sdists_downloads, "predefined_url-1.0", "foo-1.0"
)


Expand Down

0 comments on commit 0d99308

Please sign in to comment.