Skip to content

Commit 9af6577

Browse files
Merge pull request #27 from synccomputingcode/graham/dbfs-2
[PROD-1062] Add support for getting the event log from DBFS
2 parents e9596ad + aef18bd commit 9af6577

File tree

5 files changed

+240
-107
lines changed

5 files changed

+240
-107
lines changed

sync/awsdatabricks.py

Lines changed: 150 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Utilities for interacting with Databricks
33
"""
4-
import base64
4+
import gzip
55
import io
66
import logging
77
import time
@@ -26,6 +26,7 @@
2626
Platform,
2727
Response,
2828
)
29+
from sync.utils.dbfs import format_dbfs_filepath, read_dbfs_file, write_dbfs_file
2930

3031
logger = logging.getLogger(__name__)
3132

@@ -254,34 +255,60 @@ def _get_cluster_report(
254255
)
255256

256257

258+
def _get_cluster_instances_from_s3(bucket: str, file_key: str, cluster_id):
259+
s3 = boto.client("s3")
260+
try:
261+
return s3.get_object(Bucket=bucket, Key=file_key)["Body"].read()
262+
except ClientError as ex:
263+
if ex.response["Error"]["Code"] == "NoSuchKey":
264+
logger.warning(
265+
f"Could not find sync_data/cluster_instances.json for cluster: {cluster_id}"
266+
)
267+
else:
268+
logger.error(
269+
f"Unexpected error encountered while attempting to fetch sync_data/cluster_instances.json: {ex}"
270+
)
271+
272+
273+
def _get_cluster_instances_from_dbfs(filepath: str):
274+
filepath = format_dbfs_filepath(filepath)
275+
dbx_client = get_default_client()
276+
try:
277+
return read_dbfs_file(filepath, dbx_client)
278+
except Exception as ex:
279+
logger.error(f"Unexpected error encountered while attempting to fetch {filepath}: {ex}")
280+
281+
257282
def _get_cluster_instances(cluster: dict) -> Response[dict]:
258283
cluster_instances = None
259284
aws_region_name = DB_CONFIG.aws_region_name
260285

261-
cluster_id = cluster["cluster_id"]
262286
cluster_log_dest = _cluster_log_destination(cluster)
263287

264288
if cluster_log_dest:
265289
(_, filesystem, bucket, base_prefix) = cluster_log_dest
266-
if filesystem == "s3":
267-
s3 = boto.client("s3")
268290

269-
cluster_instances_file_key = f"{base_prefix}/sync_data/cluster_instances.json"
291+
cluster_id = cluster["cluster_id"]
292+
spark_context_id = cluster["spark_context_id"]
293+
cluster_instances_file_key = (
294+
f"{base_prefix}/sync_data/{spark_context_id}/cluster_instances.json"
295+
)
270296

271-
try:
272-
cluster_instances_file_response = s3.get_object(
273-
Bucket=bucket, Key=cluster_instances_file_key
274-
)
275-
cluster_instances = orjson.loads(cluster_instances_file_response["Body"].read())
276-
except ClientError as ex:
277-
if ex.response["Error"]["Code"] == "NoSuchKey":
278-
logger.warning(
279-
f"Could not find sync_data/cluster_instances.json for cluster: {cluster_id}"
280-
)
281-
else:
282-
logger.error(
283-
f"Unexpected error encountered while attempting to fetch sync_data/cluster_instances.json: {ex}"
284-
)
297+
cluster_instances_file_response = None
298+
if filesystem == "s3":
299+
cluster_instances_file_response = _get_cluster_instances_from_s3(
300+
bucket, cluster_instances_file_key, cluster_id
301+
)
302+
elif filesystem == "dbfs":
303+
cluster_instances_file_response = _get_cluster_instances_from_dbfs(
304+
cluster_instances_file_key
305+
)
306+
307+
cluster_instances = (
308+
orjson.loads(cluster_instances_file_response)
309+
if cluster_instances_file_response
310+
else None
311+
)
285312

286313
# If this cluster does not have the "Sync agent" configured, attempt a best-effort snapshot of the instances that
287314
# are associated with this cluster
@@ -886,91 +913,111 @@ def _cluster_log_destination(
886913

887914
def monitor_cluster(cluster_id: str, polling_period: int = 30) -> None:
888915
cluster = get_default_client().get_cluster(cluster_id)
889-
cluster_log_dest = _cluster_log_destination(cluster)
890-
if cluster_log_dest:
891-
(_, filesystem, bucket, base_prefix) = cluster_log_dest
892-
# If the event log destination is just a *bucket* without any sub-path, then we don't want to include
893-
# a leading `/` in our Prefix (which will make it so that we never actually find the event log), so
894-
# we make sure to re-strip our final Prefix
895-
file_key = f"{base_prefix}/sync_data/cluster_instances.json".strip("/")
916+
spark_context_id = cluster.get("spark_context_id")
896917

897-
aws_region_name = DB_CONFIG.aws_region_name
898-
ec2 = boto.client("ec2", region_name=aws_region_name)
918+
while not spark_context_id:
919+
# This is largely just a convenience for when this command is run by someone locally
920+
logger.info("Waiting for cluster startup...")
921+
sleep(30)
922+
cluster = get_default_client().get_cluster(cluster_id)
923+
spark_context_id = cluster.get("spark_context_id")
899924

900-
if filesystem == "s3":
901-
s3 = boto.client("s3")
925+
(log_url, filesystem, bucket, base_prefix) = _cluster_log_destination(cluster)
926+
if log_url:
927+
_monitor_cluster(
928+
(log_url, filesystem, bucket, base_prefix), cluster_id, spark_context_id, polling_period
929+
)
930+
else:
931+
logger.warning("Unable to monitor cluster due to missing cluster log destination - exiting")
902932

903-
def write_file(body: bytes):
904-
s3.put_object(Bucket=bucket, Key=file_key, Body=body)
905933

906-
elif filesystem == "dbfs":
934+
def _monitor_cluster(
935+
cluster_log_destination, cluster_id: str, spark_context_id: int, polling_period: int
936+
) -> None:
937+
(log_url, filesystem, bucket, base_prefix) = cluster_log_destination
938+
# If the event log destination is just a *bucket* without any sub-path, then we don't want to include
939+
# a leading `/` in our Prefix (which will make it so that we never actually find the event log), so
940+
# we make sure to re-strip our final Prefix
941+
file_key = f"{base_prefix}/sync_data/{spark_context_id}/cluster_instances.json".strip("/")
907942

908-
def write_file(body: bytes):
909-
# TODO
910-
raise Exception()
911-
912-
previous_instances = {}
913-
while True:
914-
try:
915-
instances = ec2.describe_instances(
916-
Filters=[
917-
{"Name": "tag:Vendor", "Values": ["Databricks"]},
918-
{"Name": "tag:ClusterId", "Values": [cluster_id]},
919-
# {'Name': 'tag:JobId', 'Values': []}
920-
]
921-
)
943+
aws_region_name = DB_CONFIG.aws_region_name
944+
ec2 = boto.client("ec2", region_name=aws_region_name)
922945

923-
new_instances = [res for res in instances["Reservations"]]
924-
new_instance_id_to_reservation = dict(
925-
zip(
926-
(res["Instances"][0]["InstanceId"] for res in new_instances),
927-
new_instances,
928-
)
946+
if filesystem == "s3":
947+
s3 = boto.client("s3")
948+
949+
def write_file(body: bytes):
950+
logger.info("Saving state to S3")
951+
s3.put_object(Bucket=bucket, Key=file_key, Body=body)
952+
953+
elif filesystem == "dbfs":
954+
path = format_dbfs_filepath(file_key)
955+
dbx_client = get_default_client()
956+
957+
def write_file(body: bytes):
958+
logger.info("Saving state to DBFS")
959+
write_dbfs_file(path, body, dbx_client)
960+
961+
previous_instances = {}
962+
while True:
963+
try:
964+
instances = ec2.describe_instances(
965+
Filters=[
966+
{"Name": "tag:Vendor", "Values": ["Databricks"]},
967+
{"Name": "tag:ClusterId", "Values": [cluster_id]},
968+
# {'Name': 'tag:JobId', 'Values': []}
969+
]
970+
)
971+
972+
new_instances = [res for res in instances["Reservations"]]
973+
new_instance_id_to_reservation = dict(
974+
zip(
975+
(res["Instances"][0]["InstanceId"] for res in new_instances),
976+
new_instances,
929977
)
978+
)
930979

931-
old_instances = [res for res in previous_instances.get("Reservations", [])]
932-
old_instance_id_to_reservation = dict(
933-
zip(
934-
(res["Instances"][0]["InstanceId"] for res in old_instances),
935-
old_instances,
936-
)
980+
old_instances = [res for res in previous_instances.get("Reservations", [])]
981+
old_instance_id_to_reservation = dict(
982+
zip(
983+
(res["Instances"][0]["InstanceId"] for res in old_instances),
984+
old_instances,
937985
)
986+
)
938987

939-
old_instance_ids = set(old_instance_id_to_reservation)
940-
new_instance_ids = set(new_instance_id_to_reservation)
988+
old_instance_ids = set(old_instance_id_to_reservation)
989+
new_instance_ids = set(new_instance_id_to_reservation)
941990

942-
# If we have the exact same set of instances, prefer the new set...
943-
if old_instance_ids == new_instance_ids:
944-
instances = {"Reservations": new_instances}
945-
else:
946-
# Otherwise, update old references and include any new instances in the list
947-
newly_added_instance_ids = new_instance_ids.difference(old_instance_ids)
948-
updated_instance_ids = newly_added_instance_ids.intersection(old_instance_ids)
949-
removed_instance_ids = old_instance_ids.difference(updated_instance_ids)
950-
951-
removed_instances = [
952-
old_instance_id_to_reservation[id] for id in removed_instance_ids
953-
]
954-
updated_instances = [
955-
new_instance_id_to_reservation[id] for id in updated_instance_ids
956-
]
957-
new_instances = [
958-
new_instance_id_to_reservation[id] for id in newly_added_instance_ids
959-
]
960-
961-
instances = {
962-
"Reservations": [*removed_instances, *updated_instances, *new_instances]
963-
}
991+
# If we have the exact same set of instances, prefer the new set...
992+
if old_instance_ids == new_instance_ids:
993+
instances = {"Reservations": new_instances}
994+
else:
995+
# Otherwise, update old references and include any new instances in the list
996+
newly_added_instance_ids = new_instance_ids.difference(old_instance_ids)
997+
updated_instance_ids = newly_added_instance_ids.intersection(old_instance_ids)
998+
removed_instance_ids = old_instance_ids.difference(updated_instance_ids)
999+
1000+
removed_instances = [
1001+
old_instance_id_to_reservation[id] for id in removed_instance_ids
1002+
]
1003+
updated_instances = [
1004+
new_instance_id_to_reservation[id] for id in updated_instance_ids
1005+
]
1006+
new_instances = [
1007+
new_instance_id_to_reservation[id] for id in newly_added_instance_ids
1008+
]
1009+
1010+
instances = {
1011+
"Reservations": [*removed_instances, *updated_instances, *new_instances]
1012+
}
9641013

965-
write_file(orjson.dumps(instances))
1014+
write_file(orjson.dumps(instances))
9661015

967-
previous_instances = instances
968-
except Exception as e:
969-
logger.error(f"Exception encountered while polling cluster: {e}")
1016+
previous_instances = instances
1017+
except Exception as e:
1018+
logger.error(f"Exception encountered while polling cluster: {e}")
9701019

971-
sleep(polling_period)
972-
else:
973-
logger.warning("Unable to monitor cluster due to missing cluster log destination - exiting")
1020+
sleep(polling_period)
9741021

9751022

9761023
def _get_job_cluster(tasks: List[dict], job_clusters: list) -> Response[dict]:
@@ -1132,7 +1179,7 @@ def _get_eventlog_from_dbfs(
11321179
run_end_time_millis: int,
11331180
poll_duration_seconds: int,
11341181
):
1135-
prefix = f"dbfs:/{base_filepath}/eventlog/"
1182+
prefix = format_dbfs_filepath(f"{base_filepath}/eventlog/")
11361183

11371184
root_dir = get_default_client().list_dbfs_directory(prefix)
11381185

@@ -1179,21 +1226,18 @@ def _get_eventlog_from_dbfs(
11791226
eventlog_zip_file = zipfile.ZipFile(eventlog_zip, "a", zipfile.ZIP_DEFLATED)
11801227

11811228
eventlog_files = eventlog_dir["files"]
1229+
dbx_client = get_default_client()
11821230
for eventlog_file_metadata in eventlog_files:
1183-
filename = eventlog_file_metadata["path"].split("/")[-1]
1184-
filesize = eventlog_file_metadata["file_size"]
1185-
1186-
bytes_read = 0
1187-
# DBFS tells us exactly how many bytes to expect for each file, so pre-allocate an array
1188-
# to write the file chunks in to
1189-
content = bytearray(filesize)
1190-
while bytes_read < filesize:
1191-
file_content = get_default_client().read_dbfs_file_chunk(
1192-
eventlog_file_metadata["path"], offset=bytes_read
1193-
)
1194-
new_bytes_read = bytes_read + file_content["bytes_read"]
1195-
content[bytes_read:new_bytes_read] = base64.b64decode(file_content["data"])
1196-
bytes_read = new_bytes_read
1231+
filename: str = eventlog_file_metadata["path"].split("/")[-1]
1232+
filesize: int = eventlog_file_metadata["file_size"]
1233+
1234+
content = read_dbfs_file(eventlog_file_metadata["path"], dbx_client, filesize)
1235+
1236+
# Databricks typically leaves the most recent rollover log uncompressed, so we may as well
1237+
# gzip it before upload
1238+
if not filename.endswith(".gz"):
1239+
content = gzip.compress(content)
1240+
filename += ".gz"
11971241

11981242
eventlog_zip_file.writestr(filename, content)
11991243

@@ -1257,7 +1301,7 @@ def _get_all_cluster_events(cluster_id: str):
12571301

12581302
next_args = response.get("next_page")
12591303
while next_args:
1260-
response = get_default_client().get_cluster_events(cluster_id, **next_args)
1304+
response = get_default_client().get_cluster_events(**next_args)
12611305
responses.append(response)
12621306
next_args = response.get("next_page")
12631307

sync/clients/databricks.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,30 @@ def read_dbfs_file_chunk(self, path: str, offset: int = 0, length: int = 1024 *
113113
)
114114
)
115115

116+
def open_dbfs_file_stream(self, path: str, overwrite: bool = False) -> dict:
117+
headers, content = encode_json({"path": path, "overwrite": overwrite})
118+
return self._send(
119+
self._client.build_request(
120+
"POST", "/api/2.0/dbfs/create", headers=headers, content=content
121+
)
122+
)
123+
124+
def add_block_to_dbfs_file_stream(self, handle: int, data: str) -> dict:
125+
headers, content = encode_json({"handle": handle, "data": data})
126+
return self._send(
127+
self._client.build_request(
128+
"POST", "/api/2.0/dbfs/add-block", headers=headers, content=content
129+
)
130+
)
131+
132+
def close_dbfs_file_stream(self, handle: int) -> dict:
133+
headers, content = encode_json({"handle": handle})
134+
return self._send(
135+
self._client.build_request(
136+
"POST", "/api/2.0/dbfs/close", headers=headers, content=content
137+
)
138+
)
139+
116140
def _send(self, request: httpx.Request) -> dict:
117141
response = self._send_request(request)
118142

sync/utils/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)