|
4 | 4 | import sys
|
5 | 5 | from pathlib import Path
|
6 | 6 | from time import sleep
|
7 |
| -from typing import Dict, List, Optional, Type, TypeVar |
| 7 | +from typing import Dict, List, Optional, Type, TypeVar, Union |
8 | 8 | from urllib.parse import urlparse
|
9 | 9 |
|
10 | 10 | from azure.common.credentials import get_cli_profile
|
11 | 11 | from azure.core.exceptions import ClientAuthenticationError
|
12 |
| -from azure.identity import DefaultAzureCredential |
| 12 | +from azure.identity import DefaultAzureCredential, ClientSecretCredential |
13 | 13 | from azure.mgmt.compute import ComputeManagementClient
|
14 | 14 | from azure.mgmt.resource import ResourceManagementClient
|
15 | 15 |
|
@@ -125,7 +125,7 @@ def get_access_report(log_url: str = None) -> AccessReport:
|
125 | 125 | )
|
126 | 126 |
|
127 | 127 | try:
|
128 |
| - DefaultAzureCredential().get_token("https://management.azure.com/.default") |
| 128 | + get_azure_credential().get_token("https://management.azure.com/.default") |
129 | 129 | report.append(
|
130 | 130 | AccessReportLine(
|
131 | 131 | name="Azure Authentication",
|
@@ -463,20 +463,47 @@ def _get_databricks_resource_group_name() -> str:
|
463 | 463 | AzureClient = TypeVar("AzureClient")
|
464 | 464 |
|
465 | 465 |
|
| 466 | +def get_azure_credential() -> Union[ClientSecretCredential, DefaultAzureCredential]: |
| 467 | + global _azure_credential |
| 468 | + if _azure_credential is None: |
| 469 | + _azure_credential = DefaultAzureCredential() |
| 470 | + return _azure_credential |
| 471 | + |
| 472 | + |
| 473 | +def set_azure_client_credentials( |
| 474 | + azure_subscription_id: str, azure_credential: ClientSecretCredential |
| 475 | +): |
| 476 | + global _azure_subscription_id |
| 477 | + if _azure_subscription_id is not None: |
| 478 | + raise RuntimeError("Azure client credentials already set, cannot reset subscription id") |
| 479 | + _azure_subscription_id = azure_subscription_id |
| 480 | + |
| 481 | + global _azure_credential |
| 482 | + if _azure_credential is not None: |
| 483 | + raise RuntimeError("Azure client credentials already set, cannot reset credentials") |
| 484 | + _azure_credential = azure_credential |
| 485 | + |
| 486 | + |
466 | 487 | def _get_azure_client(azure_client_class: Type[AzureClient]) -> AzureClient:
|
467 | 488 | global _azure_subscription_id
|
468 | 489 | if not _azure_subscription_id:
|
469 | 490 | _azure_subscription_id = _get_azure_subscription_id()
|
470 | 491 |
|
471 | 492 | global _azure_credential
|
472 | 493 | if not _azure_credential:
|
473 |
| - _azure_credential = DefaultAzureCredential() |
| 494 | + _azure_credential = get_azure_credential() |
474 | 495 |
|
475 | 496 | return azure_client_class(_azure_credential, _azure_subscription_id)
|
476 | 497 |
|
477 | 498 |
|
478 | 499 | def _get_azure_subscription_id():
|
479 |
| - return os.getenv("AZURE_SUBSCRIPTION_ID") or get_cli_profile().get_login_credentials()[1] |
| 500 | + global _azure_subscription_id |
| 501 | + subscription_id = ( |
| 502 | + _azure_subscription_id |
| 503 | + or os.getenv("AZURE_SUBSCRIPTION_ID") |
| 504 | + or get_cli_profile().get_login_credentials()[1] |
| 505 | + ) |
| 506 | + return subscription_id |
480 | 507 |
|
481 | 508 |
|
482 | 509 | def _get_running_vms_by_id(
|
|
0 commit comments