diff --git a/quixstreams/dataframe/utils.py b/quixstreams/dataframe/utils.py index b3cf9cba8..bb793afa4 100644 --- a/quixstreams/dataframe/utils.py +++ b/quixstreams/dataframe/utils.py @@ -1,4 +1,4 @@ -from datetime import timedelta +from datetime import datetime, timedelta from typing import Union @@ -22,3 +22,8 @@ def ensure_milliseconds(delta: Union[int, timedelta]) -> int: f'Timedelta must be either "int" representing milliseconds ' f'or "datetime.timedelta", got "{type(delta)}"' ) + + +def now() -> int: + # TODO: Should be UTC time + return int(datetime.now().timestamp() * 1000) diff --git a/quixstreams/dataframe/windows/base.py b/quixstreams/dataframe/windows/base.py index 8040b2774..04d14de1b 100644 --- a/quixstreams/dataframe/windows/base.py +++ b/quixstreams/dataframe/windows/base.py @@ -17,7 +17,11 @@ from typing_extensions import TypeAlias from quixstreams.context import message_context -from quixstreams.core.stream import TransformExpandedCallback +from quixstreams.core.stream import ( + Stream, + TransformExpandedCallback, + TransformFunction, +) from quixstreams.core.stream.exceptions import InvalidOperation from quixstreams.models.topics.manager import TopicManager from quixstreams.state import WindowedPartitionTransaction @@ -42,6 +46,8 @@ Iterable[Message], ] +WallClockCallback = Callable[[WindowedPartitionTransaction], Iterable[Message]] + class Window(abc.ABC): def __init__( @@ -69,6 +75,13 @@ def process_window( ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: pass + @abstractmethod + def process_wall_clock( + self, + transaction: WindowedPartitionTransaction, + ) -> Iterable[WindowKeyResult]: + pass + def register_store(self) -> None: TopicManager.ensure_topics_copartitioned(*self._dataframe.topics) # Create a config for the changelog topic based on the underlying SDF topics @@ -83,6 +96,7 @@ def _apply_window( self, func: TransformRecordCallbackExpandedWindowed, name: str, + wall_clock_func: WallClockCallback, ) -> "StreamingDataFrame": self.register_store() @@ -92,12 +106,24 @@ def _apply_window( processing_context=self._dataframe.processing_context, store_name=name, ) + wall_clock_transform_func = _as_wall_clock( + func=wall_clock_func, + stream_id=self._dataframe.stream_id, + processing_context=self._dataframe.processing_context, + store_name=name, + ) # Manually modify the Stream and clone the source StreamingDataFrame # to avoid adding "transform" API to it. # Transform callbacks can modify record key and timestamp, # and it's prone to misuse. - stream = self._dataframe.stream.add_transform(func=windowed_func, expand=True) - return self._dataframe.__dataframe_clone__(stream=stream) + windowed_stream = self._dataframe.stream.add_transform( + func=windowed_func, expand=True + ) + wall_clock_stream = Stream( + func=TransformFunction(wall_clock_transform_func, expand=True) + ) + sdf = self._dataframe.__dataframe_clone__(stream=windowed_stream) + return sdf.concat_wall_clock(wall_clock_stream) def final(self) -> "StreamingDataFrame": """ @@ -140,9 +166,17 @@ def window_callback( for key, window in expired_windows: yield (window, key, window["start"], None) + def wall_clock_callback( + transaction: WindowedPartitionTransaction, + ) -> Iterable[Message]: + # TODO: Check if this will work for sliding windows + for key, window in self.process_wall_clock(transaction): + yield (window, key, window["start"], None) + return self._apply_window( func=window_callback, name=self._name, + wall_clock_func=wall_clock_callback, ) def current(self) -> "StreamingDataFrame": @@ -188,7 +222,17 @@ def window_callback( for key, window in updated_windows: yield (window, key, window["start"], None) - return self._apply_window(func=window_callback, name=self._name) + def wall_clock_callback( + transaction: WindowedPartitionTransaction, + ) -> Iterable[Message]: + # TODO: Implement wall_clock callback + return [] + + return self._apply_window( + func=window_callback, + name=self._name, + wall_clock_func=wall_clock_callback, + ) # Implemented by SingleAggregationWindowMixin and MultiAggregationWindowMixin # Single aggregation and multi aggregation windows store aggregations and collections @@ -424,6 +468,28 @@ def wrapper( return wrapper +def _as_wall_clock( + func: WallClockCallback, + processing_context: "ProcessingContext", + store_name: str, + stream_id: str, +) -> TransformExpandedCallback: + @functools.wraps(func) + def wrapper( + value: Any, key: Any, timestamp: int, headers: Any + ) -> Iterable[Message]: + ctx = message_context() + transaction = cast( + WindowedPartitionTransaction, + processing_context.checkpoint.get_store_transaction( + stream_id=stream_id, partition=ctx.partition, store_name=store_name + ), + ) + return func(transaction) + + return wrapper + + class WindowOnLateCallback(Protocol): def __call__( self, diff --git a/quixstreams/dataframe/windows/count_based.py b/quixstreams/dataframe/windows/count_based.py index 57c6b36e5..973d23cce 100644 --- a/quixstreams/dataframe/windows/count_based.py +++ b/quixstreams/dataframe/windows/count_based.py @@ -189,6 +189,12 @@ def process_window( state.set(key=self.STATE_KEY, value=data) return updated_windows, expired_windows + def process_wall_clock( + self, + transaction: WindowedPartitionTransaction, + ) -> Iterable[WindowKeyResult]: + return [] + def _get_collection_start_id(self, window: CountWindowData) -> int: start_id = window.get("collection_start_id", _MISSING) if start_id is _MISSING: diff --git a/quixstreams/dataframe/windows/time_based.py b/quixstreams/dataframe/windows/time_based.py index c403cfdfa..4bfed4857 100644 --- a/quixstreams/dataframe/windows/time_based.py +++ b/quixstreams/dataframe/windows/time_based.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Literal, Optional from quixstreams.context import message_context +from quixstreams.dataframe.utils import now from quixstreams.state import WindowedPartitionTransaction, WindowedState from .base import ( @@ -200,11 +201,23 @@ def process_window( return updated_windows, expired_windows + def process_wall_clock( + self, + transaction: WindowedPartitionTransaction, + ) -> Iterable[WindowKeyResult]: + return self.expire_by_partition( + transaction=transaction, + max_expired_end=now() - self._grace_ms, + collect=self.collect, + advance_last_expired_timestamp=False, + ) + def expire_by_partition( self, transaction: WindowedPartitionTransaction, max_expired_end: int, collect: bool, + advance_last_expired_timestamp: bool = True, ) -> Iterable[WindowKeyResult]: for ( window_start, @@ -214,6 +227,7 @@ def expire_by_partition( step_ms=self._step_ms if self._step_ms else self._duration_ms, collect=collect, delete=True, + advance_last_expired_timestamp=advance_last_expired_timestamp, ): yield key, self._results(aggregated, collected, window_start, window_end) diff --git a/quixstreams/state/rocksdb/windowed/transaction.py b/quixstreams/state/rocksdb/windowed/transaction.py index 3779b3e29..dfa3fa12a 100644 --- a/quixstreams/state/rocksdb/windowed/transaction.py +++ b/quixstreams/state/rocksdb/windowed/transaction.py @@ -298,6 +298,7 @@ def expire_all_windows( step_ms: int, delete: bool = True, collect: bool = False, + advance_last_expired_timestamp: bool = True, ) -> Iterable[ExpiredWindowDetail]: """ Get all expired windows for all prefix from RocksDB up to the specified `max_end_time` timestamp. @@ -360,9 +361,12 @@ def expire_all_windows( if collect: self.delete_from_collection(end=start, prefix=prefix) - self._set_timestamp( - prefix=b"", cache=self._last_expired_timestamps, timestamp_ms=last_expired - ) + if advance_last_expired_timestamp: + self._set_timestamp( + prefix=b"", + cache=self._last_expired_timestamps, + timestamp_ms=last_expired, + ) def delete_windows( self, max_start_time: int, delete_values: bool, prefix: bytes diff --git a/quixstreams/state/types.py b/quixstreams/state/types.py index c80c9e2ad..9f07b9499 100644 --- a/quixstreams/state/types.py +++ b/quixstreams/state/types.py @@ -378,6 +378,7 @@ def expire_all_windows( step_ms: int, delete: bool = True, collect: bool = False, + advance_last_expired_timestamp: bool = True, ) -> Iterable[ExpiredWindowDetail[V]]: """ Get all expired windows for all prefix from RocksDB up to the specified `max_start_time` timestamp. @@ -388,6 +389,7 @@ def expire_all_windows( :param max_end_time: The timestamp up to which windows are considered expired, inclusive. :param delete: If True, expired windows will be deleted. :param collect: If True, values will be collected into windows. + :param advance_last_expired_timestamp: If True, the last expired timestamp will be persisted. """ ...