Skip to content

Commit

Permalink
Merge pull request #7 from kfirgoldberg/jeremy/tqdm_download
Browse files Browse the repository at this point in the history
Jeremy/tqdm download
  • Loading branch information
kfirgoldberg authored Apr 6, 2024
2 parents d371c62 + d5edae8 commit 5808c66
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 52 deletions.
2 changes: 1 addition & 1 deletion anypathlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.1"
__version__ = "0.1"

from anypathlib.anypath import AnyPath
from anypathlib.path_handlers.path_types import PathType
31 changes: 21 additions & 10 deletions anypathlib/anypath.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def stem(self) -> str:
def name(self) -> str:
return self.path_handler.name(self.base_path)

def __get_local_path(self, target_path: Optional[Path] = None, force_overwrite: bool = False) -> Path:
def __get_local_path(self, target_path: Optional[Path] = None, force_overwrite: bool = False,
verbose: bool = False) -> Optional[Path]:
if target_path is None:
if self.is_dir():
valid_target_path = Path(tempfile.mkdtemp())
Expand All @@ -129,15 +130,21 @@ def __get_local_path(self, target_path: Optional[Path] = None, force_overwrite:
return valid_target_path
else:
if self.is_dir():
local_path, _ = self.path_handler.download_directory(url=self.base_path,
force_overwrite=force_overwrite,
target_dir=valid_target_path)
result = self.path_handler.download_directory(url=self.base_path,
force_overwrite=force_overwrite,
target_dir=valid_target_path,
verbose=verbose)
if result is not None:
local_path, _ = result
else:
return None

else:
local_path = self.path_handler.download_file(url=self.base_path, force_overwrite=force_overwrite,
target_path=valid_target_path)

assert local_path == valid_target_path, f'local_path {local_path} is not equal to valid_target_path {valid_target_path}'
assert local_path == valid_target_path, \
f'local_path {local_path} is not equal to valid_target_path {valid_target_path}'
return Path(local_path)

def __get_local_cache_path(self) -> 'AnyPath':
Expand All @@ -149,26 +156,30 @@ def __get_local_cache_path(self) -> 'AnyPath':
local_cache_path.parent.mkdir(exist_ok=True, parents=True)
return AnyPath(local_cache_path)

def copy(self, target: Optional['AnyPath'] = None, force_overwrite: bool = True) -> 'AnyPath':
def copy(self, target: Optional['AnyPath'] = None, force_overwrite: bool = True, verbose: bool = False) -> 'AnyPath':
assert self.exists(), f'source path: {self.base_path} does not exist'
if target is None:
valid_target = self.__get_local_cache_path()
else:
valid_target = target
if valid_target.is_local:
self.__get_local_path(target_path=Path(valid_target.base_path), force_overwrite=force_overwrite)
self.__get_local_path(target_path=Path(valid_target.base_path), force_overwrite=force_overwrite,
verbose=verbose)
else:
if valid_target.is_s3 and self.is_s3:
S3Handler.copy(source_url=self.base_path, target_url=valid_target.base_path)
elif valid_target.is_azure and self.is_azure:
AzureHandler.copy(source_url=self.base_path, target_url=valid_target.base_path)
else:
# valid_target and source are different, so we need to download the source and upload it to the valid_target
# valid_target and source are different,
# so we need to download the source and upload it to the valid_target

local_path = Path(self.base_path) if self.is_local else self.__get_local_path(
force_overwrite=force_overwrite)
force_overwrite=force_overwrite, verbose=verbose)
target_path_handler = valid_target.path_handler
if self.is_dir():
target_path_handler.upload_directory(local_dir=local_path, target_url=valid_target.base_path)
target_path_handler.upload_directory(local_dir=local_path, target_url=valid_target.base_path,
verbose=verbose)
else:
target_path_handler.upload_file(local_path=str(local_path), target_url=valid_target.base_path)
return valid_target
38 changes: 21 additions & 17 deletions anypathlib/path_handlers/azure_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional, List, Tuple
from urllib.parse import urlparse

from tqdm import tqdm
from azure.core.exceptions import ResourceNotFoundError

from azure.identity import DefaultAzureCredential
Expand Down Expand Up @@ -240,7 +241,7 @@ def remove(cls, url: str, allow_missing: bool = False):
raise e

@classmethod
def download_directory(cls, url: str, force_overwrite: bool, target_dir: Path) -> \
def download_directory(cls, url: str, force_overwrite: bool, target_dir: Path, verbose: bool) -> \
Optional[Tuple[Path, List[Path]]]:
"""Download a directory (all blobs with the same prefix) from Azure Blob Storage."""
assert target_dir.is_dir()
Expand All @@ -249,31 +250,28 @@ def download_directory(cls, url: str, force_overwrite: bool, target_dir: Path) -
blob_service_client = BlobServiceClient.from_connection_string(azure_storage_path.connection_string)
container_client = blob_service_client.get_container_client(container=azure_storage_path.container_name)
local_paths = []
blob_urls = []
for blob in container_client.list_blobs(name_starts_with=azure_storage_path.blob_name):

