Skip to content

Commit

Permalink
feat: add custom Azure host/port to support custom blob endpoint
Browse files Browse the repository at this point in the history
 e.g. azurite
  • Loading branch information
jeqo committed Dec 28, 2023
1 parent cc2b427 commit 63da30f
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 14 deletions.
51 changes: 38 additions & 13 deletions rohmu/object_storage/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def __init__(
account_key: Optional[str] = None,
sas_token: Optional[str] = None,
prefix: Optional[str] = None,
is_secure: bool = True,
host: Optional[str] = None,
port: Optional[int] = None,
azure_cloud: Optional[str] = None,
proxy_info: Optional[dict[str, Union[str, int]]] = None,
notifier: Optional[Notifier] = None,
Expand All @@ -78,16 +81,13 @@ def __init__(
self.account_key = account_key
self.container_name = bucket_name
self.sas_token = sas_token
try:
endpoint_suffix = ENDPOINT_SUFFIXES[azure_cloud]
except KeyError:
raise InvalidConfigurationError(f"Unknown azure cloud {repr(azure_cloud)}")

conn_str = (
"DefaultEndpointsProtocol=https;"
f"AccountName={self.account_name};"
f"AccountKey={self.account_key};"
f"EndpointSuffix={endpoint_suffix}"
conn_str = self.conn_string(
account_name=account_name,
account_key=account_key,
azure_cloud=azure_cloud,
host=host,
port=port,
is_secure=is_secure,
)
config: dict[str, Any] = {"max_block_size": MAX_BLOCK_SIZE}
if proxy_info:
Expand All @@ -97,13 +97,13 @@ def __init__(
auth = f"{username}:{password}@"
else:
auth = ""
host = proxy_info["host"]
port = proxy_info["port"]
proxy_host = proxy_info["host"]
proxy_port = proxy_info["port"]
if proxy_info.get("type") == "socks5":
schema = "socks5"
else:
schema = "http"
config["proxies"] = {"https": f"{schema}://{auth}{host}:{port}"}
config["proxies"] = {"https": f"{schema}://{auth}{proxy_host}:{proxy_port}"}

self.conn: BlobServiceClient = BlobServiceClient.from_connection_string(
conn_str=conn_str,
Expand All @@ -113,6 +113,31 @@ def __init__(
self.container = self.get_or_create_container(self.container_name)
self.log.debug("AzureTransfer initialized, %r", self.container_name)

@staticmethod
def conn_string(
account_name: str,
account_key: Optional[str],
azure_cloud: Optional[str],
host: Optional[str],
port: Optional[int],
is_secure: bool,
) -> str:
protocol = "https" if is_secure else "http"
conn_str = f"DefaultEndpointsProtocol={protocol};" f"AccountName={account_name};" f"AccountKey={account_key};"
if not host and not port:
try:
endpoint_suffix = ENDPOINT_SUFFIXES[azure_cloud]
except KeyError:
raise InvalidConfigurationError(f"Unknown azure cloud {repr(azure_cloud)}")

conn_str = f"{conn_str}" f"EndpointSuffix={endpoint_suffix};"
else:
if not host or not port:
raise InvalidConfigurationError("Custom host and port must be specified together")

conn_str = f"{conn_str}" f"BlobEndpoint={protocol}://{host}:{port}/{account_name};"
return conn_str

def copy_file(
self, *, source_key: str, destination_key: str, metadata: Optional[Metadata] = None, **kwargs: Any
) -> None:
Expand Down
3 changes: 3 additions & 0 deletions rohmu/object_storage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class AzureObjectStorageConfig(StorageModel):
account_key: Optional[str] = Field(None, repr=False)
sas_token: Optional[str] = Field(None, repr=False)
prefix: Optional[str] = None
is_secure: bool = True
host: Optional[str] = None
port: Optional[int] = None
azure_cloud: Optional[str] = None
proxy_info: Optional[ProxyInfo] = None
storage_type: Literal[StorageDriver.azure] = StorageDriver.azure
Expand Down
51 changes: 50 additions & 1 deletion test/object_storage/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from rohmu.errors import InvalidByteRangeError
from tempfile import NamedTemporaryFile
from types import ModuleType
from typing import Any, Tuple
from typing import Any, Optional, Tuple
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -103,3 +103,52 @@ def test_get_contents_to_fileobj_raises_error_on_invalid_byte_range(azure_module
fileobj_to_store_to=BytesIO(),
byte_range=(100, 10),
)


@pytest.mark.parametrize(
"host,port,is_secured,expected",
[
(
None,
None,
True,
"DefaultEndpointsProtocol=https;AccountName=test_name;AccountKey=test_key;EndpointSuffix=core.windows.net;",
),
(
None,
None,
False,
"DefaultEndpointsProtocol=http;AccountName=test_name;AccountKey=test_key;EndpointSuffix=core.windows.net;",
),
(
"localhost",
10000,
True,
"DefaultEndpointsProtocol=https;AccountName=test_name;AccountKey=test_key;"
"BlobEndpoint=https://localhost:10000/test_name;",
),
(
"localhost",
10000,
False,
"DefaultEndpointsProtocol=http;AccountName=test_name;AccountKey=test_key;"
"BlobEndpoint=http://localhost:10000/test_name;",
),
],
)
def test_conn_string(host: Optional[str], port: Optional[int], is_secured: bool, expected: str) -> None:
get_blob_client_mock = MagicMock()
blob_client = MagicMock(get_blob_client=get_blob_client_mock)
service_client = MagicMock(from_connection_string=MagicMock(return_value=blob_client))
module_patches = {
"azure.common": MagicMock(),
"azure.core.exceptions": MagicMock(),
"azure.storage.blob": MagicMock(BlobServiceClient=service_client),
}
with patch.dict(sys.modules, module_patches):
from rohmu.object_storage.azure import AzureTransfer

conn_string = AzureTransfer.conn_string(
account_name="test_name", account_key="test_key", azure_cloud=None, host=host, port=port, is_secure=is_secured
)
assert expected == conn_string

0 comments on commit 63da30f

Please sign in to comment.