diff --git a/CHANGELOG.md b/CHANGELOG.md index 84e0238..29055af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,9 @@ # Changelog -## 0.28.1 - TBD +## 0.29.0 - TBD + +#### Enhancements +- Added `tz` parameter to `DBNStore.to_df` which will convert all timestamp fields from UTC to a specified timezone when used with `pretty_ts` #### Bug fixes - `Live.block_for_close` and `Live.wait_for_close` will now call `Live.stop` when a timeout is reached instead of `Live.terminate` to close the stream more gracefully diff --git a/databento/common/dbnstore.py b/databento/common/dbnstore.py index 3b75b83..76bff62 100644 --- a/databento/common/dbnstore.py +++ b/databento/common/dbnstore.py @@ -26,6 +26,7 @@ import pandas as pd import pyarrow as pa import pyarrow.parquet as pq +import pytz import zstandard from databento_dbn import FIXED_PRICE_SCALE from databento_dbn import Compression @@ -47,6 +48,7 @@ from databento.common.error import BentoError from databento.common.symbology import InstrumentMap from databento.common.types import DBNRecord +from databento.common.types import Default from databento.common.validation import validate_enum from databento.common.validation import validate_file_write_path from databento.common.validation import validate_maybe_enum @@ -830,6 +832,7 @@ def to_df( pretty_ts: bool = ..., map_symbols: bool = ..., schema: Schema | str | None = ..., + tz: pytz.BaseTzInfo | str = ..., count: None = ..., ) -> pd.DataFrame: ... @@ -841,6 +844,7 @@ def to_df( pretty_ts: bool = ..., map_symbols: bool = ..., schema: Schema | str | None = ..., + tz: pytz.BaseTzInfo | str = ..., count: int = ..., ) -> DataFrameIterator: ... @@ -851,6 +855,7 @@ def to_df( pretty_ts: bool = True, map_symbols: bool = True, schema: Schema | str | None = None, + tz: pytz.BaseTzInfo | str | Default[pytz.BaseTzInfo] = Default[pytz.BaseTzInfo](pytz.UTC), count: int | None = None, ) -> pd.DataFrame | DataFrameIterator: """ @@ -865,7 +870,7 @@ def to_df( If "decimal", prices will be instances of `decimal.Decimal`. pretty_ts : bool, default True If all timestamp columns should be converted from UNIX nanosecond - `int` to tz-aware UTC `pd.Timestamp`. + `int` to tz-aware `pd.Timestamp`. The timezone can be specified using the `tz` parameter. map_symbols : bool, default True If symbology mappings from the metadata should be used to create a 'symbol' column, mapping the instrument ID to its requested symbol for @@ -873,6 +878,8 @@ def to_df( schema : Schema or str, optional The DBN schema for the dataframe. This is only required when reading a DBN stream with mixed record types. + tz : pytz.BaseTzInfo or str, default UTC + If `pretty_ts` is `True`, all timestamps will be converted to the specified timezone. count : int, optional If set, instead of returning a single `DataFrame` a `DataFrameIterator` instance will be returned. When iterated, this object will yield @@ -892,6 +899,14 @@ def to_df( """ schema = validate_maybe_enum(schema, Schema, "schema") + + if isinstance(tz, Default): + tz = tz.value # consume default + elif not pretty_ts: + raise ValueError("A timezone was specified when `pretty_ts` is `False`. Did you mean to set `pretty_ts=True`?") + + if not isinstance(tz, pytz.BaseTzInfo): + tz = pytz.timezone(tz) if schema is None: if self.schema is None: raise ValueError("a schema must be specified for mixed DBN data") @@ -910,6 +925,7 @@ def to_df( count=count, struct_type=self._schema_struct_map[schema], instrument_map=self._instrument_map, + tz=tz, price_type=price_type, pretty_ts=pretty_ts, map_symbols=map_symbols, @@ -1334,6 +1350,7 @@ def __init__( count: int | None, struct_type: type[DBNRecord], instrument_map: InstrumentMap, + tz: pytz.BaseTzInfo, price_type: Literal["fixed", "float", "decimal"] = "float", pretty_ts: bool = True, map_symbols: bool = True, @@ -1345,6 +1362,7 @@ def __init__( self._pretty_ts = pretty_ts self._map_symbols = map_symbols self._instrument_map = instrument_map + self._tz = tz def __iter__(self) -> DataFrameIterator: return self @@ -1411,7 +1429,7 @@ def _format_px( def _format_pretty_ts(self, df: pd.DataFrame) -> None: for field in self._struct_type._timestamp_fields: - df[field] = pd.to_datetime(df[field], utc=True, errors="coerce") + df[field] = pd.to_datetime(df[field], utc=True, errors="coerce").dt.tz_convert(self._tz) def _format_set_index(self, df: pd.DataFrame) -> None: index_column = self._struct_type._ordered_fields[0] diff --git a/databento/common/types.py b/databento/common/types.py index 4c246be..01859f2 100644 --- a/databento/common/types.py +++ b/databento/common/types.py @@ -1,4 +1,4 @@ -from typing import Callable, Union +from typing import Callable, Generic, TypeVar, Union import databento_dbn @@ -21,3 +21,34 @@ RecordCallback = Callable[[DBNRecord], None] ExceptionCallback = Callable[[Exception], None] + +_T = TypeVar("_T") +class Default(Generic[_T]): + """ + A container for a default value. This is to be used when a callable wants + to detect if a default parameter value is being used. + + Example + ------- + def foo(param=Default[int](10)): + if isinstance(param, Default): + print(f"param={param.value} (default)") + else: + print(f"param={param.value}") + + """ + + def __init__(self, value: _T): + self._value = value + + @property + def value(self) -> _T: + """ + The default value. + + Returns + ------- + _T + + """ + return self._value diff --git a/pyproject.toml b/pyproject.toml index 3505c33..1bee201 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ ruff = "^0.0.291" types-requests = "^2.30.0.0" tomli = "^2.0.1" teamcity-messages = "^1.32" +types-pytz = "^2024.1.0.20240203" [build-system] requires = ["poetry-core"] diff --git a/tests/test_historical_bento.py b/tests/test_historical_bento.py index 8f7022d..366d528 100644 --- a/tests/test_historical_bento.py +++ b/tests/test_historical_bento.py @@ -13,7 +13,9 @@ import numpy as np import pandas as pd import pytest +import pytz import zstandard +from databento.common.constants import SCHEMA_STRUCT_MAP from databento.common.dbnstore import DBNStore from databento.common.error import BentoError from databento.common.publishers import Dataset @@ -1330,3 +1332,68 @@ def test_dbnstore_to_df_cannot_map_symbols_default_to_false( # Assert assert len(df_iter) == 4 + + +@pytest.mark.parametrize( + "timezone", + [ + "US/Central", + "US/Eastern", + "Europe/Vienna", + "Asia/Dubai", + "UTC", + ], +) +@pytest.mark.parametrize( + "schema", + [pytest.param(schema, id=str(schema)) for schema in Schema.variants()], +) +def test_dbnstore_to_df_with_timezone( + test_data: Callable[[Dataset, Schema], bytes], + schema: Schema, + timezone: str, +) -> None: + """ + Test that setting the `tz` parameter in `DBNStore.to_df` converts all + timestamp fields into the specified timezone. + """ + # Arrange + dbn_stub_data = ( + zstandard.ZstdDecompressor().stream_reader(test_data(Dataset.GLBX_MDP3, schema)).read() + ) + dbnstore = DBNStore.from_bytes(data=dbn_stub_data) + + # Act + df = dbnstore.to_df(tz=timezone) + df.reset_index(inplace=True) + + # Assert + expected_timezone = pytz.timezone(timezone)._utcoffset + failures = [] + struct = SCHEMA_STRUCT_MAP[schema] + for field in struct._timestamp_fields: + if df[field].dt.tz._utcoffset != expected_timezone: + failures.append(field) + + assert not failures + + +def test_dbnstore_to_df_with_timezone_pretty_ts_error( + test_data: Callable[[Dataset, Schema], bytes], +) -> None: + """ + Test that setting the `tz` parameter in `DBNStore.to_df` when `pretty_ts` + is `False` causes an error. + """ + # Arrange + dbn_stub_data = ( + zstandard.ZstdDecompressor().stream_reader(test_data(Dataset.GLBX_MDP3, Schema.MBO)).read() + ) + dbnstore = DBNStore.from_bytes(data=dbn_stub_data) + + # Act, Assert + with pytest.raises(ValueError): + dbnstore.to_df( + pretty_ts=False, + tz=pytz.UTC, + )