Skip to content

Commit

Permalink
checking in progress but this Pr still is not ready yet
Browse files Browse the repository at this point in the history
Signed-off-by: Francisco Javier Arceo <[email protected]>
  • Loading branch information
franciscojavierarceo committed Jan 25, 2025
1 parent dc2c1dc commit 7597c28
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 96 deletions.
153 changes: 117 additions & 36 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
from feast.ssl_ca_trust_store_setup import configure_ca_trust_store_env_variables
from feast.stream_feature_view import StreamFeatureView
from feast.utils import _utc_now
from feast.type_map import feast_value_type_to_python_type

warnings.simplefilter("once", DeprecationWarning)

Expand Down Expand Up @@ -1825,42 +1826,52 @@ def retrieve_online_documents(
requested_feature_view = available_feature_views[0]

provider = self._get_provider()
document_features = self._retrieve_from_online_store(
provider,
requested_feature_view,
requested_feature,
requested_features,
query,
top_k,
distance_metric,
)

# TODO currently not return the vector value since it is same as feature value, if embedding is supported,
# the feature value can be raw text before embedded
entity_key_vals = [feature[1] for feature in document_features]
join_key_values: Dict[str, List[ValueProto]] = {}
for entity_key_val in entity_key_vals:
if entity_key_val is not None:
for join_key, entity_value in zip(
entity_key_val.join_keys, entity_key_val.entity_values
):
if join_key not in join_key_values:
join_key_values[join_key] = []
join_key_values[join_key].append(entity_value)

document_feature_vals = [feature[4] for feature in document_features]
document_feature_distance_vals = [feature[5] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])
requested_feature = requested_feature or requested_features[0]
utils._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={
**join_key_values,
requested_feature: document_feature_vals,
"distance": document_feature_distance_vals,
},
)
return OnlineResponse(online_features_response)
if self.config.online_store.type != 'milvus':
document_features = self._retrieve_from_online_store(
provider,
requested_feature_view,
requested_feature,
requested_features,
query,
top_k,
distance_metric,
)
# TODO currently not return the vector value since it is same as feature value, if embedding is supported,
# the feature value can be raw text before embedded
entity_key_vals = [feature[1] for feature in document_features]
join_key_values: Dict[str, List[ValueProto]] = {}
for entity_key_val in entity_key_vals:
if entity_key_val is not None:
for join_key, entity_value in zip(
entity_key_val.join_keys, entity_key_val.entity_values
):
if join_key not in join_key_values:
join_key_values[join_key] = []
join_key_values[join_key].append(entity_value)

document_feature_vals = [feature[4] for feature in document_features]
document_feature_distance_vals = [feature[5] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])
requested_feature = requested_feature or requested_features[0]
utils._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={
**join_key_values,
requested_feature: document_feature_vals,
"distance": document_feature_distance_vals,
},
)
return OnlineResponse(online_features_response)
else:
return self._retrieve_from_online_store_v2(
provider,
requested_feature_view,
requested_feature,
requested_features,
query,
top_k,
distance_metric,
)

