Skip to content

Commit 4cad8b3

Browse files
Merge pull request #18 from synccomputingcode/graham/dbx-error-message
PROD-958 Update handling of incomplete cluster_report data for Databricks jobs
2 parents 2adc605 + 7faac13 commit 4cad8b3

File tree

8 files changed

+140
-39
lines changed

8 files changed

+140
-39
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"httpx~=0.23.0",
3232
"orjson~=3.8.0",
3333
"click~=8.1.0",
34+
"tenacity==8.2.2"
3435
]
3536
dynamic = ["version", "description"]
3637

sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Library for leveraging the power of Sync"""
2-
__version__ = "0.0.7"
2+
__version__ = "0.0.8"
33

44
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

sync/awsdatabricks.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,11 @@ def get_cluster(cluster_id: str) -> Response[dict]:
112112

113113
# TODO - Databricks configuration documentation
114114
def create_prediction_for_run(
115-
run_id: str, plan_type: str, compute_type: str, project_id: str = None
115+
run_id: str,
116+
plan_type: str,
117+
compute_type: str,
118+
project_id: str = None,
119+
allow_incomplete_cluster_report: bool = False,
116120
) -> Response[str]:
117121
"""Create a prediction for the specified Databricks run.
118122
@@ -124,6 +128,8 @@ def create_prediction_for_run(
124128
:type compute_type: str
125129
:param project_id: Sync project ID, defaults to None
126130
:type project_id: str, optional
131+
:param allow_incomplete_cluster_report: Whether creating a prediction with incomplete cluster report data should be allowable
132+
:type allow_incomplete_cluster_report: bool, optional, defaults to False
127133
:return: prediction ID
128134
:rtype: Response[str]
129135
"""
@@ -138,7 +144,9 @@ def create_prediction_for_run(
138144
if cluster_id := cluster_id_response.result:
139145
# Making these calls prior to fetching the event log allows Databricks a little extra time to finish
140146
# uploading all the event log data before we start checking for it
141-
cluster_report_response = _get_cluster_report(cluster_id, plan_type, compute_type)
147+
cluster_report_response = _get_cluster_report(
148+
cluster_id, plan_type, compute_type, allow_incomplete_cluster_report
149+
)
142150
if cluster_report := cluster_report_response.result:
143151

144152
cluster = cluster_report.cluster
@@ -161,7 +169,7 @@ def create_prediction_for_run(
161169

162170

163171
def get_cluster_report(
164-
run_id: str, plan_type: str, compute_type: str
172+
run_id: str, plan_type: str, compute_type: str, allow_incomplete: bool = False
165173
) -> Response[DatabricksClusterReport]:
166174
"""Fetches the cluster information required to create a Sync prediction
167175
@@ -171,6 +179,8 @@ def get_cluster_report(
171179
:type plan_type: str
172180
:param compute_type: Cluster compute type, e.g. "Jobs Compute"
173181
:type compute_type: str
182+
:param allow_incomplete: Whether creating a cluster report with incomplete data should be allowable
183+
:type allow_incomplete: bool, optional, defaults to False
174184
:return: cluster report
175185
:rtype: Response[DatabricksClusterReport]
176186
"""
@@ -183,13 +193,13 @@ def get_cluster_report(
183193

184194
cluster_id_response = _get_run_cluster_id(run["tasks"])
185195
if cluster_id := cluster_id_response.result:
186-
return _get_cluster_report(cluster_id, plan_type, compute_type)
196+
return _get_cluster_report(cluster_id, plan_type, compute_type, allow_incomplete)
187197

188198
return cluster_id_response
189199

190200

191201
def _get_cluster_report(
192-
cluster_id: str, plan_type: str, compute_type: str
202+
cluster_id: str, plan_type: str, compute_type: str, allow_incomplete: bool
193203
) -> Response[DatabricksClusterReport]:
194204
cluster = get_default_client().get_cluster(cluster_id)
195205
if "error_code" in cluster:
@@ -209,11 +219,15 @@ def _get_cluster_report(
209219
]
210220
)
211221
if not instances["Reservations"]:
212-
return Response(
213-
error=DatabricksError(
214-
message=f"Unable to find any active or recently terminated instances for cluster `{cluster_id}` in `{aws_region_name}`"
215-
)
222+
no_instances_message = (
223+
f"Unable to find any active or recently terminated instances for cluster `{cluster_id}` in `{aws_region_name}`. "
224+
+ "Please refer to the following documentation for options on how to address this - "
225+
+ "https://synccomputingcode.github.io/syncsparkpy/reference/awsdatabricks.html"
216226
)
227+
if allow_incomplete:
228+
logger.warning(no_instances_message)
229+
else:
230+
return Response(error=DatabricksError(message=no_instances_message))
217231

218232
return Response(
219233
result=DatabricksClusterReport(
@@ -281,7 +295,9 @@ def get_prediction_job(
281295
if "autoscale" in cluster:
282296
del cluster["autoscale"]
283297

284-
prediction_cluster = _deep_update(cluster, prediction["solutions"][preference]["configuration"])
298+
prediction_cluster = _deep_update(
299+
cluster, prediction["solutions"][preference]["configuration"]
300+
)
285301

286302
if cluster_key := tasks[0].get("job_cluster_key"):
287303
job_settings["job_clusters"] = [

sync/cli/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def main(debug: bool):
2323
if debug:
2424
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
2525
else:
26-
logging.basicConfig(level=logging.CRITICAL, format=LOG_FORMAT)
26+
logging.basicConfig(level=logging.WARNING, format=LOG_FORMAT)
2727

2828

2929
main.add_command(predictions.predictions)

sync/cli/awsdatabricks.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,22 @@ def run_job(
6060
default=DatabricksComputeType.JOBS_COMPUTE,
6161
)
6262
@click.option("--project", callback=validate_project)
63+
@click.option(
64+
"--allow-incomplete",
65+
is_flag=True,
66+
default=False,
67+
help="Force creation of a prediction even with incomplete cluster data.",
68+
)
6369
def create_prediction(
64-
run_id: str, plan: DatabricksPlanType, compute: DatabricksComputeType, project: dict = None
70+
run_id: str,
71+
plan: DatabricksPlanType,
72+
compute: DatabricksComputeType,
73+
project: dict = None,
74+
allow_incomplete: bool = False,
6575
):
6676
"""Create a prediction for a job run"""
6777
prediction_response = awsdatabricks.create_prediction_for_run(
68-
run_id, plan, compute, project["id"]
78+
run_id, plan, compute, project["id"], allow_incomplete
6979
)
7080
if prediction := prediction_response.result:
7181
click.echo(f"Prediction ID: {prediction}")
@@ -81,9 +91,20 @@ def create_prediction(
8191
type=click.Choice(DatabricksComputeType),
8292
default=DatabricksComputeType.JOBS_COMPUTE,
8393
)
84-
def get_cluster_report(run_id: str, plan: DatabricksPlanType, compute: DatabricksComputeType):
94+
@click.option(
95+
"--allow-incomplete",
96+
is_flag=True,
97+
default=False,
98+
help="Force creation of a cluster report even if some data is missing.",
99+
)
100+
def get_cluster_report(
101+
run_id: str,
102+
plan: DatabricksPlanType,
103+
compute: DatabricksComputeType,
104+
allow_incomplete: bool = False,
105+
):
85106
"""Get a cluster report"""
86-
config_response = awsdatabricks.get_cluster_report(run_id, plan, compute)
107+
config_response = awsdatabricks.get_cluster_report(run_id, plan, compute, allow_incomplete)
87108
if config := config_response.result:
88109
click.echo(
89110
orjson.dumps(

sync/clients/__init__.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import httpx
12
import orjson
3+
from tenacity import Retrying, TryAgain, stop_after_attempt, wait_exponential_jitter
24

35
from sync import __version__
46

@@ -15,3 +17,56 @@ def encode_json(obj: dict) -> tuple[dict, str]:
1517
"Content-Length": str(len(json)),
1618
"Content-Type": "application/json",
1719
}, json
20+
21+
22+
class RetryableHTTPClient:
23+
"""
24+
Smaller wrapper around httpx.Client/AsyncClient to contain retrying logic that httpx does not offer natively
25+
"""
26+
27+
_DEFAULT_RETRYABLE_STATUS_CODES: set[httpx.codes] = {
28+
httpx.codes.REQUEST_TIMEOUT,
29+
httpx.codes.TOO_EARLY,
30+
httpx.codes.TOO_MANY_REQUESTS,
31+
httpx.codes.INTERNAL_SERVER_ERROR,
32+
httpx.codes.BAD_GATEWAY,
33+
httpx.codes.SERVICE_UNAVAILABLE,
34+
httpx.codes.GATEWAY_TIMEOUT,
35+
}
36+
37+
def __init__(self, client: httpx.Client | httpx.AsyncClient):
38+
self._client: httpx.Client | httpx.AsyncClient = client
39+
40+
def _send_request(self, request: httpx.Request) -> httpx.Response:
41+
try:
42+
for attempt in Retrying(
43+
stop=stop_after_attempt(3),
44+
wait=wait_exponential_jitter(initial=2, max=10, jitter=2),
45+
reraise=True,
46+
):
47+
with attempt:
48+
response = self._client.send(request)
49+
if response.status_code in self._DEFAULT_RETRYABLE_STATUS_CODES:
50+
raise TryAgain()
51+
except TryAgain:
52+
# If we max out on retries, then return the bad response back to the caller to handle as appropriate
53+
pass
54+
55+
return response
56+
57+
async def _send_request_async(self, request: httpx.Request) -> httpx.Response:
58+
try:
59+
for attempt in Retrying(
60+
stop=stop_after_attempt(3),
61+
wait=wait_exponential_jitter(initial=2, max=10, jitter=2),
62+
reraise=True,
63+
):
64+
with attempt:
65+
response = await self._client.send(request)
66+
if response.status_code in self._DEFAULT_RETRYABLE_STATUS_CODES:
67+
raise TryAgain()
68+
except TryAgain:
69+
# If we max out on retries, then return the bad response back to the caller to handle as appropriate
70+
pass
71+
72+
return response

sync/clients/databricks.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import httpx
55

66
from ..config import DB_CONFIG
7-
from . import USER_AGENT, encode_json
7+
from . import USER_AGENT, RetryableHTTPClient, encode_json
88

99
logger = logging.getLogger(__name__)
1010

@@ -19,10 +19,14 @@ def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Re
1919
yield request
2020

2121

22-
class DatabricksClient:
22+
class DatabricksClient(RetryableHTTPClient):
2323
def __init__(self, base_url: str, access_token: str):
24-
self._client = httpx.Client(
25-
base_url=base_url, headers={"User-Agent": USER_AGENT}, auth=DatabricksAuth(access_token)
24+
super().__init__(
25+
client=httpx.Client(
26+
base_url=base_url,
27+
headers={"User-Agent": USER_AGENT},
28+
auth=DatabricksAuth(access_token),
29+
)
2630
)
2731

2832
def create_cluster(self, config: dict) -> dict:
@@ -92,9 +96,9 @@ def get_run(self, run_id: str) -> dict:
9296
)
9397

9498
def _send(self, request: httpx.Request) -> dict:
95-
response = self._client.send(request)
99+
response = self._send_request(request)
96100

97-
if response.status_code >= 200 and response.status_code < 300:
101+
if 200 <= response.status_code < 300:
98102
return response.json()
99103

100104
if response.headers.get("Content-Type", "").startswith("application/json"):

sync/clients/sync.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import httpx
55

66
from ..config import API_KEY, CONFIG, APIKey
7-
from . import USER_AGENT, encode_json
7+
from . import USER_AGENT, RetryableHTTPClient, encode_json
88

99
logger = logging.getLogger(__name__)
1010

@@ -48,13 +48,15 @@ def update_access_token(self, response: httpx.Response):
4848
logger.error(f"{response.status_code}: Failed to authenticate")
4949

5050

51-
class SyncClient:
51+
class SyncClient(RetryableHTTPClient):
5252
def __init__(self, api_url, api_key):
53-
self._client = httpx.Client(
54-
base_url=api_url,
55-
headers={"User-Agent": USER_AGENT},
56-
auth=SyncAuth(api_url, api_key),
57-
timeout=60.0,
53+
super().__init__(
54+
client=httpx.Client(
55+
base_url=api_url,
56+
headers={"User-Agent": USER_AGENT},
57+
auth=SyncAuth(api_url, api_key),
58+
timeout=60.0,
59+
)
5860
)
5961

6062
def get_products(self) -> dict:
@@ -109,9 +111,9 @@ def delete_project(self, project_id: str) -> dict:
109111
return self._send(self._client.build_request("DELETE", f"/v1/projects/{project_id}"))
110112

111113
def _send(self, request: httpx.Request) -> dict:
112-
response = self._client.send(request)
114+
response = self._send_request(request)
113115

114-
if response.status_code >= 200 and response.status_code < 300:
116+
if 200 <= response.status_code < 300:
115117
return response.json()
116118

117119
if response.headers.get("Content-Type", "").startswith("application/json"):
@@ -126,13 +128,15 @@ def _send(self, request: httpx.Request) -> dict:
126128
return {"error": {"code": "Sync API Error", "message": "Transaction failure"}}
127129

128130

129-
class ASyncClient:
131+
class ASyncClient(RetryableHTTPClient):
130132
def __init__(self, api_url, api_key):
131-
self._client = httpx.AsyncClient(
132-
base_url=api_url,
133-
headers={"User-Agent": USER_AGENT},
134-
auth=SyncAuth(api_url, api_key),
135-
timeout=60.0,
133+
super().__init__(
134+
client=httpx.AsyncClient(
135+
base_url=api_url,
136+
headers={"User-Agent": USER_AGENT},
137+
auth=SyncAuth(api_url, api_key),
138+
timeout=60.0,
139+
)
136140
)
137141

138142
async def create_prediction(self, prediction: dict) -> dict:
@@ -184,9 +188,9 @@ async def delete_project(self, project_id: str) -> dict:
184188
return await self._send(self._client.build_request("DELETE", f"/v1/projects/{project_id}"))
185189

186190
async def _send(self, request: httpx.Request) -> dict:
187-
response = await self._client.send(request)
191+
response = await self._send_request_async(request)
188192

189-
if response.status_code >= 200 and response.status_code < 300:
193+
if 200 <= response.status_code < 300:
190194
return response.json()
191195

192196
if response.headers.get("Content-Type", "").startswith("application/json"):

0 commit comments

Comments
 (0)