Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use token-based authentication instead of the access key in the AzureFileShareService #779

Merged
merged 24 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2a6cd0d
use token-based credential instead of a storage key in Azure file sha…
motus Jul 8, 2024
951fdba
initialize auth service before azure fileshare in unit tests
motus Jul 8, 2024
9aeaa92
late initialization for azure fileshare _share_client; proper monkey …
motus Jul 9, 2024
54c308b
remove storageAccountKey parameter from azure_fileshare fixture
motus Jul 9, 2024
9b65bd7
Merge branch 'main' into sergiym/svc/fs_token_auth
motus Jul 10, 2024
101b593
Merge branch 'main' into sergiym/svc/fs_token_auth
motus Jul 10, 2024
71fae3c
Merge branch 'main' into sergiym/svc/fs_token_auth
motus Jul 12, 2024
4bc4c25
Merge branch 'main' of https://github.com/microsoft/MLOS into sergiym…
motus Jul 12, 2024
289af50
Merge branch 'sergiym/svc/fs_token_auth' of https://github.com/motus/…
motus Jul 12, 2024
ead96a1
black formatting updates
motus Jul 12, 2024
a46bcd4
Merge branch 'main' into sergiym/svc/fs_token_auth
motus Jul 15, 2024
68c6d0d
Merge branch 'main' into sergiym/svc/fs_token_auth
motus Jul 16, 2024
99172ed
Merge branch 'main' of https://github.com/microsoft/MLOS into sergiym…
motus Jul 19, 2024
8078d70
docformatter fixes
motus Jul 19, 2024
527ae66
Merge branch 'sergiym/svc/fs_token_auth' of https://github.com/motus/…
motus Jul 19, 2024
56704f6
Merge branch 'main' into sergiym/svc/fs_token_auth
motus Jul 19, 2024
5af8b6f
Merge branch 'main' into sergiym/svc/fs_token_auth
bpkroth Jul 22, 2024
378f073
Merge branch 'main' into sergiym/svc/fs_token_auth
motus Jul 22, 2024
f2b47c6
formatting fixes
motus Jul 22, 2024
e9e48f2
remove git merge artifact
motus Jul 22, 2024
2925ebb
more git merge fixes
motus Jul 22, 2024
d9bf3fe
Merge branch 'main' into sergiym/svc/fs_token_auth
motus Jul 23, 2024
827492f
create share client every time to make sure token-based authenticatio…
motus Jul 23, 2024
af55d5e
roll back creation of the ShareClient every time we need it. Will pro…
motus Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from mlos_bench.services.base_fileshare import FileShareService
from mlos_bench.services.base_service import Service
from mlos_bench.services.types.authenticator_type import SupportsAuth
from mlos_bench.util import check_required_params

_LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -52,23 +53,30 @@ def __init__(
parent,
self.merge_methods(methods, [self.upload, self.download]),
)

check_required_params(
self.config,
{
"storageAccountName",
"storageFileShareName",
"storageAccountKey",
},
)

self._share_client = ShareClient.from_share_url(
AzureFileShareService._SHARE_URL.format(
account_name=self.config["storageAccountName"],
fs_name=self.config["storageFileShareName"],
),
credential=self.config["storageAccountKey"],
)
self._share_client: Optional[ShareClient] = None

def _get_share_client(self) -> ShareClient:
"""Get the Azure file share client object."""
if self._share_client is None:
assert self._parent is not None and isinstance(
self._parent, SupportsAuth
), "Authorization service not provided. Include service-auth.jsonc?"
motus marked this conversation as resolved.
Show resolved Hide resolved
self._share_client = ShareClient.from_share_url(
self._SHARE_URL.format(
account_name=self.config["storageAccountName"],
fs_name=self.config["storageFileShareName"],
),
credential=self._parent.get_access_token(),
token_intent="backup",
)
return self._share_client

