|
1 | 1 | """
|
2 | 2 | Utilities for interacting with Databricks
|
3 | 3 | """
|
4 |
| -import base64 |
| 4 | +import gzip |
5 | 5 | import io
|
6 | 6 | import logging
|
7 | 7 | import time
|
|
26 | 26 | Platform,
|
27 | 27 | Response,
|
28 | 28 | )
|
| 29 | +from sync.utils.dbfs import format_dbfs_filepath, read_dbfs_file, write_dbfs_file |
29 | 30 |
|
30 | 31 | logger = logging.getLogger(__name__)
|
31 | 32 |
|
@@ -254,34 +255,60 @@ def _get_cluster_report(
|
254 | 255 | )
|
255 | 256 |
|
256 | 257 |
|
| 258 | +def _get_cluster_instances_from_s3(bucket: str, file_key: str, cluster_id): |
| 259 | + s3 = boto.client("s3") |
| 260 | + try: |
| 261 | + return s3.get_object(Bucket=bucket, Key=file_key)["Body"].read() |
| 262 | + except ClientError as ex: |
| 263 | + if ex.response["Error"]["Code"] == "NoSuchKey": |
| 264 | + logger.warning( |
| 265 | + f"Could not find sync_data/cluster_instances.json for cluster: {cluster_id}" |
| 266 | + ) |
| 267 | + else: |
| 268 | + logger.error( |
| 269 | + f"Unexpected error encountered while attempting to fetch sync_data/cluster_instances.json: {ex}" |
| 270 | + ) |
| 271 | + |
| 272 | + |
| 273 | +def _get_cluster_instances_from_dbfs(filepath: str): |
| 274 | + filepath = format_dbfs_filepath(filepath) |
| 275 | + dbx_client = get_default_client() |
| 276 | + try: |
| 277 | + return read_dbfs_file(filepath, dbx_client) |
| 278 | + except Exception as ex: |
| 279 | + logger.error(f"Unexpected error encountered while attempting to fetch {filepath}: {ex}") |
| 280 | + |
| 281 | + |
257 | 282 | def _get_cluster_instances(cluster: dict) -> Response[dict]:
|
258 | 283 | cluster_instances = None
|
259 | 284 | aws_region_name = DB_CONFIG.aws_region_name
|
260 | 285 |
|
261 |
| - cluster_id = cluster["cluster_id"] |
262 | 286 | cluster_log_dest = _cluster_log_destination(cluster)
|
263 | 287 |
|
264 | 288 | if cluster_log_dest:
|
265 | 289 | (_, filesystem, bucket, base_prefix) = cluster_log_dest
|
266 |
| - if filesystem == "s3": |
267 |
| - s3 = boto.client("s3") |
268 | 290 |
|
269 |
| - cluster_instances_file_key = f"{base_prefix}/sync_data/cluster_instances.json" |
| 291 | + cluster_id = cluster["cluster_id"] |
| 292 | + spark_context_id = cluster["spark_context_id"] |
| 293 | + cluster_instances_file_key = ( |
| 294 | + f"{base_prefix}/sync_data/{spark_context_id}/cluster_instances.json" |
| 295 | + ) |
270 | 296 |
|
271 |
| - try: |
272 |
| - cluster_instances_file_response = s3.get_object( |
273 |
| - Bucket=bucket, Key=cluster_instances_file_key |
274 |
| - ) |
275 |
| - cluster_instances = orjson.loads(cluster_instances_file_response["Body"].read()) |
276 |
| - except ClientError as ex: |
277 |
| - if ex.response["Error"]["Code"] == "NoSuchKey": |
278 |
| - logger.warning( |
279 |
| - f"Could not find sync_data/cluster_instances.json for cluster: {cluster_id}" |
280 |
| - ) |
281 |
| - else: |
282 |
| - logger.error( |
283 |
| - f"Unexpected error encountered while attempting to fetch sync_data/cluster_instances.json: {ex}" |
284 |
| - ) |
| 297 | + cluster_instances_file_response = None |
| 298 | + if filesystem == "s3": |
| 299 | + cluster_instances_file_response = _get_cluster_instances_from_s3( |
| 300 | + bucket, cluster_instances_file_key, cluster_id |
| 301 | + ) |
| 302 | + elif filesystem == "dbfs": |
| 303 | + cluster_instances_file_response = _get_cluster_instances_from_dbfs( |
| 304 | + cluster_instances_file_key |
| 305 | + ) |
| 306 | + |
| 307 | + cluster_instances = ( |
| 308 | + orjson.loads(cluster_instances_file_response) |
| 309 | + if cluster_instances_file_response |
| 310 | + else None |
| 311 | + ) |
285 | 312 |
|
286 | 313 | # If this cluster does not have the "Sync agent" configured, attempt a best-effort snapshot of the instances that
|
287 | 314 | # are associated with this cluster
|
@@ -886,91 +913,111 @@ def _cluster_log_destination(
|
886 | 913 |
|
887 | 914 | def monitor_cluster(cluster_id: str, polling_period: int = 30) -> None:
|
888 | 915 | cluster = get_default_client().get_cluster(cluster_id)
|
889 |
| - cluster_log_dest = _cluster_log_destination(cluster) |
890 |
| - if cluster_log_dest: |
891 |
| - (_, filesystem, bucket, base_prefix) = cluster_log_dest |
892 |
| - # If the event log destination is just a *bucket* without any sub-path, then we don't want to include |
893 |
| - # a leading `/` in our Prefix (which will make it so that we never actually find the event log), so |
894 |
| - # we make sure to re-strip our final Prefix |
895 |
| - file_key = f"{base_prefix}/sync_data/cluster_instances.json".strip("/") |
| 916 | + spark_context_id = cluster.get("spark_context_id") |
896 | 917 |
|
897 |
| - aws_region_name = DB_CONFIG.aws_region_name |
898 |
| - ec2 = boto.client("ec2", region_name=aws_region_name) |
| 918 | + while not spark_context_id: |
| 919 | + # This is largely just a convenience for when this command is run by someone locally |
| 920 | + logger.info("Waiting for cluster startup...") |
| 921 | + sleep(30) |
| 922 | + cluster = get_default_client().get_cluster(cluster_id) |
| 923 | + spark_context_id = cluster.get("spark_context_id") |
899 | 924 |
|
900 |
| - if filesystem == "s3": |
901 |
| - s3 = boto.client("s3") |
| 925 | + (log_url, filesystem, bucket, base_prefix) = _cluster_log_destination(cluster) |
| 926 | + if log_url: |
| 927 | + _monitor_cluster( |
| 928 | + (log_url, filesystem, bucket, base_prefix), cluster_id, spark_context_id, polling_period |
| 929 | + ) |
| 930 | + else: |
| 931 | + logger.warning("Unable to monitor cluster due to missing cluster log destination - exiting") |
902 | 932 |
|
903 |
| - def write_file(body: bytes): |
904 |
| - s3.put_object(Bucket=bucket, Key=file_key, Body=body) |
905 | 933 |
|
906 |
| - elif filesystem == "dbfs": |
| 934 | +def _monitor_cluster( |
| 935 | + cluster_log_destination, cluster_id: str, spark_context_id: int, polling_period: int |
| 936 | +) -> None: |
| 937 | + (log_url, filesystem, bucket, base_prefix) = cluster_log_destination |
| 938 | + # If the event log destination is just a *bucket* without any sub-path, then we don't want to include |
| 939 | + # a leading `/` in our Prefix (which will make it so that we never actually find the event log), so |
| 940 | + # we make sure to re-strip our final Prefix |
| 941 | + file_key = f"{base_prefix}/sync_data/{spark_context_id}/cluster_instances.json".strip("/") |
907 | 942 |
|
908 |
| - def write_file(body: bytes): |
909 |
| - # TODO |
910 |
| - raise Exception() |
911 |
| - |
912 |
| - previous_instances = {} |
913 |
| - while True: |
914 |
| - try: |
915 |
| - instances = ec2.describe_instances( |
916 |
| - Filters=[ |
917 |
| - {"Name": "tag:Vendor", "Values": ["Databricks"]}, |
918 |
| - {"Name": "tag:ClusterId", "Values": [cluster_id]}, |
919 |
| - # {'Name': 'tag:JobId', 'Values': []} |
920 |
| - ] |
921 |
| - ) |
| 943 | + aws_region_name = DB_CONFIG.aws_region_name |
| 944 | + ec2 = boto.client("ec2", region_name=aws_region_name) |
922 | 945 |
|
923 |
| - new_instances = [res for res in instances["Reservations"]] |
924 |
| - new_instance_id_to_reservation = dict( |
925 |
| - zip( |
926 |
| - (res["Instances"][0]["InstanceId"] for res in new_instances), |
927 |
| - new_instances, |
928 |
| - ) |
| 946 | + if filesystem == "s3": |
| 947 | + s3 = boto.client("s3") |
| 948 | + |
| 949 | + def write_file(body: bytes): |
| 950 | + logger.info("Saving state to S3") |
| 951 | + s3.put_object(Bucket=bucket, Key=file_key, Body=body) |
| 952 | + |
| 953 | + elif filesystem == "dbfs": |
| 954 | + path = format_dbfs_filepath(file_key) |
| 955 | + dbx_client = get_default_client() |
| 956 | + |
| 957 | + def write_file(body: bytes): |
| 958 | + logger.info("Saving state to DBFS") |
| 959 | + write_dbfs_file(path, body, dbx_client) |
| 960 | + |
| 961 | + previous_instances = {} |
| 962 | + while True: |
| 963 | + try: |
| 964 | + instances = ec2.describe_instances( |
| 965 | + Filters=[ |
| 966 | + {"Name": "tag:Vendor", "Values": ["Databricks"]}, |
| 967 | + {"Name": "tag:ClusterId", "Values": [cluster_id]}, |
| 968 | + # {'Name': 'tag:JobId', 'Values': []} |
| 969 | + ] |
| 970 | + ) |
| 971 | + |
| 972 | + new_instances = [res for res in instances["Reservations"]] |
| 973 | + new_instance_id_to_reservation = dict( |
| 974 | + zip( |
| 975 | + (res["Instances"][0]["InstanceId"] for res in new_instances), |
| 976 | + new_instances, |
929 | 977 | )
|
| 978 | + ) |
930 | 979 |
|
931 |
| - old_instances = [res for res in previous_instances.get("Reservations", [])] |
932 |
| - old_instance_id_to_reservation = dict( |
933 |
| - zip( |
934 |
| - (res["Instances"][0]["InstanceId"] for res in old_instances), |
935 |
| - old_instances, |
936 |
| - ) |
| 980 | + old_instances = [res for res in previous_instances.get("Reservations", [])] |
| 981 | + old_instance_id_to_reservation = dict( |
| 982 | + zip( |
| 983 | + (res["Instances"][0]["InstanceId"] for res in old_instances), |
| 984 | + old_instances, |
937 | 985 | )
|
| 986 | + ) |
938 | 987 |
|
939 |
| - old_instance_ids = set(old_instance_id_to_reservation) |
940 |
| - new_instance_ids = set(new_instance_id_to_reservation) |
| 988 | + old_instance_ids = set(old_instance_id_to_reservation) |
| 989 | + new_instance_ids = set(new_instance_id_to_reservation) |
941 | 990 |
|
942 |
| - # If we have the exact same set of instances, prefer the new set... |
943 |
| - if old_instance_ids == new_instance_ids: |
944 |
| - instances = {"Reservations": new_instances} |
945 |
| - else: |
946 |
| - # Otherwise, update old references and include any new instances in the list |
947 |
| - newly_added_instance_ids = new_instance_ids.difference(old_instance_ids) |
948 |
| - updated_instance_ids = newly_added_instance_ids.intersection(old_instance_ids) |
949 |
| - removed_instance_ids = old_instance_ids.difference(updated_instance_ids) |
950 |
| - |
951 |
| - removed_instances = [ |
952 |
| - old_instance_id_to_reservation[id] for id in removed_instance_ids |
953 |
| - ] |
954 |
| - updated_instances = [ |
955 |
| - new_instance_id_to_reservation[id] for id in updated_instance_ids |
956 |
| - ] |
957 |
| - new_instances = [ |
958 |
| - new_instance_id_to_reservation[id] for id in newly_added_instance_ids |
959 |
| - ] |
960 |
| - |
961 |
| - instances = { |
962 |
| - "Reservations": [*removed_instances, *updated_instances, *new_instances] |
963 |
| - } |
| 991 | + # If we have the exact same set of instances, prefer the new set... |
| 992 | + if old_instance_ids == new_instance_ids: |
| 993 | + instances = {"Reservations": new_instances} |
| 994 | + else: |
| 995 | + # Otherwise, update old references and include any new instances in the list |
| 996 | + newly_added_instance_ids = new_instance_ids.difference(old_instance_ids) |
| 997 | + updated_instance_ids = newly_added_instance_ids.intersection(old_instance_ids) |
| 998 | + removed_instance_ids = old_instance_ids.difference(updated_instance_ids) |
| 999 | + |
| 1000 | + removed_instances = [ |
| 1001 | + old_instance_id_to_reservation[id] for id in removed_instance_ids |
| 1002 | + ] |
| 1003 | + updated_instances = [ |
| 1004 | + new_instance_id_to_reservation[id] for id in updated_instance_ids |
| 1005 | + ] |
| 1006 | + new_instances = [ |
| 1007 | + new_instance_id_to_reservation[id] for id in newly_added_instance_ids |
| 1008 | + ] |
| 1009 | + |
| 1010 | + instances = { |
| 1011 | + "Reservations": [*removed_instances, *updated_instances, *new_instances] |
| 1012 | + } |
964 | 1013 |
|
965 |
| - write_file(orjson.dumps(instances)) |
| 1014 | + write_file(orjson.dumps(instances)) |
966 | 1015 |
|
967 |
| - previous_instances = instances |
968 |
| - except Exception as e: |
969 |
| - logger.error(f"Exception encountered while polling cluster: {e}") |
| 1016 | + previous_instances = instances |
| 1017 | + except Exception as e: |
| 1018 | + logger.error(f"Exception encountered while polling cluster: {e}") |
970 | 1019 |
|
971 |
| - sleep(polling_period) |
972 |
| - else: |
973 |
| - logger.warning("Unable to monitor cluster due to missing cluster log destination - exiting") |
| 1020 | + sleep(polling_period) |
974 | 1021 |
|
975 | 1022 |
|
976 | 1023 | def _get_job_cluster(tasks: List[dict], job_clusters: list) -> Response[dict]:
|
@@ -1132,7 +1179,7 @@ def _get_eventlog_from_dbfs(
|
1132 | 1179 | run_end_time_millis: int,
|
1133 | 1180 | poll_duration_seconds: int,
|
1134 | 1181 | ):
|
1135 |
| - prefix = f"dbfs:/{base_filepath}/eventlog/" |
| 1182 | + prefix = format_dbfs_filepath(f"{base_filepath}/eventlog/") |
1136 | 1183 |
|
1137 | 1184 | root_dir = get_default_client().list_dbfs_directory(prefix)
|
1138 | 1185 |
|
@@ -1179,21 +1226,18 @@ def _get_eventlog_from_dbfs(
|
1179 | 1226 | eventlog_zip_file = zipfile.ZipFile(eventlog_zip, "a", zipfile.ZIP_DEFLATED)
|
1180 | 1227 |
|
1181 | 1228 | eventlog_files = eventlog_dir["files"]
|
| 1229 | + dbx_client = get_default_client() |
1182 | 1230 | for eventlog_file_metadata in eventlog_files:
|
1183 |
| - filename = eventlog_file_metadata["path"].split("/")[-1] |
1184 |
| - filesize = eventlog_file_metadata["file_size"] |
1185 |
| - |
1186 |
| - bytes_read = 0 |
1187 |
| - # DBFS tells us exactly how many bytes to expect for each file, so pre-allocate an array |
1188 |
| - # to write the file chunks in to |
1189 |
| - content = bytearray(filesize) |
1190 |
| - while bytes_read < filesize: |
1191 |
| - file_content = get_default_client().read_dbfs_file_chunk( |
1192 |
| - eventlog_file_metadata["path"], offset=bytes_read |
1193 |
| - ) |
1194 |
| - new_bytes_read = bytes_read + file_content["bytes_read"] |
1195 |
| - content[bytes_read:new_bytes_read] = base64.b64decode(file_content["data"]) |
1196 |
| - bytes_read = new_bytes_read |
| 1231 | + filename: str = eventlog_file_metadata["path"].split("/")[-1] |
| 1232 | + filesize: int = eventlog_file_metadata["file_size"] |
| 1233 | + |
| 1234 | + content = read_dbfs_file(eventlog_file_metadata["path"], dbx_client, filesize) |
| 1235 | + |
| 1236 | + # Databricks typically leaves the most recent rollover log uncompressed, so we may as well |
| 1237 | + # gzip it before upload |
| 1238 | + if not filename.endswith(".gz"): |
| 1239 | + content = gzip.compress(content) |
| 1240 | + filename += ".gz" |
1197 | 1241 |
|
1198 | 1242 | eventlog_zip_file.writestr(filename, content)
|
1199 | 1243 |
|
@@ -1257,7 +1301,7 @@ def _get_all_cluster_events(cluster_id: str):
|
1257 | 1301 |
|
1258 | 1302 | next_args = response.get("next_page")
|
1259 | 1303 | while next_args:
|
1260 |
| - response = get_default_client().get_cluster_events(cluster_id, **next_args) |
| 1304 | + response = get_default_client().get_cluster_events(**next_args) |
1261 | 1305 | responses.append(response)
|
1262 | 1306 | next_args = response.get("next_page")
|
1263 | 1307 |
|
|
0 commit comments