diff --git a/.github/workflows/connector-tests.yml b/.github/workflows/connector-tests.yml index e24761f4..5b4da7ba 100644 --- a/.github/workflows/connector-tests.yml +++ b/.github/workflows/connector-tests.yml @@ -75,9 +75,12 @@ jobs: # Chargebee is being flaky: # - connector: source-chargebee # cdk_extra: n/a - # These two are behind in CDK updates and can't be used as tests until they are updated: + # This is behind in CDK updates and can't be used in tests until updated: # - connector: source-s3 # cdk_extra: file-based + - connector: destination-pgvector + cdk_extra: vector-db-based + # Can bring back Pinecone for testing once it is updated to latest CDK version: # - connector: destination-pinecone # cdk_extra: vector-db-based - connector: destination-motherduck diff --git a/airbyte_cdk/destinations/vector_db_based/config.py b/airbyte_cdk/destinations/vector_db_based/config.py index c7c40eca..7102c6ea 100644 --- a/airbyte_cdk/destinations/vector_db_based/config.py +++ b/airbyte_cdk/destinations/vector_db_based/config.py @@ -109,6 +109,12 @@ class ProcessingConfigModel(BaseModel): always_show=True, examples=["text", "user.name", "users.*.name"], ) + omit_field_names_from_embeddings: bool = Field( + default=False, + title="Omit field names from embeddings", + description="Do not include the field names in the text that gets embedded. By default field names are embedded (e.g., 'user.name: John Doe \n user.email: john@example.com'). If set to true, only the values are embedded (e.g., 'John Doe \n john@example.com').", + always_show=True, + ) metadata_fields: Optional[List[str]] = Field( default=[], title="Fields to store as metadata", diff --git a/airbyte_cdk/destinations/vector_db_based/document_processor.py b/airbyte_cdk/destinations/vector_db_based/document_processor.py index c007bf9e..4dd6eea3 100644 --- a/airbyte_cdk/destinations/vector_db_based/document_processor.py +++ b/airbyte_cdk/destinations/vector_db_based/document_processor.py @@ -5,7 +5,7 @@ import json import logging from dataclasses import dataclass -from typing import Any, Dict, List, Mapping, Optional, Tuple +from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import dpath from langchain.text_splitter import Language, RecursiveCharacterTextSplitter @@ -126,6 +126,7 @@ def __init__(self, config: ProcessingConfigModel, catalog: ConfiguredAirbyteCata self.text_fields = config.text_fields self.metadata_fields = config.metadata_fields self.field_name_mappings = config.field_name_mappings + self.omit_field_names_from_embeddings = config.omit_field_names_from_embeddings self.logger = logging.getLogger("airbyte.document_processor") def process(self, record: AirbyteRecordMessage) -> Tuple[List[Chunk], Optional[str]]: @@ -163,10 +164,28 @@ def _generate_document(self, record: AirbyteRecordMessage) -> Optional[Document] relevant_fields = self._extract_relevant_fields(record, self.text_fields) if len(relevant_fields) == 0: return None - text = stringify_dict(relevant_fields) + text = self._generate_text_from_fields(relevant_fields) metadata = self._extract_metadata(record) return Document(page_content=text, metadata=metadata) + def _generate_text_from_fields(self, fields: Dict[str, Any]) -> str: + if self.omit_field_names_from_embeddings: + return self._extract_values_from_dict(fields) + else: + return stringify_dict(fields) + + def _extract_values_from_dict( + self, data: Union[Dict[Any, Any], List[Any], Any], join_char: str = "\n" + ) -> str: + if data is None: + return "" + elif isinstance(data, dict): + return join_char.join(self._extract_values_from_dict(value) for value in data.values()) + elif isinstance(data, list): + return join_char.join(self._extract_values_from_dict(item) for item in data) + else: + return str(data) + def _extract_relevant_fields( self, record: AirbyteRecordMessage, fields: Optional[List[str]] ) -> Dict[str, Any]: diff --git a/unit_tests/destinations/vector_db_based/config_test.py b/unit_tests/destinations/vector_db_based/config_test.py index ea6f446b..45d963f1 100644 --- a/unit_tests/destinations/vector_db_based/config_test.py +++ b/unit_tests/destinations/vector_db_based/config_test.py @@ -243,6 +243,13 @@ def test_json_schema_generation(): "type": "array", "items": {"type": "string"}, }, + "omit_field_names_from_embeddings": { + "title": "Omit field names from embeddings", + "description": "Do not include the field names in the text that gets embedded. By default field names are embedded (e.g., 'user.name: John Doe \n user.email: john@example.com'). If set to true, only the values are embedded (e.g., 'John Doe \n john@example.com').", + "default": False, + "always_show": True, + "type": "boolean", + }, "metadata_fields": { "title": "Fields to store as metadata", "description": "List of fields in the record that should be stored as metadata. The field list is applied to all streams in the same way and non-existing fields are ignored. If none are defined, all fields are considered metadata fields. When specifying text fields, you can access nested fields in the record by using dot notation, e.g. `user.name` will access the `name` field in the `user` object. It's also possible to use wildcards to access all fields in an object, e.g. `users.*.name` will access all `names` fields in all entries of the `users` array. When specifying nested paths, all matching values are flattened into an array set to a field named by the path.", diff --git a/unit_tests/destinations/vector_db_based/document_processor_test.py b/unit_tests/destinations/vector_db_based/document_processor_test.py index ede88921..84e2b924 100644 --- a/unit_tests/destinations/vector_db_based/document_processor_test.py +++ b/unit_tests/destinations/vector_db_based/document_processor_test.py @@ -194,6 +194,7 @@ def test_complex_text_fields(): "non.*.existing", ] processor.metadata_fields = ["non_text", "non_text_2", "id"] + processor.omit_field_names_from_embeddings = False chunks, _ = processor.process(record) @@ -214,6 +215,63 @@ def test_complex_text_fields(): } +def test_complex_text_fields_omit_field_names(): + processor = initialize_processor() + + record = AirbyteRecordMessage( + stream="stream1", + namespace="namespace1", + data={ + "id": 1, + "nested": { + "texts": [ + {"text": "This is the text"}, + {"text": "And another"}, + ] + }, + "non_text": "a", + "non_text_2": 1, + "text": "This is the regular text", + "other_nested": {"non_text": {"a": "xyz", "b": "abc"}}, + "empty_list": [], + "empty_dict": {}, + "large_nested": {"a": {"b": {"c": {"d": {"e": {"f": {"g": "h"}}}}}}}, + }, + emitted_at=1234, + ) + + processor.text_fields = [ + "nested.texts.*.text", + "text", + "other_nested.non_text", + "non.*.existing", + "large_nested", + "empty_list", + "empty_dict", + ] + processor.metadata_fields = ["non_text", "non_text_2", "id"] + processor.omit_field_names_from_embeddings = True + + chunks, _ = processor.process(record) + + assert len(chunks) == 1 + assert ( + chunks[0].page_content + == """This is the text +And another +This is the regular text +xyz +abc +h""" + ) + assert chunks[0].metadata == { + "id": 1, + "non_text": "a", + "non_text_2": 1, + "_ab_stream": "namespace1_stream1", + } + + def test_no_text_fields(): processor = initialize_processor()