Skip to content

Commit 45528cf

Browse files
eujingEu Jing Chuamotus
authored
Proposed fixes for token-based auth in azure fileshare service (#820)
To use token-based authentication for `ShareClient`, I think we should be passing in the credential object derived from `TokenCredential` (in our case, some instance of `DefaultAzureCredential`. Previously, we were passing specific string tokens to the `credential` argument, which is being intepreted as a SAS token. This leads to errors like: `azure.core.exceptions.ClientAuthenticationError: Server failed to authenticate the request. Make sure the value of Authorization header is formed correctly including the signature.` [ShareClient](https://learn.microsoft.com/en-us/python/api/azure-storage-file-share/azure.storage.fileshare.shareclient?view=azure-python) documentation on the `credential` argument. By passing in the whole `TokenCredential` object, I believe `ShareClient` will manage the token lifecycle and we won't need to do so as mentioned in #818. --------- Co-authored-by: Eu Jing Chua <[email protected]> Co-authored-by: Sergiy Matusevych <[email protected]>
1 parent 7fe167d commit 45528cf

File tree

5 files changed

+63
-24
lines changed

5 files changed

+63
-24
lines changed

mlos_bench/mlos_bench/services/remote/azure/azure_auth.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from datetime import datetime
1010
from typing import Any, Callable, Dict, List, Optional, Union
1111

12-
import azure.identity as azure_id
12+
from azure.core.credentials import TokenCredential
13+
from azure.identity import CertificateCredential, DefaultAzureCredential
1314
from azure.keyvault.secrets import SecretClient
1415
from pytz import UTC
1516

@@ -20,7 +21,7 @@
2021
_LOG = logging.getLogger(__name__)
2122

2223

23-
class AzureAuthService(Service, SupportsAuth):
24+
class AzureAuthService(Service, SupportsAuth[TokenCredential]):
2425
"""Helper methods to get access to Azure services."""
2526

2627
_REQ_INTERVAL = 300 # = 5 min
@@ -56,6 +57,7 @@ def __init__(
5657
[
5758
self.get_access_token,
5859
self.get_auth_headers,
60+
self.get_credential,
5961
],
6062
),
6163
)
@@ -65,10 +67,7 @@ def __init__(
6567

6668
self._access_token = "RENEW *NOW*"
6769
self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp.
68-
69-
# Login as the first identity available, usually ourselves or a managed identity
70-
self._cred: Union[azure_id.DefaultAzureCredential, azure_id.CertificateCredential]
71-
self._cred = azure_id.DefaultAzureCredential()
70+
self._cred: Optional[TokenCredential] = None
7271

7372
# Verify info required for SP auth early
7473
if "spClientId" in self.config:
@@ -82,18 +81,22 @@ def __init__(
8281
},
8382
)
8483

85-
def _init_sp(self) -> None:
84+
def get_credential(self) -> TokenCredential:
85+
"""Return the Azure SDK credential object."""
8686
# Perform this initialization outside of __init__ so that environment loading tests
8787
# don't need to specifically mock keyvault interactions out
88+
if self._cred is not None:
89+
return self._cred
8890

89-
# Already logged in as SP
90-
if isinstance(self._cred, azure_id.CertificateCredential):
91-
return
91+
self._cred = DefaultAzureCredential()
92+
if "spClientId" not in self.config:
93+
return self._cred
9294

9395
sp_client_id = self.config["spClientId"]
9496
keyvault_name = self.config["keyVaultName"]
9597
cert_name = self.config["certName"]
9698
tenant_id = self.config["tenant"]
99+
_LOG.debug("Log in with Azure Service Principal %s", sp_client_id)
97100

98101
# Get a client for fetching cert info
99102
keyvault_secrets_client = SecretClient(
@@ -108,23 +111,20 @@ def _init_sp(self) -> None:
108111
cert_bytes = b64decode(secret.value)
109112

110113
# Reauthenticate as the service principal.
111-
self._cred = azure_id.CertificateCredential(
114+
self._cred = CertificateCredential(
112115
tenant_id=tenant_id,
113116
client_id=sp_client_id,
114117
certificate_data=cert_bytes,
115118
)
119+
return self._cred
116120

117121
def get_access_token(self) -> str:
118122
"""Get the access token from Azure CLI, if expired."""
119-
# Ensure we are logged as the Service Principal, if provided
120-
if "spClientId" in self.config:
121-
self._init_sp()
122-
123123
ts_diff = (self._token_expiration_ts - datetime.now(UTC)).total_seconds()
124124
_LOG.debug("Time to renew the token: %.2f sec.", ts_diff)
125125
if ts_diff < self._req_interval:
126126
_LOG.debug("Request new accessToken")
127-
res = self._cred.get_token("https://management.azure.com/.default")
127+
res = self.get_credential().get_token("https://management.azure.com/.default")
128128
self._token_expiration_ts = datetime.fromtimestamp(res.expires_on, tz=UTC)
129129
self._access_token = res.token
130130
_LOG.info("Got new accessToken. Expiration time: %s", self._token_expiration_ts)

mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
from typing import Any, Callable, Dict, List, Optional, Set, Union
1010

11+
from azure.core.credentials import TokenCredential
1112
from azure.core.exceptions import ResourceNotFoundError
1213
from azure.storage.fileshare import ShareClient
1314

@@ -60,20 +61,25 @@ def __init__(
6061
"storageFileShareName",
6162
},
6263
)
64+
assert self._parent is not None and isinstance(
65+
self._parent, SupportsAuth
66+
), "Authorization service not provided. Include service-auth.jsonc?"
67+
self._auth_service: SupportsAuth[TokenCredential] = self._parent
6368
self._share_client: Optional[ShareClient] = None
6469

6570
def _get_share_client(self) -> ShareClient:
6671
"""Get the Azure file share client object."""
6772
if self._share_client is None:
68-
assert self._parent is not None and isinstance(
69-
self._parent, SupportsAuth
70-
), "Authorization service not provided. Include service-auth.jsonc?"
73+
credential = self._auth_service.get_credential()
74+
assert isinstance(
75+
credential, TokenCredential
76+
), f"Expected a TokenCredential, but got {type(credential)} instead."
7177
self._share_client = ShareClient.from_share_url(
7278
self._SHARE_URL.format(
7379
account_name=self.config["storageAccountName"],
7480
fs_name=self.config["storageFileShareName"],
7581
),
76-
credential=self._parent.get_access_token(),
82+
credential=credential,
7783
token_intent="backup",
7884
)
7985
return self._share_client

mlos_bench/mlos_bench/services/types/authenticator_type.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
#
55
"""Protocol interface for authentication for the cloud services."""
66

7-
from typing import Protocol, runtime_checkable
7+
from typing import Protocol, TypeVar, runtime_checkable
8+
9+
T_co = TypeVar("T_co", covariant=True)
810

911

1012
@runtime_checkable
11-
class SupportsAuth(Protocol):
13+
class SupportsAuth(Protocol[T_co]):
1214
"""Protocol interface for authentication for the cloud services."""
1315

1416
def get_access_token(self) -> str:
@@ -30,3 +32,13 @@ def get_auth_headers(self) -> dict:
3032
access_header : dict
3133
HTTP header containing the access token.
3234
"""
35+
36+
def get_credential(self) -> T_co:
37+
"""
38+
Get the credential object for cloud services.
39+
40+
Returns
41+
-------
42+
credential : T
43+
Cloud-specific credential object.
44+
"""

mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,28 @@ def test_load_service_config_examples(
4848
config_path: str,
4949
) -> None:
5050
"""Tests loading a config example."""
51+
parent: Service = config_loader_service
5152
config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE)
53+
# Add other services that require a SupportsAuth parent service as necessary.
54+
requires_auth_service_parent = {
55+
"AzureFileShareService",
56+
}
57+
config_class_name = str(config.get("class", "MISSING CLASS")).rsplit(".", maxsplit=1)[-1]
58+
if config_class_name in requires_auth_service_parent:
59+
# AzureFileShareService requires an auth service to be loaded as well.
60+
auth_service_config = config_loader_service.load_config(
61+
"services/remote/mock/mock_auth_service.jsonc",
62+
ConfigSchema.SERVICE,
63+
)
64+
auth_service = config_loader_service.build_service(
65+
config=auth_service_config,
66+
parent=config_loader_service,
67+
)
68+
parent = auth_service
5269
# Make an instance of the class based on the config.
5370
service_inst = config_loader_service.build_service(
5471
config=config,
55-
parent=config_loader_service,
72+
parent=parent,
5673
)
5774
assert service_inst is not None
5875
assert isinstance(service_inst, Service)

mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
_LOG = logging.getLogger(__name__)
1414

1515

16-
class MockAuthService(Service, SupportsAuth):
16+
class MockAuthService(Service, SupportsAuth[str]):
1717
"""A collection Service functions for mocking authentication ops."""
1818

1919
def __init__(
@@ -32,6 +32,7 @@ def __init__(
3232
[
3333
self.get_access_token,
3434
self.get_auth_headers,
35+
self.get_credential,
3536
],
3637
),
3738
)
@@ -41,3 +42,6 @@ def get_access_token(self) -> str:
4142

4243
def get_auth_headers(self) -> dict:
4344
return {"Authorization": "Bearer " + self.get_access_token()}
45+
46+
def get_credential(self) -> str:
47+
return "MOCK CREDENTIAL"

0 commit comments

Comments
 (0)