Skip to content

Commit 4c9b858

Browse files
authored
[CUS-7] Allow creating predictions for each cluster in a job run (#44)
1 parent 91ca9af commit 4c9b858

File tree

2 files changed

+140
-62
lines changed

2 files changed

+140
-62
lines changed

sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Library for leveraging the power of Sync"""
2-
__version__ = "0.4.2"
2+
__version__ = "0.4.3"
33

44
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

sync/_databricks.py

Lines changed: 139 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,49 @@ def create_prediction_for_run(
142142
:return: prediction ID
143143
:rtype: Response[str]
144144
"""
145+
run = get_default_client().get_run(run_id)
146+
147+
if "error_code" in run:
148+
return Response(error=DatabricksAPIError(**run))
149+
150+
project_cluster_tasks = _get_project_cluster_tasks(run, exclude_tasks)
151+
152+
cluster_id = None
153+
if project_id:
154+
if project_id in project_cluster_tasks:
155+
cluster_id, tasks = project_cluster_tasks.get(project_id)
156+
157+
if not cluster_id:
158+
if None in project_cluster_tasks and len(project_cluster_tasks) == 1:
159+
# If there's only 1 cluster and it's not tagged with project ID assume that's the one for the project
160+
cluster_id, tasks = project_cluster_tasks.get(None)
161+
162+
if not cluster_id:
163+
return Response(
164+
error=DatabricksError(
165+
message=f"No cluster found in run {run_id} for project {project_id}"
166+
)
167+
)
168+
169+
return _create_prediction(
170+
cluster_id, tasks, plan_type, compute_type, project_id, allow_incomplete_cluster_report
171+
)
172+
173+
174+
def _create_prediction(
175+
cluster_id: str,
176+
tasks: List[dict],
177+
plan_type: str,
178+
compute_type: str,
179+
project_id: str = None,
180+
allow_incomplete_cluster_report: bool = False,
181+
):
145182
run_information_response = _get_run_information(
146-
run_id,
183+
cluster_id,
184+
tasks,
147185
plan_type,
148186
compute_type,
149-
project_id=project_id,
150187
allow_incomplete_cluster_report=allow_incomplete_cluster_report,
151-
exclude_tasks=exclude_tasks,
152188
)
153189

154190
if run_information_response.error:
@@ -247,14 +283,26 @@ def create_submission_for_run(
247283
:return: prediction ID
248284
:rtype: Response[str]
249285
"""
286+
run = get_default_client().get_run(run_id)
287+
288+
if "error_code" in run:
289+
return Response(error=DatabricksAPIError(**run))
290+
291+
project_cluster_tasks = _get_project_cluster_tasks(run, exclude_tasks)
292+
293+
if project_id in project_cluster_tasks:
294+
cluster_id, tasks = project_cluster_tasks.get(project_id)
295+
elif None in project_cluster_tasks and len(project_cluster_tasks) == 1:
296+
# If there's only 1 cluster and it's not tagged with project ID assume that's the one for the project
297+
cluster_id, tasks = project_cluster_tasks.get(None)
298+
250299
run_information_response = _get_run_information(
251-
run_id,
300+
cluster_id,
301+
tasks,
252302
plan_type,
253303
compute_type,
254-
project_id=project_id,
255304
allow_failed_tasks=True,
256305
allow_incomplete_cluster_report=allow_incomplete_cluster_report,
257-
exclude_tasks=exclude_tasks,
258306
)
259307

260308
if run_information_response.error:
@@ -269,24 +317,13 @@ def create_submission_for_run(
269317

270318

271319
def _get_run_information(
272-
run_id: str,
320+
cluster_id: str,
321+
tasks: List[dict],
273322
plan_type: str,
274323
compute_type: str,
275-
project_id: str = None,
276324
allow_failed_tasks: bool = False,
277325
allow_incomplete_cluster_report: bool = False,
278-
exclude_tasks: Union[Collection[str], None] = None,
279326
) -> Response[Tuple[DatabricksClusterReport, bytes]]:
280-
run = get_default_client().get_run(run_id)
281-
282-
if "error_code" in run:
283-
return Response(error=DatabricksAPIError(**run))
284-
285-
try:
286-
cluster_id, tasks = _get_cluster_id_and_tasks_from_run_tasks(run, exclude_tasks, project_id)
287-
except Exception as e:
288-
return Response(error=DatabricksError(message=str(e)))
289-
290327
if not allow_failed_tasks and any(
291328
task["state"].get("result_state") != "SUCCESS" for task in tasks
292329
):
@@ -297,12 +334,13 @@ def _get_run_information(
297334
cluster_report_response = _get_cluster_report(
298335
cluster_id, tasks, plan_type, compute_type, allow_incomplete_cluster_report
299336
)
337+
300338
cluster_report = cluster_report_response.result
301339
if cluster_report:
302-
303340
cluster = cluster_report.cluster
304341
spark_context_id = _get_run_spark_context_id(tasks)
305-
eventlog_response = _get_eventlog(cluster, spark_context_id.result, run.get("end_time"))
342+
end_time = max(task["end_time"] for task in tasks)
343+
eventlog_response = _get_eventlog(cluster, spark_context_id.result, end_time)
306344

307345
eventlog = eventlog_response.result
308346
if eventlog:
@@ -343,12 +381,16 @@ def get_cluster_report(
343381
if "error_code" in run:
344382
return Response(error=DatabricksAPIError(**run))
345383

346-
try:
347-
cluster_id, tasks = _get_cluster_id_and_tasks_from_run_tasks(run, exclude_tasks, project_id)
348-
except Exception as e:
349-
return Response(error=DatabricksError(message=str(e)))
384+
project_cluster_tasks = _get_project_cluster_tasks(run, exclude_tasks, project_id)
385+
cluster_tasks = project_cluster_tasks.get(project_id)
386+
if not cluster_tasks:
387+
return Response(
388+
error=DatabricksError(f"Failed to locate cluster for project ID {project_id}")
389+
)
350390

351-
return _get_cluster_report(cluster_id, tasks, plan_type, compute_type, allow_incomplete)
391+
return _get_cluster_report(
392+
cluster_tasks[0], cluster_tasks[1], plan_type, compute_type, allow_incomplete
393+
)
352394

353395

354396
def _get_cluster_report(
@@ -374,30 +416,83 @@ def record_run(
374416
run_id: str,
375417
plan_type: str,
376418
compute_type: str,
377-
project_id: str,
419+
project_id: Union[str, None] = None,
378420
allow_incomplete_cluster_report: bool = False,
379421
exclude_tasks: Union[Collection[str], None] = None,
380-
) -> Response[str]:
422+
) -> Response[List[str]]:
381423
"""See :py:func:`~create_prediction_for_run`
382424
425+
If project ID is provided only create a prediction for the cluster tagged with it, or the only cluster if there is such that is untagged.
426+
If no project ID is provided then create a prediction for each cluster tagged with a project ID.
427+
383428
:param run_id: Databricks run ID
384429
:type run_id: str
385430
:param plan_type: either "Standard", "Premium" or "Enterprise"
386431
:type plan_type: str
387432
:param compute_type: e.g. "Jobs Compute"
388433
:type compute_type: str
389-
:param project_id: Sync project ID, defaults to None
390-
:type project_id: str
434+
:param project_id: Sync project ID
435+
:type project_id: str, optional, defaults to None
391436
:param allow_incomplete_cluster_report: Whether creating a prediction with incomplete cluster report data should be allowable
392437
:type allow_incomplete_cluster_report: bool, optional, defaults to False
393438
:param exclude_tasks: Keys of tasks (task names) to exclude
394439
:type exclude_tasks: Collection[str], optional, defaults to None
395440
:return: prediction ID
396441
:rtype: Response[str]
397442
"""
398-
return create_prediction_for_run(
399-
run_id, plan_type, compute_type, project_id, allow_incomplete_cluster_report, exclude_tasks
400-
)
443+
run = get_default_client().get_run(run_id)
444+
445+
if "error_code" in run:
446+
return Response(error=DatabricksAPIError(**run))
447+
448+
project_cluster_tasks = _get_project_cluster_tasks(run, exclude_tasks)
449+
450+
filtered_project_cluster_tasks = {}
451+
if project_id:
452+
if project_id in project_cluster_tasks:
453+
filtered_project_cluster_tasks = {project_id: project_cluster_tasks.get(project_id)}
454+
elif None in project_cluster_tasks and len(project_cluster_tasks) == 1:
455+
# If there's only 1 cluster and it's not tagged with project ID assume that's the one for the project
456+
filtered_project_cluster_tasks = {project_id: project_cluster_tasks.get(None)}
457+
else:
458+
filtered_project_cluster_tasks = {
459+
cluster_project_id: cluster_tasks
460+
for cluster_project_id, cluster_tasks in project_cluster_tasks.items()
461+
if cluster_project_id
462+
}
463+
464+
if not filtered_project_cluster_tasks:
465+
return Response(
466+
error=DatabricksError(
467+
message=f"No cluster found in run {run_id} for project {project_id}"
468+
)
469+
)
470+
471+
prediction_ids = []
472+
for cluster_project_id, (cluster_id, tasks) in filtered_project_cluster_tasks.items():
473+
prediction_response = _create_prediction(
474+
cluster_id,
475+
tasks,
476+
plan_type,
477+
compute_type,
478+
cluster_project_id,
479+
allow_incomplete_cluster_report,
480+
)
481+
482+
prediction_id = prediction_response.result
483+
if prediction_id:
484+
prediction_ids.append(prediction_id)
485+
else:
486+
logger.error(
487+
f"Failed to create prediction for cluster {cluster_id} in project {cluster_project_id}"
488+
)
489+
490+
if prediction_ids:
491+
return Response(result=prediction_ids)
492+
else:
493+
return Response(
494+
error=DatabricksError(message=f"Failed to create any predictions for run {run_id}")
495+
)
401496

402497

403498
def get_prediction_job(
@@ -1025,53 +1120,36 @@ def _get_job_cluster(tasks: List[dict], job_clusters: list) -> Response[dict]:
10251120
return Response(error=DatabricksError(message="Not all tasks use the same cluster"))
10261121

10271122

1028-
def _get_cluster_id_and_tasks_from_run_tasks(
1123+
def _get_project_cluster_tasks(
10291124
run: dict,
10301125
exclude_tasks: Union[Collection[str], None] = None,
1031-
project_id: str = None,
1032-
) -> Tuple[str, List[dict]]:
1126+
) -> Dict[str, Tuple[str, List[dict]]]:
1127+
"""Returns a mapping of project IDs to cluster ID-tasks pairs"""
10331128
job_clusters = {c["job_cluster_key"]: c["new_cluster"] for c in run.get("job_clusters", [])}
1034-
project_cluster_ids = defaultdict(list)
1035-
all_cluster_tasks = defaultdict(list)
1129+
all_project_cluster_tasks = defaultdict(lambda: defaultdict(list))
10361130

10371131
for task in run["tasks"]:
10381132
if "cluster_instance" in task and (
10391133
not exclude_tasks or task["task_key"] not in exclude_tasks
10401134
):
10411135
cluster_id = task["cluster_instance"]["cluster_id"]
1042-
all_cluster_tasks[cluster_id].append(task)
10431136

10441137
task_cluster = task.get("new_cluster")
10451138
if not task_cluster:
10461139
task_cluster = job_clusters.get(task.get("job_cluster_key"))
10471140

10481141
if task_cluster:
10491142
cluster_project_id = task_cluster.get("custom_tags", {}).get("sync:project-id")
1050-
if cluster_project_id:
1051-
project_cluster_ids[cluster_project_id].append(cluster_id)
1143+
all_project_cluster_tasks[cluster_project_id][cluster_id].append(task)
10521144

1053-
project_cluster_tasks = None
1054-
if project_id:
1055-
cluster_ids = project_cluster_ids.get(project_id)
1056-
if cluster_ids:
1057-
project_cluster_tasks = {
1058-
cluster_id: tasks
1059-
for cluster_id, tasks in all_cluster_tasks.items()
1060-
if cluster_id in cluster_ids
1061-
}
1145+
filtered_project_cluster_tasks = {}
1146+
for project_id, cluster_tasks in all_project_cluster_tasks.items():
1147+
if len(cluster_tasks) > 1:
1148+
logger.warning(f"More than 1 cluster found for project ID {project_id}")
10621149
else:
1063-
logger.warning(
1064-
"No task clusters found matching the provided project-id - assuming all non-excluded tasks are relevant"
1065-
)
1066-
1067-
cluster_tasks = project_cluster_tasks or all_cluster_tasks
1068-
num_clusters = len(cluster_tasks)
1069-
if num_clusters == 0:
1070-
raise RuntimeError("No cluster found for tasks")
1071-
elif num_clusters > 1:
1072-
raise RuntimeError("More than 1 cluster found for tasks")
1150+
filtered_project_cluster_tasks[project_id] = next(iter(cluster_tasks.items()))
10731151

1074-
return cluster_tasks.popitem()
1152+
return filtered_project_cluster_tasks
10751153

10761154

10771155
def _get_run_spark_context_id(tasks: List[dict]) -> Response[str]:

0 commit comments

Comments
 (0)