Skip to content

Commit a745394

Browse files
authored
[PROD-1968] Add monitor_once functionality to library (#116)
* Add monitor_once functionality to library * Track in progress cluster object
1 parent 7c71815 commit a745394

File tree

3 files changed

+67
-10
lines changed

3 files changed

+67
-10
lines changed

sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Library for leveraging the power of Sync"""
22

3-
__version__ = "1.7.0"
3+
__version__ = "1.8.0"
44

55
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

sync/awsdatabricks.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"get_cluster_report",
6565
"get_all_cluster_events",
6666
"monitor_cluster",
67+
"monitor_once",
6768
"create_cluster",
6869
"get_cluster",
6970
"handle_successful_job_run",
@@ -85,7 +86,6 @@
8586
"apply_project_recommendation",
8687
]
8788

88-
8989
logger = logging.getLogger(__name__)
9090

9191

@@ -268,7 +268,6 @@ def _create_cluster_report(
268268

269269

270270
def _load_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict]]:
271-
272271
cluster_info = None
273272
cluster_id = None
274273
cluster_log_dest = _cluster_log_destination(cluster)
@@ -312,7 +311,6 @@ def _load_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict
312311

313312

314313
def _get_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict], Response[dict]]:
315-
316314
aws_region_name = DB_CONFIG.aws_region_name
317315

318316
cluster_info, cluster_id = _load_aws_cluster_info(cluster)
@@ -394,7 +392,6 @@ def _monitor_cluster(
394392
kill_on_termination: bool = False,
395393
write_function=None,
396394
) -> None:
397-
398395
(log_url, filesystem, bucket, base_prefix) = cluster_log_destination
399396
# If the event log destination is just a *bucket* without any sub-path, then we don't want to include
400397
# a leading `/` in our Prefix (which will make it so that we never actually find the event log), so
@@ -458,6 +455,42 @@ def _monitor_cluster(
458455
sleep(polling_period)
459456

460457

458+
def monitor_once(cluster_id: str, in_progress_cluster={}):
459+
all_inst_by_id = in_progress_cluster.get("all_inst_by_id") or {}
460+
active_timelines_by_id = in_progress_cluster.get("active_timelines_by_id") or {}
461+
retired_timelines = in_progress_cluster.get("retired_timelines") or []
462+
recorded_volumes_by_id = in_progress_cluster.get("recorded_volumes_by_id") or {}
463+
464+
aws_region_name = DB_CONFIG.aws_region_name
465+
ec2 = boto.client("ec2", region_name=aws_region_name)
466+
467+
current_insts = _get_ec2_instances(cluster_id, ec2)
468+
recorded_volumes_by_id.update(
469+
{v["VolumeId"]: v for v in _get_ebs_volumes_for_instances(current_insts, ec2)}
470+
)
471+
472+
# Record new (or overwrite) existing instances.
473+
# Separately record the ids of those that are in the "running" state.
474+
running_inst_ids = set({})
475+
for inst in current_insts:
476+
all_inst_by_id[inst["InstanceId"]] = inst
477+
if inst["State"]["Name"] == "running":
478+
running_inst_ids.add(inst["InstanceId"])
479+
480+
active_timelines_by_id, new_retired_timelines = _update_monitored_timelines(
481+
running_inst_ids, active_timelines_by_id
482+
)
483+
484+
retired_timelines.extend(new_retired_timelines)
485+
486+
return {
487+
"all_inst_by_id": all_inst_by_id,
488+
"active_timelines_by_id": active_timelines_by_id,
489+
"retired_timelines": retired_timelines,
490+
"recorded_volumes_by_id": recorded_volumes_by_id,
491+
}
492+
493+
461494
def _define_write_file(file_key, filesystem, bucket, write_function):
462495
if filesystem == "lambda":
463496

@@ -499,7 +532,6 @@ def write_file(body: bytes):
499532

500533

501534
def _get_ec2_instances(cluster_id: str, ec2_client: "botocore.client.ec2") -> List[dict]:
502-
503535
filters = [
504536
{"Name": "tag:Vendor", "Values": ["Databricks"]},
505537
{"Name": "tag:ClusterId", "Values": [cluster_id]},

sync/azuredatabricks.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"get_access_report",
6464
"run_and_record_job",
6565
"monitor_cluster",
66+
"monitor_once",
6667
"create_cluster",
6768
"get_cluster",
6869
"create_submission_for_run",
@@ -89,7 +90,6 @@
8990
"apply_project_recommendation",
9091
]
9192

92-
9393
logger = logging.getLogger(__name__)
9494

9595

@@ -287,7 +287,6 @@ def _get_cluster_instances(cluster: dict) -> Response[dict]:
287287
# If this cluster does not have the "Sync agent" configured, attempt a best-effort snapshot of the instances that
288288
# are associated with this cluster
289289
if not cluster_instances:
290-
291290
resource_group_name = _get_databricks_resource_group_name()
292291

293292
compute = _get_azure_client(ComputeManagementClient)
@@ -422,6 +421,34 @@ def _monitor_cluster(
422421
sleep(polling_period)
423422

424423

424+
def monitor_once(cluster_id: str, in_progress_cluster={}):
425+
all_vms_by_id = in_progress_cluster.get("all_vms_by_id") or {}
426+
active_timelines_by_id = in_progress_cluster.get("active_timelines_by_id") or {}
427+
retired_timelines = in_progress_cluster.get("retired_timelines") or []
428+
429+
resource_group_name = _get_databricks_resource_group_name()
430+
if not resource_group_name:
431+
logger.warning("Failed to find Databricks managed resource group")
432+
433+
compute = _get_azure_client(ComputeManagementClient)
434+
435+
running_vms_by_id = _get_running_vms_by_id(compute, resource_group_name, cluster_id)
436+
437+
for vm in running_vms_by_id.values():
438+
all_vms_by_id[vm["name"]] = vm
439+
440+
active_timelines_by_id, new_retired_timelines = _update_monitored_timelines(
441+
set(running_vms_by_id.keys()), active_timelines_by_id
442+
)
443+
retired_timelines.extend(new_retired_timelines)
444+
445+
return {
446+
"all_vms_by_id": all_vms_by_id,
447+
"active_timelines_by_id": active_timelines_by_id,
448+
"retired_timelines": retired_timelines,
449+
}
450+
451+
425452
def _define_write_file(file_key, filesystem, write_function):
426453
if filesystem == "lambda":
427454

@@ -469,7 +496,6 @@ def _get_databricks_resource_group_name() -> str:
469496
_azure_credential = None
470497
_azure_subscription_id = None
471498

472-
473499
AzureClient = TypeVar("AzureClient")
474500

475501

@@ -519,7 +545,6 @@ def _get_azure_subscription_id():
519545
def _get_running_vms_by_id(
520546
compute: AzureClient, resource_group_name: Optional[str], cluster_id: str
521547
) -> Dict[str, dict]:
522-
523548
if resource_group_name:
524549
vms = compute.virtual_machines.list(resource_group_name=resource_group_name)
525550
else:

0 commit comments

Comments
 (0)