Skip to content

Commit ff5fa60

Browse files
authored
[PROD-1291] Apply prediction to existing job (#49)
1 parent ed81ffe commit ff5fa60

File tree

10 files changed

+295
-26
lines changed

10 files changed

+295
-26
lines changed

sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Library for leveraging the power of Sync"""
2-
__version__ = "0.4.5"
2+
__version__ = "0.4.6"
33

44
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

sync/_databricks.py

Lines changed: 210 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414

1515
import boto3 as boto
1616

17-
from sync.api.predictions import create_prediction_with_eventlog_bytes, get_prediction
17+
from sync.api.predictions import (
18+
create_prediction_with_eventlog_bytes,
19+
get_prediction,
20+
get_predictions,
21+
)
1822
from sync.api.projects import (
1923
create_project_submission_with_eventlog_bytes,
2024
get_project,
@@ -153,22 +157,22 @@ def create_prediction_for_run(
153157

154158
project_cluster_tasks = _get_project_cluster_tasks(run, exclude_tasks)
155159

156-
cluster_id = None
160+
cluster_tasks = None
157161
if project_id:
158-
if project_id in project_cluster_tasks:
159-
cluster_id, tasks = project_cluster_tasks.get(project_id)
162+
cluster_tasks = project_cluster_tasks.get(project_id)
160163

161-
if not cluster_id:
162-
if None in project_cluster_tasks and len(project_cluster_tasks) == 1:
163-
# If there's only 1 cluster and it's not tagged with project ID assume that's the one for the project
164-
cluster_id, tasks = project_cluster_tasks.get(None)
165-
166-
if not cluster_id:
167-
return Response(
168-
error=DatabricksError(
169-
message=f"No cluster found in run {run_id} for project {project_id}"
164+
if not cluster_tasks:
165+
if len(project_cluster_tasks) == 1:
166+
# If there's only 1 cluster assume that's the one for the project
167+
cluster_tasks = next(iter(project_cluster_tasks.values()))
168+
else:
169+
return Response(
170+
error=DatabricksError(
171+
message=f"No cluster found in run {run_id} for project {project_id}"
172+
)
170173
)
171-
)
174+
175+
cluster_id, tasks = cluster_tasks
172176

173177
return _create_prediction(
174178
cluster_id, tasks, plan_type, compute_type, project_id, allow_incomplete_cluster_report
@@ -294,11 +298,19 @@ def create_submission_for_run(
294298

295299
project_cluster_tasks = _get_project_cluster_tasks(run, exclude_tasks)
296300

297-
if project_id in project_cluster_tasks:
298-
cluster_id, tasks = project_cluster_tasks.get(project_id)
299-
elif None in project_cluster_tasks and len(project_cluster_tasks) == 1:
300-
# If there's only 1 cluster and it's not tagged with project ID assume that's the one for the project
301-
cluster_id, tasks = project_cluster_tasks.get(None)
301+
cluster_tasks = project_cluster_tasks.get(project_id)
302+
if not cluster_tasks:
303+
if len(project_cluster_tasks) == 1:
304+
# If there's only 1 cluster assume that's the one for the project
305+
cluster_tasks = next(iter(project_cluster_tasks.values()))
306+
else:
307+
return Response(
308+
error=DatabricksError(
309+
message=f"Unable to locate cluster in run {run_id} for project {project_id}"
310+
)
311+
)
312+
313+
cluster_id, tasks = cluster_tasks
302314

303315
run_information_response = _get_run_information(
304316
cluster_id,
@@ -386,6 +398,7 @@ def get_cluster_report(
386398
return Response(error=DatabricksAPIError(**run))
387399

388400
project_cluster_tasks = _get_project_cluster_tasks(run, exclude_tasks)
401+
389402
cluster_tasks = project_cluster_tasks.get(project_id)
390403
if not cluster_tasks:
391404
return Response(
@@ -455,9 +468,11 @@ def record_run(
455468
if project_id:
456469
if project_id in project_cluster_tasks:
457470
filtered_project_cluster_tasks = {project_id: project_cluster_tasks.get(project_id)}
458-
elif None in project_cluster_tasks and len(project_cluster_tasks) == 1:
459-
# If there's only 1 cluster and it's not tagged with project ID assume that's the one for the project
460-
filtered_project_cluster_tasks = {project_id: project_cluster_tasks.get(None)}
471+
elif len(project_cluster_tasks) == 1:
472+
# If there's only 1 cluster assume that's the one for the project
473+
filtered_project_cluster_tasks = {
474+
project_id: next(iter(project_cluster_tasks.values()))
475+
}
461476
else:
462477
filtered_project_cluster_tasks = {
463478
cluster_project_id: cluster_tasks
@@ -499,6 +514,83 @@ def record_run(
499514
)
500515

501516

517+
def apply_prediction(
518+
job_id: str, project_id: str, prediction_id: str = None, preference: str = None
519+
):
520+
"""Updates jobs with prediction configuration
521+
522+
:param job_id: ID of job to apply prediction to
523+
:type job_id: str
524+
:param project_id: Sync project ID
525+
:type project_id: str
526+
:param prediction_id: Sync prediction ID, defaults to latest in project
527+
:type prediction_id: str, optional
528+
:param preference: Prediction preference, defaults to "recommended" then "economy"
529+
:type preference: str, optional
530+
:return: ID of applied prediction
531+
:rtype: Response[str]
532+
"""
533+
if prediction_id:
534+
prediction_response = get_prediction(prediction_id, preference)
535+
else:
536+
predictions_response = get_predictions(project_id=project_id)
537+
if predictions_response.error:
538+
return predictions_response
539+
prediction_id = predictions_response.result[0]["prediction_id"]
540+
prediction_response = get_prediction(prediction_id, preference)
541+
542+
if prediction_response.error:
543+
return prediction_response
544+
545+
prediction = prediction_response.result
546+
547+
databricks_client = get_default_client()
548+
549+
job = databricks_client.get_job(job_id)
550+
job_clusters = _get_project_job_clusters(job)
551+
552+
project_cluster = job_clusters.get(project_id)
553+
if not project_cluster:
554+
if len(job_clusters) == 1:
555+
project_cluster = next(iter(job_clusters.values()))
556+
else:
557+
return Response(
558+
error=DatabricksError(
559+
message=f"Unable to locate cluster in job {job_id} for project {project_id}"
560+
)
561+
)
562+
563+
project_cluster_path, _ = project_cluster
564+
565+
if preference:
566+
prediction_cluster = prediction["solutions"][preference]["configuration"]
567+
else:
568+
prediction_cluster = prediction["solutions"].get(
569+
"recommended", prediction["solutions"]["economy"]
570+
)["configuration"]
571+
572+
if "cluster_name" in prediction_cluster:
573+
del prediction_cluster["cluster_name"]
574+
575+
if project_cluster_path[0] == "job_clusters":
576+
new_settings = {
577+
"job_clusters": [
578+
{"job_cluster_key": project_cluster_path[1], "new_cluster": prediction_cluster}
579+
]
580+
}
581+
else:
582+
new_settings = {
583+
"tasks": [{"task_key": project_cluster_path[1], "new_cluster": prediction_cluster}]
584+
}
585+
586+
response = databricks_client.update_job(job_id, new_settings)
587+
588+
if "error_code" in response:
589+
return Response(error=DatabricksAPIError(**response))
590+
591+
return Response(result=prediction_id)
592+
593+
502594
def get_prediction_job(
503595
job_id: str, prediction_id: str, preference: str = CONFIG.default_prediction_preference.value
504596
) -> Response[dict]:
@@ -586,6 +678,62 @@ def get_prediction_cluster(
586678
return prediction_response
587679

588680

681+
def apply_project_recommendation(job_id: str, project_id: str, recommendation_id: str):
682+
"""Updates jobs with project recommendation
683+
684+
:param job_id: ID of job to apply prediction to
685+
:type job_id: str
686+
:param project_id: Sync project ID
687+
:type project_id: str
688+
:param recommendation_id: Sync project recommendation ID
689+
:type recommendation_id: str
690+
:return: ID of applied recommendation
691+
:rtype: Response[str]
692+
"""
693+
databricks_client = get_default_client()
694+
695+
job = databricks_client.get_job(job_id)
696+
job_clusters = _get_project_job_clusters(job)
697+
698+
project_cluster = job_clusters.get(project_id)
699+
if not project_cluster:
700+
if len(job_clusters) == 1:
701+
project_cluster = next(iter(job_clusters.values()))
702+
else:
703+
return Response(
704+
error=DatabricksError(
705+
message=f"Unable to locate cluster in job {job_id} for project {project_id}"
706+
)
707+
)
708+
709+
project_cluster_path, project_cluster_def = project_cluster
710+
711+
new_cluster_def_response = get_recommendation_cluster(
712+
project_cluster_def, project_id, recommendation_id
713+
)
714+
if new_cluster_def_response.error:
715+
return new_cluster_def_response
716+
new_cluster_def = new_cluster_def_response.result
717+
718+
if project_cluster_path[0] == "job_clusters":
719+
new_settings = {
720+
"job_clusters": [
721+
{"job_cluster_key": project_cluster_path[1], "new_cluster": new_cluster_def}
722+
]
723+
}
724+
else:
725+
new_settings = {
726+
"tasks": [{"task_key": project_cluster_path[1], "new_cluster": new_cluster_def}]
727+
}
728+
729+
response = databricks_client.update_job(job_id, new_settings)
730+
731+
if "error_code" in response:
732+
return Response(error=DatabricksAPIError(**response))
733+
734+
return Response(result=recommendation_id)
735+
736+
589737
def get_recommendation_job(job_id: str, project_id: str, recommendation_id: str) -> Response[dict]:
590738
"""Apply the recommendation to the specified job.
591739
@@ -1222,6 +1370,46 @@ def _get_job_cluster(tasks: List[dict], job_clusters: list) -> Response[dict]:
12221370
return Response(error=DatabricksError(message="Not all tasks use the same cluster"))
12231371

12241372

1373+
def _get_project_job_clusters(
1374+
job: dict,
1375+
exclude_tasks: Union[Collection[str], None] = None,
1376+
) -> Dict[str, Tuple[Tuple[str], dict]]:
1377+
"""Returns a mapping of project IDs to cluster paths and clusters.
1378+
1379+
Cluster paths are tuples that can be used to locate clusters in a job object, e.g.
1380+
1381+
("tasks", <task_key>) or ("job_clusters", <job_cluster_key>)
1382+
1383+
Items for project IDs with more than 1 associated cluster are omitted"""
1384+
job_clusters = {
1385+
c["job_cluster_key"]: c["new_cluster"] for c in job["settings"].get("job_clusters", [])
1386+
}
1387+
all_project_clusters = defaultdict(list)
1388+
1389+
for task in job["settings"]["tasks"]:
1390+
if not exclude_tasks or task["task_key"] not in exclude_tasks:
1391+
task_cluster = task.get("new_cluster")
1392+
if task_cluster:
1393+
task_cluster_path = ("tasks", task["task_key"])
1394+
1395+
if not task_cluster:
1396+
task_cluster = job_clusters.get(task.get("job_cluster_key"))
1397+
task_cluster_path = ("job_clusters", task.get("job_cluster_key"))
1398+
1399+
if task_cluster:
1400+
cluster_project_id = task_cluster.get("custom_tags", {}).get("sync:project-id")
1401+
all_project_clusters[cluster_project_id].append((task_cluster_path, task_cluster))
1402+
1403+
filtered_project_clusters = {}
1404+
for project_id, clusters in all_project_clusters.items():
1405+
if len(clusters) > 1:
1406+
logger.warning(f"More than 1 cluster found for project ID {project_id}")
1407+
else:
1408+
filtered_project_clusters[project_id] = clusters[0]
1409+
1410+
return filtered_project_clusters
1411+
1412+
12251413
def _get_project_cluster_tasks(
12261414
run: dict,
12271415
exclude_tasks: Union[Collection[str], None] = None,

sync/api/projects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
SubmissionError,
2020
)
2121

22-
logger = logging.getLogger()
22+
logger = logging.getLogger(__name__)
2323

2424

2525
def get_prediction(project_id: str, preference: Preference = None) -> Response[dict]:

sync/awsdatabricks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
_get_all_cluster_events,
1515
_get_cluster_instances_from_dbfs,
1616
_wait_for_cluster_termination,
17+
apply_prediction,
18+
apply_project_recommendation,
1719
create_and_record_run,
1820
create_and_wait_for_run,
1921
create_cluster,
@@ -83,6 +85,8 @@
8385
"wait_for_run_and_cluster",
8486
"terminate_cluster",
8587
"event_log_poll_duration_seconds",
88+
"apply_prediction",
89+
"apply_project_recommendation",
8690
]
8791

8892

sync/azuredatabricks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
_get_all_cluster_events,
1919
_get_cluster_instances_from_dbfs,
2020
_wait_for_cluster_termination,
21+
apply_prediction,
22+
apply_project_recommendation,
2123
create_and_record_run,
2224
create_and_wait_for_run,
2325
create_cluster,
@@ -87,6 +89,8 @@
8789
"wait_for_run_and_cluster",
8890
"terminate_cluster",
8991
"event_log_poll_duration_seconds",
92+
"apply_prediction",
93+
"apply_project_recommendation",
9094
]
9195

9296

0 commit comments

Comments
 (0)