From aee7deedae138530a4ad5fbb556980f08113210f Mon Sep 17 00:00:00 2001 From: Jorge Esteban Quilcate Otoya Date: Wed, 20 Dec 2023 09:45:08 -0500 Subject: [PATCH] feat: add custom Azure host/port to support custom blob endpoint e.g. azurite --- rohmu/object_storage/azure.py | 57 ++++++++++++-------- rohmu/object_storage/config.py | 26 ++++++++- test/object_storage/test_azure.py | 90 ++++++++++++++++++++++++++++++- 3 files changed, 150 insertions(+), 23 deletions(-) diff --git a/rohmu/object_storage/azure.py b/rohmu/object_storage/azure.py index 0a09e07d..4fe8c95d 100644 --- a/rohmu/object_storage/azure.py +++ b/rohmu/object_storage/azure.py @@ -24,6 +24,7 @@ SourceStorageModelT, ) from rohmu.object_storage.config import ( # pylint: disable=unused-import + AZURE_ENDPOINT_SUFFIXES, AZURE_MAX_BLOCK_SIZE as MAX_BLOCK_SIZE, AzureObjectStorageConfig as Config, calculate_azure_max_block_size as calculate_max_block_size, @@ -42,14 +43,6 @@ from azure.storage.blob._models import BlobPrefix, BlobType # type: ignore -ENDPOINT_SUFFIXES = { - None: "core.windows.net", - "germany": "core.cloudapi.de", # Azure Germany is a completely separate cloud from the regular Azure Public cloud - "china": "core.chinacloudapi.cn", - "public": "core.windows.net", -} - - # Reduce Azure logging verbocity of http requests and responses logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING) @@ -64,6 +57,9 @@ def __init__( account_key: Optional[str] = None, sas_token: Optional[str] = None, prefix: Optional[str] = None, + is_secure: bool = True, + host: Optional[str] = None, + port: Optional[int] = None, azure_cloud: Optional[str] = None, proxy_info: Optional[dict[str, Union[str, int]]] = None, notifier: Optional[Notifier] = None, @@ -78,16 +74,13 @@ def __init__( self.account_key = account_key self.container_name = bucket_name self.sas_token = sas_token - try: - endpoint_suffix = ENDPOINT_SUFFIXES[azure_cloud] - except KeyError: - raise InvalidConfigurationError(f"Unknown azure cloud {repr(azure_cloud)}") - - conn_str = ( - "DefaultEndpointsProtocol=https;" - f"AccountName={self.account_name};" - f"AccountKey={self.account_key};" - f"EndpointSuffix={endpoint_suffix}" + conn_str = self.conn_string( + account_name=account_name, + account_key=account_key, + azure_cloud=azure_cloud, + host=host, + port=port, + is_secure=is_secure, ) config: dict[str, Any] = {"max_block_size": MAX_BLOCK_SIZE} if proxy_info: @@ -97,13 +90,13 @@ def __init__( auth = f"{username}:{password}@" else: auth = "" - host = proxy_info["host"] - port = proxy_info["port"] + proxy_host = proxy_info["host"] + proxy_port = proxy_info["port"] if proxy_info.get("type") == "socks5": schema = "socks5" else: schema = "http" - config["proxies"] = {"https": f"{schema}://{auth}{host}:{port}"} + config["proxies"] = {"https": f"{schema}://{auth}{proxy_host}:{proxy_port}"} self.conn: BlobServiceClient = BlobServiceClient.from_connection_string( conn_str=conn_str, @@ -113,6 +106,28 @@ def __init__( self.container = self.get_or_create_container(self.container_name) self.log.debug("AzureTransfer initialized, %r", self.container_name) + @staticmethod + def conn_string( + account_name: str, + account_key: Optional[str], + azure_cloud: Optional[str], + host: Optional[str], + port: Optional[int], + is_secure: bool, + ) -> str: + protocol = "https" if is_secure else "http" + conn = [ + f"DefaultEndpointsProtocol={protocol}", + f"AccountName={account_name}", + f"AccountKey={account_key}", + ] + if not host and not port: + endpoint_suffix = AZURE_ENDPOINT_SUFFIXES[azure_cloud] + conn.append(f"EndpointSuffix={endpoint_suffix}") + else: + conn.append(f"BlobEndpoint={protocol}://{host}:{port}/{account_name}") + return ";".join(conn) + def copy_file( self, *, source_key: str, destination_key: str, metadata: Optional[Metadata] = None, **kwargs: Any ) -> None: diff --git a/rohmu/object_storage/config.py b/rohmu/object_storage/config.py index 5bfec47c..2068c6bb 100644 --- a/rohmu/object_storage/config.py +++ b/rohmu/object_storage/config.py @@ -10,7 +10,7 @@ from enum import Enum, unique from pathlib import Path -from pydantic import Field, root_validator +from pydantic import Field, root_validator, validator from rohmu.common.models import ProxyInfo, StorageDriver, StorageModel from typing import Any, Dict, Final, Literal, Optional, TypeVar @@ -42,6 +42,12 @@ def calculate_azure_max_block_size() -> int: return max(min(int(total_mem_mib / 1000), 100), 4) * 1024 * 1024 +AZURE_ENDPOINT_SUFFIXES = { + None: "core.windows.net", + "germany": "core.cloudapi.de", # Azure Germany is a completely separate cloud from the regular Azure Public cloud + "china": "core.chinacloudapi.cn", + "public": "core.windows.net", +} # Increase block size based on host memory. Azure supports up to 50k blocks and up to 5 TiB individual # files. Default block size is set to 4 MiB so only ~200 GB files can be uploaded. In order to get close # to that 5 TiB increase the block size based on host memory; we don't want to use the max 100 for all @@ -83,10 +89,28 @@ class AzureObjectStorageConfig(StorageModel): account_key: Optional[str] = Field(None, repr=False) sas_token: Optional[str] = Field(None, repr=False) prefix: Optional[str] = None + is_secure: bool = True + host: Optional[str] = None + port: Optional[int] = None azure_cloud: Optional[str] = None proxy_info: Optional[ProxyInfo] = None storage_type: Literal[StorageDriver.azure] = StorageDriver.azure + @root_validator + @classmethod + def host_and_port_must_be_set_together(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "host" in values and "port" in values: + if not values["host"] or not values["port"]: + raise ValueError("host and port must be set together") + return values + + @validator("azure_cloud") + @classmethod + def valid_azure_cloud_endpoint(cls, v: str) -> str: + if v not in AZURE_ENDPOINT_SUFFIXES: + raise ValueError(f"azure_cloud must be one of {AZURE_ENDPOINT_SUFFIXES.keys()}") + return v + class GoogleObjectStorageConfig(StorageModel): project_id: str diff --git a/test/object_storage/test_azure.py b/test/object_storage/test_azure.py index e193f53d..eeb20287 100644 --- a/test/object_storage/test_azure.py +++ b/test/object_storage/test_azure.py @@ -2,9 +2,10 @@ from datetime import datetime from io import BytesIO from rohmu.errors import InvalidByteRangeError +from rohmu.object_storage.config import AzureObjectStorageConfig from tempfile import NamedTemporaryFile from types import ModuleType -from typing import Any, Tuple +from typing import Any, Optional, Tuple from unittest.mock import MagicMock, patch import pytest @@ -103,3 +104,90 @@ def test_get_contents_to_fileobj_raises_error_on_invalid_byte_range(azure_module fileobj_to_store_to=BytesIO(), byte_range=(100, 10), ) + + +def test_azure_config_host_port_set_together() -> None: + with pytest.raises(ValueError): + AzureObjectStorageConfig(account_name="test", host="localhost") + with pytest.raises(ValueError): + AzureObjectStorageConfig(account_name="test", port=10000) + + +def test_valid_azure_cloud_endpoint() -> None: + with pytest.raises(ValueError): + AzureObjectStorageConfig(account_name="test", azure_cloud="invalid") + + +@pytest.mark.parametrize( + "host,port,is_secured,expected", + [ + ( + None, + None, + True, + ";".join( + [ + "DefaultEndpointsProtocol=https", + "AccountName=test_name", + "AccountKey=test_key", + "EndpointSuffix=core.windows.net", + ] + ), + ), + ( + None, + None, + False, + ";".join( + [ + "DefaultEndpointsProtocol=http", + "AccountName=test_name", + "AccountKey=test_key", + "EndpointSuffix=core.windows.net", + ] + ), + ), + ( + "localhost", + 10000, + True, + ";".join( + [ + "DefaultEndpointsProtocol=https", + "AccountName=test_name", + "AccountKey=test_key", + "BlobEndpoint=https://localhost:10000/test_name", + ] + ), + ), + ( + "localhost", + 10000, + False, + ";".join( + [ + "DefaultEndpointsProtocol=http", + "AccountName=test_name", + "AccountKey=test_key", + "BlobEndpoint=http://localhost:10000/test_name", + ] + ), + ), + ], +) +def test_conn_string(host: Optional[str], port: Optional[int], is_secured: bool, expected: str) -> None: + get_blob_client_mock = MagicMock() + blob_client = MagicMock(get_blob_client=get_blob_client_mock) + service_client = MagicMock(from_connection_string=MagicMock(return_value=blob_client)) + module_patches = { + "azure.common": MagicMock(), + "azure.core.exceptions": MagicMock(), + "azure.storage.blob": MagicMock(BlobServiceClient=service_client), + } + with patch.dict(sys.modules, module_patches): + from rohmu.object_storage.azure import AzureTransfer + + conn_string = AzureTransfer.conn_string( + account_name="test_name", account_key="test_key", azure_cloud=None, host=host, port=port, is_secure=is_secured + ) + assert expected == conn_string