if verbose:
container_iterator = container_client.list_blobs(name_starts_with=azure_storage_path.blob_name)
progress_bar = tqdm(container_iterator, desc='Downloading directory',
total=len(list(container_iterator)))
else:
progress_bar = container_client.list_blobs(name_starts_with=azure_storage_path.blob_name)

for blob in progress_bar:
blob_url = AzureStoragePath(storage_account=azure_storage_path.storage_account,
container_name=azure_storage_path.container_name, blob_name=blob.name,
connection_string=azure_storage_path.connection_string).http_url
local_target = target_dir / blob.name
local_target = target_dir / Path(blob_url).relative_to(Path(url))
local_path = cls.download_file(url=blob_url, force_overwrite=force_overwrite, target_path=local_target)
assert local_path is not None, f'could not download from {url}'
local_paths.append(Path(local_path))
blob_urls.append(blob_url)
if len(local_paths) == 0:
return None
if target_dir is not None:
local_files = []
for blob_url, local_file in zip(blob_urls, local_paths):
relative_path = Path(blob_url).relative_to(Path(url))
target_path = target_dir / relative_path
target_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(local_file, target_path)
local_files.append(target_path)
return target_dir, local_files
return local_paths[0].parent, local_paths

@classmethod
def upload_directory(cls, local_dir: Path, target_url: str):
def upload_directory(cls, local_dir: Path, target_url: str, verbose: bool):
"""Upload a directory to Azure Blob Storage."""
azure_storage_path = cls.http_to_storage_params(target_url)
blob_service_client = BlobServiceClient.from_connection_string(azure_storage_path.connection_string)
Expand Down Expand Up @@ -302,8 +300,14 @@ def upload_file_wrapper(local_path: str, blob_name: str):
with ThreadPoolExecutor() as executor:
futures = [executor.submit(upload_file_wrapper, str(local_path), blob_name) for local_path, blob_name in
files_to_upload]
for future in futures:
future.result() # Wait for each upload to complete
if verbose:
with tqdm(total=len(files_to_upload), desc='Uploading directory') as pbar:
for future in futures:
future.result() # Wait for each upload to complete
pbar.update(1)
else:
for future in futures:
future.result() # Wait for each upload to complete

@classmethod
def copy(cls, source_url: str, target_url: str):
Expand Down
5 changes: 3 additions & 2 deletions anypathlib/path_handlers/base_path_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def remove(cls, url: str):

@classmethod
@abstractmethod
def download_directory(cls, url: str, force_overwrite: bool, target_dir: Path) -> Optional[Tuple[Path, List[Path]]]:
def download_directory(cls, url: str, force_overwrite: bool, target_dir: Path,
verbose: bool) -> Optional[Tuple[Path, List[Path]]]:
pass

@classmethod
Expand All @@ -31,7 +32,7 @@ def upload_file(cls, local_path: str, target_url: str):

@classmethod
@abstractmethod
def upload_directory(cls, local_dir: Path, target_url: str):
def upload_directory(cls, local_dir: Path, target_url: str, verbose: bool):
pass

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions anypathlib/path_handlers/local_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def upload_file(cls, local_path: str, target_url: str):
cls.copy_path(url=Path(local_path).absolute().as_posix(), target_path=Path(target_url), force_overwrite=True)

@classmethod
def upload_directory(cls, local_dir: Path, target_url: str):
def upload_directory(cls, local_dir: Path, target_url: str, verbose: bool):
cls.copy_path(url=local_dir.absolute().as_posix(), target_path=Path(target_url), force_overwrite=True)

@classmethod
Expand All @@ -52,7 +52,7 @@ def copy_path(cls, url: str, target_path: Path, force_overwrite: bool = True) ->
shutil.copy(local_path, target_path)

@classmethod
def download_directory(cls, url: str, force_overwrite: bool, target_dir: Path) -> \
def download_directory(cls, url: str, force_overwrite: bool, target_dir: Path, verbose: bool) -> \
Optional[Tuple[Path, List[Path]]]:
cls.copy_path(url=url, target_path=target_dir, force_overwrite=force_overwrite)
return target_dir, [p for p in target_dir.rglob('*')]
Expand Down
50 changes: 39 additions & 11 deletions anypathlib/path_handlers/s3_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import boto3 as boto3
import botocore
from tqdm import tqdm

from anypathlib.path_handlers.base_path_handler import BasePathHandler

Expand Down Expand Up @@ -111,7 +112,7 @@ def remove(cls, url: str):
bucket.objects.filter(Prefix=key).delete()

@classmethod
def download_directory(cls, url: str, force_overwrite: bool, target_dir: Path) -> \
def download_directory(cls, url: str, force_overwrite: bool, target_dir: Path, verbose: bool) -> \
Optional[Tuple[Path, List[Path]]]:

s3_resource = boto3.resource('s3')
Expand All @@ -137,14 +138,27 @@ def s3_path_to_local_file_path(s3_path: str, local_base_path: Path) -> Path:
target_path=s3_path_to_local_file_path(s3_path=s3_path,
local_base_path=target_dir),
force_overwrite=force_overwrite): s3_path for s3_path in s3_paths}
for future in as_completed(future_to_s3_path):
s3_path = future_to_s3_path[future]
try:
local_path = future.result()
if local_path:
all_files.append(local_path)
except Exception as exc:
print(f'{s3_path} generated an exception: {exc}')

def process_futures():
for future in as_completed(future_to_s3_path):
s3_path = future_to_s3_path[future]
try:
local_path = future.result()
if local_path:
all_files.append(local_path)
except Exception as exc:
print(f'{s3_path} generated an exception: {exc}')

yield None

if verbose:
with tqdm(total=len(s3_paths), desc='Downloading directory') as pbar:
for _ in process_futures():
pbar.update(1)
else:
for _ in process_futures():
pass

return target_dir, all_files

@classmethod
Expand All @@ -153,14 +167,28 @@ def upload_file(cls, local_path: str, target_url: str):
cls.s3_client.upload_file(local_path, bucket, key)

@classmethod
def upload_directory(cls, local_dir: Path, target_url: str):
def upload_directory(cls, local_dir: Path, target_url: str, verbose: bool = False):
bucket, key = cls.get_bucket_and_key_from_uri(target_url)
for root, dirs, files in os.walk(local_dir):

total_files = 0
if verbose:
for root, dirs, files in os.walk(local_dir):
total_files += len(files)

if verbose:
progress_bar = tqdm(os.walk(local_dir), desc='Uploading directory', total=total_files)
else:
progress_bar = os.walk(local_dir)

for root, dirs, files in progress_bar:
for file in files:
local_path = os.path.join(root, file)
s3_key = f'{key}/{os.path.relpath(local_path, local_dir)}'
cls.s3_client.upload_file(local_path, bucket, s3_key)

if verbose:
progress_bar.update(len(files))

@classmethod
def copy(cls, source_url: str, target_url: str):
s3_resource = boto3.resource('s3')
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ azure-storage-blob>=12.14.0
azure-identity>=1.10.0
azure-mgmt-storage>=21.1.0
boto3>=1.34.23
loguru>=0.7.2
loguru>=0.7.2
tqdm>=4.66.2
9 changes: 5 additions & 4 deletions tests/test_anypath_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ def test_is_file(path_type: PathType, temp_dir_with_files, clean_remote_dir):

@pytest.mark.usefixtures("clean_remote_dir")
@pytest.mark.parametrize("path_type", [PathType.azure, PathType.s3, PathType.local])
def test_caching(path_type: PathType, temp_dir_with_files, clean_remote_dir):
@pytest.mark.parametrize("verbose", [True, False])
def test_caching(path_type: PathType, temp_dir_with_files, clean_remote_dir, verbose: bool):
cloud_handler = PATH_TYPE_TO_HANDLER[path_type]
local_dir_path, local_dir_files = temp_dir_with_files
remote_dir = clean_remote_dir
cloud_handler.upload_directory(local_dir=local_dir_path, target_url=remote_dir)
target1 = AnyPath(remote_dir).copy(target=None, force_overwrite=False)
target2 = AnyPath(remote_dir).copy(target=None, force_overwrite=False)
cloud_handler.upload_directory(local_dir=local_dir_path, target_url=remote_dir, verbose=verbose)
target1 = AnyPath(remote_dir).copy(target=None, force_overwrite=False, verbose=verbose)
target2 = AnyPath(remote_dir).copy(target=None, force_overwrite=False, verbose=verbose)
assert target1.base_path == target2.base_path
9 changes: 5 additions & 4 deletions tests/test_download_from_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def test_copy_to_local_from_cloud(path_type: PathType, temp_dir_with_files, temp
local_dir_path, local_dir_files = temp_dir_with_files

remote_dir = clean_remote_dir
cloud_handler.upload_directory(local_dir=local_dir_path, target_url=remote_dir)
AnyPath(remote_dir).copy(target=AnyPath(temp_local_dir), force_overwrite=True)
cloud_handler.remove(remote_dir)
assert sorted([fn for fn in temp_local_dir.iterdir()]) == sorted([fn for fn in temp_local_dir.iterdir()])
cloud_handler.upload_directory(local_dir=local_dir_path, target_url=remote_dir, verbose=False)
local_download_dir = AnyPath(remote_dir).copy(target=AnyPath(temp_local_dir), force_overwrite=True)
remote_files = AnyPath(remote_dir).listdir()
assert sorted([fn.name for fn in remote_files]) == sorted(
[fn.name for fn in local_download_dir.listdir()])

0 comments on commit 5808c66

Please sign in to comment.