Skip to content

Commit 2d2d170

Browse files
authored
Merge pull request #17 from corruptmane/develop
2 parents c98d848 + 0f73fca commit 2d2d170

File tree

4 files changed

+203
-2
lines changed

4 files changed

+203
-2
lines changed

taskiq_nats/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
PushBasedJetStreamBroker,
1212
)
1313
from taskiq_nats.result_backend import NATSObjectStoreResultBackend
14+
from taskiq_nats.schedule_source import NATSKeyValueScheduleSource
1415

1516
__all__ = [
1617
"NatsBroker",
1718
"PushBasedJetStreamBroker",
1819
"PullBasedJetStreamBroker",
1920
"NATSObjectStoreResultBackend",
21+
"NATSKeyValueScheduleSource",
2022
]

taskiq_nats/schedule_source.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import logging
2+
from typing import Any, Final, List, Optional, Union
3+
4+
import nats
5+
from nats import NATS
6+
from nats.js import JetStreamContext
7+
from nats.js.errors import BucketNotFoundError, NoKeysError
8+
from nats.js.kv import KeyValue
9+
from taskiq import ScheduledTask, ScheduleSource
10+
from taskiq.abc.serializer import TaskiqSerializer
11+
from taskiq.compat import model_dump, model_validate
12+
from taskiq.serializers import PickleSerializer
13+
14+
log = logging.getLogger(__name__)
15+
16+
17+
class NATSKeyValueScheduleSource(ScheduleSource):
18+
"""
19+
Source of schedules for NATS Key-Value storage.
20+
21+
This class allows you to store schedules in NATS Key-Value storage.
22+
Also it supports dynamic schedules.
23+
"""
24+
25+
def __init__(
26+
self,
27+
servers: Union[str, List[str]],
28+
bucket_name: str = "taskiq_schedules",
29+
prefix: str = "schedule",
30+
serializer: Optional[TaskiqSerializer] = None,
31+
**connect_options: Any,
32+
) -> None:
33+
"""Construct new result backend.
34+
35+
:param servers: NATS servers.
36+
:param bucket_name: name of the bucket where schedules would be stored.
37+
:param prefix: prefix for nats kv storage schedule keys.
38+
:param serializer: serializer for data.
39+
:param connect_kwargs: additional arguments for nats `connect()` method.
40+
"""
41+
self.servers: Final = servers
42+
self.bucket_name: Final = bucket_name
43+
self.prefix: Final = prefix
44+
self.serializer = serializer or PickleSerializer()
45+
self.connect_options: Final = connect_options
46+
47+
self.nats_client: NATS
48+
self.nats_jetstream: JetStreamContext
49+
self.kv: KeyValue
50+
51+
async def startup(self) -> None:
52+
"""Create new connection to NATS.
53+
54+
Initialize JetStream context and new KeyValue instance.
55+
"""
56+
self.nats_client = await nats.connect(
57+
servers=self.servers,
58+
**self.connect_options,
59+
)
60+
self.nats_jetstream = self.nats_client.jetstream()
61+
62+
try:
63+
self.kv = await self.nats_jetstream.key_value(self.bucket_name)
64+
except BucketNotFoundError:
65+
self.kv = await self.nats_jetstream.create_key_value(
66+
bucket=self.bucket_name,
67+
)
68+
69+
async def shutdown(self) -> None:
70+
"""Close nats connection."""
71+
if self.nats_client.is_closed:
72+
return
73+
await self.nats_client.close()
74+
75+
async def delete_schedule(self, schedule_id: str) -> None:
76+
"""Remove schedule by id."""
77+
await self.kv.delete(f"{self.prefix}.{schedule_id}")
78+
79+
async def add_schedule(self, schedule: ScheduledTask) -> None:
80+
"""
81+
Add schedule to NATS Key-Value storage.
82+
83+
:param schedule: schedule to add.
84+
:param schedule_id: schedule id.
85+
"""
86+
await self.kv.put(
87+
f"{self.prefix}.{schedule.schedule_id}",
88+
self.serializer.dumpb(model_dump(schedule)),
89+
)
90+
91+
async def get_schedules(self) -> List[ScheduledTask]:
92+
"""
93+
Get all schedules from NATS Key-Value storage.
94+
95+
This method is used by scheduler to get all schedules.
96+
97+
:return: list of schedules.
98+
"""
99+
try:
100+
schedules = await self.kv.history(f"{self.prefix}.*")
101+
except NoKeysError:
102+
return []
103+
104+
return [
105+
model_validate(ScheduledTask, self.serializer.loadb(schedule.value))
106+
for schedule in schedules
107+
if schedule and schedule.value
108+
]
109+
110+
async def post_send(self, task: ScheduledTask) -> None:
111+
"""Delete a task after it's completed."""
112+
if task.time is not None:
113+
await self.delete_schedule(task.schedule_id)

