diff --git a/geos_ats_package/geos_ats/baseline_io.py b/geos_ats_package/geos_ats/baseline_io.py index 26b1920b..88d5370a 100644 --- a/geos_ats_package/geos_ats/baseline_io.py +++ b/geos_ats_package/geos_ats/baseline_io.py @@ -6,9 +6,13 @@ import time import requests import pathlib +import ssl from functools import partial from tqdm.auto import tqdm from google.cloud import storage +from google.auth.transport.requests import AuthorizedSession +from google.auth import default + logger = logging.getLogger( 'geos_ats' ) tmpdir = tempfile.TemporaryDirectory() @@ -28,7 +32,28 @@ def file_download_progress( headers: dict, url: str, filename: str ): path = pathlib.Path( filename ).expanduser().resolve() path.parent.mkdir( parents=True, exist_ok=True ) - r = requests.get( url, stream=True, allow_redirects=True, headers=headers ) + certs = ["/usr/local/share/ca-certificates/ADPKI_LLNLROOT.crt.crt", + "/usr/local/share/ca-certificates/DigiCertGlobalCAG2.crt.crt", + "/usr/local/share/ca-certificates/cspca.crt.crt", + "usr/local/share/ca-certificates/ADPKI-11.the-lab.llnl.gov_ADPKI-11.crt.crt", + "/usr/local/share/ca-certificates/ADPKI-12.the-lab.llnl.gov_ADPKI-12.crt.crt", + "/usr/local/share/ca-certificates/ADPKI-13.the-lab.llnl.gov_ADPKI-13.crt.crt", + "/usr/local/share/ca-certificates/ADPKI-14.the-lab.llnl.gov_ADPKI-14.crt.crt", + "/usr/local/share/ca-certificates/ADPKI-15.the-lab.llnl.gov_ADPKI-15.crt.crt", + "/usr/local/share/ca-certificates/ADPKI-16.the-lab.llnl.gov_ADPKI-16.crt.crt"] + + combined_cert_path = "/usr/local/share/ca-certificates/combined.crt" + + logger.info("file name.") + + with open(combined_cert_path, 'w') as outputfile: + for cert in certs: + with open(cert) as infile: + outputfile.write(infile.read()) + outputfile.write("\n") + + + r = requests.get( url, stream=True, allow_redirects=True, headers=headers, cert=combined_cert_path ) if r.status_code != 200: r.raise_for_status() raise RuntimeError( f"Request to {url} returned status code {r.status_code}" ) @@ -47,6 +72,21 @@ def file_download_progress( headers: dict, url: str, filename: str ): for chunk in r.iter_content( chunk_size=128 ): f.write( chunk ) +def create_anonymous_client_with_custom_cert(cert_path): + # Create a custom SSL context + ssl_context = ssl.create_default_context(cafile=cert_path) + + # Obtain default credentials + credentials, project = default() + + # Create an authorized session with the custom SSL context + authed_session = AuthorizedSession(credentials, ssl_context=ssl_context) + + # Initialize the storage client with the custom session + client = storage.Client(credentials=credentials, _http=authed_session) + + return client + def collect_baselines( bucket_name: str, blob_name: str, @@ -71,7 +111,7 @@ def collect_baselines( bucket_name: str, short_blob_name = os.path.basename( blob_name ) # Check to see if the baselines are already downloaded - logger.info( 'Checking for existing baseline files...' ) + logger.info( f'Checking for existing baseline files in {baseline_path}' ) if os.path.isdir( baseline_path ): if os.listdir( baseline_path ): logger.info( f'Target baseline directory already exists: {baseline_path}' ) @@ -113,6 +153,7 @@ def collect_baselines( bucket_name: str, archive_name = '' blob_tar = f'{blob_name}.tar.gz' short_blob_tar = f'{short_blob_name}.tar.gz' + logger.info( f'Checking cache directory ({cache_directory}) for existing baseline named {short_blob_name}' ) if cache_directory and not force_redownload: cache_directory = os.path.abspath( os.path.expanduser( cache_directory ) ) logger.info( f'Checking cache directory ({cache_directory}) for existing baseline...' ) @@ -128,7 +169,8 @@ def collect_baselines( bucket_name: str, archive_name = os.path.join( cache_directory, short_blob_tar ) else: archive_name = os.path.join( baseline_temporary_directory, short_blob_tar ) - + + logger.info( f'bucket_name {bucket_name}' ) if 'https://' in bucket_name: # Download from URL try: @@ -139,10 +181,37 @@ def collect_baselines( bucket_name: str, else: # Download from GCP try: - client = storage.Client.create_anonymous_client() + certs = ["/usr/local/share/ca-certificates/ADPKI_LLNLROOT.crt.crt", + "/usr/local/share/ca-certificates/DigiCertGlobalCAG2.crt.crt", + "/usr/local/share/ca-certificates/cspca.crt.crt", + "usr/local/share/ca-certificates/ADPKI-11.the-lab.llnl.gov_ADPKI-11.crt.crt", + "/usr/local/share/ca-certificates/ADPKI-12.the-lab.llnl.gov_ADPKI-12.crt.crt", + "/usr/local/share/ca-certificates/ADPKI-13.the-lab.llnl.gov_ADPKI-13.crt.crt", + "/usr/local/share/ca-certificates/ADPKI-14.the-lab.llnl.gov_ADPKI-14.crt.crt", + "/usr/local/share/ca-certificates/ADPKI-15.the-lab.llnl.gov_ADPKI-15.crt.crt", + "/usr/local/share/ca-certificates/ADPKI-16.the-lab.llnl.gov_ADPKI-16.crt.crt"] + + combined_cert_path = "/usr/local/share/ca-certificates/combined.crt" + + with open(combined_cert_path, 'w') as outputfile: + for cert in certs: + with open(cert) as infile: + outputfile.write(infile.read()) + outputfile.write("\n") + + os.environ['GRPC_DEFAULT_SSL_ROOTS_FILE_PATH'] = combined_cert_path + + # Obtain default credentials + credentials_path = os.getenv('GOOGLE_APPLICATION_CREDENTIALS') + + # Print the environment variable + logger.info(f"GOOGLE_APPLICATION_CREDENTIALS: {credentials_path}") + + client = create_anonymous_client_with_custom_cert(combined_cert_path) + bucket = client.bucket( bucket_name ) blob = bucket.blob( blob_tar ) - blob.download_to_filename( archive_name ) + blob.download_to_filename( archive_name ) except Exception as e: logger.error( f'Failed to download baseline from GCP ({bucket_name}/{blob_tar})' ) logger.error( repr( e ) )