Skip to content
4 changes: 2 additions & 2 deletions batch/pinned-requirements.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion ci/pinned-requirements.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 16 additions & 5 deletions gear/pinned-requirements.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions hail/python/dev/pinned-requirements.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

177 changes: 172 additions & 5 deletions hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,39 @@
import logging
import os
import urllib.parse
from concurrent.futures.thread import ThreadPoolExecutor
from contextlib import AsyncExitStack
from types import TracebackType
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, MutableMapping, Optional, Set, Tuple, Type, cast
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Dict,
List,
MutableMapping,
Optional,
Set,
Tuple,
Type,
cast,
)

import aiohttp
from google.auth.aio.credentials import AnonymousCredentials
from google.cloud.storage import Client, transfer_manager
from google.oauth2 import service_account
from google.oauth2.credentials import Credentials
from multidict import CIMultiDictProxy # pylint: disable=unused-import # pylint: disable=unused-import

from hailtop import timex
from hailtop.aiotools import FeedableAsyncIterable, WriteBuffer
from hailtop.aiocloud.common import AnonymousCloudCredentials
from hailtop.aiotools import (
FeedableAsyncIterable,
LocalAsyncFS,
WeightedSemaphore,
WriteBuffer,
)
from hailtop.aiotools.fs import (
AsyncFS,
AsyncFSFactory,
Expand All @@ -25,10 +49,17 @@
UnexpectedEOFError,
WritableStream,
)
from hailtop.utils import OnlineBoundedGather2, TransientError, retry_transient_errors, secret_alnum_string
from hailtop.utils import (
OnlineBoundedGather2,
TransientError,
async_to_blocking,
blocking_to_async,
retry_transient_errors,
secret_alnum_string,
)

from ...common.session import BaseSession
from ..credentials import GoogleCredentials
from ..credentials import GoogleCredentials, GoogleServiceAccountCredentials
from ..user_config import GCSRequesterPaysConfiguration, get_gcs_requester_pays_configuration
from .base_client import GoogleBaseClient

Expand Down Expand Up @@ -315,14 +346,47 @@


class GoogleStorageClient(GoogleBaseClient):
def __init__(self, gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None, **kwargs):
CHUNK_SIZE = 8 * 1024 * 1024
MAX_WORKERS = 8

def __init__(

Check warning on line 352 in hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py#L352

Method __init__ has a cyclomatic complexity of 9 (limit is 8)
self,
gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None,
thread_pool: Optional[ThreadPoolExecutor] = None,
**kwargs,
):
if 'timeout' not in kwargs and 'http_session' not in kwargs:
# Around May 2022, GCS started timing out a lot with our default 5s timeout
kwargs['timeout'] = aiohttp.ClientTimeout(total=20)

timeout = kwargs.get('timeout')
if isinstance(timeout, aiohttp.ClientTimeout):
self._timeout = timeout.total
elif isinstance(timeout, (int, float)):
self._timeout = timeout
else:
self._timeout = 20

if not thread_pool:
thread_pool = ThreadPoolExecutor()
self._thread_pool = thread_pool

super().__init__('https://storage.googleapis.com/storage/v1', **kwargs)
self._gcs_requester_pays_configuration = get_gcs_requester_pays_configuration(
gcs_requester_pays_configuration=gcs_requester_pays_configuration
)
credentials = kwargs.get('credentials')
if isinstance(credentials, GoogleServiceAccountCredentials):
gcp_credentials = service_account.Credentials.from_service_account_info(info=credentials.key)
elif isinstance(credentials, GoogleCredentials):
access_token = async_to_blocking(credentials.access_token())
gcp_credentials = Credentials(token=access_token)
elif isinstance(credentials, AnonymousCloudCredentials):
gcp_credentials = AnonymousCredentials()
else:
gcp_credentials = None

self._client = Client(credentials=gcp_credentials)

async def bucket_info(self, bucket: str) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -415,6 +479,44 @@
self._update_params_with_user_project(kwargs, bucket)
return PageIterator(self, f'/b/{bucket}/o', kwargs)

async def download_single_file(self, bucket: str, filename: str, dest: str) -> None:
user_project = self._get_user_project_for_bucket(bucket)
bucket_instance = self._client.bucket(bucket, user_project=user_project)
blob = bucket_instance.blob(filename)

