Skip to content

Commit 21b6200

Browse files
Merge pull request #135 from synccomputingcode/PROD-2276/bump-version-to-1.10.1
[PROD-2276] Move CachedToken class to not break globals and imports
2 parents 5668e6d + de3d551 commit 21b6200

File tree

3 files changed

+110
-100
lines changed

3 files changed

+110
-100
lines changed

sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Library for leveraging the power of Sync"""
22

3-
__version__ = "1.10.1"
3+
__version__ = "1.10.2"
44

55
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

sync/clients/cache.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import logging
2+
from datetime import datetime, timedelta, timezone
3+
from typing import Optional, Tuple, Union, Callable, Type
4+
from pathlib import Path
5+
import json
6+
7+
from platformdirs import user_cache_dir
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class CachedToken:
13+
def __init__(self, token_refresh_before_expiry=timedelta(seconds=30)):
14+
self.token_refresh_before_expiry = token_refresh_before_expiry
15+
16+
cache = self._get_cached_token()
17+
18+
if cache:
19+
self._access_token, self._access_token_expires_at_utc = cache
20+
else:
21+
self._access_token: Optional[str] = None
22+
self._access_token_expires_at_utc: Optional[datetime] = None
23+
24+
@property
25+
def access_token(self) -> Optional[str]:
26+
return self._access_token if self.is_access_token_valid else None
27+
28+
@property
29+
def is_access_token_valid(self) -> bool:
30+
if not self._access_token:
31+
return False
32+
33+
if self._access_token_expires_at_utc:
34+
return datetime.now(tz=timezone.utc) < (
35+
self._access_token_expires_at_utc - self.token_refresh_before_expiry
36+
)
37+
38+
return False
39+
40+
def set_cached_token(self, access_token: str, expires_at_utc: datetime) -> None:
41+
self._access_token = access_token
42+
self._access_token_expires_at_utc = expires_at_utc
43+
self._set_cached_token()
44+
45+
def _set_cached_token(self) -> None:
46+
raise NotImplementedError
47+
48+
def _get_cached_token(self) -> Optional[Tuple[str, datetime]]:
49+
raise NotImplementedError
50+
51+
52+
class FileCachedToken(CachedToken):
53+
def __init__(self):
54+
self._cache_file = Path(user_cache_dir("syncsparkpy")) / "auth.json"
55+
56+
super().__init__()
57+
58+
def _get_cached_token(self) -> Optional[Tuple[str, datetime]]:
59+
# Cache is optional, we can fail to read it and not worry
60+
if self._cache_file.exists():
61+
try:
62+
cached_token = json.loads(self._cache_file.read_text())
63+
cached_access_token = cached_token["access_token"]
64+
cached_expiry = datetime.fromisoformat(cached_token["expires_at_utc"])
65+
return cached_access_token, cached_expiry
66+
except Exception as e:
67+
logger.warning(
68+
f"Failed to read cached access token @ {self._cache_file}", exc_info=e
69+
)
70+
71+
return None
72+
73+
def _set_cached_token(self) -> None:
74+
# Cache is optional, we can fail to read it and not worry
75+
try:
76+
self._cache_file.parent.mkdir(parents=True, exist_ok=True)
77+
self._cache_file.write_text(
78+
json.dumps(
79+
{
80+
"access_token": self._access_token,
81+
"expires_at_utc": self._access_token_expires_at_utc.isoformat(),
82+
}
83+
)
84+
)
85+
except Exception as e:
86+
logger.warning(
87+
f"Failed to write cached access token @ {self._cache_file}", exc_info=e
88+
)
89+
90+
91+
# Putting this here instead of config.py because circular imports and typing.
92+
ACCESS_TOKEN_CACHE_CLS_TYPE = Union[Type[CachedToken], Callable[[], CachedToken]]
93+
_access_token_cache_cls: ACCESS_TOKEN_CACHE_CLS_TYPE = FileCachedToken # Default to local file caching.
94+
95+
96+
def set_access_token_cache_cls(access_token_cache_cls: ACCESS_TOKEN_CACHE_CLS_TYPE) -> None:
97+
global _access_token_cache_cls
98+
_access_token_cache_cls = access_token_cache_cls
99+
100+
101+
def get_access_token_cache_cache() -> ACCESS_TOKEN_CACHE_CLS_TYPE:
102+
global _access_token_cache_cls
103+
return _access_token_cache_cls

sync/clients/sync.py

Lines changed: 6 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,15 @@
1-
import json
21
import logging
3-
from datetime import datetime, timedelta, timezone
4-
from pathlib import Path
5-
from typing import Generator, Optional, Tuple, Type, Callable, Union
2+
from typing import Generator, Optional
63

74
import dateutil.parser
85
import httpx
9-
from platformdirs import user_cache_dir
106

117
from ..config import API_KEY, CONFIG, APIKey
128
from . import USER_AGENT, RetryableHTTPClient, encode_json
13-
9+
from .cache import ACCESS_TOKEN_CACHE_CLS_TYPE, FileCachedToken, get_access_token_cache_cache
1410
logger = logging.getLogger(__name__)
1511

1612

17-
class CachedToken:
18-
def __init__(self, token_refresh_before_expiry=timedelta(seconds=30)):
19-
self.token_refresh_before_expiry = token_refresh_before_expiry
20-
21-
cache = self._get_cached_token()
22-
23-
if cache:
24-
self._access_token, self._access_token_expires_at_utc = cache
25-
else:
26-
self._access_token: Optional[str] = None
27-
self._access_token_expires_at_utc: Optional[datetime] = None
28-
29-
@property
30-
def access_token(self) -> Optional[str]:
31-
return self._access_token if self.is_access_token_valid else None
32-
33-
@property
34-
def is_access_token_valid(self) -> bool:
35-
if not self._access_token:
36-
return False
37-
38-
if self._access_token_expires_at_utc:
39-
return datetime.now(tz=timezone.utc) < (
40-
self._access_token_expires_at_utc - self.token_refresh_before_expiry
41-
)
42-
43-
return False
44-
45-
def set_cached_token(self, access_token: str, expires_at_utc: datetime) -> None:
46-
self._access_token = access_token
47-
self._access_token_expires_at_utc = expires_at_utc
48-
self._set_cached_token()
49-
50-
def _set_cached_token(self) -> None:
51-
raise NotImplementedError
52-
53-
def _get_cached_token(self) -> Optional[Tuple[str, datetime]]:
54-
raise NotImplementedError
55-
56-
57-
class FileCachedToken(CachedToken):
58-
def __init__(self):
59-
self._cache_file = Path(user_cache_dir("syncsparkpy")) / "auth.json"
60-
61-
super().__init__()
62-
63-
def _get_cached_token(self) -> Optional[Tuple[str, datetime]]:
64-
# Cache is optional, we can fail to read it and not worry
65-
if self._cache_file.exists():
66-
try:
67-
cached_token = json.loads(self._cache_file.read_text())
68-
cached_access_token = cached_token["access_token"]
69-
cached_expiry = datetime.fromisoformat(cached_token["expires_at_utc"])
70-
return cached_access_token, cached_expiry
71-
except Exception as e:
72-
logger.warning(
73-
f"Failed to read cached access token @ {self._cache_file}", exc_info=e
74-
)
75-
76-
return None
77-
78-
def _set_cached_token(self) -> None:
79-
# Cache is optional, we can fail to read it and not worry
80-
try:
81-
self._cache_file.parent.mkdir(parents=True, exist_ok=True)
82-
self._cache_file.write_text(
83-
json.dumps(
84-
{
85-
"access_token": self._access_token,
86-
"expires_at_utc": self._access_token_expires_at_utc.isoformat(),
87-
}
88-
)
89-
)
90-
except Exception as e:
91-
logger.warning(
92-
f"Failed to write cached access token @ {self._cache_file}", exc_info=e
93-
)
94-
95-
96-
# Putting this here instead of config.py because circular imports and typing.
97-
_access_token_cache_cls = FileCachedToken # Default to local file caching.
98-
ACCESS_TOKEN_CACHE_CLS_TYPE = Union[Type[CachedToken], Callable[[], CachedToken]]
99-
100-
101-
def set_access_token_cache_cls(access_token_cache_cls: ACCESS_TOKEN_CACHE_CLS_TYPE) -> None:
102-
global _access_token_cache_cls
103-
_access_token_cache_cls = access_token_cache_cls
104-
105-
10613
class SyncAuth(httpx.Auth):
10714
requires_response_body = True
10815

@@ -425,7 +332,7 @@ async def _send(self, request: httpx.Request) -> dict:
425332
return {"error": {"code": "Sync API Error", "message": "Transaction failure"}}
426333

427334

428-
_sync_client: SyncClient = None
335+
_sync_client: Optional[SyncClient] = None
429336

430337

431338
def get_default_client() -> SyncClient:
@@ -434,12 +341,12 @@ def get_default_client() -> SyncClient:
434341
_sync_client = SyncClient(
435342
CONFIG.api_url,
436343
API_KEY,
437-
access_token_cache_cls=_access_token_cache_cls
344+
access_token_cache_cls=get_access_token_cache_cache()
438345
)
439346
return _sync_client
440347

441348

442-
_async_sync_client: ASyncClient = None
349+
_async_sync_client: Optional[ASyncClient] = None
443350

444351

445352
def get_default_async_client() -> ASyncClient:
@@ -448,6 +355,6 @@ def get_default_async_client() -> ASyncClient:
448355
_async_sync_client = ASyncClient(
449356
CONFIG.api_url,
450357
API_KEY,
451-
access_token_cache_cls=_access_token_cache_cls
358+
access_token_cache_cls=get_access_token_cache_cache()
452359
)
453360
return _async_sync_client

0 commit comments

Comments
 (0)