1
- import json
2
1
import logging
3
- from datetime import datetime , timedelta , timezone
4
- from pathlib import Path
5
- from typing import Generator , Optional , Tuple , Type , Callable , Union
2
+ from typing import Generator , Optional
6
3
7
4
import dateutil .parser
8
5
import httpx
9
- from platformdirs import user_cache_dir
10
6
11
7
from ..config import API_KEY , CONFIG , APIKey
12
8
from . import USER_AGENT , RetryableHTTPClient , encode_json
13
-
9
+ from . cache import ACCESS_TOKEN_CACHE_CLS_TYPE , FileCachedToken , get_access_token_cache_cache
14
10
logger = logging .getLogger (__name__ )
15
11
16
12
17
- class CachedToken :
18
- def __init__ (self , token_refresh_before_expiry = timedelta (seconds = 30 )):
19
- self .token_refresh_before_expiry = token_refresh_before_expiry
20
-
21
- cache = self ._get_cached_token ()
22
-
23
- if cache :
24
- self ._access_token , self ._access_token_expires_at_utc = cache
25
- else :
26
- self ._access_token : Optional [str ] = None
27
- self ._access_token_expires_at_utc : Optional [datetime ] = None
28
-
29
- @property
30
- def access_token (self ) -> Optional [str ]:
31
- return self ._access_token if self .is_access_token_valid else None
32
-
33
- @property
34
- def is_access_token_valid (self ) -> bool :
35
- if not self ._access_token :
36
- return False
37
-
38
- if self ._access_token_expires_at_utc :
39
- return datetime .now (tz = timezone .utc ) < (
40
- self ._access_token_expires_at_utc - self .token_refresh_before_expiry
41
- )
42
-
43
- return False
44
-
45
- def set_cached_token (self , access_token : str , expires_at_utc : datetime ) -> None :
46
- self ._access_token = access_token
47
- self ._access_token_expires_at_utc = expires_at_utc
48
- self ._set_cached_token ()
49
-
50
- def _set_cached_token (self ) -> None :
51
- raise NotImplementedError
52
-
53
- def _get_cached_token (self ) -> Optional [Tuple [str , datetime ]]:
54
- raise NotImplementedError
55
-
56
-
57
- class FileCachedToken (CachedToken ):
58
- def __init__ (self ):
59
- self ._cache_file = Path (user_cache_dir ("syncsparkpy" )) / "auth.json"
60
-
61
- super ().__init__ ()
62
-
63
- def _get_cached_token (self ) -> Optional [Tuple [str , datetime ]]:
64
- # Cache is optional, we can fail to read it and not worry
65
- if self ._cache_file .exists ():
66
- try :
67
- cached_token = json .loads (self ._cache_file .read_text ())
68
- cached_access_token = cached_token ["access_token" ]
69
- cached_expiry = datetime .fromisoformat (cached_token ["expires_at_utc" ])
70
- return cached_access_token , cached_expiry
71
- except Exception as e :
72
- logger .warning (
73
- f"Failed to read cached access token @ { self ._cache_file } " , exc_info = e
74
- )
75
-
76
- return None
77
-
78
- def _set_cached_token (self ) -> None :
79
- # Cache is optional, we can fail to read it and not worry
80
- try :
81
- self ._cache_file .parent .mkdir (parents = True , exist_ok = True )
82
- self ._cache_file .write_text (
83
- json .dumps (
84
- {
85
- "access_token" : self ._access_token ,
86
- "expires_at_utc" : self ._access_token_expires_at_utc .isoformat (),
87
- }
88
- )
89
- )
90
- except Exception as e :
91
- logger .warning (
92
- f"Failed to write cached access token @ { self ._cache_file } " , exc_info = e
93
- )
94
-
95
-
96
- # Putting this here instead of config.py because circular imports and typing.
97
- _access_token_cache_cls = FileCachedToken # Default to local file caching.
98
- ACCESS_TOKEN_CACHE_CLS_TYPE = Union [Type [CachedToken ], Callable [[], CachedToken ]]
99
-
100
-
101
- def set_access_token_cache_cls (access_token_cache_cls : ACCESS_TOKEN_CACHE_CLS_TYPE ) -> None :
102
- global _access_token_cache_cls
103
- _access_token_cache_cls = access_token_cache_cls
104
-
105
-
106
13
class SyncAuth (httpx .Auth ):
107
14
requires_response_body = True
108
15
@@ -425,7 +332,7 @@ async def _send(self, request: httpx.Request) -> dict:
425
332
return {"error" : {"code" : "Sync API Error" , "message" : "Transaction failure" }}
426
333
427
334
428
- _sync_client : SyncClient = None
335
+ _sync_client : Optional [ SyncClient ] = None
429
336
430
337
431
338
def get_default_client () -> SyncClient :
@@ -434,12 +341,12 @@ def get_default_client() -> SyncClient:
434
341
_sync_client = SyncClient (
435
342
CONFIG .api_url ,
436
343
API_KEY ,
437
- access_token_cache_cls = _access_token_cache_cls
344
+ access_token_cache_cls = get_access_token_cache_cache ()
438
345
)
439
346
return _sync_client
440
347
441
348
442
- _async_sync_client : ASyncClient = None
349
+ _async_sync_client : Optional [ ASyncClient ] = None
443
350
444
351
445
352
def get_default_async_client () -> ASyncClient :
@@ -448,6 +355,6 @@ def get_default_async_client() -> ASyncClient:
448
355
_async_sync_client = ASyncClient (
449
356
CONFIG .api_url ,
450
357
API_KEY ,
451
- access_token_cache_cls = _access_token_cache_cls
358
+ access_token_cache_cls = get_access_token_cache_cache ()
452
359
)
453
360
return _async_sync_client
0 commit comments