Skip to content

Commit 60803c3

Browse files
authored
[PROD-1103] Access report (#31)
1 parent a77cbf7 commit 60803c3

File tree

9 files changed

+354
-30
lines changed

9 files changed

+354
-30
lines changed

sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Library for leveraging the power of Sync"""
2-
__version__ = "0.1.0"
2+
__version__ = "0.1.1"
33

44
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

sync/api/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from sync.clients.sync import get_default_client
2+
from sync.models import AccessReport, AccessReportLine, AccessStatusCode
3+
4+
5+
def get_access_report() -> AccessReport:
6+
"""Reports access to the Sync API
7+
8+
:return: access report
9+
:rtype: AccessReport
10+
"""
11+
response = get_default_client().get_products()
12+
13+
error = response.get("error")
14+
if error:
15+
return AccessReport(
16+
[
17+
AccessReportLine(
18+
name="Sync Authentication",
19+
status=AccessStatusCode.RED,
20+
message=f"{error['code']}: {error['message']}",
21+
)
22+
]
23+
)
24+
25+
return AccessReport(
26+
[
27+
AccessReportLine(
28+
name="Sync Authentication",
29+
status=AccessStatusCode.GREEN,
30+
message="Sync credentials are valid",
31+
)
32+
]
33+
)

sync/awsdatabricks.py

Lines changed: 116 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,18 @@
1212
from urllib.parse import urlparse
1313

1414
import boto3 as boto
15+
import orjson
1516
from botocore.exceptions import ClientError
16-
from orjson import orjson
1717

18+
from sync.api import get_access_report as get_api_access_report
1819
from sync.api.predictions import create_prediction_with_eventlog_bytes, get_prediction
1920
from sync.api.projects import get_project
2021
from sync.clients.databricks import get_default_client
2122
from sync.config import CONFIG, DB_CONFIG
2223
from sync.models import (
24+
AccessReport,
25+
AccessReportLine,
26+
AccessStatusCode,
2327
DatabricksAPIError,
2428
DatabricksClusterReport,
2529
DatabricksError,
@@ -31,6 +35,100 @@
3135
logger = logging.getLogger(__name__)
3236

3337

38+
def get_access_report(log_url: str = None) -> AccessReport:
39+
"""Reports access to Databricks, AWS and Sync required for integrating jobs with Sync.
40+
Access is partially determined by the configuration of this library and boto3.
41+
42+
:param log_url: location of event logs, defaults to None
43+
:type log_url: str, optional
44+
:return: access report
45+
:rtype: AccessReport
46+
"""
47+
report = get_api_access_report()
48+
dbx_client = get_default_client()
49+
50+
response = dbx_client.get_current_user()
51+
user_name = response.get("userName")
52+
if user_name:
53+
report.append(
54+
AccessReportLine(
55+
name="Databricks Authentication",
56+
status=AccessStatusCode.GREEN,
57+
message=f"Authenticated as '{user_name}'",
58+
)
59+
)
60+
else:
61+
report.append(
62+
AccessReportLine(
63+
name="Databricks Authentication",
64+
status=AccessStatusCode.RED,
65+
message=f"{response.get('error_code')}: {response.get('message')}",
66+
)
67+
)
68+
69+
response = boto.client("sts").get_caller_identity()
70+
arn = response.get("Arn")
71+
if arn:
72+
report.append(
73+
AccessReportLine(
74+
name="AWS Authentication",
75+
status=AccessStatusCode.GREEN,
76+
message=f"Authenticated as '{arn}'",
77+
)
78+
)
79+
80+
ec2 = boto.client("ec2", region_name=DB_CONFIG.aws_region_name)
81+
report.add_boto_method_call(ec2.describe_instances, AccessStatusCode.YELLOW, DryRun=True)
82+
else:
83+
report.append(
84+
AccessReportLine(
85+
name="AWS Authentication",
86+
status=AccessStatusCode.RED,
87+
message="Failed to authenticate AWS credentials",
88+
)
89+
)
90+
91+
if log_url:
92+
parsed_log_url = urlparse(log_url)
93+
94+
if parsed_log_url.scheme == "s3" and arn:
95+
s3 = boto.client("s3")
96+
report.add_boto_method_call(
97+
s3.list_objects_v2,
98+
Bucket=parsed_log_url.netloc,
99+
Prefix=parsed_log_url.params.rstrip("/"),
100+
MaxKeys=1,
101+
)
102+
elif parsed_log_url.scheme == "dbfs":
103+
response = dbx_client.list_dbfs_directory(parsed_log_url.geturl())
104+
if "error_code" not in response:
105+
report.append(
106+
AccessReportLine(
107+
name="Log Access",
108+
status=AccessStatusCode.GREEN,
109+
message=f"Can list objects at {parsed_log_url.geturl()}",
110+
)
111+
)
112+
else:
113+
report.append(
114+
AccessReportLine(
115+
name="Log Access",
116+
status=AccessStatusCode.RED,
117+
message=f"Can list objects at {parsed_log_url.geturl()}",
118+
)
119+
)
120+
else:
121+
report.append(
122+
AccessReportLine(
123+
name="Log Access",
124+
status=AccessStatusCode.RED,
125+
message=f"scheme in {parsed_log_url.geturl()} is not supported",
126+
)
127+
)
128+
129+
return report
130+
131+
34132
def create_prediction(
35133
plan_type: str,
36134
compute_type: str,
@@ -314,15 +412,18 @@ def _get_cluster_instances(cluster: dict) -> Response[dict]:
314412
# are associated with this cluster
315413
if cluster_instances is None:
316414
ec2 = boto.client("ec2", region_name=aws_region_name)
317-
cluster_instances = ec2.describe_instances(
318-
Filters=[
319-
{"Name": "tag:Vendor", "Values": ["Databricks"]},
320-
{"Name": "tag:ClusterId", "Values": [cluster_id]},
321-
# {'Name': 'tag:JobId', 'Values': []}
322-
]
323-
)
415+
try:
416+
cluster_instances = ec2.describe_instances(
417+
Filters=[
418+
{"Name": "tag:Vendor", "Values": ["Databricks"]},
419+
{"Name": "tag:ClusterId", "Values": [cluster_id]},
420+
# {'Name': 'tag:JobId', 'Values': []}
421+
]
422+
)
423+
except Exception as exc:
424+
logger.warning(exc)
324425

325-
if not cluster_instances["Reservations"]:
426+
if not cluster_instances or not cluster_instances["Reservations"]:
326427
no_instances_message = (
327428
f"Unable to find any active or recently terminated instances for cluster `{cluster_id}` in `{aws_region_name}`. "
328429
+ "Please refer to the following documentation for options on how to address this - "
@@ -1179,10 +1280,10 @@ def _get_eventlog_from_dbfs(
11791280
run_end_time_millis: int,
11801281
poll_duration_seconds: int,
11811282
):
1182-
prefix = format_dbfs_filepath(f"{base_filepath}/eventlog/")
1183-
1184-
root_dir = get_default_client().list_dbfs_directory(prefix)
1283+
dbx_client = get_default_client()
11851284

1285+
prefix = format_dbfs_filepath(f"{base_filepath}/eventlog/")
1286+
root_dir = dbx_client.list_dbfs_directory(prefix)
11861287
eventlog_files = [f for f in root_dir["files"] if f["is_dir"]]
11871288
matching_subdirectory = None
11881289

@@ -1195,7 +1296,7 @@ def _get_eventlog_from_dbfs(
11951296
eventlog_file_metadata = eventlog_files.pop()
11961297
path = eventlog_file_metadata["path"]
11971298

1198-
subdir = get_default_client().list_dbfs_directory(path)
1299+
subdir = dbx_client.list_dbfs_directory(path)
11991300

12001301
subdir_files = subdir["files"]
12011302
matching_subdirectory = next(
@@ -1204,7 +1305,7 @@ def _get_eventlog_from_dbfs(
12041305
)
12051306

12061307
if matching_subdirectory:
1207-
eventlog_dir = get_default_client().list_dbfs_directory(matching_subdirectory["path"])
1308+
eventlog_dir = dbx_client.list_dbfs_directory(matching_subdirectory["path"])
12081309

12091310
poll_num_attempts = 0
12101311
poll_max_attempts = 20 # 5 minutes / 15 seconds = 20 attempts
@@ -1219,14 +1320,13 @@ def _get_eventlog_from_dbfs(
12191320
)
12201321
sleep(poll_duration_seconds)
12211322

1222-
eventlog_dir = get_default_client().list_dbfs_directory(matching_subdirectory["path"])
1323+
eventlog_dir = dbx_client.list_dbfs_directory(matching_subdirectory["path"])
12231324
poll_num_attempts += 1
12241325

12251326
eventlog_zip = io.BytesIO()
12261327
eventlog_zip_file = zipfile.ZipFile(eventlog_zip, "a", zipfile.ZIP_DEFLATED)
12271328

12281329
eventlog_files = eventlog_dir["files"]
1229-
dbx_client = get_default_client()
12301330
for eventlog_file_metadata in eventlog_files:
12311331
filename: str = eventlog_file_metadata["path"].split("/")[-1]
12321332
filesize: int = eventlog_file_metadata["file_size"]

sync/awsemr.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,119 @@
77
import logging
88
import re
99
from copy import deepcopy
10+
from typing import Tuple
1011
from urllib.parse import urlparse
1112
from uuid import uuid4
1213

1314
import boto3 as boto
1415
import orjson
1516
from dateutil.parser import parse as dateparse
16-
from typing import Tuple
1717

1818
from sync import TIME_FORMAT
19+
from sync.api import get_access_report as get_api_access_report
1920
from sync.api.predictions import create_prediction, wait_for_prediction
2021
from sync.api.projects import get_project
21-
from sync.models import EMRError, Platform, ProjectError, Response
22+
from sync.models import (
23+
AccessReport,
24+
AccessReportLine,
25+
AccessStatusCode,
26+
EMRError,
27+
Platform,
28+
ProjectError,
29+
Response,
30+
)
2231

2332
logger = logging.getLogger(__name__)
2433

2534

2635
RUN_DIR_PATTERN_TEMPLATE = r"{project_prefix}/{project_id}/(?P<timestamp>\d{{4}}-[^/]+)/{run_id}"
2736

2837

38+
def get_access_report(
39+
log_url: str = None, cluster_id: str = None, region_name: str = None
40+
) -> AccessReport:
41+
"""Reports access to systems required for integration of EMR jobs with Sync.
42+
Access is partially determined by the configuration of this library and boto3.
43+
44+
:param log_url: location of event logs, defaults to None
45+
:type log_url: str, optional
46+
:param cluster_id: cluster ID with which to test EMR access, defaults to None
47+
:type cluster_id: str, optional
48+
:param region_name: region override, defaults to None
49+
:type region_name: str, optional
50+
:return: access report
51+
:rtype: AccessReport
52+
"""
53+
report = get_api_access_report()
54+
sts = boto.client("sts")
55+
response = sts.get_caller_identity()
56+
57+
arn = response.get("Arn")
58+
if not arn:
59+
report.append(
60+
AccessReportLine(
61+
name="AWS Authentication",
62+
status=AccessStatusCode.RED,
63+
message="Failed to authenticate AWS credentials",
64+
)
65+
)
66+
else:
67+
report.append(
68+
AccessReportLine(
69+
name="AWS Authentication",
70+
status=AccessStatusCode.GREEN,
71+
message=f"Authenticated as '{arn}'",
72+
)
73+
)
74+
75+
if log_url:
76+
parsed_log_url = urlparse(log_url)
77+
78+
if parsed_log_url.scheme == "s3" and arn:
79+
s3 = boto.client("s3")
80+
report.add_boto_method_call(
81+
s3.list_objects_v2,
82+
Bucket=parsed_log_url.netloc,
83+
Prefix=parsed_log_url.params.rstrip("/"),
84+
MaxKeys=1,
85+
)
86+
else:
87+
report.append(
88+
AccessReportLine(
89+
name="Logging",
90+
status=AccessStatusCode.RED,
91+
message=f"scheme in {parsed_log_url.geturl()} is not supported",
92+
)
93+
)
94+
95+
if arn and cluster_id:
96+
emr = boto.client("emr", region_name=region_name)
97+
98+
try:
99+
response = emr.describe_cluster(ClusterId=cluster_id)
100+
report.append(
101+
AccessReportLine(
102+
"EMR DescribeCluster",
103+
AccessStatusCode.GREEN,
104+
"describe_cluster call succeeded",
105+
)
106+
)
107+
108+
if response["Cluster"]["InstanceCollectionType"] == "INSTANCE_FLEET":
109+
report.add_boto_method_call(emr.list_instance_fleets, ClusterId=cluster_id)
110+
elif response["Cluster"]["InstanceCollectionType"] == "INSTANCE_GROUP":
111+
report.add_boto_method_call(emr.list_instance_groups, ClusterId=cluster_id)
112+
113+
report.add_boto_method_call(emr.list_bootstrap_actions, ClusterId=cluster_id)
114+
report.add_boto_method_call(emr.list_instances, ClusterId=cluster_id)
115+
report.add_boto_method_call(emr.list_steps, ClusterId=cluster_id)
116+
117+
except Exception as exc:
118+
report.append(AccessReportLine("EMR DescribeCluster", AccessStatusCode.RED, str(exc)))
119+
120+
return report
121+
122+
29123
def get_project_job_flow(job_flow: dict, project_id: str) -> Response[dict]:
30124
"""Returns a copy of the incoming job flow with project configuration.
31125

sync/cli/awsdatabricks.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import click
2-
from orjson import orjson
2+
import orjson
33

44
from sync import awsdatabricks
55
from sync.cli.util import validate_project
@@ -10,7 +10,13 @@
1010
@click.group
1111
def aws_databricks():
1212
"""Databricks on AWS commands"""
13-
pass
13+
14+
15+
@aws_databricks.command
16+
@click.option("--log-url")
17+
def access_report(log_url: str = None):
18+
"""Get access report"""
19+
click.echo(awsdatabricks.get_access_report(log_url))
1420

1521

1622
@aws_databricks.command

sync/cli/awsemr.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ def aws_emr():
1717
pass
1818

1919

20+
@aws_emr.command
21+
@click.option("--cluster-id")
22+
@click.option("--log-url")
23+
@click.option("-r", "--region")
24+
def access_report(cluster_id: str = None, log_url: str = None, region: str = None):
25+
"""Get access report"""
26+
click.echo(awsemr.get_access_report(log_url=log_url, cluster_id=cluster_id, region_name=region))
27+
28+
2029
@aws_emr.command
2130
@click.argument("job-flow", type=click.File("r"))
2231
@click.option("-p", "--project", callback=validate_project)

0 commit comments

Comments
 (0)