diff --git a/burr/tracking/s3client.py b/burr/tracking/s3client.py index bb5de6ae..e2ce9c62 100644 --- a/burr/tracking/s3client.py +++ b/burr/tracking/s3client.py @@ -17,6 +17,7 @@ from burr.core import Action, ApplicationGraph, State, serde from burr.integrations.base import require_plugin from burr.tracking.base import SyncTrackingClient +from burr.tracking.client import StateKey, StreamState from burr.tracking.common.models import ( ApplicationMetadataModel, ApplicationModel, @@ -25,6 +26,9 @@ BeginSpanModel, EndEntryModel, EndSpanModel, + EndStreamModel, + FirstItemStreamModel, + InitializeStreamModel, PointerModel, ) from burr.visibility import ActionSpan @@ -86,7 +90,16 @@ def _allowed_project_name(project_name: str, on_windows: bool) -> bool: return bool(re.match(pattern, project_name)) -EventType = Union[BeginEntryModel, EndEntryModel, BeginSpanModel, EndSpanModel, AttributeModel] +EventType = Union[ + BeginEntryModel, + EndEntryModel, + BeginSpanModel, + EndSpanModel, + AttributeModel, + InitializeStreamModel, + FirstItemStreamModel, + EndStreamModel, +] def unique_ordered_prefix() -> str: @@ -171,6 +184,7 @@ def __init__( self.initialized = False self.running = True self.init() + self.stream_state: Dict[StateKey, StreamState] = dict() def _get_time_partition(self): time = datetime.datetime.utcnow().isoformat() @@ -398,3 +412,85 @@ def post_end_span( span_id=span.uid, ) self.submit_log_event(end_span_model, app_id, partition_key) + + def copy(self): + return S3TrackingClient( + project=self.project, + bucket=self.bucket, + region=self.region, + endpoint_url=self.endpoint_url, + non_blocking=self.non_blocking, + serde_kwargs=self.serde_kwargs, + unique_tracker_id=self.unique_tracker_id, + flush_interval=self.flush_interval, + ) + + def pre_start_stream( + self, + *, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + initialize_stream_model = InitializeStreamModel( + action_sequence_id=sequence_id, + span_id=None, + stream_init_time=system.now(), + ) + self.submit_log_event(initialize_stream_model, app_id, partition_key) + self.stream_state[app_id, action, partition_key] = StreamState( + stream_init_time=system.now(), + count=0, + ) + + def post_stream_item( + self, + *, + item: Any, + item_index: int, + stream_initialize_time: datetime.datetime, + first_stream_item_start_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + stream_state = self.stream_state[app_id, action, partition_key] + if stream_state.count == 0: + stream_state.count += 1 + self.submit_log_event( + FirstItemStreamModel( + action_sequence_id=sequence_id, + span_id=None, + first_item_time=system.now(), + ), + app_id, + partition_key, + ) + else: + stream_state.count += 1 + + def post_end_stream( + self, + *, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + stream_state = self.stream_state[app_id, action, partition_key] + self.submit_log_event( + EndStreamModel( + action_sequence_id=sequence_id, + span_id=None, + end_time=system.now(), + items_streamed=stream_state.count, + ), + app_id, + partition_key, + ) + del self.stream_state[app_id, action, partition_key]