Skip to content

Commit 2bbe74c

Browse files
author
Brandon Kaplan
authored
Prod 1779 updated sync spark py to support azure hosted (#108)
1 parent 97964c3 commit 2bbe74c

File tree

1 file changed

+32
-5
lines changed

1 file changed

+32
-5
lines changed

sync/azuredatabricks.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import sys
55
from pathlib import Path
66
from time import sleep
7-
from typing import Dict, List, Optional, Type, TypeVar
7+
from typing import Dict, List, Optional, Type, TypeVar, Union
88
from urllib.parse import urlparse
99

1010
from azure.common.credentials import get_cli_profile
1111
from azure.core.exceptions import ClientAuthenticationError
12-
from azure.identity import DefaultAzureCredential
12+
from azure.identity import DefaultAzureCredential, ClientSecretCredential
1313
from azure.mgmt.compute import ComputeManagementClient
1414
from azure.mgmt.resource import ResourceManagementClient
1515

@@ -125,7 +125,7 @@ def get_access_report(log_url: str = None) -> AccessReport:
125125
)
126126

127127
try:
128-
DefaultAzureCredential().get_token("https://management.azure.com/.default")
128+
get_azure_credential().get_token("https://management.azure.com/.default")
129129
report.append(
130130
AccessReportLine(
131131
name="Azure Authentication",
@@ -463,20 +463,47 @@ def _get_databricks_resource_group_name() -> str:
463463
AzureClient = TypeVar("AzureClient")
464464

465465

466+
def get_azure_credential() -> Union[ClientSecretCredential, DefaultAzureCredential]:
467+
global _azure_credential
468+
if _azure_credential is None:
469+
_azure_credential = DefaultAzureCredential()
470+
return _azure_credential
471+
472+
473+
def set_azure_client_credentials(
474+
azure_subscription_id: str, azure_credential: ClientSecretCredential
475+
):
476+
global _azure_subscription_id
477+
if _azure_subscription_id is not None:
478+
raise RuntimeError("Azure client credentials already set, cannot reset subscription id")
479+
_azure_subscription_id = azure_subscription_id
480+
481+
global _azure_credential
482+
if _azure_credential is not None:
483+
raise RuntimeError("Azure client credentials already set, cannot reset credentials")
484+
_azure_credential = azure_credential
485+
486+
466487
def _get_azure_client(azure_client_class: Type[AzureClient]) -> AzureClient:
467488
global _azure_subscription_id
468489
if not _azure_subscription_id:
469490
_azure_subscription_id = _get_azure_subscription_id()
470491

471492
global _azure_credential
472493
if not _azure_credential:
473-
_azure_credential = DefaultAzureCredential()
494+
_azure_credential = get_azure_credential()
474495

475496
return azure_client_class(_azure_credential, _azure_subscription_id)
476497

477498

478499
def _get_azure_subscription_id():
479-
return os.getenv("AZURE_SUBSCRIPTION_ID") or get_cli_profile().get_login_credentials()[1]
500+
global _azure_subscription_id
501+
subscription_id = (
502+
_azure_subscription_id
503+
or os.getenv("AZURE_SUBSCRIPTION_ID")
504+
or get_cli_profile().get_login_credentials()[1]
505+
)
506+
return subscription_id
480507

481508

482509
def _get_running_vms_by_id(

0 commit comments

Comments
 (0)