Skip to content

Commit b8475c0

Browse files
authored
[PROD-1044] Refactor and CLI update for easier Airflow integration (#32)
1 parent 60803c3 commit b8475c0

File tree

3 files changed

+122
-56
lines changed

3 files changed

+122
-56
lines changed

sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Library for leveraging the power of Sync"""
2-
__version__ = "0.1.1"
2+
__version__ = "0.1.2"
33

44
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

sync/awsdatabricks.py

Lines changed: 82 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -482,33 +482,19 @@ def get_prediction_job(
482482
:return: job object with prediction applied to it
483483
:rtype: Response[dict]
484484
"""
485-
prediction_response = get_prediction(prediction_id)
486-
prediction = prediction_response.result
487-
if prediction:
488-
job = get_default_client().get_job(job_id)
489-
if "error_code" in job:
490-
return Response(error=DatabricksAPIError(**job))
491-
492-
job_settings = job["settings"]
493-
tasks = job_settings.get("tasks", [])
494-
if tasks:
495-
cluster_response = _get_job_cluster(tasks, job_settings.get("job_clusters", []))
496-
cluster = cluster_response.result
497-
if cluster:
498-
# num_workers/autoscale are mutually exclusive settings, and we are relying on our Prediction
499-
# Recommendations to set these appropriately. Since we may recommend a Static cluster (i.e. a cluster
500-
# with `num_workers`) for a cluster that was originally autoscaled, we want to make sure to remove this
501-
# prior configuration
502-
if "num_workers" in cluster:
503-
del cluster["num_workers"]
504-
505-
if "autoscale" in cluster:
506-
del cluster["autoscale"]
507-
508-
prediction_cluster = _deep_update(
509-
cluster, prediction["solutions"][preference]["configuration"]
510-
)
485+
job = get_default_client().get_job(job_id)
486+
if "error_code" in job:
487+
return Response(error=DatabricksAPIError(**job))
511488

489+
job_settings = job["settings"]
490+
tasks = job_settings.get("tasks", [])
491+
if tasks:
492+
cluster_response = _get_job_cluster(tasks, job_settings.get("job_clusters", []))
493+
cluster = cluster_response.result
494+
if cluster:
495+
prediction_cluster_response = get_prediction_cluster(cluster, prediction_id, preference)
496+
prediction_cluster = prediction_cluster_response.result
497+
if prediction_cluster:
512498
cluster_key = tasks[0].get("job_cluster_key")
513499
if cluster_key:
514500
job_settings["job_clusters"] = [
@@ -517,10 +503,51 @@ def get_prediction_job(
517503
if j.get("job_cluster_key") != cluster_key
518504
] + [{"job_cluster_key": cluster_key, "new_cluster": prediction_cluster}]
519505
else:
506+
# For `new_cluster` definitions, Databricks will automatically assign the newly created cluster a name,
507+
# and will reject any run submissions where the `cluster_name` is pre-populated
508+
if "cluster_name" in prediction_cluster:
509+
del prediction_cluster["cluster_name"]
520510
tasks[0]["new_cluster"] = prediction_cluster
521511
return Response(result=job)
522-
return cluster_response
523-
return Response(error=DatabricksError(message="No task found in job"))
512+
return prediction_cluster_response
513+
return cluster_response
514+
return Response(error=DatabricksError(message="No task found in job"))
515+
516+
517+
def get_prediction_cluster(
518+
cluster: dict, prediction_id: str, preference: str = CONFIG.default_prediction_preference.value
519+
) -> Response[dict]:
520+
"""Apply the prediction to the provided cluster.
521+
522+
The cluster is updated with configuration from the prediction and returned in the result.
523+
524+
:param cluster: Databricks cluster object
525+
:type cluster: dict
526+
:param prediction_id: prediction ID
527+
:type prediction_id: str
528+
:param preference: preferred prediction solution, defaults to local configuration
529+
:type preference: str, optional
530+
:return: job object with prediction applied to it
531+
:rtype: Response[dict]
532+
"""
533+
prediction_response = get_prediction(prediction_id)
534+
prediction = prediction_response.result
535+
if prediction:
536+
# num_workers/autoscale are mutually exclusive settings, and we are relying on our Prediction
537+
# Recommendations to set these appropriately. Since we may recommend a Static cluster (i.e. a cluster
538+
# with `num_workers`) for a cluster that was originally autoscaled, we want to make sure to remove this
539+
# prior configuration
540+
if "num_workers" in cluster:
541+
del cluster["num_workers"]
542+
543+
if "autoscale" in cluster:
544+
del cluster["autoscale"]
545+
546+
prediction_cluster = _deep_update(
547+
cluster, prediction["solutions"][preference]["configuration"]
548+
)
549+
550+
return Response(result=prediction_cluster)
524551
return prediction_response
525552

526553

@@ -550,10 +577,9 @@ def get_project_job(job_id: str, project_id: str, region_name: str = None) -> Re
550577
cluster_response = _get_job_cluster(tasks, job_settings.get("job_clusters", []))
551578
cluster = cluster_response.result
552579
if cluster:
553-
project_settings_response = get_project_cluster_settings(project_id, region_name)
554-
project_cluster_settings = project_settings_response.result
555-
if project_cluster_settings:
556-
project_cluster = _deep_update(cluster, project_cluster_settings)
580+
project_cluster_response = get_project_cluster(cluster, project_id, region_name)
581+
project_cluster = project_cluster_response.result
582+
if project_cluster:
557583
cluster_key = tasks[0].get("job_cluster_key")
558584
if cluster_key:
559585
job_settings["job_clusters"] = [
@@ -565,11 +591,34 @@ def get_project_job(job_id: str, project_id: str, region_name: str = None) -> Re
565591
tasks[0]["new_cluster"] = project_cluster
566592

567593
return Response(result=job)
568-
return project_settings_response
594+
return project_cluster_response
569595
return cluster_response
570596
return Response(error=DatabricksError(message="No task found in job"))
571597

572598

599+
def get_project_cluster(cluster: dict, project_id: str, region_name: str = None) -> Response[dict]:
600+
"""Apply project configuration to a cluster.
601+
602+
The cluster is updated with tags and a log configuration to facilitate project continuity.
603+
604+
:param cluster: Databricks cluster object
605+
:type cluster: dict
606+
:param project_id: Sync project ID
607+
:type project_id: str
608+
:param region_name: region name, defaults to AWS configuration
609+
:type region_name: str, optional
610+
:return: project job object
611+
:rtype: Response[dict]
612+
"""
613+
project_settings_response = get_project_cluster_settings(project_id, region_name)
614+
project_cluster_settings = project_settings_response.result
615+
if project_cluster_settings:
616+
project_cluster = _deep_update(cluster, project_cluster_settings)
617+
618+
return Response(result=project_cluster)
619+
return project_settings_response
620+
621+
573622
def get_project_cluster_settings(project_id: str, region_name: str = None) -> Response[dict]:
574623
"""Gets cluster configuration for a project.
575624

sync/cli/__init__.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,40 +33,57 @@ def main(debug: bool):
3333

3434

3535
@main.command
36-
def configure():
36+
@click.option("--api-key-id")
37+
@click.option("--api-key-secret")
38+
@click.option("--prediction-preference")
39+
@click.option("--databricks-host")
40+
@click.option("--databricks-token")
41+
@click.option("--databricks-region")
42+
def configure(
43+
api_key_id: str = None,
44+
api_key_secret: str = None,
45+
prediction_preference: str = None,
46+
databricks_host: str = None,
47+
databricks_token: str = None,
48+
databricks_region: str = None,
49+
):
3750
"""Configure Sync Library"""
38-
api_key_id = click.prompt("Sync API key ID", default=API_KEY.id if API_KEY else None)
39-
api_key_secret = click.prompt(
51+
api_key_id = api_key_id or click.prompt(
52+
"Sync API key ID", default=API_KEY.id if API_KEY else None
53+
)
54+
api_key_secret = api_key_secret or click.prompt(
4055
"Sync API key secret",
4156
default=API_KEY.secret if API_KEY else None,
4257
hide_input=True,
4358
show_default=False,
4459
)
4560

46-
prediction_preference = click.prompt(
61+
prediction_preference = prediction_preference or click.prompt(
4762
"Default prediction preference",
4863
type=click.Choice([p.value for p in Preference]),
4964
default=(CONFIG.default_prediction_preference or Preference.ECONOMY).value,
5065
)
5166

52-
dbx_host = OPTIONAL_DEFAULT
53-
dbx_token = OPTIONAL_DEFAULT
54-
dbx_region = OPTIONAL_DEFAULT
55-
if click.confirm("Would you like to configure a Databricks workspace?"):
56-
dbx_host = click.prompt(
57-
"Databricks host (prefix with https://)",
58-
default=DB_CONFIG.host if DB_CONFIG else OPTIONAL_DEFAULT,
59-
)
60-
dbx_token = click.prompt(
61-
"Databricks token",
62-
default=DB_CONFIG.token if DB_CONFIG else OPTIONAL_DEFAULT,
63-
hide_input=True,
64-
show_default=False,
65-
)
66-
dbx_region = click.prompt(
67-
"Databricks AWS region name",
68-
default=DB_CONFIG.aws_region_name if DB_CONFIG else OPTIONAL_DEFAULT,
69-
)
67+
dbx_host = databricks_host or OPTIONAL_DEFAULT
68+
dbx_token = databricks_token or OPTIONAL_DEFAULT
69+
dbx_region = databricks_region or OPTIONAL_DEFAULT
70+
# Skip only if all are provided since all are required to initialize the configuration below
71+
if any(param == OPTIONAL_DEFAULT for param in (dbx_host, dbx_token, dbx_region)):
72+
if click.confirm("Would you like to configure a Databricks workspace?"):
73+
dbx_host = click.prompt(
74+
"Databricks host (prefix with https://)",
75+
default=DB_CONFIG.host if DB_CONFIG else OPTIONAL_DEFAULT,
76+
)
77+
dbx_token = click.prompt(
78+
"Databricks token",
79+
default=DB_CONFIG.token if DB_CONFIG else OPTIONAL_DEFAULT,
80+
hide_input=True,
81+
show_default=False,
82+
)
83+
dbx_region = click.prompt(
84+
"Databricks AWS region name",
85+
default=DB_CONFIG.aws_region_name if DB_CONFIG else OPTIONAL_DEFAULT,
86+
)
7087

7188
init(
7289
APIKey(api_key_id=api_key_id, api_key_secret=api_key_secret),

0 commit comments

Comments
 (0)