Skip to content

Commit

Permalink
Brings S3 tracker to parity with standard tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Aug 26, 2024
1 parent 64c9d73 commit f98f09b
Showing 1 changed file with 97 additions and 1 deletion.
98 changes: 97 additions & 1 deletion burr/tracking/s3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from burr.core import Action, ApplicationGraph, State, serde
from burr.integrations.base import require_plugin
from burr.tracking.base import SyncTrackingClient
from burr.tracking.client import StateKey, StreamState
from burr.tracking.common.models import (
ApplicationMetadataModel,
ApplicationModel,
Expand All @@ -25,6 +26,9 @@
BeginSpanModel,
EndEntryModel,
EndSpanModel,
EndStreamModel,
FirstItemStreamModel,
InitializeStreamModel,
PointerModel,
)
from burr.visibility import ActionSpan
Expand Down Expand Up @@ -86,7 +90,16 @@ def _allowed_project_name(project_name: str, on_windows: bool) -> bool:
return bool(re.match(pattern, project_name))


EventType = Union[BeginEntryModel, EndEntryModel, BeginSpanModel, EndSpanModel, AttributeModel]
EventType = Union[
BeginEntryModel,
EndEntryModel,
BeginSpanModel,
EndSpanModel,
AttributeModel,
InitializeStreamModel,
FirstItemStreamModel,
EndStreamModel,
]


def unique_ordered_prefix() -> str:
Expand Down Expand Up @@ -171,6 +184,7 @@ def __init__(
self.initialized = False
self.running = True
self.init()
self.stream_state: Dict[StateKey, StreamState] = dict()

def _get_time_partition(self):
time = datetime.datetime.utcnow().isoformat()
Expand Down Expand Up @@ -398,3 +412,85 @@ def post_end_span(
span_id=span.uid,
)
self.submit_log_event(end_span_model, app_id, partition_key)

def copy(self):
return S3TrackingClient(
project=self.project,
bucket=self.bucket,
region=self.region,
endpoint_url=self.endpoint_url,
non_blocking=self.non_blocking,
serde_kwargs=self.serde_kwargs,
unique_tracker_id=self.unique_tracker_id,
flush_interval=self.flush_interval,
)

def pre_start_stream(
self,
*,
action: str,
sequence_id: int,
app_id: str,
partition_key: Optional[str],
**future_kwargs: Any,
):
initialize_stream_model = InitializeStreamModel(
action_sequence_id=sequence_id,
span_id=None,
stream_init_time=system.now(),
)
self.submit_log_event(initialize_stream_model, app_id, partition_key)
self.stream_state[app_id, action, partition_key] = StreamState(
stream_init_time=system.now(),
count=0,
)

def post_stream_item(
self,
*,
item: Any,
item_index: int,
stream_initialize_time: datetime.datetime,
first_stream_item_start_time: datetime.datetime,
action: str,
sequence_id: int,
app_id: str,
partition_key: Optional[str],
**future_kwargs: Any,
):
stream_state = self.stream_state[app_id, action, partition_key]
if stream_state.count == 0:
stream_state.count += 1
self.submit_log_event(
FirstItemStreamModel(
action_sequence_id=sequence_id,
span_id=None,
first_item_time=system.now(),
),
app_id,
partition_key,
)
else:
stream_state.count += 1

def post_end_stream(
self,
*,
action: str,
sequence_id: int,
app_id: str,
partition_key: Optional[str],
**future_kwargs: Any,
):
stream_state = self.stream_state[app_id, action, partition_key]
self.submit_log_event(
EndStreamModel(
action_sequence_id=sequence_id,
span_id=None,
end_time=system.now(),
items_streamed=stream_state.count,
),
app_id,
partition_key,
)
del self.stream_state[app_id, action, partition_key]

0 comments on commit f98f09b

Please sign in to comment.