Skip to content

Commit

Permalink
ADD: DBNStore.to_df tz parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
nmacholl committed Feb 8, 2024
1 parent ee406b4 commit 8e7c94e
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 4 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Changelog

## 0.28.1 - TBD
## 0.29.0 - TBD

#### Enhancements
- Added `tz` parameter to `DBNStore.to_df` which will convert all timestamp fields from UTC to a specified timezone when used with `pretty_ts`

#### Bug fixes
- `Live.block_for_close` and `Live.wait_for_close` will now call `Live.stop` when a timeout is reached instead of `Live.terminate` to close the stream more gracefully
Expand Down
22 changes: 20 additions & 2 deletions databento/common/dbnstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytz
import zstandard
from databento_dbn import FIXED_PRICE_SCALE
from databento_dbn import Compression
Expand All @@ -47,6 +48,7 @@
from databento.common.error import BentoError
from databento.common.symbology import InstrumentMap
from databento.common.types import DBNRecord
from databento.common.types import Default
from databento.common.validation import validate_enum
from databento.common.validation import validate_file_write_path
from databento.common.validation import validate_maybe_enum
Expand Down Expand Up @@ -830,6 +832,7 @@ def to_df(
pretty_ts: bool = ...,
map_symbols: bool = ...,
schema: Schema | str | None = ...,
tz: pytz.BaseTzInfo | str = ...,
count: None = ...,
) -> pd.DataFrame:
...
Expand All @@ -841,6 +844,7 @@ def to_df(
pretty_ts: bool = ...,
map_symbols: bool = ...,
schema: Schema | str | None = ...,
tz: pytz.BaseTzInfo | str = ...,
count: int = ...,
) -> DataFrameIterator:
...
Expand All @@ -851,6 +855,7 @@ def to_df(
pretty_ts: bool = True,
map_symbols: bool = True,
schema: Schema | str | None = None,
tz: pytz.BaseTzInfo | str | Default[pytz.BaseTzInfo] = Default[pytz.BaseTzInfo](pytz.UTC),
count: int | None = None,
) -> pd.DataFrame | DataFrameIterator:
"""
Expand All @@ -865,14 +870,16 @@ def to_df(
If "decimal", prices will be instances of `decimal.Decimal`.
pretty_ts : bool, default True
If all timestamp columns should be converted from UNIX nanosecond
`int` to tz-aware UTC `pd.Timestamp`.
`int` to tz-aware `pd.Timestamp`. The timezone can be specified using the `tz` parameter.
map_symbols : bool, default True
If symbology mappings from the metadata should be used to create
a 'symbol' column, mapping the instrument ID to its requested symbol for
every record.
schema : Schema or str, optional
The DBN schema for the dataframe.
This is only required when reading a DBN stream with mixed record types.
tz : pytz.BaseTzInfo or str, default UTC
If `pretty_ts` is `True`, all timestamps will be converted to the specified timezone.
count : int, optional
If set, instead of returning a single `DataFrame` a `DataFrameIterator`
instance will be returned. When iterated, this object will yield
Expand All @@ -892,6 +899,14 @@ def to_df(
"""
schema = validate_maybe_enum(schema, Schema, "schema")

if isinstance(tz, Default):
tz = tz.value # consume default
elif not pretty_ts:
raise ValueError("A timezone was specified when `pretty_ts` is `False`. Did you mean to set `pretty_ts=True`?")

