@@ -482,33 +482,19 @@ def get_prediction_job(
482
482
:return: job object with prediction applied to it
483
483
:rtype: Response[dict]
484
484
"""
485
- prediction_response = get_prediction (prediction_id )
486
- prediction = prediction_response .result
487
- if prediction :
488
- job = get_default_client ().get_job (job_id )
489
- if "error_code" in job :
490
- return Response (error = DatabricksAPIError (** job ))
491
-
492
- job_settings = job ["settings" ]
493
- tasks = job_settings .get ("tasks" , [])
494
- if tasks :
495
- cluster_response = _get_job_cluster (tasks , job_settings .get ("job_clusters" , []))
496
- cluster = cluster_response .result
497
- if cluster :
498
- # num_workers/autoscale are mutually exclusive settings, and we are relying on our Prediction
499
- # Recommendations to set these appropriately. Since we may recommend a Static cluster (i.e. a cluster
500
- # with `num_workers`) for a cluster that was originally autoscaled, we want to make sure to remove this
501
- # prior configuration
502
- if "num_workers" in cluster :
503
- del cluster ["num_workers" ]
504
-
505
- if "autoscale" in cluster :
506
- del cluster ["autoscale" ]
507
-
508
- prediction_cluster = _deep_update (
509
- cluster , prediction ["solutions" ][preference ]["configuration" ]
510
- )
485
+ job = get_default_client ().get_job (job_id )
486
+ if "error_code" in job :
487
+ return Response (error = DatabricksAPIError (** job ))
511
488
489
+ job_settings = job ["settings" ]
490
+ tasks = job_settings .get ("tasks" , [])
491
+ if tasks :
492
+ cluster_response = _get_job_cluster (tasks , job_settings .get ("job_clusters" , []))
493
+ cluster = cluster_response .result
494
+ if cluster :
495
+ prediction_cluster_response = get_prediction_cluster (cluster , prediction_id , preference )
496
+ prediction_cluster = prediction_cluster_response .result
497
+ if prediction_cluster :
512
498
cluster_key = tasks [0 ].get ("job_cluster_key" )
513
499
if cluster_key :
514
500
job_settings ["job_clusters" ] = [
@@ -517,10 +503,51 @@ def get_prediction_job(
517
503
if j .get ("job_cluster_key" ) != cluster_key
518
504
] + [{"job_cluster_key" : cluster_key , "new_cluster" : prediction_cluster }]
519
505
else :
506
+ # For `new_cluster` definitions, Databricks will automatically assign the newly created cluster a name,
507
+ # and will reject any run submissions where the `cluster_name` is pre-populated
508
+ if "cluster_name" in prediction_cluster :
509
+ del prediction_cluster ["cluster_name" ]
520
510
tasks [0 ]["new_cluster" ] = prediction_cluster
521
511
return Response (result = job )
522
- return cluster_response
523
- return Response (error = DatabricksError (message = "No task found in job" ))
512
+ return prediction_cluster_response
513
+ return cluster_response
514
+ return Response (error = DatabricksError (message = "No task found in job" ))
515
+
516
+
517
+ def get_prediction_cluster (
518
+ cluster : dict , prediction_id : str , preference : str = CONFIG .default_prediction_preference .value
519
+ ) -> Response [dict ]:
520
+ """Apply the prediction to the provided cluster.
521
+
522
+ The cluster is updated with configuration from the prediction and returned in the result.
523
+
524
+ :param cluster: Databricks cluster object
525
+ :type cluster: dict
526
+ :param prediction_id: prediction ID
527
+ :type prediction_id: str
528
+ :param preference: preferred prediction solution, defaults to local configuration
529
+ :type preference: str, optional
530
+ :return: job object with prediction applied to it
531
+ :rtype: Response[dict]
532
+ """
533
+ prediction_response = get_prediction (prediction_id )
534
+ prediction = prediction_response .result
535
+ if prediction :
536
+ # num_workers/autoscale are mutually exclusive settings, and we are relying on our Prediction
537
+ # Recommendations to set these appropriately. Since we may recommend a Static cluster (i.e. a cluster
538
+ # with `num_workers`) for a cluster that was originally autoscaled, we want to make sure to remove this
539
+ # prior configuration
540
+ if "num_workers" in cluster :
541
+ del cluster ["num_workers" ]
542
+
543
+ if "autoscale" in cluster :
544
+ del cluster ["autoscale" ]
545
+
546
+ prediction_cluster = _deep_update (
547
+ cluster , prediction ["solutions" ][preference ]["configuration" ]
548
+ )
549
+
550
+ return Response (result = prediction_cluster )
524
551
return prediction_response
525
552
526
553
@@ -550,10 +577,9 @@ def get_project_job(job_id: str, project_id: str, region_name: str = None) -> Re
550
577
cluster_response = _get_job_cluster (tasks , job_settings .get ("job_clusters" , []))
551
578
cluster = cluster_response .result
552
579
if cluster :
553
- project_settings_response = get_project_cluster_settings (project_id , region_name )
554
- project_cluster_settings = project_settings_response .result
555
- if project_cluster_settings :
556
- project_cluster = _deep_update (cluster , project_cluster_settings )
580
+ project_cluster_response = get_project_cluster (cluster , project_id , region_name )
581
+ project_cluster = project_cluster_response .result
582
+ if project_cluster :
557
583
cluster_key = tasks [0 ].get ("job_cluster_key" )
558
584
if cluster_key :
559
585
job_settings ["job_clusters" ] = [
@@ -565,11 +591,34 @@ def get_project_job(job_id: str, project_id: str, region_name: str = None) -> Re
565
591
tasks [0 ]["new_cluster" ] = project_cluster
566
592
567
593
return Response (result = job )
568
- return project_settings_response
594
+ return project_cluster_response
569
595
return cluster_response
570
596
return Response (error = DatabricksError (message = "No task found in job" ))
571
597
572
598
599
+ def get_project_cluster (cluster : dict , project_id : str , region_name : str = None ) -> Response [dict ]:
600
+ """Apply project configuration to a cluster.
601
+
602
+ The cluster is updated with tags and a log configuration to facilitate project continuity.
603
+
604
+ :param cluster: Databricks cluster object
605
+ :type cluster: dict
606
+ :param project_id: Sync project ID
607
+ :type project_id: str
608
+ :param region_name: region name, defaults to AWS configuration
609
+ :type region_name: str, optional
610
+ :return: project job object
611
+ :rtype: Response[dict]
612
+ """
613
+ project_settings_response = get_project_cluster_settings (project_id , region_name )
614
+ project_cluster_settings = project_settings_response .result
615
+ if project_cluster_settings :
616
+ project_cluster = _deep_update (cluster , project_cluster_settings )
617
+
618
+ return Response (result = project_cluster )
619
+ return project_settings_response
620
+
621
+
573
622
def get_project_cluster_settings (project_id : str , region_name : str = None ) -> Response [dict ]:
574
623
"""Gets cluster configuration for a project.
575
624
0 commit comments