dest_parent = os.path.dirname(dest)
if dest_parent:
if os.path.exists(dest_parent) and not os.path.isdir(dest_parent):
raise NotADirectoryError(dest_parent)
os.makedirs(dest_parent, exist_ok=True)
await blocking_to_async(
self._thread_pool,
blob.download_to_filename,
dest,
single_shot_download=True,
timeout=self._timeout,
)

async def download_large_file(self, bucket: str, src: str, dest: str) -> None:
user_project = self._get_user_project_for_bucket(bucket)
bucket_instance = self._client.bucket(bucket, user_project=user_project)
blob = bucket_instance.blob(src)

dest_parent = os.path.dirname(dest)
if dest_parent:
if os.path.exists(dest_parent) and not os.path.isdir(dest_parent):
raise NotADirectoryError(dest_parent)
os.makedirs(dest_parent, exist_ok=True)
await blocking_to_async(
self._thread_pool,
transfer_manager.download_chunks_concurrently,
blob=blob,
filename=dest,
chunk_size=self.CHUNK_SIZE,
download_kwargs={'timeout': self._timeout},
max_workers=self.MAX_WORKERS,
)

async def compose(self, bucket: str, names: List[str], destination: str, **kwargs) -> None:
assert destination
n = len(names)
Expand Down Expand Up @@ -455,6 +557,18 @@
if bucket in buckets:
params.update({'userProject': project})

def _get_user_project_for_bucket(self, bucket: str) -> Optional[str]:
if isinstance(self._gcs_requester_pays_configuration, str):
return self._gcs_requester_pays_configuration
elif isinstance(self._gcs_requester_pays_configuration, tuple):
project, buckets = self._gcs_requester_pays_configuration
if bucket in buckets:
return project
else:
return None
else:
return None


class GetObjectFileStatus(FileStatus):
def __init__(self, items: Dict[str, str], url: str):
Expand Down Expand Up @@ -646,6 +760,8 @@
if bucket_allow_list is None:
bucket_allow_list = []
self.allowed_storage_locations = bucket_allow_list
self._sema = WeightedSemaphore(self._storage_client.MAX_WORKERS * 10)
self._xfer_sema = WeightedSemaphore(self._storage_client.CHUNK_SIZE * 10)

@staticmethod
def schemes() -> Set[str]:
Expand Down Expand Up @@ -875,6 +991,57 @@
raise FileNotFoundError(url) from e
raise

async def copy_between_fs(
self,
srcfile: str,
srcstat: FileStatus,
destfile: str,
**kwargs,
):
if LocalAsyncFS.valid_url(destfile):
if kwargs.get('sema'):
sema = cast(WeightedSemaphore, kwargs.get('sema'))
else:
sema = self._sema
if kwargs.get('xfer_sema'):
xfer_sema = cast(WeightedSemaphore, kwargs.get('xfer_sema'))
else:
xfer_sema = self._xfer_sema

size = await srcstat.size()
if destfile.startswith('file://'):
local_dest = destfile[len('file://') :]
else:
local_dest = destfile
if size > self._storage_client.CHUNK_SIZE:
async with sema.acquire_manager(self._storage_client.MAX_WORKERS):
await self._copy_single_large_file(xfer_sema, srcfile, local_dest)
else:
await self._copy_single_local_file(xfer_sema, srcfile, local_dest, size)
else:
raise NotImplementedError

async def _copy_single_local_file(
self,
xfer_sema: WeightedSemaphore,
src: str,
dest: str,
size: int,
) -> None:
async with xfer_sema.acquire_manager(size):
bucket, name = self.get_bucket_and_name(src)
await retry_transient_errors(self._storage_client.download_single_file, bucket, name, dest)

async def _copy_single_large_file(
self,
xfer_sema: WeightedSemaphore,
src: str,
dest: str,
) -> None:
async with xfer_sema.acquire_manager(self._storage_client.CHUNK_SIZE * self._storage_client.MAX_WORKERS):
bucket, name = self.get_bucket_and_name(src)
await retry_transient_errors(self._storage_client.download_large_file, bucket, name, dest)

async def copy_within_gcs(
self, src: str, dest: str, callback: Optional[Callable[[Dict[str, Any], bool], None]] = None
) -> None:
Expand Down
Loading
Loading