12
12
from urllib .parse import urlparse
13
13
14
14
import boto3 as boto
15
+ import orjson
15
16
from botocore .exceptions import ClientError
16
- from orjson import orjson
17
17
18
+ from sync .api import get_access_report as get_api_access_report
18
19
from sync .api .predictions import create_prediction_with_eventlog_bytes , get_prediction
19
20
from sync .api .projects import get_project
20
21
from sync .clients .databricks import get_default_client
21
22
from sync .config import CONFIG , DB_CONFIG
22
23
from sync .models import (
24
+ AccessReport ,
25
+ AccessReportLine ,
26
+ AccessStatusCode ,
23
27
DatabricksAPIError ,
24
28
DatabricksClusterReport ,
25
29
DatabricksError ,
31
35
logger = logging .getLogger (__name__ )
32
36
33
37
38
+ def get_access_report (log_url : str = None ) -> AccessReport :
39
+ """Reports access to Databricks, AWS and Sync required for integrating jobs with Sync.
40
+ Access is partially determined by the configuration of this library and boto3.
41
+
42
+ :param log_url: location of event logs, defaults to None
43
+ :type log_url: str, optional
44
+ :return: access report
45
+ :rtype: AccessReport
46
+ """
47
+ report = get_api_access_report ()
48
+ dbx_client = get_default_client ()
49
+
50
+ response = dbx_client .get_current_user ()
51
+ user_name = response .get ("userName" )
52
+ if user_name :
53
+ report .append (
54
+ AccessReportLine (
55
+ name = "Databricks Authentication" ,
56
+ status = AccessStatusCode .GREEN ,
57
+ message = f"Authenticated as '{ user_name } '" ,
58
+ )
59
+ )
60
+ else :
61
+ report .append (
62
+ AccessReportLine (
63
+ name = "Databricks Authentication" ,
64
+ status = AccessStatusCode .RED ,
65
+ message = f"{ response .get ('error_code' )} : { response .get ('message' )} " ,
66
+ )
67
+ )
68
+
69
+ response = boto .client ("sts" ).get_caller_identity ()
70
+ arn = response .get ("Arn" )
71
+ if arn :
72
+ report .append (
73
+ AccessReportLine (
74
+ name = "AWS Authentication" ,
75
+ status = AccessStatusCode .GREEN ,
76
+ message = f"Authenticated as '{ arn } '" ,
77
+ )
78
+ )
79
+
80
+ ec2 = boto .client ("ec2" , region_name = DB_CONFIG .aws_region_name )
81
+ report .add_boto_method_call (ec2 .describe_instances , AccessStatusCode .YELLOW , DryRun = True )
82
+ else :
83
+ report .append (
84
+ AccessReportLine (
85
+ name = "AWS Authentication" ,
86
+ status = AccessStatusCode .RED ,
87
+ message = "Failed to authenticate AWS credentials" ,
88
+ )
89
+ )
90
+
91
+ if log_url :
92
+ parsed_log_url = urlparse (log_url )
93
+
94
+ if parsed_log_url .scheme == "s3" and arn :
95
+ s3 = boto .client ("s3" )
96
+ report .add_boto_method_call (
97
+ s3 .list_objects_v2 ,
98
+ Bucket = parsed_log_url .netloc ,
99
+ Prefix = parsed_log_url .params .rstrip ("/" ),
100
+ MaxKeys = 1 ,
101
+ )
102
+ elif parsed_log_url .scheme == "dbfs" :
103
+ response = dbx_client .list_dbfs_directory (parsed_log_url .geturl ())
104
+ if "error_code" not in response :
105
+ report .append (
106
+ AccessReportLine (
107
+ name = "Log Access" ,
108
+ status = AccessStatusCode .GREEN ,
109
+ message = f"Can list objects at { parsed_log_url .geturl ()} " ,
110
+ )
111
+ )
112
+ else :
113
+ report .append (
114
+ AccessReportLine (
115
+ name = "Log Access" ,
116
+ status = AccessStatusCode .RED ,
117
+ message = f"Can list objects at { parsed_log_url .geturl ()} " ,
118
+ )
119
+ )
120
+ else :
121
+ report .append (
122
+ AccessReportLine (
123
+ name = "Log Access" ,
124
+ status = AccessStatusCode .RED ,
125
+ message = f"scheme in { parsed_log_url .geturl ()} is not supported" ,
126
+ )
127
+ )
128
+
129
+ return report
130
+
131
+
34
132
def create_prediction (
35
133
plan_type : str ,
36
134
compute_type : str ,
@@ -314,15 +412,18 @@ def _get_cluster_instances(cluster: dict) -> Response[dict]:
314
412
# are associated with this cluster
315
413
if cluster_instances is None :
316
414
ec2 = boto .client ("ec2" , region_name = aws_region_name )
317
- cluster_instances = ec2 .describe_instances (
318
- Filters = [
319
- {"Name" : "tag:Vendor" , "Values" : ["Databricks" ]},
320
- {"Name" : "tag:ClusterId" , "Values" : [cluster_id ]},
321
- # {'Name': 'tag:JobId', 'Values': []}
322
- ]
323
- )
415
+ try :
416
+ cluster_instances = ec2 .describe_instances (
417
+ Filters = [
418
+ {"Name" : "tag:Vendor" , "Values" : ["Databricks" ]},
419
+ {"Name" : "tag:ClusterId" , "Values" : [cluster_id ]},
420
+ # {'Name': 'tag:JobId', 'Values': []}
421
+ ]
422
+ )
423
+ except Exception as exc :
424
+ logger .warning (exc )
324
425
325
- if not cluster_instances ["Reservations" ]:
426
+ if not cluster_instances or not cluster_instances ["Reservations" ]:
326
427
no_instances_message = (
327
428
f"Unable to find any active or recently terminated instances for cluster `{ cluster_id } ` in `{ aws_region_name } `. "
328
429
+ "Please refer to the following documentation for options on how to address this - "
@@ -1179,10 +1280,10 @@ def _get_eventlog_from_dbfs(
1179
1280
run_end_time_millis : int ,
1180
1281
poll_duration_seconds : int ,
1181
1282
):
1182
- prefix = format_dbfs_filepath (f"{ base_filepath } /eventlog/" )
1183
-
1184
- root_dir = get_default_client ().list_dbfs_directory (prefix )
1283
+ dbx_client = get_default_client ()
1185
1284
1285
+ prefix = format_dbfs_filepath (f"{ base_filepath } /eventlog/" )
1286
+ root_dir = dbx_client .list_dbfs_directory (prefix )
1186
1287
eventlog_files = [f for f in root_dir ["files" ] if f ["is_dir" ]]
1187
1288
matching_subdirectory = None
1188
1289
@@ -1195,7 +1296,7 @@ def _get_eventlog_from_dbfs(
1195
1296
eventlog_file_metadata = eventlog_files .pop ()
1196
1297
path = eventlog_file_metadata ["path" ]
1197
1298
1198
- subdir = get_default_client () .list_dbfs_directory (path )
1299
+ subdir = dbx_client .list_dbfs_directory (path )
1199
1300
1200
1301
subdir_files = subdir ["files" ]
1201
1302
matching_subdirectory = next (
@@ -1204,7 +1305,7 @@ def _get_eventlog_from_dbfs(
1204
1305
)
1205
1306
1206
1307
if matching_subdirectory :
1207
- eventlog_dir = get_default_client () .list_dbfs_directory (matching_subdirectory ["path" ])
1308
+ eventlog_dir = dbx_client .list_dbfs_directory (matching_subdirectory ["path" ])
1208
1309
1209
1310
poll_num_attempts = 0
1210
1311
poll_max_attempts = 20 # 5 minutes / 15 seconds = 20 attempts
@@ -1219,14 +1320,13 @@ def _get_eventlog_from_dbfs(
1219
1320
)
1220
1321
sleep (poll_duration_seconds )
1221
1322
1222
- eventlog_dir = get_default_client () .list_dbfs_directory (matching_subdirectory ["path" ])
1323
+ eventlog_dir = dbx_client .list_dbfs_directory (matching_subdirectory ["path" ])
1223
1324
poll_num_attempts += 1
1224
1325
1225
1326
eventlog_zip = io .BytesIO ()
1226
1327
eventlog_zip_file = zipfile .ZipFile (eventlog_zip , "a" , zipfile .ZIP_DEFLATED )
1227
1328
1228
1329
eventlog_files = eventlog_dir ["files" ]
1229
- dbx_client = get_default_client ()
1230
1330
for eventlog_file_metadata in eventlog_files :
1231
1331
filename : str = eventlog_file_metadata ["path" ].split ("/" )[- 1 ]
1232
1332
filesize : int = eventlog_file_metadata ["file_size" ]
0 commit comments