Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate Indexing Logic from Inference #33

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 82 additions & 71 deletions byaldi/RAGModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,36 @@
from PIL import Image

from byaldi.colpali import ColPaliModel

from byaldi.indexing import IndexManager
from byaldi.objects import Result

# Optional langchain integration
try:
from byaldi.integrations import ByaldiLangChainRetriever
except ImportError:
pass


class RAGMultiModalModel:
"""
Wrapper class for a pretrained RAG multi-modal model, and all the associated utilities.
Wrapper class for a pretrained RAG multi-modal model, and an associated index manager.
Allows you to load a pretrained model from disk or from the hub, build or query an index.

## Usage

Load a pre-trained checkpoint:

```python
from byaldi import RAGMultiModalModel

RAG = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2")
```

Both methods will load a fully initialised instance of ColPali, which you can use to build and query indexes.

```python
RAG.search("How many people live in France?")
```
"""

model: Optional[ColPaliModel] = None
def __init__(
self,
model: Optional[ColPaliModel] = None,
index_root: str = ".byaldi",
device: str = "cuda",
verbose: int = 1,
):
self.model = model
self.index_manager = IndexManager(index_root=index_root, verbose=verbose)
self.device = device
self.verbose = verbose

@classmethod
def from_pretrained(
Expand All @@ -46,113 +43,120 @@ def from_pretrained(
device: str = "cuda",
verbose: int = 1,
):
"""Load a ColPali model from a pre-trained checkpoint.

"""
Load a ColPali model from a pre-trained checkpoint.
Parameters:
pretrained_model_name_or_path (str): Local path or huggingface model name.
device (str): The device to load the model on. Default is "cuda".

Returns:
cls (RAGMultiModalModel): The current instance of RAGMultiModalModel, with the model initialised.
"""
instance = cls()
instance.model = ColPaliModel.from_pretrained(
model = ColPaliModel.from_pretrained(
pretrained_model_name_or_path,
index_root=index_root,
device=device,
verbose=verbose,
)
return instance
return cls(model, index_root, device, verbose)

@classmethod
def from_index(
cls,
index_path: Union[str, Path],
index_name: str,
index_root: str = ".byaldi",
device: str = "cuda",
verbose: int = 1,
):
"""Load an Index and the associated ColPali model from an existing document index.
"""
Load an Index and the associated ColPali model from an existing document index.

Parameters:
index_path (Union[str, Path]): Path to the index.
index_name (str): Name of the index.
index_root (str): Path to the index root directory.
device (str): The device to load the model on. Default is "cuda".
verbose (int): Verbosity level.

Returns:
cls (RAGMultiModalModel): The current instance of RAGMultiModalModel, with the model and index initialised.
cls (RAGMultiModalModel): The current instance of RAGMultiModalModel, with the index and model loaded.
"""
instance = cls()
index_path = Path(index_path)
instance.model = ColPaliModel.from_index(
index_path, index_root=index_root, device=device, verbose=verbose
instance = cls(index_root=index_root, device=device, verbose=verbose)
instance.index_manager.load_index(index_name)
instance.model = ColPaliModel.from_pretrained(
instance.index_manager.model_name,
device=device,
verbose=verbose,
)

return instance

def index(
self,
input_path: Union[str, Path],
index_name: Optional[str] = None,
doc_ids: Optional[int] = None,
index_name: str,
store_collection_with_index: bool = False,
doc_ids: Optional[List[int]] = None,
metadata: Optional[List[Dict[str, Union[str, int]]]] = None,
overwrite: bool = False,
metadata: Optional[
Union[
Dict[Union[str, int], Dict[str, Union[str, int]]],
List[Dict[str, Union[str, int]]],
]
] = None,
max_image_width: Optional[int] = None,
max_image_height: Optional[int] = None,
**kwargs,
):
"""Build an index from input documents.
"""
Wrapper function to create and add to an index.

Parameters:
input_path (Union[str, Path]): Path to the input documents.
index_name (Optional[str]): The name of the index that will be built.
doc_ids (Optional[List[Union[str, int]]]): List of document IDs.
input_path (str, Path): Path to the input file or directory.
index_name (str): Name of the index.
store_collection_with_index (bool): Whether to store the collection with the index.
overwrite (bool): Whether to overwrite an existing index with the same name.
metadata (Optional[Union[Dict[Union[str, int], Dict[str, Union[str, int]]], List[Dict[str, Union[str, int]]]]]):
Metadata for the documents. Can be a dictionary mapping doc_ids to metadata dictionaries,
or a list of metadata dictionaries (one for each document).
doc_ids (List[int]): List of document IDs.
metadata (List[Dict[str, Union[str, int]]]): List of metadata dictionaries.
overwrite (bool): Whether to overwrite the existing index.
max_image_width (int): Maximum image width.
max_image_height (int): Maximum image height.
**kwargs: Additional keyword arguments.

Returns:
None
"""
return self.model.index(
input_path,
self.index_manager.create_index(
index_name,
doc_ids,
self.model.pretrained_model_name_or_path,
store_collection_with_index,
overwrite=overwrite,
metadata=metadata,
max_image_width=max_image_width,
max_image_height=max_image_height,
**kwargs,
overwrite,
max_image_width,
max_image_height,
)
return self.index_manager.add_to_index(
input_path,
self.model.encode_image,
doc_ids,
metadata,
)

def add_to_index(
self,
input_item: Union[str, Path, Image.Image],
store_collection_with_index: bool,
store_collection_with_index: bool = False,
doc_id: Optional[int] = None,
metadata: Optional[Dict[str, Union[str, int]]] = None,
):
"""Add an item to an existing index.
"""
Wrapper function to add to an existing index.

Parameters:
input_item (Union[str, Path, Image.Image]): The item to add to the index.
input_item (str, Path, Image.Image): Input file or directory.
store_collection_with_index (bool): Whether to store the collection with the index.
doc_id (Union[str, int]): The document ID for the item being added.
metadata (Optional[Dict[str, Union[str, int]]]): Metadata for the document being added.
doc_id (int): Document ID.
metadata (Dict[str, Union[str, int]]): Metadata dictionary.

Returns:
None
"""
return self.model.add_to_index(
input_item, store_collection_with_index, doc_id, metadata=metadata

return self.index_manager.add_to_index(
input_item,
self.model.encode_image,
doc_id,
metadata,
store_collection_with_index,
)

def search(
Expand All @@ -161,20 +165,27 @@ def search(
k: int = 10,
return_base64_results: Optional[bool] = None,
) -> Union[List[Result], List[List[Result]]]:
"""Query an index.
"""
Search the index for the given query.

Parameters:
query (Union[str, List[str]]): The query or queries to search for.
k (int): The number of results to return. Default is 10.
return_base64_results (Optional[bool]): Whether to return base64-encoded image results.
query (str, List[str]): Query string or list of query strings.
k (int): Number of results to return.
return_base64_results (bool): Whether to return base64 encoded results.

Returns:
Union[List[Result], List[List[Result]]]: A list of Result objects or a list of lists of Result objects.
Union[List[Result], List[List[Result]]]: List of Result objects or list of lists of Result objects.
"""
return self.model.search(query, k, return_base64_results)
return self.index_manager.search(
query,
self.model.score,
k,
return_base64_results,
)

def get_doc_ids_to_file_names(self):
return self.model.get_doc_ids_to_file_names()
return self.index_manager.get_doc_ids_to_file_names()

def as_langchain_retriever(self, **kwargs: Any):
return ByaldiLangChainRetriever(model=self, kwargs=kwargs)
from byaldi.integrations import ByaldiLangChainRetriever
return ByaldiLangChainRetriever(model=self, kwargs=kwargs)
Loading