Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the batch size to reduce the memory load during writing to Elastic Search. #1339

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,15 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
documents = self._search_documents(query=query)
return documents

def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int:
def write_documents(
self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE, batch_size: int = 1000
) -> int:
"""
Writes `Document`s to Elasticsearch.

:param documents: List of Documents to write to the document store.
:param policy: DuplicatePolicy to apply when a document with the same ID already exists in the document store.
:param batch_size: Every batch_size documents perform the writes to Elastic Search
:raises ValueError: If `documents` is not a list of `Document`s.
:raises DuplicateDocumentError: If a document with the same ID already exists in the document store and
`policy` is set to `DuplicatePolicy.FAIL` or `DuplicatePolicy.NONE`.
Expand All @@ -254,56 +257,58 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
policy = DuplicatePolicy.FAIL

action = "index" if policy == DuplicatePolicy.OVERWRITE else "create"

elasticsearch_actions = []
for doc in documents:
doc_dict = doc.to_dict()
if "sparse_embedding" in doc_dict:
sparse_embedding = doc_dict.pop("sparse_embedding", None)
if sparse_embedding:
logger.warning(
"Document %s has the `sparse_embedding` field set,"
"but storing sparse embeddings in Elasticsearch is not currently supported."
"The `sparse_embedding` field will be ignored.",
doc.id,
)
elasticsearch_actions.append(
{
"_op_type": action,
"_id": doc.id,
"_source": doc_dict,
}
documents_written = 0
for i in range(0, len(documents), batch_size):
batched_documents = documents[i : i + batch_size]
elasticsearch_actions = []
for doc in batched_documents:
doc_dict = doc.to_dict()
if "sparse_embedding" in doc_dict:
sparse_embedding = doc_dict.pop("sparse_embedding", None)
if sparse_embedding:
logger.warning(
"Document %s has the `sparse_embedding` field set,"
"but storing sparse embeddings in Elasticsearch is not currently supported."
"The `sparse_embedding` field will be ignored.",
doc.id,
)
elasticsearch_actions.append(
{
"_op_type": action,
"_id": doc.id,
"_source": doc_dict,
}
)

batched_documents_written, errors = helpers.bulk(
client=self.client,
actions=elasticsearch_actions,
refresh="wait_for",
index=self._index,
raise_on_error=False,
)

documents_written, errors = helpers.bulk(
client=self.client,
actions=elasticsearch_actions,
refresh="wait_for",
index=self._index,
raise_on_error=False,
)

if errors:
duplicate_errors_ids = []
other_errors = []
for e in errors:
error_type = e["create"]["error"]["type"]
if policy == DuplicatePolicy.FAIL and error_type == "version_conflict_engine_exception":
duplicate_errors_ids.append(e["create"]["_id"])
elif policy == DuplicatePolicy.SKIP and error_type == "version_conflict_engine_exception":
# when the policy is skip, duplication errors are OK and we should not raise an exception
continue
else:
other_errors.append(e)

if len(duplicate_errors_ids) > 0:
msg = f"IDs '{', '.join(duplicate_errors_ids)}' already exist in the document store."
raise DuplicateDocumentError(msg)

if len(other_errors) > 0:
msg = f"Failed to write documents to Elasticsearch. Errors:\n{other_errors}"
raise DocumentStoreError(msg)

if errors:
duplicate_errors_ids = []
other_errors = []
for e in errors:
error_type = e["create"]["error"]["type"]
if policy == DuplicatePolicy.FAIL and error_type == "version_conflict_engine_exception":
duplicate_errors_ids.append(e["create"]["_id"])
elif policy == DuplicatePolicy.SKIP and error_type == "version_conflict_engine_exception":
# when the policy is skip, duplication errors are OK and we should not raise an exception
continue
else:
other_errors.append(e)

if len(duplicate_errors_ids) > 0:
msg = f"IDs '{', '.join(duplicate_errors_ids)}' already exist in the document store."
raise DuplicateDocumentError(msg)

if len(other_errors) > 0:
msg = f"Failed to write documents to Elasticsearch. Errors:\n{other_errors}"
raise DocumentStoreError(msg)
documents_written += batched_documents_written
return documents_written

@staticmethod
Expand Down
6 changes: 4 additions & 2 deletions integrations/elasticsearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,10 @@ def test_user_agent_header(self, document_store: ElasticsearchDocumentStore):
assert document_store.client._headers["user-agent"].startswith("haystack-py-ds/")

def test_write_documents(self, document_store: ElasticsearchDocumentStore):
docs = [Document(id="1")]
assert document_store.write_documents(docs) == 1
num_docs = 50000
batch_size = 5000
docs = [Document(id=f"{i}") for i in range(num_docs)]
assert document_store.write_documents(docs, batch_size) == num_docs
with pytest.raises(DuplicateDocumentError):
document_store.write_documents(docs, DuplicatePolicy.FAIL)

Expand Down