@@ -142,13 +142,49 @@ def create_prediction_for_run(
142
142
:return: prediction ID
143
143
:rtype: Response[str]
144
144
"""
145
+ run = get_default_client ().get_run (run_id )
146
+
147
+ if "error_code" in run :
148
+ return Response (error = DatabricksAPIError (** run ))
149
+
150
+ project_cluster_tasks = _get_project_cluster_tasks (run , exclude_tasks )
151
+
152
+ cluster_id = None
153
+ if project_id :
154
+ if project_id in project_cluster_tasks :
155
+ cluster_id , tasks = project_cluster_tasks .get (project_id )
156
+
157
+ if not cluster_id :
158
+ if None in project_cluster_tasks and len (project_cluster_tasks ) == 1 :
159
+ # If there's only 1 cluster and it's not tagged with project ID assume that's the one for the project
160
+ cluster_id , tasks = project_cluster_tasks .get (None )
161
+
162
+ if not cluster_id :
163
+ return Response (
164
+ error = DatabricksError (
165
+ message = f"No cluster found in run { run_id } for project { project_id } "
166
+ )
167
+ )
168
+
169
+ return _create_prediction (
170
+ cluster_id , tasks , plan_type , compute_type , project_id , allow_incomplete_cluster_report
171
+ )
172
+
173
+
174
+ def _create_prediction (
175
+ cluster_id : str ,
176
+ tasks : List [dict ],
177
+ plan_type : str ,
178
+ compute_type : str ,
179
+ project_id : str = None ,
180
+ allow_incomplete_cluster_report : bool = False ,
181
+ ):
145
182
run_information_response = _get_run_information (
146
- run_id ,
183
+ cluster_id ,
184
+ tasks ,
147
185
plan_type ,
148
186
compute_type ,
149
- project_id = project_id ,
150
187
allow_incomplete_cluster_report = allow_incomplete_cluster_report ,
151
- exclude_tasks = exclude_tasks ,
152
188
)
153
189
154
190
if run_information_response .error :
@@ -247,14 +283,26 @@ def create_submission_for_run(
247
283
:return: prediction ID
248
284
:rtype: Response[str]
249
285
"""
286
+ run = get_default_client ().get_run (run_id )
287
+
288
+ if "error_code" in run :
289
+ return Response (error = DatabricksAPIError (** run ))
290
+
291
+ project_cluster_tasks = _get_project_cluster_tasks (run , exclude_tasks )
292
+
293
+ if project_id in project_cluster_tasks :
294
+ cluster_id , tasks = project_cluster_tasks .get (project_id )
295
+ elif None in project_cluster_tasks and len (project_cluster_tasks ) == 1 :
296
+ # If there's only 1 cluster and it's not tagged with project ID assume that's the one for the project
297
+ cluster_id , tasks = project_cluster_tasks .get (None )
298
+
250
299
run_information_response = _get_run_information (
251
- run_id ,
300
+ cluster_id ,
301
+ tasks ,
252
302
plan_type ,
253
303
compute_type ,
254
- project_id = project_id ,
255
304
allow_failed_tasks = True ,
256
305
allow_incomplete_cluster_report = allow_incomplete_cluster_report ,
257
- exclude_tasks = exclude_tasks ,
258
306
)
259
307
260
308
if run_information_response .error :
@@ -269,24 +317,13 @@ def create_submission_for_run(
269
317
270
318
271
319
def _get_run_information (
272
- run_id : str ,
320
+ cluster_id : str ,
321
+ tasks : List [dict ],
273
322
plan_type : str ,
274
323
compute_type : str ,
275
- project_id : str = None ,
276
324
allow_failed_tasks : bool = False ,
277
325
allow_incomplete_cluster_report : bool = False ,
278
- exclude_tasks : Union [Collection [str ], None ] = None ,
279
326
) -> Response [Tuple [DatabricksClusterReport , bytes ]]:
280
- run = get_default_client ().get_run (run_id )
281
-
282
- if "error_code" in run :
283
- return Response (error = DatabricksAPIError (** run ))
284
-
285
- try :
286
- cluster_id , tasks = _get_cluster_id_and_tasks_from_run_tasks (run , exclude_tasks , project_id )
287
- except Exception as e :
288
- return Response (error = DatabricksError (message = str (e )))
289
-
290
327
if not allow_failed_tasks and any (
291
328
task ["state" ].get ("result_state" ) != "SUCCESS" for task in tasks
292
329
):
@@ -297,12 +334,13 @@ def _get_run_information(
297
334
cluster_report_response = _get_cluster_report (
298
335
cluster_id , tasks , plan_type , compute_type , allow_incomplete_cluster_report
299
336
)
337
+
300
338
cluster_report = cluster_report_response .result
301
339
if cluster_report :
302
-
303
340
cluster = cluster_report .cluster
304
341
spark_context_id = _get_run_spark_context_id (tasks )
305
- eventlog_response = _get_eventlog (cluster , spark_context_id .result , run .get ("end_time" ))
342
+ end_time = max (task ["end_time" ] for task in tasks )
343
+ eventlog_response = _get_eventlog (cluster , spark_context_id .result , end_time )
306
344
307
345
eventlog = eventlog_response .result
308
346
if eventlog :
@@ -343,12 +381,16 @@ def get_cluster_report(
343
381
if "error_code" in run :
344
382
return Response (error = DatabricksAPIError (** run ))
345
383
346
- try :
347
- cluster_id , tasks = _get_cluster_id_and_tasks_from_run_tasks (run , exclude_tasks , project_id )
348
- except Exception as e :
349
- return Response (error = DatabricksError (message = str (e )))
384
+ project_cluster_tasks = _get_project_cluster_tasks (run , exclude_tasks , project_id )
385
+ cluster_tasks = project_cluster_tasks .get (project_id )
386
+ if not cluster_tasks :
387
+ return Response (
388
+ error = DatabricksError (f"Failed to locate cluster for project ID { project_id } " )
389
+ )
350
390
351
- return _get_cluster_report (cluster_id , tasks , plan_type , compute_type , allow_incomplete )
391
+ return _get_cluster_report (
392
+ cluster_tasks [0 ], cluster_tasks [1 ], plan_type , compute_type , allow_incomplete
393
+ )
352
394
353
395
354
396
def _get_cluster_report (
@@ -374,30 +416,83 @@ def record_run(
374
416
run_id : str ,
375
417
plan_type : str ,
376
418
compute_type : str ,
377
- project_id : str ,
419
+ project_id : Union [ str , None ] = None ,
378
420
allow_incomplete_cluster_report : bool = False ,
379
421
exclude_tasks : Union [Collection [str ], None ] = None ,
380
- ) -> Response [str ]:
422
+ ) -> Response [List [ str ] ]:
381
423
"""See :py:func:`~create_prediction_for_run`
382
424
425
+ If project ID is provided only create a prediction for the cluster tagged with it, or the only cluster if there is such that is untagged.
426
+ If no project ID is provided then create a prediction for each cluster tagged with a project ID.
427
+
383
428
:param run_id: Databricks run ID
384
429
:type run_id: str
385
430
:param plan_type: either "Standard", "Premium" or "Enterprise"
386
431
:type plan_type: str
387
432
:param compute_type: e.g. "Jobs Compute"
388
433
:type compute_type: str
389
- :param project_id: Sync project ID, defaults to None
390
- :type project_id: str
434
+ :param project_id: Sync project ID
435
+ :type project_id: str, optional, defaults to None
391
436
:param allow_incomplete_cluster_report: Whether creating a prediction with incomplete cluster report data should be allowable
392
437
:type allow_incomplete_cluster_report: bool, optional, defaults to False
393
438
:param exclude_tasks: Keys of tasks (task names) to exclude
394
439
:type exclude_tasks: Collection[str], optional, defaults to None
395
440
:return: prediction ID
396
441
:rtype: Response[str]
397
442
"""
398
- return create_prediction_for_run (
399
- run_id , plan_type , compute_type , project_id , allow_incomplete_cluster_report , exclude_tasks
400
- )
443
+ run = get_default_client ().get_run (run_id )
444
+
445
+ if "error_code" in run :
446
+ return Response (error = DatabricksAPIError (** run ))
447
+
448
+ project_cluster_tasks = _get_project_cluster_tasks (run , exclude_tasks )
449
+
450
+ filtered_project_cluster_tasks = {}
451
+ if project_id :
452
+ if project_id in project_cluster_tasks :
453
+ filtered_project_cluster_tasks = {project_id : project_cluster_tasks .get (project_id )}
454
+ elif None in project_cluster_tasks and len (project_cluster_tasks ) == 1 :
455
+ # If there's only 1 cluster and it's not tagged with project ID assume that's the one for the project
456
+ filtered_project_cluster_tasks = {project_id : project_cluster_tasks .get (None )}
457
+ else :
458
+ filtered_project_cluster_tasks = {
459
+ cluster_project_id : cluster_tasks
460
+ for cluster_project_id , cluster_tasks in project_cluster_tasks .items ()
461
+ if cluster_project_id
462
+ }
463
+
464
+ if not filtered_project_cluster_tasks :
465
+ return Response (
466
+ error = DatabricksError (
467
+ message = f"No cluster found in run { run_id } for project { project_id } "
468
+ )
469
+ )
470
+
471
+ prediction_ids = []
472
+ for cluster_project_id , (cluster_id , tasks ) in filtered_project_cluster_tasks .items ():
473
+ prediction_response = _create_prediction (
474
+ cluster_id ,
475
+ tasks ,
476
+ plan_type ,
477
+ compute_type ,
478
+ cluster_project_id ,
479
+ allow_incomplete_cluster_report ,
480
+ )
481
+
482
+ prediction_id = prediction_response .result
483
+ if prediction_id :
484
+ prediction_ids .append (prediction_id )
485
+ else :
486
+ logger .error (
487
+ f"Failed to create prediction for cluster { cluster_id } in project { cluster_project_id } "
488
+ )
489
+
490
+ if prediction_ids :
491
+ return Response (result = prediction_ids )
492
+ else :
493
+ return Response (
494
+ error = DatabricksError (message = f"Failed to create any predictions for run { run_id } " )
495
+ )
401
496
402
497
403
498
def get_prediction_job (
@@ -1025,53 +1120,36 @@ def _get_job_cluster(tasks: List[dict], job_clusters: list) -> Response[dict]:
1025
1120
return Response (error = DatabricksError (message = "Not all tasks use the same cluster" ))
1026
1121
1027
1122
1028
- def _get_cluster_id_and_tasks_from_run_tasks (
1123
+ def _get_project_cluster_tasks (
1029
1124
run : dict ,
1030
1125
exclude_tasks : Union [Collection [str ], None ] = None ,
1031
- project_id : str = None ,
1032
- ) -> Tuple [ str , List [ dict ]]:
1126
+ ) -> Dict [ str , Tuple [ str , List [ dict ]]]:
1127
+ """Returns a mapping of project IDs to cluster ID-tasks pairs"""
1033
1128
job_clusters = {c ["job_cluster_key" ]: c ["new_cluster" ] for c in run .get ("job_clusters" , [])}
1034
- project_cluster_ids = defaultdict (list )
1035
- all_cluster_tasks = defaultdict (list )
1129
+ all_project_cluster_tasks = defaultdict (lambda : defaultdict (list ))
1036
1130
1037
1131
for task in run ["tasks" ]:
1038
1132
if "cluster_instance" in task and (
1039
1133
not exclude_tasks or task ["task_key" ] not in exclude_tasks
1040
1134
):
1041
1135
cluster_id = task ["cluster_instance" ]["cluster_id" ]
1042
- all_cluster_tasks [cluster_id ].append (task )
1043
1136
1044
1137
task_cluster = task .get ("new_cluster" )
1045
1138
if not task_cluster :
1046
1139
task_cluster = job_clusters .get (task .get ("job_cluster_key" ))
1047
1140
1048
1141
if task_cluster :
1049
1142
cluster_project_id = task_cluster .get ("custom_tags" , {}).get ("sync:project-id" )
1050
- if cluster_project_id :
1051
- project_cluster_ids [cluster_project_id ].append (cluster_id )
1143
+ all_project_cluster_tasks [cluster_project_id ][cluster_id ].append (task )
1052
1144
1053
- project_cluster_tasks = None
1054
- if project_id :
1055
- cluster_ids = project_cluster_ids .get (project_id )
1056
- if cluster_ids :
1057
- project_cluster_tasks = {
1058
- cluster_id : tasks
1059
- for cluster_id , tasks in all_cluster_tasks .items ()
1060
- if cluster_id in cluster_ids
1061
- }
1145
+ filtered_project_cluster_tasks = {}
1146
+ for project_id , cluster_tasks in all_project_cluster_tasks .items ():
1147
+ if len (cluster_tasks ) > 1 :
1148
+ logger .warning (f"More than 1 cluster found for project ID { project_id } " )
1062
1149
else :
1063
- logger .warning (
1064
- "No task clusters found matching the provided project-id - assuming all non-excluded tasks are relevant"
1065
- )
1066
-
1067
- cluster_tasks = project_cluster_tasks or all_cluster_tasks
1068
- num_clusters = len (cluster_tasks )
1069
- if num_clusters == 0 :
1070
- raise RuntimeError ("No cluster found for tasks" )
1071
- elif num_clusters > 1 :
1072
- raise RuntimeError ("More than 1 cluster found for tasks" )
1150
+ filtered_project_cluster_tasks [project_id ] = next (iter (cluster_tasks .items ()))
1073
1151
1074
- return cluster_tasks . popitem ()
1152
+ return filtered_project_cluster_tasks
1075
1153
1076
1154
1077
1155
def _get_run_spark_context_id (tasks : List [dict ]) -> Response [str ]:
0 commit comments