if not isinstance(tz, pytz.BaseTzInfo):
tz = pytz.timezone(tz)
if schema is None:
if self.schema is None:
raise ValueError("a schema must be specified for mixed DBN data")
Expand All @@ -910,6 +925,7 @@ def to_df(
count=count,
struct_type=self._schema_struct_map[schema],
instrument_map=self._instrument_map,
tz=tz,
price_type=price_type,
pretty_ts=pretty_ts,
map_symbols=map_symbols,
Expand Down Expand Up @@ -1334,6 +1350,7 @@ def __init__(
count: int | None,
struct_type: type[DBNRecord],
instrument_map: InstrumentMap,
tz: pytz.BaseTzInfo,
price_type: Literal["fixed", "float", "decimal"] = "float",
pretty_ts: bool = True,
map_symbols: bool = True,
Expand All @@ -1345,6 +1362,7 @@ def __init__(
self._pretty_ts = pretty_ts
self._map_symbols = map_symbols
self._instrument_map = instrument_map
self._tz = tz

def __iter__(self) -> DataFrameIterator:
return self
Expand Down Expand Up @@ -1411,7 +1429,7 @@ def _format_px(

def _format_pretty_ts(self, df: pd.DataFrame) -> None:
for field in self._struct_type._timestamp_fields:
df[field] = pd.to_datetime(df[field], utc=True, errors="coerce")
df[field] = pd.to_datetime(df[field], utc=True, errors="coerce").dt.tz_convert(self._tz)

def _format_set_index(self, df: pd.DataFrame) -> None:
index_column = self._struct_type._ordered_fields[0]
Expand Down
33 changes: 32 additions & 1 deletion databento/common/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Union
from typing import Callable, Generic, TypeVar, Union

import databento_dbn

Expand All @@ -21,3 +21,34 @@

RecordCallback = Callable[[DBNRecord], None]
ExceptionCallback = Callable[[Exception], None]

_T = TypeVar("_T")
class Default(Generic[_T]):
"""
A container for a default value. This is to be used when a callable wants
to detect if a default parameter value is being used.
Example
-------
def foo(param=Default[int](10)):
if isinstance(param, Default):
print(f"param={param.value} (default)")
else:
print(f"param={param.value}")
"""

def __init__(self, value: _T):
self._value = value

@property
def value(self) -> _T:
"""
The default value.
Returns
-------
_T
"""
return self._value
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ ruff = "^0.0.291"
types-requests = "^2.30.0.0"
tomli = "^2.0.1"
teamcity-messages = "^1.32"
types-pytz = "^2024.1.0.20240203"

[build-system]
requires = ["poetry-core"]
Expand Down
67 changes: 67 additions & 0 deletions tests/test_historical_bento.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import numpy as np
import pandas as pd
import pytest
import pytz
import zstandard
from databento.common.constants import SCHEMA_STRUCT_MAP
from databento.common.dbnstore import DBNStore
from databento.common.error import BentoError
from databento.common.publishers import Dataset
Expand Down Expand Up @@ -1330,3 +1332,68 @@ def test_dbnstore_to_df_cannot_map_symbols_default_to_false(

# Assert
assert len(df_iter) == 4


@pytest.mark.parametrize(
"timezone",
[
"US/Central",
"US/Eastern",
"Europe/Vienna",
"Asia/Dubai",
"UTC",
],
)
@pytest.mark.parametrize(
"schema",
[pytest.param(schema, id=str(schema)) for schema in Schema.variants()],
)
def test_dbnstore_to_df_with_timezone(
test_data: Callable[[Dataset, Schema], bytes],
schema: Schema,
timezone: str,
) -> None:
"""
Test that setting the `tz` parameter in `DBNStore.to_df` converts all
timestamp fields into the specified timezone.
"""
# Arrange
dbn_stub_data = (
zstandard.ZstdDecompressor().stream_reader(test_data(Dataset.GLBX_MDP3, schema)).read()
)
dbnstore = DBNStore.from_bytes(data=dbn_stub_data)

# Act
df = dbnstore.to_df(tz=timezone)
df.reset_index(inplace=True)

# Assert
expected_timezone = pytz.timezone(timezone)._utcoffset
failures = []
struct = SCHEMA_STRUCT_MAP[schema]
for field in struct._timestamp_fields:
if df[field].dt.tz._utcoffset != expected_timezone:
failures.append(field)

assert not failures


def test_dbnstore_to_df_with_timezone_pretty_ts_error(
test_data: Callable[[Dataset, Schema], bytes],
) -> None:
"""
Test that setting the `tz` parameter in `DBNStore.to_df` when `pretty_ts`
is `False` causes an error.
"""
# Arrange
dbn_stub_data = (
zstandard.ZstdDecompressor().stream_reader(test_data(Dataset.GLBX_MDP3, Schema.MBO)).read()
)
dbnstore = DBNStore.from_bytes(data=dbn_stub_data)

# Act, Assert
with pytest.raises(ValueError):
dbnstore.to_df(
pretty_ts=False,
tz=pytz.UTC,
)

0 comments on commit 8e7c94e

Please sign in to comment.