2
2
import logging
3
3
from pathlib import Path
4
4
from time import sleep
5
- from typing import Generator , List , Tuple
5
+ from typing import Dict , Generator , List , Optional , Tuple
6
6
from urllib .parse import urlparse
7
7
8
8
import boto3 as boto
84
84
"wait_for_run_and_cluster" ,
85
85
"terminate_cluster" ,
86
86
"apply_project_recommendation" ,
87
+ "save_cluster_report" ,
87
88
]
88
89
89
90
logger = logging .getLogger (__name__ )
@@ -348,6 +349,62 @@ def _get_aws_cluster_info_from_s3(bucket: str, file_key: str, cluster_id):
348
349
logger .warning (f"Failed to retrieve cluster info from S3 with key, '{ file_key } ': { err } " )
349
350
350
351
352
+ def save_cluster_report (
353
+ cluster_id : str ,
354
+ instance_timelines : List [dict ],
355
+ cluster_log_destination : Optional [Tuple [str , ...]] = None ,
356
+ cluster_report_destination_override : Optional [Dict [str , str ]] = None ,
357
+ write_function = None ,
358
+ ) -> bool :
359
+ cluster = get_default_client ().get_cluster (cluster_id )
360
+ spark_context_id = cluster .get ("spark_context_id" )
361
+
362
+ if not spark_context_id :
363
+ return False
364
+
365
+ cluster_log_destination = cluster_log_destination or _cluster_log_destination (cluster )
366
+ (log_url , filesystem , bucket , base_prefix ) = cluster_log_destination
367
+
368
+ if cluster_report_destination_override :
369
+ filesystem = cluster_report_destination_override .get ("filesystem" , filesystem )
370
+ base_prefix = cluster_report_destination_override .get ("base_prefix" , base_prefix )
371
+
372
+ if not log_url and not cluster_report_destination_override :
373
+ logger .warning (
374
+ "Unable to save cluster report due to missing cluster log destination - exiting"
375
+ )
376
+ return False
377
+
378
+ # If the event log destination is just a *bucket* without any sub-path, then we don't want to include
379
+ # a leading `/` in our Prefix (which will make it so that we never actually find the event log), so
380
+ # we make sure to re-strip our final Prefix
381
+ file_key = f"{ base_prefix } /sync_data/{ spark_context_id } /aws_cluster_info.json" .strip ("/" )
382
+
383
+ aws_region_name = DB_CONFIG .aws_region_name
384
+ ec2 = boto .client ("ec2" , region_name = aws_region_name )
385
+
386
+ current_insts = _get_ec2_instances (cluster_id , ec2 )
387
+ ebs_volumes = _get_ebs_volumes_for_instances (current_insts , ec2 )
388
+
389
+ write_file = _define_write_file (file_key , filesystem , bucket , write_function )
390
+
391
+ write_file (
392
+ bytes (
393
+ json .dumps (
394
+ {
395
+ "instances" : current_insts ,
396
+ "instance_timelines" : instance_timelines ,
397
+ "volumes" : ebs_volumes ,
398
+ },
399
+ cls = DefaultDateTimeEncoder ,
400
+ ),
401
+ "utf-8" ,
402
+ )
403
+ )
404
+
405
+ return True
406
+
407
+
351
408
def monitor_cluster (
352
409
cluster_id : str ,
353
410
polling_period : int = 20 ,
@@ -375,7 +432,6 @@ def monitor_cluster(
375
432
_monitor_cluster (
376
433
(log_url , filesystem , bucket , base_prefix ),
377
434
cluster_id ,
378
- spark_context_id ,
379
435
polling_period ,
380
436
kill_on_termination ,
381
437
write_function ,
@@ -387,40 +443,26 @@ def monitor_cluster(
387
443
def _monitor_cluster (
388
444
cluster_log_destination ,
389
445
cluster_id : str ,
390
- spark_context_id : int ,
391
446
polling_period : int ,
392
447
kill_on_termination : bool = False ,
393
448
write_function = None ,
394
449
) -> None :
395
- (log_url , filesystem , bucket , base_prefix ) = cluster_log_destination
396
- # If the event log destination is just a *bucket* without any sub-path, then we don't want to include
397
- # a leading `/` in our Prefix (which will make it so that we never actually find the event log), so
398
- # we make sure to re-strip our final Prefix
399
- file_key = f"{ base_prefix } /sync_data/{ spark_context_id } /aws_cluster_info.json" .strip ("/" )
400
-
401
450
aws_region_name = DB_CONFIG .aws_region_name
402
451
ec2 = boto .client ("ec2" , region_name = aws_region_name )
403
452
404
- write_file = _define_write_file (file_key , filesystem , bucket , write_function )
405
-
406
- all_inst_by_id = {}
407
453
active_timelines_by_id = {}
408
454
retired_timelines = []
409
- recorded_volumes_by_id = {}
410
455
411
456
while_condition = True
457
+
412
458
while while_condition :
413
459
try :
414
460
current_insts = _get_ec2_instances (cluster_id , ec2 )
415
- recorded_volumes_by_id .update (
416
- {v ["VolumeId" ]: v for v in _get_ebs_volumes_for_instances (current_insts , ec2 )}
417
- )
418
461
419
462
# Record new (or overwrite) existing instances.
420
463
# Separately record the ids of those that are in the "running" state.
421
464
running_inst_ids = set ({})
422
465
for inst in current_insts :
423
- all_inst_by_id [inst ["InstanceId" ]] = inst
424
466
if inst ["State" ]["Name" ] == "running" :
425
467
running_inst_ids .add (inst ["InstanceId" ])
426
468
@@ -431,24 +473,19 @@ def _monitor_cluster(
431
473
retired_timelines .extend (new_retired_timelines )
432
474
all_timelines = retired_timelines + list (active_timelines_by_id .values ())
433
475
434
- write_file (
435
- bytes (
436
- json .dumps (
437
- {
438
- "instances" : list (all_inst_by_id .values ()),
439
- "instance_timelines" : all_timelines ,
440
- "volumes" : list (recorded_volumes_by_id .values ()),
441
- },
442
- cls = DefaultDateTimeEncoder ,
443
- ),
444
- "utf-8" ,
445
- )
476
+ save_cluster_report (
477
+ cluster_id ,
478
+ all_timelines ,
479
+ cluster_log_destination = cluster_log_destination ,
480
+ write_function = write_function ,
446
481
)
447
482
448
483
if kill_on_termination :
449
484
cluster_state = get_default_client ().get_cluster (cluster_id ).get ("state" )
485
+
450
486
if cluster_state == "TERMINATED" :
451
487
while_condition = False
488
+
452
489
except Exception as e :
453
490
logger .error (f"Exception encountered while polling cluster: { e } " )
454
491
0 commit comments