Skip to content

Commit

Permalink
Extract event wrap and unwrap code so it can be called from mlrun (#531)
Browse files Browse the repository at this point in the history
* Extract event wrap and unwrap code so it can be called from mlrun

[ML-7202](https://iguazio.atlassian.net/browse/ML-7202)

* Fix
  • Loading branch information
gtopper authored Jul 23, 2024
1 parent 45674a0 commit 64f7699
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 46 deletions.
14 changes: 0 additions & 14 deletions storey/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ class Event:
:type awaitable_result: AwaitableResult (Optional)
"""

_serialize_event_marker = "full_event_wrapper"
_serialize_fields = ["key", "id"]

def __init__(
self,
body: object,
Expand Down Expand Up @@ -93,17 +90,6 @@ def __eq__(self, other):
def __str__(self):
return f"Event(id={self.id}, key={str(self.key)}, body={self.body})"

@staticmethod
def wrap_for_serialization(event, record):
record = {Event._serialize_event_marker: True, "body": record}
for field in Event._serialize_fields:
val = getattr(event, field)
if val is not None:
if isinstance(val, datetime):
val = datetime.isoformat(val)
record[field] = val
return record


class V3ioError(Exception):
pass
Expand Down
37 changes: 8 additions & 29 deletions storey/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
from .dtypes import Event, _termination_obj
from .flow import Complete, Flow, WithUUID
from .queue import SimpleAsyncQueue
from .utils import find_filters, find_partitions, url_to_file_system
from .utils import (
find_filters,
find_partitions,
unpack_event_if_wrapped,
url_to_file_system,
)


class AwaitableResult:
Expand Down Expand Up @@ -79,20 +84,6 @@ def _set_error(self, ex):
self._set_result(ex)


def _convert_to_datetime(obj, time_format: Optional[str] = None):
if isinstance(obj, datetime):
return obj
elif isinstance(obj, float) or isinstance(obj, int):
return datetime.fromtimestamp(obj, tz=pytz.utc)
elif isinstance(obj, str):
if time_format is None:
return datetime.fromisoformat(obj)
else:
return datetime.strptime(obj, time_format)
else:
raise ValueError(f"Could not parse '{obj}' (of type {type(obj)}) as a time.")


class FlowControllerBase(WithUUID):
def __init__(
self,
Expand All @@ -106,20 +97,8 @@ def __init__(
def _build_event(self, element, key):
element_is_event = hasattr(element, "id")
if element_is_event:
if isinstance(element.body, dict) and element.body.get(Event._serialize_event_marker):
serialized_event = element.body
body = serialized_event.get("body")
element.body = body
for field in Event._serialize_fields:
val = serialized_event.get(field)
if val is not None:
val = serialized_event.get(field)
if val is not None:
if field == "time":
val = _convert_to_datetime(val)
setattr(element, field, val)
else:
body = element.body
element = unpack_event_if_wrapped(element)
body = element.body
if not hasattr(element, "processing_time"):
if hasattr(element, "timestamp"):
element.processing_time = element.timestamp
Expand Down
6 changes: 3 additions & 3 deletions storey/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .dtypes import Event, V3ioError
from .flow import Flow, _Batching, _split_path, _termination_obj
from .table import Table, _PersistJob
from .utils import stringify_key, url_to_file_system
from .utils import stringify_key, url_to_file_system, wrap_event_for_serialization


class _Writer:
Expand Down Expand Up @@ -1089,7 +1089,7 @@ async def _worker(self):
shard_id %= self._shards
record = self._event_to_writer_entry(event)
if self._full_event:
record = Event.wrap_for_serialization(event, record)
record = wrap_event_for_serialization(event, record)
buffers[shard_id].append(record)
buffer_events[shard_id].append(event)
if len(buffers[shard_id]) >= self._batch_size:
Expand Down Expand Up @@ -1248,7 +1248,7 @@ async def _do(self, event):
key = stringify_key(event.key).encode("UTF-8")
record = self._event_to_writer_entry(event)
if self._full_event:
record = Event.wrap_for_serialization(event, record)
record = wrap_event_for_serialization(event, record)
record = json.dumps(record, default=str).encode("UTF-8")
partition = None
if self._sharding_func:
Expand Down
45 changes: 45 additions & 0 deletions storey/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@
import os
import struct
from array import array
from datetime import datetime
from typing import Optional
from urllib.parse import urlparse

import fsspec
import pytz

bucketPerWindow = 2
schema_file_name = ".schema"

serialize_event_marker = "full_event_wrapper"
event_fields_to_serialize = ["key", "id"]


def parse_duration(string_time):
unit = string_time[-1]
Expand Down Expand Up @@ -344,3 +350,42 @@ def find_filters(partitions_time_attributes, start, end, filters, filter_column)
filters,
filter_column,
)


def _convert_to_datetime(obj, time_format: Optional[str] = None):
if isinstance(obj, datetime):
return obj
elif isinstance(obj, float) or isinstance(obj, int):
return datetime.fromtimestamp(obj, tz=pytz.utc)
elif isinstance(obj, str):
if time_format is None:
return datetime.fromisoformat(obj)
else:
return datetime.strptime(obj, time_format)
else:
raise ValueError(f"Could not parse '{obj}' (of type {type(obj)}) as a time.")


def unpack_event_if_wrapped(event):
if isinstance(event.body, dict) and event.body.get(serialize_event_marker):
serialized_event = event.body
body = serialized_event.get("body")
event.body = body
for field in event_fields_to_serialize:
val = serialized_event.get(field)
if val is not None:
val = serialized_event.get(field)
if val is not None:
setattr(event, field, val)
return event


def wrap_event_for_serialization(event, record):
record = {serialize_event_marker: True, "body": record}
for field in event_fields_to_serialize:
val = getattr(event, field, None)
if val is not None:
if isinstance(val, datetime):
val = datetime.isoformat(val)
record[field] = val
return record

0 comments on commit 64f7699

Please sign in to comment.