def _retrieve_from_online_store(
self,
Expand Down Expand Up @@ -1917,6 +1928,76 @@ def _retrieve_from_online_store(
)
return read_row_protos


def _retrieve_from_online_store_v2(
self,
provider: Provider,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
query: List[float],
top_k: int,
distance_metric: Optional[str],
) -> OnlineResponse:
"""
Search and return document features from the online document store.
"""
documents = provider.retrieve_online_documents(
config=self.config,
table=table,
requested_feature=requested_feature,
requested_features=requested_features,
query=query,
top_k=top_k,
distance_metric=distance_metric,
)

read_row_protos = []
entity_key_dict = {}
table_entity_values = []
for row_ts, entity_key, feature_dict in documents:
read_row_protos.append((
row_ts,
entity_key,
feature_dict,
))
if entity_key:
for key, value in zip(entity_key.join_keys, entity_key.entity_values):
python_value = value
if key not in entity_key_dict:
entity_key_dict[key] = []
entity_key_dict[key].append(python_value)

table_entity_values.append(
tuple(feast_value_type_to_python_type(e) for e in entity_key.entity_values)
)
table_entity_values, idxs = utils._get_unique_entities_from_values(
entity_key_dict,
)

datevals, entityvals, feature_dicts = [], [], []
for d, e, f in documents:
datevals.append(d)
entityvals.append(e)
feature_dicts.append(f)

feature_data = utils._convert_rows_to_protobuf(
requested_features=[requested_feature, 'distance'] if requested_feature else requested_features + ['distance'],
read_rows=list(zip(datevals, feature_dicts)),
)

online_features_response = GetOnlineFeaturesResponse(results=[])
utils._populate_response_from_feature_data(
feature_data=feature_data,
indexes=idxs,
online_features_response=online_features_response,
full_feature_names=False,
requested_features=requested_features + ['distance'],
table=table,
)

return OnlineResponse(online_features_response)

def serve(
self,
host: str,
Expand Down
135 changes: 95 additions & 40 deletions sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from feast.feature_view import FeatureView
from feast.infra.infra_object import InfraObject
from feast.infra.key_encoding_utils import (
deserialize_entity_key,
serialize_entity_key,
)
from feast.infra.online_stores.online_store import OnlineStore
Expand All @@ -24,7 +25,10 @@
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.type_map import PROTO_VALUE_TO_VALUE_TYPE_MAP
from feast.type_map import (
PROTO_VALUE_TO_VALUE_TYPE_MAP,
feast_value_type_to_python_type,
)
from feast.types import (
VALUE_TYPES_TO_FEAST_TYPES,
Array,
Expand All @@ -33,7 +37,6 @@
ValueType,
)
from feast.utils import (
_build_retrieve_online_document_record,
_serialize_vector_to_float_list,
to_naive_utc,
)
Expand Down Expand Up @@ -170,16 +173,14 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Dict[str, A
dim=config.online_store.embedding_dim,
)
)
elif dtype == DataType.VARCHAR:
else:
fields.append(
FieldSchema(
name=field.name,
dtype=dtype,
dtype=DataType.VARCHAR,
max_length=512,
)
)
else:
fields.append(FieldSchema(name=field.name, dtype=dtype))

schema = CollectionSchema(
fields=fields, description="Feast feature view data"
Expand Down Expand Up @@ -234,25 +235,35 @@ def online_write_batch(
) -> None:
self.client = self._connect(config)
collection = self._get_collection(config, table)
vector_cols = [f.name for f in table.features if f.vector_index]
entity_batch_to_insert = []
for entity_key, values_dict, timestamp, created_ts in data:
# need to construct the composite primary key also need to handle the fact that entities are a list
entity_key_str = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
# to recover the entity key just run:
# deserialize_entity_key(bytes.fromhex(entity_key_str), entity_key_serialization_version=3)
composite_key_name = (
"_".join([str(value) for value in entity_key.join_keys]) + "_pk"
)
timestamp_int = int(to_naive_utc(timestamp).timestamp() * 1e6)
created_ts_int = (
int(to_naive_utc(created_ts).timestamp() * 1e6) if created_ts else 0
)
values_dict = _extract_proto_values_to_dict(values_dict)
entity_dict = _extract_proto_values_to_dict(
dict(zip(entity_key.join_keys, entity_key.entity_values))
)
entity_dict = {
join_key: feast_value_type_to_python_type(value)
for join_key, value in zip(
entity_key.join_keys, entity_key.entity_values
)
}
values_dict.update(entity_dict)
values_dict = _extract_proto_values_to_dict(
values_dict,
vector_cols=vector_cols,
serialize_to_string=True,
)

single_entity_record = {
composite_key_name: entity_key_str,
Expand Down Expand Up @@ -329,11 +340,12 @@ def retrieve_online_documents(
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[Dict[str, ValueProto]],
]
]:
entity_name_feast_primitive_type_map = {
k.name: k.dtype for k in table.entity_columns
}
self.client = self._connect(config)
collection_name = _table_id(config.project, table)
collection = self._get_collection(config, table)
Expand Down Expand Up @@ -383,39 +395,67 @@ def retrieve_online_documents(
)

result_list = []
c = 0
for hits in results:
for hit in hits:
single_record = {}
c+=1
res = {}
res_ts = None
for field in output_fields:
single_record[field] = hit.get("entity", {}).get(field, None)

entity_key_bytes = bytes.fromhex(
hit.get("entity", {}).get(composite_key_name, None)
)
embedding = hit.get("entity", {}).get(ann_search_field)
serialized_embedding = _serialize_vector_to_float_list(embedding)
val = ValueProto()
if field in ["created_ts", "event_ts"]:
res_ts = datetime.fromtimestamp(
hit.get("entity", {}).get(field) / 1e6
)
elif field == ann_search_field:
serialized_embedding = _serialize_vector_to_float_list(
embedding
)
res[ann_search_field] = serialized_embedding
elif field == composite_key_name:
# In other approaches the entity keys are joined later
entity_key_bytes = bytes.fromhex(
hit.get("entity", {}).get(composite_key_name, None)
)
entity_key_proto = deserialize_entity_key(entity_key_bytes)
# res[field] = entity_key_proto
elif entity_name_feast_primitive_type_map.get(
field, PrimitiveFeastType.INVALID
) in [
PrimitiveFeastType.STRING,
PrimitiveFeastType.INT64,
PrimitiveFeastType.INT32,
PrimitiveFeastType.BYTES,
]:
res[field] = ValueProto(
string_val=hit.get("entity", {}).get(field, "")
)
else:
val.ParseFromString(
bytes(hit.get("entity", {}).get(field, b"").encode())
)
res[field] = val
distance = hit.get("distance", None)
event_ts = datetime.fromtimestamp(
hit.get("entity", {}).get("event_ts") / 1e6
res["distance"] = (
ValueProto(float_val=distance) if distance else ValueProto()
)
prepared_result = _build_retrieve_online_document_record(
entity_key_bytes,
# This may have a bug
serialized_embedding.SerializeToString(),
embedding,
distance,
event_ts,
config.entity_key_serialization_version,
)
result_list.append(prepared_result)
if not res:
result_list.append((None, None, None))
else:
result_list.append((res_ts, entity_key_proto, res))
print(f"{c} results found with k = {top_k}")
return result_list


def _table_id(project: str, table: FeatureView) -> str:
return f"{project}_{table.name}"


def _extract_proto_values_to_dict(input_dict: Dict[str, Any]) -> Dict[str, Any]:
def _extract_proto_values_to_dict(
input_dict: Dict[str, Any],
vector_cols: List[str],
serialize_to_string=False,
) -> Dict[str, Any]:
numeric_vector_list_types = [
k
for k in PROTO_VALUE_TO_VALUE_TYPE_MAP.keys()
Expand All @@ -424,12 +464,27 @@ def _extract_proto_values_to_dict(input_dict: Dict[str, Any]) -> Dict[str, Any]:
output_dict = {}
for feature_name, feature_values in input_dict.items():
for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP:
if feature_values.HasField(proto_val_type):
if proto_val_type in numeric_vector_list_types:
vector_values = getattr(feature_values, proto_val_type).val
else:
vector_values = getattr(feature_values, proto_val_type)
output_dict[feature_name] = vector_values
if not isinstance(feature_values, (int, float, str)):
if feature_values.HasField(proto_val_type):
if proto_val_type in numeric_vector_list_types:
if serialize_to_string and feature_name not in vector_cols:
vector_values = getattr(
feature_values, proto_val_type
).SerializeToString()
else:
vector_values = getattr(feature_values, proto_val_type).val
else:
if serialize_to_string:
vector_values = feature_values.SerializeToString().decode()
else:
vector_values = getattr(feature_values, proto_val_type)
output_dict[feature_name] = vector_values
else:
if serialize_to_string:
if not isinstance(feature_values, str):
feature_values = str(feature_values)
output_dict[feature_name] = feature_values

return output_dict


Expand Down
Loading

0 comments on commit 7597c28

Please sign in to comment.