Skip to content

Commit

Permalink
add store_embed task to ingestor (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
jperez999 authored Nov 27, 2024
1 parent 93b0781 commit d9b2fc5
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 14 deletions.
24 changes: 23 additions & 1 deletion client/src/nv_ingest_client/client/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nv_ingest_client.primitives.tasks import SplitTask
from nv_ingest_client.primitives.tasks import StoreTask
from nv_ingest_client.primitives.tasks import VdbUploadTask
from nv_ingest_client.primitives.tasks import StoreEmbedTask
from nv_ingest_client.primitives.tasks.chart_extraction import ChartExtractionTask
from nv_ingest_client.primitives.tasks.table_extraction import TableExtractionTask
from nv_ingest_client.util.util import filter_function_kwargs
Expand Down Expand Up @@ -301,7 +302,8 @@ def all_tasks(self) -> "Ingestor":
.dedup() \
.filter() \
.split() \
.embed()
.embed() \
.store_embed()
# .store() \
# .vdb_upload()
# fmt: on
Expand Down Expand Up @@ -438,6 +440,26 @@ def store(self, **kwargs: Any) -> "Ingestor":

return self

@ensure_job_specs
def store_embed(self, **kwargs: Any) -> "Ingestor":
"""
Adds a StoreTask to the batch job specification.
Parameters
----------
kwargs : dict
Parameters specific to the StoreTask.
Returns
-------
Ingestor
Returns self for chaining.
"""
store_task = StoreEmbedTask(**kwargs)
self._job_specs.add_task(store_task)

return self

@ensure_job_specs
def vdb_upload(self, **kwargs: Any) -> "Ingestor":
"""
Expand Down
2 changes: 1 addition & 1 deletion client/src/nv_ingest_client/primitives/tasks/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(
"""
super().__init__()

self._params = params
self._params = params or {}
self._extra_params = extra_params

def __str__(self) -> str:
Expand Down
7 changes: 5 additions & 2 deletions client/src/nv_ingest_client/primitives/tasks/vdb_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

class VdbUploadTaskSchema(BaseModel):
filter_errors: bool = False
bulk_ingest: bool = False,
bulk_ingest_path: str = None,
params: dict = None

class Config:
extra = "forbid"
Expand All @@ -32,7 +35,7 @@ def __init__(
self,
filter_errors: bool = False,
bulk_ingest: bool = False,
bulk_ingest_path: str = None,
bulk_ingest_path: str = "embeddings/",
params: dict = None
) -> None:
"""
Expand All @@ -42,7 +45,7 @@ def __init__(
self._filter_errors = filter_errors
self._bulk_ingest = bulk_ingest
self._bulk_ingest_path = bulk_ingest_path
self._params = params
self._params = params or {}

def __str__(self) -> str:
"""
Expand Down
17 changes: 8 additions & 9 deletions src/nv_ingest/stages/storages/embedding_storage_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,18 @@ def upload_embeddings(df: pd.DataFrame, params: Dict[str, Any]) -> pd.DataFrame:
for idx, row in df.iterrows():
uu_id = row["uuid"]
metadata = row["metadata"].copy()

metadata["source_metadata"]["source_location"] = bucket_path

if row["document_type"] == ContentTypeEnum.EMBEDDING:
logger.debug("Storing embedding data to Minio")
metadata["embedding_metadata"][
"uploaded_embedding_url"
] = bucket_path
metadata["embedding_metadata"] = {}
metadata["embedding_metadata"]["uploaded_embedding_url"] = bucket_path
doc_type = row["document_type"]
content_replace = doc_type in [ContentTypeEnum.IMAGE, ContentTypeEnum.STRUCTURED]
location = metadata["source_metadata"]["source_location"]
content = metadata["content"]
# TODO: validate metadata before putting it back in.
if metadata["embedding"] is not None:
logger.error(f"row type: {doc_type} - {location} - {len(content)}")
df.at[idx, "metadata"] = metadata
writer.append_row({
"text": metadata["content"],
"text": location if content_replace else content,
"source": metadata["source_metadata"],
"content_metadata": metadata["content_metadata"],
"vector": metadata["embedding"]}
Expand Down
17 changes: 16 additions & 1 deletion tests/nv_ingest_client/client/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nv_ingest_client.primitives.tasks import FilterTask
from nv_ingest_client.primitives.tasks import SplitTask
from nv_ingest_client.primitives.tasks import StoreTask
from nv_ingest_client.primitives.tasks import StoreEmbedTask
from nv_ingest_client.primitives.tasks import TableExtractionTask
from nv_ingest_client.primitives.tasks import VdbUploadTask

Expand Down Expand Up @@ -174,6 +175,20 @@ def test_store_task_some_args(ingestor):
assert task._store_method == "s3"


def test_store_embed_task_no_args(ingestor):
ingestor.store_embed()

assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[0], StoreEmbedTask)


def test_store_task_some_args(ingestor):
ingestor.store_embed(params={"extra_param": "extra"})

task = ingestor._job_specs.job_specs["pdf"][0]._tasks[0]
assert isinstance(task, StoreEmbedTask)
assert task._params["extra_param"] == "extra"


def test_vdb_upload_task_no_args(ingestor):
ingestor.vdb_upload()

Expand Down Expand Up @@ -315,7 +330,7 @@ def test_files_with_remote_files(ingestor_without_doc):
def test_all_tasks_adds_default_tasks(ingestor):
ingestor.all_tasks()

task_classes = {ExtractTask, DedupTask, FilterTask, SplitTask, EmbedTask}
task_classes = {ExtractTask, DedupTask, FilterTask, SplitTask, EmbedTask, StoreEmbedTask}
added_tasks = {
type(task) for job_specs in ingestor._job_specs._file_type_to_job_spec.values() for task in job_specs[0]._tasks
}
Expand Down

0 comments on commit d9b2fc5

Please sign in to comment.