|
14 | 14 |
|
15 | 15 | import boto3 as boto
|
16 | 16 |
|
17 |
| -from sync.api.predictions import create_prediction_with_eventlog_bytes, get_prediction |
| 17 | +from sync.api.predictions import ( |
| 18 | + create_prediction_with_eventlog_bytes, |
| 19 | + get_prediction, |
| 20 | + get_predictions, |
| 21 | +) |
18 | 22 | from sync.api.projects import (
|
19 | 23 | create_project_submission_with_eventlog_bytes,
|
20 | 24 | get_project,
|
@@ -153,22 +157,22 @@ def create_prediction_for_run(
|
153 | 157 |
|
154 | 158 | project_cluster_tasks = _get_project_cluster_tasks(run, exclude_tasks)
|
155 | 159 |
|
156 |
| - cluster_id = None |
| 160 | + cluster_tasks = None |
157 | 161 | if project_id:
|
158 |
| - if project_id in project_cluster_tasks: |
159 |
| - cluster_id, tasks = project_cluster_tasks.get(project_id) |
| 162 | + cluster_tasks = project_cluster_tasks.get(project_id) |
160 | 163 |
|
161 |
| - if not cluster_id: |
162 |
| - if None in project_cluster_tasks and len(project_cluster_tasks) == 1: |
163 |
| - # If there's only 1 cluster and it's not tagged with project ID assume that's the one for the project |
164 |
| - cluster_id, tasks = project_cluster_tasks.get(None) |
165 |
| - |
166 |
| - if not cluster_id: |
167 |
| - return Response( |
168 |
| - error=DatabricksError( |
169 |
| - message=f"No cluster found in run {run_id} for project {project_id}" |
| 164 | + if not cluster_tasks: |
| 165 | + if len(project_cluster_tasks) == 1: |
| 166 | + # If there's only 1 cluster assume that's the one for the project |
| 167 | + cluster_tasks = next(iter(project_cluster_tasks.values())) |
| 168 | + else: |
| 169 | + return Response( |
| 170 | + error=DatabricksError( |
| 171 | + message=f"No cluster found in run {run_id} for project {project_id}" |
| 172 | + ) |
170 | 173 | )
|
171 |
| - ) |
| 174 | + |
| 175 | + cluster_id, tasks = cluster_tasks |
172 | 176 |
|
173 | 177 | return _create_prediction(
|
174 | 178 | cluster_id, tasks, plan_type, compute_type, project_id, allow_incomplete_cluster_report
|
@@ -294,11 +298,19 @@ def create_submission_for_run(
|
294 | 298 |
|
295 | 299 | project_cluster_tasks = _get_project_cluster_tasks(run, exclude_tasks)
|
296 | 300 |
|
297 |
| - if project_id in project_cluster_tasks: |
298 |
| - cluster_id, tasks = project_cluster_tasks.get(project_id) |
299 |
| - elif None in project_cluster_tasks and len(project_cluster_tasks) == 1: |
300 |
| - # If there's only 1 cluster and it's not tagged with project ID assume that's the one for the project |
301 |
| - cluster_id, tasks = project_cluster_tasks.get(None) |
| 301 | + cluster_tasks = project_cluster_tasks.get(project_id) |
| 302 | + if not cluster_tasks: |
| 303 | + if len(project_cluster_tasks) == 1: |
| 304 | + # If there's only 1 cluster assume that's the one for the project |
| 305 | + cluster_tasks = next(iter(project_cluster_tasks.values())) |
| 306 | + else: |
| 307 | + return Response( |
| 308 | + error=DatabricksError( |
| 309 | + message=f"Unable to locate cluster in run {run_id} for project {project_id}" |
| 310 | + ) |
| 311 | + ) |
| 312 | + |
| 313 | + cluster_id, tasks = cluster_tasks |
302 | 314 |
|
303 | 315 | run_information_response = _get_run_information(
|
304 | 316 | cluster_id,
|
@@ -386,6 +398,7 @@ def get_cluster_report(
|
386 | 398 | return Response(error=DatabricksAPIError(**run))
|
387 | 399 |
|
388 | 400 | project_cluster_tasks = _get_project_cluster_tasks(run, exclude_tasks)
|
| 401 | + |
389 | 402 | cluster_tasks = project_cluster_tasks.get(project_id)
|
390 | 403 | if not cluster_tasks:
|
391 | 404 | return Response(
|
@@ -455,9 +468,11 @@ def record_run(
|
455 | 468 | if project_id:
|
456 | 469 | if project_id in project_cluster_tasks:
|
457 | 470 | filtered_project_cluster_tasks = {project_id: project_cluster_tasks.get(project_id)}
|
458 |
| - elif None in project_cluster_tasks and len(project_cluster_tasks) == 1: |
459 |
| - # If there's only 1 cluster and it's not tagged with project ID assume that's the one for the project |
460 |
| - filtered_project_cluster_tasks = {project_id: project_cluster_tasks.get(None)} |
| 471 | + elif len(project_cluster_tasks) == 1: |
| 472 | + # If there's only 1 cluster assume that's the one for the project |
| 473 | + filtered_project_cluster_tasks = { |
| 474 | + project_id: next(iter(project_cluster_tasks.values())) |
| 475 | + } |
461 | 476 | else:
|
462 | 477 | filtered_project_cluster_tasks = {
|
463 | 478 | cluster_project_id: cluster_tasks
|
@@ -499,6 +514,83 @@ def record_run(
|
499 | 514 | )
|
500 | 515 |
|
501 | 516 |
|
| 517 | +def apply_prediction( |
| 518 | + job_id: str, project_id: str, prediction_id: str = None, preference: str = None |
| 519 | +): |
| 520 | + """Updates jobs with prediction configuration |
| 521 | +
|
| 522 | + :param job_id: ID of job to apply prediction to |
| 523 | + :type job_id: str |
| 524 | + :param project_id: Sync project ID |
| 525 | + :type project_id: str |
| 526 | + :param prediction_id: Sync prediction ID, defaults to latest in project |
| 527 | + :type prediction_id: str, optional |
| 528 | + :param preference: Prediction preference, defaults to "recommended" then "economy" |
| 529 | + :type preference: str, optional |
| 530 | + :return: ID of applied prediction |
| 531 | + :rtype: Response[str] |
| 532 | + """ |
| 533 | + if prediction_id: |
| 534 | + prediction_response = get_prediction(prediction_id, preference) |
| 535 | + else: |
| 536 | + predictions_response = get_predictions(project_id=project_id) |
| 537 | + if predictions_response.error: |
| 538 | + return predictions_response |
| 539 | + prediction_id = predictions_response.result[0]["prediction_id"] |
| 540 | + prediction_response = get_prediction(prediction_id, preference) |
| 541 | + |
| 542 | + if prediction_response.error: |
| 543 | + return prediction_response |
| 544 | + |
| 545 | + prediction = prediction_response.result |
| 546 | + |
| 547 | + databricks_client = get_default_client() |
| 548 | + |
| 549 | + job = databricks_client.get_job(job_id) |
| 550 | + job_clusters = _get_project_job_clusters(job) |
| 551 | + |
| 552 | + project_cluster = job_clusters.get(project_id) |
| 553 | + if not project_cluster: |
| 554 | + if len(job_clusters) == 1: |
| 555 | + project_cluster = next(iter(job_clusters.values())) |
| 556 | + else: |
| 557 | + return Response( |
| 558 | + error=DatabricksError( |
| 559 | + message=f"Unable to locate cluster in job {job_id} for project {project_id}" |
| 560 | + ) |
| 561 | + ) |
| 562 | + |
| 563 | + project_cluster_path, _ = project_cluster |
| 564 | + |
| 565 | + if preference: |
| 566 | + prediction_cluster = prediction["solutions"][preference]["configuration"] |
| 567 | + else: |
| 568 | + prediction_cluster = prediction["solutions"].get( |
| 569 | + "recommended", prediction["solutions"]["economy"] |
| 570 | + )["configuration"] |
| 571 | + |
| 572 | + if "cluster_name" in prediction_cluster: |
| 573 | + del prediction_cluster["cluster_name"] |
| 574 | + |
| 575 | + if project_cluster_path[0] == "job_clusters": |
| 576 | + new_settings = { |
| 577 | + "job_clusters": [ |
| 578 | + {"job_cluster_key": project_cluster_path[1], "new_cluster": prediction_cluster} |
| 579 | + ] |
| 580 | + } |
| 581 | + else: |
| 582 | + new_settings = { |
| 583 | + "tasks": [{"task_key": project_cluster_path[1], "new_cluster": prediction_cluster}] |
| 584 | + } |
| 585 | + |
| 586 | + response = databricks_client.update_job(job_id, new_settings) |
| 587 | + |
| 588 | + if "error_code" in response: |
| 589 | + return Response(error=DatabricksAPIError(**response)) |
| 590 | + |
| 591 | + return Response(result=prediction_id) |
| 592 | + |
| 593 | + |
502 | 594 | def get_prediction_job(
|
503 | 595 | job_id: str, prediction_id: str, preference: str = CONFIG.default_prediction_preference.value
|
504 | 596 | ) -> Response[dict]:
|
@@ -586,6 +678,62 @@ def get_prediction_cluster(
|
586 | 678 | return prediction_response
|
587 | 679 |
|
588 | 680 |
|
| 681 | +def apply_project_recommendation(job_id: str, project_id: str, recommendation_id: str): |
| 682 | + """Updates jobs with project recommendation |
| 683 | +
|
| 684 | + :param job_id: ID of job to apply prediction to |
| 685 | + :type job_id: str |
| 686 | + :param project_id: Sync project ID |
| 687 | + :type project_id: str |
| 688 | + :param recommendation_id: Sync project recommendation ID |
| 689 | + :type recommendation_id: str |
| 690 | + :return: ID of applied recommendation |
| 691 | + :rtype: Response[str] |
| 692 | + """ |
| 693 | + databricks_client = get_default_client() |
| 694 | + |
| 695 | + job = databricks_client.get_job(job_id) |
| 696 | + job_clusters = _get_project_job_clusters(job) |
| 697 | + |
| 698 | + project_cluster = job_clusters.get(project_id) |
| 699 | + if not project_cluster: |
| 700 | + if len(job_clusters) == 1: |
| 701 | + project_cluster = next(iter(job_clusters.values())) |
| 702 | + else: |
| 703 | + return Response( |
| 704 | + error=DatabricksError( |
| 705 | + message=f"Unable to locate cluster in job {job_id} for project {project_id}" |
| 706 | + ) |
| 707 | + ) |
| 708 | + |
| 709 | + project_cluster_path, project_cluster_def = project_cluster |
| 710 | + |
| 711 | + new_cluster_def_response = get_recommendation_cluster( |
| 712 | + project_cluster_def, project_id, recommendation_id |
| 713 | + ) |
| 714 | + if new_cluster_def_response.error: |
| 715 | + return new_cluster_def_response |
| 716 | + new_cluster_def = new_cluster_def_response.result |
| 717 | + |
| 718 | + if project_cluster_path[0] == "job_clusters": |
| 719 | + new_settings = { |
| 720 | + "job_clusters": [ |
| 721 | + {"job_cluster_key": project_cluster_path[1], "new_cluster": new_cluster_def} |
| 722 | + ] |
| 723 | + } |
| 724 | + else: |
| 725 | + new_settings = { |
| 726 | + "tasks": [{"task_key": project_cluster_path[1], "new_cluster": new_cluster_def}] |
| 727 | + } |
| 728 | + |
| 729 | + response = databricks_client.update_job(job_id, new_settings) |
| 730 | + |
| 731 | + if "error_code" in response: |
| 732 | + return Response(error=DatabricksAPIError(**response)) |
| 733 | + |
| 734 | + return Response(result=recommendation_id) |
| 735 | + |
| 736 | + |
589 | 737 | def get_recommendation_job(job_id: str, project_id: str, recommendation_id: str) -> Response[dict]:
|
590 | 738 | """Apply the recommendation to the specified job.
|
591 | 739 |
|
@@ -1222,6 +1370,46 @@ def _get_job_cluster(tasks: List[dict], job_clusters: list) -> Response[dict]:
|
1222 | 1370 | return Response(error=DatabricksError(message="Not all tasks use the same cluster"))
|
1223 | 1371 |
|
1224 | 1372 |
|
| 1373 | +def _get_project_job_clusters( |
| 1374 | + job: dict, |
| 1375 | + exclude_tasks: Union[Collection[str], None] = None, |
| 1376 | +) -> Dict[str, Tuple[Tuple[str], dict]]: |
| 1377 | + """Returns a mapping of project IDs to cluster paths and clusters. |
| 1378 | +
|
| 1379 | + Cluster paths are tuples that can be used to locate clusters in a job object, e.g. |
| 1380 | +
|
| 1381 | + ("tasks", <task_key>) or ("job_clusters", <job_cluster_key>) |
| 1382 | +
|
| 1383 | + Items for project IDs with more than 1 associated cluster are omitted""" |
| 1384 | + job_clusters = { |
| 1385 | + c["job_cluster_key"]: c["new_cluster"] for c in job["settings"].get("job_clusters", []) |
| 1386 | + } |
| 1387 | + all_project_clusters = defaultdict(list) |
| 1388 | + |
| 1389 | + for task in job["settings"]["tasks"]: |
| 1390 | + if not exclude_tasks or task["task_key"] not in exclude_tasks: |
| 1391 | + task_cluster = task.get("new_cluster") |
| 1392 | + if task_cluster: |
| 1393 | + task_cluster_path = ("tasks", task["task_key"]) |
| 1394 | + |
| 1395 | + if not task_cluster: |
| 1396 | + task_cluster = job_clusters.get(task.get("job_cluster_key")) |
| 1397 | + task_cluster_path = ("job_clusters", task.get("job_cluster_key")) |
| 1398 | + |
| 1399 | + if task_cluster: |
| 1400 | + cluster_project_id = task_cluster.get("custom_tags", {}).get("sync:project-id") |
| 1401 | + all_project_clusters[cluster_project_id].append((task_cluster_path, task_cluster)) |
| 1402 | + |
| 1403 | + filtered_project_clusters = {} |
| 1404 | + for project_id, clusters in all_project_clusters.items(): |
| 1405 | + if len(clusters) > 1: |
| 1406 | + logger.warning(f"More than 1 cluster found for project ID {project_id}") |
| 1407 | + else: |
| 1408 | + filtered_project_clusters[project_id] = clusters[0] |
| 1409 | + |
| 1410 | + return filtered_project_clusters |
| 1411 | + |
| 1412 | + |
1225 | 1413 | def _get_project_cluster_tasks(
|
1226 | 1414 | run: dict,
|
1227 | 1415 | exclude_tasks: Union[Collection[str], None] = None,
|
|
0 commit comments