Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@ boto3>=1.26.0
redis>=4.5.5
hiredis
libnacl>=2.1.0
azure-storage-blob>=12.9.0
azure-identity>=1.15.0
setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability
wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability
173 changes: 169 additions & 4 deletions tensorizer/stream_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@
import time
import typing
import weakref
from datetime import datetime, timedelta, timezone
from io import SEEK_CUR, SEEK_END, SEEK_SET
from typing import Any, Dict, List, Optional, Tuple, Union
from tempfile import _TemporaryFileWrapper
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse

import azure.core.exceptions
import boto3
import botocore
import redis
from azure.identity import DefaultAzureCredential
from azure.storage.blob import (
BlobSasPermissions,
BlobServiceClient,
generate_blob_sas,
)

import tensorizer._version as _version
import tensorizer._wide_pipes as _wide_pipes
Expand Down Expand Up @@ -59,6 +68,13 @@ class _ParsedCredentials(typing.NamedTuple):
s3_secret_key: Optional[str]


class _ParsedAzureCredentials(typing.NamedTuple):
account_name: Optional[str]
account_key: Optional[str]
endpoint: Optional[str]
protocol: Optional[str]


@functools.lru_cache(maxsize=None)
def _get_s3cfg_values(
config_paths: Optional[
Expand Down Expand Up @@ -970,6 +986,29 @@ def s3_upload(
client.upload_file(path, bucket, key)


def azure_upload(
path: str,
azure_storage_account_name: str,
azure_container: str,
azure_blob: str,
azure_credentials=DefaultAzureCredential(),
):
blob_service_client = BlobServiceClient(
account_url=f"https://{azure_storage_account_name}.blob.core.windows.net",
credential=azure_credentials,
)
# Check if the container exists, if not create it
try:
blob_service_client.create_container(azure_container)
except azure.core.exceptions.ResourceExistsError:
pass
blob_client = blob_service_client.get_blob_client(
container=azure_container, blob=azure_blob
)
with open(path, "rb") as data:
blob_client.upload_blob(data)


def _s3_download_url(
path_uri: str,
s3_access_key_id: str,
Expand Down Expand Up @@ -1060,6 +1099,45 @@ def s3_download(
)


def azure_download(
account_name: str,
container: str,
blob: str,
azure_credentials: DefaultAzureCredential,
buffer_size: Optional[int] = None,
force_http: bool = False,
begin: Optional[int] = None,
end: Optional[int] = None,
):
start_time = datetime.now(timezone.utc)
expiry_time: datetime = start_time + timedelta(minutes=15)
bsc = BlobServiceClient(
account_url=f"https://{account_name}.blob.core.windows.net",
credential=azure_credentials,
)
key = bsc.get_user_delegation_key(start_time, expiry_time)
sas = generate_blob_sas(
account_name=account_name,
container_name=container,
blob_name=blob,
user_delegation_key=key,
permission=BlobSasPermissions(read=True),
start=start_time,
expiry=expiry_time,
)
url = (
f"https://{account_name}.blob.core.windows.net/{container}/{blob}?{sas}"
)
if force_http and url.lower().startswith("https://"):
url = "http://" + url[8:]
return CURLStreamFile(
url,
buffer_size=buffer_size,
begin=begin,
end=end,
)


def _infer_credentials(
s3_access_key_id: Optional[str],
s3_secret_access_key: Optional[str],
Expand Down Expand Up @@ -1151,7 +1229,9 @@ def _infer_credentials(
)


def _temp_file_closer(file: io.IOBase, file_name: str, *upload_args):
def _temp_file_closer(
file: io.IOBase, file_name: str, callback_fn: typing.Callable, *upload_args
):
"""
Close, upload by name, and then delete the file.
Meant to be placed as a hook before both .close() and .__exit__()
Expand Down Expand Up @@ -1180,7 +1260,7 @@ def _temp_file_closer(file: io.IOBase, file_name: str, *upload_args):

try:
file.close()
s3_upload(file_name, *upload_args)
callback_fn(file_name, *upload_args)
finally:
try:
os.unlink(file_name)
Expand All @@ -1203,7 +1283,7 @@ def open_stream(
s3_region_name: Optional[str] = None,
s3_signature_version: Optional[str] = None,
certificate_handling: Optional[CAInfo] = None,
) -> Union[CURLStreamFile, RedisStreamFile, typing.BinaryIO]:
) -> Union[CURLStreamFile, RedisStreamFile, _TemporaryFileWrapper, BinaryIO]:
"""
Open a file path, http(s):// URL, or s3:// URI.

Expand Down Expand Up @@ -1343,6 +1423,90 @@ def open_stream(
path_uri, buffer_size=buffer_size, begin=begin, end=end
)

elif not local_only and scheme == "azure":
if normalized_mode not in ("br", "bw", "ab", "+bw", "+ab"):
raise ValueError(
'Only the modes "rb", "wb[+]", and "ab[+]" are valid'
" when opening azure:// streams."
)
is_azure_upload = "w" in mode or "a" in mode
try:
azure_credentials = DefaultAzureCredential()
except Exception as e:
raise ValueError(
"Failed to get Azure credentials. "
"Please ensure that you have the Azure Identity library installed."
) from e

# Parse our path_uri
uri_components = urlparse(path_uri)
account_name = uri_components.netloc
path_elements = uri_components.path.split("/")
if len(path_elements) < 2:
raise ValueError(f"Invalid Azure URI: {path_uri}")
container = uri_components.path.split("/")[1]
blob = "/".join(uri_components.path.split("/")[2:])

if is_azure_upload:
temp_file = tempfile.NamedTemporaryFile(mode="wb+", delete=False)

# Attach a callback to upload the temporary file when it closes.
# weakref finalizers are idempotent, so this upload callback
# is guaranteed to run at most once.
guaranteed_closer = weakref.finalize(
temp_file,
_temp_file_closer,
temp_file.file,
temp_file.name,
azure_upload,
account_name,
container,
blob,
azure_credentials,
)

# Always run the close + upload procedure
# before any code from Python's NamedTemporaryFile wrapper.
# It isn't safe to call a bound method from a weakref finalizer,
# but calling a weakref finalizer alongside a bound method
# creates no problems, other than that the code outside the
# finalizer is not guaranteed to be run at any point.
# In this case, the weakref finalizer performs all necessary
# cleanup itself, but the original NamedTemporaryFile methods
# are invoked as well, just in case.
wrapped_close = temp_file.close

def close_wrapper():
guaranteed_closer()
return wrapped_close()

# Python 3.12+ doesn't call NamedTemporaryFile.close() during
# .__exit__(), so it must be wrapped separately.
# Since guaranteed_closer is idempotent, it's fine to call it in
# both methods, even if both are called back-to-back.
wrapped_exit = temp_file.__exit__

def exit_wrapper(exc, value, tb):
guaranteed_closer()
return wrapped_exit(exc, value, tb)

temp_file.close = close_wrapper
temp_file.__exit__ = exit_wrapper

return temp_file
else:
curl_stream_file = azure_download(
account_name,
container,
blob,
azure_credentials,
buffer_size=buffer_size,
begin=begin,
end=end,
force_http=force_http,
)
return curl_stream_file

elif not local_only and scheme == "s3":
if normalized_mode not in ("br", "bw", "ab", "+bw", "+ab"):
raise ValueError(
Expand Down Expand Up @@ -1402,6 +1566,7 @@ def open_stream(
_temp_file_closer,
temp_file.file,
temp_file.name,
s3_upload,
path_uri,
s3_access_key_id,
s3_secret_access_key,
Expand Down