From 21850d1d324669cb3ceadd3cec2c0e8f7dc0ab59 Mon Sep 17 00:00:00 2001 From: Wes Brown Date: Fri, 22 Mar 2024 17:14:58 -0400 Subject: [PATCH] feat(azure): Preliminary support for Azure --- requirements.txt | 2 + tensorizer/stream_io.py | 173 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 171 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index de7a50c1..7786015b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,7 @@ boto3>=1.26.0 redis>=4.5.5 hiredis libnacl>=2.1.0 +azure-storage-blob>=12.9.0 +azure-identity>=1.15.0 setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/tensorizer/stream_io.py b/tensorizer/stream_io.py index 924b8b48..026c5841 100644 --- a/tensorizer/stream_io.py +++ b/tensorizer/stream_io.py @@ -13,13 +13,22 @@ import time import typing import weakref +from datetime import datetime, timedelta, timezone from io import SEEK_CUR, SEEK_END, SEEK_SET -from typing import Any, Dict, List, Optional, Tuple, Union +from tempfile import _TemporaryFileWrapper +from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse +import azure.core.exceptions import boto3 import botocore import redis +from azure.identity import DefaultAzureCredential +from azure.storage.blob import ( + BlobSasPermissions, + BlobServiceClient, + generate_blob_sas, +) import tensorizer._version as _version import tensorizer._wide_pipes as _wide_pipes @@ -59,6 +68,13 @@ class _ParsedCredentials(typing.NamedTuple): s3_secret_key: Optional[str] +class _ParsedAzureCredentials(typing.NamedTuple): + account_name: Optional[str] + account_key: Optional[str] + endpoint: Optional[str] + protocol: Optional[str] + + @functools.lru_cache(maxsize=None) def _get_s3cfg_values( config_paths: Optional[ @@ -970,6 +986,29 @@ def s3_upload( client.upload_file(path, bucket, key) +def azure_upload( + path: str, + azure_storage_account_name: str, + azure_container: str, + azure_blob: str, + azure_credentials=DefaultAzureCredential(), +): + blob_service_client = BlobServiceClient( + account_url=f"https://{azure_storage_account_name}.blob.core.windows.net", + credential=azure_credentials, + ) + # Check if the container exists, if not create it + try: + blob_service_client.create_container(azure_container) + except azure.core.exceptions.ResourceExistsError: + pass + blob_client = blob_service_client.get_blob_client( + container=azure_container, blob=azure_blob + ) + with open(path, "rb") as data: + blob_client.upload_blob(data) + + def _s3_download_url( path_uri: str, s3_access_key_id: str, @@ -1060,6 +1099,45 @@ def s3_download( ) +def azure_download( + account_name: str, + container: str, + blob: str, + azure_credentials: DefaultAzureCredential, + buffer_size: Optional[int] = None, + force_http: bool = False, + begin: Optional[int] = None, + end: Optional[int] = None, +): + start_time = datetime.now(timezone.utc) + expiry_time: datetime = start_time + timedelta(minutes=15) + bsc = BlobServiceClient( + account_url=f"https://{account_name}.blob.core.windows.net", + credential=azure_credentials, + ) + key = bsc.get_user_delegation_key(start_time, expiry_time) + sas = generate_blob_sas( + account_name=account_name, + container_name=container, + blob_name=blob, + user_delegation_key=key, + permission=BlobSasPermissions(read=True), + start=start_time, + expiry=expiry_time, + ) + url = ( + f"https://{account_name}.blob.core.windows.net/{container}/{blob}?{sas}" + ) + if force_http and url.lower().startswith("https://"): + url = "http://" + url[8:] + return CURLStreamFile( + url, + buffer_size=buffer_size, + begin=begin, + end=end, + ) + + def _infer_credentials( s3_access_key_id: Optional[str], s3_secret_access_key: Optional[str], @@ -1151,7 +1229,9 @@ def _infer_credentials( ) -def _temp_file_closer(file: io.IOBase, file_name: str, *upload_args): +def _temp_file_closer( + file: io.IOBase, file_name: str, callback_fn: typing.Callable, *upload_args +): """ Close, upload by name, and then delete the file. Meant to be placed as a hook before both .close() and .__exit__() @@ -1180,7 +1260,7 @@ def _temp_file_closer(file: io.IOBase, file_name: str, *upload_args): try: file.close() - s3_upload(file_name, *upload_args) + callback_fn(file_name, *upload_args) finally: try: os.unlink(file_name) @@ -1203,7 +1283,7 @@ def open_stream( s3_region_name: Optional[str] = None, s3_signature_version: Optional[str] = None, certificate_handling: Optional[CAInfo] = None, -) -> Union[CURLStreamFile, RedisStreamFile, typing.BinaryIO]: +) -> Union[CURLStreamFile, RedisStreamFile, _TemporaryFileWrapper, BinaryIO]: """ Open a file path, http(s):// URL, or s3:// URI. @@ -1343,6 +1423,90 @@ def open_stream( path_uri, buffer_size=buffer_size, begin=begin, end=end ) + elif not local_only and scheme == "azure": + if normalized_mode not in ("br", "bw", "ab", "+bw", "+ab"): + raise ValueError( + 'Only the modes "rb", "wb[+]", and "ab[+]" are valid' + " when opening azure:// streams." + ) + is_azure_upload = "w" in mode or "a" in mode + try: + azure_credentials = DefaultAzureCredential() + except Exception as e: + raise ValueError( + "Failed to get Azure credentials. " + "Please ensure that you have the Azure Identity library installed." + ) from e + + # Parse our path_uri + uri_components = urlparse(path_uri) + account_name = uri_components.netloc + path_elements = uri_components.path.split("/") + if len(path_elements) < 2: + raise ValueError(f"Invalid Azure URI: {path_uri}") + container = uri_components.path.split("/")[1] + blob = "/".join(uri_components.path.split("/")[2:]) + + if is_azure_upload: + temp_file = tempfile.NamedTemporaryFile(mode="wb+", delete=False) + + # Attach a callback to upload the temporary file when it closes. + # weakref finalizers are idempotent, so this upload callback + # is guaranteed to run at most once. + guaranteed_closer = weakref.finalize( + temp_file, + _temp_file_closer, + temp_file.file, + temp_file.name, + azure_upload, + account_name, + container, + blob, + azure_credentials, + ) + + # Always run the close + upload procedure + # before any code from Python's NamedTemporaryFile wrapper. + # It isn't safe to call a bound method from a weakref finalizer, + # but calling a weakref finalizer alongside a bound method + # creates no problems, other than that the code outside the + # finalizer is not guaranteed to be run at any point. + # In this case, the weakref finalizer performs all necessary + # cleanup itself, but the original NamedTemporaryFile methods + # are invoked as well, just in case. + wrapped_close = temp_file.close + + def close_wrapper(): + guaranteed_closer() + return wrapped_close() + + # Python 3.12+ doesn't call NamedTemporaryFile.close() during + # .__exit__(), so it must be wrapped separately. + # Since guaranteed_closer is idempotent, it's fine to call it in + # both methods, even if both are called back-to-back. + wrapped_exit = temp_file.__exit__ + + def exit_wrapper(exc, value, tb): + guaranteed_closer() + return wrapped_exit(exc, value, tb) + + temp_file.close = close_wrapper + temp_file.__exit__ = exit_wrapper + + return temp_file + else: + curl_stream_file = azure_download( + account_name, + container, + blob, + azure_credentials, + buffer_size=buffer_size, + begin=begin, + end=end, + force_http=force_http, + ) + return curl_stream_file + elif not local_only and scheme == "s3": if normalized_mode not in ("br", "bw", "ab", "+bw", "+ab"): raise ValueError( @@ -1402,6 +1566,7 @@ def open_stream( _temp_file_closer, temp_file.file, temp_file.name, + s3_upload, path_uri, s3_access_key_id, s3_secret_access_key,