def download(
self,
Expand All @@ -78,7 +86,7 @@ def download(
recursive: bool = True,
) -> None:
super().download(params, remote_path, local_path, recursive)
dir_client = self._share_client.get_directory_client(remote_path)
dir_client = self._get_share_client().get_directory_client(remote_path)
if dir_client.exists():
os.makedirs(local_path, exist_ok=True)
for content in dir_client.list_directories_and_files():
Expand All @@ -91,7 +99,7 @@ def download(
# Ensure parent folders exist
folder, _ = os.path.split(local_path)
os.makedirs(folder, exist_ok=True)
file_client = self._share_client.get_file_client(remote_path)
file_client = self._get_share_client().get_file_client(remote_path)
try:
data = file_client.download_file()
with open(local_path, "wb") as output_file:
Expand Down Expand Up @@ -147,7 +155,7 @@ def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[
# Ensure parent folders exist
folder, _ = os.path.split(remote_path)
self._remote_makedirs(folder)
file_client = self._share_client.get_file_client(remote_path)
file_client = self._get_share_client().get_file_client(remote_path)
with open(local_path, "rb") as file_data:
_LOG.debug("Upload file: %s -> %s", local_path, remote_path)
file_client.upload_file(file_data)
Expand All @@ -167,6 +175,6 @@ def _remote_makedirs(self, remote_path: str) -> None:
if not folder:
continue
path += folder + "/"
dir_client = self._share_client.get_directory_client(path)
dir_client = self._get_share_client().get_directory_client(path)
if not dir_client.exists():
dir_client.create_directory()
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ def test_download_file(
local_folder = "some/local/folder"
remote_path = f"{remote_folder}/{filename}"
local_path = f"{local_folder}/{filename}"
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access

config: dict = {}
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, patch.object(
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
mock_share_client, "get_file_client"
) as mock_get_file_client, patch.object(
mock_share_client, "get_directory_client"
) as mock_get_directory_client:

mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False))

azure_fileshare.download(config, remote_path, local_path)
Expand Down Expand Up @@ -81,8 +84,9 @@ def test_download_folder_non_recursive(
local_folder = "some/local/folder"
dir_client_returns = make_dir_client_returns(remote_folder)
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access

config: dict = {}
with patch.object(
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
mock_share_client, "get_directory_client"
) as mock_get_directory_client, patch.object(
mock_share_client, "get_file_client"
Expand Down Expand Up @@ -114,15 +118,14 @@ def test_download_folder_recursive(
remote_folder = "a/remote/folder"
local_folder = "some/local/folder"
dir_client_returns = make_dir_client_returns(remote_folder)
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access

config: dict = {}
with patch.object(
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
mock_share_client, "get_directory_client"
) as mock_get_directory_client, patch.object(
mock_share_client, "get_file_client"
) as mock_get_file_client:
mock_get_directory_client.side_effect = lambda x: dir_client_returns[x]

azure_fileshare.download(config, remote_folder, local_folder, recursive=True)

mock_get_file_client.assert_has_calls(
Expand Down Expand Up @@ -157,9 +160,11 @@ def test_upload_file(
local_path = f"{local_folder}/{filename}"
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
mock_isdir.return_value = False
config: dict = {}

with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
config: dict = {}
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
mock_share_client, "get_file_client"
) as mock_get_file_client:
azure_fileshare.upload(config, local_path, remote_path)

mock_get_file_client.assert_called_with(remote_path)
Expand Down Expand Up @@ -228,9 +233,11 @@ def test_upload_directory_non_recursive(
mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)]
mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)]
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
config: dict = {}

with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
config: dict = {}
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
mock_share_client, "get_file_client"
) as mock_get_file_client:
azure_fileshare.upload(config, local_folder, remote_folder, recursive=False)

mock_get_file_client.assert_called_with(f"{remote_folder}/a_file_1.csv")
Expand All @@ -252,9 +259,11 @@ def test_upload_directory_recursive(
mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)]
mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)]
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
config: dict = {}

with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
config: dict = {}
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
mock_share_client, "get_file_client"
) as mock_get_file_client:
azure_fileshare.upload(config, local_folder, remote_folder, recursive=True)

mock_get_file_client.assert_has_calls(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def azure_vm_service_remote_exec_only(azure_auth_service: AzureAuthService) -> A


@pytest.fixture
def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> AzureFileShareService:
def azure_fileshare(azure_auth_service: AzureAuthService) -> AzureFileShareService:
"""Creates a dummy AzureFileShareService for tests that require it."""
with patch("mlos_bench.services.remote.azure.azure_fileshare.ShareClient"):
return AzureFileShareService(
Expand All @@ -112,5 +112,5 @@ def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> Azu
"storageAccountKey": "TEST_ACCOUNT_KEY",
},
global_config={},
parent=config_persistence_service,
parent=azure_auth_service,
)
Loading