Skip to content

Commit

Permalink
added more tests and fixes wrt comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ethantang-db committed Nov 2, 2024
1 parent 9bb255a commit 537a5ae
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 15 deletions.
2 changes: 1 addition & 1 deletion streaming/base/storage/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get(cls, remote_dir: Optional[str] = None) -> 'CloudDownloader':
if remote_dir is None:
return _LOCAL_DOWNLOADER()

logger.info('Acquiring downloader client for remote directory %s', remote_dir)
logger.debug('Acquiring downloader client for remote directory %s', remote_dir)

prefix = urllib.parse.urlparse(remote_dir).scheme
if prefix == 'dbfs' and remote_dir.startswith('dbfs:/Volumes'):
Expand Down
2 changes: 1 addition & 1 deletion streaming/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def retry( # type: ignore
num_tries = 0
def clean_up():
print("cleaning up")
# Do clean up stuff here
@retry(RuntimeError, clean_up_fn=clean_up, num_attempts=3, initial_backoff=0.1)
def flaky_function():
Expand Down
63 changes: 50 additions & 13 deletions tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
from tests.conftest import GCS_URL, MY_BUCKET, R2_URL

MY_PREFIX = 'train'
TEST_FILE = 'file.txt'


@pytest.fixture(scope='function')
def remote_local_file() -> Any:
"""Creates a temporary directory and then deletes it when the calling function is done."""

def _method(cloud_prefix: str = '', filename: str = 'file.txt') -> tuple[str, str]:
def _method(cloud_prefix: str = '', filename: str = TEST_FILE) -> tuple[str, str]:
try:
mock_local_dir = tempfile.TemporaryDirectory()
mock_local_filepath = os.path.join(mock_local_dir.name, filename)
Expand Down Expand Up @@ -114,15 +115,14 @@ class TestGCSClient:
@pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test', 'remote_local_file')
def test_download_from_gcs(self, remote_local_file: Any):
with tempfile.TemporaryDirectory() as tmp_dir:
file_name = 'file.txt'
tmp = os.path.join(tmp_dir, file_name)
mock_remote_filepath, _ = remote_local_file(cloud_prefix='gs://', filename=file_name)
tmp = os.path.join(tmp_dir, TEST_FILE)
mock_remote_filepath, _ = remote_local_file(cloud_prefix='gs://', filename=TEST_FILE)
client = boto3.client('s3',
region_name='us-east-1',
endpoint_url=GCS_URL,
aws_access_key_id=os.environ['GCS_KEY'],
aws_secret_access_key=os.environ['GCS_SECRET'])
client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, file_name), Body='')
client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, TEST_FILE), Body='')
downloader = GCSDownloader()
downloader.download(mock_remote_filepath, tmp)
assert os.path.isfile(tmp)
Expand Down Expand Up @@ -176,20 +176,58 @@ class TestDatabricksUnityCatalog:
@pytest.mark.parametrize('cloud_prefix', ['dbfs:/Volumess', 'dbfs:/Content'])
def test_invalid_prefix_from_db_uc(self, remote_local_file: Any, cloud_prefix: str):
with tempfile.TemporaryDirectory() as tmp_dir:
file_name = os.path.join(tmp_dir, 'file.txt')
mock_remote_filepath, _ = remote_local_file(cloud_prefix=cloud_prefix,
filename=file_name)
file_name = os.path.join(tmp_dir, TEST_FILE)
mock_remote_filepath, _ = remote_local_file(cloud_prefix=cloud_prefix)
with pytest.raises(Exception, match='Expected path prefix to be `dbfs:/Volumes`.*'):
downloader = DatabricksUnityCatalogDownloader()
downloader.download(mock_remote_filepath, file_name)

@patch('databricks.sdk.WorkspaceClient', autospec=True)
def test_databricks_error_file_not_found(self, workspace_client_mock: Mock,
remote_local_file: Any):
from databricks.sdk.core import DatabricksError
workspace_client_mock_instance = workspace_client_mock.return_value
workspace_client_mock_instance.files = Mock()
workspace_client_mock_instance.files.download = Mock()
download_return_val = workspace_client_mock_instance.files.download.return_value
download_return_val.contents = Mock()
download_return_val.contents.__enter__ = Mock(
side_effect=DatabricksError('Error', error_code='NOT_FOUND'))
download_return_val.contents.__exit__ = Mock()

with tempfile.TemporaryDirectory() as tmp_dir:
file_name = os.path.join(tmp_dir, TEST_FILE)
mock_remote_filepath, _ = remote_local_file(cloud_prefix='dbfs:/Volumes')
with pytest.raises(FileNotFoundError):
downloader = DatabricksUnityCatalogDownloader()
downloader.download(mock_remote_filepath, file_name)

@patch('databricks.sdk.WorkspaceClient', autospec=True)
def test_databricks_error(self, workspace_client_mock: Mock, remote_local_file: Any):
from databricks.sdk.core import DatabricksError
workspace_client_mock_instance = workspace_client_mock.return_value
workspace_client_mock_instance.files = Mock()
workspace_client_mock_instance.files.download = Mock()
download_return_val = workspace_client_mock_instance.files.download.return_value
download_return_val.contents = Mock()
download_return_val.contents.__enter__ = Mock(
side_effect=DatabricksError('Error', error_code='REQUEST_LIMIT_EXCEEDED'))
download_return_val.contents.__exit__ = Mock()

with tempfile.TemporaryDirectory() as tmp_dir:
file_name = os.path.join(tmp_dir, TEST_FILE)
mock_remote_filepath, _ = remote_local_file(cloud_prefix='dbfs:/Volumes')
with pytest.raises(DatabricksError):
downloader = DatabricksUnityCatalogDownloader()
downloader.download(mock_remote_filepath, file_name)


class TestDatabricksFileSystem:

def test_invalid_prefix_from_dbfs(self, remote_local_file: Any):
with tempfile.TemporaryDirectory() as tmp_dir:
file_name = os.path.join(tmp_dir, 'file.txt')
mock_remote_filepath, _ = remote_local_file(cloud_prefix='dbfsx:/', filename=file_name)
file_name = os.path.join(tmp_dir, TEST_FILE)
mock_remote_filepath, _ = remote_local_file(cloud_prefix='dbfsx:/')
with pytest.raises(Exception, match='Expected remote path to start with.*'):
downloader = DBFSDownloader()
downloader.download(mock_remote_filepath, file_name)
Expand All @@ -198,9 +236,8 @@ def test_invalid_prefix_from_dbfs(self, remote_local_file: Any):
def test_download_from_local():
mock_remote_dir = tempfile.TemporaryDirectory()
mock_local_dir = tempfile.TemporaryDirectory()
file_name = 'file.txt'
mock_remote_file = os.path.join(mock_remote_dir.name, file_name)
mock_local_file = os.path.join(mock_local_dir.name, file_name)
mock_remote_file = os.path.join(mock_remote_dir.name, TEST_FILE)
mock_local_file = os.path.join(mock_local_dir.name, TEST_FILE)
# Creates a new empty file
with open(mock_remote_file, 'w') as _:
pass
Expand Down

0 comments on commit 537a5ae

Please sign in to comment.