-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* WIP: Haystack integration * Added integration for Haystack V2 * Missing commit * Haystack documentation fixes * Minor fix in haystack v2 integration notebook
- Loading branch information
Showing
4 changed files
with
1,142 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
try: | ||
from haystack.nodes import PromptNode | ||
except ImportError: | ||
raise ImportError('haystack is not installed. Please install it with "pip install farm-haystack"') | ||
|
||
import enum | ||
from typing import Callable, Optional | ||
from lmformatenforcer import CharacterLevelParser | ||
|
||
class LMFormatEnforcerPromptNode(PromptNode): | ||
"""A prompt node for Haystack V1 API that activates the LMFormatEnforcer on the generated text""" | ||
class ModelType(enum.Enum): | ||
HUGGINGFACE = 'HFLocalInvocationLayer' | ||
# VLLM = 'vLLMLocalInvocationLayer' TODO: After vLLM 0.22 will be relased, this will be possible | ||
|
||
def __init__(self, *args, character_level_parser: Optional[CharacterLevelParser] = None, **kwargs): | ||
"""Create a new prompt node that activates the LMFormatEnforcer on the generated text. See PromptNode | ||
documentation for all of the regular arguments. | ||
:param character_level_parser: A CharacterLevelParser that will be used to enforce the format of the generated""" | ||
super().__init__(*args, **kwargs) | ||
self.character_level_parser = character_level_parser | ||
self.model_type = self._resolve_model_type() | ||
self.token_enforcer_fn = self._prepare_token_enforcer_fn() | ||
|
||
def _prepare_token_enforcer_fn(self) -> Optional[Callable]: | ||
if not self.character_level_parser: | ||
return None | ||
if self.model_type == LMFormatEnforcerPromptNode.ModelType.HUGGINGFACE: | ||
tokenizer = self.prompt_model.model_invocation_layer.pipe.tokenizer | ||
from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn | ||
return build_transformers_prefix_allowed_tokens_fn(tokenizer, self.character_level_parser) | ||
raise NotImplementedError(f"Token enforcer not implemented for model type {self.model_type.name}") | ||
|
||
def _resolve_model_type(self) -> ModelType: | ||
invocation_layer_name = self.prompt_model.model_invocation_layer.__class__.__name__ | ||
try: | ||
return LMFormatEnforcerPromptNode.ModelType(invocation_layer_name) | ||
except ValueError: | ||
supported_strings = ",".join(str(t.name) for t in LMFormatEnforcerPromptNode.ModelType) | ||
raise ValueError(f"Unsupported invocation layer: {invocation_layer_name}. " | ||
f"Must be one of {supported_strings}") | ||
|
||
def _prepare_model_kwargs(self): | ||
model_kwargs = super()._prepare_model_kwargs() | ||
if self.token_enforcer_fn: | ||
if self.model_type == LMFormatEnforcerPromptNode.ModelType.HUGGINGFACE: | ||
if 'generation_kwargs' not in model_kwargs: | ||
model_kwargs['generation_kwargs'] = {} | ||
model_kwargs['generation_kwargs']['prefix_allowed_tokens_fn'] = self.token_enforcer_fn | ||
return model_kwargs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
try: | ||
from haystack.preview import component | ||
from canals import Component | ||
except ImportError: | ||
raise ImportError('haystack is not installed. Please install it with "pip install farm-haystack" or "pip install haystack-ai') | ||
|
||
import enum | ||
from typing import Any, Callable, Dict, List, Optional | ||
from lmformatenforcer import CharacterLevelParser | ||
|
||
|
||
class _ModelType(enum.Enum): | ||
HUGGINGFACE = 'HuggingFaceLocalGenerator' | ||
# VLLM = 'vLLMLocalInvocationLayer' TODO: Add this when vLLM has Haystack V2 support | ||
|
||
@component | ||
class LMFormatEnforcerLocalGenerator: | ||
"""A generator component for Haystack V2 API that activates the LMFormatEnforcer on the generated text. | ||
It wraps a local generator, and should be added to the pipeline instead of it""" | ||
def __init__(self, model_component: Component, character_level_parser: Optional[CharacterLevelParser] = None): | ||
"""Initialize the generator component | ||
:param model_component: A local generator component to wrap | ||
:param character_level_parser: A CharacterLevelParser that will be used to enforce the format of the generated""" | ||
self.model_component = model_component | ||
self.character_level_parser = character_level_parser | ||
self._model_type = self._resolve_model_type() | ||
self.token_enforcer_fn: Optional[Callable] = None | ||
|
||
@component.output_types(replies=List[str]) | ||
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): | ||
try: | ||
self._inject_enforcer_into_model() | ||
kwargs = {} | ||
if generation_kwargs: | ||
kwargs['generation_kwargs'] = generation_kwargs | ||
return self.model_component.run(prompt, **kwargs) | ||
finally: | ||
self._release_model_injection() | ||
|
||
def warm_up(self): | ||
if hasattr(self.model_component, 'warm_up'): | ||
self.model_component.warm_up() | ||
self.token_enforcer_fn = self._prepare_token_enforcer_fn() | ||
|
||
def _prepare_token_enforcer_fn(self) -> Optional[Callable]: | ||
if not self.character_level_parser: | ||
return None | ||
if self._model_type == _ModelType.HUGGINGFACE: | ||
tokenizer = self.model_component.pipeline.tokenizer | ||
from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn | ||
return build_transformers_prefix_allowed_tokens_fn(tokenizer, self.character_level_parser) | ||
raise NotImplementedError(f"Token enforcer not implemented for model type {self._model_type.name}") | ||
|
||
def _resolve_model_type(self) -> _ModelType: | ||
generator_component_name = self.model_component.__class__.__name__ | ||
try: | ||
return _ModelType(generator_component_name) | ||
except ValueError: | ||
supported_strings = ",".join(str(t.name) for t in _ModelType) | ||
raise ValueError(f"Unsupported local generator component layer: {generator_component_name}. " | ||
f"Must be one of {supported_strings}") | ||
|
||
def _inject_enforcer_into_model(self): | ||
if not self.token_enforcer_fn: | ||
return | ||
if self._model_type == _ModelType.HUGGINGFACE: | ||
self.model_component.generation_kwargs['prefix_allowed_tokens_fn'] = self.token_enforcer_fn | ||
|
||
|
||
def _release_model_injection(self): | ||
if not self.token_enforcer_fn: | ||
return | ||
if self._model_type == _ModelType.HUGGINGFACE: | ||
del self.model_component.generation_kwargs['prefix_allowed_tokens_fn'] |
Oops, something went wrong.