tests/test_schedule_source.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import datetime as dt
2+
import uuid
3+
from typing import List
4+
5+
import pytest
6+
from taskiq import ScheduledTask
7+
8+
from taskiq_nats import NATSKeyValueScheduleSource
9+
10+
11+
@pytest.mark.anyio
12+
async def test_set_schedule(nats_urls: List[str]) -> None:
13+
prefix = uuid.uuid4().hex
14+
source = NATSKeyValueScheduleSource(servers=nats_urls, prefix=prefix)
15+
await source.startup()
16+
schedule = ScheduledTask(
17+
task_name="test_task",
18+
labels={},
19+
args=[],
20+
kwargs={},
21+
cron="* * * * *",
22+
)
23+
await source.add_schedule(schedule)
24+
schedules = await source.get_schedules()
25+
assert schedules == [schedule]
26+
await source.shutdown()
27+
28+
29+
@pytest.mark.anyio
30+
async def test_delete_schedule(nats_urls: List[str]) -> None:
31+
prefix = uuid.uuid4().hex
32+
source = NATSKeyValueScheduleSource(servers=nats_urls, prefix=prefix)
33+
await source.startup()
34+
schedule = ScheduledTask(
35+
task_name="test_task",
36+
labels={},
37+
args=[],
38+
kwargs={},
39+
cron="* * * * *",
40+
)
41+
await source.add_schedule(schedule)
42+
schedules = await source.get_schedules()
43+
assert schedules == [schedule]
44+
await source.delete_schedule(schedule.schedule_id)
45+
schedules = await source.get_schedules()
46+
# Schedules are empty.
47+
assert not schedules
48+
await source.shutdown()
49+
50+
51+
@pytest.mark.anyio
52+
async def test_post_run_cron(nats_urls: List[str]) -> None:
53+
prefix = uuid.uuid4().hex
54+
source = NATSKeyValueScheduleSource(servers=nats_urls, prefix=prefix)
55+
await source.startup()
56+
schedule = ScheduledTask(
57+
task_name="test_task",
58+
labels={},
59+
args=[],
60+
kwargs={},
61+
cron="* * * * *",
62+
)
63+
await source.add_schedule(schedule)
64+
assert await source.get_schedules() == [schedule]
65+
await source.post_send(schedule)
66+
assert await source.get_schedules() == [schedule]
67+
await source.shutdown()
68+
69+
70+
@pytest.mark.anyio
71+
async def test_post_run_time(nats_urls: List[str]) -> None:
72+
prefix = uuid.uuid4().hex
73+
source = NATSKeyValueScheduleSource(servers=nats_urls, prefix=prefix)
74+
await source.startup()
75+
schedule = ScheduledTask(
76+
task_name="test_task",
77+
labels={},
78+
args=[],
79+
kwargs={},
80+
time=dt.datetime(2000, 1, 1),
81+
)
82+
await source.add_schedule(schedule)
83+
assert await source.get_schedules() == [schedule]
84+
await source.post_send(schedule)
85+
assert await source.get_schedules() == []
86+
await source.shutdown()

tests/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
async def read_message(broker: AsyncBroker) -> Union[bytes, AckableMessage]:
77
"""
8-
Read signle message from the broker's listen method.
8+
Read single message from the broker's listen method.
99
1010
:param broker: current broker.
11-
:return: firs message.
11+
:return: first message.
1212
"""
1313
msg: Union[bytes, AckableMessage] = b"error"
1414
async for message in broker.listen():

0 commit comments

Comments
 (0)