diff --git a/src/python/txtai/vectors/factory.py b/src/python/txtai/vectors/factory.py index de07575b7..7abdd8755 100644 --- a/src/python/txtai/vectors/factory.py +++ b/src/python/txtai/vectors/factory.py @@ -8,7 +8,7 @@ from .huggingface import HFVectors from .litellm import LiteLLM from .llama import LlamaCpp -from .words import WordVectors, WORDS +from .words import WordVectors class VectorsFactory: @@ -47,13 +47,6 @@ def create(config, scoring=None, models=None): # Word vectors if method == "words": - if not WORDS: - # Raise error if trying to create Word Vectors without vectors extra - raise ImportError( - 'Word vector models are not available - install "vectors" extra to enable. Otherwise, specify ' - + 'method="transformers" to use transformer backed models' - ) - return WordVectors(config, scoring, models) # Transformers vectors diff --git a/src/python/txtai/vectors/litellm.py b/src/python/txtai/vectors/litellm.py index d7d414d21..e1ecbe390 100644 --- a/src/python/txtai/vectors/litellm.py +++ b/src/python/txtai/vectors/litellm.py @@ -48,11 +48,12 @@ def ismodel(path): return False def __init__(self, config, scoring, models): - super().__init__(config, scoring, models) - + # Check before parent constructor since it calls loadmodel if not LITELLM: raise ImportError('LiteLLM is not available - install "vectors" extra to enable') + super().__init__(config, scoring, models) + def loadmodel(self, path): return None diff --git a/src/python/txtai/vectors/words.py b/src/python/txtai/vectors/words.py index d97514900..46f5d9ed7 100644 --- a/src/python/txtai/vectors/words.py +++ b/src/python/txtai/vectors/words.py @@ -66,6 +66,17 @@ class WordVectors(Vectors): Builds vectors using weighted word embeddings. """ + def __init__(self, config, scoring, models): + # Check before parent constructor since it calls loadmodel + if not WORDS: + # Raise error if trying to create Word Vectors without vectors extra + raise ImportError( + 'Word vector models are not available - install "vectors" extra to enable. Otherwise, specify ' + + 'method="transformers" to use transformer backed models' + ) + + super().__init__(config, scoring, models) + def loadmodel(self, path): # Ensure that vector path exists if not path or not os.path.isfile(path):