Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Range query materialization #185

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
13 changes: 8 additions & 5 deletions sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_SparkSerializedArtifacts,
)
from feast.infra.provider import get_provider
from feast.sorted_feature_view import SortedFeatureView
from feast.stream_feature_view import StreamFeatureView


Expand Down Expand Up @@ -258,11 +259,13 @@ def batch_write(row: DataFrame, batch_id: int):
ts_field = self.sfv.timestamp_field
else:
ts_field = self.sfv.stream_source.timestamp_field # type: ignore
rows = (
rows.sort_values(by=[*self.join_keys, ts_field], ascending=False)
.groupby(self.join_keys)
.nth(0)
)

if not isinstance(self.sfv, SortedFeatureView):
rows = (
rows.sort_values(by=[*self.join_keys, ts_field], ascending=False)
.groupby(self.join_keys)
.nth(0)
)
# Created column is not used anywhere in the code, but it is added to the dataframe.
# Commenting this out as it is not used anywhere in the code
# rows["created"] = pd.to_datetime("now", utc=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from feast.infra.registry.base_registry import BaseRegistry
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.sorted_feature_view import SortedFeatureView
from feast.stream_feature_view import StreamFeatureView
from feast.utils import (
_convert_arrow_to_proto,
Expand Down Expand Up @@ -135,7 +136,7 @@ def materialize(
def _materialize_one(
self,
registry: BaseRegistry,
feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView],
feature_view: Union[BatchFeatureView, SortedFeatureView, StreamFeatureView, FeatureView],
start_date: datetime,
end_date: datetime,
project: str,
Expand All @@ -155,19 +156,33 @@ def _materialize_one(
job_id = f"{feature_view.name}-{start_date}-{end_date}"

try:
offline_job = cast(
SparkRetrievalJob,
self.offline_store.pull_latest_from_table_or_query(
config=self.repo_config,
data_source=feature_view.batch_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
),
)
if isinstance(feature_view, SortedFeatureView):
offline_job = cast(
SparkRetrievalJob,
self.offline_store.pull_all_from_table_or_query(
config=self.repo_config,
data_source=feature_view.batch_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
start_date=start_date,
end_date=end_date,
),
)
else:
offline_job = cast(
SparkRetrievalJob,
self.offline_store.pull_latest_from_table_or_query(
config=self.repo_config,
data_source=feature_view.batch_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
),
)

spark_serialized_artifacts = _SparkSerializedArtifacts.serialize(
feature_view=feature_view, repo_config=self.repo_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import logging
import string
import time
from collections import defaultdict
from datetime import datetime
from functools import partial
from queue import Queue
Expand All @@ -47,6 +48,7 @@
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.rate_limiter import SlidingWindowRateLimiter
from feast.repo_config import FeastConfigBaseModel
from feast.sorted_feature_view import SortedFeatureView

# Error messages
E_CASSANDRA_UNEXPECTED_CONFIGURATION_CLASS = (
Expand Down Expand Up @@ -74,6 +76,10 @@
" (?, ?, ?, ?) USING TTL {ttl};"
)

INSERT_TIME_SERIES_TEMPLATE = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't call it "time series" since it can be used for any sorted set. I suggest going with "sorted features" to be consistent with the SortedFeatureView name.

"INSERT INTO {fqtable} ({feature_names}, entity_key, event_ts) VALUES ({parameters}) USING TTL {ttl};"
)

SELECT_CQL_TEMPLATE = "SELECT {columns} FROM {fqtable} WHERE entity_key = ?;"

CREATE_TABLE_CQL_TEMPLATE = """
Expand All @@ -94,6 +100,7 @@
CQL_TEMPLATE_MAP = {
# Queries/DML, statements to be prepared
"insert4": (INSERT_CQL_4_TEMPLATE, True),
"insert_time_series": (INSERT_TIME_SERIES_TEMPLATE, True),
"select": (SELECT_CQL_TEMPLATE, True),
# DDL, do not prepare these
"drop": (DROP_TABLE_CQL_TEMPLATE, False),
Expand Down Expand Up @@ -400,50 +407,73 @@ def on_failure(exc, concurrent_queue):
keyspace, project, table, table_name_version
)

insert_cql = self._get_cql_statement(
config,
"insert4",
fqtable=fqtable,
ttl=ttl,
session=session,
)
if isinstance(table, SortedFeatureView):
# Split the data in to multiple batches, with each batch having the same entity key (partition key).
# NOTE: It is not a good practice to have data from multiple partitions in the same batch.
# Doing so can affect write latency and also data loss among other things.
entity_dict: Dict[str, List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]] = \
defaultdict(list[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]])
for row in data:
entity_key_bin = serialize_entity_key(
row[0],
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
entity_dict[entity_key_bin].append(row)

# Get the list of feature names from data to use in the insert query
feature_names = list(data[0][1].keys())
feature_names_str = ', '.join(feature_names)
params_str = ", ".join(["?"] * (len(feature_names)+2))

insert_cql = self._get_cql_statement(
config,
"insert_time_series",
fqtable=fqtable,
ttl=ttl,
session=session,
feature_names_str=feature_names_str,
params_str=params_str,
)

for entity_key, values, timestamp, created_ts in data:
batch = BatchStatement(batch_type=BatchType.UNLOGGED)
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
for feature_name, val in values.items():
params: Tuple[str, bytes, str, datetime] = (
feature_name,
val.SerializeToString(),
entity_key_bin,
timestamp,
)
batch.add(insert_cql, params)

# Wait until the rate limiter allows
if not rate_limiter.acquire():
while not rate_limiter.acquire():
time.sleep(0.001)

future = session.execute_async(batch)
concurrent_queue.put(future)
future.add_callbacks(
partial(
on_success,
concurrent_queue=concurrent_queue,
),
partial(
on_failure,
concurrent_queue=concurrent_queue,
),
# Write each batch with same entity key in to the online store
for entity_key_bin, batch_to_write in entity_dict.items():
batch = BatchStatement(batch_type=BatchType.UNLOGGED)
for entity_key, feat_dict, timestamp, created_ts in batch_to_write:
feature_values = ()
for valProto in feat_dict.values():
feature_value = getattr(valProto, valProto.WhichOneof('val'))
feature_values += (feature_value,)

feature_values = feature_values + (entity_key_bin,timestamp)
batch.add(insert_cql, feature_values)

CassandraOnlineStore._apply_batch(rate_limiter, batch, progress, session, concurrent_queue, on_success, on_failure)
else:
insert_cql = self._get_cql_statement(
config,
"insert4",
fqtable=fqtable,
ttl=ttl,
session=session,
)

# this happens N-1 times, will be corrected outside:
if progress:
progress(1)
for entity_key, values, timestamp, created_ts in data:
batch = BatchStatement(batch_type=BatchType.UNLOGGED)
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
for feature_name, val in values.items():
params: Tuple[str, bytes, str, datetime] = (
feature_name,
val.SerializeToString(),
entity_key_bin,
timestamp,
)
batch.add(insert_cql, params)

CassandraOnlineStore._apply_batch(rate_limiter, batch, progress, session, concurrent_queue, on_success, on_failure)

# Wait for all tasks to complete
while not concurrent_queue.empty():
time.sleep(0.001)
Expand Down Expand Up @@ -724,10 +754,19 @@ def _get_cql_statement(
session = self._get_session(config)

template, prepare = CQL_TEMPLATE_MAP[op_name]
statement = template.format(
fqtable=fqtable,
**kwargs,
)
if op_name == "insert_time_series":
statement = template.format(
fqtable=fqtable,
feature_names=kwargs.get('feature_names_str'),
parameters=kwargs.get('params_str'),
**kwargs,
)
else:
statement = template.format(
fqtable=fqtable,
**kwargs,
)

if prepare:
# using the statement itself as key (no problem with that)
cache_key = statement
Expand All @@ -737,3 +776,35 @@ def _get_cql_statement(
return self._prepared_statements[cache_key]
else:
return statement

@staticmethod
def _apply_batch(
rate_limiter: SlidingWindowRateLimiter,
batch: BatchStatement,
progress: Optional[Callable[[int], Any]],
session: Session,
concurrent_queue: Queue,
on_success,
on_failure
):
# Wait until the rate limiter allows
if not rate_limiter.acquire():
while not rate_limiter.acquire():
time.sleep(0.001)

future = session.execute_async(batch)
concurrent_queue.put(future)
future.add_callbacks(
partial(
on_success,
concurrent_queue=concurrent_queue,
),
partial(
on_failure,
concurrent_queue=concurrent_queue,
),
)

# this happens N-1 times, will be corrected outside:
if progress:
progress(1)
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def create_online_store(self) -> Dict[str, object]:
"hosts": ["127.0.0.1"],
"port": exposed_port,
"keyspace": keyspace_name,
"container": self.container
}

def teardown(self):
Expand Down
Loading
Loading