Skip to content

Commit dfe4899

Browse files
Merge pull request #121 from synccomputingcode/PROD-2015/collect-cluster-report-refac
[PROD-2015] Add save_cluster_report and refactor _monitor_cluster
2 parents 97fc459 + e611748 commit dfe4899

File tree

3 files changed

+67
-33
lines changed

3 files changed

+67
-33
lines changed

sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Library for leveraging the power of Sync"""
22

3-
__version__ = "1.8.4"
3+
__version__ = "1.9.0"
44

55
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

sync/awsdatabricks.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
from pathlib import Path
44
from time import sleep
5-
from typing import Generator, List, Tuple
5+
from typing import Dict, Generator, List, Optional, Tuple
66
from urllib.parse import urlparse
77

88
import boto3 as boto
@@ -84,6 +84,7 @@
8484
"wait_for_run_and_cluster",
8585
"terminate_cluster",
8686
"apply_project_recommendation",
87+
"save_cluster_report",
8788
]
8889

8990
logger = logging.getLogger(__name__)
@@ -348,6 +349,62 @@ def _get_aws_cluster_info_from_s3(bucket: str, file_key: str, cluster_id):
348349
logger.warning(f"Failed to retrieve cluster info from S3 with key, '{file_key}': {err}")
349350

350351

352+
def save_cluster_report(
353+
cluster_id: str,
354+
instance_timelines: List[dict],
355+
cluster_log_destination: Optional[Tuple[str, ...]] = None,
356+
cluster_report_destination_override: Optional[Dict[str, str]] = None,
357+
write_function=None,
358+
) -> bool:
359+
cluster = get_default_client().get_cluster(cluster_id)
360+
spark_context_id = cluster.get("spark_context_id")
361+
362+
if not spark_context_id:
363+
return False
364+
365+
cluster_log_destination = cluster_log_destination or _cluster_log_destination(cluster)
366+
(log_url, filesystem, bucket, base_prefix) = cluster_log_destination
367+
368+
if cluster_report_destination_override:
369+
filesystem = cluster_report_destination_override.get("filesystem", filesystem)
370+
base_prefix = cluster_report_destination_override.get("base_prefix", base_prefix)
371+
372+
if not log_url and not cluster_report_destination_override:
373+
logger.warning(
374+
"Unable to save cluster report due to missing cluster log destination - exiting"
375+
)
376+
return False
377+
378+
# If the event log destination is just a *bucket* without any sub-path, then we don't want to include
379+
# a leading `/` in our Prefix (which will make it so that we never actually find the event log), so
380+
# we make sure to re-strip our final Prefix
381+
file_key = f"{base_prefix}/sync_data/{spark_context_id}/aws_cluster_info.json".strip("/")
382+
383+
aws_region_name = DB_CONFIG.aws_region_name
384+
ec2 = boto.client("ec2", region_name=aws_region_name)
385+
386+
current_insts = _get_ec2_instances(cluster_id, ec2)
387+
ebs_volumes = _get_ebs_volumes_for_instances(current_insts, ec2)
388+
389+
write_file = _define_write_file(file_key, filesystem, bucket, write_function)
390+
391+
write_file(
392+
bytes(
393+
json.dumps(
394+
{
395+
"instances": current_insts,
396+
"instance_timelines": instance_timelines,
397+
"volumes": ebs_volumes,
398+
},
399+
cls=DefaultDateTimeEncoder,
400+
),
401+
"utf-8",
402+
)
403+
)
404+
405+
return True
406+
407+
351408
def monitor_cluster(
352409
cluster_id: str,
353410
polling_period: int = 20,
@@ -375,7 +432,6 @@ def monitor_cluster(
375432
_monitor_cluster(
376433
(log_url, filesystem, bucket, base_prefix),
377434
cluster_id,
378-
spark_context_id,
379435
polling_period,
380436
kill_on_termination,
381437
write_function,
@@ -387,40 +443,26 @@ def monitor_cluster(
387443
def _monitor_cluster(
388444
cluster_log_destination,
389445
cluster_id: str,
390-
spark_context_id: int,
391446
polling_period: int,
392447
kill_on_termination: bool = False,
393448
write_function=None,
394449
) -> None:
395-
(log_url, filesystem, bucket, base_prefix) = cluster_log_destination
396-
# If the event log destination is just a *bucket* without any sub-path, then we don't want to include
397-
# a leading `/` in our Prefix (which will make it so that we never actually find the event log), so
398-
# we make sure to re-strip our final Prefix
399-
file_key = f"{base_prefix}/sync_data/{spark_context_id}/aws_cluster_info.json".strip("/")
400-
401450
aws_region_name = DB_CONFIG.aws_region_name
402451
ec2 = boto.client("ec2", region_name=aws_region_name)
403452

404-
write_file = _define_write_file(file_key, filesystem, bucket, write_function)
405-
406-
all_inst_by_id = {}
407453
active_timelines_by_id = {}
408454
retired_timelines = []
409-
recorded_volumes_by_id = {}
410455

411456
while_condition = True
457+
412458
while while_condition:
413459
try:
414460
current_insts = _get_ec2_instances(cluster_id, ec2)
415-
recorded_volumes_by_id.update(
416-
{v["VolumeId"]: v for v in _get_ebs_volumes_for_instances(current_insts, ec2)}
417-
)
418461

419462
# Record new (or overwrite) existing instances.
420463
# Separately record the ids of those that are in the "running" state.
421464
running_inst_ids = set({})
422465
for inst in current_insts:
423-
all_inst_by_id[inst["InstanceId"]] = inst
424466
if inst["State"]["Name"] == "running":
425467
running_inst_ids.add(inst["InstanceId"])
426468

@@ -431,24 +473,19 @@ def _monitor_cluster(
431473
retired_timelines.extend(new_retired_timelines)
432474
all_timelines = retired_timelines + list(active_timelines_by_id.values())
433475

434-
write_file(
435-
bytes(
436-
json.dumps(
437-
{
438-
"instances": list(all_inst_by_id.values()),
439-
"instance_timelines": all_timelines,
440-
"volumes": list(recorded_volumes_by_id.values()),
441-
},
442-
cls=DefaultDateTimeEncoder,
443-
),
444-
"utf-8",
445-
)
476+
save_cluster_report(
477+
cluster_id,
478+
all_timelines,
479+
cluster_log_destination=cluster_log_destination,
480+
write_function=write_function,
446481
)
447482

448483
if kill_on_termination:
449484
cluster_state = get_default_client().get_cluster(cluster_id).get("state")
485+
450486
if cluster_state == "TERMINATED":
451487
while_condition = False
488+
452489
except Exception as e:
453490
logger.error(f"Exception encountered while polling cluster: {e}")
454491

tests/test_awsdatabricks.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def test_monitor_cluster_with_override(
4141
mock_monitor_cluster.assert_called_with(
4242
expected_log_destination_override,
4343
"0101-214342-tpi6qdp2",
44-
1443449481634833945,
4544
1,
4645
True,
4746
None,
@@ -53,7 +52,6 @@ def test_monitor_cluster_with_override(
5352
mock_monitor_cluster.assert_called_with(
5453
expected_log_destination_override,
5554
"0101-214342-tpi6qdp2",
56-
1443449481634833945,
5755
1,
5856
True,
5957
None,
@@ -77,7 +75,6 @@ def test_monitor_cluster_without_override(
7775
mock_monitor_cluster.assert_called_with(
7876
mock_cluster_log_destination.return_value,
7977
"0101-214342-tpi6qdp2",
80-
1443449481634833945,
8178
1,
8279
False,
8380
None,

0 commit comments

Comments
 (0)